mirror of
https://github.com/tinygrad/tinygrad.git
synced 2026-02-08 13:45:50 -05:00
@@ -156,7 +156,7 @@ if (env_default_float := getenv("DEFAULT_FLOAT", "")):
|
||||
assert dtypes.is_float(dtypes.default_float), f"{env_default_float} is not a float dtype"
|
||||
|
||||
DTypeLike = Union[str, DType]
|
||||
def to_dtype(dtype:DTypeLike) -> DType: return dtype if isinstance(dtype, DType) else getattr(dtypes, dtype)
|
||||
def to_dtype(dtype:DTypeLike) -> DType: return dtype if isinstance(dtype, DType) else getattr(dtypes, dtype.lower())
|
||||
|
||||
# https://jax.readthedocs.io/en/latest/jep/9407-type-promotion.html
|
||||
# we don't support weak type and complex type
|
||||
@@ -180,7 +180,7 @@ def sum_acc_dtype(dt:DType):
|
||||
# default acc dtype for sum
|
||||
if dtypes.is_unsigned(dt): return least_upper_dtype(dt, dtypes.uint)
|
||||
if dtypes.is_int(dt) or dt == dtypes.bool: return least_upper_dtype(dt, dtypes.int)
|
||||
return least_upper_dtype(dt, dtypes.float)
|
||||
return least_upper_dtype(dt, to_dtype(getenv("SUM_DTYPE", "float32")))
|
||||
|
||||
def truncate_fp16(x):
|
||||
try: return struct.unpack("@e", struct.pack("@e", float(x)))[0]
|
||||
|
||||
Reference in New Issue
Block a user