fa: pull inputs out of call (#15127)

This commit is contained in:
wozeparrot
2026-03-04 19:15:49 +08:00
committed by GitHub
parent 47faa2d7b4
commit 4e9b85ecfd

View File

@@ -45,8 +45,14 @@ def flash_attention(xq, xk, xv, attn_mask:Tensor|None=None, is_causal:bool=False
attn = _sharded_empty_like(xq, axis=shard_axis)
l_vec = _sharded_empty((B, H, 1, N), xq, dtype=dtypes.float32, axis=shard_axis_t)
def grad(dou:UOp, _) -> tuple[None, None, UOp, UOp, UOp]:
def grad(dou:UOp, ker:UOp) -> tuple[None, None, UOp, UOp, UOp]:
do = Tensor(dou, device=dou.device)
attn = Tensor(ker.src[1].after(ker), device=ker.src[1].device)
l_vec = Tensor(ker.src[2].after(ker), device=ker.src[2].device)
xq = Tensor(ker.src[3], device=ker.src[3].device)
xk = Tensor(ker.src[4], device=ker.src[4].device)
xv = Tensor(ker.src[5], device=ker.src[5].device)
dq = _sharded_empty((B, H, N, D), xq, axis=shard_axis_t)
GROUP_SIZE = H_local // H_KV_local
dk_partial = _sharded_empty((B * GROUP_SIZE, N, H_KV, D), xk, axis=shard_axis)