diff --git a/tinygrad/codegen/linearizer.py b/tinygrad/codegen/linearizer.py index 8fdb08120e..03f6636f9e 100644 --- a/tinygrad/codegen/linearizer.py +++ b/tinygrad/codegen/linearizer.py @@ -1,6 +1,6 @@ from __future__ import annotations from typing import List, Tuple, Optional, Type, cast, DefaultDict, Dict, Union, Final, Iterator, Sequence, Callable -import itertools, math, functools +import itertools, functools from collections import defaultdict from tinygrad.dtype import ImageDType, dtypes, DType, PtrDType @@ -108,10 +108,8 @@ render_ops: Dict[Type, Callable[..., UOp]] = { class Linearizer(Kernel): def get_reduce_acc(self, reduceop:LazyOp): - if reduceop.op is ReduceOps.SUM: return 0.0 if dtypes.is_float(reduceop.dtype) else 0 - if reduceop.op is ReduceOps.MAX: - if dtypes.is_int(reduceop.dtype): return 0 if dtypes.is_unsigned(reduceop.dtype) else -2**(reduceop.dtype.itemsize*8-1) - return -math.inf if dtypes.is_float(reduceop.dtype) else False + if reduceop.op is ReduceOps.SUM: return dtypes.as_const(0, reduceop.dtype) + if reduceop.op is ReduceOps.MAX: return dtypes.min(reduceop.dtype) # NOTE: once images are loaded, we uop them as their base float def get_base_dtype(self, dt:DType) -> DType: return dt.base if isinstance(dt, ImageDType) else dt diff --git a/tinygrad/dtype.py b/tinygrad/dtype.py index 9c1481f282..bf62a6e5c2 100644 --- a/tinygrad/dtype.py +++ b/tinygrad/dtype.py @@ -53,6 +53,10 @@ class dtypes: @staticmethod def as_const(val: ConstType, dtype:DType): return int(val) if dtypes.is_int(dtype) else float(val) if dtypes.is_float(dtype) else bool(val) @staticmethod + def min(dtype:DType): + if dtypes.is_int(dtype): return 0 if dtypes.is_unsigned(dtype) else -2**(dtype.itemsize*8-1) + return -float("inf") if dtypes.is_float(dtype) else False + @staticmethod def fields() -> Dict[str, DType]: return DTYPES_DICT bool: Final[DType] = DType(0, 1, "bool", '?', 1) int8: Final[DType] = DType(1, 1, "char", 'b', 1) diff --git a/tinygrad/lazy.py b/tinygrad/lazy.py index 5f5217b146..8e13242985 100644 --- a/tinygrad/lazy.py +++ b/tinygrad/lazy.py @@ -1,5 +1,4 @@ from __future__ import annotations -import math from typing import Union, Optional, Any, Tuple, List from tinygrad.dtype import dtypes, DType, ConstType from tinygrad.helpers import prod, getenv, all_int, all_same, DEBUG @@ -177,9 +176,7 @@ class LazyBuffer: def r(self, op:ReduceOps, axis:Tuple[int, ...]) -> LazyBuffer: new_shape = tuple(1 if i in axis else s for i,s in enumerate(self.shape)) # TODO: this logic should move to the scheduler - if self.size == 0 and 0 not in new_shape: - # TODO: move base case const to dtype - return self.const({ReduceOps.SUM: 0.0, ReduceOps.MAX: -math.inf if self.dtype != dtypes.bool else False}[op], new_shape) + if 0 in self.shape and 0 not in new_shape: return self.const({ReduceOps.SUM: 0.0, ReduceOps.MAX: dtypes.min(self.dtype)}[op], new_shape) # const folding # TODO: fold this for symbolic?