diff --git a/tinygrad/mixin/elementwise.py b/tinygrad/mixin/elementwise.py index 8a4129c495..2df166059f 100644 --- a/tinygrad/mixin/elementwise.py +++ b/tinygrad/mixin/elementwise.py @@ -1,8 +1,8 @@ -import math +import math, functools, operator from typing import Self from tinygrad.uop import Ops from tinygrad.dtype import dtypes, ConstType, least_upper_dtype, least_upper_float -from tinygrad.helpers import polyN +from tinygrad.helpers import argfix, polyN from tinygrad.mixin.dtype import DTypeMixin from tinygrad.mixin.creation import CreationMixin @@ -22,6 +22,9 @@ class ElementwiseMixin(DTypeMixin, CreationMixin): def _binop(self, op: Ops, x: Self | ConstType, reverse: bool) -> Self: return self.ufix(x).alu(op, self) if reverse else self.alu(op, self.ufix(x)) + def usum(self, *uops) -> Self: return functools.reduce(operator.or_ if self.dtype is dtypes.bool else operator.add, argfix(*uops), self) + def uprod(self, *uops) -> Self: return functools.reduce(operator.and_ if self.dtype is dtypes.bool else operator.mul, argfix(*uops), self) + def logical_not(self) -> Self: """ Computes the logical NOT of the tensor element-wise. diff --git a/tinygrad/schedule/indexing.py b/tinygrad/schedule/indexing.py index 4cb0b65231..7d5635de37 100644 --- a/tinygrad/schedule/indexing.py +++ b/tinygrad/schedule/indexing.py @@ -85,7 +85,7 @@ def create_bufferize_and_index_based_on_ranges(ctx:IndexingContext, x:UOp): def convert_pad_to_where_to_keep_behavior_local(ctx:IndexingContext, x:UOp): if x not in ctx.range_map: return None - valid: UOp = UOp.const(dtypes.bool, True).uprod(*[r.get_valid() for r in ctx.range_map[x][0]]) + valid: UOp = UOp.const(dtypes.bool, True).uprod([r.get_valid() for r in ctx.range_map[x][0]]) ret = valid.where(x.src[0], UOp.const(x.dtype, 0)) ctx.range_map[ret] = ctx.range_map[x] return ret @@ -118,7 +118,7 @@ def _apply_reshape(in_shape:tuple[sint,...], out_shape:tuple[sint, ...], urngs:U for s,src in list(zip(out_shape, urngs.src))[::-1]: axes_in.append(acc*src) acc *= s - combined_axes = UOp.const(dtypes.weakint, 0).usum(*axes_in) + combined_axes = UOp.const(dtypes.weakint, 0).usum(axes_in) axes_out:list[UOp] = [] for s in in_shape[::-1]: axes_out.append(combined_axes % s) @@ -211,7 +211,7 @@ def run_rangeify(tsink:UOp, debug:bool=False) -> tuple[UOp, IndexingContext]: # we compare the ranges without their valids if all_all_same or (PCONTIG and all_same(local_rngs)): # the new valid is the OR of all the children valids - minimum_valid = UOp.const(dtypes.bool, False).usum(*valids) + minimum_valid = UOp.const(dtypes.bool, False).usum(valids) _out_rngs.append(graph_rewrite(minimum_valid.where(local_rngs[0], UOp.invalid()), symbolic, name="minimum_valid")) else: _out_rngs.append(rctx.new_range(x.shape[i])) diff --git a/tinygrad/tensor.py b/tinygrad/tensor.py index 3123cf95d9..f35716aa73 100644 --- a/tinygrad/tensor.py +++ b/tinygrad/tensor.py @@ -1200,9 +1200,9 @@ class Tensor(OpMixin): consecutive = dims == list(range(dims[0], dims[0] + len(dims))) if v is None and len(dims) > 1 and consecutive and all_int(ishp := tuple(x.shape[d] for d in dims)): strides = tuple(prod(ishp[i+1:]) for i in range(len(dims))) - try: linear_idx = functools.reduce(Tensor.add, (t._broadcast_to(big_shape) * s for t, s in zip(tensors, strides))) + try: linear_idx = Tensor.usum(*[t._broadcast_to(big_shape) * s for t, s in zip(tensors, strides)]) except ValueError as err: raise IndexError(f"cannot broadcast indices: {err}") from err - valid = functools.reduce(Tensor.__and__, ((t >= 0) & (t < s) for t, s in zip(tensors, ishp))) + valid = Tensor.uprod(*[(t >= 0) & (t < s) for t, s in zip(tensors, ishp)]) pre, post = x.shape[:dims[0]], x.shape[dims[-1]+1:] x = x.reshape(pre + (prod(ishp),) + post)[tuple([slice(None)] * len(pre)) + (valid.where(linear_idx, 0),)] return valid.reshape((1,) * len(pre) + big_shape + (1,) * len(post)).where(x, 0) @@ -1216,7 +1216,7 @@ class Tensor(OpMixin): masks.append(i._one_hot_along_dim(num_classes=x.shape[dim], dim=(dim - x.ndim))) # reduce masks to 1 mask - mask: Tensor = functools.reduce(lambda x,y: x.mul(y), masks) + mask: Tensor = Tensor.uprod(*masks) # inject 1's for the extra dims added in create masks reshape_arg = x.shape[:dims[0]] + (1,) * len(big_shape) + x.shape[dims[0]:] @@ -1250,7 +1250,7 @@ class Tensor(OpMixin): per_dim.append((idx >= s) & (idx < e) & (((e-1-idx) if m['stride'] < 0 else (idx-s)) % st == 0)) vb = vb.flip(tuple(d for d, m in enumerate(mops) if m['stride'] < 0)) vb = vb.pad(tuple((m['boundary'][0], self.shape[d] - m['boundary'][1]) for d, m in enumerate(mops))) - return (functools.reduce(lambda a, b: a & b, per_dim) if per_dim else Tensor(True, dtype=dtypes.bool, device=self.device)).where(vb, self) + return (Tensor.uprod(*per_dim) if per_dim else Tensor(True, dtype=dtypes.bool, device=self.device)).where(vb, self) def __getitem__(self, indices) -> Tensor: """ @@ -1361,7 +1361,7 @@ class Tensor(OpMixin): tensors = [self, *args] dim_cumsum = list(itertools.accumulate([t.shape[dim] for t in tensors], initial=0)) for i,t in enumerate(tensors): tensors[i] = t.pad([(dim_cumsum[i], dim_cumsum[-1]-dim_cumsum[i+1]) if j==dim else None for j in range(t.ndim)]) - return functools.reduce(Tensor.add, tensors) + return Tensor.usum(*tensors) def stack(self:Tensor, *args:Tensor, dim:int=0) -> Tensor: """ @@ -1759,7 +1759,7 @@ class Tensor(OpMixin): # align all tensors to alphabet, multiply, sum non-output, permute to output order xs = [x.permute(*[s.index(c) for c in sorted(s)]).reshape([sz[c] if c in s else 1 for c in alpha]).expand([sz[c] for c in alpha]) if s else x for s, x in zip(inputs, xs)] - return functools.reduce(lambda a,b:a*b, xs).sum([i for i,c in enumerate(alpha) if c not in rhs], dtype=dtype).permute(argsort(argsort(list(rhs)))) + return Tensor.uprod(*xs).sum([i for i,c in enumerate(alpha) if c not in rhs], dtype=dtype).permute(argsort(argsort(list(rhs)))) # ***** processing ops ***** @@ -2900,7 +2900,7 @@ class Tensor(OpMixin): if ns > os: tmp = tmp.reshape(self.shape[:-1] + (self.shape[-1]//(rate := ns//os), rate)) nones = (None,) * (tmp.ndim - 1) - return functools.reduce(Tensor.add, (tmp.shrink(nones + ((i, i+1),)).cast(new_uint)<<8*i*os for i in range(rate))).squeeze(-1).bitcast(dtype) + return Tensor.usum(*[tmp.shrink(nones + ((i, i+1),)).cast(new_uint)<<8*i*os for i in range(rate)]).squeeze(-1).bitcast(dtype) return Tensor.stack(*(tmp>>8*i*ns for i in range(os//ns)), dim=-1).flatten(-2).cast(new_uint).bitcast(dtype) return self._apply_uop(UOp.bitcast, dtype=dt) if self.dtype != dt else self diff --git a/tinygrad/uop/ops.py b/tinygrad/uop/ops.py index eac45c59f5..f58d25a198 100644 --- a/tinygrad/uop/ops.py +++ b/tinygrad/uop/ops.py @@ -829,8 +829,6 @@ class UOp(OpMixin, metaclass=UOpMetaClass): new_count.subtract(div_fac.split_uop(Ops.MUL)) if const%div_const==0 and all(v>=0 for v in new_count.values()): return math.prod(new_count.elements(), start=self.const_like(const//div_const)) return None # generic None if we aren't sure - def usum(self:UOp, *uops:UOp) -> UOp: return functools.reduce(operator.or_ if self.dtype is dtypes.bool else operator.add, uops, self) - def uprod(self:UOp, *uops:UOp) -> UOp: return functools.reduce(operator.and_ if self.dtype is dtypes.bool else operator.mul, uops, self) @property def vmin(self) -> PyConst: return self._min_max[0] @property diff --git a/tinygrad/uop/symbolic.py b/tinygrad/uop/symbolic.py index d6ada2c912..c276b424b5 100644 --- a/tinygrad/uop/symbolic.py +++ b/tinygrad/uop/symbolic.py @@ -1,5 +1,5 @@ # all of symbolic lives here now -import math, operator, struct, functools +import math, struct from collections import defaultdict from tinygrad.uop.ops import Ops, PatternMatcher, UPat, UOp, GroupOp, exec_alu from tinygrad.dtype import ConstType, dtypes, PtrDType, can_lossless_cast, Invalid @@ -41,11 +41,11 @@ def fold_add_divmod_recombine(x:UOp) -> UOp|None: # ((base//d)%div)*mul + (base//(d*div))*(div*mul) -> (base//d)*mul if not exact and base.op is Ops.IDIV and base.src[1].op is Ops.CONST: exact = q.op is Ops.IDIV and q.src[1].op is Ops.CONST and q.src[0] is base.src[0] and q.src[1].arg == base.src[1].arg*div - if exact: return functools.reduce(operator.add, (t for k,t in enumerate(terms) if k not in (i,j)), base*mul) + if exact: return (base*mul).usum(*[t for k,t in enumerate(terms) if k not in (i,j)]) # ((base//div)%d)*div + base%div -> base%(div*d) if mul == 1 and div > 0 and q.op is Ops.MOD and q.src[1].op is Ops.CONST and (d:=q.src[1].arg) > 0 and q.src[0].op is Ops.IDIV: if q.src[0].src[0] is base and q.src[0].src[1].op is Ops.CONST and q.src[0].src[1].arg == div: - return functools.reduce(operator.add, (t for k,t in enumerate(terms) if k not in (i,j)), base % (div*d)) + return (base % (div*d)).usum(*[t for k,t in enumerate(terms) if k not in (i,j)]) return None # this needs to be before symbolic so that 0*something_that_might_be_invalid doesnt become 0 @@ -390,7 +390,7 @@ def where_on_load(cond:UOp, buf:UOp, idx:UOp, or_cast:UOp) -> UOp|None: return c.ranges.keys() <= idx.ranges.keys() and all(u in idx_index for u in c.backward_slice_with_self if u.op is Ops.INDEX) moved, keep = partition([c for c in where_clauses if c not in in_load], can_move) if len(keep) == len(where_clauses): return None - idx = buf.index(idx.get_idx().valid(functools.reduce(operator.and_, moved, load_valid))) + idx = buf.index(idx.get_idx().valid(load_valid.uprod(*moved))) return UOp.const(dtypes.bool, True).uprod(*keep).where(idx.cast(or_cast.dtype) if or_cast.op is Ops.CAST else idx, 0) # where after gated load becomes alt value, TODO: this is sort of duplicated with rules in devectorizer