mirror of
https://github.com/tinygrad/tinygrad.git
synced 2026-04-29 03:00:14 -04:00
stronger test_graph_input_output_aliasing (#14282)
* stronger test_graph_input_output_aliasing * comfirmed failure
This commit is contained in:
@@ -53,39 +53,47 @@ class TestJitFootguns(unittest.TestCase):
|
||||
self.assertEqual([r1.item(), r2.item(), r3.item()], [2, 4, 6])
|
||||
|
||||
def test_graph_input_output_aliasing(self):
|
||||
"""JIT graph fails when input=output during graph creation, then different input later.
|
||||
"""Test that JIT handles input=output aliasing correctly, simulating LLM generate pattern.
|
||||
|
||||
Graph-only because _input_replace is recomputed at _first_run only when JIT < 2 (graphing enabled).
|
||||
When _first_run happens with input buffer == captured.ret buffer:
|
||||
- get_input_replace() adds output position to input_replace (since buffer matches input_buffers)
|
||||
- GraphRunner.__init__ skips setting buffer at output position (thinks it will be replaced)
|
||||
- But output position isn't a true input, so it's never updated in __call__
|
||||
The LLM generate pattern:
|
||||
1. First "session": multiple iterations where output becomes next input
|
||||
2. Second "session": starts with a NEW input tensor (not the previous output)
|
||||
|
||||
This pattern occurs in LLM token generation where output becomes next input.
|
||||
The bug: GraphRunner computes input_replace during _first_run. If at that time input buffer == output buffer
|
||||
(aliasing), it incorrectly includes the output position in input_replace. Later, when a DIFFERENT input
|
||||
is passed, the output position gets overwritten with the input, corrupting the computation.
|
||||
|
||||
This requires multiple kernels to trigger because single-kernel JITs don't get graphed ("only one kernel doesn't graph").
|
||||
"""
|
||||
from tinygrad import Device
|
||||
if Device[Device.DEFAULT].graph is None or JIT != 1:
|
||||
self.skipTest("test requires JIT graph support")
|
||||
|
||||
# Multiple operations to create multiple kernels that get batched into a GraphRunner
|
||||
@TinyJit
|
||||
def step(x): return (x + 1).realize()
|
||||
def step(x):
|
||||
y = (x + 1).realize() # kernel 1
|
||||
z = (y * 2).realize() # kernel 2
|
||||
return z
|
||||
|
||||
# Phase 1: warmup and capture with fresh inputs
|
||||
# Phase 1: warmup and capture
|
||||
a = Tensor([10]).contiguous().realize()
|
||||
step(a) # warmup (cnt=0)
|
||||
b = Tensor([20]).contiguous().realize()
|
||||
captured_ret = step(b) # capture (cnt=1)
|
||||
x = step(b) # capture (cnt=1), x = (20+1)*2 = 42
|
||||
|
||||
# Phase 2: first exec where input IS captured.ret (triggers _first_run with aliased buffers)
|
||||
result = step(captured_ret) # cnt=2, _first_run=True, input_buf == output_buf
|
||||
self.assertEqual(result.item(), 22) # 21+1=22, correct
|
||||
# Phase 2: first "session" - iterations where output becomes input (triggers _first_run with aliasing)
|
||||
for _ in range(3):
|
||||
x = step(x) # (42+1)*2=86, (86+1)*2=174, (174+1)*2=350
|
||||
self.assertEqual(x.item(), 350)
|
||||
|
||||
# Phase 3: subsequent exec with DIFFERENT input (exposes the bug)
|
||||
c = Tensor([100]).contiguous().realize()
|
||||
result = step(c) # cnt=3, different input buffer
|
||||
# TODO: get_input_replace() incorrectly added output position to input_replace when input buffer == output buffer
|
||||
# fix: output-only positions (in prg.p.outs but not prg.p.ins) should never be added to input_replace
|
||||
self.assertEqual(result.item(), 22) # should be 101!
|
||||
# Phase 3: second "session" - NEW input tensor (simulates new generate() call)
|
||||
# The bug: GraphRunner's input_replace incorrectly includes the output position
|
||||
# When new input y is passed, it overwrites the output buffer, using old value (350) instead of new (100)
|
||||
y = Tensor([100]).contiguous().realize()
|
||||
for _ in range(3):
|
||||
y = step(y) # should be (100+1)*2=202, (202+1)*2=406, (406+1)*2=814
|
||||
self.assertEqual(y.item(), 1406) # TODO: should be 814
|
||||
|
||||
def test_multiple_outputs_same_intermediate(self):
|
||||
"""Multiple outputs derived from the same intermediate - JIT copies aliased inputs to prevent hazard."""
|
||||
|
||||
Reference in New Issue
Block a user