improved float_to_bf16 (#11848)

round instead of truncate
This commit is contained in:
chenyu
2025-08-26 11:14:06 -04:00
committed by GitHub
parent afe14ccbfa
commit f28f613f85
4 changed files with 14 additions and 20 deletions

View File

@@ -414,11 +414,11 @@ class TestDtypeUsage(unittest.TestCase):
t = Tensor([[1, 2], [3, 4]], dtype=d)
(t*t).max().item()
@unittest.skipUnless(is_dtype_supported(dtypes.bfloat16), f"no bfloat16 on {Device.DEFAULT}")
@unittest.skipUnless(is_dtype_supported(dtypes.bfloat16) or Device.DEFAULT == "PYTHON", f"no bfloat16 on {Device.DEFAULT}")
class TestOpsBFloat16(unittest.TestCase):
def test_cast(self):
# TODO: helper_test_op breaks in unrelated part
# TODO: wrong output with GPU=1 / PYTHON=1 on mac
# TODO: wrong output with GPU=1 on mac
data = [60000.0, 70000.0, 80000.0]
np.testing.assert_allclose(Tensor(data).cast("bfloat16").numpy(), torch.tensor(data).type(torch.bfloat16).float().numpy())

View File

@@ -56,6 +56,7 @@ class TestCastConvenienceMethod(unittest.TestCase):
class TestDtypeTolist(unittest.TestCase):
def test_bfloat16(self):
self.assertEqual(Tensor([-60000, 1.5, 3.1, 60000], device="PYTHON", dtype=dtypes.bfloat16).tolist(), [-59904.0, 1.5, 3.09375, 59904.0])
def test_fp8(self):
# 448
self.assertEqual(Tensor([-30000, 1.5, 3.1, 30000], device="PYTHON", dtype=dtypes.fp8e4m3).tolist(), [-448.0, 1.5, 3.0, 448.0])
# 57344

View File

@@ -1,6 +1,6 @@
import unittest, math, operator, subprocess
from tinygrad.tensor import Tensor, dtypes, Device
from tinygrad.dtype import DType, DTYPES_DICT, truncate, truncate_fp16, truncate_bf16, _to_np_dtype, least_upper_dtype, least_upper_float
from tinygrad.dtype import DType, DTYPES_DICT, truncate, truncate_fp16, float_to_bf16, _to_np_dtype, least_upper_dtype, least_upper_float
from tinygrad.device import is_dtype_supported
from tinygrad.helpers import getenv, CI, DEBUG
from hypothesis import given, settings, strategies as strat
@@ -108,17 +108,12 @@ class TestHelpers(unittest.TestCase):
self.assertEqual(truncate_fp16(-65520), -math.inf)
self.assertTrue(math.isnan(truncate_fp16(math.nan)))
def test_truncate_bf16(self):
self.assertEqual(truncate_bf16(1), 1)
# TODO: rounding, torch bfloat 1.1 gives 1.1015625 instead of 1.09375
self.assertAlmostEqual(truncate_bf16(1.1), 1.09375, places=7)
for a in [1234, 23456, -777.777]:
self.assertEqual(truncate_bf16(a), torch.tensor([a], dtype=torch.bfloat16).item())
def test_float_to_bf16(self):
# TODO: fuzz this better
max_bf16 = torch.finfo(torch.bfloat16).max
self.assertEqual(truncate_bf16(max_bf16), max_bf16)
self.assertEqual(truncate_bf16(min_bf16:=-max_bf16), min_bf16)
self.assertEqual(truncate_bf16(max_bf16 * 1.00001), math.inf)
self.assertEqual(truncate_bf16(min_bf16 * 1.00001), -math.inf)
for a in [1, 1.1, 1234, 23456, -777.777, max_bf16, max_bf16 * 1.00001, -max_bf16, -max_bf16 * 1.00001, math.inf, -math.inf]:
self.assertEqual(float_to_bf16(a), torch.tensor([a], dtype=torch.bfloat16).item())
self.assertTrue(math.isnan(float_to_bf16(math.nan)))
@given(strat.floats(width=32, allow_subnormal=True, allow_nan=True, allow_infinity=True))
def test_truncate_fp8e4m3(self, x):

View File

@@ -217,12 +217,10 @@ def truncate_fp16(x):
try: return struct.unpack('e', struct.pack('e', float(x)))[0]
except OverflowError: return math.copysign(math.inf, x)
def truncate_bf16(x):
max_bf16 = struct.unpack('f', struct.pack('I', 0x7f7f0000))[0]
if abs(x) > max_bf16: return math.copysign(math.inf, x)
f32_int = struct.unpack('I', struct.pack('f', x))[0]
bf = struct.unpack('f', struct.pack('I', f32_int & 0xFFFF0000))[0]
return bf
def float_to_bf16(x):
u = struct.unpack('I', struct.pack('f', x))[0]
u = (u + 0x7FFF + ((u >> 16) & 1)) & 0xFFFF0000
return struct.unpack('f', struct.pack('I', u))[0]
# fp8-float conversions based on https://gitlab.com/nvidia/headers/cuda-individual/cudart/-/blob/main/cuda_fp8.hpp
def float_to_fp8(x: float, dtype: DType) -> int:
@@ -287,7 +285,7 @@ def fp8_to_float(x: int, dtype: DType) -> float:
return float(float32_val)
truncate: dict[DType, Callable] = {dtypes.bool: bool,
dtypes.float16: truncate_fp16, dtypes.bfloat16: truncate_bf16,
dtypes.float16: truncate_fp16, dtypes.bfloat16: lambda x: float_to_bf16(float(x)),
**{fp8: (lambda x, dtype=fp8: fp8_to_float(float_to_fp8(x, dtype), dtype)) for fp8 in dtypes.fp8s},
dtypes.float32: lambda x: ctypes.c_float(x).value, dtypes.float64: lambda x: ctypes.c_double(x).value,
dtypes.uint8: lambda x: ctypes.c_uint8(x).value, dtypes.uint16: lambda x: ctypes.c_uint16(x).value,