find_permutes -> fix_assign_hazard [pr] (#14354)

some noop tweaks and comment updates
This commit is contained in:
chenyu
2026-01-26 14:05:19 -05:00
committed by GitHub
parent e152f1b0f5
commit 145df879c1

View File

@@ -27,12 +27,13 @@ pm_mops = PatternMatcher([
# *****************
# 0. do some cleanup rewrites, mostly copied from the old stuff
def find_permutes(a:UOp, b:UOp, assign:UOp):
if not (permutes:=[s for s in b.toposort(gate=lambda s:s.op not in ALWAYS_CONTIGUOUS)
if s.op in GroupOp.Movement and s.op not in {Ops.RESHAPE, Ops.EXPAND, Ops.PAD, Ops.SHRINK}]): return
target = a.base
for p in permutes:
if any(s is target for s in p.toposort(gate=lambda s:s.op not in ALWAYS_CONTIGUOUS-{Ops.BUFFER})): return assign.replace(src=(a, b.contiguous()))
def fix_assign_hazard(dest:UOp, src:UOp, assign:UOp):
# PERMUTE and FLIP reorder indices, causing read/write races when src and dest are the same buffer
unsafe = {Ops.PERMUTE, Ops.FLIP}
if not (hazards:=[s for s in src.toposort(gate=lambda s:s.op not in ALWAYS_CONTIGUOUS) if s.op in unsafe]): return
for h in hazards:
if any(s is dest.base for s in h.toposort(gate=lambda s:s.op not in ALWAYS_CONTIGUOUS-{Ops.BUFFER})):
return assign.replace(src=(dest, src.contiguous()))
def split_reduceop(reduce:UOp, x:UOp):
if prod(reduce.shape) == 0: return None
@@ -116,8 +117,8 @@ earliest_rewrites = mop_cleanup+PatternMatcher([
lambda x,target,assign: x.f(Ops.CONTIGUOUS, tag=assign.tag) if ((t:=target.base).op is not Ops.BUFFER and \
not (t.op is Ops.MSTACK and all(s.op is Ops.BUFFER for s in t.src))) else None),
# realize before assign if input permutes the target buffer
(UPat(Ops.ASSIGN, src=(UPat.var("a"), UPat.var("b")), name="assign"), find_permutes),
# make source contiguous if it has hazardous movement ops on the dest buffer
(UPat(Ops.ASSIGN, src=(UPat.var("dest"), UPat.var("src")), name="assign"), fix_assign_hazard),
])
# *****************