mirror of
https://github.com/tinygrad/tinygrad.git
synced 2026-01-23 05:48:08 -05:00
identity element of binary ops (#6275)
helper for the number reduce acc is inited to (0 for ADD, 1 for MUL and -inf for MAX)
This commit is contained in:
@@ -4,6 +4,7 @@ import functools, itertools, heapq, math, operator
|
||||
from collections import defaultdict
|
||||
from tinygrad.dtype import dtypes, PtrDType, ImageDType, DType
|
||||
from tinygrad.ops import UnaryOps, BinaryOps, exec_alu, UOp, NOp, UOps, UPat, PatternMatcher, END_FOR_UOP, graph_rewrite, type_verify, print_uops
|
||||
from tinygrad.ops import identity_element
|
||||
from tinygrad.helpers import DEBUG, getenv, flatten, dedup, TRANSCENDENTAL, prod, CI, all_same, partition
|
||||
from tinygrad.codegen.transcendental import xexp2, xlog2, xsin, TRANSCENDENTAL_SUPPORTED_DTYPES
|
||||
if TYPE_CHECKING: from tinygrad.renderer import Renderer
|
||||
@@ -400,9 +401,7 @@ def do_reduce(root:UOp):
|
||||
ret = root.src[0]
|
||||
if len(reduce_parented):
|
||||
assert root.dtype is not None
|
||||
# TODO: helper to reuse this in 0 size folding
|
||||
const = UOp.const(root.dtype, {BinaryOps.ADD:0, BinaryOps.MUL:1, BinaryOps.MAX:dtypes.min(root.dtype.scalar())}[root.arg])
|
||||
acc = UOp(UOps.DEFINE_ACC, root.dtype, (const,) + tuple(reduce_parented), (acc_number,))
|
||||
acc = UOp(UOps.DEFINE_ACC, root.dtype, (root.const(identity_element(root.arg, root.dtype.scalar())),) + tuple(reduce_parented), (acc_number,))
|
||||
acc_number += 1
|
||||
ret = UOp(UOps.PHI, root.dtype, (acc, acc.alu(root.arg, ret)))
|
||||
# for MAX, we can just ignore the unparented
|
||||
|
||||
@@ -2,7 +2,7 @@ from __future__ import annotations
|
||||
from typing import Union, Optional, Any, Tuple, List, get_args
|
||||
from tinygrad.dtype import dtypes, DType, DTypeLike, ConstType, to_dtype
|
||||
from tinygrad.helpers import prod, getenv, all_int, all_same, DEBUG, _METADATA, Metadata, SPLIT_REDUCEOP
|
||||
from tinygrad.ops import MetaOps, UnaryOps, BinaryOps, TernaryOps, ReduceOps, Op, exec_alu, python_alu
|
||||
from tinygrad.ops import MetaOps, UnaryOps, BinaryOps, TernaryOps, ReduceOps, Op, exec_alu, python_alu, REDUCE_ALU, identity_element
|
||||
from tinygrad.shape.symbolic import sint, Variable
|
||||
from tinygrad.shape.shapetracker import ShapeTracker
|
||||
from tinygrad.device import Buffer
|
||||
@@ -173,8 +173,7 @@ class LazyBuffer:
|
||||
def r(self, op:ReduceOps, axis:Tuple[int, ...]) -> LazyBuffer:
|
||||
new_shape = self.st.reduce(axis)
|
||||
# TODO: this logic should move to the scheduler
|
||||
if 0 in self.shape and 0 not in new_shape:
|
||||
return self.const({ReduceOps.SUM: 0.0, ReduceOps.PROD: 1.0, ReduceOps.MAX: dtypes.min(self.dtype)}[op], new_shape)
|
||||
if 0 in self.shape and 0 not in new_shape: return self.const(identity_element(REDUCE_ALU[op], self.dtype), new_shape)
|
||||
|
||||
# const folding
|
||||
# TODO: fold this for symbolic?
|
||||
|
||||
@@ -35,6 +35,9 @@ UNSAFE_PAD_OPS = {UnaryOps.RECIP, UnaryOps.LOG2, UnaryOps.EXP2, BinaryOps.IDIV}
|
||||
|
||||
REDUCE_ALU: Dict[ReduceOps, BinaryOps] = {ReduceOps.SUM:BinaryOps.ADD, ReduceOps.PROD:BinaryOps.MUL, ReduceOps.MAX:BinaryOps.MAX}
|
||||
|
||||
# https://en.wikipedia.org/wiki/Identity_element
|
||||
def identity_element(op:BinaryOps, dt:DType): return dtypes.as_const({BinaryOps.ADD:0, BinaryOps.MUL:1, BinaryOps.MAX:dtypes.min(dt)}[op], dt)
|
||||
|
||||
# the order of these UOps controls the order of the toposort
|
||||
class UOps(Enum):
|
||||
# ops that aren't rendered
|
||||
|
||||
Reference in New Issue
Block a user