Created by: ngoyal2707
custom autograd function from transformer block to have more control over which ops to recompute. I know the code is ugly and adds another codepath for model :(
I think maybe we can just kill model parallel modules at some point?