mirror of
https://github.com/tinygrad/tinygrad.git
synced 2026-04-29 03:00:14 -04:00
failed assign test cases with write before read (#14148)
slice assign write before read fails now. this is why kv cache needs a realize
This commit is contained in:
@@ -424,5 +424,21 @@ class TestAssign(unittest.TestCase):
|
||||
a = Tensor.empty(5, device=f"disk:{temp('disk_assignment')}").assign(Tensor.ones(5)).numpy()
|
||||
np.testing.assert_equal(a, np.ones(5))
|
||||
|
||||
def test_assign_slice_then_read(self):
|
||||
"""Assign to slice then read from buffer - read should see the assigned values.
|
||||
This is the KV cache pattern from llm.py.
|
||||
"""
|
||||
v_pos = Variable("pos", 0, 3).bind(0)
|
||||
|
||||
# without .realize() after assign, the read doesn't see the assigned values
|
||||
cache = Tensor.zeros(4, 4).contiguous().realize()
|
||||
cache[v_pos:v_pos+1, :].assign(Tensor.ones(1, 4))
|
||||
self.assertEqual(cache.sum().item(), 0.0) # should be 4.0!
|
||||
|
||||
# TODO: remove .realize() workaround once assign-read dependency is fixed
|
||||
cache2 = Tensor.zeros(4, 4).contiguous().realize()
|
||||
cache2[v_pos:v_pos+1, :].assign(Tensor.ones(1, 4)).realize()
|
||||
self.assertEqual(cache2.sum().item(), 4.0)
|
||||
|
||||
if __name__ == "__main__":
|
||||
unittest.main()
|
||||
|
||||
@@ -7,6 +7,7 @@ Comments marked "should be X!" indicate the intuitively expected value.
|
||||
|
||||
SILENT MISMATCHES (highest priority - wrong results, no error):
|
||||
class_method_shared_across_instances EASY could check if first arg is self and warn
|
||||
slice_assign_requires_realize MED assign graph not connected to read during JIT replay
|
||||
output_buffer_reuse MED performance tradeoff, could add option or better docs
|
||||
python_constants_frozen HARD inherent to tracing JITs
|
||||
conditional_branches_frozen HARD inherent to tracing JITs
|
||||
@@ -49,6 +50,31 @@ class TestJitFootguns(unittest.TestCase):
|
||||
|
||||
self.assertEqual([r1.item(), r2.item(), r3.item()], [2, 4, 6])
|
||||
|
||||
def test_slice_assign_requires_realize(self):
|
||||
"""Slice assign then read from same buffer - assign isn't connected to read without explicit realize()."""
|
||||
from tinygrad import Variable
|
||||
v_pos = Variable("pos", 0, 3)
|
||||
|
||||
# without .realize() after assign, the read doesn't see the assigned values
|
||||
cache = Tensor.zeros(4, 4).contiguous().realize()
|
||||
@TinyJit
|
||||
def f_broken(pos):
|
||||
cache[pos:pos+1, :].assign(Tensor.ones(1, 4))
|
||||
return cache.sum().realize()
|
||||
for i in range(4):
|
||||
cache.assign(Tensor.zeros(4, 4)).realize()
|
||||
self.assertEqual(f_broken(v_pos.bind(i)).item(), 0.0) # should be 4.0!
|
||||
|
||||
# workaround: add .realize() after assign
|
||||
cache2 = Tensor.zeros(4, 4).contiguous().realize()
|
||||
@TinyJit
|
||||
def f_fixed(pos):
|
||||
cache2[pos:pos+1, :].assign(Tensor.ones(1, 4)).realize()
|
||||
return cache2.sum().realize()
|
||||
for i in range(4):
|
||||
cache2.assign(Tensor.zeros(4, 4)).realize()
|
||||
self.assertEqual(f_fixed(v_pos.bind(i)).item(), 4.0)
|
||||
|
||||
def test_non_tensor_outputs_error(self):
|
||||
@TinyJit
|
||||
def f(x, mult): return (x * 2).realize(), mult * 10
|
||||
|
||||
Reference in New Issue
Block a user