mirror of
https://github.com/tinygrad/tinygrad.git
synced 2026-01-09 15:08:02 -05:00
358 lines
18 KiB
Python
358 lines
18 KiB
Python
from __future__ import annotations
|
|
from typing import Final, ClassVar, Callable, Literal
|
|
import math, struct, ctypes, functools
|
|
from dataclasses import dataclass, fields
|
|
from tinygrad.helpers import getenv, prod, round_up, next_power2
|
|
from enum import Enum, auto
|
|
|
|
class InvalidTypeMetaClass(type):
|
|
instance:None|InvalidType = None
|
|
def __call__(cls):
|
|
if (ret:=InvalidTypeMetaClass.instance) is not None: return ret
|
|
InvalidTypeMetaClass.instance = ret = super().__call__()
|
|
return ret
|
|
|
|
class InvalidType(metaclass=InvalidTypeMetaClass):
|
|
def __eq__(self, other): return self is other
|
|
def __lt__(self, other): return self is not other
|
|
def __gt__(self, other): return self is not other
|
|
def __hash__(self): return id(self)
|
|
def __repr__(self): return "Invalid"
|
|
def __reduce__(self): return (InvalidType, ()) # Return the global Invalid instance
|
|
|
|
Invalid = InvalidType()
|
|
|
|
ConstType = float|int|bool
|
|
|
|
FmtStr = Literal['?', 'b', 'B', 'h', 'H', 'i', 'I', 'q', 'Q', 'e', 'f', 'd']
|
|
|
|
# all DTypes should only be created once
|
|
class DTypeMetaClass(type):
|
|
dcache: dict[tuple, DType] = {}
|
|
def __call__(cls, *args, **kwargs):
|
|
if (ret:=DTypeMetaClass.dcache.get(args, None)) is not None: return ret
|
|
DTypeMetaClass.dcache[args] = ret = super().__call__(*args)
|
|
return ret
|
|
|
|
class AddrSpace(Enum):
|
|
def __repr__(self): return str(self)
|
|
GLOBAL = auto(); LOCAL = auto(); REG = auto() # noqa: E702
|
|
|
|
@dataclass(frozen=True, eq=False)
|
|
class DType(metaclass=DTypeMetaClass):
|
|
priority: int # this determines when things get upcasted
|
|
itemsize: int
|
|
name: str
|
|
fmt: FmtStr|None
|
|
count: int
|
|
_scalar: DType|None
|
|
@staticmethod
|
|
def new(priority:int, itemsize:int, name:str, fmt:FmtStr|None): return DType(priority, itemsize, name, fmt, 1, None)
|
|
def __reduce__(self): return type(self), tuple(getattr(self, f.name) for f in fields(self))
|
|
def __repr__(self): return f"dtypes.{INVERSE_DTYPES_DICT[self.scalar().name]}"+(f".vec({self.count})" if self.count != 1 else "")
|
|
def __lt__(self, o:DType): return (self.priority, self.itemsize, self.name, self.fmt, self.count) < (o.priority, o.itemsize, o.name, o.fmt, o.count)
|
|
@property
|
|
def base(self): return self
|
|
@property
|
|
def vcount(self): return self.count
|
|
@functools.cache # pylint: disable=method-cache-max-size-none
|
|
def vec(self, sz:int) -> DType:
|
|
assert self.count == 1, f"can't vectorize {self} with size {sz}"
|
|
if sz == 1 or self == dtypes.void: return self # void doesn't vectorize, and sz=1 is scalar
|
|
return DType(self.priority, self.itemsize*sz, f"{INVERSE_DTYPES_DICT[self.name]}{sz}", None, sz, self)
|
|
def ptr(self, size=-1, addrspace=AddrSpace.GLOBAL) -> PtrDType:
|
|
return PtrDType(self.priority, self.itemsize, self.name, self.fmt, self.count, None, self, addrspace, 1, size)
|
|
def scalar(self) -> DType: return self._scalar if self._scalar is not None else self
|
|
def nbytes(self) -> int: raise RuntimeError("only ptr types have nbytes")
|
|
@property
|
|
def min(self): return dtypes.min(self)
|
|
@property
|
|
def max(self): return dtypes.max(self)
|
|
|
|
@dataclass(frozen=True, eq=False)
|
|
class PtrDType(DType):
|
|
_base: DType
|
|
addrspace: AddrSpace
|
|
v: int
|
|
size: int = -1 # -1 is unlimited size
|
|
@property
|
|
def base(self): return self._base
|
|
@functools.cache # pylint: disable=method-cache-max-size-none
|
|
def vec(self, sz:int) -> DType:
|
|
assert self.v == 1, f"can't vectorize ptr {self} with size {sz}"
|
|
if sz == 1: return self # sz=1 is a scalar
|
|
if isinstance(self, ImageDType):
|
|
return ImageDType(self.priority, self.itemsize, self.name, self.fmt, self.count, self, self._base, self.addrspace, sz, self.size, self.shape)
|
|
return type(self)(self.priority, self.itemsize, self.name, self.fmt, self.count, self, self._base, self.addrspace, sz, self.size)
|
|
def ptr(self, size=-1, addrspace=AddrSpace.GLOBAL) -> PtrDType: raise RuntimeError("can't make a pointer from a pointer")
|
|
def nbytes(self) -> int:
|
|
if self.size == -1: raise RuntimeError("can't get nbytes of a pointer with unlimited size")
|
|
return self.size*self.itemsize
|
|
@property
|
|
def vcount(self): return self.v
|
|
def __repr__(self):
|
|
return f"{self.base.__repr__()}.ptr({self.size}{', '+str(self.addrspace) if self.addrspace != AddrSpace.GLOBAL else ''})" + \
|
|
(f'.vec({self.v})' if self.v != 1 else '')
|
|
|
|
@dataclass(frozen=True, eq=False)
|
|
class ImageDType(PtrDType):
|
|
shape: tuple[int, ...] = () # shape of the Image
|
|
def ptr(self, size=-1, addrspace=AddrSpace.GLOBAL) -> PtrDType:
|
|
assert addrspace == AddrSpace.GLOBAL, "images can't be local"
|
|
return self
|
|
def __repr__(self): return f"dtypes.{self.name}({self.shape})" + (f'.vec({self.v})' if self.v != 1 else '')
|
|
@property
|
|
def pitch(self):
|
|
imgw, imgh, itemsize_log = self.shape[1], self.shape[0], int(math.log2(self.itemsize))
|
|
pitchalign = max(6, 11 - int(math.log2(imgh))) if imgh > 1 else 6
|
|
align_up = max(1, (8 // itemsize_log + 1) - imgh // 32) if pitchalign == 6 else (2 ** (pitchalign - itemsize_log - 2))
|
|
|
|
granularity = 128 if self.itemsize == 4 else 256
|
|
pitch_add = (1 << pitchalign) if min(next_power2(imgw), round_up(imgw, granularity)) - align_up + 1 <= imgw and imgw > granularity//2 else 0
|
|
return round_up(imgw * 4 * self.itemsize, 1 << pitchalign) + pitch_add
|
|
|
|
class dtypes:
|
|
@staticmethod
|
|
@functools.cache
|
|
def is_float(x: DType) -> bool: return x.scalar() in dtypes.floats or isinstance(x, ImageDType)
|
|
@staticmethod # static methods on top, or bool in the type info will refer to dtypes.bool
|
|
@functools.cache
|
|
def is_int(x: DType) -> bool: return x.scalar() in dtypes.index_like
|
|
@staticmethod
|
|
@functools.cache
|
|
def is_unsigned(x: DType) -> bool: return x.scalar() in dtypes.uints
|
|
@staticmethod
|
|
def is_bool(x: DType) -> bool: return x.scalar() == dtypes.bool
|
|
@staticmethod
|
|
def from_py(x) -> DType:
|
|
if x.__class__ is float: return dtypes.default_float
|
|
if x.__class__ is int: return dtypes.default_int
|
|
if x.__class__ is bool: return dtypes.bool
|
|
# put this in the last is faster because there are more items than lists/tuples to check
|
|
if x.__class__ is list or x.__class__ is tuple: return max(dtypes.from_py(xi) for xi in x) if x else dtypes.default_float
|
|
raise RuntimeError(f"Could not infer dtype of {x} with type {type(x)}")
|
|
@staticmethod
|
|
def as_const(val: tuple[ConstType|InvalidType, ...]|ConstType|InvalidType, dtype:DType):
|
|
if isinstance(val, tuple):
|
|
assert len(val) == dtype.count, f"mismatch {val} {dtype}"
|
|
return tuple(dtypes.as_const(x, dtype) for x in val)
|
|
if isinstance(val, InvalidType): return val
|
|
# NOTE: float('nan') != float('nan'), so we canonicalize here
|
|
if isinstance(val, float) and math.isnan(val): val = math.nan
|
|
return int(val) if dtypes.is_int(dtype) else float(val) if dtypes.is_float(dtype) else bool(val)
|
|
@staticmethod
|
|
@functools.cache
|
|
def min(dtype:DType):
|
|
if dtypes.is_int(dtype): return 0 if dtypes.is_unsigned(dtype) else -2**(dtype.scalar().itemsize*8-1)
|
|
return -float("inf") if dtypes.is_float(dtype) else False
|
|
@staticmethod
|
|
@functools.cache
|
|
def max(dtype:DType):
|
|
if dtypes.is_int(dtype): return 2**(dtype.scalar().itemsize*8)-1+dtypes.min(dtype)
|
|
return float("inf") if dtypes.is_float(dtype) else True
|
|
@staticmethod
|
|
def finfo(dtype:DType) -> tuple[int, int]:
|
|
"""(exponent, mantissa)"""
|
|
if not dtypes.is_float(dtype): raise ValueError(f"{dtype} is not a floating point type")
|
|
return {dtypes.float16: (5, 10), dtypes.bfloat16: (8, 7), dtypes.float32: (8, 23), dtypes.float64: (11, 52),
|
|
dtypes.fp8e5m2: (5, 2), dtypes.fp8e4m3: (4, 3)}[dtype]
|
|
@staticmethod
|
|
def fields() -> dict[str, DType]: return DTYPES_DICT
|
|
void: Final[DType] = DType.new(-1, 0, "void", None)
|
|
index: Final[DType] = DType.new(-1,100, "index", None)
|
|
bool: Final[DType] = DType.new(0, 1, "bool", '?')
|
|
int8: Final[DType] = DType.new(1, 1, "signed char", 'b')
|
|
uint8: Final[DType] = DType.new(2, 1, "unsigned char", 'B')
|
|
int16: Final[DType] = DType.new(3, 2, "short", 'h')
|
|
uint16: Final[DType] = DType.new(4, 2, "unsigned short", 'H')
|
|
int32: Final[DType] = DType.new(5, 4, "int", 'i')
|
|
uint32: Final[DType] = DType.new(6, 4, "unsigned int", 'I')
|
|
int64: Final[DType] = DType.new(7, 8, "long", 'q')
|
|
uint64: Final[DType] = DType.new(8, 8, "unsigned long", 'Q')
|
|
fp8e4m3: Final[DType] = DType.new(9, 1, "float8_e4m3", None)
|
|
fp8e5m2: Final[DType] = DType.new(10, 1, "float8_e5m2", None)
|
|
float16: Final[DType] = DType.new(11, 2, "half", 'e')
|
|
# bfloat16 has higher priority than float16, so least_upper_dtype(dtypes.int64, dtypes.uint64) = dtypes.float16
|
|
bfloat16: Final[DType] = DType.new(12, 2, "__bf16", None)
|
|
float32: Final[DType] = DType.new(13, 4, "float", 'f')
|
|
float64: Final[DType] = DType.new(14, 8, "double", 'd')
|
|
|
|
# dtype aliases
|
|
half = float16; float = float32; double = float64 # noqa: E702
|
|
uchar = uint8; ushort = uint16; uint = uint32; ulong = uint64 # noqa: E702
|
|
char = int8; short = int16; int = int32; long = int64 # noqa: E702
|
|
|
|
# NOTE: these are image dtypes
|
|
@staticmethod
|
|
def imageh(shp): return ImageDType(100, 2, "imageh", 'e', 1, None, dtypes.float32, AddrSpace.GLOBAL, 1, prod(shp), shp)
|
|
@staticmethod
|
|
def imagef(shp): return ImageDType(100, 4, "imagef", 'f', 1, None, dtypes.float32, AddrSpace.GLOBAL, 1, prod(shp), shp)
|
|
|
|
default_float: ClassVar[DType] = float32
|
|
default_int: ClassVar[DType] = int32
|
|
|
|
fp8s = (fp8e4m3, fp8e5m2)
|
|
floats = fp8s + (float16, bfloat16, float32, float64)
|
|
uints = (uint8, uint16, uint32, uint64)
|
|
sints = (int8, int16, int32, int64)
|
|
ints = uints + sints
|
|
index_like = ints + (index,)
|
|
all = floats + ints + (bool, index) # noqa: A003
|
|
|
|
if (env_default_float := getenv("DEFAULT_FLOAT", "")):
|
|
dtypes.default_float = getattr(dtypes, env_default_float.lower())
|
|
assert dtypes.is_float(dtypes.default_float), f"{env_default_float} is not a float dtype"
|
|
|
|
DTypeLike = str|DType
|
|
def to_dtype(dtype:DTypeLike) -> DType: return dtype if isinstance(dtype, DType) else getattr(dtypes, dtype.lower())
|
|
|
|
# https://jax.readthedocs.io/en/latest/jep/9407-type-promotion.html
|
|
# we don't support weak type and complex type
|
|
promo_lattice = { dtypes.bool: [dtypes.int8, dtypes.uint8], dtypes.int8: [dtypes.int16], dtypes.int16: [dtypes.int32], dtypes.int32: [dtypes.int64],
|
|
dtypes.int64: [dtypes.fp8e4m3, dtypes.fp8e5m2], dtypes.uint8: [dtypes.int16, dtypes.uint16], dtypes.uint16: [dtypes.int32, dtypes.uint32],
|
|
dtypes.uint32: [dtypes.int64, dtypes.uint64], dtypes.uint64: [dtypes.fp8e4m3, dtypes.fp8e5m2],
|
|
dtypes.fp8e5m2: [dtypes.float16, dtypes.bfloat16], dtypes.fp8e4m3: [dtypes.float16, dtypes.bfloat16],
|
|
dtypes.float16: [dtypes.float32], dtypes.bfloat16: [dtypes.float32], dtypes.float32: [dtypes.float64], }
|
|
|
|
@functools.cache
|
|
def _get_recursive_parents(dtype:DType) -> set[DType]:
|
|
return set.union(*[_get_recursive_parents(d) for d in promo_lattice[dtype]], {dtype}) if dtype != dtypes.float64 else {dtypes.float64}
|
|
@functools.cache
|
|
def least_upper_dtype(*ds:DType) -> DType:
|
|
return min(set.intersection(*[_get_recursive_parents(d.scalar()) for d in ds])) \
|
|
if not (images:=[d for d in ds if isinstance(d, ImageDType)]) else images[0]
|
|
def least_upper_float(dt:DType) -> DType: return dt if dtypes.is_float(dt) else least_upper_dtype(dt, dtypes.default_float)
|
|
|
|
DTYPES_DICT = {k: v for k, v in dtypes.__dict__.items() if isinstance(v, DType) and not k.startswith(("default", "void", "index"))}
|
|
INVERSE_DTYPES_DICT = {**{v.name:k for k,v in DTYPES_DICT.items()}, "void": "void", "index":"index"}
|
|
|
|
@functools.cache
|
|
def can_lossless_cast(dt0:DType, dt1:DType) -> bool:
|
|
# return if dt1 preserves value of dt0
|
|
# similar to https://numpy.org/doc/stable/reference/generated/numpy.can_cast.html
|
|
if dt0 == dt1 or dt0 == dtypes.bool: return True
|
|
match dt1:
|
|
case dtypes.index: return dt0 in dtypes.ints
|
|
case dtypes.double: return dt0 in (dtypes.float, dtypes.half, dtypes.bfloat16, *dtypes.fp8s,
|
|
dtypes.uint32, dtypes.uint16, dtypes.uint8, dtypes.int32, dtypes.int16, dtypes.int8)
|
|
case dtypes.float: return dt0 in (dtypes.half, dtypes.bfloat16, *dtypes.fp8s, dtypes.uint16, dtypes.uint8, dtypes.int16, dtypes.int8)
|
|
case dtypes.half: return dt0 in (*dtypes.fp8s, dtypes.uint8, dtypes.int8)
|
|
case dtypes.uint64: return dt0 in (dtypes.uint32, dtypes.uint16, dtypes.uint8)
|
|
case dtypes.uint32: return dt0 in (dtypes.uint16, dtypes.uint8)
|
|
case dtypes.uint16: return dt0 in (dtypes.uint8,)
|
|
case dtypes.int64: return dt0 in (dtypes.uint32, dtypes.uint16, dtypes.uint8, dtypes.int32, dtypes.int16, dtypes.int8)
|
|
case dtypes.int32: return dt0 in (dtypes.uint16, dtypes.uint8, dtypes.int16, dtypes.int8)
|
|
case dtypes.int16: return dt0 in (dtypes.uint8, dtypes.int8)
|
|
case _: return False
|
|
|
|
def sum_acc_dtype(dt:DType):
|
|
# default acc dtype for sum
|
|
if dtypes.is_unsigned(dt): return least_upper_dtype(dt, dtypes.uint)
|
|
if dtypes.is_int(dt) or dt == dtypes.bool: return least_upper_dtype(dt, dtypes.int)
|
|
return least_upper_dtype(dt, to_dtype(getenv("SUM_DTYPE", "float32")))
|
|
|
|
def float_to_fp16(x):
|
|
try: return struct.unpack('e', struct.pack('e', float(x)))[0]
|
|
except OverflowError: return math.copysign(math.inf, x)
|
|
|
|
def float_to_bf16(x):
|
|
if not math.isfinite(x): return x
|
|
u = struct.unpack('I', struct.pack('f', x))[0]
|
|
u = (u + 0x7FFF + ((u >> 16) & 1)) & 0xFFFF0000
|
|
return struct.unpack('f', struct.pack('I', u))[0]
|
|
|
|
# fp8-float conversions based on https://gitlab.com/nvidia/headers/cuda-individual/cudart/-/blob/main/cuda_fp8.hpp
|
|
def float_to_fp8(x: float, dtype: DType) -> int:
|
|
assert dtype in dtypes.fp8s, "Only for fp8s"
|
|
# e4m3 don't support inf, return 0x7f(+NaN) and 0xff(-NaN) to match jax
|
|
# NaN is unordered, can't compare with zero, use math.copysign to get sign
|
|
if dtype == dtypes.fp8e4m3 and not math.isfinite(x): return 0x7f if math.copysign(1, x) > 0 else 0xff
|
|
if dtype == dtypes.fp8e5m2 and math.isinf(x): return 0x7c if math.copysign(1, x) > 0 else 0xfc
|
|
config = {
|
|
dtypes.fp8e4m3: {"EXP_BIAS": 7, "SIGNIFICAND_BITS": 4, "MANTISSA_MASK": 0x7, "MINDENORM_O2": 0x3F50000000000000,
|
|
"OVERFLOW_THRESHOLD": 0x407D000000000000, "MAXNORM": 0x7E, "MINNORM": 0x3F90000000000000, "INF_VALUE": 0x7F},
|
|
dtypes.fp8e5m2: {"EXP_BIAS": 15, "SIGNIFICAND_BITS": 3, "MANTISSA_MASK": 0x3, "MINDENORM_O2": 0x3EE0000000000000,
|
|
"OVERFLOW_THRESHOLD": 0x40EE000000000000 - 1, "MAXNORM": 0x7B, "MINNORM": 0x3F10000000000000, "INF_VALUE": 0x7E}
|
|
}[dtype]
|
|
xbits, = struct.unpack('Q', struct.pack('d', x))
|
|
FP8_DP_HALF_ULP = 1 << (53 - config["SIGNIFICAND_BITS"] - 1)
|
|
sign = ((xbits >> 63) & 1) << 7
|
|
exp = (((xbits >> 52) & 0x7FF) - 1023 + config["EXP_BIAS"])
|
|
mantissa = (xbits >> (53 - config["SIGNIFICAND_BITS"])) & config["MANTISSA_MASK"]
|
|
absx = xbits & 0x7FFFFFFFFFFFFFFF
|
|
|
|
if absx <= config["MINDENORM_O2"]: res = 0
|
|
elif absx > 0x7FF0000000000000: res = 0x7F if dtype == dtypes.fp8e4m3 else 0x7E | mantissa
|
|
elif absx > config["OVERFLOW_THRESHOLD"]: res = config["MAXNORM"]
|
|
elif absx >= config["MINNORM"]:
|
|
res = ((exp << (config["SIGNIFICAND_BITS"] - 1)) | mantissa)
|
|
round_bits = xbits & ((FP8_DP_HALF_ULP << 1) - 1)
|
|
if (round_bits > FP8_DP_HALF_ULP) or (round_bits == FP8_DP_HALF_ULP and (mantissa & 1)): res = res + 1
|
|
else:
|
|
shift = 1 - exp
|
|
mantissa |= 1 << (config["SIGNIFICAND_BITS"] - 1)
|
|
res = (mantissa >> shift)
|
|
round_bits = (xbits | (1 << (53 - 1))) & ((FP8_DP_HALF_ULP << (shift + 1)) - 1)
|
|
if (round_bits > (FP8_DP_HALF_ULP << shift)) or (round_bits == (FP8_DP_HALF_ULP << shift) and (res & 1)):
|
|
res = res + 1
|
|
|
|
res |= sign
|
|
return int(res)
|
|
|
|
def fp8_to_float(x: int, dtype: DType) -> float:
|
|
assert dtype in dtypes.fp8s, "Only for fp8s"
|
|
ur = x << 8
|
|
|
|
if dtype == dtypes.fp8e5m2 and (ur & 0x7FFF) > 0x7C00: ur = 0x7FFF
|
|
elif dtype == dtypes.fp8e4m3:
|
|
sign = ur & 0x8000
|
|
exponent = ((ur & 0x7800) >> 1) + 0x2000
|
|
mantissa = (ur & 0x0700) >> 1
|
|
absx = x & 0x7F
|
|
if absx == 0x7F: ur = 0x7FFF
|
|
elif exponent == 0x2000:
|
|
if mantissa != 0:
|
|
mantissa <<= 1
|
|
while (mantissa & 0x0400) == 0:
|
|
mantissa <<= 1
|
|
exponent -= 0x0400
|
|
mantissa &= 0x03FF
|
|
else:
|
|
exponent = 0
|
|
ur = (sign | exponent) | mantissa
|
|
else:
|
|
ur = (sign | exponent) | mantissa
|
|
|
|
half_bytes = struct.pack('<H', ur)
|
|
float32_val = struct.unpack('e', half_bytes)[0]
|
|
return float(float32_val)
|
|
|
|
truncate: dict[DType, Callable] = {dtypes.bool: bool,
|
|
dtypes.float16: float_to_fp16, dtypes.bfloat16: lambda x: float_to_bf16(float(x)),
|
|
**{fp8: (lambda x, dtype=fp8: fp8_to_float(float_to_fp8(x, dtype), dtype)) for fp8 in dtypes.fp8s},
|
|
**{getattr(dtypes, n): (lambda x, c=getattr(ctypes, f'c_{n}'): c(x).value)
|
|
for n in ('float', 'double', 'int8', 'int16', 'int32', 'int64', 'uint8', 'uint16', 'uint32', 'uint64')}}
|
|
|
|
# numpy and torch dtype interop
|
|
|
|
def _to_np_dtype(dtype:DType) -> type|None:
|
|
import numpy as np
|
|
if dtype in { dtypes.bfloat16, *dtypes.fp8s }: return np.float32
|
|
return np.dtype(dtype.fmt).type if dtype.fmt is not None else None
|
|
def _from_np_dtype(npdtype:'np.dtype') -> DType: # type: ignore [name-defined] # noqa: F821
|
|
import numpy as np
|
|
return dtypes.fields()[np.dtype(npdtype).name]
|
|
|
|
@functools.cache
|
|
def _to_torch_dtype(dtype:DType) -> 'torch.dtype'|None: # type: ignore [name-defined] # noqa: F821
|
|
import numpy as np, torch
|
|
if dtype == dtypes.uint64: return torch.uint64
|
|
if dtype == dtypes.bfloat16: return torch.bfloat16
|
|
if dtype in dtypes.fp8s: return torch.uint8
|
|
# NOTE: torch doesn't expose this mapping with a stable API
|
|
try: return torch.from_numpy(np.array([], dtype=_to_np_dtype(dtype))).dtype
|
|
except TypeError: return None
|
|
@functools.cache
|
|
def _from_torch_dtype(torchdtype:'torch.dtype') -> DType: # type: ignore [name-defined] # noqa: F821
|
|
return {v:k for k in DTYPES_DICT.values() if (v:=_to_torch_dtype(k)) is not None}[torchdtype]
|