diff --git a/tinygrad/schedule/rangeify.py b/tinygrad/schedule/rangeify.py index e01b97be3f..017fd00d4a 100644 --- a/tinygrad/schedule/rangeify.py +++ b/tinygrad/schedule/rangeify.py @@ -27,12 +27,13 @@ pm_mops = PatternMatcher([ # ***************** # 0. do some cleanup rewrites, mostly copied from the old stuff -def find_permutes(a:UOp, b:UOp, assign:UOp): - if not (permutes:=[s for s in b.toposort(gate=lambda s:s.op not in ALWAYS_CONTIGUOUS) - if s.op in GroupOp.Movement and s.op not in {Ops.RESHAPE, Ops.EXPAND, Ops.PAD, Ops.SHRINK}]): return - target = a.base - for p in permutes: - if any(s is target for s in p.toposort(gate=lambda s:s.op not in ALWAYS_CONTIGUOUS-{Ops.BUFFER})): return assign.replace(src=(a, b.contiguous())) +def fix_assign_hazard(dest:UOp, src:UOp, assign:UOp): + # PERMUTE and FLIP reorder indices, causing read/write races when src and dest are the same buffer + unsafe = {Ops.PERMUTE, Ops.FLIP} + if not (hazards:=[s for s in src.toposort(gate=lambda s:s.op not in ALWAYS_CONTIGUOUS) if s.op in unsafe]): return + for h in hazards: + if any(s is dest.base for s in h.toposort(gate=lambda s:s.op not in ALWAYS_CONTIGUOUS-{Ops.BUFFER})): + return assign.replace(src=(dest, src.contiguous())) def split_reduceop(reduce:UOp, x:UOp): if prod(reduce.shape) == 0: return None @@ -116,8 +117,8 @@ earliest_rewrites = mop_cleanup+PatternMatcher([ lambda x,target,assign: x.f(Ops.CONTIGUOUS, tag=assign.tag) if ((t:=target.base).op is not Ops.BUFFER and \ not (t.op is Ops.MSTACK and all(s.op is Ops.BUFFER for s in t.src))) else None), - # realize before assign if input permutes the target buffer - (UPat(Ops.ASSIGN, src=(UPat.var("a"), UPat.var("b")), name="assign"), find_permutes), + # make source contiguous if it has hazardous movement ops on the dest buffer + (UPat(Ops.ASSIGN, src=(UPat.var("dest"), UPat.var("src")), name="assign"), fix_assign_hazard), ]) # *****************