Created by: punitkoura
Patch Description We are trying to optimize model training by having document level attention. What this entails is that we modify the attention mask to just attend to tokens in the current document, as opposed to having a constant triangular (causal) attention mask. The goal here is to ensure that tokens don't attend across documents.
To accomplish this, we
- Change the constant upper triangular mask to depend on the data separator tokens, specified by the
self_attn_doc_sep
field - Reset the positional encoding at the end of each document.
Testing steps Consider a list of 20 documents. We consider the following two cases -
- Passing the documents, one per sequence, in a batch of 20 sequences. No special attention mask is used, we still use the upper triangular mask.
- Passing the documents packed in a single sequence, separated by a special document separator token. Document attention is used here.
By intuition, we can see that the gradients in these two scenarios should match. We find that it is indeed the case.
Documents packed, with document level attention
> /fsx-mudslide/punitkoura/src/metaseq/metaseq/trainer.py(781)train_step()
-> grad_norm = self.clip_grad_norm(
(Pdb) self.model.decoder.embed_positions.weight.grad[2:10]
tensor([[-6.5479e-03, 9.6962e-03, 2.6375e-03, ..., -5.1121e-03,
-1.9985e-02, 2.4901e-03],
[ 2.5498e-02, -7.1761e-03, 2.2065e-02, ..., 6.4638e-03,
-1.0154e-02, 5.8792e-03],
[ 7.9259e-05, -4.0577e-03, 3.9698e-03, ..., 2.3798e-04,
3.1224e-03, 1.0833e-02],
...,
[-7.4849e-06, 1.7103e-03, -1.5433e-03, ..., 2.2430e-03,
-9.3732e-03, 1.5123e-03],
[-1.4408e-03, -1.8149e-03, 6.9925e-06, ..., 1.1794e-03,
-1.4160e-03, 2.6913e-03],
[-1.4234e-03, -6.3481e-03, 7.6257e-03, ..., 7.9472e-03,
-1.1467e-03, 3.3719e-03]], device='cuda:0')
(Pdb)
Documents sent one by one
> /fsx-mudslide/punitkoura/src/metaseq/metaseq/trainer.py(781)train_step()
-> grad_norm = self.clip_grad_norm(
(Pdb) self.model.decoder.embed_positions.weight.grad[2:10]
tensor([[-6.5555e-03, 9.6926e-03, 2.6306e-03, ..., -5.1131e-03,
-1.9988e-02, 2.4959e-03],
[ 2.5496e-02, -7.1761e-03, 2.2073e-02, ..., 6.4701e-03,
-1.0157e-02, 5.8840e-03],
[ 7.5817e-05, -4.0496e-03, 3.9688e-03, ..., 2.3734e-04,
3.1189e-03, 1.0830e-02],
...,
[-1.3878e-05, 1.7050e-03, -1.5432e-03, ..., 2.2473e-03,
-9.3719e-03, 1.5083e-03],
[-1.4369e-03, -1.8114e-03, 5.8990e-06, ..., 1.1820e-03,
-1.4172e-03, 2.6962e-03],
[-1.4229e-03, -6.3467e-03, 7.6208e-03, ..., 7.9439e-03,
-1.1527e-03, 3.3779e-03]], device='cuda:0')
(Pdb)