mirror of
https://github.com/tinygrad/tinygrad.git
synced 2026-01-10 07:28:15 -05:00
add truncate_bf16 (#9078)
Co-authored-by: b1tg <b1tg@users.noreply.github.com>
This commit is contained in:
@@ -1,10 +1,10 @@
|
||||
import unittest, operator, subprocess, math
|
||||
import unittest, operator, subprocess, struct, math
|
||||
import numpy as np
|
||||
import torch
|
||||
from typing import Any, List
|
||||
from tinygrad.device import is_dtype_supported
|
||||
from tinygrad.helpers import getenv, DEBUG, CI
|
||||
from tinygrad.dtype import DType, DTYPES_DICT, ImageDType, PtrDType, least_upper_float, least_upper_dtype, truncate_fp16, to_dtype
|
||||
from tinygrad.dtype import DType, DTYPES_DICT, ImageDType, PtrDType, least_upper_float, least_upper_dtype, truncate_fp16, truncate_bf16, to_dtype
|
||||
from tinygrad import Device, Tensor, dtypes
|
||||
from tinygrad.tensor import _to_np_dtype
|
||||
from hypothesis import assume, given, settings, strategies as strat
|
||||
@@ -439,6 +439,14 @@ class TestHelpers(unittest.TestCase):
|
||||
self.assertEqual(truncate_fp16(65519.999), 65504)
|
||||
self.assertEqual(truncate_fp16(65520), math.inf)
|
||||
|
||||
def test_truncate_bf16(self):
|
||||
self.assertEqual(truncate_bf16(1), 1)
|
||||
self.assertAlmostEqual(truncate_bf16(1.1), 1.09375, places=7)
|
||||
max_bf16 = struct.unpack('f', struct.pack('I', 0x7f7f0000))[0]
|
||||
self.assertEqual(truncate_bf16(max_bf16), max_bf16)
|
||||
self.assertEqual(truncate_bf16(min_bf16:=-max_bf16), min_bf16)
|
||||
self.assertEqual(truncate_bf16(max_bf16 * 1.001), math.inf)
|
||||
|
||||
class TestTypeSpec(unittest.TestCase):
|
||||
def setUp(self):
|
||||
self.old_default_int, self.old_default_float = dtypes.default_int, dtypes.default_float
|
||||
|
||||
@@ -183,9 +183,16 @@ def truncate_fp16(x):
|
||||
try: return struct.unpack("@e", struct.pack("@e", float(x)))[0]
|
||||
except OverflowError: return math.copysign(math.inf, x)
|
||||
|
||||
def truncate_bf16(x):
|
||||
max_bf16 = struct.unpack('f', struct.pack('I', 0x7f7f0000))[0]
|
||||
if x > max_bf16 or x < -max_bf16: return math.copysign(math.inf, x)
|
||||
f32_int = struct.unpack('I', struct.pack('f', x))[0]
|
||||
bf = struct.unpack('f', struct.pack('I', f32_int & 0xFFFF0000))[0]
|
||||
return bf
|
||||
|
||||
truncate: dict[DType, Callable] = {dtypes.bool: bool,
|
||||
# TODO: bfloat16
|
||||
dtypes.float16: truncate_fp16, dtypes.float32: lambda x: ctypes.c_float(x).value, dtypes.float64: lambda x: ctypes.c_double(x).value,
|
||||
dtypes.float16: truncate_fp16, dtypes.bfloat16: truncate_bf16,
|
||||
dtypes.float32: lambda x: ctypes.c_float(x).value, dtypes.float64: lambda x: ctypes.c_double(x).value,
|
||||
dtypes.uint8: lambda x: ctypes.c_uint8(x).value, dtypes.uint16: lambda x: ctypes.c_uint16(x).value,
|
||||
dtypes.uint32: lambda x: ctypes.c_uint32(x).value, dtypes.uint64: lambda x: ctypes.c_uint64(x).value,
|
||||
dtypes.int8: lambda x: ctypes.c_int8(x).value, dtypes.int16: lambda x: ctypes.c_int16(x).value, dtypes.int32: lambda x: ctypes.c_int32(x).value,
|
||||
|
||||
Reference in New Issue
Block a user