How to load sharded checkpoints?
Created by: patrickvonplaten
❓ Questions and Help
After having set-up the libraries as described in: https://github.com/facebookresearch/metaseq/blob/main/docs/setup.md , it is possible to load the 350m checkpoint since it's not sharded as follows:
wget https://dl.fbaipublicfiles.com/opt/v1_20220502/350m/reshard.pt ./
-
Next we need to comment out one line in the Megatron-LM library which is only relevant for training (initialize different random seeds accross pp ranks): Comment out this line: https://github.com/ngoyal2707/Megatron-LM/blob/ae0b844c1f6725c3433a95e42cac760b3885170b/megatron/initialize.py#L65 in your local clone of Megatron-LM
-
Now we write the following Python script to a
run_model.py
file:
import os
from transformers import AutoTokenizer, GPT2Tokenizer
from megatron.initialize import initialize_megatron
from metaseq import checkpoint_utils
import torch
path = "./"
# arguments taken from: https://arxiv.org/pdf/2205.01068.pdf | table 1
initialize_megatron(args_defaults={
"micro_batch_size": 1,
"num_layers": 24,
"hidden_size": 1024,
"num_attention_heads": 16,
"max_position_embeddings": 2048,
"encoder_seq_length": 2048
})
tokenizer = GPT2Tokenizer.from_pretrained("patrickvonplaten/opt_gpt2_tokenizer")
tokenizer.save_pretrained(path)
checkpoint = checkpoint_utils.load_model_ensemble_and_task(
[os.path.join(path, "reshard.pt")],
arg_overrides={
"vocab_filename": os.path.join(path, "vocab.json"),
"merges_filename": os.path.join(path, "merges.txt"),
}
)
model = checkpoint[0][0].eval()
- We can load the checkpoint when running
torchrun run_model.py --pipeline-model-parallel-size 1 --tensor-model-parallel-size 1
Problem This only works for the 350m checkpoint!!! For the other checkpoints this doesn't work.
E.g. when replacing:
[os.path.join(path, "reshard.pt")]
by
[os.path.join(path, "reshard-model_part-0.pt"), os.path.join(path, "reshard-model_part-1.pt")]
(part-0 and part-1 of the 125M model),
we're getting an error because the weigths are all flattened into 1D-arrays.
Using https://github.com/facebookresearch/metaseq/pull/29 sadly also doesn't help, since the checkpoints don't seem to be in the *shard*
format as required here: https://github.com/facebookresearch/metaseq/blob/48b9b6c083237f9b95c2eb67afc10005e10d67ee/metaseq/distributed/stitch_fsdp_ckpt.py#L45
The parameter flattening seems to come from Fairscale and we've found some functionality to unflatten it here: https://github.com/facebookresearch/fairscale/blob/51b53ddb6c3aa77426c7d5cc0b543b79628053c4/fairscale/nn/misc/flatten_params_wrapper.py#L358 , but we don't manage to wrap our head around how to make it work exactly.
@stephenroller @suchenzang @zhiqwang - any pointers on how we could load the 125M model (and the others) into a model
instance of metaseq
?