From 67d9712ef6cfd99d43b90338c618bb4d70a58589 Mon Sep 17 00:00:00 2001 From: chenyu Date: Sun, 18 Jan 2026 18:48:59 -0500 Subject: [PATCH] jit copy aliased output if it's read later (#14210) --- test/test_jit_footguns.py | 29 +++++++++-------------------- tinygrad/engine/jit.py | 10 ++++++++++ 2 files changed, 19 insertions(+), 20 deletions(-) diff --git a/test/test_jit_footguns.py b/test/test_jit_footguns.py index 282efb1064..82e92a966c 100644 --- a/test/test_jit_footguns.py +++ b/test/test_jit_footguns.py @@ -9,7 +9,6 @@ SILENT MISMATCHES (highest priority - wrong results, no error): class_method_shared_across_instances EASY could check if first arg is self and warn slice_assign_requires_realize MED assign graph not connected to read during JIT replay output_buffer_reuse MED performance tradeoff, could add option or better docs - multiple_outputs_same_intermediate MED outputs derived from same intermediate get aliased during replay python_constants_frozen HARD inherent to tracing JITs conditional_branches_frozen HARD inherent to tracing JITs @@ -52,29 +51,19 @@ class TestJitFootguns(unittest.TestCase): self.assertEqual([r1.item(), r2.item(), r3.item()], [2, 4, 6]) def test_multiple_outputs_same_intermediate(self): - """Multiple outputs derived from the same intermediate get aliased during replay.""" - # TODO: fix this, clone should not be required, related to https://github.com/tinygrad/tinygrad/issues/13364 + """Multiple outputs derived from the same intermediate - JIT copies aliased inputs to prevent hazard.""" @TinyJit - def f_broken(buf, frame): + def f(buf, frame): new_buf = buf[1:].cat(frame, dim=0) return new_buf.contiguous(), new_buf[:1].contiguous() - @TinyJit - def f_fixed(buf, frame): - new_buf = buf[1:].cat(frame, dim=0) - return new_buf.clone().contiguous(), new_buf[:1].contiguous() - - for f, should_work in [(f_broken, False), (f_fixed, True)]: - buf = Tensor([[0], [1], [2]]).contiguous().realize() - for i in range(4): - frame = Tensor([[10+i]]).contiguous().realize() - expected_first = buf[1:2].numpy().item() - new_buf, first = f(buf, frame) - if should_work: - self.assertEqual(first.numpy().item(), expected_first) - else: - if i >= 2: self.assertNotEqual(first.numpy().item(), expected_first) # fails on 3rd iteration! - buf = new_buf + buf = Tensor([[0], [1], [2]]).contiguous().realize() + for i in range(4): + frame = Tensor([[10+i]]).contiguous().realize() + expected_first = buf[1:2].numpy().item() + new_buf, first = f(buf, frame) + self.assertEqual(first.numpy().item(), expected_first) + buf = new_buf def test_slice_assign_requires_realize(self): """Slice assign then read from same buffer - assign isn't connected to read without explicit realize().""" diff --git a/tinygrad/engine/jit.py b/tinygrad/engine/jit.py index 1d5d9edb73..d2c3610610 100644 --- a/tinygrad/engine/jit.py +++ b/tinygrad/engine/jit.py @@ -178,6 +178,10 @@ class CapturedJit(Generic[ReturnType]): self._jit_cache: list[ExecItem] = self.jit_cache self._input_replace: dict[tuple[int, int], int] = self.input_replace self._first_run = True + # precompute read-after-write hazard detection + self._output_to_writer = {b: j for j, ei in enumerate(self.jit_cache) for b in get_out_buffers_for_ei(ei)} + self._input_to_max_reader: dict[int, int] = {} + for (j, _), idx in self.input_replace.items(): self._input_to_max_reader[idx] = max(self._input_to_max_reader.get(idx, -1), j) self._clear_inputs() def _clear_inputs(self): @@ -205,6 +209,12 @@ class CapturedJit(Generic[ReturnType]): # assign inputs for idx, offset, device, size, dtype in self.extra_view_inputs: input_buffers.append(Buffer(device, size, dtype, base=input_buffers[idx], offset=offset).ensure_allocated()) + + # copy aliased inputs to prevent read-after-write hazard + for i, ib in enumerate(input_buffers): + if (writer := self._output_to_writer.get(ib)) is not None and self._input_to_max_reader.get(i, -1) > writer: + input_buffers[i] = Buffer(ib.device, ib.size, ib.dtype).ensure_allocated().copyin(ib.as_buffer()) + for (j,i),input_idx in self._input_replace.items(): self._jit_cache[j].bufs[i] = input_buffers[input_idx] # Condense the items into a graph executor.