mirror of
https://github.com/tinygrad/tinygrad.git
synced 2026-04-29 03:00:14 -04:00
find_permutes -> fix_assign_hazard [pr] (#14354)
some noop tweaks and comment updates
This commit is contained in:
@@ -27,12 +27,13 @@ pm_mops = PatternMatcher([
|
||||
# *****************
|
||||
# 0. do some cleanup rewrites, mostly copied from the old stuff
|
||||
|
||||
def find_permutes(a:UOp, b:UOp, assign:UOp):
|
||||
if not (permutes:=[s for s in b.toposort(gate=lambda s:s.op not in ALWAYS_CONTIGUOUS)
|
||||
if s.op in GroupOp.Movement and s.op not in {Ops.RESHAPE, Ops.EXPAND, Ops.PAD, Ops.SHRINK}]): return
|
||||
target = a.base
|
||||
for p in permutes:
|
||||
if any(s is target for s in p.toposort(gate=lambda s:s.op not in ALWAYS_CONTIGUOUS-{Ops.BUFFER})): return assign.replace(src=(a, b.contiguous()))
|
||||
def fix_assign_hazard(dest:UOp, src:UOp, assign:UOp):
|
||||
# PERMUTE and FLIP reorder indices, causing read/write races when src and dest are the same buffer
|
||||
unsafe = {Ops.PERMUTE, Ops.FLIP}
|
||||
if not (hazards:=[s for s in src.toposort(gate=lambda s:s.op not in ALWAYS_CONTIGUOUS) if s.op in unsafe]): return
|
||||
for h in hazards:
|
||||
if any(s is dest.base for s in h.toposort(gate=lambda s:s.op not in ALWAYS_CONTIGUOUS-{Ops.BUFFER})):
|
||||
return assign.replace(src=(dest, src.contiguous()))
|
||||
|
||||
def split_reduceop(reduce:UOp, x:UOp):
|
||||
if prod(reduce.shape) == 0: return None
|
||||
@@ -116,8 +117,8 @@ earliest_rewrites = mop_cleanup+PatternMatcher([
|
||||
lambda x,target,assign: x.f(Ops.CONTIGUOUS, tag=assign.tag) if ((t:=target.base).op is not Ops.BUFFER and \
|
||||
not (t.op is Ops.MSTACK and all(s.op is Ops.BUFFER for s in t.src))) else None),
|
||||
|
||||
# realize before assign if input permutes the target buffer
|
||||
(UPat(Ops.ASSIGN, src=(UPat.var("a"), UPat.var("b")), name="assign"), find_permutes),
|
||||
# make source contiguous if it has hazardous movement ops on the dest buffer
|
||||
(UPat(Ops.ASSIGN, src=(UPat.var("dest"), UPat.var("src")), name="assign"), fix_assign_hazard),
|
||||
])
|
||||
|
||||
# *****************
|
||||
|
||||
Reference in New Issue
Block a user