Created by: sriniiyer
Description: This diff allows us to backpropagate losses only on designated target tokens rather than all the input tokens. It introduces a new jsonl data format for this, where "src" and "tgt" keys are specified. To use this, use the streaming_finetune_language_modeling task. The task concatenates the src and the tgt, and is trained to produce this autoregressively, but only takes losses on the tgt tokens into account. This is particularly useful during fine-tuning.
Test plan
-
Does not change existing pre-training i.e. tested existing pre-training on 125M model and ppl for the first 5 updates were exactly the same i.e. unchanged.
-
Works successfully on src-tgt-format without crashing and performs competitively on copa from superglue.