move getting 0 and min value of a dtype to dtype.py (#5328)

cleanup getting base case for reduce ops
[run_process_replay]
This commit is contained in:
chenyu
2024-07-08 10:51:56 -04:00
committed by GitHub
parent b0c5c58833
commit 7d049fc20c
3 changed files with 8 additions and 9 deletions

View File

@@ -1,6 +1,6 @@
from __future__ import annotations
from typing import List, Tuple, Optional, Type, cast, DefaultDict, Dict, Union, Final, Iterator, Sequence, Callable
import itertools, math, functools
import itertools, functools
from collections import defaultdict
from tinygrad.dtype import ImageDType, dtypes, DType, PtrDType
@@ -108,10 +108,8 @@ render_ops: Dict[Type, Callable[..., UOp]] = {
class Linearizer(Kernel):
def get_reduce_acc(self, reduceop:LazyOp):
if reduceop.op is ReduceOps.SUM: return 0.0 if dtypes.is_float(reduceop.dtype) else 0
if reduceop.op is ReduceOps.MAX:
if dtypes.is_int(reduceop.dtype): return 0 if dtypes.is_unsigned(reduceop.dtype) else -2**(reduceop.dtype.itemsize*8-1)
return -math.inf if dtypes.is_float(reduceop.dtype) else False
if reduceop.op is ReduceOps.SUM: return dtypes.as_const(0, reduceop.dtype)
if reduceop.op is ReduceOps.MAX: return dtypes.min(reduceop.dtype)
# NOTE: once images are loaded, we uop them as their base float
def get_base_dtype(self, dt:DType) -> DType: return dt.base if isinstance(dt, ImageDType) else dt

View File

@@ -53,6 +53,10 @@ class dtypes:
@staticmethod
def as_const(val: ConstType, dtype:DType): return int(val) if dtypes.is_int(dtype) else float(val) if dtypes.is_float(dtype) else bool(val)
@staticmethod
def min(dtype:DType):
if dtypes.is_int(dtype): return 0 if dtypes.is_unsigned(dtype) else -2**(dtype.itemsize*8-1)
return -float("inf") if dtypes.is_float(dtype) else False
@staticmethod
def fields() -> Dict[str, DType]: return DTYPES_DICT
bool: Final[DType] = DType(0, 1, "bool", '?', 1)
int8: Final[DType] = DType(1, 1, "char", 'b', 1)

View File

@@ -1,5 +1,4 @@
from __future__ import annotations
import math
from typing import Union, Optional, Any, Tuple, List
from tinygrad.dtype import dtypes, DType, ConstType
from tinygrad.helpers import prod, getenv, all_int, all_same, DEBUG
@@ -177,9 +176,7 @@ class LazyBuffer:
def r(self, op:ReduceOps, axis:Tuple[int, ...]) -> LazyBuffer:
new_shape = tuple(1 if i in axis else s for i,s in enumerate(self.shape))
# TODO: this logic should move to the scheduler
if self.size == 0 and 0 not in new_shape:
# TODO: move base case const to dtype
return self.const({ReduceOps.SUM: 0.0, ReduceOps.MAX: -math.inf if self.dtype != dtypes.bool else False}[op], new_shape)
if 0 in self.shape and 0 not in new_shape: return self.const({ReduceOps.SUM: 0.0, ReduceOps.MAX: dtypes.min(self.dtype)}[op], new_shape)
# const folding
# TODO: fold this for symbolic?