mirror of
https://github.com/tinygrad/tinygrad.git
synced 2026-04-29 03:00:14 -04:00
linearizer fix from dsp branch (#9641)
* linearizer fix from dsp branch * revert that
This commit is contained in:
@@ -34,23 +34,27 @@ def append_to_block(ctx:tuple[dict[UOp, tuple[UOp, ...]], dict[UOp, list[UOp]]],
|
||||
old_blocks: dict[tuple[UOp, ...], UOp] = {}
|
||||
new_blocks: dict[tuple[UOp, ...], list[UOp]] = {}
|
||||
|
||||
seen_u = set()
|
||||
for u in x.src:
|
||||
if u.op is Ops.BLOCK:
|
||||
# merge sibling blocks. NOTE: blocks must only have one output source
|
||||
assert u.arg.ctx not in old_blocks, "sibling should never have been created"
|
||||
old_blocks[u.arg.ctx] = u
|
||||
if u not in seen_u:
|
||||
# merge sibling blocks. NOTE: blocks must only have one output source
|
||||
assert u.arg.ctx not in old_blocks, "sibling should never have been created"
|
||||
old_blocks[u.arg.ctx] = u
|
||||
elif u.op not in DONT_PLACE_IN_BLOCK and set(children[u]).issubset(in_this_block):
|
||||
# if it can go in blocks and all its children are in the block, we add it to the block
|
||||
if (block_ctx:=block_ctxs[u]) == x.arg.ctx:
|
||||
# if it's the same context, we place the UOp in this block and append the parents to its srcs
|
||||
new_srcs.extend(u.src)
|
||||
to_append.append(u)
|
||||
else:
|
||||
# if it's a different context, we create a new block with this UOp
|
||||
new_blocks.setdefault(block_ctx, []).append(u)
|
||||
if u not in seen_u:
|
||||
# if it can go in blocks and all its children are in the block, we add it to the block
|
||||
if (block_ctx:=block_ctxs[u]) == x.arg.ctx:
|
||||
# if it's the same context, we place the UOp in this block and append the parents to its srcs
|
||||
new_srcs.extend(u.src)
|
||||
to_append.append(u)
|
||||
else:
|
||||
# if it's a different context, we create a new block with this UOp
|
||||
new_blocks.setdefault(block_ctx, []).append(u)
|
||||
else:
|
||||
# otherwise, we keep it in the srcs
|
||||
new_srcs.append(u)
|
||||
seen_u.add(u)
|
||||
if len(to_append) == 0 and len(new_blocks) == 0: return None
|
||||
|
||||
for rng,lst in new_blocks.items():
|
||||
|
||||
@@ -170,6 +170,20 @@ def div_and_mod_folding(x: UOp, y: UOp, which: Literal[Ops.MOD, Ops.IDIV], split
|
||||
if which is Ops.MOD: return gcd*(rem % (c//gcd)) + const%gcd
|
||||
return rem//(c//gcd)+quo
|
||||
|
||||
gep_pushing = PatternMatcher([
|
||||
# GEP/VECTORIZE, GEP/GEP, GEP/CONST, GEP/VCONST
|
||||
(UPat(Ops.GEP, src=(UPat(Ops.GEP, name='g2'),), name='g1'),
|
||||
lambda g1, g2: g2.src[0].gep(tuple(g2.arg[g1.arg[i]] for i in range(g1.dtype.count)))),
|
||||
(UPat(Ops.GEP, src=(UPat(Ops.VECTORIZE, name="vec"),), name="gep"),
|
||||
lambda gep, vec: UOp(Ops.VECTORIZE, gep.dtype, tuple(vec.src[i] for i in gep.arg)) if len(gep.arg) > 1 else vec.src[gep.arg[0]]),
|
||||
(UPat(Ops.GEP, src=(UPat.cvar("c", vec=False),), name="gep"), lambda gep, c: gep.const_like(c.arg)),
|
||||
(UPat(Ops.GEP, src=(UPat(Ops.VCONST, name="c"),), name="gep"), lambda gep, c: gep.const_like(tuple(c.arg[x] for x in gep.arg))),
|
||||
# push all GEPs through ALUs (fix arange stuff)
|
||||
(UPat(Ops.GEP, src=(UPat((*GroupOp.ALU, Ops.CAST, Ops.BITCAST), name='alu'),), name='gep'),
|
||||
lambda gep,alu: UOp(alu.op, alu.dtype.scalar().vec(gep.dtype.count), tuple(x.gep(gep.arg) for x in alu.src), alu.arg) \
|
||||
if not isinstance(gep.dtype, PtrDType) else None),
|
||||
])
|
||||
|
||||
symbolic = symbolic_simple+PatternMatcher([
|
||||
# ** COMMUTATIVE flipping (only for ints) **
|
||||
(UPat(GroupOp.Commutative, dtype=dtypes.ints, name='x'), lambda x: x.replace(src=x.src[::-1]) if x.src[1].tuplize < x.src[0].tuplize else None),
|
||||
@@ -230,18 +244,7 @@ symbolic = symbolic_simple+PatternMatcher([
|
||||
# ** mod **
|
||||
# mod folding
|
||||
(UPat.var("x") % UPat.var("y"), lambda x,y: div_and_mod_folding(x,y,Ops.MOD)),
|
||||
# GEP/VECTORIZE, GEP/GEP, GEP/CONST, GEP/VCONST
|
||||
(UPat(Ops.GEP, src=(UPat(Ops.GEP, name='g2'),), name='g1'),
|
||||
lambda g1, g2: g2.src[0].gep(tuple(g2.arg[g1.arg[i]] for i in range(g1.dtype.count)))),
|
||||
(UPat(Ops.GEP, src=(UPat(Ops.VECTORIZE, name="vec"),), name="gep"),
|
||||
lambda gep, vec: UOp(Ops.VECTORIZE, gep.dtype, tuple(vec.src[i] for i in gep.arg)) if len(gep.arg) > 1 else vec.src[gep.arg[0]]),
|
||||
(UPat(Ops.GEP, src=(UPat.cvar("c", vec=False),), name="gep"), lambda gep, c: gep.const_like(c.arg)),
|
||||
(UPat(Ops.GEP, src=(UPat(Ops.VCONST, name="c"),), name="gep"), lambda gep, c: gep.const_like(tuple(c.arg[x] for x in gep.arg))),
|
||||
# push all GEPs through ALUs (fix arange stuff)
|
||||
(UPat(Ops.GEP, src=(UPat((*GroupOp.ALU, Ops.CAST, Ops.BITCAST), name='alu'),), name='gep'),
|
||||
lambda gep,alu: UOp(alu.op, alu.dtype.scalar().vec(gep.dtype.count), tuple(x.gep(gep.arg) for x in alu.src), alu.arg) \
|
||||
if not isinstance(gep.dtype, PtrDType) else None),
|
||||
])
|
||||
])+gep_pushing
|
||||
|
||||
symbolic_flat = symbolic+PatternMatcher([
|
||||
# ** combine terms (opinionated) **
|
||||
|
||||
@@ -362,7 +362,7 @@ class UOp(MathTrait, metaclass=UOpMetaClass):
|
||||
ret = self.arg[1] if self.op is Ops.REDUCE_AXIS else self.arg[7]
|
||||
assert isinstance(ret, tuple) and all(isinstance(x, int) for x in ret), f"axis_arg trying to return {ret}"
|
||||
return ret
|
||||
def sink(self, *srcs:UOp): return UOp(Ops.SINK, dtypes.void, (self,)+srcs)
|
||||
def sink(self, *srcs:UOp, **kwargs): return UOp(Ops.SINK, dtypes.void, (self,)+srcs, **kwargs)
|
||||
def detach(self): return UOp(Ops.DETACH, self.dtype, (self,))
|
||||
def index(self, idx:UOp, valid:UOp|None=None): return UOp(Ops.INDEX, self.dtype, (self,idx,valid) if valid is not None else (self,idx))
|
||||
def const_like(self, b:ConstLike):
|
||||
|
||||
Reference in New Issue
Block a user