mirror of
https://github.com/tinygrad/tinygrad.git
synced 2026-04-29 03:00:14 -04:00
385 lines
20 KiB
Python
385 lines
20 KiB
Python
from __future__ import annotations
|
|
from typing import Final, ClassVar, Callable, Literal
|
|
import math, struct, ctypes, functools
|
|
from dataclasses import dataclass, fields
|
|
from tinygrad.helpers import ceildiv, getenv, prod, round_up, next_power2, OSX
|
|
from enum import Enum, auto
|
|
|
|
class ConstFloat(float):
|
|
"""Float subclass that distinguishes -0.0 from 0.0 and where nan == nan."""
|
|
__slots__ = ('bits',)
|
|
bits: int
|
|
def __new__(cls, v:float):
|
|
obj = super().__new__(cls, v)
|
|
obj.bits = struct.unpack('<Q', struct.pack('<d', v))[0]
|
|
return obj
|
|
def __eq__(self, other):
|
|
if self is other: return True
|
|
if isinstance(other, float) and math.isnan(self) and math.isnan(other): return True
|
|
return float.__eq__(self, other)
|
|
def __hash__(self): return hash(self.bits)
|
|
|
|
class InvalidType:
|
|
_instance: ClassVar[InvalidType|None] = None
|
|
def __new__(cls):
|
|
if cls._instance is None: cls._instance = object.__new__(cls)
|
|
return cls._instance
|
|
def __eq__(self, other): return self is other
|
|
def __lt__(self, other): return self is not other
|
|
def __gt__(self, other): return self is not other
|
|
def __hash__(self): return id(self)
|
|
def __repr__(self): return "Invalid"
|
|
def __reduce__(self): return (InvalidType, ()) # unpickle returns the singleton
|
|
def __format__(self, spec): return "Invalid"
|
|
|
|
Invalid = InvalidType()
|
|
|
|
PyConst = float|int|bool
|
|
ConstType = PyConst|InvalidType
|
|
|
|
FmtStr = Literal['?', 'b', 'B', 'h', 'H', 'i', 'I', 'q', 'Q', 'e', 'f', 'd']
|
|
|
|
# all DTypes should only be created once
|
|
class DTypeMetaClass(type):
|
|
dcache: dict[tuple, DType] = {}
|
|
def __call__(cls, *args, **kwargs):
|
|
if (ret:=DTypeMetaClass.dcache.get(args, None)) is not None: return ret
|
|
DTypeMetaClass.dcache[args] = ret = super().__call__(*args)
|
|
return ret
|
|
|
|
class AddrSpace(Enum):
|
|
def __repr__(self): return str(self)
|
|
GLOBAL = auto(); LOCAL = auto(); REG = auto() # noqa: E702
|
|
|
|
@dataclass(frozen=True, eq=False)
|
|
class DType(metaclass=DTypeMetaClass):
|
|
priority: int # this determines when things get upcasted
|
|
bitsize: int
|
|
name: str
|
|
fmt: FmtStr|None
|
|
count: int
|
|
_scalar: DType|None
|
|
@property
|
|
def itemsize(self) -> int: return (self.bitsize + 7) // 8
|
|
@staticmethod
|
|
def new(priority:int, bitsize:int, name:str, fmt:FmtStr|None): return DType(priority, bitsize, name, fmt, 1, None)
|
|
def __reduce__(self): return type(self), tuple(getattr(self, f.name) for f in fields(self))
|
|
def __repr__(self): return f"dtypes.{INVERSE_DTYPES_DICT[self.scalar().name]}"+(f".vec({self.count})" if self.count != 1 else "")
|
|
def __lt__(self, o:DType): return (self.priority, self.bitsize, self.name, self.fmt, self.count) < (o.priority, o.bitsize, o.name, o.fmt, o.count)
|
|
@property
|
|
def base(self): return self
|
|
@property
|
|
def vcount(self): return self.count
|
|
@functools.cache # pylint: disable=method-cache-max-size-none
|
|
def vec(self, sz:int) -> DType:
|
|
assert self.count == 1, f"can't vectorize {self} with size {sz}"
|
|
if sz == 1 or self == dtypes.void: return self # void doesn't vectorize, and sz=1 is scalar
|
|
return DType(self.priority, self.bitsize*sz, f"{INVERSE_DTYPES_DICT[self.name]}{sz}", None, sz, self)
|
|
def ptr(self, size=-1, addrspace=AddrSpace.GLOBAL) -> PtrDType:
|
|
return PtrDType(self.priority, self.bitsize, self.name, self.fmt, self.count, None, self, addrspace, 1, size)
|
|
def scalar(self) -> DType: return self._scalar if self._scalar is not None else self
|
|
def nbytes(self) -> int: raise RuntimeError("only ptr types have nbytes")
|
|
@functools.cached_property
|
|
def min(self):
|
|
if dtypes.is_int(self): return 0 if dtypes.is_unsigned(self) else -2**(self.scalar().bitsize-1)
|
|
return -float("inf") if dtypes.is_float(self) else False
|
|
@functools.cached_property
|
|
def max(self):
|
|
if dtypes.is_int(self): return 2**(self.scalar().bitsize)-1+self.min
|
|
return float("inf") if dtypes.is_float(self) else True
|
|
def const(self, val: tuple[ConstType, ...]|ConstType):
|
|
if isinstance(val, tuple):
|
|
assert len(val) == self.count, f"mismatch {val} {self}"
|
|
return tuple(map(self.const, val))
|
|
if isinstance(val, InvalidType): return val
|
|
# NOTE: float('nan') != float('nan'), so we canonicalize here
|
|
if isinstance(val, float) and math.isnan(val): val = math.nan
|
|
# int is the default. wrap floats in ConstFloat to distinguish -0.0 from 0.0 in cache
|
|
return ConstFloat(float(val)) if dtypes.is_float(self) else bool(val) if dtypes.is_bool(self) else int(val)
|
|
|
|
@dataclass(frozen=True, eq=False)
|
|
class PtrDType(DType):
|
|
_base: DType
|
|
addrspace: AddrSpace
|
|
v: int
|
|
size: int = -1 # -1 is unlimited size
|
|
@property
|
|
def base(self): return self._base
|
|
@functools.cache # pylint: disable=method-cache-max-size-none
|
|
def vec(self, sz:int) -> DType:
|
|
assert self.v == 1, f"can't vectorize ptr {self} with size {sz}"
|
|
if sz == 1: return self # sz=1 is a scalar
|
|
if isinstance(self, ImageDType):
|
|
return ImageDType(self.priority, self.bitsize, self.name, self.fmt, self.count, self, self._base, self.addrspace, sz, self.size, self.shape)
|
|
return type(self)(self.priority, self.bitsize, self.name, self.fmt, self.count, self, self._base, self.addrspace, sz, self.size)
|
|
def ptr(self, size=-1, addrspace=AddrSpace.GLOBAL) -> PtrDType: raise RuntimeError("can't make a pointer from a pointer")
|
|
def nbytes(self) -> int:
|
|
if self.size == -1: raise RuntimeError("can't get nbytes of a pointer with unlimited size")
|
|
return self.size*self.itemsize
|
|
@property
|
|
def vcount(self): return self.v
|
|
def __repr__(self):
|
|
return f"{self.base.__repr__()}.ptr({self.size}{', '+str(self.addrspace) if self.addrspace != AddrSpace.GLOBAL else ''})" + \
|
|
(f'.vec({self.v})' if self.v != 1 else '')
|
|
|
|
@dataclass(frozen=True, eq=False)
|
|
class ImageDType(PtrDType):
|
|
shape: tuple[int, ...] = () # shape of the Image
|
|
_pitch: int = -1
|
|
def ptr(self, size=-1, addrspace=AddrSpace.GLOBAL) -> PtrDType:
|
|
assert addrspace == AddrSpace.GLOBAL, "images can't be local"
|
|
return self
|
|
def __repr__(self): return f"dtypes.{self.name}({self.shape})" + (f'.vec({self.v})' if self.v != 1 else '')
|
|
@property
|
|
def pitch(self):
|
|
if self._pitch != -1: return self._pitch
|
|
imgw, imgh, itemsize_log = self.shape[1], self.shape[0], int(math.log2(self.itemsize))
|
|
if OSX: return round_up(imgw, 256) * 4 * self.itemsize
|
|
# needs to be IMAGE_PITCH_ALIGN=256 for AMD
|
|
min_pitchalign = int(math.log2(v)) if (v := getenv("IMAGE_PITCH_ALIGN", 0)) > 0 else 6
|
|
pitchalign = max(min_pitchalign, 11 - int(math.log2(imgh))) if imgh > 1 else min_pitchalign
|
|
align_up = max(1, (8 // itemsize_log + 1) - imgh // 32) if pitchalign == 6 else (2 ** (pitchalign - itemsize_log - 2))
|
|
|
|
granularity = 128 if self.itemsize == 4 else 256
|
|
pitch_add = (1 << pitchalign) if min(next_power2(imgw), round_up(imgw, granularity)) - align_up + 1 <= imgw and imgw > granularity//2 else 0
|
|
return round_up(imgw * 4 * self.itemsize, 1 << pitchalign) + pitch_add
|
|
|
|
# get list of (height, width) that do not require pitch padding
|
|
@staticmethod
|
|
def valid_dims(ptr:PtrDType) -> list[tuple[int,int]]:
|
|
ALIGN, MAXW, pxls = getenv("IMAGE_PITCH_ALIGN", 256 if OSX else 64), 16384, ptr.size // 4
|
|
if ptr.base not in (dtypes.half, dtypes.float) or ptr.size > 4*MAXW*MAXW: return []
|
|
# OSX has stricter requirements for height=1 images
|
|
if ptr.size % (ALIGN * 4) != 0: return [] if OSX or ptr.nbytes() % getenv("IMAGE_BASE_ALIGN", 64) != 0 else [(1, pxls)]
|
|
return [(pxls//ALIGN//k, ALIGN*k) for k in range(ceildiv(pxls//ALIGN, MAXW), min(pxls//ALIGN, MAXW//ALIGN)+1) if (pxls//ALIGN)%k == 0]
|
|
|
|
class dtypes:
|
|
@staticmethod
|
|
@functools.cache
|
|
def is_float(x: DType) -> bool: return x.scalar() in dtypes.floats or isinstance(x, ImageDType)
|
|
@staticmethod # static methods on top, or bool in the type info will refer to dtypes.bool
|
|
@functools.cache
|
|
def is_int(x: DType) -> bool: return x.scalar() in (dtypes.ints + (dtypes.weakint,))
|
|
@staticmethod
|
|
@functools.cache
|
|
def is_unsigned(x: DType) -> bool: return x.scalar() in dtypes.uints
|
|
@staticmethod
|
|
def is_bool(x: DType) -> bool: return x.scalar() == dtypes.bool
|
|
@staticmethod
|
|
def from_py(x) -> DType:
|
|
if x.__class__ is float: return dtypes.default_float
|
|
if x.__class__ is int: return dtypes.default_int
|
|
if x.__class__ is bool: return dtypes.bool
|
|
# put this in the last is faster because there are more items than lists/tuples to check
|
|
if x.__class__ is list or x.__class__ is tuple: return max(dtypes.from_py(xi) for xi in x) if x else dtypes.default_float
|
|
raise RuntimeError(f"Could not infer dtype of {x} with type {type(x)}")
|
|
@staticmethod
|
|
def finfo(dtype:DType) -> tuple[int, int]:
|
|
"""(exponent, mantissa)"""
|
|
if not dtypes.is_float(dtype): raise ValueError(f"{dtype} is not a floating point type")
|
|
return {dtypes.float16: (5, 10), dtypes.bfloat16: (8, 7), dtypes.float32: (8, 23), dtypes.float64: (11, 52),
|
|
dtypes.fp8e4m3: (4, 3), dtypes.fp8e5m2: (5, 2), dtypes.fp8e4m3fnuz: (4, 3), dtypes.fp8e5m2fnuz: (5, 2)}[dtype]
|
|
void: Final[DType] = DType.new(-1, 0, "void", None)
|
|
weakint: Final[DType] = DType.new(-1, 800, "weakint", None)
|
|
bool: Final[DType] = DType.new(0, 1, "bool", '?')
|
|
int8: Final[DType] = DType.new(1, 8, "signed char", 'b')
|
|
uint8: Final[DType] = DType.new(2, 8, "unsigned char", 'B')
|
|
int16: Final[DType] = DType.new(3, 16, "short", 'h')
|
|
uint16: Final[DType] = DType.new(4, 16, "unsigned short", 'H')
|
|
int32: Final[DType] = DType.new(5, 32, "int", 'i')
|
|
uint32: Final[DType] = DType.new(6, 32, "unsigned int", 'I')
|
|
int64: Final[DType] = DType.new(7, 64, "long", 'q')
|
|
uint64: Final[DType] = DType.new(8, 64, "unsigned long", 'Q')
|
|
_uint128: Final[DType] = DType.new(8, 128, "uint128", None)
|
|
_uint256: Final[DType] = DType.new(8, 256, "uint256", None)
|
|
fp8e4m3: Final[DType] = DType.new(9, 8, "float8_e4m3", None)
|
|
fp8e5m2: Final[DType] = DType.new(10, 8, "float8_e5m2", None)
|
|
fp8e4m3fnuz: Final[DType] = DType.new(9, 8, "float8_e4m3fnuz", None)
|
|
fp8e5m2fnuz: Final[DType] = DType.new(10, 8, "float8_e5m2fnuz", None)
|
|
float16: Final[DType] = DType.new(11, 16, "half", 'e')
|
|
# bfloat16 has higher priority than float16, so least_upper_dtype(dtypes.int64, dtypes.uint64) = dtypes.float16
|
|
bfloat16: Final[DType] = DType.new(12, 16, "__bf16", None)
|
|
float32: Final[DType] = DType.new(13, 32, "float", 'f')
|
|
float64: Final[DType] = DType.new(14, 64, "double", 'd')
|
|
|
|
# dtype aliases
|
|
half = float16; float = float32; double = float64 # noqa: E702
|
|
uchar = uint8; ushort = uint16; uint = uint32; ulong = uint64 # noqa: E702
|
|
char = int8; short = int16; int = int32; long = int64 # noqa: E702
|
|
|
|
# NOTE: these are image dtypes
|
|
@staticmethod
|
|
def imageh(shp, pitch=-1): return ImageDType(100, 16, "imageh", 'e', 1, None, dtypes.float32, AddrSpace.GLOBAL, 1, prod(shp), shp, pitch)
|
|
@staticmethod
|
|
def imagef(shp, pitch=-1): return ImageDType(100, 32, "imagef", 'f', 1, None, dtypes.float32, AddrSpace.GLOBAL, 1, prod(shp), shp, pitch)
|
|
|
|
default_float: ClassVar[DType] = float32
|
|
default_int: ClassVar[DType] = int32
|
|
|
|
fp8_ocp = (fp8e4m3, fp8e5m2)
|
|
fp8_fnuz = (fp8e4m3fnuz, fp8e5m2fnuz)
|
|
fp8s = fp8_ocp + fp8_fnuz
|
|
floats = fp8s + (float16, bfloat16, float32, float64)
|
|
int8s = (uint8, int8)
|
|
int16s = (uint16, int16)
|
|
int32s = (uint32, int32)
|
|
int64s = (uint64, int64)
|
|
uints = (uint8, uint16, uint32, uint64)
|
|
sints = (int8, int16, int32, int64)
|
|
ints = uints + sints
|
|
all = floats + ints + (bool, weakint) # noqa: A003
|
|
|
|
if (env_default_float := getenv("DEFAULT_FLOAT", "")):
|
|
dtypes.default_float = getattr(dtypes, env_default_float.lower())
|
|
assert dtypes.is_float(dtypes.default_float), f"{env_default_float} is not a float dtype"
|
|
|
|
DTypeLike = str|DType
|
|
def to_dtype(dtype:DTypeLike) -> DType: return dtype if isinstance(dtype, DType) else getattr(dtypes, dtype.lower())
|
|
|
|
# https://jax.readthedocs.io/en/latest/jep/9407-type-promotion.html
|
|
# we don't support complex type
|
|
# TODO: weakint and weakfloat in lattice
|
|
promo_lattice = { dtypes.bool: [dtypes.int8, dtypes.uint8], dtypes.int8: [dtypes.int16], dtypes.int16: [dtypes.int32], dtypes.int32: [dtypes.int64],
|
|
dtypes.int64: [dtypes.uint64], dtypes.uint8: [dtypes.int16, dtypes.uint16], dtypes.uint16: [dtypes.int32, dtypes.uint32],
|
|
dtypes.uint32: [dtypes.int64, dtypes.uint64], dtypes.uint64: [dtypes.fp8e4m3, dtypes.fp8e5m2, dtypes.fp8e4m3fnuz, dtypes.fp8e5m2fnuz],
|
|
dtypes.fp8e4m3: [dtypes.float16, dtypes.bfloat16], dtypes.fp8e5m2: [dtypes.float16, dtypes.bfloat16],
|
|
dtypes.fp8e4m3fnuz: [dtypes.float16, dtypes.bfloat16], dtypes.fp8e5m2fnuz: [dtypes.float16, dtypes.bfloat16],
|
|
dtypes.float16: [dtypes.float32], dtypes.bfloat16: [dtypes.float32], dtypes.float32: [dtypes.float64], }
|
|
|
|
@functools.cache
|
|
def _get_recursive_parents(dtype:DType) -> set[DType]:
|
|
return set.union(*[_get_recursive_parents(d) for d in promo_lattice[dtype]], {dtype}) if dtype != dtypes.float64 else {dtypes.float64}
|
|
@functools.cache
|
|
def least_upper_dtype(*ds:DType) -> DType:
|
|
return min(set.intersection(*[_get_recursive_parents(d.scalar()) for d in ds])) \
|
|
if not (images:=[d for d in ds if isinstance(d, ImageDType)]) else images[0]
|
|
def least_upper_float(dt:DType) -> DType: return dt if dtypes.is_float(dt) else least_upper_dtype(dt, dtypes.default_float)
|
|
|
|
DTYPES_DICT = {k: v for k, v in dtypes.__dict__.items() if isinstance(v, DType) and not k.startswith(("default", "void", "weakint", "_"))}
|
|
INVERSE_DTYPES_DICT = {**{v.name:k for k,v in DTYPES_DICT.items()}, "void": "void", "weakint":"weakint"}
|
|
|
|
@functools.cache
|
|
def can_lossless_cast(dt0:DType, dt1:DType) -> bool:
|
|
# return if dt1 preserves value of dt0
|
|
# similar to https://numpy.org/doc/stable/reference/generated/numpy.can_cast.html
|
|
if dt0 == dt1 or dt0 == dtypes.bool: return True
|
|
match dt1:
|
|
case dtypes.weakint: return dt0 in dtypes.ints
|
|
case dtypes.double: return dt0 in (dtypes.float, dtypes.half, dtypes.bfloat16, *dtypes.fp8s,
|
|
dtypes.uint32, dtypes.uint16, dtypes.uint8, dtypes.int32, dtypes.int16, dtypes.int8)
|
|
case dtypes.float: return dt0 in (dtypes.half, dtypes.bfloat16, *dtypes.fp8s, dtypes.uint16, dtypes.uint8, dtypes.int16, dtypes.int8)
|
|
case dtypes.half: return dt0 in (*dtypes.fp8s, dtypes.uint8, dtypes.int8)
|
|
case dtypes.uint64: return dt0 in (dtypes.uint32, dtypes.uint16, dtypes.uint8)
|
|
case dtypes.uint32: return dt0 in (dtypes.uint16, dtypes.uint8)
|
|
case dtypes.uint16: return dt0 in (dtypes.uint8,)
|
|
case dtypes.int64: return dt0 in (dtypes.uint32, dtypes.uint16, dtypes.uint8, dtypes.int32, dtypes.int16, dtypes.int8)
|
|
case dtypes.int32: return dt0 in (dtypes.uint16, dtypes.uint8, dtypes.int16, dtypes.int8)
|
|
case dtypes.int16: return dt0 in (dtypes.uint8, dtypes.int8)
|
|
case _: return False
|
|
|
|
def sum_acc_dtype(dt:DType):
|
|
# default acc dtype for sum
|
|
if dtypes.is_unsigned(dt): return least_upper_dtype(dt, dtypes.uint)
|
|
if dtypes.is_int(dt) or dt == dtypes.bool: return least_upper_dtype(dt, dtypes.int)
|
|
return least_upper_dtype(dt, to_dtype(getenv("SUM_DTYPE", "float32")))
|
|
|
|
def float_to_fp16(x):
|
|
try: return struct.unpack('e', struct.pack('e', float(x)))[0]
|
|
except OverflowError: return math.copysign(math.inf, x)
|
|
|
|
def float_to_bf16(x):
|
|
if not math.isfinite(x): return x
|
|
u = struct.unpack('I', struct.pack('f', x))[0]
|
|
u = (u + 0x7FFF + ((u >> 16) & 1)) & 0xFFFF0000
|
|
return struct.unpack('f', struct.pack('I', u))[0]
|
|
|
|
# fp8-float conversions based on https://gitlab.com/nvidia/headers/cuda-individual/cudart/-/blob/main/cuda_fp8.hpp
|
|
# (bias, sig_bits, mant_mask, min_denorm_half, ovf_threshold, max_norm, min_norm)
|
|
_fp8_cfg = {
|
|
dtypes.fp8e4m3: (7, 4, 0x7, 0x3F50000000000000, 0x407D000000000000, 0x7E, 0x3F90000000000000),
|
|
dtypes.fp8e5m2: (15, 3, 0x3, 0x3EE0000000000000, 0x40EE000000000000-1, 0x7B, 0x3F10000000000000),
|
|
dtypes.fp8e4m3fnuz: (8, 4, 0x7, 0x3F40000000000000, 0x406F000000000000-1, 0x7F, 0x3F80000000000000),
|
|
dtypes.fp8e5m2fnuz: (16, 3, 0x3, 0x3ED0000000000000, 0x40EE000000000000-1, 0x7F, 0x3F00000000000000),
|
|
}
|
|
|
|
def float_to_fp8(x: float, dtype: DType) -> int:
|
|
assert dtype in dtypes.fp8s, "Only for fp8s"
|
|
if dtype in dtypes.fp8_fnuz and not math.isfinite(x): return 0x80
|
|
if dtype in dtypes.fp8_fnuz and x == 0.0: return 0x00
|
|
# e4m3 don't support inf, return 0x7f(+NaN) and 0xff(-NaN) to match jax
|
|
# NaN is unordered, can't compare with zero, use math.copysign to get sign
|
|
if dtype == dtypes.fp8e4m3 and not math.isfinite(x): return 0x7f if math.copysign(1, x) > 0 else 0xff
|
|
if dtype == dtypes.fp8e5m2 and not math.isfinite(x): return (0 if math.copysign(1, x) > 0 else 0x80) | (0x7c if math.isinf(x) else 0x7f)
|
|
bias, sig_bits, mant_mask, min_denorm_half, ovf_threshold, max_norm, min_norm = _fp8_cfg[dtype]
|
|
xbits, = struct.unpack('Q', struct.pack('d', x))
|
|
half_ulp = 1 << (52 - sig_bits)
|
|
sign, exp, mantissa, absx = ((xbits>>63)&1)<<7, ((xbits>>52)&0x7FF)-1023+bias, (xbits>>(53-sig_bits))&mant_mask, xbits&0x7FFFFFFFFFFFFFFF
|
|
if absx <= min_denorm_half: res = 0
|
|
elif absx > ovf_threshold: res = max_norm
|
|
elif absx >= min_norm:
|
|
res, round_bits = (exp << (sig_bits - 1)) | mantissa, xbits & ((half_ulp << 1) - 1)
|
|
if round_bits > half_ulp or (round_bits == half_ulp and mantissa & 1): res += 1
|
|
else:
|
|
shift = 1 - exp
|
|
mantissa |= 1 << (sig_bits - 1)
|
|
res, half = mantissa >> shift, half_ulp << shift
|
|
round_bits = (xbits | (1 << 52)) & ((half << 1) - 1)
|
|
if round_bits > half or (round_bits == half and res & 1): res += 1
|
|
return 0 if dtype in dtypes.fp8_fnuz and res == 0 else int(res | sign) # fnuz has no negative zero
|
|
|
|
def fp8_to_float(x: int, dtype: DType) -> float:
|
|
assert dtype in dtypes.fp8s, "Only for fp8s"
|
|
if dtype in dtypes.fp8_fnuz and x == 0x80: return math.nan
|
|
if (x & 0x7F) == 0: return -0.0 if x & 0x80 else 0.0
|
|
bias, sig_bits, *_ = _fp8_cfg[dtype]
|
|
mant_bits, exp_bits = sig_bits - 1, 8 - sig_bits
|
|
exp_max, mant_max = (1 << exp_bits) - 1, (1 << mant_bits) - 1
|
|
sign, exp, mantissa = (x >> 7) & 1, (x >> mant_bits) & exp_max, x & mant_max
|
|
if dtype not in dtypes.fp8_fnuz and exp == exp_max:
|
|
if dtype == dtypes.fp8e5m2: return math.copysign(math.nan if mantissa else math.inf, -1 if sign else 1)
|
|
if mantissa == mant_max: return math.nan
|
|
val = (mantissa / (mant_max + 1)) * 2 ** (1 - bias) if exp == 0 else (1 + mantissa / (mant_max + 1)) * 2 ** (exp - bias)
|
|
return -val if sign else val
|
|
|
|
def storage_fmt_for_dtype(dtype:DType): return 'H' if dtype == dtypes.bfloat16 else 'B' if dtype in dtypes.fp8s else dtype.fmt
|
|
|
|
def to_storage_scalar(x, dtype:DType):
|
|
if dtype == dtypes.half: return float_to_fp16(x)
|
|
if dtype == dtypes.bfloat16: return (struct.unpack('I', struct.pack('f', float_to_bf16(x)))[0] >> 16) & 0xFFFF
|
|
if dtype in dtypes.fp8s: return float_to_fp8(float(x), dtype)
|
|
return x
|
|
|
|
def from_storage_scalar(x, dtype:DType):
|
|
if dtype == dtypes.bfloat16: return struct.unpack('f', struct.pack('I', (x & 0xFFFF) << 16))[0]
|
|
if dtype in dtypes.fp8s: return fp8_to_float(int(x), dtype)
|
|
return x
|
|
|
|
truncate: dict[DType, Callable] = {dtypes.bool: bool,
|
|
dtypes.float16: float_to_fp16, dtypes.bfloat16: lambda x: float_to_bf16(float(x)),
|
|
**{fp8: (lambda x, dtype=fp8: fp8_to_float(float_to_fp8(x, dtype), dtype)) for fp8 in dtypes.fp8s},
|
|
**{getattr(dtypes, n): (lambda x, c=getattr(ctypes, f'c_{n}'): c(x).value)
|
|
for n in ('float', 'double', 'int8', 'int16', 'int32', 'int64', 'uint8', 'uint16', 'uint32', 'uint64')}}
|
|
|
|
# numpy and torch dtype interop
|
|
|
|
def _to_np_dtype(dtype:DType) -> type|None:
|
|
import numpy as np
|
|
if dtype in { dtypes.bfloat16, *dtypes.fp8s }: return np.float32
|
|
return np.dtype(dtype.fmt).type if dtype.fmt is not None else None
|
|
def _from_np_dtype(npdtype:'np.dtype') -> DType: # type: ignore [name-defined] # noqa: F821
|
|
import numpy as np
|
|
return DTYPES_DICT[np.dtype(npdtype).name]
|
|
|
|
@functools.cache
|
|
def _to_torch_dtype(dtype:DType) -> 'torch.dtype'|None: # type: ignore [name-defined] # noqa: F821
|
|
import numpy as np, torch
|
|
if dtype == dtypes.uint64: return torch.uint64
|
|
if dtype == dtypes.bfloat16: return torch.bfloat16
|
|
if dtype in dtypes.fp8s: return torch.uint8
|
|
# NOTE: torch doesn't expose this mapping with a stable API
|
|
try: return torch.from_numpy(np.array([], dtype=_to_np_dtype(dtype))).dtype
|
|
except TypeError: return None
|
|
@functools.cache
|
|
def _from_torch_dtype(torchdtype:'torch.dtype') -> DType: # type: ignore [name-defined] # noqa: F821
|
|
return {v:k for k in DTYPES_DICT.values() if (v:=_to_torch_dtype(k)) is not None}[torchdtype]
|