From 8bb61c2490bc3e77ed5eb080640f8d7754201e1a Mon Sep 17 00:00:00 2001 From: chenyu Date: Thu, 22 Jan 2026 09:59:34 -0500 Subject: [PATCH] stronger test_graph_input_output_aliasing (#14282) * stronger test_graph_input_output_aliasing * comfirmed failure --- test/test_jit_footguns.py | 46 +++++++++++++++++++++++---------------- 1 file changed, 27 insertions(+), 19 deletions(-) diff --git a/test/test_jit_footguns.py b/test/test_jit_footguns.py index 7142cb1d91..534026268c 100644 --- a/test/test_jit_footguns.py +++ b/test/test_jit_footguns.py @@ -53,39 +53,47 @@ class TestJitFootguns(unittest.TestCase): self.assertEqual([r1.item(), r2.item(), r3.item()], [2, 4, 6]) def test_graph_input_output_aliasing(self): - """JIT graph fails when input=output during graph creation, then different input later. + """Test that JIT handles input=output aliasing correctly, simulating LLM generate pattern. - Graph-only because _input_replace is recomputed at _first_run only when JIT < 2 (graphing enabled). - When _first_run happens with input buffer == captured.ret buffer: - - get_input_replace() adds output position to input_replace (since buffer matches input_buffers) - - GraphRunner.__init__ skips setting buffer at output position (thinks it will be replaced) - - But output position isn't a true input, so it's never updated in __call__ + The LLM generate pattern: + 1. First "session": multiple iterations where output becomes next input + 2. Second "session": starts with a NEW input tensor (not the previous output) - This pattern occurs in LLM token generation where output becomes next input. + The bug: GraphRunner computes input_replace during _first_run. If at that time input buffer == output buffer + (aliasing), it incorrectly includes the output position in input_replace. Later, when a DIFFERENT input + is passed, the output position gets overwritten with the input, corrupting the computation. + + This requires multiple kernels to trigger because single-kernel JITs don't get graphed ("only one kernel doesn't graph"). """ from tinygrad import Device if Device[Device.DEFAULT].graph is None or JIT != 1: self.skipTest("test requires JIT graph support") + # Multiple operations to create multiple kernels that get batched into a GraphRunner @TinyJit - def step(x): return (x + 1).realize() + def step(x): + y = (x + 1).realize() # kernel 1 + z = (y * 2).realize() # kernel 2 + return z - # Phase 1: warmup and capture with fresh inputs + # Phase 1: warmup and capture a = Tensor([10]).contiguous().realize() step(a) # warmup (cnt=0) b = Tensor([20]).contiguous().realize() - captured_ret = step(b) # capture (cnt=1) + x = step(b) # capture (cnt=1), x = (20+1)*2 = 42 - # Phase 2: first exec where input IS captured.ret (triggers _first_run with aliased buffers) - result = step(captured_ret) # cnt=2, _first_run=True, input_buf == output_buf - self.assertEqual(result.item(), 22) # 21+1=22, correct + # Phase 2: first "session" - iterations where output becomes input (triggers _first_run with aliasing) + for _ in range(3): + x = step(x) # (42+1)*2=86, (86+1)*2=174, (174+1)*2=350 + self.assertEqual(x.item(), 350) - # Phase 3: subsequent exec with DIFFERENT input (exposes the bug) - c = Tensor([100]).contiguous().realize() - result = step(c) # cnt=3, different input buffer - # TODO: get_input_replace() incorrectly added output position to input_replace when input buffer == output buffer - # fix: output-only positions (in prg.p.outs but not prg.p.ins) should never be added to input_replace - self.assertEqual(result.item(), 22) # should be 101! + # Phase 3: second "session" - NEW input tensor (simulates new generate() call) + # The bug: GraphRunner's input_replace incorrectly includes the output position + # When new input y is passed, it overwrites the output buffer, using old value (350) instead of new (100) + y = Tensor([100]).contiguous().realize() + for _ in range(3): + y = step(y) # should be (100+1)*2=202, (202+1)*2=406, (406+1)*2=814 + self.assertEqual(y.item(), 1406) # TODO: should be 814 def test_multiple_outputs_same_intermediate(self): """Multiple outputs derived from the same intermediate - JIT copies aliased inputs to prevent hazard."""