diff --git a/tinygrad/codegen/uopgraph.py b/tinygrad/codegen/uopgraph.py index 712d4df54b..c162c23fc7 100644 --- a/tinygrad/codegen/uopgraph.py +++ b/tinygrad/codegen/uopgraph.py @@ -4,6 +4,7 @@ import functools, itertools, heapq, math, operator from collections import defaultdict from tinygrad.dtype import dtypes, PtrDType, ImageDType, DType from tinygrad.ops import UnaryOps, BinaryOps, exec_alu, UOp, NOp, UOps, UPat, PatternMatcher, END_FOR_UOP, graph_rewrite, type_verify, print_uops +from tinygrad.ops import identity_element from tinygrad.helpers import DEBUG, getenv, flatten, dedup, TRANSCENDENTAL, prod, CI, all_same, partition from tinygrad.codegen.transcendental import xexp2, xlog2, xsin, TRANSCENDENTAL_SUPPORTED_DTYPES if TYPE_CHECKING: from tinygrad.renderer import Renderer @@ -400,9 +401,7 @@ def do_reduce(root:UOp): ret = root.src[0] if len(reduce_parented): assert root.dtype is not None - # TODO: helper to reuse this in 0 size folding - const = UOp.const(root.dtype, {BinaryOps.ADD:0, BinaryOps.MUL:1, BinaryOps.MAX:dtypes.min(root.dtype.scalar())}[root.arg]) - acc = UOp(UOps.DEFINE_ACC, root.dtype, (const,) + tuple(reduce_parented), (acc_number,)) + acc = UOp(UOps.DEFINE_ACC, root.dtype, (root.const(identity_element(root.arg, root.dtype.scalar())),) + tuple(reduce_parented), (acc_number,)) acc_number += 1 ret = UOp(UOps.PHI, root.dtype, (acc, acc.alu(root.arg, ret))) # for MAX, we can just ignore the unparented diff --git a/tinygrad/lazy.py b/tinygrad/lazy.py index 0c99745115..f1b497616e 100644 --- a/tinygrad/lazy.py +++ b/tinygrad/lazy.py @@ -2,7 +2,7 @@ from __future__ import annotations from typing import Union, Optional, Any, Tuple, List, get_args from tinygrad.dtype import dtypes, DType, DTypeLike, ConstType, to_dtype from tinygrad.helpers import prod, getenv, all_int, all_same, DEBUG, _METADATA, Metadata, SPLIT_REDUCEOP -from tinygrad.ops import MetaOps, UnaryOps, BinaryOps, TernaryOps, ReduceOps, Op, exec_alu, python_alu +from tinygrad.ops import MetaOps, UnaryOps, BinaryOps, TernaryOps, ReduceOps, Op, exec_alu, python_alu, REDUCE_ALU, identity_element from tinygrad.shape.symbolic import sint, Variable from tinygrad.shape.shapetracker import ShapeTracker from tinygrad.device import Buffer @@ -173,8 +173,7 @@ class LazyBuffer: def r(self, op:ReduceOps, axis:Tuple[int, ...]) -> LazyBuffer: new_shape = self.st.reduce(axis) # TODO: this logic should move to the scheduler - if 0 in self.shape and 0 not in new_shape: - return self.const({ReduceOps.SUM: 0.0, ReduceOps.PROD: 1.0, ReduceOps.MAX: dtypes.min(self.dtype)}[op], new_shape) + if 0 in self.shape and 0 not in new_shape: return self.const(identity_element(REDUCE_ALU[op], self.dtype), new_shape) # const folding # TODO: fold this for symbolic? diff --git a/tinygrad/ops.py b/tinygrad/ops.py index b04dc99c07..d7d331820f 100644 --- a/tinygrad/ops.py +++ b/tinygrad/ops.py @@ -35,6 +35,9 @@ UNSAFE_PAD_OPS = {UnaryOps.RECIP, UnaryOps.LOG2, UnaryOps.EXP2, BinaryOps.IDIV} REDUCE_ALU: Dict[ReduceOps, BinaryOps] = {ReduceOps.SUM:BinaryOps.ADD, ReduceOps.PROD:BinaryOps.MUL, ReduceOps.MAX:BinaryOps.MAX} +# https://en.wikipedia.org/wiki/Identity_element +def identity_element(op:BinaryOps, dt:DType): return dtypes.as_const({BinaryOps.ADD:0, BinaryOps.MUL:1, BinaryOps.MAX:dtypes.min(dt)}[op], dt) + # the order of these UOps controls the order of the toposort class UOps(Enum): # ops that aren't rendered