mirror of
https://github.com/tinygrad/tinygrad.git
synced 2026-04-07 03:00:26 -04:00
remove ml_dtypes (#12169)
This commit is contained in:
@@ -6,7 +6,6 @@ from tinygrad.helpers import getenv, CI, DEBUG
|
||||
from hypothesis import given, settings, strategies as strat
|
||||
import numpy as np
|
||||
import torch
|
||||
import ml_dtypes
|
||||
|
||||
settings.register_profile("my_profile", max_examples=200, deadline=None, derandomize=getenv("DERANDOMIZE_CI", False))
|
||||
settings.load_profile("my_profile")
|
||||
@@ -190,7 +189,7 @@ class TestHelpers(unittest.TestCase):
|
||||
elif math.isinf(x): np.testing.assert_equal(truncate[dtypes.fp8e4m3](x), math.copysign(math.nan, x))
|
||||
elif x > FP8E4M3_MAX: np.testing.assert_equal(truncate[dtypes.fp8e4m3](x), FP8E4M3_MAX)
|
||||
elif x < -FP8E4M3_MAX: np.testing.assert_equal(truncate[dtypes.fp8e4m3](x), -FP8E4M3_MAX)
|
||||
else: np.testing.assert_equal(truncate[dtypes.fp8e4m3](x), ml_dtypes.float8_e4m3fn(x))
|
||||
else: np.testing.assert_equal(truncate[dtypes.fp8e4m3](x), torch.tensor(x, dtype=torch.float8_e4m3fn).float().item())
|
||||
|
||||
@given(strat.floats(width=32, allow_subnormal=True, allow_nan=True, allow_infinity=True))
|
||||
def test_truncate_fp8e5m2(self, x):
|
||||
@@ -198,7 +197,7 @@ class TestHelpers(unittest.TestCase):
|
||||
elif math.isinf(x): np.testing.assert_equal(truncate[dtypes.fp8e5m2](x), x)
|
||||
elif x > FP8E5M2_MAX: np.testing.assert_equal(truncate[dtypes.fp8e5m2](x), FP8E5M2_MAX)
|
||||
elif x < -FP8E5M2_MAX: np.testing.assert_equal(truncate[dtypes.fp8e5m2](x), -FP8E5M2_MAX)
|
||||
else: np.testing.assert_equal(truncate[dtypes.fp8e5m2](x), ml_dtypes.float8_e5m2(x))
|
||||
else: np.testing.assert_equal(truncate[dtypes.fp8e5m2](x), torch.tensor(x, dtype=torch.float8_e5m2).float().item())
|
||||
|
||||
class TestTypeSpec(unittest.TestCase):
|
||||
def setUp(self):
|
||||
|
||||
Reference in New Issue
Block a user