Created by: stephenroller
Update: buggy. These results are invalid.
Patch Description Switches us to flashattention. Did it quick and dirty to test.
Experimental Setup Tried launching the 2.7B OPT baseline on 64 gpus. Observed that flashattention didn't like this, because its dimensionality of the attention heads is 40, and flashattention only supports 32/64/128. Resolved this by swapping # heads and head dimension, so that we have 40 heads in R^32 instead of 32 heads in R^40.
To control for this, launched 3 versions:
- (green) Our Megatron-based attention implementation, with the original OPT setup
- (orange) A Megatron-based attention implementation with 40 heads in R^32
- (purple) Flash attention based with 40 heads in R^32
Results We observe the flash attention significantly reduces memory usage:
We observe that flash attention significantly increases throughput:
I was quite surprised by how extreme this speedup is. Based on other's reports and the known FLOPS ratio, it should've been closer to 1.1x, not the 1.6x we're observing. Perhaps the key is that we get to avoid the transpose now? Our FLOPS/GPU aren't particularly high in any of these cases: the baseline is 77 TFLOPS/GPU and flash attention is 128 TFLOPS/GPU, neither of which is particularly great. We did not spend very much time ever optimizing the 2.7B, but it seems like we must be doing something really bad with it.
But here's the downside. It seems to significantly hurt stability and decrease convergence
Unfortunately, unless the stability is resolved, I can't recommend flash attention replace our current implementation.
Next steps A version based on the triton implementation is likely preferable.