identity element of binary ops (#6275)

helper for the number reduce acc is inited to (0 for ADD, 1 for MUL and -inf for MAX)
This commit is contained in:
chenyu
2024-08-24 18:10:19 -04:00
committed by GitHub
parent ee245b48a9
commit 00282afa41
3 changed files with 7 additions and 6 deletions

View File

@@ -4,6 +4,7 @@ import functools, itertools, heapq, math, operator
from collections import defaultdict
from tinygrad.dtype import dtypes, PtrDType, ImageDType, DType
from tinygrad.ops import UnaryOps, BinaryOps, exec_alu, UOp, NOp, UOps, UPat, PatternMatcher, END_FOR_UOP, graph_rewrite, type_verify, print_uops
from tinygrad.ops import identity_element
from tinygrad.helpers import DEBUG, getenv, flatten, dedup, TRANSCENDENTAL, prod, CI, all_same, partition
from tinygrad.codegen.transcendental import xexp2, xlog2, xsin, TRANSCENDENTAL_SUPPORTED_DTYPES
if TYPE_CHECKING: from tinygrad.renderer import Renderer
@@ -400,9 +401,7 @@ def do_reduce(root:UOp):
ret = root.src[0]
if len(reduce_parented):
assert root.dtype is not None
# TODO: helper to reuse this in 0 size folding
const = UOp.const(root.dtype, {BinaryOps.ADD:0, BinaryOps.MUL:1, BinaryOps.MAX:dtypes.min(root.dtype.scalar())}[root.arg])
acc = UOp(UOps.DEFINE_ACC, root.dtype, (const,) + tuple(reduce_parented), (acc_number,))
acc = UOp(UOps.DEFINE_ACC, root.dtype, (root.const(identity_element(root.arg, root.dtype.scalar())),) + tuple(reduce_parented), (acc_number,))
acc_number += 1
ret = UOp(UOps.PHI, root.dtype, (acc, acc.alu(root.arg, ret)))
# for MAX, we can just ignore the unparented

View File

@@ -2,7 +2,7 @@ from __future__ import annotations
from typing import Union, Optional, Any, Tuple, List, get_args
from tinygrad.dtype import dtypes, DType, DTypeLike, ConstType, to_dtype
from tinygrad.helpers import prod, getenv, all_int, all_same, DEBUG, _METADATA, Metadata, SPLIT_REDUCEOP
from tinygrad.ops import MetaOps, UnaryOps, BinaryOps, TernaryOps, ReduceOps, Op, exec_alu, python_alu
from tinygrad.ops import MetaOps, UnaryOps, BinaryOps, TernaryOps, ReduceOps, Op, exec_alu, python_alu, REDUCE_ALU, identity_element
from tinygrad.shape.symbolic import sint, Variable
from tinygrad.shape.shapetracker import ShapeTracker
from tinygrad.device import Buffer
@@ -173,8 +173,7 @@ class LazyBuffer:
def r(self, op:ReduceOps, axis:Tuple[int, ...]) -> LazyBuffer:
new_shape = self.st.reduce(axis)
# TODO: this logic should move to the scheduler
if 0 in self.shape and 0 not in new_shape:
return self.const({ReduceOps.SUM: 0.0, ReduceOps.PROD: 1.0, ReduceOps.MAX: dtypes.min(self.dtype)}[op], new_shape)
if 0 in self.shape and 0 not in new_shape: return self.const(identity_element(REDUCE_ALU[op], self.dtype), new_shape)
# const folding
# TODO: fold this for symbolic?

View File

@@ -35,6 +35,9 @@ UNSAFE_PAD_OPS = {UnaryOps.RECIP, UnaryOps.LOG2, UnaryOps.EXP2, BinaryOps.IDIV}
REDUCE_ALU: Dict[ReduceOps, BinaryOps] = {ReduceOps.SUM:BinaryOps.ADD, ReduceOps.PROD:BinaryOps.MUL, ReduceOps.MAX:BinaryOps.MAX}
# https://en.wikipedia.org/wiki/Identity_element
def identity_element(op:BinaryOps, dt:DType): return dtypes.as_const({BinaryOps.ADD:0, BinaryOps.MUL:1, BinaryOps.MAX:dtypes.min(dt)}[op], dt)
# the order of these UOps controls the order of the toposort
class UOps(Enum):
# ops that aren't rendered