fast scalar (#7545)

* fast scalar set on dtype

* prevent loop

* lru_cache those
This commit is contained in:
George Hotz
2024-11-05 14:08:08 +08:00
committed by GitHub
parent 5682955c7b
commit d87adccb6c

View File

@@ -21,6 +21,9 @@ class DType(metaclass=DTypeMetaClass):
name: str
fmt: Optional[str]
count: int
_scalar: Optional[DType]
@staticmethod
def new(priority:int, itemsize:int, name:str, fmt:Optional[str]): return DType(priority, itemsize, name, fmt, 1, None)
def __reduce__(self): return type(self), tuple(getattr(self, f.name) for f in fields(self))
def __repr__(self): return f"dtypes.{INVERSE_DTYPES_DICT[self.scalar().name]}"+(f".vec({self.count})" if self.count > 1 else "")
def __lt__(self, o:DType): return (self.priority, self.itemsize, self.name, self.fmt, self.count) < (o.priority, o.itemsize, o.name, o.fmt, o.count)
@@ -28,12 +31,13 @@ class DType(metaclass=DTypeMetaClass):
def base(self): return self
@property
def vcount(self): return self.count
@functools.lru_cache(None) # pylint: disable=method-cache-max-size-none
def vec(self, sz:int) -> DType:
assert self.count == 1, f"can't vectorize {self} with size {sz}"
if sz == 1 or self == dtypes.void: return self # void doesn't vectorize, and sz=1 is scalar
return DType(self.priority, self.itemsize*sz, f"{INVERSE_DTYPES_DICT[self.name]}{sz}", None, sz)
def ptr(self, local=False) -> PtrDType: return PtrDType(self.priority, self.itemsize, self.name, self.fmt, self.count, self, local, 1)
def scalar(self) -> DType: return DTYPES_DICT[self.name[:-len(str(self.count))]] if self.count > 1 else self
return DType(self.priority, self.itemsize*sz, f"{INVERSE_DTYPES_DICT[self.name]}{sz}", None, sz, self)
def ptr(self, local=False) -> PtrDType: return PtrDType(self.priority, self.itemsize, self.name, self.fmt, self.count, None, self, local, 1)
def scalar(self) -> DType: return self._scalar if self._scalar is not None else self
@dataclass(frozen=True, eq=False)
class PtrDType(DType):
@@ -42,8 +46,11 @@ class PtrDType(DType):
v: int
@property
def base(self): return self._base
def scalar(self) -> PtrDType: return self.vec(1)
def vec(self, sz:int) -> PtrDType: return type(self)(*tuple(getattr(self, f.name) if f.name != 'v' else sz for f in fields(self)))
@functools.lru_cache(None) # pylint: disable=method-cache-max-size-none
def vec(self, sz:int) -> DType:
assert self.v == 1, f"can't vectorize ptr {self} with size {sz}"
if sz == 1: return self # sz=1 is a scalar
return type(self)(*tuple(sz if f.name == 'v' else (self if f.name == '_scalar' else getattr(self, f.name)) for f in fields(self)))
def ptr(self, local=False): raise RuntimeError("can't make a pointer from a pointer")
@property
def vcount(self): return self.v
@@ -99,21 +106,21 @@ class dtypes:
return {dtypes.float16: (5, 10), dtypes.bfloat16: (8, 7), dtypes.float32: (8, 23), dtypes.float64: (11, 52)}[dtype]
@staticmethod
def fields() -> Dict[str, DType]: return DTYPES_DICT
void: Final[DType] = DType(-1, 0, "void", None, 1)
bool: Final[DType] = DType(0, 1, "bool", '?', 1)
int8: Final[DType] = DType(1, 1, "char", 'b', 1)
uint8: Final[DType] = DType(2, 1, "unsigned char", 'B', 1)
int16: Final[DType] = DType(3, 2, "short", 'h', 1)
uint16: Final[DType] = DType(4, 2, "unsigned short", 'H', 1)
int32: Final[DType] = DType(5, 4, "int", 'i', 1)
uint32: Final[DType] = DType(6, 4, "unsigned int", 'I', 1)
int64: Final[DType] = DType(7, 8, "long", 'q', 1)
uint64: Final[DType] = DType(8, 8, "unsigned long", 'Q', 1)
float16: Final[DType] = DType(9, 2, "half", 'e', 1)
void: Final[DType] = DType.new(-1, 0, "void", None)
bool: Final[DType] = DType.new(0, 1, "bool", '?')
int8: Final[DType] = DType.new(1, 1, "char", 'b')
uint8: Final[DType] = DType.new(2, 1, "unsigned char", 'B')
int16: Final[DType] = DType.new(3, 2, "short", 'h')
uint16: Final[DType] = DType.new(4, 2, "unsigned short", 'H')
int32: Final[DType] = DType.new(5, 4, "int", 'i')
uint32: Final[DType] = DType.new(6, 4, "unsigned int", 'I')
int64: Final[DType] = DType.new(7, 8, "long", 'q')
uint64: Final[DType] = DType.new(8, 8, "unsigned long", 'Q')
float16: Final[DType] = DType.new(9, 2, "half", 'e')
# bfloat16 has higher priority than float16, so least_upper_dtype(dtypes.int64, dtypes.uint64) = dtypes.float16
bfloat16: Final[DType] = DType(10, 2, "__bf16", None, 1)
float32: Final[DType] = DType(11, 4, "float", 'f', 1)
float64: Final[DType] = DType(12, 8, "double", 'd', 1)
bfloat16: Final[DType] = DType.new(10, 2, "__bf16", None)
float32: Final[DType] = DType.new(11, 4, "float", 'f')
float64: Final[DType] = DType.new(12, 8, "double", 'd')
# dtype aliases
half = float16; float = float32; double = float64 # noqa: E702
@@ -122,9 +129,9 @@ class dtypes:
# NOTE: these are image dtypes
@staticmethod
def imageh(shp): return ImageDType(100, 2, "imageh", 'e', 1, dtypes.float32, False, 1, shp)
def imageh(shp): return ImageDType(100, 2, "imageh", 'e', 1, None, dtypes.float32, False, 1, shp)
@staticmethod
def imagef(shp): return ImageDType(100, 4, "imagef", 'f', 1, dtypes.float32, False, 1, shp)
def imagef(shp): return ImageDType(100, 4, "imagef", 'f', 1, None, dtypes.float32, False, 1, shp)
default_float: ClassVar[DType] = float32
default_int: ClassVar[DType] = int32