stronger test_graph_input_output_aliasing (#14282)

* stronger test_graph_input_output_aliasing

* comfirmed failure
This commit is contained in:
chenyu
2026-01-22 09:59:34 -05:00
committed by GitHub
parent d7afa02085
commit 8bb61c2490

View File

@@ -53,39 +53,47 @@ class TestJitFootguns(unittest.TestCase):
self.assertEqual([r1.item(), r2.item(), r3.item()], [2, 4, 6])
def test_graph_input_output_aliasing(self):
"""JIT graph fails when input=output during graph creation, then different input later.
"""Test that JIT handles input=output aliasing correctly, simulating LLM generate pattern.
Graph-only because _input_replace is recomputed at _first_run only when JIT < 2 (graphing enabled).
When _first_run happens with input buffer == captured.ret buffer:
- get_input_replace() adds output position to input_replace (since buffer matches input_buffers)
- GraphRunner.__init__ skips setting buffer at output position (thinks it will be replaced)
- But output position isn't a true input, so it's never updated in __call__
The LLM generate pattern:
1. First "session": multiple iterations where output becomes next input
2. Second "session": starts with a NEW input tensor (not the previous output)
This pattern occurs in LLM token generation where output becomes next input.
The bug: GraphRunner computes input_replace during _first_run. If at that time input buffer == output buffer
(aliasing), it incorrectly includes the output position in input_replace. Later, when a DIFFERENT input
is passed, the output position gets overwritten with the input, corrupting the computation.
This requires multiple kernels to trigger because single-kernel JITs don't get graphed ("only one kernel doesn't graph").
"""
from tinygrad import Device
if Device[Device.DEFAULT].graph is None or JIT != 1:
self.skipTest("test requires JIT graph support")
# Multiple operations to create multiple kernels that get batched into a GraphRunner
@TinyJit
def step(x): return (x + 1).realize()
def step(x):
y = (x + 1).realize() # kernel 1
z = (y * 2).realize() # kernel 2
return z
# Phase 1: warmup and capture with fresh inputs
# Phase 1: warmup and capture
a = Tensor([10]).contiguous().realize()
step(a) # warmup (cnt=0)
b = Tensor([20]).contiguous().realize()
captured_ret = step(b) # capture (cnt=1)
x = step(b) # capture (cnt=1), x = (20+1)*2 = 42
# Phase 2: first exec where input IS captured.ret (triggers _first_run with aliased buffers)
result = step(captured_ret) # cnt=2, _first_run=True, input_buf == output_buf
self.assertEqual(result.item(), 22) # 21+1=22, correct
# Phase 2: first "session" - iterations where output becomes input (triggers _first_run with aliasing)
for _ in range(3):
x = step(x) # (42+1)*2=86, (86+1)*2=174, (174+1)*2=350
self.assertEqual(x.item(), 350)
# Phase 3: subsequent exec with DIFFERENT input (exposes the bug)
c = Tensor([100]).contiguous().realize()
result = step(c) # cnt=3, different input buffer
# TODO: get_input_replace() incorrectly added output position to input_replace when input buffer == output buffer
# fix: output-only positions (in prg.p.outs but not prg.p.ins) should never be added to input_replace
self.assertEqual(result.item(), 22) # should be 101!
# Phase 3: second "session" - NEW input tensor (simulates new generate() call)
# The bug: GraphRunner's input_replace incorrectly includes the output position
# When new input y is passed, it overwrites the output buffer, using old value (350) instead of new (100)
y = Tensor([100]).contiguous().realize()
for _ in range(3):
y = step(y) # should be (100+1)*2=202, (202+1)*2=406, (406+1)*2=814
self.assertEqual(y.item(), 1406) # TODO: should be 814
def test_multiple_outputs_same_intermediate(self):
"""Multiple outputs derived from the same intermediate - JIT copies aliased inputs to prevent hazard."""