use shl everywhere (#6744)

* use shl everywhere

* fix parens

* late patterns

* works as an extra pass

* ptx
This commit is contained in:
George Hotz
2024-09-26 09:59:36 +08:00
committed by GitHub
parent 88160e59b2
commit b199b699ed
5 changed files with 38 additions and 25 deletions

View File

@@ -44,13 +44,13 @@ class PtrDType(DType):
class dtypes:
@staticmethod
@functools.lru_cache(None)
def is_float(x: DType) -> bool: return x.scalar() in {dtypes.float16, dtypes.bfloat16, dtypes.float32, dtypes.float64}
def is_float(x: DType) -> bool: return x.scalar() in dtypes.floats
@staticmethod # static methds on top, or bool in the type info will refer to dtypes.bool
@functools.lru_cache(None)
def is_int(x: DType) -> bool: return x.scalar() in {dtypes.int8, dtypes.int16, dtypes.int32, dtypes.int64, dtypes.pyint} or dtypes.is_unsigned(x)
def is_int(x: DType) -> bool: return x.scalar() in dtypes.ints
@staticmethod
@functools.lru_cache(None)
def is_unsigned(x: DType) -> bool: return x.scalar() in {dtypes.uint8, dtypes.uint16, dtypes.uint32, dtypes.uint64}
def is_unsigned(x: DType) -> bool: return x.scalar() in dtypes.uints
@staticmethod
def from_py(x) -> DType:
if x.__class__ is float: return dtypes.default_float
@@ -114,6 +114,11 @@ class dtypes:
default_float: ClassVar[DType] = float32
default_int: ClassVar[DType] = int32
floats = (float16, bfloat16, float32, float64)
uints = (uint8, uint16, uint32, uint64)
sints = (int8, int16, int32, int64, pyint)
ints = uints + sints
if (env_default_float := getenv("DEFAULT_FLOAT", "")):
dtypes.default_float = getattr(dtypes, env_default_float.lower())
assert dtypes.is_float(dtypes.default_float), f"{env_default_float} is not a float dtype"
@@ -137,7 +142,8 @@ def least_upper_dtype(*ds:DType) -> DType:
def least_upper_float(dt:DType) -> DType: return dt if dtypes.is_float(dt) else least_upper_dtype(dt, dtypes.float32)
# HACK: staticmethods are not callable in 3.8 so we have to compare the class
DTYPES_DICT = {k: v for k, v in dtypes.__dict__.items() if not (k.startswith(('__', 'default', 'void')) or v.__class__ is staticmethod)}
DTYPES_DICT = {k: v for k, v in dtypes.__dict__.items() if not (k.startswith(('__', 'default', 'void'))
or v.__class__ is staticmethod or isinstance(v, tuple))}
INVERSE_DTYPES_DICT = {v.name:k for k,v in DTYPES_DICT.items()}
INVERSE_DTYPES_DICT['void'] = 'void'