mirror of
https://github.com/tinygrad/tinygrad.git
synced 2026-04-07 03:00:26 -04:00
fast scalar (#7545)
* fast scalar set on dtype * prevent loop * lru_cache those
This commit is contained in:
@@ -21,6 +21,9 @@ class DType(metaclass=DTypeMetaClass):
|
||||
name: str
|
||||
fmt: Optional[str]
|
||||
count: int
|
||||
_scalar: Optional[DType]
|
||||
@staticmethod
|
||||
def new(priority:int, itemsize:int, name:str, fmt:Optional[str]): return DType(priority, itemsize, name, fmt, 1, None)
|
||||
def __reduce__(self): return type(self), tuple(getattr(self, f.name) for f in fields(self))
|
||||
def __repr__(self): return f"dtypes.{INVERSE_DTYPES_DICT[self.scalar().name]}"+(f".vec({self.count})" if self.count > 1 else "")
|
||||
def __lt__(self, o:DType): return (self.priority, self.itemsize, self.name, self.fmt, self.count) < (o.priority, o.itemsize, o.name, o.fmt, o.count)
|
||||
@@ -28,12 +31,13 @@ class DType(metaclass=DTypeMetaClass):
|
||||
def base(self): return self
|
||||
@property
|
||||
def vcount(self): return self.count
|
||||
@functools.lru_cache(None) # pylint: disable=method-cache-max-size-none
|
||||
def vec(self, sz:int) -> DType:
|
||||
assert self.count == 1, f"can't vectorize {self} with size {sz}"
|
||||
if sz == 1 or self == dtypes.void: return self # void doesn't vectorize, and sz=1 is scalar
|
||||
return DType(self.priority, self.itemsize*sz, f"{INVERSE_DTYPES_DICT[self.name]}{sz}", None, sz)
|
||||
def ptr(self, local=False) -> PtrDType: return PtrDType(self.priority, self.itemsize, self.name, self.fmt, self.count, self, local, 1)
|
||||
def scalar(self) -> DType: return DTYPES_DICT[self.name[:-len(str(self.count))]] if self.count > 1 else self
|
||||
return DType(self.priority, self.itemsize*sz, f"{INVERSE_DTYPES_DICT[self.name]}{sz}", None, sz, self)
|
||||
def ptr(self, local=False) -> PtrDType: return PtrDType(self.priority, self.itemsize, self.name, self.fmt, self.count, None, self, local, 1)
|
||||
def scalar(self) -> DType: return self._scalar if self._scalar is not None else self
|
||||
|
||||
@dataclass(frozen=True, eq=False)
|
||||
class PtrDType(DType):
|
||||
@@ -42,8 +46,11 @@ class PtrDType(DType):
|
||||
v: int
|
||||
@property
|
||||
def base(self): return self._base
|
||||
def scalar(self) -> PtrDType: return self.vec(1)
|
||||
def vec(self, sz:int) -> PtrDType: return type(self)(*tuple(getattr(self, f.name) if f.name != 'v' else sz for f in fields(self)))
|
||||
@functools.lru_cache(None) # pylint: disable=method-cache-max-size-none
|
||||
def vec(self, sz:int) -> DType:
|
||||
assert self.v == 1, f"can't vectorize ptr {self} with size {sz}"
|
||||
if sz == 1: return self # sz=1 is a scalar
|
||||
return type(self)(*tuple(sz if f.name == 'v' else (self if f.name == '_scalar' else getattr(self, f.name)) for f in fields(self)))
|
||||
def ptr(self, local=False): raise RuntimeError("can't make a pointer from a pointer")
|
||||
@property
|
||||
def vcount(self): return self.v
|
||||
@@ -99,21 +106,21 @@ class dtypes:
|
||||
return {dtypes.float16: (5, 10), dtypes.bfloat16: (8, 7), dtypes.float32: (8, 23), dtypes.float64: (11, 52)}[dtype]
|
||||
@staticmethod
|
||||
def fields() -> Dict[str, DType]: return DTYPES_DICT
|
||||
void: Final[DType] = DType(-1, 0, "void", None, 1)
|
||||
bool: Final[DType] = DType(0, 1, "bool", '?', 1)
|
||||
int8: Final[DType] = DType(1, 1, "char", 'b', 1)
|
||||
uint8: Final[DType] = DType(2, 1, "unsigned char", 'B', 1)
|
||||
int16: Final[DType] = DType(3, 2, "short", 'h', 1)
|
||||
uint16: Final[DType] = DType(4, 2, "unsigned short", 'H', 1)
|
||||
int32: Final[DType] = DType(5, 4, "int", 'i', 1)
|
||||
uint32: Final[DType] = DType(6, 4, "unsigned int", 'I', 1)
|
||||
int64: Final[DType] = DType(7, 8, "long", 'q', 1)
|
||||
uint64: Final[DType] = DType(8, 8, "unsigned long", 'Q', 1)
|
||||
float16: Final[DType] = DType(9, 2, "half", 'e', 1)
|
||||
void: Final[DType] = DType.new(-1, 0, "void", None)
|
||||
bool: Final[DType] = DType.new(0, 1, "bool", '?')
|
||||
int8: Final[DType] = DType.new(1, 1, "char", 'b')
|
||||
uint8: Final[DType] = DType.new(2, 1, "unsigned char", 'B')
|
||||
int16: Final[DType] = DType.new(3, 2, "short", 'h')
|
||||
uint16: Final[DType] = DType.new(4, 2, "unsigned short", 'H')
|
||||
int32: Final[DType] = DType.new(5, 4, "int", 'i')
|
||||
uint32: Final[DType] = DType.new(6, 4, "unsigned int", 'I')
|
||||
int64: Final[DType] = DType.new(7, 8, "long", 'q')
|
||||
uint64: Final[DType] = DType.new(8, 8, "unsigned long", 'Q')
|
||||
float16: Final[DType] = DType.new(9, 2, "half", 'e')
|
||||
# bfloat16 has higher priority than float16, so least_upper_dtype(dtypes.int64, dtypes.uint64) = dtypes.float16
|
||||
bfloat16: Final[DType] = DType(10, 2, "__bf16", None, 1)
|
||||
float32: Final[DType] = DType(11, 4, "float", 'f', 1)
|
||||
float64: Final[DType] = DType(12, 8, "double", 'd', 1)
|
||||
bfloat16: Final[DType] = DType.new(10, 2, "__bf16", None)
|
||||
float32: Final[DType] = DType.new(11, 4, "float", 'f')
|
||||
float64: Final[DType] = DType.new(12, 8, "double", 'd')
|
||||
|
||||
# dtype aliases
|
||||
half = float16; float = float32; double = float64 # noqa: E702
|
||||
@@ -122,9 +129,9 @@ class dtypes:
|
||||
|
||||
# NOTE: these are image dtypes
|
||||
@staticmethod
|
||||
def imageh(shp): return ImageDType(100, 2, "imageh", 'e', 1, dtypes.float32, False, 1, shp)
|
||||
def imageh(shp): return ImageDType(100, 2, "imageh", 'e', 1, None, dtypes.float32, False, 1, shp)
|
||||
@staticmethod
|
||||
def imagef(shp): return ImageDType(100, 4, "imagef", 'f', 1, dtypes.float32, False, 1, shp)
|
||||
def imagef(shp): return ImageDType(100, 4, "imagef", 'f', 1, None, dtypes.float32, False, 1, shp)
|
||||
|
||||
default_float: ClassVar[DType] = float32
|
||||
default_int: ClassVar[DType] = int32
|
||||
|
||||
Reference in New Issue
Block a user