Created by: stephenroller
Patch Description Since we're doing manual activation checkpointing, we need to have custom backwards for MHA. This patch leverages the flash implementation in xformers.
TODO:
-
Gate behind the appropriate (existing) flag, and allow vanilla implementation to still exist -
Add a test
Testing steps At very large scale, was a ~0.5-1% speedup. Probably not worth it at the largest scales given the risk of numeric changes, but maybe still worth it for medium scales.