mirror of
https://github.com/ROCm/ROCm.git
synced 2026-04-05 03:01:17 -04:00
[TUTORIALS] alow BLOCK(bwd) != BLOCK_M(fwd) in flash attention (#2020)
This commit is contained in:
@@ -342,7 +342,7 @@ class _attention(torch.autograd.Function):
|
||||
dk = torch.empty_like(k)
|
||||
dv = torch.empty_like(v)
|
||||
delta = torch.empty_like(L)
|
||||
_bwd_preprocess[(ctx.grid[0] * ctx.grid[1], )](
|
||||
_bwd_preprocess[(cdiv(q.shape[2], BLOCK) * ctx.grid[1], )](
|
||||
o, do,
|
||||
delta,
|
||||
BLOCK_M=BLOCK, D_HEAD=ctx.BLOCK_DMODEL,
|
||||
|
||||
Reference in New Issue
Block a user