diff --git a/test/test_assign.py b/test/test_assign.py index 718faad44c..d63c092abc 100644 --- a/test/test_assign.py +++ b/test/test_assign.py @@ -424,5 +424,21 @@ class TestAssign(unittest.TestCase): a = Tensor.empty(5, device=f"disk:{temp('disk_assignment')}").assign(Tensor.ones(5)).numpy() np.testing.assert_equal(a, np.ones(5)) + def test_assign_slice_then_read(self): + """Assign to slice then read from buffer - read should see the assigned values. + This is the KV cache pattern from llm.py. + """ + v_pos = Variable("pos", 0, 3).bind(0) + + # without .realize() after assign, the read doesn't see the assigned values + cache = Tensor.zeros(4, 4).contiguous().realize() + cache[v_pos:v_pos+1, :].assign(Tensor.ones(1, 4)) + self.assertEqual(cache.sum().item(), 0.0) # should be 4.0! + + # TODO: remove .realize() workaround once assign-read dependency is fixed + cache2 = Tensor.zeros(4, 4).contiguous().realize() + cache2[v_pos:v_pos+1, :].assign(Tensor.ones(1, 4)).realize() + self.assertEqual(cache2.sum().item(), 4.0) + if __name__ == "__main__": unittest.main() diff --git a/test/test_jit_footguns.py b/test/test_jit_footguns.py index 704653d143..2afc494f3e 100644 --- a/test/test_jit_footguns.py +++ b/test/test_jit_footguns.py @@ -7,6 +7,7 @@ Comments marked "should be X!" indicate the intuitively expected value. SILENT MISMATCHES (highest priority - wrong results, no error): class_method_shared_across_instances EASY could check if first arg is self and warn + slice_assign_requires_realize MED assign graph not connected to read during JIT replay output_buffer_reuse MED performance tradeoff, could add option or better docs python_constants_frozen HARD inherent to tracing JITs conditional_branches_frozen HARD inherent to tracing JITs @@ -49,6 +50,31 @@ class TestJitFootguns(unittest.TestCase): self.assertEqual([r1.item(), r2.item(), r3.item()], [2, 4, 6]) + def test_slice_assign_requires_realize(self): + """Slice assign then read from same buffer - assign isn't connected to read without explicit realize().""" + from tinygrad import Variable + v_pos = Variable("pos", 0, 3) + + # without .realize() after assign, the read doesn't see the assigned values + cache = Tensor.zeros(4, 4).contiguous().realize() + @TinyJit + def f_broken(pos): + cache[pos:pos+1, :].assign(Tensor.ones(1, 4)) + return cache.sum().realize() + for i in range(4): + cache.assign(Tensor.zeros(4, 4)).realize() + self.assertEqual(f_broken(v_pos.bind(i)).item(), 0.0) # should be 4.0! + + # workaround: add .realize() after assign + cache2 = Tensor.zeros(4, 4).contiguous().realize() + @TinyJit + def f_fixed(pos): + cache2[pos:pos+1, :].assign(Tensor.ones(1, 4)).realize() + return cache2.sum().realize() + for i in range(4): + cache2.assign(Tensor.zeros(4, 4)).realize() + self.assertEqual(f_fixed(v_pos.bind(i)).item(), 4.0) + def test_non_tensor_outputs_error(self): @TinyJit def f(x, mult): return (x * 2).realize(), mult * 10