jit: fix raw in same kernel (#14699)

* jit: fix raw in same kernel

* fix

* ugh

* x

* simpler
This commit is contained in:
nimlgen
2026-02-12 15:33:32 +03:00
committed by GitHub
parent 19e68a1833
commit b376bd7a21
2 changed files with 15 additions and 4 deletions

View File

@@ -23,7 +23,7 @@ ERRORS RAISED (lower priority - at least users know):
"""
import unittest
import numpy as np
from tinygrad import Tensor, TinyJit
from tinygrad import Tensor, TinyJit, Device
from tinygrad.engine.jit import JitError
from tinygrad.helpers import JIT
@@ -66,7 +66,6 @@ class TestJitFootguns(unittest.TestCase):
This requires multiple kernels to trigger because single-kernel JITs don't get graphed ("only one kernel doesn't graph").
"""
from tinygrad import Device
if Device[Device.DEFAULT].graph is None or JIT != 1:
self.skipTest("test requires JIT graph support")
@@ -111,6 +110,15 @@ class TestJitFootguns(unittest.TestCase):
self.assertEqual(first.numpy().item(), expected_first)
buf = new_buf
def test_intra_kernel_output_input_aliasing(self):
"""JIT must copy aliased input when output buffer is fed back as input (read-write race in same kernel)."""
N = 1 << 20
f = TinyJit(lambda buf, new: buf[N//2:].cat(new), prune=True)
buf = Tensor.zeros(N, dtype='int32').contiguous().realize()
for i in range(10):
buf = f(buf, Tensor(np.ones(N//2, dtype=np.int32)*(i+1)))
np.testing.assert_array_equal(buf[:N//2].numpy(), np.full(N//2, i, dtype=np.int32))
def test_slice_assign_works_without_realize(self):
"""Slice assign then read from same buffer - pending assigns are side-realized."""
from tinygrad import Variable

View File

@@ -187,7 +187,10 @@ class CapturedJit(Generic[ReturnType]):
# precompute read-after-write hazard detection
self._output_to_writer = {b: j for j, ei in enumerate(self.jit_cache) for b in get_out_buffers_for_ei(ei)}
self._input_to_max_reader: dict[int, int] = {}
for (j, _), idx in self.input_replace.items(): self._input_to_max_reader[idx] = max(self._input_to_max_reader.get(idx, -1), j)
for (j, i), idx in self.input_replace.items():
# only buffers that were different during capture but alias at jit time (e.g. feeding output back as input) need the copy.
if self.jit_cache[j].bufs[i] not in get_out_buffers_for_ei(self.jit_cache[j]):
self._input_to_max_reader[idx] = max(self._input_to_max_reader.get(idx, -1), j)
self._clear_inputs()
def _clear_inputs(self):
@@ -218,7 +221,7 @@ class CapturedJit(Generic[ReturnType]):
# copy aliased inputs to prevent read-after-write hazard
for i, ib in enumerate(input_buffers):
if (writer := self._output_to_writer.get(ib)) is not None and self._input_to_max_reader.get(i, -1) > writer:
if (writer := self._output_to_writer.get(ib)) is not None and self._input_to_max_reader.get(i, -1) >= writer:
input_buffers[i] = Buffer(ib.device, ib.size, ib.dtype).ensure_allocated().copyin(ib.as_memoryview())
for (j,i),input_idx in self._input_replace.items(): self._jit_cache[j].bufs[i] = input_buffers[input_idx]