mirror of
https://github.com/tinygrad/tinygrad.git
synced 2026-04-29 03:00:14 -04:00
disallow subnormals in emulated test_dtype (#14744)
This commit is contained in:
committed by
GitHub
parent
9d9ef81608
commit
eaa9506a00
@@ -251,7 +251,7 @@ class TestEmulatedHalf(TestHalfDType):
|
||||
def setUpClass(cls):
|
||||
cls.stack = contextlib.ExitStack()
|
||||
cls.stack.enter_context(Context(EMULATED_DTYPES="half"))
|
||||
cls.DATA = rand_for_dtype(cls.DTYPE, 10)
|
||||
cls.DATA = rand_for_dtype(cls.DTYPE, 10, allow_subnormal=False)
|
||||
|
||||
@classmethod
|
||||
def tearDownClass(cls): cls.stack.close()
|
||||
@@ -355,7 +355,7 @@ class TestEmulatedInt64DType(TestInt64DType):
|
||||
def setUpClass(cls):
|
||||
cls.stack = contextlib.ExitStack()
|
||||
cls.stack.enter_context(Context(EMULATED_DTYPES="long"))
|
||||
cls.DATA = rand_for_dtype(cls.DTYPE, 10)
|
||||
cls.DATA = rand_for_dtype(cls.DTYPE, 10, allow_subnormal=False)
|
||||
|
||||
@classmethod
|
||||
def tearDownClass(cls): cls.stack.close()
|
||||
@@ -371,7 +371,7 @@ class TestEmulatedUInt64DType(TestUint64DType):
|
||||
def setUpClass(cls):
|
||||
cls.stack = contextlib.ExitStack()
|
||||
cls.stack.enter_context(Context(EMULATED_DTYPES="long"))
|
||||
cls.DATA = rand_for_dtype(cls.DTYPE, 10)
|
||||
cls.DATA = rand_for_dtype(cls.DTYPE, 10, allow_subnormal=False)
|
||||
|
||||
@classmethod
|
||||
def tearDownClass(cls): cls.stack.close()
|
||||
@@ -385,7 +385,7 @@ class TestEmulatedBFloat16Type(TestBFloat16Type):
|
||||
def setUpClass(cls):
|
||||
cls.stack = contextlib.ExitStack()
|
||||
cls.stack.enter_context(Context(EMULATED_DTYPES="bfloat16"))
|
||||
cls.DATA = rand_for_dtype(cls.DTYPE, 10)
|
||||
cls.DATA = rand_for_dtype(cls.DTYPE, 10, allow_subnormal=False)
|
||||
|
||||
@classmethod
|
||||
def tearDownClass(cls): cls.stack.close()
|
||||
@@ -397,7 +397,7 @@ class TestEmulatedFp8e4m3(TestFp8e4m3):
|
||||
def setUpClass(cls):
|
||||
cls.stack = contextlib.ExitStack()
|
||||
cls.stack.enter_context(Context(EMULATED_DTYPES="fp8e4m3"))
|
||||
cls.DATA = rand_for_dtype(cls.DTYPE, 10)
|
||||
cls.DATA = rand_for_dtype(cls.DTYPE, 10, allow_subnormal=False)
|
||||
|
||||
@classmethod
|
||||
def tearDownClass(cls): cls.stack.close()
|
||||
@@ -409,7 +409,7 @@ class TestEmulatedFp8e5m2(TestFp8e5m2):
|
||||
def setUpClass(cls):
|
||||
cls.stack = contextlib.ExitStack()
|
||||
cls.stack.enter_context(Context(EMULATED_DTYPES="fp8e5m2"))
|
||||
cls.DATA = rand_for_dtype(cls.DTYPE, 10)
|
||||
cls.DATA = rand_for_dtype(cls.DTYPE, 10, allow_subnormal=False)
|
||||
|
||||
@classmethod
|
||||
def tearDownClass(cls): cls.stack.close()
|
||||
|
||||
@@ -41,14 +41,18 @@ def assert_jit_cache_len(fxn, expected_len):
|
||||
assert type(fxn.jit_cache[0].prg).__name__.endswith('Graph')
|
||||
assert len(fxn.jit_cache[0].prg.jit_cache) == expected_len
|
||||
|
||||
def rand_for_dtype(dt:DType, size:int):
|
||||
def rand_for_dtype(dt:DType, size:int, allow_subnormal=True):
|
||||
if dtypes.is_unsigned(dt):
|
||||
return np.random.randint(0, 100, size=size, dtype=_to_np_dtype(dt))
|
||||
elif dtypes.is_int(dt):
|
||||
return np.random.randint(-100, 100, size=size, dtype=_to_np_dtype(dt))
|
||||
elif dt == dtypes.bool:
|
||||
return np.random.choice([True, False], size=size)
|
||||
return np.random.uniform(-10, 10, size=size).astype(_to_np_dtype(dt))
|
||||
ret = np.random.uniform(-10, 10, size=size).astype(_to_np_dtype(dt))
|
||||
if not allow_subnormal:
|
||||
min_normal = 2.0 ** (2 - (1 << (dtypes.finfo(dt)[0] - 1)))
|
||||
ret = np.where(np.abs(ret) < min_normal, 0, ret)
|
||||
return ret
|
||||
|
||||
def timeit(fxn:Callable[..., T], *args, **kwargs) -> tuple[T, float]:
|
||||
st = time.perf_counter_ns()
|
||||
|
||||
Reference in New Issue
Block a user