mirror of
https://github.com/tinygrad/tinygrad.git
synced 2026-01-10 23:48:01 -05:00
minor cleanups in dtype.py (#2978)
* minor cleanups in dtype.py * all not
This commit is contained in:
@@ -12,7 +12,7 @@ class DType(NamedTuple):
|
||||
def __repr__(self): return f"dtypes.{INVERSE_DTYPES_DICT[self]}" if self.sz == 1 else f"dtypes._{INVERSE_DTYPES_DICT[self.scalar()]}{self.sz}"
|
||||
def vec(self, sz:int):
|
||||
assert sz > 1 and self.sz == 1, f"can't vectorize {self} with size {sz}"
|
||||
return DType(self.priority, self.itemsize*sz, f"{INVERSE_DTYPES_DICT[self]}{str(sz)}", None, sz)
|
||||
return DType(self.priority, self.itemsize*sz, f"{INVERSE_DTYPES_DICT[self]}{sz}", None, sz)
|
||||
def scalar(self): return DTYPES_DICT[self.name[:-len(str(self.sz))]] if self.sz > 1 else self
|
||||
|
||||
# dependent typing?
|
||||
@@ -79,9 +79,8 @@ class dtypes:
|
||||
|
||||
# https://jax.readthedocs.io/en/latest/jep/9407-type-promotion.html
|
||||
# we don't support weak type and complex type
|
||||
promo_lattice = { dtypes.bool: [dtypes.int8, dtypes.uint8],
|
||||
dtypes.int8: [dtypes.int16], dtypes.int16: [dtypes.int32], dtypes.int32: [dtypes.int64], dtypes.int64: [dtypes.float16, dtypes.bfloat16],
|
||||
dtypes.uint8: [dtypes.int16, dtypes.uint16], dtypes.uint16: [dtypes.int32, dtypes.uint32],
|
||||
promo_lattice = { dtypes.bool: [dtypes.int8, dtypes.uint8], dtypes.int8: [dtypes.int16], dtypes.int16: [dtypes.int32], dtypes.int32: [dtypes.int64],
|
||||
dtypes.int64: [dtypes.float16, dtypes.bfloat16], dtypes.uint8: [dtypes.int16, dtypes.uint16], dtypes.uint16: [dtypes.int32, dtypes.uint32],
|
||||
dtypes.uint32: [dtypes.int64, dtypes.uint64], dtypes.uint64: [dtypes.float16, dtypes.bfloat16],
|
||||
dtypes.float16: [dtypes.float32], dtypes.bfloat16: [dtypes.float32], dtypes.float32: [dtypes.float64], }
|
||||
|
||||
@@ -94,6 +93,5 @@ def least_upper_dtype(*ds:DType) -> DType:
|
||||
def least_upper_float(dt:DType) -> DType: return dt if dtypes.is_float(dt) else least_upper_dtype(dt, dtypes.float32)
|
||||
|
||||
# HACK: staticmethods are not callable in 3.8 so we have to compare the class
|
||||
DTYPES_DICT = {k: v for k, v in dtypes.__dict__.items() if (
|
||||
not k.startswith('__') and not k.startswith('default') and not callable(v) and v.__class__ is not staticmethod)}
|
||||
DTYPES_DICT = {k: v for k, v in dtypes.__dict__.items() if not (k.startswith('__') or k.startswith('default') or v.__class__ is staticmethod)}
|
||||
INVERSE_DTYPES_DICT = {v:k for k,v in DTYPES_DICT.items()}
|
||||
|
||||
Reference in New Issue
Block a user