tk: fa cleanups + causal test (#13963)

This commit is contained in:
wozeparrot
2026-01-01 21:05:00 -05:00
committed by GitHub
parent af0392efea
commit ecbac8a338
2 changed files with 38 additions and 1 deletions

View File

@@ -340,7 +340,6 @@ def flash_attention(xq, xk, xv, attn_mask:Tensor|None=None, is_causal:bool=False
mask = Tensor(kernel.src[5])
delta_vec = (grad * attn).sum(-1).transpose(1, 2).unsqueeze(-2).detach()
print(l_vec.shape, delta_vec.shape, grad.shape, attn.shape, grad_q.shape, grad_k.shape, grad_v.shape)
grad_q = Tensor.custom_kernel(grad_q, grad, q, k, v, mask, l_vec, delta_vec, fxn=custom_backward_q)[0]
grad_k, grad_v = Tensor.custom_kernel(grad_k, grad_v, grad, q, k, v, mask, l_vec, delta_vec, fxn=custom_backward_kv)[:2]