mirror of
https://github.com/tinygrad/tinygrad.git
synced 2026-01-08 22:48:25 -05:00
put assign and store next to each other [pr] (#11306)
This commit is contained in:
@@ -234,6 +234,7 @@ class UOp(MathTrait, metaclass=UOpMetaClass):
|
||||
return UOp(Ops.GEP, self.dtype.scalar().vec(len(i)) if len(i) > 1 else self.dtype.scalar(), (self,), i)
|
||||
def load(self, *src:UOp, **kwargs): return UOp(Ops.LOAD, dtype=kwargs.pop("dtype", self.dtype.base), src=(self,)+src, **kwargs)
|
||||
def store(self, *src:UOp, **kwargs): return UOp(Ops.STORE, dtypes.void, (self,)+src, **kwargs)
|
||||
def assign(self, x:UOp): return UOp(Ops.ASSIGN, self.dtype, (self, x))
|
||||
def alu(self, arg, *src:UOp):
|
||||
out_dtype = (self, *src)[-1].dtype
|
||||
if arg in {Ops.CMPLT, Ops.CMPNE}: out_dtype = dtypes.bool.vec(out_dtype.count) if out_dtype.count > 1 else dtypes.bool
|
||||
@@ -266,7 +267,6 @@ class UOp(MathTrait, metaclass=UOpMetaClass):
|
||||
ret, new_axis = self, axis
|
||||
ret = UOp(Ops.REDUCE_AXIS, self.dtype, (ret,), (op, new_axis))
|
||||
return ret.reshape(tuple([x if i not in axis else 1 for i,x in enumerate(self.shape)]))
|
||||
def assign(self, x:UOp): return UOp(Ops.ASSIGN, self.dtype, (self,x))
|
||||
def reduce(self, *src:UOp, **kwargs): return UOp(Ops.REDUCE, kwargs.pop('dtype', self.dtype), src=(self,)+src, **kwargs)
|
||||
def contiguous(self): return self.alu(Ops.CONTIGUOUS)
|
||||
def contiguous_backward(self): return self.alu(Ops.CONTIGUOUS_BACKWARD)
|
||||
|
||||
@@ -461,8 +461,8 @@ sym = symbolic_flat+PatternMatcher([
|
||||
# remove VECTORIZE from SINK/BARRIER
|
||||
(UPat(Ops.BARRIER, src=(UPat((Ops.VECTORIZE, Ops.SINK), name='sink'),)), lambda sink: UOp(Ops.BARRIER, dtypes.void, sink.src)),
|
||||
(UPat(Ops.SINK, name="root"),
|
||||
lambda root: UOp(Ops.SINK, root.dtype, tuple(flatten(x.src if x.op in {Ops.SINK, Ops.UNROLL} else (x,) for x in root.src)), root.arg)
|
||||
if any(x.op in {Ops.SINK, Ops.UNROLL} for x in root.src) else None),
|
||||
lambda root: UOp(Ops.SINK, root.dtype, tuple(flatten(x.src if x.op in {Ops.SINK, Ops.UNROLL, Ops.PTRCAT} else (x,) for x in root.src)), root.arg)
|
||||
if any(x.op in {Ops.SINK, Ops.UNROLL, Ops.PTRCAT} for x in root.src) else None),
|
||||
((UPat.var("x") * UPat.var("x")).reciprocal(), lambda x: x.reciprocal()*x.reciprocal()), # 1/(x^c) -> (1/x)^c
|
||||
((UPat.var("x") * UPat.var("x") * UPat.var("x")).reciprocal(), lambda x: x.reciprocal()*x.reciprocal()*x.reciprocal()),
|
||||
(UPat.var("x") * ((1+UPat.var("x")).reciprocal().named("d")), lambda x,d: 1-d), # x*/(1+x) -> 1-1/(1+x)
|
||||
|
||||
Reference in New Issue
Block a user