tk: fa cleanups + causal test (#13963)

This commit is contained in:
wozeparrot
2026-01-01 21:05:00 -05:00
committed by GitHub
parent af0392efea
commit ecbac8a338
2 changed files with 38 additions and 1 deletions

View File

@@ -340,7 +340,6 @@ def flash_attention(xq, xk, xv, attn_mask:Tensor|None=None, is_causal:bool=False
mask = Tensor(kernel.src[5])
delta_vec = (grad * attn).sum(-1).transpose(1, 2).unsqueeze(-2).detach()
print(l_vec.shape, delta_vec.shape, grad.shape, attn.shape, grad_q.shape, grad_k.shape, grad_v.shape)
grad_q = Tensor.custom_kernel(grad_q, grad, q, k, v, mask, l_vec, delta_vec, fxn=custom_backward_q)[0]
grad_k, grad_v = Tensor.custom_kernel(grad_k, grad_v, grad, q, k, v, mask, l_vec, delta_vec, fxn=custom_backward_kv)[:2]

View File

@@ -802,5 +802,43 @@ class TestTK(unittest.TestCase):
np.testing.assert_allclose(v.grad.numpy(), v_ref.grad.numpy(), atol=2e-2, rtol=2e-2)
np.testing.assert_allclose(k.grad.numpy(), k_ref.grad.numpy(), atol=5e-2, rtol=2e-2)
def test_fast_fa_bwd_causal(self):
from extra.thunder.tiny.fa import flash_attention
Tensor.manual_seed(42)
B, N, H, H_KV, D = 1, 32, 2, 1, 32
with Context(DEBUG=0):
q = Tensor.randn(B, N, H, D, dtype=dtypes.bfloat16, requires_grad=True).contiguous()
k = Tensor.randn(B, N, H_KV, D, dtype=dtypes.bfloat16, requires_grad=True).contiguous()
v = Tensor.randn(B, N, H_KV, D, dtype=dtypes.bfloat16, requires_grad=True).contiguous()
Tensor.realize(q, k, v)
do = Tensor.ones(B, N, H, D, dtype=dtypes.float32).contiguous()
Tensor.realize(do)
q_, k_, v_ = q.transpose(1, 2), k.transpose(1, 2), v.transpose(1, 2)
out = flash_attention(q_, k_, v_, is_causal=True)
out = out.float().transpose(1, 2)
out.backward(do)
Tensor.realize(q.grad, k.grad, v.grad)
with Context(DEBUG=0):
q_ref = q.detach().clone().requires_grad_(True)
k_ref = k.detach().clone().requires_grad_(True)
v_ref = v.detach().clone().requires_grad_(True)
Tensor.realize(q_ref, k_ref, v_ref)
q_ref_, k_ref_, v_ref_ = q_ref.transpose(1, 2), k_ref.transpose(1, 2), v_ref.transpose(1, 2)
ref = q_ref_.scaled_dot_product_attention(k_ref_, v_ref_, is_causal=True)
ref = ref.float().transpose(1, 2)
ref.backward(do)
Tensor.realize(q_ref.grad, k_ref.grad, v_ref.grad)
np.testing.assert_allclose(q.grad.numpy(), q_ref.grad.numpy(), atol=2e-2, rtol=2e-2)
np.testing.assert_allclose(v.grad.numpy(), v_ref.grad.numpy(), atol=2e-2, rtol=2e-2)
np.testing.assert_allclose(k.grad.numpy(), k_ref.grad.numpy(), atol=5e-2, rtol=2e-2)
if __name__ == "__main__":
unittest.main()