Created by: tangbinh
Summary of Changes
The existing script for resharding model parallel parts (i.e. metaseq/scripts/reshard_model_parallel.py
) loads all checkpoint parts at once and might result in OOM issues under RAM constraints, especially for very large models. Here, we rewrite the script and optimize for memory usage by first allocating an unsharded model state dict and iteratively merging model parallel parts into it.
Previously, peak memory usage was close to 2X model size as we needed to hold input and output state dicts, but theoretically it's closer to 1x model size now thanks to the iterative process.
The new script produces the same output as metaseq/scripts/reshard_model_parallel.py
. We delete it to avoid duplication and note that the old script still remains accessible in the internal repo (see this script).
Test Plan
- Run the script with an OPT 2.7B checkpoint to reshard 4 MP parts into 8 MP parts and make sure the resulting checkpoint performs reasonably:
seq 0 3 | parallel --line-buffer 'python metaseq/scripts/reshard_fsdp.py --input "/data/checkpoints/opt-2.7b/raw/checkpoint_last-model_part-{}-shard*.pt" --output "/data/checkpoints/opt-2.7b/reshard-no-os/reshard-model_part-{}.pt" --skip-optimizer-state True --unflatten-weights True --output-dtype fp16' python -m metaseq.scripts.reshard_mp --input "/data/checkpoints/opt-2.7b/reshard_no_os/reshard-model_part-*.pt" --output "/data/checkpoints/opt-2.7b/reshard_no_os_mp8/reshard-model_part-{i}.pt" --num-output-parts 8
python metaseq/scripts/interactive.py --merges-filename /data/checkpoints/gpt2-merges.txt --vocab-filename /data/checkpoints/gpt2-vocab.json --path /data/checkpoints/opt-2.7b/reshard_no_os_mp8/reshard.pt --model-parallel-size 8 --distributed-world-size 8 --beam 3 --max-source-positions 4 --max-target-positions 128 > Prompt: What is the meaning of life? Output: To be happy.
- We compare performance with
metaseq/scripts/reshard_model_parallel.py
while resharding an OPT-175B checkpoint from 8 MP parts into 16 MP parts. The old script takes 849.40 seconds and results in a peak RSS delta of 668,301 MB while the new script takes 891.65 seconds and has RSS delta of 458,185 MB (a 46% reduction in RAM usage).python metaseq/scripts/reshard_model_parallel.py --pth_prefix /data/checkpoints/opt-175b/reshard_no_os_unflat/reshard.pt --new-model-parts 16 --save-prefix /data/checkpoints/opt-175b/reshard_no_os_unflat_mp16_ref/reshard.pt python -m metaseq.scripts.reshard_mp --input "/data/checkpoints/opt-175b/reshard_no_os_unflat/reshard-model_part-*.pt" --output "/data/checkpoints/opt-175b/reshard_no_os_unflat_mp16/reshard-model_part-{i}.pt" --num-output-parts 16