diff --git a/test/test_assign.py b/test/test_assign.py index a8a0bd4d72..bcd5fc57c2 100644 --- a/test/test_assign.py +++ b/test/test_assign.py @@ -79,10 +79,34 @@ class TestAssign(unittest.TestCase): x.realize() x = Tensor([0]) for _ in range(5): f(x) + assert x.item() == 5 + y = Tensor([0]) for _ in range(4): f(y) assert y.item() == 4 + def test_assign_other_jit(self): + @TinyJit + def f(x, a): + x.assign(a) + x.realize() + x = Tensor([0]) + for i in range(1, 6): + f(x, x.full_like(i).contiguous()) # const would be implicitly folded without contiguous + assert x.item() == i + + def test_assign_add_other_jit(self): + @TinyJit + def f(x, a): + x += a + x.realize() + x = Tensor([0]) + a = 0 + for i in range(1, 6): + a += i + f(x, x.full_like(i).contiguous()) + assert x.item() == a + def test_assign_changes(self): a = Tensor.ones(4).contiguous().realize() old_a = a