From 4e9b85ecfdec40fe891a7be9f3db746db724200c Mon Sep 17 00:00:00 2001 From: wozeparrot Date: Wed, 4 Mar 2026 19:15:49 +0800 Subject: [PATCH] fa: pull inputs out of call (#15127) --- extra/thunder/amd/fa.py | 8 +++++++- 1 file changed, 7 insertions(+), 1 deletion(-) diff --git a/extra/thunder/amd/fa.py b/extra/thunder/amd/fa.py index 2a919cc507..c201f9ee04 100644 --- a/extra/thunder/amd/fa.py +++ b/extra/thunder/amd/fa.py @@ -45,8 +45,14 @@ def flash_attention(xq, xk, xv, attn_mask:Tensor|None=None, is_causal:bool=False attn = _sharded_empty_like(xq, axis=shard_axis) l_vec = _sharded_empty((B, H, 1, N), xq, dtype=dtypes.float32, axis=shard_axis_t) - def grad(dou:UOp, _) -> tuple[None, None, UOp, UOp, UOp]: + def grad(dou:UOp, ker:UOp) -> tuple[None, None, UOp, UOp, UOp]: do = Tensor(dou, device=dou.device) + attn = Tensor(ker.src[1].after(ker), device=ker.src[1].device) + l_vec = Tensor(ker.src[2].after(ker), device=ker.src[2].device) + xq = Tensor(ker.src[3], device=ker.src[3].device) + xk = Tensor(ker.src[4], device=ker.src[4].device) + xv = Tensor(ker.src[5], device=ker.src[5].device) + dq = _sharded_empty((B, H, N, D), xq, axis=shard_axis_t) GROUP_SIZE = H_local // H_KV_local dk_partial = _sharded_empty((B * GROUP_SIZE, N, H_KV, D), xk, axis=shard_axis)