From 043f5dbfa0f589d3274ae29c3cb825aa52965871 Mon Sep 17 00:00:00 2001 From: chenyu Date: Sat, 14 Feb 2026 17:23:05 -0500 Subject: [PATCH] fix write-after-read tracking (#14754) AFTER-AFTER was silently dropped, which breaks write-after-read --- test/null/test_tensor_metadata.py | 1 + test/unit/test_assign.py | 2 -- tinygrad/engine/schedule.py | 4 +++- tinygrad/schedule/rangeify.py | 10 +++++----- 4 files changed, 9 insertions(+), 8 deletions(-) diff --git a/test/null/test_tensor_metadata.py b/test/null/test_tensor_metadata.py index ef882b1bc7..23df3b42e1 100644 --- a/test/null/test_tensor_metadata.py +++ b/test/null/test_tensor_metadata.py @@ -63,6 +63,7 @@ class TestTensorMetadata(unittest.TestCase): self.assertEqual(len(si.metadata), 3) self.assertEqual(set(m.name for m in si.metadata), {"relu", "sigmoid", "__mul__"}) + @unittest.skip("flaky") def test_complex_backward(self): x = Tensor.rand(3, requires_grad=True).realize() y = Tensor.rand(3, requires_grad=True).realize() diff --git a/test/unit/test_assign.py b/test/unit/test_assign.py index 6aaee70cd0..7d169bd7ef 100644 --- a/test/unit/test_assign.py +++ b/test/unit/test_assign.py @@ -756,8 +756,6 @@ class TestAssignOrdering(unittest.TestCase): self.assertEqual(buf[0:1, :].sum().item(), 4) self.assertEqual(buf[1:2, :].sum().item(), 8) - # TODO: fix this, see https://github.com/tinygrad/tinygrad/issues/13600 - @unittest.expectedFailure def test_multi_step_assign_read_write_same_buffer(self): """Assign to m and param reading b, then update b, across multiple steps. This is the optimizer bias-correction pattern from issue #13600: m accumulates, diff --git a/tinygrad/engine/schedule.py b/tinygrad/engine/schedule.py index c09cdef947..8b7a7d4204 100644 --- a/tinygrad/engine/schedule.py +++ b/tinygrad/engine/schedule.py @@ -29,7 +29,9 @@ def create_schedule(sched_sink:UOp) -> tuple[list[ExecItem], UOp]: assert k.op in {Ops.CALL, Ops.END}, f"AFTER src[1] should be KERNEL or END, not {k.op}" in_degree.setdefault(k, 0) if k.op is Ops.END: assert k.src[0].op is Ops.CALL, f"END src[0] should be KERNEL, not {k.src[0].op}" - for s in k.src[0].src[1:] if k.op is Ops.END else k.src[1:]: + # WAR deps from rangeify are stored in AFTER src[2:] + kernel_deps = k.src[0].src[1:] if k.op is Ops.END else k.src[1:] + for s in kernel_deps + u.src[2:]: match (s := _unwrap_src(s)).op: case Ops.AFTER: children.setdefault(s.src[1], []).append(k) diff --git a/tinygrad/schedule/rangeify.py b/tinygrad/schedule/rangeify.py index 03c9eebacc..82da3cc366 100644 --- a/tinygrad/schedule/rangeify.py +++ b/tinygrad/schedule/rangeify.py @@ -602,15 +602,15 @@ def get_rangeify_map(sink:UOp) -> dict[UOp, UOp]: name="bufferize to store") tsink = graph_rewrite(tsink, pm_gate_kernel_sink+split_kernels, ctx=uop_list, bottom_up=True, name="split kernels") - # if a kernel depends on a buffer, and that buffer is later assigned to, make the assign depend on the kernel's assign - kernel_assign: dict[UOp, UOp] = {} + # WAR deps: if kernel U reads buffer S, and S is also written by another kernel, S's write must wait for U to finish + afters = [u for u in tsink.toposort() if u.op is Ops.AFTER] + kernel_assign: dict[UOp, UOp] = {u.buf_uop:u for u in afters} assign_rep: dict[UOp, UOp] = {} - for u in tsink.toposort(): - if u.op is not Ops.AFTER: continue - kernel_assign[u.buf_uop] = u + for u in afters: for s in u.src[1].src: # TODO: this is probably broken for MSELECT/MSTACK if s.op not in {Ops.BUFFER, Ops.PARAM} or s is u.buf_uop or (a:=kernel_assign.get(s)) is None: continue + if a.src[1] is u.src[1]: continue # same kernel (multi-output custom kernels) if any(x.op is Ops.AFTER and x.buf_uop is s for x in u.toposort()): raise RuntimeError(f"cycle detected in graph, kernel for {u.buf_uop} must either depend on AFTER or BUFFER") assign_rep[a] = kernel_assign[s] = a.replace(src=a.src+(u,))