clean up arange/indexing matchers [pr] (#7427)

* clean up arange/indexing matchers [pr]

* syntax for assign
This commit is contained in:
George Hotz
2024-10-31 11:12:44 +07:00
committed by GitHub
parent e446e95974
commit fe2bc4c613
2 changed files with 14 additions and 22 deletions

View File

@@ -222,6 +222,8 @@ def no_vectorized_wmma(wmma:UOp):
wmma_ex = flatten([[e.gep(i) for i in range(out_sz)] for e in wmmas])
return UOp(UOps.VECTORIZE, wmma.dtype, tuple(wmma_ex))
index_load = UPat.var("buf").index(UPat.any(UPat.var("add")+UPat.var("mul")*UPat(UOps.RANGE,name="rng"), UPat(UOps.RANGE,name="rng"))).load(name="ld")
# this is symbolic 2.0
sym = symbolic_flat+PatternMatcher([
# self ASSIGN is just self
@@ -260,12 +262,6 @@ sym = symbolic_flat+PatternMatcher([
# threefry
(UPat(UOps.ALU, dtype=dtypes.uint64, src=(UPat.var("x"), UPat.var("key")), arg=BinaryOps.THREEFRY), threefry2x32),
# arange loop folding
(UPat(UOps.REDUCE, src=(UPat.any(m2:=UPat.any(
m1:=(UPat.var("idx") + UPat.cvar("mval") * UPat(UOps.RANGE, name="rng")),
m1 + UPat.var("idx2"), m1 + UPat.var("idx2") + UPat.var("idx3"), UPat(UOps.VECTORIZE, name="vec", src=m1))
.lt(UPat.cvar("compval")).where(UPat.cvar("multconst"), UPat.const(None, 0)), m2 + UPat.var("extra")),),
arg=BinaryOps.ADD, name="reduce", allow_any_len=True), loop_collapse),
# arange loop folding (new ge)
(UPat(UOps.REDUCE, src=(UPat.any(m2:=UPat.any(
m1:=(UPat.var("idx") + UPat.any(UPat.cvar("mval") * UPat(UOps.RANGE, name="rng"), UPat(UOps.RANGE, name="rng"))),
m1 + UPat.var("idx2"), m1 + UPat.var("idx2") + UPat.var("idx3"), UPat(UOps.VECTORIZE, name="vec", src=m1))
@@ -273,12 +269,10 @@ sym = symbolic_flat+PatternMatcher([
.where(UPat.cvar("multconst"), UPat.const(None, 0)), m2 + UPat.var("extra")),),
arg=BinaryOps.ADD, name="reduce", allow_any_len=True), loop_collapse),
# indexing, with cast or where
(UPat(UOps.REDUCE, src=(UPat.var("idx").eq(UPat(UOps.RANGE, name="rng")).cast()*
UPat(UOps.LOAD, src=(UPat.var("buf").index(UPat.any(UPat.var("add")+UPat.var("mul")*UPat(UOps.RANGE,name="rng"), UPat(UOps.RANGE,name="rng"))),),
name="ld"),), arg=BinaryOps.ADD, name="reduce", allow_any_len=True), index_collapse),
(UPat(UOps.REDUCE, src=(UPat.var("idx").eq(UPat(UOps.RANGE, name="rng")).where(
UPat(UOps.LOAD, src=(UPat.var("buf").index(UPat.any(UPat.var("add")+UPat.var("mul")*UPat(UOps.RANGE,name="rng"), UPat(UOps.RANGE,name="rng"))),),
name="ld"), UPat.const(None, 0.0)),), arg=BinaryOps.ADD, name="reduce", allow_any_len=True), index_collapse),
(UPat(UOps.REDUCE, src=(UPat.var("idx").eq(UPat(UOps.RANGE, name="rng")).cast()*index_load,),
arg=BinaryOps.ADD, name="reduce", allow_any_len=True), index_collapse),
(UPat(UOps.REDUCE, src=(UPat.var("idx").eq(UPat(UOps.RANGE, name="rng")).where(index_load, UPat.const(None, 0.0)),),
arg=BinaryOps.ADD, name="reduce", allow_any_len=True), index_collapse),
# GEP/CAST const rules
(UPat(UOps.CAST, name="root", src=UPat.cvar("c")), lambda root, c: root.const_like(c.arg)),
# ** self folding **
@@ -376,10 +370,10 @@ def do_reduce(ctx:List[int], root:UOp):
acc = UOp(UOps.DEFINE_ACC, root.dtype,
(root.const_like(identity_element(root.arg, root.dtype.scalar())),) + tuple(reduce_parented), (ctx[0],))
ctx[0] += 1
ret = UOp(UOps.ASSIGN, root.dtype, (acc, acc.alu(root.arg, ret)))
ret = acc.assign(acc.alu(root.arg, ret))
# for MAX, we can just ignore the unparented
if root.arg is BinaryOps.ADD:
for r in reduce_unparented:ret = ret * (r.src[1]-r.src[0]).cast(ret.dtype.scalar()).broadcast(ret.dtype.count)
for r in reduce_unparented: ret = ret * (r.src[1]-r.src[0]).cast(ret.dtype.scalar()).broadcast(ret.dtype.count)
return ret
def do_contract(con:UOp):

View File

@@ -321,10 +321,8 @@ class UOp(MathTrait, metaclass=UOpMetaClass):
if self.dtype == dtypes.void or (i == tuple(range(len(i))) and self.dtype.vcount == len(i)): return self
assert len(i) >= 1 and all(x < self.dtype.vcount for x in i), f"bad GEP on {self.dtype}, {i}"
return UOp(UOps.GEP, self.dtype.scalar().vec(len(i)) if len(i) > 1 else self.dtype.scalar(), (self,), i)
@staticmethod
def load(*src:UOp, dtype:DType): return UOp(UOps.LOAD, dtype, src)
@staticmethod
def store(*src:UOp): return UOp(UOps.STORE, dtypes.void, src)
def load(self, *src:UOp, **kwargs): return UOp(UOps.LOAD, src=(self,)+src, **kwargs)
def store(self, *src:UOp, **kwargs): return UOp(UOps.STORE, dtypes.void, (self,)+src, **kwargs)
def alu(self, arg, *src:UOp):
out_dtype = (self, *src)[-1].dtype
if arg in {BinaryOps.CMPLT, BinaryOps.CMPNE} and out_dtype is not None:
@@ -341,6 +339,7 @@ class UOp(MathTrait, metaclass=UOpMetaClass):
UOp.const(dtype, end) if not isinstance(end, UOp) else end), arg=idx)
def reduce(self, op:BinaryOps, *rng:UOp): return UOp(UOps.REDUCE, self.dtype, (self,) + rng, op)
def r(self, op, axis): return UOp(UOps.REDUCE_AXIS, self.dtype, (self,), (REDUCE_ALU[op] if op in ReduceOps else op, axis))
def assign(self, x:UOp): return UOp(UOps.ASSIGN, self.dtype, (self,x))
# *** uop Variable stuff ***
@@ -566,10 +565,9 @@ class UPat(MathTrait):
def cast(self, dtype=None): return UPat(UOps.CAST, dtype, (self,))
def bitcast(self, dtype=None): return UPat(UOps.BITCAST, dtype, (self,))
def gep(self, i:int): return UPat(UOps.GEP, None, (self,), (i,))
@staticmethod
def load(*src:UPat, **kwargs): return UPat(UOps.LOAD, src=src, **kwargs)
@staticmethod
def store(*src:UPat, **kwargs): return UPat(UOps.STORE, dtypes.void, src, **kwargs)
def load(self, *src:UPat, **kwargs): return UPat(UOps.LOAD, src=(self,)+src, **kwargs)
def store(self, *src:UPat, **kwargs): return UPat(UOps.STORE, dtypes.void, (self,)+src, **kwargs)
def assign(self, x:UPat): return UPat(UOps.ASSIGN, self.dtype, (self,x))
def const_like(self, b:ConstType|Variable|Tuple[ConstType]): return UPat.const(self.dtype, cast(ConstType, b))
def alu(self, arg, *src:UPat):