truncate_fp16 -> float_to_fp16 (#12234)

match float_to_bf16 and float_to_fp8
This commit is contained in:
chenyu
2025-09-18 09:48:27 -04:00
committed by GitHub
parent 54c15d74a4
commit 7487c13b61
3 changed files with 13 additions and 14 deletions

View File

@@ -519,7 +519,6 @@ jobs:
llvm: "true"
- name: Test CPU=1 RANGEIFY=1
# TODO: add more passing tests here
# test_threefry_doesnt_use_long is because there's a contig after the long now
# test_embedding issue with jit
# test_load_state_dict_sharded_model_dict_same_axis issue with multi
# test_instancenorm_3d is very slow

View File

@@ -1,6 +1,6 @@
import unittest, math, operator, subprocess, struct
from tinygrad.tensor import Tensor, dtypes, Device
from tinygrad.dtype import DType, DTYPES_DICT, truncate, truncate_fp16, float_to_bf16, _to_np_dtype, least_upper_dtype, least_upper_float
from tinygrad.dtype import DType, DTYPES_DICT, truncate, float_to_fp16, float_to_bf16, _to_np_dtype, least_upper_dtype, least_upper_float
from tinygrad.device import is_dtype_supported
from tinygrad.helpers import getenv, CI, DEBUG
from hypothesis import given, settings, strategies as strat
@@ -106,16 +106,16 @@ class TestHelpers(unittest.TestCase):
self.assertEqual(dt.min, dt.vec(4).min)
self.assertEqual(dt.max, dt.vec(4).max)
def test_truncate_fp16(self):
self.assertEqual(truncate_fp16(1), 1)
self.assertEqual(truncate_fp16(65504), 65504)
self.assertEqual(truncate_fp16(65519.999), 65504)
self.assertEqual(truncate_fp16(65520), math.inf)
self.assertEqual(truncate_fp16(1e-8), 0.0)
self.assertEqual(truncate_fp16(-65504), -65504)
self.assertEqual(truncate_fp16(-65519.999), -65504)
self.assertEqual(truncate_fp16(-65520), -math.inf)
self.assertTrue(math.isnan(truncate_fp16(math.nan)))
def test_float_to_fp16(self):
self.assertEqual(float_to_fp16(1), 1)
self.assertEqual(float_to_fp16(65504), 65504)
self.assertEqual(float_to_fp16(65519.999), 65504)
self.assertEqual(float_to_fp16(65520), math.inf)
self.assertEqual(float_to_fp16(1e-8), 0.0)
self.assertEqual(float_to_fp16(-65504), -65504)
self.assertEqual(float_to_fp16(-65519.999), -65504)
self.assertEqual(float_to_fp16(-65520), -math.inf)
self.assertTrue(math.isnan(float_to_fp16(math.nan)))
def test_float_to_bf16(self):
# TODO: fuzz this better

View File

@@ -233,7 +233,7 @@ def sum_acc_dtype(dt:DType):
if dtypes.is_int(dt) or dt == dtypes.bool: return least_upper_dtype(dt, dtypes.int)
return least_upper_dtype(dt, to_dtype(getenv("SUM_DTYPE", "float32")))
def truncate_fp16(x):
def float_to_fp16(x):
try: return struct.unpack('e', struct.pack('e', float(x)))[0]
except OverflowError: return math.copysign(math.inf, x)
@@ -310,7 +310,7 @@ def fp8_to_float(x: int, dtype: DType) -> float:
return float(float32_val)
truncate: dict[DType, Callable] = {dtypes.bool: bool,
dtypes.float16: truncate_fp16, dtypes.bfloat16: lambda x: float_to_bf16(float(x)),
dtypes.float16: float_to_fp16, dtypes.bfloat16: lambda x: float_to_bf16(float(x)),
**{fp8: (lambda x, dtype=fp8: fp8_to_float(float_to_fp8(x, dtype), dtype)) for fp8 in dtypes.fp8s},
dtypes.float32: lambda x: ctypes.c_float(x).value, dtypes.float64: lambda x: ctypes.c_double(x).value,
dtypes.uint8: lambda x: ctypes.c_uint8(x).value, dtypes.uint16: lambda x: ctypes.c_uint16(x).value,