Created by: suchenzang
From MHA, removed the following args:
need_head_weights
need_weights
static_kv
before_softmax
Cleaned up some cluttered returns from MHA (attn_weights
and l_aux
).
Follow-up: https://github.com/facebookresearch/metaseq/issues/423
Tested with 2x model parallel 125m, and non-model-parallel 125m.