minor cleanups in dtype.py (#2978)

* minor cleanups in dtype.py

* all not
This commit is contained in:
chenyu
2024-01-02 13:42:37 -05:00
committed by GitHub
parent ff5399f053
commit 91ddda244f

View File

@@ -12,7 +12,7 @@ class DType(NamedTuple):
def __repr__(self): return f"dtypes.{INVERSE_DTYPES_DICT[self]}" if self.sz == 1 else f"dtypes._{INVERSE_DTYPES_DICT[self.scalar()]}{self.sz}"
def vec(self, sz:int):
assert sz > 1 and self.sz == 1, f"can't vectorize {self} with size {sz}"
return DType(self.priority, self.itemsize*sz, f"{INVERSE_DTYPES_DICT[self]}{str(sz)}", None, sz)
return DType(self.priority, self.itemsize*sz, f"{INVERSE_DTYPES_DICT[self]}{sz}", None, sz)
def scalar(self): return DTYPES_DICT[self.name[:-len(str(self.sz))]] if self.sz > 1 else self
# dependent typing?
@@ -79,9 +79,8 @@ class dtypes:
# https://jax.readthedocs.io/en/latest/jep/9407-type-promotion.html
# we don't support weak type and complex type
promo_lattice = { dtypes.bool: [dtypes.int8, dtypes.uint8],
dtypes.int8: [dtypes.int16], dtypes.int16: [dtypes.int32], dtypes.int32: [dtypes.int64], dtypes.int64: [dtypes.float16, dtypes.bfloat16],
dtypes.uint8: [dtypes.int16, dtypes.uint16], dtypes.uint16: [dtypes.int32, dtypes.uint32],
promo_lattice = { dtypes.bool: [dtypes.int8, dtypes.uint8], dtypes.int8: [dtypes.int16], dtypes.int16: [dtypes.int32], dtypes.int32: [dtypes.int64],
dtypes.int64: [dtypes.float16, dtypes.bfloat16], dtypes.uint8: [dtypes.int16, dtypes.uint16], dtypes.uint16: [dtypes.int32, dtypes.uint32],
dtypes.uint32: [dtypes.int64, dtypes.uint64], dtypes.uint64: [dtypes.float16, dtypes.bfloat16],
dtypes.float16: [dtypes.float32], dtypes.bfloat16: [dtypes.float32], dtypes.float32: [dtypes.float64], }
@@ -94,6 +93,5 @@ def least_upper_dtype(*ds:DType) -> DType:
def least_upper_float(dt:DType) -> DType: return dt if dtypes.is_float(dt) else least_upper_dtype(dt, dtypes.float32)
# HACK: staticmethods are not callable in 3.8 so we have to compare the class
DTYPES_DICT = {k: v for k, v in dtypes.__dict__.items() if (
not k.startswith('__') and not k.startswith('default') and not callable(v) and v.__class__ is not staticmethod)}
DTYPES_DICT = {k: v for k, v in dtypes.__dict__.items() if not (k.startswith('__') or k.startswith('default') or v.__class__ is staticmethod)}
INVERSE_DTYPES_DICT = {v:k for k,v in DTYPES_DICT.items()}