Created by: thomasw21
Patch Description Describe your changes
- Related to: #60
convert_to_singleton.py
doesn't seem to recusively unwrap fully sharded modules, leaving with the following parameters in restored.pt
(notice the flat_param
substring):
{'decoder.embed_positions.weight': torch.Size([2050, 2048]), 'decoder.embed_tokens.weight': torch.Size([50272, 2048]), 'decoder.layer_norm.bias': torch.Size([2048]), 'decoder.layer_norm.weight': torch.Size([2048]), 'decoder.layers.0.flat_param_0': torch.Size([25185280]), 'decoder.layers.1.flat_param_0': torch.Size([25185280]), 'decoder.layers.10.flat_param_0': torch.Size([25185280]), 'decoder.layers.11.flat_param_0': torch.Size([25185280]), 'decoder.layers.12.flat_param_0': torch.Size([25185280]), 'decoder.layers.13.flat_param_0': torch.Size([25185280]), 'decoder.layers.14.flat_param_0': torch.Size([25185280]), 'decoder.layers.15.flat_param_0': torch.Size([25185280]), 'decoder.layers.16.flat_param_0': torch.Size([25185280]), 'decoder.layers.17.flat_param_0': torch.Size([25185280]), 'decoder.layers.18.flat_param_0': torch.Size([25185280]), 'decoder.layers.19.flat_param_0': torch.Size([25185280]), 'decoder.layers.2.flat_param_0': torch.Size([25185280]), 'decoder.layers.20.flat_param_0': torch.Size([25185280]), 'decoder.layers.21.flat_param_0': torch.Size([25185280]), 'decoder.layers.22.flat_param_0': torch.Size([25185280]), 'decoder.layers.23.flat_param_0': torch.Size([25185280]), 'decoder.layers.3.flat_param_0': torch.Size([25185280]), 'decoder.layers.4.flat_param_0': torch.Size([25185280]), 'decoder.layers.5.flat_param_0': torch.Size([25185280]), 'decoder.layers.6.flat_param_0': torch.Size([25185280]), 'decoder.layers.7.flat_param_0': torch.Size([25185280]), 'decoder.layers.8.flat_param_0': torch.Size([25185280]), 'decoder.layers.9.flat_param_0': torch.Size([25185280]), 'decoder.version': torch.Size([1])}
Testing steps Describe how you tested your changes
Tested on 1B3
checkpoint, and the keys to restored.pt
correspond to their unwrapped version. Haven't tested logits/generation as model should already be loaded from checkpoint before hand.
cc @stephenroller