move usum and uprod to mixin (#15690)

and used it to clean up ops and tensor
This commit is contained in:
chenyu
2026-04-12 11:42:24 -04:00
committed by GitHub
parent e9b2e156b4
commit 0254cfe642
5 changed files with 19 additions and 18 deletions

View File

@@ -1,8 +1,8 @@
import math
import math, functools, operator
from typing import Self
from tinygrad.uop import Ops
from tinygrad.dtype import dtypes, ConstType, least_upper_dtype, least_upper_float
from tinygrad.helpers import polyN
from tinygrad.helpers import argfix, polyN
from tinygrad.mixin.dtype import DTypeMixin
from tinygrad.mixin.creation import CreationMixin
@@ -22,6 +22,9 @@ class ElementwiseMixin(DTypeMixin, CreationMixin):
def _binop(self, op: Ops, x: Self | ConstType, reverse: bool) -> Self:
return self.ufix(x).alu(op, self) if reverse else self.alu(op, self.ufix(x))
def usum(self, *uops) -> Self: return functools.reduce(operator.or_ if self.dtype is dtypes.bool else operator.add, argfix(*uops), self)
def uprod(self, *uops) -> Self: return functools.reduce(operator.and_ if self.dtype is dtypes.bool else operator.mul, argfix(*uops), self)
def logical_not(self) -> Self:
"""
Computes the logical NOT of the tensor element-wise.

View File

@@ -85,7 +85,7 @@ def create_bufferize_and_index_based_on_ranges(ctx:IndexingContext, x:UOp):
def convert_pad_to_where_to_keep_behavior_local(ctx:IndexingContext, x:UOp):
if x not in ctx.range_map: return None
valid: UOp = UOp.const(dtypes.bool, True).uprod(*[r.get_valid() for r in ctx.range_map[x][0]])
valid: UOp = UOp.const(dtypes.bool, True).uprod([r.get_valid() for r in ctx.range_map[x][0]])
ret = valid.where(x.src[0], UOp.const(x.dtype, 0))
ctx.range_map[ret] = ctx.range_map[x]
return ret
@@ -118,7 +118,7 @@ def _apply_reshape(in_shape:tuple[sint,...], out_shape:tuple[sint, ...], urngs:U
for s,src in list(zip(out_shape, urngs.src))[::-1]:
axes_in.append(acc*src)
acc *= s
combined_axes = UOp.const(dtypes.weakint, 0).usum(*axes_in)
combined_axes = UOp.const(dtypes.weakint, 0).usum(axes_in)
axes_out:list[UOp] = []
for s in in_shape[::-1]:
axes_out.append(combined_axes % s)
@@ -211,7 +211,7 @@ def run_rangeify(tsink:UOp, debug:bool=False) -> tuple[UOp, IndexingContext]:
# we compare the ranges without their valids
if all_all_same or (PCONTIG and all_same(local_rngs)):
# the new valid is the OR of all the children valids
minimum_valid = UOp.const(dtypes.bool, False).usum(*valids)
minimum_valid = UOp.const(dtypes.bool, False).usum(valids)
_out_rngs.append(graph_rewrite(minimum_valid.where(local_rngs[0], UOp.invalid()), symbolic, name="minimum_valid"))
else:
_out_rngs.append(rctx.new_range(x.shape[i]))

View File

@@ -1200,9 +1200,9 @@ class Tensor(OpMixin):
consecutive = dims == list(range(dims[0], dims[0] + len(dims)))
if v is None and len(dims) > 1 and consecutive and all_int(ishp := tuple(x.shape[d] for d in dims)):
strides = tuple(prod(ishp[i+1:]) for i in range(len(dims)))
try: linear_idx = functools.reduce(Tensor.add, (t._broadcast_to(big_shape) * s for t, s in zip(tensors, strides)))
try: linear_idx = Tensor.usum(*[t._broadcast_to(big_shape) * s for t, s in zip(tensors, strides)])
except ValueError as err: raise IndexError(f"cannot broadcast indices: {err}") from err
valid = functools.reduce(Tensor.__and__, ((t >= 0) & (t < s) for t, s in zip(tensors, ishp)))
valid = Tensor.uprod(*[(t >= 0) & (t < s) for t, s in zip(tensors, ishp)])
pre, post = x.shape[:dims[0]], x.shape[dims[-1]+1:]
x = x.reshape(pre + (prod(ishp),) + post)[tuple([slice(None)] * len(pre)) + (valid.where(linear_idx, 0),)]
return valid.reshape((1,) * len(pre) + big_shape + (1,) * len(post)).where(x, 0)
@@ -1216,7 +1216,7 @@ class Tensor(OpMixin):
masks.append(i._one_hot_along_dim(num_classes=x.shape[dim], dim=(dim - x.ndim)))
# reduce masks to 1 mask
mask: Tensor = functools.reduce(lambda x,y: x.mul(y), masks)
mask: Tensor = Tensor.uprod(*masks)
# inject 1's for the extra dims added in create masks
reshape_arg = x.shape[:dims[0]] + (1,) * len(big_shape) + x.shape[dims[0]:]
@@ -1250,7 +1250,7 @@ class Tensor(OpMixin):
per_dim.append((idx >= s) & (idx < e) & (((e-1-idx) if m['stride'] < 0 else (idx-s)) % st == 0))
vb = vb.flip(tuple(d for d, m in enumerate(mops) if m['stride'] < 0))
vb = vb.pad(tuple((m['boundary'][0], self.shape[d] - m['boundary'][1]) for d, m in enumerate(mops)))
return (functools.reduce(lambda a, b: a & b, per_dim) if per_dim else Tensor(True, dtype=dtypes.bool, device=self.device)).where(vb, self)
return (Tensor.uprod(*per_dim) if per_dim else Tensor(True, dtype=dtypes.bool, device=self.device)).where(vb, self)
def __getitem__(self, indices) -> Tensor:
"""
@@ -1361,7 +1361,7 @@ class Tensor(OpMixin):
tensors = [self, *args]
dim_cumsum = list(itertools.accumulate([t.shape[dim] for t in tensors], initial=0))
for i,t in enumerate(tensors): tensors[i] = t.pad([(dim_cumsum[i], dim_cumsum[-1]-dim_cumsum[i+1]) if j==dim else None for j in range(t.ndim)])
return functools.reduce(Tensor.add, tensors)
return Tensor.usum(*tensors)
def stack(self:Tensor, *args:Tensor, dim:int=0) -> Tensor:
"""
@@ -1759,7 +1759,7 @@ class Tensor(OpMixin):
# align all tensors to alphabet, multiply, sum non-output, permute to output order
xs = [x.permute(*[s.index(c) for c in sorted(s)]).reshape([sz[c] if c in s else 1 for c in alpha]).expand([sz[c] for c in alpha]) if s else x
for s, x in zip(inputs, xs)]
return functools.reduce(lambda a,b:a*b, xs).sum([i for i,c in enumerate(alpha) if c not in rhs], dtype=dtype).permute(argsort(argsort(list(rhs))))
return Tensor.uprod(*xs).sum([i for i,c in enumerate(alpha) if c not in rhs], dtype=dtype).permute(argsort(argsort(list(rhs))))
# ***** processing ops *****
@@ -2900,7 +2900,7 @@ class Tensor(OpMixin):
if ns > os:
tmp = tmp.reshape(self.shape[:-1] + (self.shape[-1]//(rate := ns//os), rate))
nones = (None,) * (tmp.ndim - 1)
return functools.reduce(Tensor.add, (tmp.shrink(nones + ((i, i+1),)).cast(new_uint)<<8*i*os for i in range(rate))).squeeze(-1).bitcast(dtype)
return Tensor.usum(*[tmp.shrink(nones + ((i, i+1),)).cast(new_uint)<<8*i*os for i in range(rate)]).squeeze(-1).bitcast(dtype)
return Tensor.stack(*(tmp>>8*i*ns for i in range(os//ns)), dim=-1).flatten(-2).cast(new_uint).bitcast(dtype)
return self._apply_uop(UOp.bitcast, dtype=dt) if self.dtype != dt else self

View File

@@ -829,8 +829,6 @@ class UOp(OpMixin, metaclass=UOpMetaClass):
new_count.subtract(div_fac.split_uop(Ops.MUL))
if const%div_const==0 and all(v>=0 for v in new_count.values()): return math.prod(new_count.elements(), start=self.const_like(const//div_const))
return None # generic None if we aren't sure
def usum(self:UOp, *uops:UOp) -> UOp: return functools.reduce(operator.or_ if self.dtype is dtypes.bool else operator.add, uops, self)
def uprod(self:UOp, *uops:UOp) -> UOp: return functools.reduce(operator.and_ if self.dtype is dtypes.bool else operator.mul, uops, self)
@property
def vmin(self) -> PyConst: return self._min_max[0]
@property

View File

@@ -1,5 +1,5 @@
# all of symbolic lives here now
import math, operator, struct, functools
import math, struct
from collections import defaultdict
from tinygrad.uop.ops import Ops, PatternMatcher, UPat, UOp, GroupOp, exec_alu
from tinygrad.dtype import ConstType, dtypes, PtrDType, can_lossless_cast, Invalid
@@ -41,11 +41,11 @@ def fold_add_divmod_recombine(x:UOp) -> UOp|None:
# ((base//d)%div)*mul + (base//(d*div))*(div*mul) -> (base//d)*mul
if not exact and base.op is Ops.IDIV and base.src[1].op is Ops.CONST:
exact = q.op is Ops.IDIV and q.src[1].op is Ops.CONST and q.src[0] is base.src[0] and q.src[1].arg == base.src[1].arg*div
if exact: return functools.reduce(operator.add, (t for k,t in enumerate(terms) if k not in (i,j)), base*mul)
if exact: return (base*mul).usum(*[t for k,t in enumerate(terms) if k not in (i,j)])
# ((base//div)%d)*div + base%div -> base%(div*d)
if mul == 1 and div > 0 and q.op is Ops.MOD and q.src[1].op is Ops.CONST and (d:=q.src[1].arg) > 0 and q.src[0].op is Ops.IDIV:
if q.src[0].src[0] is base and q.src[0].src[1].op is Ops.CONST and q.src[0].src[1].arg == div:
return functools.reduce(operator.add, (t for k,t in enumerate(terms) if k not in (i,j)), base % (div*d))
return (base % (div*d)).usum(*[t for k,t in enumerate(terms) if k not in (i,j)])
return None
# this needs to be before symbolic so that 0*something_that_might_be_invalid doesnt become 0
@@ -390,7 +390,7 @@ def where_on_load(cond:UOp, buf:UOp, idx:UOp, or_cast:UOp) -> UOp|None:
return c.ranges.keys() <= idx.ranges.keys() and all(u in idx_index for u in c.backward_slice_with_self if u.op is Ops.INDEX)
moved, keep = partition([c for c in where_clauses if c not in in_load], can_move)
if len(keep) == len(where_clauses): return None
idx = buf.index(idx.get_idx().valid(functools.reduce(operator.and_, moved, load_valid)))
idx = buf.index(idx.get_idx().valid(load_valid.uprod(*moved)))
return UOp.const(dtypes.bool, True).uprod(*keep).where(idx.cast(or_cast.dtype) if or_cast.op is Ops.CAST else idx, 0)
# where after gated load becomes alt value, TODO: this is sort of duplicated with rules in devectorizer