Created by: punitkoura
Patch Description
- Support overriding sequence parallelism in the API
- Support loading FSDP sharded models through the API
Testing steps
- Take a sequence parallel checkpoint, and create a constants module file (custom_constants_module.py)
# Copyright (c) Meta Platforms, Inc. and affiliates. All Rights Reserved.
#
# This source code is licensed under the MIT license found in the
# LICENSE file in the root directory of this source tree.
import os
MAX_SEQ_LEN = 2048
BATCH_SIZE = 2048 # silly high bc we dynamically batch by MAX_BATCH_TOKENS
MAX_BATCH_TOKENS = 3072
DEFAULT_PORT = 6010
MODEL_PARALLEL = <REPLACE_MODEL_PARALLEL_SIZE_HERE>
TOTAL_WORLD_SIZE = <REPLACE_WORLD_SIZE_HERE>
MAX_BEAM = 16
CHECKPOINT_FOLDER = <INSERT_MODEL_CHECKPOINT_FOLDER_HERE>
# tokenizer files
HF_TOKENIZER = <INSERT_TOKENIZER_FILE_HERE>
MODEL_FILE = os.path.join(CHECKPOINT_FOLDER, "reshard.pt")
LAUNCH_ARGS = [
f"--model-parallel-size {MODEL_PARALLEL}",
f"--distributed-world-size {TOTAL_WORLD_SIZE}",
"--ddp-backend fully_sharded",
"--task language_modeling",
"--bpe hf_byte_bpe",
f"--hf-tokenizer {HF_TOKENIZER}",
f"--path {MODEL_FILE}",
"--beam 1 --nbest 1",
"--distributed-port 13000",
"--checkpoint-shard-count 1",
"--use-sharded-state",
f"--batch-size {BATCH_SIZE}",
f"--buffer-size {BATCH_SIZE * MAX_SEQ_LEN}",
f"--max-tokens {BATCH_SIZE * MAX_SEQ_LEN}",
"/tmp", # required "data" argument.
]
# Optional arg overrides which influence model loading during inference
INFERENCE_ARG_OVERRIDES = {"sequence_parallel": False}
- Export the constants module for it to be visible to interactive_hosted.py
export PYTHONPATH=$PYTHONPATH:/path/to/custom_constants_module METASEQ_SERVICE_CONSTANTS_MODULE=custom_constants_module
- Run interactive_hosted.py
srun --exclusive -N <NUMBER_OF_NODES> --gpus-per-node <NUMBER_OF_GPUS_PER_NODE> --tasks 1 -c 96 --partition <PUT_PARTITION_NAME_HERE> --time "1-00:00:00" --qos high --pty python metaseq/cli/interactive_hosted.py
NUMBER_OF_NODES*NUMBER_OF_GPUS_PER_NODE should be equal to MODEL_PARALLEL_SIZE
- Prompt the model
curl -k http://<ip>:<port>/completions -H "Authorization: Bearer Punit" -H "Content-Type: application/json" \
-d '{
"prompt": [16853, 16947, 19678, 16709, 16647, 19493, 17495, 16688, 16397],
"temperature": 0.0,
"max_tokens": 0, "min_tokens": 0,
"top_p": 1.0, "n": 1, "best_of": 1,
"echo": true, "logprobs": 1, "seed": 1
}'