Created by: tangbinh
Summary of Changes
We add an option to convert weights into a new dtype
while resharding FSDP checkpoints. This helps reduce checkpoint sizes and avoids issues under RAM constraints when we attempt to load checkpoints. For example, model weights might be saved in full precision but one only needs half precision at inference time.
We also rename some options (e.g. input-glob-pattern
→ input
and output-shard-name
→ output
) for succinctness.
Test Plan
- Reshard and convert the OPT-125M checkpoint into various dtypes and make sure we can do inference correctly in the new dtypes:
seq 0 1 | parallel --line-buffer 'python metaseq/scripts/reshard_fsdp.py --input "/data/checkpoints/opt-125m/raw/checkpoint_last-model_part-{}-shard*.pt" --output "/data/checkpoints/opt-125m/reshard-no-os/reshard-model_part-{}.pt" --skip-optimizer-state True --unflatten-weights True --output-dtype fp16'