after with empty src is self [pr] (#14496)

This commit is contained in:
chenyu
2026-02-02 10:19:05 -05:00
committed by GitHub
parent 6e958dbfd4
commit 61ca19ff24
2 changed files with 3 additions and 6 deletions

View File

@@ -319,8 +319,7 @@ def reduce_to_acc(ctx:ReduceContext, red:UOp):
input_ranges = tuple([x for x in topo if x.op is Ops.RANGE and x not in reduce_range and x not in ended_ranges])
identity = red.const(red.dtype, identity_element(red.arg, red.dtype.scalar()))
acc = UOp(Ops.DEFINE_REG, red.dtype.ptr(size=1, addrspace=AddrSpace.REG), arg=ctx.acc_num)
acc_init = acc.after(*input_ranges).index(UOp.const(dtypes.int, 0)).store(identity) if len(input_ranges) else \
acc.index(UOp.const(dtypes.int, 0)).store(identity)
acc_init = acc.after(*input_ranges).index(UOp.const(dtypes.int, 0)).store(identity)
lst = [acc.after(acc_init, *reduce_range).index(UOp.const(dtypes.int, 0))] + lst # put acc as the first element
ctx.acc_num += 1
ret = functools.reduce(lambda x,y: x.alu(red.arg, y), lst)

View File

@@ -414,10 +414,8 @@ class UOp(OpMixin, metaclass=UOpMetaClass):
def load(self, *src:UOp, **kwargs): return UOp(Ops.LOAD, dtype=kwargs.pop("dtype", self.dtype.base), src=(self,)+src, **kwargs)
def store(self, src:UOp|ConstType, **kwargs):
return UOp(Ops.STORE, kwargs.pop("dtype", dtypes.void), (self, UOp.const(self.dtype, src) if not isinstance(src, UOp) else src), **kwargs)
def end(self, *src:UOp):
if len(src) == 0: return self
return UOp(Ops.END, src=(self,)+src)
def after(self, *src:UOp, **kwargs): return UOp(Ops.AFTER, self.dtype, (self,)+src, **kwargs)
def end(self, *src:UOp): return UOp(Ops.END, src=(self,)+src) if len(src) else self
def after(self, *src:UOp, **kwargs): return UOp(Ops.AFTER, self.dtype, (self,)+src, **kwargs) if len(src) else self
def assign(self, x:UOp): return UOp(Ops.ASSIGN, self.dtype, (self, x))
def barrier(self, *src:UOp): return UOp(Ops.BARRIER, src=(self,)+src)
def contract(self, *rngs:UOp):