simplify dtype (#3137)

This commit is contained in:
George Hotz
2024-01-15 16:27:43 -08:00
committed by GitHub
parent e4528543fa
commit a5d634a541

View File

@@ -1,5 +1,5 @@
from typing import Final, Optional, ClassVar, Set, Tuple, Dict, Any
from dataclasses import dataclass, field
from typing import Final, Optional, ClassVar, Set, Tuple, Dict
from dataclasses import dataclass
import numpy as np # TODO: remove numpy
import functools
@@ -9,7 +9,7 @@ class DType:
itemsize: int
name: str
fmt: Optional[str]
sz: int = 1
sz: int
def __repr__(self): return f"dtypes.{'_'*(c:=self.sz!=1)}{INVERSE_DTYPES_DICT[self.name if not c else self.scalar().name]}{str(self.sz)*c}"
def vec(self, sz:int):
assert sz > 1 and self.sz == 1, f"can't vectorize {self} with size {sz}"
@@ -22,10 +22,8 @@ class DType:
# dependent typing?
@dataclass(frozen=True, repr=False)
class ImageDType(DType):
shape: Tuple[int, ...] = (0,) # arbitrary arg for the dtype, used in image for the shape
base: Any = field(default=None, hash=False)
def __post_init__(self):
if not isinstance(self.base, DType): raise ValueError("base is not a valid dtype")
shape: Tuple[int, ...] # arbitrary arg for the dtype, used in image for the shape
base: DType
def scalar(self): return self.base
def vec(self, sz:int): return self.base.vec(sz)
def __repr__(self): return f"dtypes.{self.name}({self.shape})"
@@ -36,7 +34,7 @@ class PtrDType(DType):
def __repr__(self): return f"ptr.{super().__repr__()}"
def __hash__(self): return super().__hash__()
def __eq__(self, dt): return self.priority==dt.priority and self.itemsize==dt.itemsize and self.name==dt.name and self.sz==dt.sz
def __ne__(self, dt): return self.priority!=dt.priority or self.itemsize!=dt.itemsize or self.name!=dt.name or self.sz!=dt.sz
def __ne__(self, dt): return not (self == dt)
class dtypes:
@staticmethod
@@ -51,20 +49,20 @@ class dtypes:
def from_py(x) -> DType: return dtypes.default_float if isinstance(x, float) else dtypes.bool if isinstance(x, bool) else dtypes.default_int
@staticmethod
def fields() -> Dict[str, DType]: return DTYPES_DICT
bool: Final[DType] = DType(0, 1, "bool", '?')
int8: Final[DType] = DType(1, 1, "char", 'b')
uint8: Final[DType] = DType(2, 1, "unsigned char", 'B')
int16: Final[DType] = DType(3, 2, "short", 'h')
uint16: Final[DType] = DType(4, 2, "unsigned short", 'H')
int32: Final[DType] = DType(5, 4, "int", 'i')
uint32: Final[DType] = DType(6, 4, "unsigned int", 'I')
int64: Final[DType] = DType(7, 8, "long", 'l')
uint64: Final[DType] = DType(8, 8, "unsigned long", 'L')
float16: Final[DType] = DType(9, 2, "half", 'e')
bool: Final[DType] = DType(0, 1, "bool", '?', 1)
int8: Final[DType] = DType(1, 1, "char", 'b', 1)
uint8: Final[DType] = DType(2, 1, "unsigned char", 'B', 1)
int16: Final[DType] = DType(3, 2, "short", 'h', 1)
uint16: Final[DType] = DType(4, 2, "unsigned short", 'H', 1)
int32: Final[DType] = DType(5, 4, "int", 'i', 1)
uint32: Final[DType] = DType(6, 4, "unsigned int", 'I', 1)
int64: Final[DType] = DType(7, 8, "long", 'l', 1)
uint64: Final[DType] = DType(8, 8, "unsigned long", 'L', 1)
float16: Final[DType] = DType(9, 2, "half", 'e', 1)
# bfloat16 has higher priority than float16, so least_upper_dtype(dtypes.int64, dtypes.uint64) = dtypes.float16
bfloat16: Final[DType] = DType(10, 2, "__bf16", None)
float32: Final[DType] = DType(11, 4, "float", 'f')
float64: Final[DType] = DType(12, 8, "double", 'd')
bfloat16: Final[DType] = DType(10, 2, "__bf16", None, 1)
float32: Final[DType] = DType(11, 4, "float", 'f', 1)
float64: Final[DType] = DType(12, 8, "double", 'd', 1)
# dtype aliases
half = float16; float = float32; double = float64 # noqa: E702
@@ -73,9 +71,9 @@ class dtypes:
# NOTE: these are image dtypes
@staticmethod
def imageh(shp): return ImageDType(100, 2, "imageh", 'e', shape=shp, base=dtypes.float32)
def imageh(shp): return ImageDType(100, 2, "imageh", 'e', 1, shape=shp, base=dtypes.float32)
@staticmethod
def imagef(shp): return ImageDType(100, 4, "imagef", 'f', shape=shp, base=dtypes.float32)
def imagef(shp): return ImageDType(100, 4, "imagef", 'f', 1, shape=shp, base=dtypes.float32)
default_float: ClassVar[DType] = float32
default_int: ClassVar[DType] = int32