failed assign test cases with write before read (#14148)

slice assign write before read fails now. this is why kv cache needs a realize
This commit is contained in:
chenyu
2026-01-14 10:30:50 -05:00
committed by GitHub
parent 986e865830
commit 899a56446e
2 changed files with 42 additions and 0 deletions

View File

@@ -424,5 +424,21 @@ class TestAssign(unittest.TestCase):
a = Tensor.empty(5, device=f"disk:{temp('disk_assignment')}").assign(Tensor.ones(5)).numpy()
np.testing.assert_equal(a, np.ones(5))
def test_assign_slice_then_read(self):
"""Assign to slice then read from buffer - read should see the assigned values.
This is the KV cache pattern from llm.py.
"""
v_pos = Variable("pos", 0, 3).bind(0)
# without .realize() after assign, the read doesn't see the assigned values
cache = Tensor.zeros(4, 4).contiguous().realize()
cache[v_pos:v_pos+1, :].assign(Tensor.ones(1, 4))
self.assertEqual(cache.sum().item(), 0.0) # should be 4.0!
# TODO: remove .realize() workaround once assign-read dependency is fixed
cache2 = Tensor.zeros(4, 4).contiguous().realize()
cache2[v_pos:v_pos+1, :].assign(Tensor.ones(1, 4)).realize()
self.assertEqual(cache2.sum().item(), 4.0)
if __name__ == "__main__":
unittest.main()

View File

@@ -7,6 +7,7 @@ Comments marked "should be X!" indicate the intuitively expected value.
SILENT MISMATCHES (highest priority - wrong results, no error):
class_method_shared_across_instances EASY could check if first arg is self and warn
slice_assign_requires_realize MED assign graph not connected to read during JIT replay
output_buffer_reuse MED performance tradeoff, could add option or better docs
python_constants_frozen HARD inherent to tracing JITs
conditional_branches_frozen HARD inherent to tracing JITs
@@ -49,6 +50,31 @@ class TestJitFootguns(unittest.TestCase):
self.assertEqual([r1.item(), r2.item(), r3.item()], [2, 4, 6])
def test_slice_assign_requires_realize(self):
"""Slice assign then read from same buffer - assign isn't connected to read without explicit realize()."""
from tinygrad import Variable
v_pos = Variable("pos", 0, 3)
# without .realize() after assign, the read doesn't see the assigned values
cache = Tensor.zeros(4, 4).contiguous().realize()
@TinyJit
def f_broken(pos):
cache[pos:pos+1, :].assign(Tensor.ones(1, 4))
return cache.sum().realize()
for i in range(4):
cache.assign(Tensor.zeros(4, 4)).realize()
self.assertEqual(f_broken(v_pos.bind(i)).item(), 0.0) # should be 4.0!
# workaround: add .realize() after assign
cache2 = Tensor.zeros(4, 4).contiguous().realize()
@TinyJit
def f_fixed(pos):
cache2[pos:pos+1, :].assign(Tensor.ones(1, 4)).realize()
return cache2.sum().realize()
for i in range(4):
cache2.assign(Tensor.zeros(4, 4)).realize()
self.assertEqual(f_fixed(v_pos.bind(i)).item(), 4.0)
def test_non_tensor_outputs_error(self):
@TinyJit
def f(x, mult): return (x * 2).realize(), mult * 10