mirror of
https://github.com/tinygrad/tinygrad.git
synced 2026-04-07 03:00:26 -04:00
remove ml_dtypes (#12169)
This commit is contained in:
@@ -10,7 +10,6 @@ from tinygrad import Device, Tensor, dtypes
|
||||
from hypothesis import assume, given, settings, strategies as strat
|
||||
from test.helpers import rand_for_dtype
|
||||
from test.unit.test_dtype_spec import _assert_eq, core_dtypes, dtype_ints, dtype_floats, FP8E4M3_MAX, FP8E5M2_MAX
|
||||
import ml_dtypes
|
||||
import pytest
|
||||
pytestmark = pytest.mark.filterwarnings("ignore")
|
||||
|
||||
@@ -129,11 +128,10 @@ class TestDType(unittest.TestCase):
|
||||
np.testing.assert_allclose(tin, tor, atol=1e-6, rtol=1e-3)
|
||||
|
||||
def test_finfo(self):
|
||||
if self.DTYPE not in [dtypes.float16, dtypes.bfloat16, dtypes.float32, dtypes.float64]: return
|
||||
info = ml_dtypes.finfo(ml_dtypes.bfloat16 if self.DTYPE is dtypes.bfloat16 else _to_np_dtype(self.DTYPE))
|
||||
assert info.bits == self.DTYPE.itemsize*8
|
||||
assert info.nexp == dtypes.finfo(self.DTYPE)[0]
|
||||
assert info.nmant == dtypes.finfo(self.DTYPE)[1]
|
||||
if self.DTYPE not in [dtypes.float16, dtypes.float32, dtypes.float64]: return
|
||||
info = np.finfo(_to_np_dtype(self.DTYPE))
|
||||
self.assertEqual(info.bits, self.DTYPE.itemsize*8)
|
||||
self.assertEqual((info.nexp, info.nmant), dtypes.finfo(self.DTYPE))
|
||||
|
||||
def _test_ops(a_dtype:DType, b_dtype:DType, target_dtype=None):
|
||||
target_dtype = target_dtype or least_upper_dtype(a_dtype, b_dtype)
|
||||
@@ -151,7 +149,8 @@ class TestFp8s(unittest.TestCase):
|
||||
|
||||
class TestFp8sConversions(unittest.TestCase):
|
||||
@given(strat.floats(width=32, allow_subnormal=True, allow_nan=False, allow_infinity=False, min_value=-FP8E4M3_MAX, max_value=FP8E4M3_MAX))
|
||||
def test_float_to_fp8e4m3(self, x): np.testing.assert_equal(float_to_fp8(x, dtypes.fp8e4m3), ml_dtypes.float8_e4m3fn(x).tobytes()[0])
|
||||
def test_float_to_fp8e4m3(self, x):
|
||||
np.testing.assert_equal(float_to_fp8(x, dtypes.fp8e4m3), torch.tensor(x, dtype=torch.float8_e4m3fn).view(torch.uint8).item())
|
||||
|
||||
def test_float_to_fp8e4m3_extreme_values(self):
|
||||
np.testing.assert_equal(float_to_fp8(FP8E4M3_MAX, dtypes.fp8e4m3), 126)
|
||||
@@ -164,7 +163,8 @@ class TestFp8sConversions(unittest.TestCase):
|
||||
np.testing.assert_equal(float_to_fp8(-math.nan, dtypes.fp8e4m3), 255)
|
||||
|
||||
@given(strat.floats(width=32, allow_subnormal=True, allow_nan=False, allow_infinity=False, min_value=-FP8E5M2_MAX, max_value=FP8E5M2_MAX))
|
||||
def test_float_to_fp8e5m2(self, x): np.testing.assert_equal(float_to_fp8(x, dtypes.fp8e5m2), ml_dtypes.float8_e5m2(x).tobytes()[0])
|
||||
def test_float_to_fp8e5m2(self, x):
|
||||
np.testing.assert_equal(float_to_fp8(x, dtypes.fp8e5m2), torch.tensor(x, dtype=torch.float8_e5m2).view(torch.uint8).item())
|
||||
|
||||
def test_float_to_fp8e5m2_extreme_values(self):
|
||||
np.testing.assert_equal(float_to_fp8(FP8E5M2_MAX, dtypes.fp8e5m2), 123)
|
||||
@@ -177,10 +177,12 @@ class TestFp8sConversions(unittest.TestCase):
|
||||
np.testing.assert_equal(float_to_fp8(-math.nan, dtypes.fp8e5m2), 254)
|
||||
|
||||
@given(strat.integers(min_value=0, max_value=255))
|
||||
def test_fp8e4m3_to_float(self, x): np.testing.assert_equal(fp8_to_float(x, dtypes.fp8e4m3), np.uint8(x).view(ml_dtypes.float8_e4m3fn).item())
|
||||
def test_fp8e4m3_to_float(self, x):
|
||||
np.testing.assert_equal(fp8_to_float(x, dtypes.fp8e4m3), torch.tensor(x, dtype=torch.uint8).view(torch.float8_e4m3fn).float().item())
|
||||
|
||||
@given(strat.integers(min_value=0, max_value=255))
|
||||
def test_fp8e5m2_to_float(self, x): np.testing.assert_equal(fp8_to_float(x, dtypes.fp8e5m2), np.uint8(x).view(ml_dtypes.float8_e5m2).item())
|
||||
def test_fp8e5m2_to_float(self, x):
|
||||
np.testing.assert_equal(fp8_to_float(x, dtypes.fp8e5m2), torch.tensor(x, dtype=torch.uint8).view(torch.float8_e5m2).float().item())
|
||||
|
||||
@unittest.skipUnless(is_dtype_supported(dtypes.bfloat16), "bfloat16 not supported")
|
||||
class TestBFloat16(unittest.TestCase):
|
||||
|
||||
Reference in New Issue
Block a user