mirror of
https://github.com/tinygrad/tinygrad.git
synced 2026-01-29 00:38:10 -05:00
clean up arange/indexing matchers [pr] (#7427)
* clean up arange/indexing matchers [pr] * syntax for assign
This commit is contained in:
@@ -222,6 +222,8 @@ def no_vectorized_wmma(wmma:UOp):
|
||||
wmma_ex = flatten([[e.gep(i) for i in range(out_sz)] for e in wmmas])
|
||||
return UOp(UOps.VECTORIZE, wmma.dtype, tuple(wmma_ex))
|
||||
|
||||
index_load = UPat.var("buf").index(UPat.any(UPat.var("add")+UPat.var("mul")*UPat(UOps.RANGE,name="rng"), UPat(UOps.RANGE,name="rng"))).load(name="ld")
|
||||
|
||||
# this is symbolic 2.0
|
||||
sym = symbolic_flat+PatternMatcher([
|
||||
# self ASSIGN is just self
|
||||
@@ -260,12 +262,6 @@ sym = symbolic_flat+PatternMatcher([
|
||||
# threefry
|
||||
(UPat(UOps.ALU, dtype=dtypes.uint64, src=(UPat.var("x"), UPat.var("key")), arg=BinaryOps.THREEFRY), threefry2x32),
|
||||
# arange loop folding
|
||||
(UPat(UOps.REDUCE, src=(UPat.any(m2:=UPat.any(
|
||||
m1:=(UPat.var("idx") + UPat.cvar("mval") * UPat(UOps.RANGE, name="rng")),
|
||||
m1 + UPat.var("idx2"), m1 + UPat.var("idx2") + UPat.var("idx3"), UPat(UOps.VECTORIZE, name="vec", src=m1))
|
||||
.lt(UPat.cvar("compval")).where(UPat.cvar("multconst"), UPat.const(None, 0)), m2 + UPat.var("extra")),),
|
||||
arg=BinaryOps.ADD, name="reduce", allow_any_len=True), loop_collapse),
|
||||
# arange loop folding (new ge)
|
||||
(UPat(UOps.REDUCE, src=(UPat.any(m2:=UPat.any(
|
||||
m1:=(UPat.var("idx") + UPat.any(UPat.cvar("mval") * UPat(UOps.RANGE, name="rng"), UPat(UOps.RANGE, name="rng"))),
|
||||
m1 + UPat.var("idx2"), m1 + UPat.var("idx2") + UPat.var("idx3"), UPat(UOps.VECTORIZE, name="vec", src=m1))
|
||||
@@ -273,12 +269,10 @@ sym = symbolic_flat+PatternMatcher([
|
||||
.where(UPat.cvar("multconst"), UPat.const(None, 0)), m2 + UPat.var("extra")),),
|
||||
arg=BinaryOps.ADD, name="reduce", allow_any_len=True), loop_collapse),
|
||||
# indexing, with cast or where
|
||||
(UPat(UOps.REDUCE, src=(UPat.var("idx").eq(UPat(UOps.RANGE, name="rng")).cast()*
|
||||
UPat(UOps.LOAD, src=(UPat.var("buf").index(UPat.any(UPat.var("add")+UPat.var("mul")*UPat(UOps.RANGE,name="rng"), UPat(UOps.RANGE,name="rng"))),),
|
||||
name="ld"),), arg=BinaryOps.ADD, name="reduce", allow_any_len=True), index_collapse),
|
||||
(UPat(UOps.REDUCE, src=(UPat.var("idx").eq(UPat(UOps.RANGE, name="rng")).where(
|
||||
UPat(UOps.LOAD, src=(UPat.var("buf").index(UPat.any(UPat.var("add")+UPat.var("mul")*UPat(UOps.RANGE,name="rng"), UPat(UOps.RANGE,name="rng"))),),
|
||||
name="ld"), UPat.const(None, 0.0)),), arg=BinaryOps.ADD, name="reduce", allow_any_len=True), index_collapse),
|
||||
(UPat(UOps.REDUCE, src=(UPat.var("idx").eq(UPat(UOps.RANGE, name="rng")).cast()*index_load,),
|
||||
arg=BinaryOps.ADD, name="reduce", allow_any_len=True), index_collapse),
|
||||
(UPat(UOps.REDUCE, src=(UPat.var("idx").eq(UPat(UOps.RANGE, name="rng")).where(index_load, UPat.const(None, 0.0)),),
|
||||
arg=BinaryOps.ADD, name="reduce", allow_any_len=True), index_collapse),
|
||||
# GEP/CAST const rules
|
||||
(UPat(UOps.CAST, name="root", src=UPat.cvar("c")), lambda root, c: root.const_like(c.arg)),
|
||||
# ** self folding **
|
||||
@@ -376,10 +370,10 @@ def do_reduce(ctx:List[int], root:UOp):
|
||||
acc = UOp(UOps.DEFINE_ACC, root.dtype,
|
||||
(root.const_like(identity_element(root.arg, root.dtype.scalar())),) + tuple(reduce_parented), (ctx[0],))
|
||||
ctx[0] += 1
|
||||
ret = UOp(UOps.ASSIGN, root.dtype, (acc, acc.alu(root.arg, ret)))
|
||||
ret = acc.assign(acc.alu(root.arg, ret))
|
||||
# for MAX, we can just ignore the unparented
|
||||
if root.arg is BinaryOps.ADD:
|
||||
for r in reduce_unparented:ret = ret * (r.src[1]-r.src[0]).cast(ret.dtype.scalar()).broadcast(ret.dtype.count)
|
||||
for r in reduce_unparented: ret = ret * (r.src[1]-r.src[0]).cast(ret.dtype.scalar()).broadcast(ret.dtype.count)
|
||||
return ret
|
||||
|
||||
def do_contract(con:UOp):
|
||||
|
||||
@@ -321,10 +321,8 @@ class UOp(MathTrait, metaclass=UOpMetaClass):
|
||||
if self.dtype == dtypes.void or (i == tuple(range(len(i))) and self.dtype.vcount == len(i)): return self
|
||||
assert len(i) >= 1 and all(x < self.dtype.vcount for x in i), f"bad GEP on {self.dtype}, {i}"
|
||||
return UOp(UOps.GEP, self.dtype.scalar().vec(len(i)) if len(i) > 1 else self.dtype.scalar(), (self,), i)
|
||||
@staticmethod
|
||||
def load(*src:UOp, dtype:DType): return UOp(UOps.LOAD, dtype, src)
|
||||
@staticmethod
|
||||
def store(*src:UOp): return UOp(UOps.STORE, dtypes.void, src)
|
||||
def load(self, *src:UOp, **kwargs): return UOp(UOps.LOAD, src=(self,)+src, **kwargs)
|
||||
def store(self, *src:UOp, **kwargs): return UOp(UOps.STORE, dtypes.void, (self,)+src, **kwargs)
|
||||
def alu(self, arg, *src:UOp):
|
||||
out_dtype = (self, *src)[-1].dtype
|
||||
if arg in {BinaryOps.CMPLT, BinaryOps.CMPNE} and out_dtype is not None:
|
||||
@@ -341,6 +339,7 @@ class UOp(MathTrait, metaclass=UOpMetaClass):
|
||||
UOp.const(dtype, end) if not isinstance(end, UOp) else end), arg=idx)
|
||||
def reduce(self, op:BinaryOps, *rng:UOp): return UOp(UOps.REDUCE, self.dtype, (self,) + rng, op)
|
||||
def r(self, op, axis): return UOp(UOps.REDUCE_AXIS, self.dtype, (self,), (REDUCE_ALU[op] if op in ReduceOps else op, axis))
|
||||
def assign(self, x:UOp): return UOp(UOps.ASSIGN, self.dtype, (self,x))
|
||||
|
||||
# *** uop Variable stuff ***
|
||||
|
||||
@@ -566,10 +565,9 @@ class UPat(MathTrait):
|
||||
def cast(self, dtype=None): return UPat(UOps.CAST, dtype, (self,))
|
||||
def bitcast(self, dtype=None): return UPat(UOps.BITCAST, dtype, (self,))
|
||||
def gep(self, i:int): return UPat(UOps.GEP, None, (self,), (i,))
|
||||
@staticmethod
|
||||
def load(*src:UPat, **kwargs): return UPat(UOps.LOAD, src=src, **kwargs)
|
||||
@staticmethod
|
||||
def store(*src:UPat, **kwargs): return UPat(UOps.STORE, dtypes.void, src, **kwargs)
|
||||
def load(self, *src:UPat, **kwargs): return UPat(UOps.LOAD, src=(self,)+src, **kwargs)
|
||||
def store(self, *src:UPat, **kwargs): return UPat(UOps.STORE, dtypes.void, (self,)+src, **kwargs)
|
||||
def assign(self, x:UPat): return UPat(UOps.ASSIGN, self.dtype, (self,x))
|
||||
|
||||
def const_like(self, b:ConstType|Variable|Tuple[ConstType]): return UPat.const(self.dtype, cast(ConstType, b))
|
||||
def alu(self, arg, *src:UPat):
|
||||
|
||||
Reference in New Issue
Block a user