diff --git a/tinygrad/uop/ops.py b/tinygrad/uop/ops.py index ff54ccbe27..c705d745ae 100644 --- a/tinygrad/uop/ops.py +++ b/tinygrad/uop/ops.py @@ -234,6 +234,7 @@ class UOp(MathTrait, metaclass=UOpMetaClass): return UOp(Ops.GEP, self.dtype.scalar().vec(len(i)) if len(i) > 1 else self.dtype.scalar(), (self,), i) def load(self, *src:UOp, **kwargs): return UOp(Ops.LOAD, dtype=kwargs.pop("dtype", self.dtype.base), src=(self,)+src, **kwargs) def store(self, *src:UOp, **kwargs): return UOp(Ops.STORE, dtypes.void, (self,)+src, **kwargs) + def assign(self, x:UOp): return UOp(Ops.ASSIGN, self.dtype, (self, x)) def alu(self, arg, *src:UOp): out_dtype = (self, *src)[-1].dtype if arg in {Ops.CMPLT, Ops.CMPNE}: out_dtype = dtypes.bool.vec(out_dtype.count) if out_dtype.count > 1 else dtypes.bool @@ -266,7 +267,6 @@ class UOp(MathTrait, metaclass=UOpMetaClass): ret, new_axis = self, axis ret = UOp(Ops.REDUCE_AXIS, self.dtype, (ret,), (op, new_axis)) return ret.reshape(tuple([x if i not in axis else 1 for i,x in enumerate(self.shape)])) - def assign(self, x:UOp): return UOp(Ops.ASSIGN, self.dtype, (self,x)) def reduce(self, *src:UOp, **kwargs): return UOp(Ops.REDUCE, kwargs.pop('dtype', self.dtype), src=(self,)+src, **kwargs) def contiguous(self): return self.alu(Ops.CONTIGUOUS) def contiguous_backward(self): return self.alu(Ops.CONTIGUOUS_BACKWARD) diff --git a/tinygrad/uop/symbolic.py b/tinygrad/uop/symbolic.py index 3cbe4f6bb6..a2904760d4 100644 --- a/tinygrad/uop/symbolic.py +++ b/tinygrad/uop/symbolic.py @@ -461,8 +461,8 @@ sym = symbolic_flat+PatternMatcher([ # remove VECTORIZE from SINK/BARRIER (UPat(Ops.BARRIER, src=(UPat((Ops.VECTORIZE, Ops.SINK), name='sink'),)), lambda sink: UOp(Ops.BARRIER, dtypes.void, sink.src)), (UPat(Ops.SINK, name="root"), - lambda root: UOp(Ops.SINK, root.dtype, tuple(flatten(x.src if x.op in {Ops.SINK, Ops.UNROLL} else (x,) for x in root.src)), root.arg) - if any(x.op in {Ops.SINK, Ops.UNROLL} for x in root.src) else None), + lambda root: UOp(Ops.SINK, root.dtype, tuple(flatten(x.src if x.op in {Ops.SINK, Ops.UNROLL, Ops.PTRCAT} else (x,) for x in root.src)), root.arg) + if any(x.op in {Ops.SINK, Ops.UNROLL, Ops.PTRCAT} for x in root.src) else None), ((UPat.var("x") * UPat.var("x")).reciprocal(), lambda x: x.reciprocal()*x.reciprocal()), # 1/(x^c) -> (1/x)^c ((UPat.var("x") * UPat.var("x") * UPat.var("x")).reciprocal(), lambda x: x.reciprocal()*x.reciprocal()*x.reciprocal()), (UPat.var("x") * ((1+UPat.var("x")).reciprocal().named("d")), lambda x,d: 1-d), # x*/(1+x) -> 1-1/(1+x)