Remove slicing for out in save for bwd

This commit is contained in:
Vinayak Gokhale
2023-11-22 05:29:54 +00:00
committed by Vinayak Gokhale
parent e0a4d97569
commit dc62569e57

View File

@@ -500,7 +500,7 @@ class _attention(torch.autograd.Function):
num_stages=1, num_warps=num_warps
)
ctx.save_for_backward(q, k, v, o[:,:,0:seqlen,:], M)
ctx.save_for_backward(q, k, v, o, M)
ctx.grid = grid
ctx.sm_scale = sm_scale
ctx.BLOCK_DMODEL = Lk