From 61ca19ff242c3a3396b4763f79965cedad90fa51 Mon Sep 17 00:00:00 2001 From: chenyu Date: Mon, 2 Feb 2026 10:19:05 -0500 Subject: [PATCH] after with empty src is self [pr] (#14496) --- tinygrad/codegen/late/devectorizer.py | 3 +-- tinygrad/uop/ops.py | 6 ++---- 2 files changed, 3 insertions(+), 6 deletions(-) diff --git a/tinygrad/codegen/late/devectorizer.py b/tinygrad/codegen/late/devectorizer.py index 62c2a8c624..dbf4fd3ae6 100644 --- a/tinygrad/codegen/late/devectorizer.py +++ b/tinygrad/codegen/late/devectorizer.py @@ -319,8 +319,7 @@ def reduce_to_acc(ctx:ReduceContext, red:UOp): input_ranges = tuple([x for x in topo if x.op is Ops.RANGE and x not in reduce_range and x not in ended_ranges]) identity = red.const(red.dtype, identity_element(red.arg, red.dtype.scalar())) acc = UOp(Ops.DEFINE_REG, red.dtype.ptr(size=1, addrspace=AddrSpace.REG), arg=ctx.acc_num) - acc_init = acc.after(*input_ranges).index(UOp.const(dtypes.int, 0)).store(identity) if len(input_ranges) else \ - acc.index(UOp.const(dtypes.int, 0)).store(identity) + acc_init = acc.after(*input_ranges).index(UOp.const(dtypes.int, 0)).store(identity) lst = [acc.after(acc_init, *reduce_range).index(UOp.const(dtypes.int, 0))] + lst # put acc as the first element ctx.acc_num += 1 ret = functools.reduce(lambda x,y: x.alu(red.arg, y), lst) diff --git a/tinygrad/uop/ops.py b/tinygrad/uop/ops.py index 1075290688..6c5ed761be 100644 --- a/tinygrad/uop/ops.py +++ b/tinygrad/uop/ops.py @@ -414,10 +414,8 @@ class UOp(OpMixin, metaclass=UOpMetaClass): def load(self, *src:UOp, **kwargs): return UOp(Ops.LOAD, dtype=kwargs.pop("dtype", self.dtype.base), src=(self,)+src, **kwargs) def store(self, src:UOp|ConstType, **kwargs): return UOp(Ops.STORE, kwargs.pop("dtype", dtypes.void), (self, UOp.const(self.dtype, src) if not isinstance(src, UOp) else src), **kwargs) - def end(self, *src:UOp): - if len(src) == 0: return self - return UOp(Ops.END, src=(self,)+src) - def after(self, *src:UOp, **kwargs): return UOp(Ops.AFTER, self.dtype, (self,)+src, **kwargs) + def end(self, *src:UOp): return UOp(Ops.END, src=(self,)+src) if len(src) else self + def after(self, *src:UOp, **kwargs): return UOp(Ops.AFTER, self.dtype, (self,)+src, **kwargs) if len(src) else self def assign(self, x:UOp): return UOp(Ops.ASSIGN, self.dtype, (self, x)) def barrier(self, *src:UOp): return UOp(Ops.BARRIER, src=(self,)+src) def contract(self, *rngs:UOp):