[TUTORIALS] alow BLOCK(bwd) != BLOCK_M(fwd) in flash attention (#2020)

This commit is contained in:
YouJiacheng
2023-08-18 06:31:53 +08:00
committed by GitHub
parent 2d513dbf50
commit 0970a297b2

View File

@@ -342,7 +342,7 @@ class _attention(torch.autograd.Function):
dk = torch.empty_like(k)
dv = torch.empty_like(v)
delta = torch.empty_like(L)
_bwd_preprocess[(ctx.grid[0] * ctx.grid[1], )](
_bwd_preprocess[(cdiv(q.shape[2], BLOCK) * ctx.grid[1], )](
o, do,
delta,
BLOCK_M=BLOCK, D_HEAD=ctx.BLOCK_DMODEL,