mirror of
https://github.com/tinygrad/tinygrad.git
synced 2026-04-07 03:00:26 -04:00
simplify dtype (#3137)
This commit is contained in:
@@ -1,5 +1,5 @@
|
||||
from typing import Final, Optional, ClassVar, Set, Tuple, Dict, Any
|
||||
from dataclasses import dataclass, field
|
||||
from typing import Final, Optional, ClassVar, Set, Tuple, Dict
|
||||
from dataclasses import dataclass
|
||||
import numpy as np # TODO: remove numpy
|
||||
import functools
|
||||
|
||||
@@ -9,7 +9,7 @@ class DType:
|
||||
itemsize: int
|
||||
name: str
|
||||
fmt: Optional[str]
|
||||
sz: int = 1
|
||||
sz: int
|
||||
def __repr__(self): return f"dtypes.{'_'*(c:=self.sz!=1)}{INVERSE_DTYPES_DICT[self.name if not c else self.scalar().name]}{str(self.sz)*c}"
|
||||
def vec(self, sz:int):
|
||||
assert sz > 1 and self.sz == 1, f"can't vectorize {self} with size {sz}"
|
||||
@@ -22,10 +22,8 @@ class DType:
|
||||
# dependent typing?
|
||||
@dataclass(frozen=True, repr=False)
|
||||
class ImageDType(DType):
|
||||
shape: Tuple[int, ...] = (0,) # arbitrary arg for the dtype, used in image for the shape
|
||||
base: Any = field(default=None, hash=False)
|
||||
def __post_init__(self):
|
||||
if not isinstance(self.base, DType): raise ValueError("base is not a valid dtype")
|
||||
shape: Tuple[int, ...] # arbitrary arg for the dtype, used in image for the shape
|
||||
base: DType
|
||||
def scalar(self): return self.base
|
||||
def vec(self, sz:int): return self.base.vec(sz)
|
||||
def __repr__(self): return f"dtypes.{self.name}({self.shape})"
|
||||
@@ -36,7 +34,7 @@ class PtrDType(DType):
|
||||
def __repr__(self): return f"ptr.{super().__repr__()}"
|
||||
def __hash__(self): return super().__hash__()
|
||||
def __eq__(self, dt): return self.priority==dt.priority and self.itemsize==dt.itemsize and self.name==dt.name and self.sz==dt.sz
|
||||
def __ne__(self, dt): return self.priority!=dt.priority or self.itemsize!=dt.itemsize or self.name!=dt.name or self.sz!=dt.sz
|
||||
def __ne__(self, dt): return not (self == dt)
|
||||
|
||||
class dtypes:
|
||||
@staticmethod
|
||||
@@ -51,20 +49,20 @@ class dtypes:
|
||||
def from_py(x) -> DType: return dtypes.default_float if isinstance(x, float) else dtypes.bool if isinstance(x, bool) else dtypes.default_int
|
||||
@staticmethod
|
||||
def fields() -> Dict[str, DType]: return DTYPES_DICT
|
||||
bool: Final[DType] = DType(0, 1, "bool", '?')
|
||||
int8: Final[DType] = DType(1, 1, "char", 'b')
|
||||
uint8: Final[DType] = DType(2, 1, "unsigned char", 'B')
|
||||
int16: Final[DType] = DType(3, 2, "short", 'h')
|
||||
uint16: Final[DType] = DType(4, 2, "unsigned short", 'H')
|
||||
int32: Final[DType] = DType(5, 4, "int", 'i')
|
||||
uint32: Final[DType] = DType(6, 4, "unsigned int", 'I')
|
||||
int64: Final[DType] = DType(7, 8, "long", 'l')
|
||||
uint64: Final[DType] = DType(8, 8, "unsigned long", 'L')
|
||||
float16: Final[DType] = DType(9, 2, "half", 'e')
|
||||
bool: Final[DType] = DType(0, 1, "bool", '?', 1)
|
||||
int8: Final[DType] = DType(1, 1, "char", 'b', 1)
|
||||
uint8: Final[DType] = DType(2, 1, "unsigned char", 'B', 1)
|
||||
int16: Final[DType] = DType(3, 2, "short", 'h', 1)
|
||||
uint16: Final[DType] = DType(4, 2, "unsigned short", 'H', 1)
|
||||
int32: Final[DType] = DType(5, 4, "int", 'i', 1)
|
||||
uint32: Final[DType] = DType(6, 4, "unsigned int", 'I', 1)
|
||||
int64: Final[DType] = DType(7, 8, "long", 'l', 1)
|
||||
uint64: Final[DType] = DType(8, 8, "unsigned long", 'L', 1)
|
||||
float16: Final[DType] = DType(9, 2, "half", 'e', 1)
|
||||
# bfloat16 has higher priority than float16, so least_upper_dtype(dtypes.int64, dtypes.uint64) = dtypes.float16
|
||||
bfloat16: Final[DType] = DType(10, 2, "__bf16", None)
|
||||
float32: Final[DType] = DType(11, 4, "float", 'f')
|
||||
float64: Final[DType] = DType(12, 8, "double", 'd')
|
||||
bfloat16: Final[DType] = DType(10, 2, "__bf16", None, 1)
|
||||
float32: Final[DType] = DType(11, 4, "float", 'f', 1)
|
||||
float64: Final[DType] = DType(12, 8, "double", 'd', 1)
|
||||
|
||||
# dtype aliases
|
||||
half = float16; float = float32; double = float64 # noqa: E702
|
||||
@@ -73,9 +71,9 @@ class dtypes:
|
||||
|
||||
# NOTE: these are image dtypes
|
||||
@staticmethod
|
||||
def imageh(shp): return ImageDType(100, 2, "imageh", 'e', shape=shp, base=dtypes.float32)
|
||||
def imageh(shp): return ImageDType(100, 2, "imageh", 'e', 1, shape=shp, base=dtypes.float32)
|
||||
@staticmethod
|
||||
def imagef(shp): return ImageDType(100, 4, "imagef", 'f', shape=shp, base=dtypes.float32)
|
||||
def imagef(shp): return ImageDType(100, 4, "imagef", 'f', 1, shape=shp, base=dtypes.float32)
|
||||
|
||||
default_float: ClassVar[DType] = float32
|
||||
default_int: ClassVar[DType] = int32
|
||||
|
||||
Reference in New Issue
Block a user