mirror of
https://github.com/tinygrad/tinygrad.git
synced 2026-04-07 03:00:26 -04:00
cleanup truncate_bf16 [pr] (#9725)
use torch bfloat16 for groundtruth in test. also a TODO for discrepancy
This commit is contained in:
@@ -1,4 +1,4 @@
|
||||
import unittest, operator, subprocess, struct, math
|
||||
import unittest, operator, subprocess, math
|
||||
import numpy as np
|
||||
import torch
|
||||
from typing import Any, List
|
||||
@@ -442,10 +442,14 @@ class TestHelpers(unittest.TestCase):
|
||||
def test_truncate_bf16(self):
|
||||
self.assertEqual(truncate_bf16(1), 1)
|
||||
self.assertAlmostEqual(truncate_bf16(1.1), 1.09375, places=7)
|
||||
max_bf16 = struct.unpack('f', struct.pack('I', 0x7f7f0000))[0]
|
||||
for a in [1234, 23456, -777.777]:
|
||||
self.assertEqual(truncate_bf16(a), torch.tensor([a], dtype=torch.bfloat16).item())
|
||||
# TODO: torch bfloat 1.1 gives 1.1015625 instead of 1.09375
|
||||
max_bf16 = torch.finfo(torch.bfloat16).max
|
||||
self.assertEqual(truncate_bf16(max_bf16), max_bf16)
|
||||
self.assertEqual(truncate_bf16(min_bf16:=-max_bf16), min_bf16)
|
||||
self.assertEqual(truncate_bf16(max_bf16 * 1.001), math.inf)
|
||||
self.assertEqual(truncate_bf16(max_bf16 * 1.00001), math.inf)
|
||||
self.assertEqual(truncate_bf16(min_bf16 * 1.00001), -math.inf)
|
||||
|
||||
class TestTypeSpec(unittest.TestCase):
|
||||
def setUp(self):
|
||||
|
||||
@@ -188,7 +188,7 @@ def truncate_fp16(x):
|
||||
|
||||
def truncate_bf16(x):
|
||||
max_bf16 = struct.unpack('f', struct.pack('I', 0x7f7f0000))[0]
|
||||
if x > max_bf16 or x < -max_bf16: return math.copysign(math.inf, x)
|
||||
if abs(x) > max_bf16: return math.copysign(math.inf, x)
|
||||
f32_int = struct.unpack('I', struct.pack('f', x))[0]
|
||||
bf = struct.unpack('f', struct.pack('I', f32_int & 0xFFFF0000))[0]
|
||||
return bf
|
||||
|
||||
Reference in New Issue
Block a user