From b376bd7a2196634ea683c55a460ff3d07cfc3615 Mon Sep 17 00:00:00 2001 From: nimlgen <138685161+nimlgen@users.noreply.github.com> Date: Thu, 12 Feb 2026 15:33:32 +0300 Subject: [PATCH] jit: fix raw in same kernel (#14699) * jit: fix raw in same kernel * fix * ugh * x * simpler --- test/backend/test_jit_footguns.py | 12 ++++++++++-- tinygrad/engine/jit.py | 7 +++++-- 2 files changed, 15 insertions(+), 4 deletions(-) diff --git a/test/backend/test_jit_footguns.py b/test/backend/test_jit_footguns.py index 1fd58c28ea..b7fbab19f7 100644 --- a/test/backend/test_jit_footguns.py +++ b/test/backend/test_jit_footguns.py @@ -23,7 +23,7 @@ ERRORS RAISED (lower priority - at least users know): """ import unittest import numpy as np -from tinygrad import Tensor, TinyJit +from tinygrad import Tensor, TinyJit, Device from tinygrad.engine.jit import JitError from tinygrad.helpers import JIT @@ -66,7 +66,6 @@ class TestJitFootguns(unittest.TestCase): This requires multiple kernels to trigger because single-kernel JITs don't get graphed ("only one kernel doesn't graph"). """ - from tinygrad import Device if Device[Device.DEFAULT].graph is None or JIT != 1: self.skipTest("test requires JIT graph support") @@ -111,6 +110,15 @@ class TestJitFootguns(unittest.TestCase): self.assertEqual(first.numpy().item(), expected_first) buf = new_buf + def test_intra_kernel_output_input_aliasing(self): + """JIT must copy aliased input when output buffer is fed back as input (read-write race in same kernel).""" + N = 1 << 20 + f = TinyJit(lambda buf, new: buf[N//2:].cat(new), prune=True) + buf = Tensor.zeros(N, dtype='int32').contiguous().realize() + for i in range(10): + buf = f(buf, Tensor(np.ones(N//2, dtype=np.int32)*(i+1))) + np.testing.assert_array_equal(buf[:N//2].numpy(), np.full(N//2, i, dtype=np.int32)) + def test_slice_assign_works_without_realize(self): """Slice assign then read from same buffer - pending assigns are side-realized.""" from tinygrad import Variable diff --git a/tinygrad/engine/jit.py b/tinygrad/engine/jit.py index 79fe034d39..e4959bafd8 100644 --- a/tinygrad/engine/jit.py +++ b/tinygrad/engine/jit.py @@ -187,7 +187,10 @@ class CapturedJit(Generic[ReturnType]): # precompute read-after-write hazard detection self._output_to_writer = {b: j for j, ei in enumerate(self.jit_cache) for b in get_out_buffers_for_ei(ei)} self._input_to_max_reader: dict[int, int] = {} - for (j, _), idx in self.input_replace.items(): self._input_to_max_reader[idx] = max(self._input_to_max_reader.get(idx, -1), j) + for (j, i), idx in self.input_replace.items(): + # only buffers that were different during capture but alias at jit time (e.g. feeding output back as input) need the copy. + if self.jit_cache[j].bufs[i] not in get_out_buffers_for_ei(self.jit_cache[j]): + self._input_to_max_reader[idx] = max(self._input_to_max_reader.get(idx, -1), j) self._clear_inputs() def _clear_inputs(self): @@ -218,7 +221,7 @@ class CapturedJit(Generic[ReturnType]): # copy aliased inputs to prevent read-after-write hazard for i, ib in enumerate(input_buffers): - if (writer := self._output_to_writer.get(ib)) is not None and self._input_to_max_reader.get(i, -1) > writer: + if (writer := self._output_to_writer.get(ib)) is not None and self._input_to_max_reader.get(i, -1) >= writer: input_buffers[i] = Buffer(ib.device, ib.size, ib.dtype).ensure_allocated().copyin(ib.as_memoryview()) for (j,i),input_idx in self._input_replace.items(): self._jit_cache[j].bufs[i] = input_buffers[input_idx]