put assign and store next to each other [pr] (#11306)

This commit is contained in:
George Hotz
2025-07-21 11:07:35 -07:00
committed by GitHub
parent de2df92551
commit 41de76a7fd
2 changed files with 3 additions and 3 deletions

View File

@@ -234,6 +234,7 @@ class UOp(MathTrait, metaclass=UOpMetaClass):
return UOp(Ops.GEP, self.dtype.scalar().vec(len(i)) if len(i) > 1 else self.dtype.scalar(), (self,), i)
def load(self, *src:UOp, **kwargs): return UOp(Ops.LOAD, dtype=kwargs.pop("dtype", self.dtype.base), src=(self,)+src, **kwargs)
def store(self, *src:UOp, **kwargs): return UOp(Ops.STORE, dtypes.void, (self,)+src, **kwargs)
def assign(self, x:UOp): return UOp(Ops.ASSIGN, self.dtype, (self, x))
def alu(self, arg, *src:UOp):
out_dtype = (self, *src)[-1].dtype
if arg in {Ops.CMPLT, Ops.CMPNE}: out_dtype = dtypes.bool.vec(out_dtype.count) if out_dtype.count > 1 else dtypes.bool
@@ -266,7 +267,6 @@ class UOp(MathTrait, metaclass=UOpMetaClass):
ret, new_axis = self, axis
ret = UOp(Ops.REDUCE_AXIS, self.dtype, (ret,), (op, new_axis))
return ret.reshape(tuple([x if i not in axis else 1 for i,x in enumerate(self.shape)]))
def assign(self, x:UOp): return UOp(Ops.ASSIGN, self.dtype, (self,x))
def reduce(self, *src:UOp, **kwargs): return UOp(Ops.REDUCE, kwargs.pop('dtype', self.dtype), src=(self,)+src, **kwargs)
def contiguous(self): return self.alu(Ops.CONTIGUOUS)
def contiguous_backward(self): return self.alu(Ops.CONTIGUOUS_BACKWARD)

View File

@@ -461,8 +461,8 @@ sym = symbolic_flat+PatternMatcher([
# remove VECTORIZE from SINK/BARRIER
(UPat(Ops.BARRIER, src=(UPat((Ops.VECTORIZE, Ops.SINK), name='sink'),)), lambda sink: UOp(Ops.BARRIER, dtypes.void, sink.src)),
(UPat(Ops.SINK, name="root"),
lambda root: UOp(Ops.SINK, root.dtype, tuple(flatten(x.src if x.op in {Ops.SINK, Ops.UNROLL} else (x,) for x in root.src)), root.arg)
if any(x.op in {Ops.SINK, Ops.UNROLL} for x in root.src) else None),
lambda root: UOp(Ops.SINK, root.dtype, tuple(flatten(x.src if x.op in {Ops.SINK, Ops.UNROLL, Ops.PTRCAT} else (x,) for x in root.src)), root.arg)
if any(x.op in {Ops.SINK, Ops.UNROLL, Ops.PTRCAT} for x in root.src) else None),
((UPat.var("x") * UPat.var("x")).reciprocal(), lambda x: x.reciprocal()*x.reciprocal()), # 1/(x^c) -> (1/x)^c
((UPat.var("x") * UPat.var("x") * UPat.var("x")).reciprocal(), lambda x: x.reciprocal()*x.reciprocal()*x.reciprocal()),
(UPat.var("x") * ((1+UPat.var("x")).reciprocal().named("d")), lambda x,d: 1-d), # x*/(1+x) -> 1-1/(1+x)