Created by: tangbinh
Summary
We add a new script to reshard raw FSDP checkpoints as part of our efforts to consolidate the checkpoint resharding logic. This script is a bit more general than some of the existing ones:
- Compared to reshard_mp.py, it allows us to optionally unflatten model weights and be compatible with the generator interface when
ddp-backend
is set topytorch_ddp
. - Compared to the
consolidate_shard_weights
andbuild_unflat_state_dict
functions from FSDP (the former is used in stitch_fsdp_ckpt.py), it supports both unsharding and resharding model weights and optimizer states. - Compared to checkpoint_utils.py, which is used in
convert_to_singleton.py
, it doesn't require instantiating FSDP instances and avoid the various requirements that come with it (DDP, vocab files, configs, etc). We also decouple the filename handling to make it a bit more flexible.
Note that this script doesn't include the logic for model parallel resharding. We should probably have a separate script for it, which can be used together with this one.
Testing
- Run the script to merge the sharded checkpoints of the 2.7B parameters model into one shard for each model parallel part and load the resharded checkpoints with the interactive CLI:
for j in {0..3}; do
python -m metaseq.scripts.reshard_fsdp \
--input-glob-pattern "/data/gpt-z/models/gptz/2.7B/raw/checkpoint_last-model_part-$j-shard*.pt" \
--output-shard-name "/shared/home/binhtang/checkpoints/opt-2.7b/reshard-model_part-$j.pt" \
--num-output-shards 1 --skip-optimizer-state True --unflatten-weights True;
done
python -m metaseq.cli.interactive_cli
> what is the meaning of life?
To be happy.
- Run the script to reshard the 6.7B parameters model checkpoint for each model parallel part from 256 shards to 1 shard and from 1 shard back to 256 shards. The sharded checkpoints we get back are almost identical to the original ones except for some rank-specific data that are lost during the first conversion due to rank 0 copies (e.g
optimizer_history
,extra_state
,cfg.distributed_training.distributed_rank
).
for j in {0..1}; do
python -m metaseq.scripts.reshard_fsdp \
--input-glob-pattern "/data/gpt-z/models/gptz/6.7B/raw/checkpoint_last-model_part-$j-shard*.pt" \
--output-shard-name "/shared/home/binhtang/checkpoints/opt-6.7b/reshard-model_part-$j.pt" \
--num-output-shards 1 --skip-optimizer-state False --unflatten-weights False;
done
for j in {0..1}; do
python -m metaseq.scripts.reshard_fsdp \
--input-glob-pattern "/shared/home/binhtang/checkpoints/opt-6.7b/reshard-model_part-$j.pt" \
--output-shard-name "/shared/home/binhtang/checkpoints/opt-6.7b-reshard/checkpoint_last-model_part-$j-shard{i}.pt" \
--num-output-shards 256 --skip-optimizer-state False --unflatten-weights False;
done
import torch
for i in range(256):
before = torch.load(f"/data/gpt-z/models/gptz/6.7B/raw/checkpoint_last-model_part-0-shard{i}.pt", map_location=torch.device("cpu"))
after = torch.load(f"/shared/home/binhtang/checkpoints/opt-6.7b-reshard/checkpoint_last-model_part-0-shard{i}.pt", map_location=torch.device("cpu"))
assert all(torch.allclose(before["model"][k], after["model"][k]) for k in before["model"].keys())
assert(before["shard_metadata"] == after["shard_metadata"])
assert(torch.allclose(x['exp_avg'], y['exp_avg']) for x, y in zip(before['last_optimizer_state']['state'], after['last_optimizer_state']['state']) for key in ('exp_avg', 'exp_avg_sq'))