diff --git a/tinygrad/codegen/uopgraph.py b/tinygrad/codegen/uopgraph.py index 22e9f91aaa..509efc1ae0 100644 --- a/tinygrad/codegen/uopgraph.py +++ b/tinygrad/codegen/uopgraph.py @@ -222,6 +222,8 @@ def no_vectorized_wmma(wmma:UOp): wmma_ex = flatten([[e.gep(i) for i in range(out_sz)] for e in wmmas]) return UOp(UOps.VECTORIZE, wmma.dtype, tuple(wmma_ex)) +index_load = UPat.var("buf").index(UPat.any(UPat.var("add")+UPat.var("mul")*UPat(UOps.RANGE,name="rng"), UPat(UOps.RANGE,name="rng"))).load(name="ld") + # this is symbolic 2.0 sym = symbolic_flat+PatternMatcher([ # self ASSIGN is just self @@ -260,12 +262,6 @@ sym = symbolic_flat+PatternMatcher([ # threefry (UPat(UOps.ALU, dtype=dtypes.uint64, src=(UPat.var("x"), UPat.var("key")), arg=BinaryOps.THREEFRY), threefry2x32), # arange loop folding - (UPat(UOps.REDUCE, src=(UPat.any(m2:=UPat.any( - m1:=(UPat.var("idx") + UPat.cvar("mval") * UPat(UOps.RANGE, name="rng")), - m1 + UPat.var("idx2"), m1 + UPat.var("idx2") + UPat.var("idx3"), UPat(UOps.VECTORIZE, name="vec", src=m1)) - .lt(UPat.cvar("compval")).where(UPat.cvar("multconst"), UPat.const(None, 0)), m2 + UPat.var("extra")),), - arg=BinaryOps.ADD, name="reduce", allow_any_len=True), loop_collapse), - # arange loop folding (new ge) (UPat(UOps.REDUCE, src=(UPat.any(m2:=UPat.any( m1:=(UPat.var("idx") + UPat.any(UPat.cvar("mval") * UPat(UOps.RANGE, name="rng"), UPat(UOps.RANGE, name="rng"))), m1 + UPat.var("idx2"), m1 + UPat.var("idx2") + UPat.var("idx3"), UPat(UOps.VECTORIZE, name="vec", src=m1)) @@ -273,12 +269,10 @@ sym = symbolic_flat+PatternMatcher([ .where(UPat.cvar("multconst"), UPat.const(None, 0)), m2 + UPat.var("extra")),), arg=BinaryOps.ADD, name="reduce", allow_any_len=True), loop_collapse), # indexing, with cast or where - (UPat(UOps.REDUCE, src=(UPat.var("idx").eq(UPat(UOps.RANGE, name="rng")).cast()* - UPat(UOps.LOAD, src=(UPat.var("buf").index(UPat.any(UPat.var("add")+UPat.var("mul")*UPat(UOps.RANGE,name="rng"), UPat(UOps.RANGE,name="rng"))),), - name="ld"),), arg=BinaryOps.ADD, name="reduce", allow_any_len=True), index_collapse), - (UPat(UOps.REDUCE, src=(UPat.var("idx").eq(UPat(UOps.RANGE, name="rng")).where( - UPat(UOps.LOAD, src=(UPat.var("buf").index(UPat.any(UPat.var("add")+UPat.var("mul")*UPat(UOps.RANGE,name="rng"), UPat(UOps.RANGE,name="rng"))),), - name="ld"), UPat.const(None, 0.0)),), arg=BinaryOps.ADD, name="reduce", allow_any_len=True), index_collapse), + (UPat(UOps.REDUCE, src=(UPat.var("idx").eq(UPat(UOps.RANGE, name="rng")).cast()*index_load,), + arg=BinaryOps.ADD, name="reduce", allow_any_len=True), index_collapse), + (UPat(UOps.REDUCE, src=(UPat.var("idx").eq(UPat(UOps.RANGE, name="rng")).where(index_load, UPat.const(None, 0.0)),), + arg=BinaryOps.ADD, name="reduce", allow_any_len=True), index_collapse), # GEP/CAST const rules (UPat(UOps.CAST, name="root", src=UPat.cvar("c")), lambda root, c: root.const_like(c.arg)), # ** self folding ** @@ -376,10 +370,10 @@ def do_reduce(ctx:List[int], root:UOp): acc = UOp(UOps.DEFINE_ACC, root.dtype, (root.const_like(identity_element(root.arg, root.dtype.scalar())),) + tuple(reduce_parented), (ctx[0],)) ctx[0] += 1 - ret = UOp(UOps.ASSIGN, root.dtype, (acc, acc.alu(root.arg, ret))) + ret = acc.assign(acc.alu(root.arg, ret)) # for MAX, we can just ignore the unparented if root.arg is BinaryOps.ADD: - for r in reduce_unparented:ret = ret * (r.src[1]-r.src[0]).cast(ret.dtype.scalar()).broadcast(ret.dtype.count) + for r in reduce_unparented: ret = ret * (r.src[1]-r.src[0]).cast(ret.dtype.scalar()).broadcast(ret.dtype.count) return ret def do_contract(con:UOp): diff --git a/tinygrad/ops.py b/tinygrad/ops.py index dc911316da..10dbe80007 100644 --- a/tinygrad/ops.py +++ b/tinygrad/ops.py @@ -321,10 +321,8 @@ class UOp(MathTrait, metaclass=UOpMetaClass): if self.dtype == dtypes.void or (i == tuple(range(len(i))) and self.dtype.vcount == len(i)): return self assert len(i) >= 1 and all(x < self.dtype.vcount for x in i), f"bad GEP on {self.dtype}, {i}" return UOp(UOps.GEP, self.dtype.scalar().vec(len(i)) if len(i) > 1 else self.dtype.scalar(), (self,), i) - @staticmethod - def load(*src:UOp, dtype:DType): return UOp(UOps.LOAD, dtype, src) - @staticmethod - def store(*src:UOp): return UOp(UOps.STORE, dtypes.void, src) + def load(self, *src:UOp, **kwargs): return UOp(UOps.LOAD, src=(self,)+src, **kwargs) + def store(self, *src:UOp, **kwargs): return UOp(UOps.STORE, dtypes.void, (self,)+src, **kwargs) def alu(self, arg, *src:UOp): out_dtype = (self, *src)[-1].dtype if arg in {BinaryOps.CMPLT, BinaryOps.CMPNE} and out_dtype is not None: @@ -341,6 +339,7 @@ class UOp(MathTrait, metaclass=UOpMetaClass): UOp.const(dtype, end) if not isinstance(end, UOp) else end), arg=idx) def reduce(self, op:BinaryOps, *rng:UOp): return UOp(UOps.REDUCE, self.dtype, (self,) + rng, op) def r(self, op, axis): return UOp(UOps.REDUCE_AXIS, self.dtype, (self,), (REDUCE_ALU[op] if op in ReduceOps else op, axis)) + def assign(self, x:UOp): return UOp(UOps.ASSIGN, self.dtype, (self,x)) # *** uop Variable stuff *** @@ -566,10 +565,9 @@ class UPat(MathTrait): def cast(self, dtype=None): return UPat(UOps.CAST, dtype, (self,)) def bitcast(self, dtype=None): return UPat(UOps.BITCAST, dtype, (self,)) def gep(self, i:int): return UPat(UOps.GEP, None, (self,), (i,)) - @staticmethod - def load(*src:UPat, **kwargs): return UPat(UOps.LOAD, src=src, **kwargs) - @staticmethod - def store(*src:UPat, **kwargs): return UPat(UOps.STORE, dtypes.void, src, **kwargs) + def load(self, *src:UPat, **kwargs): return UPat(UOps.LOAD, src=(self,)+src, **kwargs) + def store(self, *src:UPat, **kwargs): return UPat(UOps.STORE, dtypes.void, (self,)+src, **kwargs) + def assign(self, x:UPat): return UPat(UOps.ASSIGN, self.dtype, (self,x)) def const_like(self, b:ConstType|Variable|Tuple[ConstType]): return UPat.const(self.dtype, cast(ConstType, b)) def alu(self, arg, *src:UPat):