move acc_dt out of lazy (#4382)

move the logic to tensor.py for forward, and function.py for two places in backward (expand and max)
This commit is contained in:
chenyu
2024-05-06 10:41:25 -04:00
committed by GitHub
parent 113c2f00b9
commit 292ce64ad7
5 changed files with 29 additions and 29 deletions

View File

@@ -105,3 +105,9 @@ def least_upper_float(dt:DType) -> DType: return dt if dtypes.is_float(dt) else
# HACK: staticmethods are not callable in 3.8 so we have to compare the class
DTYPES_DICT = {k: v for k, v in dtypes.__dict__.items() if not (k.startswith(('__', 'default')) or v.__class__ is staticmethod)}
INVERSE_DTYPES_DICT = {v.name:k for k,v in DTYPES_DICT.items()}
def sum_acc_dtype(dt:DType):
# default acc dtype for sum
if dtypes.is_unsigned(dt): return least_upper_dtype(dt, dtypes.uint)
if dtypes.is_int(dt) or dt == dtypes.bool: return least_upper_dtype(dt, dtypes.int)
return least_upper_dtype(dt, dtypes.float)

View File

@@ -113,14 +113,14 @@ class MultiLazyBuffer:
def _shape_to_single_shard(self, shape:Tuple[sint, ...], lb:LazyBuffer) -> Tuple[sint, ...]:
return tuple(lb.shape[self.axis] if a == self.axis else s for a,s in enumerate(shape))
def r(self, op:ReduceOps, axis:Tuple[int, ...], acc_dt:Optional[DType]=None) -> MultiLazyBuffer:
def r(self, op:ReduceOps, axis:Tuple[int, ...]) -> MultiLazyBuffer:
if self.axis is not None and self.axis in axis:
# all-reduce on sharded axes
reduced_parts = [(x if r else x.const(0)).r(op, axis, acc_dt) for x,r in zip(self.lbs, self.real)]
reduced_parts = [(x if r else x.const(0)).r(op, axis) for x,r in zip(self.lbs, self.real)]
if all(self.real): return MultiLazyBuffer(all_reduce(op, reduced_parts), None)
return MultiLazyBuffer(reduced_parts, None, self.real)
# reduce on non sharded axes, piecewise is fine. if axis is None this is also correct
return MultiLazyBuffer([x.r(op, axis, acc_dt) for x in self.lbs], self.axis, self.real)
return MultiLazyBuffer([x.r(op, axis) for x in self.lbs], self.axis, self.real)
# *** movement ops ***

View File

@@ -2,7 +2,7 @@
import math
from typing import Tuple, Optional
from tinygrad.helpers import argsort
from tinygrad.dtype import DType
from tinygrad.dtype import dtypes, DType, sum_acc_dtype
from tinygrad.ops import UnaryOps, BinaryOps, TernaryOps, ReduceOps
from tinygrad.tensor import Function
from tinygrad.lazy import LazyBuffer
@@ -146,22 +146,22 @@ class Where(Function):
# ************* reduce ops *************
class Sum(Function):
def forward(self, x:LazyBuffer, axis:Tuple[int, ...], acc_dtype:Optional[DType]=None) -> LazyBuffer:
def forward(self, x:LazyBuffer, axis:Tuple[int, ...]) -> LazyBuffer:
self.input_shape = x.shape
return x.r(ReduceOps.SUM, axis, acc_dtype)
return x.r(ReduceOps.SUM, axis)
def backward(self, grad_output:LazyBuffer) -> LazyBuffer: return grad_output.expand(self.input_shape)
class Max(Function):
def forward(self, x:LazyBuffer, axis:Tuple[int, ...], acc_dtype:Optional[DType]=None) -> LazyBuffer:
def forward(self, x:LazyBuffer, axis:Tuple[int, ...]) -> LazyBuffer:
self.x, self.ret, self.axis = x, x.r(ReduceOps.MAX, axis), axis
return self.ret
def backward(self, grad_output:LazyBuffer) -> LazyBuffer:
# 1s in locations where the max was chosen (can be two locations)
max_is_1s = self.x.e(BinaryOps.CMPEQ, self.ret.expand(self.x.shape)).cast(self.x.dtype)
max_is_1s = self.x.e(BinaryOps.CMPEQ, self.ret.expand(self.x.shape)).cast(dtypes.float)
div = max_is_1s.r(ReduceOps.SUM, self.axis).expand(self.x.shape)
return max_is_1s.e(BinaryOps.DIV, div).e(BinaryOps.MUL, grad_output.expand(self.x.shape))
return max_is_1s.e(BinaryOps.DIV, div).cast(grad_output.dtype).e(BinaryOps.MUL, grad_output.expand(self.x.shape))
# ************* movement ops *************
@@ -171,7 +171,8 @@ class Expand(Function):
self.expanded_axis = tuple(i for i, (si, so) in enumerate(zip(x.shape, shape)) if si != so)
return x.expand(shape)
def backward(self, grad_output:LazyBuffer) -> LazyBuffer: return grad_output.r(ReduceOps.SUM, self.expanded_axis)
def backward(self, grad_output:LazyBuffer) -> LazyBuffer:
return grad_output.cast(sum_acc_dtype(grad_output.dtype)).r(ReduceOps.SUM, self.expanded_axis).cast(grad_output.dtype)
class Reshape(Function):
def forward(self, x:LazyBuffer, shape:Tuple[int, ...]) -> LazyBuffer:

View File

@@ -1,7 +1,7 @@
from __future__ import annotations
import math
from typing import Union, Optional, Any, Tuple, List
from tinygrad.dtype import dtypes, DType, ConstType, least_upper_dtype
from tinygrad.dtype import dtypes, DType, ConstType
from tinygrad.helpers import prod, getenv, all_int, all_same, DEBUG
from tinygrad.ops import LoadOps, UnaryOps, BinaryOps, TernaryOps, ReduceOps, Op, exec_alu, python_alu
from tinygrad.shape.symbolic import sint, Variable
@@ -161,35 +161,26 @@ class LazyBuffer:
# *** reduce ops ***
def _reduce_op(self, op:ReduceOps, axis:Tuple[int, ...], acc_dt:Optional[DType]=None) -> LazyBuffer:
def _reduce_op(self, op:ReduceOps, axis:Tuple[int, ...]) -> LazyBuffer:
assert all(0 <= x < len(self.shape) for x in axis), f"axis args {axis} out of range for shape {self.shape}"
axis = tuple(x for x in axis if self.shape[x] != 1)
if len(axis) == 0: return self
new_shape = tuple(1 if i in axis else s for i,s in enumerate(self.shape))
return create_lazybuffer(self.device, ShapeTracker.from_shape(new_shape), self.dtype, op, (axis, acc_dt), (self,))
return create_lazybuffer(self.device, ShapeTracker.from_shape(new_shape), self.dtype, op, axis, (self,))
def r(self, op:ReduceOps, axis:Tuple[int, ...], acc_dt:Optional[DType]=None) -> LazyBuffer:
def r(self, op:ReduceOps, axis:Tuple[int, ...]) -> LazyBuffer:
new_shape = tuple(1 if i in axis else s for i,s in enumerate(self.shape))
# TODO: this logic should move to the scheduler
if self.size == 0 and 0 not in new_shape: return self.const({ReduceOps.SUM: 0.0, ReduceOps.MAX: -math.inf}[op], new_shape)
# upcast acc_dt here so if reduce is splitted, the intermediate dtype is upcasted
if op is ReduceOps.SUM and acc_dt is None:
acc_dt = least_upper_dtype(self.dtype, dtypes.uint) if dtypes.is_unsigned(self.dtype) else \
least_upper_dtype(self.dtype, dtypes.int) if (dtypes.is_int(self.dtype) or self.dtype==dtypes.bool) else \
least_upper_dtype(self.dtype, dtypes.float)
if acc_dt is not None and acc_dt != self.dtype:
# cast back to float16 or bfloat16 to match torch / jax behavior
return self.cast(acc_dt).r(op, axis, acc_dt).cast(self.dtype if self.dtype in [dtypes.float16, dtypes.bfloat16] else acc_dt)
# const folding after acc_dt cast to correct output dtype
# const folding
if self.is_unrealized_unmasked_const():
return self.const(self.base.arg * {ReduceOps.SUM: prod(self.shape[i] for i in axis), ReduceOps.MAX: 1}[op], new_shape)
# TODO: can we split symbolic shape if the reduce axis is not symbolic?
if not getenv("SPLIT_REDUCEOP", 1) or not all_int(self.shape) or (0 in self.shape) or \
prod(self.shape) // prod(new_shape) < getenv("REDUCEOP_SPLIT_THRESHOLD", 32768):
return self._reduce_op(op, axis, acc_dt)
return self._reduce_op(op, axis)
# if there are few globals, make some reduces into globals by splitting into two kernels
# cap output buffer to 2**22: heuristic number of global outputs to achieve max occupancy with enough locals+upcasts for gemm

View File

@@ -6,7 +6,7 @@ from typing import List, Tuple, Callable, Optional, ClassVar, Type, Union, Seque
from collections import defaultdict
import numpy as np
from tinygrad.dtype import DType, dtypes, ImageDType, ConstType, least_upper_float, least_upper_dtype
from tinygrad.dtype import DType, dtypes, ImageDType, ConstType, least_upper_float, least_upper_dtype, sum_acc_dtype
from tinygrad.helpers import argfix, make_pair, flatten, prod, all_int, round_up, merge_dicts, fully_flatten, argsort, IMAGE, DEBUG, WINO, THREEFRY
from tinygrad.helpers import getenv
from tinygrad.lazy import LazyBuffer
@@ -913,14 +913,16 @@ class Tensor:
# ***** reduce ops *****
def _reduce(self, fxn:Type[Function], axis:Optional[Union[int, Tuple[int, ...]]]=None, keepdim=False, acc_dtype:Optional[DType]=None) -> Tensor:
def _reduce(self, fxn:Type[Function], axis:Optional[Union[int, Tuple[int, ...]]]=None, keepdim=False) -> Tensor:
axis_: Tuple[int, ...] = tuple(range(len(self.shape))) if axis is None else ((axis,) if isinstance(axis, int) else tuple(axis))
axis_ = tuple(x if x >= 0 else x+len(self.shape) for x in axis_)
shape = tuple(s for i,s in enumerate(self.shape) if i not in axis_)
ret = fxn.apply(self, axis=axis_, acc_dtype=acc_dtype)
ret = fxn.apply(self, axis=axis_)
return ret if keepdim else ret.reshape(shape)
def sum(self, axis=None, keepdim=False, acc_dtype:Optional[DType]=None): return self._reduce(F.Sum, axis, keepdim, acc_dtype)
def sum(self, axis=None, keepdim=False, acc_dtype:Optional[DType]=None):
ret = self.cast(acc_dtype or sum_acc_dtype(self.dtype))._reduce(F.Sum, axis, keepdim)
return ret.cast(self.dtype) if self.dtype in {dtypes.float16, dtypes.bfloat16} else ret
def max(self, axis=None, keepdim=False): return self._reduce(F.Max, axis, keepdim)
def min(self, axis=None, keepdim=False): return -((-self).max(axis=axis, keepdim=keepdim))