from __future__ import annotations from typing import Final, ClassVar, Callable, Literal import math, struct, ctypes, functools from dataclasses import dataclass, fields from tinygrad.helpers import getenv, prod from enum import Enum, auto class InvalidTypeMetaClass(type): instance:None|InvalidType = None def __call__(cls): if (ret:=InvalidTypeMetaClass.instance) is not None: return ret InvalidTypeMetaClass.instance = ret = super().__call__() return ret class InvalidType(metaclass=InvalidTypeMetaClass): def __eq__(self, other): return self is other def __hash__(self): return id(self) def __repr__(self): return "Invalid" def __reduce__(self): return (InvalidType, ()) # Return the global Invalid instance Invalid = InvalidType() ConstType = float|int|bool FmtStr = Literal['?', 'b', 'B', 'h', 'H', 'i', 'I', 'q', 'Q', 'e', 'f', 'd'] # all DTypes should only be created once class DTypeMetaClass(type): dcache: dict[tuple, DType] = {} def __call__(cls, *args, **kwargs): if (ret:=DTypeMetaClass.dcache.get(args, None)) is not None: return ret DTypeMetaClass.dcache[args] = ret = super().__call__(*args) return ret class AddrSpace(Enum): def __repr__(self): return str(self) GLOBAL = auto(); LOCAL = auto(); REG = auto() # noqa: E702 @dataclass(frozen=True, eq=False) class DType(metaclass=DTypeMetaClass): priority: int # this determines when things get upcasted itemsize: int name: str fmt: FmtStr|None count: int _scalar: DType|None @staticmethod def new(priority:int, itemsize:int, name:str, fmt:FmtStr|None): return DType(priority, itemsize, name, fmt, 1, None) def __reduce__(self): return type(self), tuple(getattr(self, f.name) for f in fields(self)) def __repr__(self): return f"dtypes.{INVERSE_DTYPES_DICT[self.scalar().name]}"+(f".vec({self.count})" if self.count != 1 else "") def __lt__(self, o:DType): return (self.priority, self.itemsize, self.name, self.fmt, self.count) < (o.priority, o.itemsize, o.name, o.fmt, o.count) @property def base(self): return self @property def vcount(self): return self.count @functools.cache # pylint: disable=method-cache-max-size-none def vec(self, sz:int) -> DType: assert self.count == 1, f"can't vectorize {self} with size {sz}" if sz == 1 or self == dtypes.void: return self # void doesn't vectorize, and sz=1 is scalar return DType(self.priority, self.itemsize*sz, f"{INVERSE_DTYPES_DICT[self.name]}{sz}", None, sz, self) def ptr(self, size=-1, addrspace=AddrSpace.GLOBAL) -> PtrDType: return PtrDType(self.priority, self.itemsize, self.name, self.fmt, self.count, None, self, addrspace, 1, size) def scalar(self) -> DType: return self._scalar if self._scalar is not None else self def nbytes(self) -> int: raise RuntimeError("only ptr types have nbytes") @property def min(self): return dtypes.min(self) @property def max(self): return dtypes.max(self) @dataclass(frozen=True, eq=False) class PtrDType(DType): _base: DType addrspace: AddrSpace v: int size: int = -1 # -1 is unlimited size @property def base(self): return self._base @functools.cache # pylint: disable=method-cache-max-size-none def vec(self, sz:int) -> DType: assert self.v == 1, f"can't vectorize ptr {self} with size {sz}" if sz == 1: return self # sz=1 is a scalar if isinstance(self, ImageDType): return ImageDType(self.priority, self.itemsize, self.name, self.fmt, self.count, self, self._base, self.addrspace, sz, self.size, self.shape) return type(self)(self.priority, self.itemsize, self.name, self.fmt, self.count, self, self._base, self.addrspace, sz, self.size) def ptr(self, size=-1, addrspace=AddrSpace.GLOBAL) -> PtrDType: raise RuntimeError("can't make a pointer from a pointer") def nbytes(self) -> int: if self.size == -1: raise RuntimeError("can't get nbytes of a pointer with unlimited size") return self.size*self.itemsize @property def vcount(self): return self.v def __repr__(self): return f"{self.base.__repr__()}.ptr({self.size}{', '+str(self.addrspace) if self.addrspace != AddrSpace.GLOBAL else ''})" + \ (f'.vec({self.v})' if self.v != 1 else '') @dataclass(frozen=True, eq=False) class ImageDType(PtrDType): shape: tuple[int, ...] = () # shape of the Image def ptr(self, size=-1, addrspace=AddrSpace.GLOBAL) -> PtrDType: assert addrspace == AddrSpace.GLOBAL, "images can't be local" return self def __repr__(self): return f"dtypes.{self.name}({self.shape})" + (f'.vec({self.v})' if self.v != 1 else '') class dtypes: @staticmethod @functools.cache def is_float(x: DType) -> bool: return x.scalar() in dtypes.floats or isinstance(x, ImageDType) @staticmethod # static methods on top, or bool in the type info will refer to dtypes.bool @functools.cache def is_int(x: DType) -> bool: return x.scalar() in dtypes.ints + (dtypes.index,) @staticmethod @functools.cache def is_unsigned(x: DType) -> bool: return x.scalar() in dtypes.uints @staticmethod def is_bool(x: DType) -> bool: return x.scalar() == dtypes.bool @staticmethod def from_py(x) -> DType: if x.__class__ is float: return dtypes.default_float if x.__class__ is int: return dtypes.default_int if x.__class__ is bool: return dtypes.bool # put this in the last is faster because there are more items than lists/tuples to check if x.__class__ is list or x.__class__ is tuple: return max(dtypes.from_py(xi) for xi in x) if x else dtypes.default_float raise RuntimeError(f"Could not infer dtype of {x} with type {type(x)}") @staticmethod def as_const(val: tuple[ConstType|InvalidType, ...]|ConstType|InvalidType, dtype:DType): if isinstance(val, tuple): assert len(val) == dtype.count, f"mismatch {val} {dtype}" return tuple(dtypes.as_const(x, dtype) for x in val) if isinstance(val, InvalidType): return val return int(val) if dtypes.is_int(dtype) else float(val) if dtypes.is_float(dtype) else bool(val) @staticmethod @functools.cache def min(dtype:DType): if dtypes.is_int(dtype): return 0 if dtypes.is_unsigned(dtype) else -2**(dtype.scalar().itemsize*8-1) return -float("inf") if dtypes.is_float(dtype) else False @staticmethod @functools.cache def max(dtype:DType): if dtypes.is_int(dtype): return 2**(dtype.scalar().itemsize*8)-1+dtypes.min(dtype) return float("inf") if dtypes.is_float(dtype) else True @staticmethod def finfo(dtype:DType) -> tuple[int, int]: """(exponent, mantissa)""" if not dtypes.is_float(dtype): raise ValueError(f"{dtype} is not a floating point type") return {dtypes.float16: (5, 10), dtypes.bfloat16: (8, 7), dtypes.float32: (8, 23), dtypes.float64: (11, 52), dtypes.fp8e5m2: (5, 2), dtypes.fp8e4m3: (4, 3)}[dtype] @staticmethod def fields() -> dict[str, DType]: return DTYPES_DICT void: Final[DType] = DType.new(-1, 0, "void", None) index: Final[DType] = DType.new(-1,100, "index", None) bool: Final[DType] = DType.new(0, 1, "bool", '?') int8: Final[DType] = DType.new(1, 1, "signed char", 'b') uint8: Final[DType] = DType.new(2, 1, "unsigned char", 'B') int16: Final[DType] = DType.new(3, 2, "short", 'h') uint16: Final[DType] = DType.new(4, 2, "unsigned short", 'H') int32: Final[DType] = DType.new(5, 4, "int", 'i') uint32: Final[DType] = DType.new(6, 4, "unsigned int", 'I') int64: Final[DType] = DType.new(7, 8, "long", 'q') uint64: Final[DType] = DType.new(8, 8, "unsigned long", 'Q') fp8e4m3: Final[DType] = DType.new(9, 1, "float8_e4m3", None) fp8e5m2: Final[DType] = DType.new(10, 1, "float8_e5m2", None) float16: Final[DType] = DType.new(11, 2, "half", 'e') # bfloat16 has higher priority than float16, so least_upper_dtype(dtypes.int64, dtypes.uint64) = dtypes.float16 bfloat16: Final[DType] = DType.new(12, 2, "__bf16", None) float32: Final[DType] = DType.new(13, 4, "float", 'f') float64: Final[DType] = DType.new(14, 8, "double", 'd') # dtype aliases half = float16; float = float32; double = float64 # noqa: E702 uchar = uint8; ushort = uint16; uint = uint32; ulong = uint64 # noqa: E702 char = int8; short = int16; int = int32; long = int64 # noqa: E702 # NOTE: these are image dtypes @staticmethod def imageh(shp): return ImageDType(100, 2, "imageh", 'e', 1, None, dtypes.float32, AddrSpace.GLOBAL, 1, prod(shp), shp) @staticmethod def imagef(shp): return ImageDType(100, 4, "imagef", 'f', 1, None, dtypes.float32, AddrSpace.GLOBAL, 1, prod(shp), shp) default_float: ClassVar[DType] = float32 default_int: ClassVar[DType] = int32 fp8s = (fp8e4m3, fp8e5m2) floats = fp8s + (float16, bfloat16, float32, float64) uints = (uint8, uint16, uint32, uint64) sints = (int8, int16, int32, int64) ints = uints + sints all = floats + ints + (bool, index) # noqa: A003 if (env_default_float := getenv("DEFAULT_FLOAT", "")): dtypes.default_float = getattr(dtypes, env_default_float.lower()) assert dtypes.is_float(dtypes.default_float), f"{env_default_float} is not a float dtype" DTypeLike = str|DType def to_dtype(dtype:DTypeLike) -> DType: return dtype if isinstance(dtype, DType) else getattr(dtypes, dtype.lower()) # https://jax.readthedocs.io/en/latest/jep/9407-type-promotion.html # we don't support weak type and complex type promo_lattice = { dtypes.bool: [dtypes.int8, dtypes.uint8], dtypes.int8: [dtypes.int16], dtypes.int16: [dtypes.int32], dtypes.int32: [dtypes.int64], dtypes.int64: [dtypes.fp8e4m3, dtypes.fp8e5m2], dtypes.uint8: [dtypes.int16, dtypes.uint16], dtypes.uint16: [dtypes.int32, dtypes.uint32], dtypes.uint32: [dtypes.int64, dtypes.uint64], dtypes.uint64: [dtypes.fp8e4m3, dtypes.fp8e5m2], dtypes.fp8e5m2: [dtypes.float16, dtypes.bfloat16], dtypes.fp8e4m3: [dtypes.float16, dtypes.bfloat16], dtypes.float16: [dtypes.float32], dtypes.bfloat16: [dtypes.float32], dtypes.float32: [dtypes.float64], } @functools.cache def _get_recursive_parents(dtype:DType) -> set[DType]: return set.union(*[_get_recursive_parents(d) for d in promo_lattice[dtype]], {dtype}) if dtype != dtypes.float64 else {dtypes.float64} @functools.cache def least_upper_dtype(*ds:DType) -> DType: return min(set.intersection(*[_get_recursive_parents(d.scalar()) for d in ds])) \ if not (images:=[d for d in ds if isinstance(d, ImageDType)]) else images[0] def least_upper_float(dt:DType) -> DType: return dt if dtypes.is_float(dt) else least_upper_dtype(dt, dtypes.default_float) DTYPES_DICT = {k: v for k, v in dtypes.__dict__.items() if isinstance(v, DType) and not k.startswith(("default", "void", "index"))} INVERSE_DTYPES_DICT = {**{v.name:k for k,v in DTYPES_DICT.items()}, "void": "void", "index":"index"} @functools.cache def can_safe_cast(dt0:DType, dt1:DType) -> bool: # return if dt1 preserves value of dt0 # https://numpy.org/doc/stable/reference/generated/numpy.can_cast.html if dt0 == dt1 or dt0 == dtypes.bool: return True match dt1: case dtypes.index: return dt0 in dtypes.ints case dtypes.double: return dt0 in (dtypes.float, dtypes.half, dtypes.bfloat16, dtypes.uint32, dtypes.uint16, dtypes.uint8, dtypes.int32, dtypes.int16, dtypes.int8) case dtypes.float: return dt0 in (dtypes.half, dtypes.bfloat16, dtypes.uint16, dtypes.uint8, dtypes.int16, dtypes.int8) case dtypes.uint64: return dt0 in (dtypes.uint32, dtypes.uint16, dtypes.uint8) case dtypes.uint32: return dt0 in (dtypes.uint16, dtypes.uint8) case dtypes.int64: return dt0 in (dtypes.uint32, dtypes.uint16, dtypes.uint8, dtypes.int32, dtypes.int16, dtypes.int8) case dtypes.int32: return dt0 in (dtypes.uint16, dtypes.uint8, dtypes.int16, dtypes.int8) case dtypes.int16: return dt0 in (dtypes.uint8, dtypes.int8) case _: return False def sum_acc_dtype(dt:DType): # default acc dtype for sum if dtypes.is_unsigned(dt): return least_upper_dtype(dt, dtypes.uint) if dtypes.is_int(dt) or dt == dtypes.bool: return least_upper_dtype(dt, dtypes.int) return least_upper_dtype(dt, to_dtype(getenv("SUM_DTYPE", "float32"))) def float_to_fp16(x): try: return struct.unpack('e', struct.pack('e', float(x)))[0] except OverflowError: return math.copysign(math.inf, x) def float_to_bf16(x): if not math.isfinite(x): return x u = struct.unpack('I', struct.pack('f', x))[0] u = (u + 0x7FFF + ((u >> 16) & 1)) & 0xFFFF0000 return struct.unpack('f', struct.pack('I', u))[0] # fp8-float conversions based on https://gitlab.com/nvidia/headers/cuda-individual/cudart/-/blob/main/cuda_fp8.hpp def float_to_fp8(x: float, dtype: DType) -> int: assert dtype in dtypes.fp8s, "Only for fp8s" # e4m3 don't support inf, return 0x7f(+NaN) and 0xff(-NaN) to match jax # NaN is unordered, can't compare with zero, use math.copysign to get sign if dtype == dtypes.fp8e4m3 and not math.isfinite(x): return 0x7f if math.copysign(1, x) > 0 else 0xff if dtype == dtypes.fp8e5m2 and math.isinf(x): return 0x7c if math.copysign(1, x) > 0 else 0xfc config = { dtypes.fp8e4m3: {"EXP_BIAS": 7, "SIGNIFICAND_BITS": 4, "MANTISSA_MASK": 0x7, "MINDENORM_O2": 0x3F50000000000000, "OVERFLOW_THRESHOLD": 0x407D000000000000, "MAXNORM": 0x7E, "MINNORM": 0x3F90000000000000, "INF_VALUE": 0x7F}, dtypes.fp8e5m2: {"EXP_BIAS": 15, "SIGNIFICAND_BITS": 3, "MANTISSA_MASK": 0x3, "MINDENORM_O2": 0x3EE0000000000000, "OVERFLOW_THRESHOLD": 0x40EE000000000000 - 1, "MAXNORM": 0x7B, "MINNORM": 0x3F10000000000000, "INF_VALUE": 0x7E} }[dtype] xbits, = struct.unpack('Q', struct.pack('d', x)) FP8_DP_HALF_ULP = 1 << (53 - config["SIGNIFICAND_BITS"] - 1) sign = ((xbits >> 63) & 1) << 7 exp = (((xbits >> 52) & 0x7FF) - 1023 + config["EXP_BIAS"]) mantissa = (xbits >> (53 - config["SIGNIFICAND_BITS"])) & config["MANTISSA_MASK"] absx = xbits & 0x7FFFFFFFFFFFFFFF if absx <= config["MINDENORM_O2"]: res = 0 elif absx > 0x7FF0000000000000: res = 0x7F if dtype == dtypes.fp8e4m3 else 0x7E | mantissa elif absx > config["OVERFLOW_THRESHOLD"]: res = config["MAXNORM"] elif absx >= config["MINNORM"]: res = ((exp << (config["SIGNIFICAND_BITS"] - 1)) | mantissa) round_bits = xbits & ((FP8_DP_HALF_ULP << 1) - 1) if (round_bits > FP8_DP_HALF_ULP) or (round_bits == FP8_DP_HALF_ULP and (mantissa & 1)): res = res + 1 else: shift = 1 - exp mantissa |= 1 << (config["SIGNIFICAND_BITS"] - 1) res = (mantissa >> shift) round_bits = (xbits | (1 << (53 - 1))) & ((FP8_DP_HALF_ULP << (shift + 1)) - 1) if (round_bits > (FP8_DP_HALF_ULP << shift)) or (round_bits == (FP8_DP_HALF_ULP << shift) and (res & 1)): res = res + 1 res |= sign return int(res) def fp8_to_float(x: int, dtype: DType) -> float: assert dtype in dtypes.fp8s, "Only for fp8s" ur = x << 8 if dtype == dtypes.fp8e5m2 and (ur & 0x7FFF) > 0x7C00: ur = 0x7FFF elif dtype == dtypes.fp8e4m3: sign = ur & 0x8000 exponent = ((ur & 0x7800) >> 1) + 0x2000 mantissa = (ur & 0x0700) >> 1 absx = x & 0x7F if absx == 0x7F: ur = 0x7FFF elif exponent == 0x2000: if mantissa != 0: mantissa <<= 1 while (mantissa & 0x0400) == 0: mantissa <<= 1 exponent -= 0x0400 mantissa &= 0x03FF else: exponent = 0 ur = (sign | exponent) | mantissa else: ur = (sign | exponent) | mantissa half_bytes = struct.pack(' type|None: import numpy as np if dtype in { dtypes.bfloat16, *dtypes.fp8s }: return np.float32 return np.dtype(dtype.fmt).type if dtype.fmt is not None else None def _from_np_dtype(npdtype:'np.dtype') -> DType: # type: ignore [name-defined] # noqa: F821 import numpy as np return dtypes.fields()[np.dtype(npdtype).name] @functools.cache def _to_torch_dtype(dtype:DType) -> 'torch.dtype'|None: # type: ignore [name-defined] # noqa: F821 import numpy as np, torch if dtype == dtypes.uint64: return torch.uint64 if dtype == dtypes.bfloat16: return torch.bfloat16 if dtype in dtypes.fp8s: return torch.uint8 # NOTE: torch doesn't expose this mapping with a stable API try: return torch.from_numpy(np.array([], dtype=_to_np_dtype(dtype))).dtype except TypeError: return None @functools.cache def _from_torch_dtype(torchdtype:'torch.dtype') -> DType: # type: ignore [name-defined] # noqa: F821 return {v:k for k in DTYPES_DICT.values() if (v:=_to_torch_dtype(k)) is not None}[torchdtype]