mirror of
https://github.com/tinygrad/tinygrad.git
synced 2026-01-24 06:18:01 -05:00
move getting 0 and min value of a dtype to dtype.py (#5328)
cleanup getting base case for reduce ops [run_process_replay]
This commit is contained in:
@@ -1,6 +1,6 @@
|
||||
from __future__ import annotations
|
||||
from typing import List, Tuple, Optional, Type, cast, DefaultDict, Dict, Union, Final, Iterator, Sequence, Callable
|
||||
import itertools, math, functools
|
||||
import itertools, functools
|
||||
from collections import defaultdict
|
||||
|
||||
from tinygrad.dtype import ImageDType, dtypes, DType, PtrDType
|
||||
@@ -108,10 +108,8 @@ render_ops: Dict[Type, Callable[..., UOp]] = {
|
||||
|
||||
class Linearizer(Kernel):
|
||||
def get_reduce_acc(self, reduceop:LazyOp):
|
||||
if reduceop.op is ReduceOps.SUM: return 0.0 if dtypes.is_float(reduceop.dtype) else 0
|
||||
if reduceop.op is ReduceOps.MAX:
|
||||
if dtypes.is_int(reduceop.dtype): return 0 if dtypes.is_unsigned(reduceop.dtype) else -2**(reduceop.dtype.itemsize*8-1)
|
||||
return -math.inf if dtypes.is_float(reduceop.dtype) else False
|
||||
if reduceop.op is ReduceOps.SUM: return dtypes.as_const(0, reduceop.dtype)
|
||||
if reduceop.op is ReduceOps.MAX: return dtypes.min(reduceop.dtype)
|
||||
|
||||
# NOTE: once images are loaded, we uop them as their base float
|
||||
def get_base_dtype(self, dt:DType) -> DType: return dt.base if isinstance(dt, ImageDType) else dt
|
||||
|
||||
@@ -53,6 +53,10 @@ class dtypes:
|
||||
@staticmethod
|
||||
def as_const(val: ConstType, dtype:DType): return int(val) if dtypes.is_int(dtype) else float(val) if dtypes.is_float(dtype) else bool(val)
|
||||
@staticmethod
|
||||
def min(dtype:DType):
|
||||
if dtypes.is_int(dtype): return 0 if dtypes.is_unsigned(dtype) else -2**(dtype.itemsize*8-1)
|
||||
return -float("inf") if dtypes.is_float(dtype) else False
|
||||
@staticmethod
|
||||
def fields() -> Dict[str, DType]: return DTYPES_DICT
|
||||
bool: Final[DType] = DType(0, 1, "bool", '?', 1)
|
||||
int8: Final[DType] = DType(1, 1, "char", 'b', 1)
|
||||
|
||||
@@ -1,5 +1,4 @@
|
||||
from __future__ import annotations
|
||||
import math
|
||||
from typing import Union, Optional, Any, Tuple, List
|
||||
from tinygrad.dtype import dtypes, DType, ConstType
|
||||
from tinygrad.helpers import prod, getenv, all_int, all_same, DEBUG
|
||||
@@ -177,9 +176,7 @@ class LazyBuffer:
|
||||
def r(self, op:ReduceOps, axis:Tuple[int, ...]) -> LazyBuffer:
|
||||
new_shape = tuple(1 if i in axis else s for i,s in enumerate(self.shape))
|
||||
# TODO: this logic should move to the scheduler
|
||||
if self.size == 0 and 0 not in new_shape:
|
||||
# TODO: move base case const to dtype
|
||||
return self.const({ReduceOps.SUM: 0.0, ReduceOps.MAX: -math.inf if self.dtype != dtypes.bool else False}[op], new_shape)
|
||||
if 0 in self.shape and 0 not in new_shape: return self.const({ReduceOps.SUM: 0.0, ReduceOps.MAX: dtypes.min(self.dtype)}[op], new_shape)
|
||||
|
||||
# const folding
|
||||
# TODO: fold this for symbolic?
|
||||
|
||||
Reference in New Issue
Block a user