mirror of
https://github.com/tinygrad/tinygrad.git
synced 2026-02-08 13:45:50 -05:00
use shl everywhere (#6744)
* use shl everywhere * fix parens * late patterns * works as an extra pass * ptx
This commit is contained in:
@@ -44,13 +44,13 @@ class PtrDType(DType):
|
||||
class dtypes:
|
||||
@staticmethod
|
||||
@functools.lru_cache(None)
|
||||
def is_float(x: DType) -> bool: return x.scalar() in {dtypes.float16, dtypes.bfloat16, dtypes.float32, dtypes.float64}
|
||||
def is_float(x: DType) -> bool: return x.scalar() in dtypes.floats
|
||||
@staticmethod # static methds on top, or bool in the type info will refer to dtypes.bool
|
||||
@functools.lru_cache(None)
|
||||
def is_int(x: DType) -> bool: return x.scalar() in {dtypes.int8, dtypes.int16, dtypes.int32, dtypes.int64, dtypes.pyint} or dtypes.is_unsigned(x)
|
||||
def is_int(x: DType) -> bool: return x.scalar() in dtypes.ints
|
||||
@staticmethod
|
||||
@functools.lru_cache(None)
|
||||
def is_unsigned(x: DType) -> bool: return x.scalar() in {dtypes.uint8, dtypes.uint16, dtypes.uint32, dtypes.uint64}
|
||||
def is_unsigned(x: DType) -> bool: return x.scalar() in dtypes.uints
|
||||
@staticmethod
|
||||
def from_py(x) -> DType:
|
||||
if x.__class__ is float: return dtypes.default_float
|
||||
@@ -114,6 +114,11 @@ class dtypes:
|
||||
default_float: ClassVar[DType] = float32
|
||||
default_int: ClassVar[DType] = int32
|
||||
|
||||
floats = (float16, bfloat16, float32, float64)
|
||||
uints = (uint8, uint16, uint32, uint64)
|
||||
sints = (int8, int16, int32, int64, pyint)
|
||||
ints = uints + sints
|
||||
|
||||
if (env_default_float := getenv("DEFAULT_FLOAT", "")):
|
||||
dtypes.default_float = getattr(dtypes, env_default_float.lower())
|
||||
assert dtypes.is_float(dtypes.default_float), f"{env_default_float} is not a float dtype"
|
||||
@@ -137,7 +142,8 @@ def least_upper_dtype(*ds:DType) -> DType:
|
||||
def least_upper_float(dt:DType) -> DType: return dt if dtypes.is_float(dt) else least_upper_dtype(dt, dtypes.float32)
|
||||
|
||||
# HACK: staticmethods are not callable in 3.8 so we have to compare the class
|
||||
DTYPES_DICT = {k: v for k, v in dtypes.__dict__.items() if not (k.startswith(('__', 'default', 'void')) or v.__class__ is staticmethod)}
|
||||
DTYPES_DICT = {k: v for k, v in dtypes.__dict__.items() if not (k.startswith(('__', 'default', 'void'))
|
||||
or v.__class__ is staticmethod or isinstance(v, tuple))}
|
||||
INVERSE_DTYPES_DICT = {v.name:k for k,v in DTYPES_DICT.items()}
|
||||
INVERSE_DTYPES_DICT['void'] = 'void'
|
||||
|
||||
|
||||
Reference in New Issue
Block a user