fix write-after-read tracking (#14754)

AFTER-AFTER was silently dropped, which breaks write-after-read
This commit is contained in:
chenyu
2026-02-14 17:23:05 -05:00
committed by GitHub
parent d79c63a0ff
commit 043f5dbfa0
4 changed files with 9 additions and 8 deletions

View File

@@ -63,6 +63,7 @@ class TestTensorMetadata(unittest.TestCase):
self.assertEqual(len(si.metadata), 3)
self.assertEqual(set(m.name for m in si.metadata), {"relu", "sigmoid", "__mul__"})
@unittest.skip("flaky")
def test_complex_backward(self):
x = Tensor.rand(3, requires_grad=True).realize()
y = Tensor.rand(3, requires_grad=True).realize()

View File

@@ -756,8 +756,6 @@ class TestAssignOrdering(unittest.TestCase):
self.assertEqual(buf[0:1, :].sum().item(), 4)
self.assertEqual(buf[1:2, :].sum().item(), 8)
# TODO: fix this, see https://github.com/tinygrad/tinygrad/issues/13600
@unittest.expectedFailure
def test_multi_step_assign_read_write_same_buffer(self):
"""Assign to m and param reading b, then update b, across multiple steps.
This is the optimizer bias-correction pattern from issue #13600: m accumulates,

View File

@@ -29,7 +29,9 @@ def create_schedule(sched_sink:UOp) -> tuple[list[ExecItem], UOp]:
assert k.op in {Ops.CALL, Ops.END}, f"AFTER src[1] should be KERNEL or END, not {k.op}"
in_degree.setdefault(k, 0)
if k.op is Ops.END: assert k.src[0].op is Ops.CALL, f"END src[0] should be KERNEL, not {k.src[0].op}"
for s in k.src[0].src[1:] if k.op is Ops.END else k.src[1:]:
# WAR deps from rangeify are stored in AFTER src[2:]
kernel_deps = k.src[0].src[1:] if k.op is Ops.END else k.src[1:]
for s in kernel_deps + u.src[2:]:
match (s := _unwrap_src(s)).op:
case Ops.AFTER:
children.setdefault(s.src[1], []).append(k)

View File

@@ -602,15 +602,15 @@ def get_rangeify_map(sink:UOp) -> dict[UOp, UOp]:
name="bufferize to store")
tsink = graph_rewrite(tsink, pm_gate_kernel_sink+split_kernels, ctx=uop_list, bottom_up=True, name="split kernels")
# if a kernel depends on a buffer, and that buffer is later assigned to, make the assign depend on the kernel's assign
kernel_assign: dict[UOp, UOp] = {}
# WAR deps: if kernel U reads buffer S, and S is also written by another kernel, S's write must wait for U to finish
afters = [u for u in tsink.toposort() if u.op is Ops.AFTER]
kernel_assign: dict[UOp, UOp] = {u.buf_uop:u for u in afters}
assign_rep: dict[UOp, UOp] = {}
for u in tsink.toposort():
if u.op is not Ops.AFTER: continue
kernel_assign[u.buf_uop] = u
for u in afters:
for s in u.src[1].src:
# TODO: this is probably broken for MSELECT/MSTACK
if s.op not in {Ops.BUFFER, Ops.PARAM} or s is u.buf_uop or (a:=kernel_assign.get(s)) is None: continue
if a.src[1] is u.src[1]: continue # same kernel (multi-output custom kernels)
if any(x.op is Ops.AFTER and x.buf_uop is s for x in u.toposort()):
raise RuntimeError(f"cycle detected in graph, kernel for {u.buf_uop} must either depend on AFTER or BUFFER")
assign_rep[a] = kernel_assign[s] = a.replace(src=a.src+(u,))