mirror of
https://github.com/tinygrad/tinygrad.git
synced 2026-01-09 15:08:02 -05:00
truncate_fp16 -> float_to_fp16 (#12234)
match float_to_bf16 and float_to_fp8
This commit is contained in:
1
.github/workflows/test.yml
vendored
1
.github/workflows/test.yml
vendored
@@ -519,7 +519,6 @@ jobs:
|
||||
llvm: "true"
|
||||
- name: Test CPU=1 RANGEIFY=1
|
||||
# TODO: add more passing tests here
|
||||
# test_threefry_doesnt_use_long is because there's a contig after the long now
|
||||
# test_embedding issue with jit
|
||||
# test_load_state_dict_sharded_model_dict_same_axis issue with multi
|
||||
# test_instancenorm_3d is very slow
|
||||
|
||||
@@ -1,6 +1,6 @@
|
||||
import unittest, math, operator, subprocess, struct
|
||||
from tinygrad.tensor import Tensor, dtypes, Device
|
||||
from tinygrad.dtype import DType, DTYPES_DICT, truncate, truncate_fp16, float_to_bf16, _to_np_dtype, least_upper_dtype, least_upper_float
|
||||
from tinygrad.dtype import DType, DTYPES_DICT, truncate, float_to_fp16, float_to_bf16, _to_np_dtype, least_upper_dtype, least_upper_float
|
||||
from tinygrad.device import is_dtype_supported
|
||||
from tinygrad.helpers import getenv, CI, DEBUG
|
||||
from hypothesis import given, settings, strategies as strat
|
||||
@@ -106,16 +106,16 @@ class TestHelpers(unittest.TestCase):
|
||||
self.assertEqual(dt.min, dt.vec(4).min)
|
||||
self.assertEqual(dt.max, dt.vec(4).max)
|
||||
|
||||
def test_truncate_fp16(self):
|
||||
self.assertEqual(truncate_fp16(1), 1)
|
||||
self.assertEqual(truncate_fp16(65504), 65504)
|
||||
self.assertEqual(truncate_fp16(65519.999), 65504)
|
||||
self.assertEqual(truncate_fp16(65520), math.inf)
|
||||
self.assertEqual(truncate_fp16(1e-8), 0.0)
|
||||
self.assertEqual(truncate_fp16(-65504), -65504)
|
||||
self.assertEqual(truncate_fp16(-65519.999), -65504)
|
||||
self.assertEqual(truncate_fp16(-65520), -math.inf)
|
||||
self.assertTrue(math.isnan(truncate_fp16(math.nan)))
|
||||
def test_float_to_fp16(self):
|
||||
self.assertEqual(float_to_fp16(1), 1)
|
||||
self.assertEqual(float_to_fp16(65504), 65504)
|
||||
self.assertEqual(float_to_fp16(65519.999), 65504)
|
||||
self.assertEqual(float_to_fp16(65520), math.inf)
|
||||
self.assertEqual(float_to_fp16(1e-8), 0.0)
|
||||
self.assertEqual(float_to_fp16(-65504), -65504)
|
||||
self.assertEqual(float_to_fp16(-65519.999), -65504)
|
||||
self.assertEqual(float_to_fp16(-65520), -math.inf)
|
||||
self.assertTrue(math.isnan(float_to_fp16(math.nan)))
|
||||
|
||||
def test_float_to_bf16(self):
|
||||
# TODO: fuzz this better
|
||||
|
||||
@@ -233,7 +233,7 @@ def sum_acc_dtype(dt:DType):
|
||||
if dtypes.is_int(dt) or dt == dtypes.bool: return least_upper_dtype(dt, dtypes.int)
|
||||
return least_upper_dtype(dt, to_dtype(getenv("SUM_DTYPE", "float32")))
|
||||
|
||||
def truncate_fp16(x):
|
||||
def float_to_fp16(x):
|
||||
try: return struct.unpack('e', struct.pack('e', float(x)))[0]
|
||||
except OverflowError: return math.copysign(math.inf, x)
|
||||
|
||||
@@ -310,7 +310,7 @@ def fp8_to_float(x: int, dtype: DType) -> float:
|
||||
return float(float32_val)
|
||||
|
||||
truncate: dict[DType, Callable] = {dtypes.bool: bool,
|
||||
dtypes.float16: truncate_fp16, dtypes.bfloat16: lambda x: float_to_bf16(float(x)),
|
||||
dtypes.float16: float_to_fp16, dtypes.bfloat16: lambda x: float_to_bf16(float(x)),
|
||||
**{fp8: (lambda x, dtype=fp8: fp8_to_float(float_to_fp8(x, dtype), dtype)) for fp8 in dtypes.fp8s},
|
||||
dtypes.float32: lambda x: ctypes.c_float(x).value, dtypes.float64: lambda x: ctypes.c_double(x).value,
|
||||
dtypes.uint8: lambda x: ctypes.c_uint8(x).value, dtypes.uint16: lambda x: ctypes.c_uint16(x).value,
|
||||
|
||||
Reference in New Issue
Block a user