linearizer fix from dsp branch (#9641)

* linearizer fix from dsp branch

* revert that
This commit is contained in:
George Hotz
2025-03-31 14:26:39 +08:00
committed by GitHub
parent ec405b919f
commit e4c545b396
3 changed files with 31 additions and 24 deletions

View File

@@ -34,23 +34,27 @@ def append_to_block(ctx:tuple[dict[UOp, tuple[UOp, ...]], dict[UOp, list[UOp]]],
old_blocks: dict[tuple[UOp, ...], UOp] = {}
new_blocks: dict[tuple[UOp, ...], list[UOp]] = {}
seen_u = set()
for u in x.src:
if u.op is Ops.BLOCK:
# merge sibling blocks. NOTE: blocks must only have one output source
assert u.arg.ctx not in old_blocks, "sibling should never have been created"
old_blocks[u.arg.ctx] = u
if u not in seen_u:
# merge sibling blocks. NOTE: blocks must only have one output source
assert u.arg.ctx not in old_blocks, "sibling should never have been created"
old_blocks[u.arg.ctx] = u
elif u.op not in DONT_PLACE_IN_BLOCK and set(children[u]).issubset(in_this_block):
# if it can go in blocks and all its children are in the block, we add it to the block
if (block_ctx:=block_ctxs[u]) == x.arg.ctx:
# if it's the same context, we place the UOp in this block and append the parents to its srcs
new_srcs.extend(u.src)
to_append.append(u)
else:
# if it's a different context, we create a new block with this UOp
new_blocks.setdefault(block_ctx, []).append(u)
if u not in seen_u:
# if it can go in blocks and all its children are in the block, we add it to the block
if (block_ctx:=block_ctxs[u]) == x.arg.ctx:
# if it's the same context, we place the UOp in this block and append the parents to its srcs
new_srcs.extend(u.src)
to_append.append(u)
else:
# if it's a different context, we create a new block with this UOp
new_blocks.setdefault(block_ctx, []).append(u)
else:
# otherwise, we keep it in the srcs
new_srcs.append(u)
seen_u.add(u)
if len(to_append) == 0 and len(new_blocks) == 0: return None
for rng,lst in new_blocks.items():

View File

@@ -170,6 +170,20 @@ def div_and_mod_folding(x: UOp, y: UOp, which: Literal[Ops.MOD, Ops.IDIV], split
if which is Ops.MOD: return gcd*(rem % (c//gcd)) + const%gcd
return rem//(c//gcd)+quo
gep_pushing = PatternMatcher([
# GEP/VECTORIZE, GEP/GEP, GEP/CONST, GEP/VCONST
(UPat(Ops.GEP, src=(UPat(Ops.GEP, name='g2'),), name='g1'),
lambda g1, g2: g2.src[0].gep(tuple(g2.arg[g1.arg[i]] for i in range(g1.dtype.count)))),
(UPat(Ops.GEP, src=(UPat(Ops.VECTORIZE, name="vec"),), name="gep"),
lambda gep, vec: UOp(Ops.VECTORIZE, gep.dtype, tuple(vec.src[i] for i in gep.arg)) if len(gep.arg) > 1 else vec.src[gep.arg[0]]),
(UPat(Ops.GEP, src=(UPat.cvar("c", vec=False),), name="gep"), lambda gep, c: gep.const_like(c.arg)),
(UPat(Ops.GEP, src=(UPat(Ops.VCONST, name="c"),), name="gep"), lambda gep, c: gep.const_like(tuple(c.arg[x] for x in gep.arg))),
# push all GEPs through ALUs (fix arange stuff)
(UPat(Ops.GEP, src=(UPat((*GroupOp.ALU, Ops.CAST, Ops.BITCAST), name='alu'),), name='gep'),
lambda gep,alu: UOp(alu.op, alu.dtype.scalar().vec(gep.dtype.count), tuple(x.gep(gep.arg) for x in alu.src), alu.arg) \
if not isinstance(gep.dtype, PtrDType) else None),
])
symbolic = symbolic_simple+PatternMatcher([
# ** COMMUTATIVE flipping (only for ints) **
(UPat(GroupOp.Commutative, dtype=dtypes.ints, name='x'), lambda x: x.replace(src=x.src[::-1]) if x.src[1].tuplize < x.src[0].tuplize else None),
@@ -230,18 +244,7 @@ symbolic = symbolic_simple+PatternMatcher([
# ** mod **
# mod folding
(UPat.var("x") % UPat.var("y"), lambda x,y: div_and_mod_folding(x,y,Ops.MOD)),
# GEP/VECTORIZE, GEP/GEP, GEP/CONST, GEP/VCONST
(UPat(Ops.GEP, src=(UPat(Ops.GEP, name='g2'),), name='g1'),
lambda g1, g2: g2.src[0].gep(tuple(g2.arg[g1.arg[i]] for i in range(g1.dtype.count)))),
(UPat(Ops.GEP, src=(UPat(Ops.VECTORIZE, name="vec"),), name="gep"),
lambda gep, vec: UOp(Ops.VECTORIZE, gep.dtype, tuple(vec.src[i] for i in gep.arg)) if len(gep.arg) > 1 else vec.src[gep.arg[0]]),
(UPat(Ops.GEP, src=(UPat.cvar("c", vec=False),), name="gep"), lambda gep, c: gep.const_like(c.arg)),
(UPat(Ops.GEP, src=(UPat(Ops.VCONST, name="c"),), name="gep"), lambda gep, c: gep.const_like(tuple(c.arg[x] for x in gep.arg))),
# push all GEPs through ALUs (fix arange stuff)
(UPat(Ops.GEP, src=(UPat((*GroupOp.ALU, Ops.CAST, Ops.BITCAST), name='alu'),), name='gep'),
lambda gep,alu: UOp(alu.op, alu.dtype.scalar().vec(gep.dtype.count), tuple(x.gep(gep.arg) for x in alu.src), alu.arg) \
if not isinstance(gep.dtype, PtrDType) else None),
])
])+gep_pushing
symbolic_flat = symbolic+PatternMatcher([
# ** combine terms (opinionated) **

View File

@@ -362,7 +362,7 @@ class UOp(MathTrait, metaclass=UOpMetaClass):
ret = self.arg[1] if self.op is Ops.REDUCE_AXIS else self.arg[7]
assert isinstance(ret, tuple) and all(isinstance(x, int) for x in ret), f"axis_arg trying to return {ret}"
return ret
def sink(self, *srcs:UOp): return UOp(Ops.SINK, dtypes.void, (self,)+srcs)
def sink(self, *srcs:UOp, **kwargs): return UOp(Ops.SINK, dtypes.void, (self,)+srcs, **kwargs)
def detach(self): return UOp(Ops.DETACH, self.dtype, (self,))
def index(self, idx:UOp, valid:UOp|None=None): return UOp(Ops.INDEX, self.dtype, (self,idx,valid) if valid is not None else (self,idx))
def const_like(self, b:ConstLike):