mirror of
https://github.com/tinygrad/tinygrad.git
synced 2026-01-10 23:48:01 -05:00
@@ -414,11 +414,11 @@ class TestDtypeUsage(unittest.TestCase):
|
||||
t = Tensor([[1, 2], [3, 4]], dtype=d)
|
||||
(t*t).max().item()
|
||||
|
||||
@unittest.skipUnless(is_dtype_supported(dtypes.bfloat16), f"no bfloat16 on {Device.DEFAULT}")
|
||||
@unittest.skipUnless(is_dtype_supported(dtypes.bfloat16) or Device.DEFAULT == "PYTHON", f"no bfloat16 on {Device.DEFAULT}")
|
||||
class TestOpsBFloat16(unittest.TestCase):
|
||||
def test_cast(self):
|
||||
# TODO: helper_test_op breaks in unrelated part
|
||||
# TODO: wrong output with GPU=1 / PYTHON=1 on mac
|
||||
# TODO: wrong output with GPU=1 on mac
|
||||
data = [60000.0, 70000.0, 80000.0]
|
||||
np.testing.assert_allclose(Tensor(data).cast("bfloat16").numpy(), torch.tensor(data).type(torch.bfloat16).float().numpy())
|
||||
|
||||
|
||||
@@ -56,6 +56,7 @@ class TestCastConvenienceMethod(unittest.TestCase):
|
||||
class TestDtypeTolist(unittest.TestCase):
|
||||
def test_bfloat16(self):
|
||||
self.assertEqual(Tensor([-60000, 1.5, 3.1, 60000], device="PYTHON", dtype=dtypes.bfloat16).tolist(), [-59904.0, 1.5, 3.09375, 59904.0])
|
||||
def test_fp8(self):
|
||||
# 448
|
||||
self.assertEqual(Tensor([-30000, 1.5, 3.1, 30000], device="PYTHON", dtype=dtypes.fp8e4m3).tolist(), [-448.0, 1.5, 3.0, 448.0])
|
||||
# 57344
|
||||
|
||||
@@ -1,6 +1,6 @@
|
||||
import unittest, math, operator, subprocess
|
||||
from tinygrad.tensor import Tensor, dtypes, Device
|
||||
from tinygrad.dtype import DType, DTYPES_DICT, truncate, truncate_fp16, truncate_bf16, _to_np_dtype, least_upper_dtype, least_upper_float
|
||||
from tinygrad.dtype import DType, DTYPES_DICT, truncate, truncate_fp16, float_to_bf16, _to_np_dtype, least_upper_dtype, least_upper_float
|
||||
from tinygrad.device import is_dtype_supported
|
||||
from tinygrad.helpers import getenv, CI, DEBUG
|
||||
from hypothesis import given, settings, strategies as strat
|
||||
@@ -108,17 +108,12 @@ class TestHelpers(unittest.TestCase):
|
||||
self.assertEqual(truncate_fp16(-65520), -math.inf)
|
||||
self.assertTrue(math.isnan(truncate_fp16(math.nan)))
|
||||
|
||||
def test_truncate_bf16(self):
|
||||
self.assertEqual(truncate_bf16(1), 1)
|
||||
# TODO: rounding, torch bfloat 1.1 gives 1.1015625 instead of 1.09375
|
||||
self.assertAlmostEqual(truncate_bf16(1.1), 1.09375, places=7)
|
||||
for a in [1234, 23456, -777.777]:
|
||||
self.assertEqual(truncate_bf16(a), torch.tensor([a], dtype=torch.bfloat16).item())
|
||||
def test_float_to_bf16(self):
|
||||
# TODO: fuzz this better
|
||||
max_bf16 = torch.finfo(torch.bfloat16).max
|
||||
self.assertEqual(truncate_bf16(max_bf16), max_bf16)
|
||||
self.assertEqual(truncate_bf16(min_bf16:=-max_bf16), min_bf16)
|
||||
self.assertEqual(truncate_bf16(max_bf16 * 1.00001), math.inf)
|
||||
self.assertEqual(truncate_bf16(min_bf16 * 1.00001), -math.inf)
|
||||
for a in [1, 1.1, 1234, 23456, -777.777, max_bf16, max_bf16 * 1.00001, -max_bf16, -max_bf16 * 1.00001, math.inf, -math.inf]:
|
||||
self.assertEqual(float_to_bf16(a), torch.tensor([a], dtype=torch.bfloat16).item())
|
||||
self.assertTrue(math.isnan(float_to_bf16(math.nan)))
|
||||
|
||||
@given(strat.floats(width=32, allow_subnormal=True, allow_nan=True, allow_infinity=True))
|
||||
def test_truncate_fp8e4m3(self, x):
|
||||
|
||||
@@ -217,12 +217,10 @@ def truncate_fp16(x):
|
||||
try: return struct.unpack('e', struct.pack('e', float(x)))[0]
|
||||
except OverflowError: return math.copysign(math.inf, x)
|
||||
|
||||
def truncate_bf16(x):
|
||||
max_bf16 = struct.unpack('f', struct.pack('I', 0x7f7f0000))[0]
|
||||
if abs(x) > max_bf16: return math.copysign(math.inf, x)
|
||||
f32_int = struct.unpack('I', struct.pack('f', x))[0]
|
||||
bf = struct.unpack('f', struct.pack('I', f32_int & 0xFFFF0000))[0]
|
||||
return bf
|
||||
def float_to_bf16(x):
|
||||
u = struct.unpack('I', struct.pack('f', x))[0]
|
||||
u = (u + 0x7FFF + ((u >> 16) & 1)) & 0xFFFF0000
|
||||
return struct.unpack('f', struct.pack('I', u))[0]
|
||||
|
||||
# fp8-float conversions based on https://gitlab.com/nvidia/headers/cuda-individual/cudart/-/blob/main/cuda_fp8.hpp
|
||||
def float_to_fp8(x: float, dtype: DType) -> int:
|
||||
@@ -287,7 +285,7 @@ def fp8_to_float(x: int, dtype: DType) -> float:
|
||||
return float(float32_val)
|
||||
|
||||
truncate: dict[DType, Callable] = {dtypes.bool: bool,
|
||||
dtypes.float16: truncate_fp16, dtypes.bfloat16: truncate_bf16,
|
||||
dtypes.float16: truncate_fp16, dtypes.bfloat16: lambda x: float_to_bf16(float(x)),
|
||||
**{fp8: (lambda x, dtype=fp8: fp8_to_float(float_to_fp8(x, dtype), dtype)) for fp8 in dtypes.fp8s},
|
||||
dtypes.float32: lambda x: ctypes.c_float(x).value, dtypes.float64: lambda x: ctypes.c_double(x).value,
|
||||
dtypes.uint8: lambda x: ctypes.c_uint8(x).value, dtypes.uint16: lambda x: ctypes.c_uint16(x).value,
|
||||
|
||||
Reference in New Issue
Block a user