diff --git a/tinygrad/dtype.py b/tinygrad/dtype.py index ddf700d6ba..d5b0d90216 100644 --- a/tinygrad/dtype.py +++ b/tinygrad/dtype.py @@ -105,3 +105,9 @@ def least_upper_float(dt:DType) -> DType: return dt if dtypes.is_float(dt) else # HACK: staticmethods are not callable in 3.8 so we have to compare the class DTYPES_DICT = {k: v for k, v in dtypes.__dict__.items() if not (k.startswith(('__', 'default')) or v.__class__ is staticmethod)} INVERSE_DTYPES_DICT = {v.name:k for k,v in DTYPES_DICT.items()} + +def sum_acc_dtype(dt:DType): + # default acc dtype for sum + if dtypes.is_unsigned(dt): return least_upper_dtype(dt, dtypes.uint) + if dtypes.is_int(dt) or dt == dtypes.bool: return least_upper_dtype(dt, dtypes.int) + return least_upper_dtype(dt, dtypes.float) \ No newline at end of file diff --git a/tinygrad/features/multi.py b/tinygrad/features/multi.py index 3237fd59ed..24eb9cbe76 100644 --- a/tinygrad/features/multi.py +++ b/tinygrad/features/multi.py @@ -113,14 +113,14 @@ class MultiLazyBuffer: def _shape_to_single_shard(self, shape:Tuple[sint, ...], lb:LazyBuffer) -> Tuple[sint, ...]: return tuple(lb.shape[self.axis] if a == self.axis else s for a,s in enumerate(shape)) - def r(self, op:ReduceOps, axis:Tuple[int, ...], acc_dt:Optional[DType]=None) -> MultiLazyBuffer: + def r(self, op:ReduceOps, axis:Tuple[int, ...]) -> MultiLazyBuffer: if self.axis is not None and self.axis in axis: # all-reduce on sharded axes - reduced_parts = [(x if r else x.const(0)).r(op, axis, acc_dt) for x,r in zip(self.lbs, self.real)] + reduced_parts = [(x if r else x.const(0)).r(op, axis) for x,r in zip(self.lbs, self.real)] if all(self.real): return MultiLazyBuffer(all_reduce(op, reduced_parts), None) return MultiLazyBuffer(reduced_parts, None, self.real) # reduce on non sharded axes, piecewise is fine. if axis is None this is also correct - return MultiLazyBuffer([x.r(op, axis, acc_dt) for x in self.lbs], self.axis, self.real) + return MultiLazyBuffer([x.r(op, axis) for x in self.lbs], self.axis, self.real) # *** movement ops *** diff --git a/tinygrad/function.py b/tinygrad/function.py index 15bbc2800b..462139ecd3 100644 --- a/tinygrad/function.py +++ b/tinygrad/function.py @@ -2,7 +2,7 @@ import math from typing import Tuple, Optional from tinygrad.helpers import argsort -from tinygrad.dtype import DType +from tinygrad.dtype import dtypes, DType, sum_acc_dtype from tinygrad.ops import UnaryOps, BinaryOps, TernaryOps, ReduceOps from tinygrad.tensor import Function from tinygrad.lazy import LazyBuffer @@ -146,22 +146,22 @@ class Where(Function): # ************* reduce ops ************* class Sum(Function): - def forward(self, x:LazyBuffer, axis:Tuple[int, ...], acc_dtype:Optional[DType]=None) -> LazyBuffer: + def forward(self, x:LazyBuffer, axis:Tuple[int, ...]) -> LazyBuffer: self.input_shape = x.shape - return x.r(ReduceOps.SUM, axis, acc_dtype) + return x.r(ReduceOps.SUM, axis) def backward(self, grad_output:LazyBuffer) -> LazyBuffer: return grad_output.expand(self.input_shape) class Max(Function): - def forward(self, x:LazyBuffer, axis:Tuple[int, ...], acc_dtype:Optional[DType]=None) -> LazyBuffer: + def forward(self, x:LazyBuffer, axis:Tuple[int, ...]) -> LazyBuffer: self.x, self.ret, self.axis = x, x.r(ReduceOps.MAX, axis), axis return self.ret def backward(self, grad_output:LazyBuffer) -> LazyBuffer: # 1s in locations where the max was chosen (can be two locations) - max_is_1s = self.x.e(BinaryOps.CMPEQ, self.ret.expand(self.x.shape)).cast(self.x.dtype) + max_is_1s = self.x.e(BinaryOps.CMPEQ, self.ret.expand(self.x.shape)).cast(dtypes.float) div = max_is_1s.r(ReduceOps.SUM, self.axis).expand(self.x.shape) - return max_is_1s.e(BinaryOps.DIV, div).e(BinaryOps.MUL, grad_output.expand(self.x.shape)) + return max_is_1s.e(BinaryOps.DIV, div).cast(grad_output.dtype).e(BinaryOps.MUL, grad_output.expand(self.x.shape)) # ************* movement ops ************* @@ -171,7 +171,8 @@ class Expand(Function): self.expanded_axis = tuple(i for i, (si, so) in enumerate(zip(x.shape, shape)) if si != so) return x.expand(shape) - def backward(self, grad_output:LazyBuffer) -> LazyBuffer: return grad_output.r(ReduceOps.SUM, self.expanded_axis) + def backward(self, grad_output:LazyBuffer) -> LazyBuffer: + return grad_output.cast(sum_acc_dtype(grad_output.dtype)).r(ReduceOps.SUM, self.expanded_axis).cast(grad_output.dtype) class Reshape(Function): def forward(self, x:LazyBuffer, shape:Tuple[int, ...]) -> LazyBuffer: diff --git a/tinygrad/lazy.py b/tinygrad/lazy.py index a4c725ef78..80c72ad971 100644 --- a/tinygrad/lazy.py +++ b/tinygrad/lazy.py @@ -1,7 +1,7 @@ from __future__ import annotations import math from typing import Union, Optional, Any, Tuple, List -from tinygrad.dtype import dtypes, DType, ConstType, least_upper_dtype +from tinygrad.dtype import dtypes, DType, ConstType from tinygrad.helpers import prod, getenv, all_int, all_same, DEBUG from tinygrad.ops import LoadOps, UnaryOps, BinaryOps, TernaryOps, ReduceOps, Op, exec_alu, python_alu from tinygrad.shape.symbolic import sint, Variable @@ -161,35 +161,26 @@ class LazyBuffer: # *** reduce ops *** - def _reduce_op(self, op:ReduceOps, axis:Tuple[int, ...], acc_dt:Optional[DType]=None) -> LazyBuffer: + def _reduce_op(self, op:ReduceOps, axis:Tuple[int, ...]) -> LazyBuffer: assert all(0 <= x < len(self.shape) for x in axis), f"axis args {axis} out of range for shape {self.shape}" axis = tuple(x for x in axis if self.shape[x] != 1) if len(axis) == 0: return self new_shape = tuple(1 if i in axis else s for i,s in enumerate(self.shape)) - return create_lazybuffer(self.device, ShapeTracker.from_shape(new_shape), self.dtype, op, (axis, acc_dt), (self,)) + return create_lazybuffer(self.device, ShapeTracker.from_shape(new_shape), self.dtype, op, axis, (self,)) - def r(self, op:ReduceOps, axis:Tuple[int, ...], acc_dt:Optional[DType]=None) -> LazyBuffer: + def r(self, op:ReduceOps, axis:Tuple[int, ...]) -> LazyBuffer: new_shape = tuple(1 if i in axis else s for i,s in enumerate(self.shape)) # TODO: this logic should move to the scheduler if self.size == 0 and 0 not in new_shape: return self.const({ReduceOps.SUM: 0.0, ReduceOps.MAX: -math.inf}[op], new_shape) - # upcast acc_dt here so if reduce is splitted, the intermediate dtype is upcasted - if op is ReduceOps.SUM and acc_dt is None: - acc_dt = least_upper_dtype(self.dtype, dtypes.uint) if dtypes.is_unsigned(self.dtype) else \ - least_upper_dtype(self.dtype, dtypes.int) if (dtypes.is_int(self.dtype) or self.dtype==dtypes.bool) else \ - least_upper_dtype(self.dtype, dtypes.float) - if acc_dt is not None and acc_dt != self.dtype: - # cast back to float16 or bfloat16 to match torch / jax behavior - return self.cast(acc_dt).r(op, axis, acc_dt).cast(self.dtype if self.dtype in [dtypes.float16, dtypes.bfloat16] else acc_dt) - - # const folding after acc_dt cast to correct output dtype + # const folding if self.is_unrealized_unmasked_const(): return self.const(self.base.arg * {ReduceOps.SUM: prod(self.shape[i] for i in axis), ReduceOps.MAX: 1}[op], new_shape) # TODO: can we split symbolic shape if the reduce axis is not symbolic? if not getenv("SPLIT_REDUCEOP", 1) or not all_int(self.shape) or (0 in self.shape) or \ prod(self.shape) // prod(new_shape) < getenv("REDUCEOP_SPLIT_THRESHOLD", 32768): - return self._reduce_op(op, axis, acc_dt) + return self._reduce_op(op, axis) # if there are few globals, make some reduces into globals by splitting into two kernels # cap output buffer to 2**22: heuristic number of global outputs to achieve max occupancy with enough locals+upcasts for gemm diff --git a/tinygrad/tensor.py b/tinygrad/tensor.py index e25284105d..22b3e32b7b 100644 --- a/tinygrad/tensor.py +++ b/tinygrad/tensor.py @@ -6,7 +6,7 @@ from typing import List, Tuple, Callable, Optional, ClassVar, Type, Union, Seque from collections import defaultdict import numpy as np -from tinygrad.dtype import DType, dtypes, ImageDType, ConstType, least_upper_float, least_upper_dtype +from tinygrad.dtype import DType, dtypes, ImageDType, ConstType, least_upper_float, least_upper_dtype, sum_acc_dtype from tinygrad.helpers import argfix, make_pair, flatten, prod, all_int, round_up, merge_dicts, fully_flatten, argsort, IMAGE, DEBUG, WINO, THREEFRY from tinygrad.helpers import getenv from tinygrad.lazy import LazyBuffer @@ -913,14 +913,16 @@ class Tensor: # ***** reduce ops ***** - def _reduce(self, fxn:Type[Function], axis:Optional[Union[int, Tuple[int, ...]]]=None, keepdim=False, acc_dtype:Optional[DType]=None) -> Tensor: + def _reduce(self, fxn:Type[Function], axis:Optional[Union[int, Tuple[int, ...]]]=None, keepdim=False) -> Tensor: axis_: Tuple[int, ...] = tuple(range(len(self.shape))) if axis is None else ((axis,) if isinstance(axis, int) else tuple(axis)) axis_ = tuple(x if x >= 0 else x+len(self.shape) for x in axis_) shape = tuple(s for i,s in enumerate(self.shape) if i not in axis_) - ret = fxn.apply(self, axis=axis_, acc_dtype=acc_dtype) + ret = fxn.apply(self, axis=axis_) return ret if keepdim else ret.reshape(shape) - def sum(self, axis=None, keepdim=False, acc_dtype:Optional[DType]=None): return self._reduce(F.Sum, axis, keepdim, acc_dtype) + def sum(self, axis=None, keepdim=False, acc_dtype:Optional[DType]=None): + ret = self.cast(acc_dtype or sum_acc_dtype(self.dtype))._reduce(F.Sum, axis, keepdim) + return ret.cast(self.dtype) if self.dtype in {dtypes.float16, dtypes.bfloat16} else ret def max(self, axis=None, keepdim=False): return self._reduce(F.Max, axis, keepdim) def min(self, axis=None, keepdim=False): return -((-self).max(axis=axis, keepdim=keepdim))