diff --git a/python/perf-kernels/flash-attention-seqlen-padded.py b/python/perf-kernels/flash-attention-seqlen-padded.py index 19291a825..db83e5f2a 100644 --- a/python/perf-kernels/flash-attention-seqlen-padded.py +++ b/python/perf-kernels/flash-attention-seqlen-padded.py @@ -500,7 +500,7 @@ class _attention(torch.autograd.Function): num_stages=1, num_warps=num_warps ) - ctx.save_for_backward(q, k, v, o[:,:,0:seqlen,:], M) + ctx.save_for_backward(q, k, v, o, M) ctx.grid = grid ctx.sm_scale = sm_scale ctx.BLOCK_DMODEL = Lk