mirror of
https://github.com/tinygrad/tinygrad.git
synced 2026-01-06 21:53:53 -05:00
fa: failing test for bwd jit (#14009)
* tk: failing test for bwd jit * feat: mark expectedFailure * clean: spaces
This commit is contained in:
@@ -807,7 +807,7 @@ class TestTK(unittest.TestCase):
|
||||
|
||||
Tensor.manual_seed(42)
|
||||
|
||||
B, N, H, H_KV, D = 1, 32, 2, 1, 32
|
||||
B, N, H, H_KV, D = 1, 1024, 32, 32, 128
|
||||
|
||||
with Context(DEBUG=0):
|
||||
q = Tensor.randn(B, N, H, D, dtype=dtypes.bfloat16, requires_grad=True).contiguous()
|
||||
@@ -840,5 +840,57 @@ class TestTK(unittest.TestCase):
|
||||
np.testing.assert_allclose(v.grad.numpy(), v_ref.grad.numpy(), atol=2e-2, rtol=2e-2)
|
||||
np.testing.assert_allclose(k.grad.numpy(), k_ref.grad.numpy(), atol=5e-2, rtol=2e-2)
|
||||
|
||||
@unittest.expectedFailure
|
||||
def test_fast_fa_bwd_causal_jitted(self):
|
||||
from extra.thunder.tiny.fa import flash_attention
|
||||
|
||||
Tensor.manual_seed(42)
|
||||
|
||||
B, N, H, H_KV, D = 1, 1024, 32, 32, 128
|
||||
|
||||
with Context(DEBUG=0):
|
||||
q = Tensor.randn(B, N, H, D, dtype=dtypes.bfloat16, requires_grad=True).contiguous()
|
||||
k = Tensor.randn(B, N, H_KV, D, dtype=dtypes.bfloat16, requires_grad=True).contiguous()
|
||||
v = Tensor.randn(B, N, H_KV, D, dtype=dtypes.bfloat16, requires_grad=True).contiguous()
|
||||
Tensor.realize(q, k, v)
|
||||
|
||||
do = Tensor.ones(B, N, H, D, dtype=dtypes.float32).contiguous()
|
||||
Tensor.realize(do)
|
||||
|
||||
def fn(q, k, v, do):
|
||||
q_, k_, v_ = q.transpose(1, 2), k.transpose(1, 2), v.transpose(1, 2)
|
||||
out = flash_attention(q_, k_, v_, is_causal=True)
|
||||
out = out.float().transpose(1, 2)
|
||||
out.backward(do)
|
||||
Tensor.realize(out, q.grad, k.grad, v.grad)
|
||||
return q.grad, k.grad, v.grad
|
||||
|
||||
fn_jitted = TinyJit(fn)
|
||||
|
||||
for _ in range(10):
|
||||
q = Tensor.randn(B, N, H, D, dtype=dtypes.bfloat16, requires_grad=True).contiguous()
|
||||
k = Tensor.randn(B, N, H_KV, D, dtype=dtypes.bfloat16, requires_grad=True).contiguous()
|
||||
v = Tensor.randn(B, N, H_KV, D, dtype=dtypes.bfloat16, requires_grad=True).contiguous()
|
||||
Tensor.realize(q, k, v)
|
||||
do = Tensor.ones(B, N, H, D, dtype=dtypes.float32).contiguous()
|
||||
Tensor.realize(do)
|
||||
q.grad, k.grad, v.grad = fn_jitted(q, k, v, do)
|
||||
|
||||
with Context(DEBUG=0):
|
||||
q_ref = q.detach().clone().requires_grad_(True)
|
||||
k_ref = k.detach().clone().requires_grad_(True)
|
||||
v_ref = v.detach().clone().requires_grad_(True)
|
||||
Tensor.realize(q_ref, k_ref, v_ref)
|
||||
|
||||
q_ref_, k_ref_, v_ref_ = q_ref.transpose(1, 2), k_ref.transpose(1, 2), v_ref.transpose(1, 2)
|
||||
ref = q_ref_.scaled_dot_product_attention(k_ref_, v_ref_, is_causal=True)
|
||||
ref = ref.float().transpose(1, 2)
|
||||
ref.backward(do)
|
||||
Tensor.realize(q_ref.grad, k_ref.grad, v_ref.grad)
|
||||
|
||||
np.testing.assert_allclose(q.grad.numpy(), q_ref.grad.numpy(), atol=5e-2, rtol=2e-2)
|
||||
np.testing.assert_allclose(v.grad.numpy(), v_ref.grad.numpy(), atol=2e-2, rtol=2e-2)
|
||||
np.testing.assert_allclose(k.grad.numpy(), k_ref.grad.numpy(), atol=5e-2, rtol=2e-2)
|
||||
|
||||
if __name__ == "__main__":
|
||||
unittest.main()
|
||||
|
||||
Reference in New Issue
Block a user