diff --git a/test/testextra/test_tk.py b/test/testextra/test_tk.py index e71c8bf947..a932a62c9c 100644 --- a/test/testextra/test_tk.py +++ b/test/testextra/test_tk.py @@ -807,7 +807,7 @@ class TestTK(unittest.TestCase): Tensor.manual_seed(42) - B, N, H, H_KV, D = 1, 32, 2, 1, 32 + B, N, H, H_KV, D = 1, 1024, 32, 32, 128 with Context(DEBUG=0): q = Tensor.randn(B, N, H, D, dtype=dtypes.bfloat16, requires_grad=True).contiguous() @@ -840,5 +840,57 @@ class TestTK(unittest.TestCase): np.testing.assert_allclose(v.grad.numpy(), v_ref.grad.numpy(), atol=2e-2, rtol=2e-2) np.testing.assert_allclose(k.grad.numpy(), k_ref.grad.numpy(), atol=5e-2, rtol=2e-2) + @unittest.expectedFailure + def test_fast_fa_bwd_causal_jitted(self): + from extra.thunder.tiny.fa import flash_attention + + Tensor.manual_seed(42) + + B, N, H, H_KV, D = 1, 1024, 32, 32, 128 + + with Context(DEBUG=0): + q = Tensor.randn(B, N, H, D, dtype=dtypes.bfloat16, requires_grad=True).contiguous() + k = Tensor.randn(B, N, H_KV, D, dtype=dtypes.bfloat16, requires_grad=True).contiguous() + v = Tensor.randn(B, N, H_KV, D, dtype=dtypes.bfloat16, requires_grad=True).contiguous() + Tensor.realize(q, k, v) + + do = Tensor.ones(B, N, H, D, dtype=dtypes.float32).contiguous() + Tensor.realize(do) + + def fn(q, k, v, do): + q_, k_, v_ = q.transpose(1, 2), k.transpose(1, 2), v.transpose(1, 2) + out = flash_attention(q_, k_, v_, is_causal=True) + out = out.float().transpose(1, 2) + out.backward(do) + Tensor.realize(out, q.grad, k.grad, v.grad) + return q.grad, k.grad, v.grad + + fn_jitted = TinyJit(fn) + + for _ in range(10): + q = Tensor.randn(B, N, H, D, dtype=dtypes.bfloat16, requires_grad=True).contiguous() + k = Tensor.randn(B, N, H_KV, D, dtype=dtypes.bfloat16, requires_grad=True).contiguous() + v = Tensor.randn(B, N, H_KV, D, dtype=dtypes.bfloat16, requires_grad=True).contiguous() + Tensor.realize(q, k, v) + do = Tensor.ones(B, N, H, D, dtype=dtypes.float32).contiguous() + Tensor.realize(do) + q.grad, k.grad, v.grad = fn_jitted(q, k, v, do) + + with Context(DEBUG=0): + q_ref = q.detach().clone().requires_grad_(True) + k_ref = k.detach().clone().requires_grad_(True) + v_ref = v.detach().clone().requires_grad_(True) + Tensor.realize(q_ref, k_ref, v_ref) + + q_ref_, k_ref_, v_ref_ = q_ref.transpose(1, 2), k_ref.transpose(1, 2), v_ref.transpose(1, 2) + ref = q_ref_.scaled_dot_product_attention(k_ref_, v_ref_, is_causal=True) + ref = ref.float().transpose(1, 2) + ref.backward(do) + Tensor.realize(q_ref.grad, k_ref.grad, v_ref.grad) + + np.testing.assert_allclose(q.grad.numpy(), q_ref.grad.numpy(), atol=5e-2, rtol=2e-2) + np.testing.assert_allclose(v.grad.numpy(), v_ref.grad.numpy(), atol=2e-2, rtol=2e-2) + np.testing.assert_allclose(k.grad.numpy(), k_ref.grad.numpy(), atol=5e-2, rtol=2e-2) + if __name__ == "__main__": unittest.main()