remove ml_dtypes (#12169)

This commit is contained in:
chenyu
2025-09-14 14:20:05 -04:00
committed by GitHub
parent 02054b53fe
commit 98ecab7563
3 changed files with 14 additions and 14 deletions

View File

@@ -10,7 +10,6 @@ from tinygrad import Device, Tensor, dtypes
from hypothesis import assume, given, settings, strategies as strat
from test.helpers import rand_for_dtype
from test.unit.test_dtype_spec import _assert_eq, core_dtypes, dtype_ints, dtype_floats, FP8E4M3_MAX, FP8E5M2_MAX
import ml_dtypes
import pytest
pytestmark = pytest.mark.filterwarnings("ignore")
@@ -129,11 +128,10 @@ class TestDType(unittest.TestCase):
np.testing.assert_allclose(tin, tor, atol=1e-6, rtol=1e-3)
def test_finfo(self):
if self.DTYPE not in [dtypes.float16, dtypes.bfloat16, dtypes.float32, dtypes.float64]: return
info = ml_dtypes.finfo(ml_dtypes.bfloat16 if self.DTYPE is dtypes.bfloat16 else _to_np_dtype(self.DTYPE))
assert info.bits == self.DTYPE.itemsize*8
assert info.nexp == dtypes.finfo(self.DTYPE)[0]
assert info.nmant == dtypes.finfo(self.DTYPE)[1]
if self.DTYPE not in [dtypes.float16, dtypes.float32, dtypes.float64]: return
info = np.finfo(_to_np_dtype(self.DTYPE))
self.assertEqual(info.bits, self.DTYPE.itemsize*8)
self.assertEqual((info.nexp, info.nmant), dtypes.finfo(self.DTYPE))
def _test_ops(a_dtype:DType, b_dtype:DType, target_dtype=None):
target_dtype = target_dtype or least_upper_dtype(a_dtype, b_dtype)
@@ -151,7 +149,8 @@ class TestFp8s(unittest.TestCase):
class TestFp8sConversions(unittest.TestCase):
@given(strat.floats(width=32, allow_subnormal=True, allow_nan=False, allow_infinity=False, min_value=-FP8E4M3_MAX, max_value=FP8E4M3_MAX))
def test_float_to_fp8e4m3(self, x): np.testing.assert_equal(float_to_fp8(x, dtypes.fp8e4m3), ml_dtypes.float8_e4m3fn(x).tobytes()[0])
def test_float_to_fp8e4m3(self, x):
np.testing.assert_equal(float_to_fp8(x, dtypes.fp8e4m3), torch.tensor(x, dtype=torch.float8_e4m3fn).view(torch.uint8).item())
def test_float_to_fp8e4m3_extreme_values(self):
np.testing.assert_equal(float_to_fp8(FP8E4M3_MAX, dtypes.fp8e4m3), 126)
@@ -164,7 +163,8 @@ class TestFp8sConversions(unittest.TestCase):
np.testing.assert_equal(float_to_fp8(-math.nan, dtypes.fp8e4m3), 255)
@given(strat.floats(width=32, allow_subnormal=True, allow_nan=False, allow_infinity=False, min_value=-FP8E5M2_MAX, max_value=FP8E5M2_MAX))
def test_float_to_fp8e5m2(self, x): np.testing.assert_equal(float_to_fp8(x, dtypes.fp8e5m2), ml_dtypes.float8_e5m2(x).tobytes()[0])
def test_float_to_fp8e5m2(self, x):
np.testing.assert_equal(float_to_fp8(x, dtypes.fp8e5m2), torch.tensor(x, dtype=torch.float8_e5m2).view(torch.uint8).item())
def test_float_to_fp8e5m2_extreme_values(self):
np.testing.assert_equal(float_to_fp8(FP8E5M2_MAX, dtypes.fp8e5m2), 123)
@@ -177,10 +177,12 @@ class TestFp8sConversions(unittest.TestCase):
np.testing.assert_equal(float_to_fp8(-math.nan, dtypes.fp8e5m2), 254)
@given(strat.integers(min_value=0, max_value=255))
def test_fp8e4m3_to_float(self, x): np.testing.assert_equal(fp8_to_float(x, dtypes.fp8e4m3), np.uint8(x).view(ml_dtypes.float8_e4m3fn).item())
def test_fp8e4m3_to_float(self, x):
np.testing.assert_equal(fp8_to_float(x, dtypes.fp8e4m3), torch.tensor(x, dtype=torch.uint8).view(torch.float8_e4m3fn).float().item())
@given(strat.integers(min_value=0, max_value=255))
def test_fp8e5m2_to_float(self, x): np.testing.assert_equal(fp8_to_float(x, dtypes.fp8e5m2), np.uint8(x).view(ml_dtypes.float8_e5m2).item())
def test_fp8e5m2_to_float(self, x):
np.testing.assert_equal(fp8_to_float(x, dtypes.fp8e5m2), torch.tensor(x, dtype=torch.uint8).view(torch.float8_e5m2).float().item())
@unittest.skipUnless(is_dtype_supported(dtypes.bfloat16), "bfloat16 not supported")
class TestBFloat16(unittest.TestCase):