improved float_to_bf16 (#11848)

round instead of truncate
This commit is contained in:
chenyu
2025-08-26 11:14:06 -04:00
committed by GitHub
parent afe14ccbfa
commit f28f613f85
4 changed files with 14 additions and 20 deletions

View File

@@ -414,11 +414,11 @@ class TestDtypeUsage(unittest.TestCase):
t = Tensor([[1, 2], [3, 4]], dtype=d)
(t*t).max().item()
@unittest.skipUnless(is_dtype_supported(dtypes.bfloat16), f"no bfloat16 on {Device.DEFAULT}")
@unittest.skipUnless(is_dtype_supported(dtypes.bfloat16) or Device.DEFAULT == "PYTHON", f"no bfloat16 on {Device.DEFAULT}")
class TestOpsBFloat16(unittest.TestCase):
def test_cast(self):
# TODO: helper_test_op breaks in unrelated part
# TODO: wrong output with GPU=1 / PYTHON=1 on mac
# TODO: wrong output with GPU=1 on mac
data = [60000.0, 70000.0, 80000.0]
np.testing.assert_allclose(Tensor(data).cast("bfloat16").numpy(), torch.tensor(data).type(torch.bfloat16).float().numpy())