disallow subnormals in emulated test_dtype (#14744)

This commit is contained in:
Christopher Milan
2026-02-13 21:11:57 -08:00
committed by GitHub
parent 9d9ef81608
commit eaa9506a00
2 changed files with 12 additions and 8 deletions

View File

@@ -251,7 +251,7 @@ class TestEmulatedHalf(TestHalfDType):
def setUpClass(cls):
cls.stack = contextlib.ExitStack()
cls.stack.enter_context(Context(EMULATED_DTYPES="half"))
cls.DATA = rand_for_dtype(cls.DTYPE, 10)
cls.DATA = rand_for_dtype(cls.DTYPE, 10, allow_subnormal=False)
@classmethod
def tearDownClass(cls): cls.stack.close()
@@ -355,7 +355,7 @@ class TestEmulatedInt64DType(TestInt64DType):
def setUpClass(cls):
cls.stack = contextlib.ExitStack()
cls.stack.enter_context(Context(EMULATED_DTYPES="long"))
cls.DATA = rand_for_dtype(cls.DTYPE, 10)
cls.DATA = rand_for_dtype(cls.DTYPE, 10, allow_subnormal=False)
@classmethod
def tearDownClass(cls): cls.stack.close()
@@ -371,7 +371,7 @@ class TestEmulatedUInt64DType(TestUint64DType):
def setUpClass(cls):
cls.stack = contextlib.ExitStack()
cls.stack.enter_context(Context(EMULATED_DTYPES="long"))
cls.DATA = rand_for_dtype(cls.DTYPE, 10)
cls.DATA = rand_for_dtype(cls.DTYPE, 10, allow_subnormal=False)
@classmethod
def tearDownClass(cls): cls.stack.close()
@@ -385,7 +385,7 @@ class TestEmulatedBFloat16Type(TestBFloat16Type):
def setUpClass(cls):
cls.stack = contextlib.ExitStack()
cls.stack.enter_context(Context(EMULATED_DTYPES="bfloat16"))
cls.DATA = rand_for_dtype(cls.DTYPE, 10)
cls.DATA = rand_for_dtype(cls.DTYPE, 10, allow_subnormal=False)
@classmethod
def tearDownClass(cls): cls.stack.close()
@@ -397,7 +397,7 @@ class TestEmulatedFp8e4m3(TestFp8e4m3):
def setUpClass(cls):
cls.stack = contextlib.ExitStack()
cls.stack.enter_context(Context(EMULATED_DTYPES="fp8e4m3"))
cls.DATA = rand_for_dtype(cls.DTYPE, 10)
cls.DATA = rand_for_dtype(cls.DTYPE, 10, allow_subnormal=False)
@classmethod
def tearDownClass(cls): cls.stack.close()
@@ -409,7 +409,7 @@ class TestEmulatedFp8e5m2(TestFp8e5m2):
def setUpClass(cls):
cls.stack = contextlib.ExitStack()
cls.stack.enter_context(Context(EMULATED_DTYPES="fp8e5m2"))
cls.DATA = rand_for_dtype(cls.DTYPE, 10)
cls.DATA = rand_for_dtype(cls.DTYPE, 10, allow_subnormal=False)
@classmethod
def tearDownClass(cls): cls.stack.close()

View File

@@ -41,14 +41,18 @@ def assert_jit_cache_len(fxn, expected_len):
assert type(fxn.jit_cache[0].prg).__name__.endswith('Graph')
assert len(fxn.jit_cache[0].prg.jit_cache) == expected_len
def rand_for_dtype(dt:DType, size:int):
def rand_for_dtype(dt:DType, size:int, allow_subnormal=True):
if dtypes.is_unsigned(dt):
return np.random.randint(0, 100, size=size, dtype=_to_np_dtype(dt))
elif dtypes.is_int(dt):
return np.random.randint(-100, 100, size=size, dtype=_to_np_dtype(dt))
elif dt == dtypes.bool:
return np.random.choice([True, False], size=size)
return np.random.uniform(-10, 10, size=size).astype(_to_np_dtype(dt))
ret = np.random.uniform(-10, 10, size=size).astype(_to_np_dtype(dt))
if not allow_subnormal:
min_normal = 2.0 ** (2 - (1 << (dtypes.finfo(dt)[0] - 1)))
ret = np.where(np.abs(ret) < min_normal, 0, ret)
return ret
def timeit(fxn:Callable[..., T], *args, **kwargs) -> tuple[T, float]:
st = time.perf_counter_ns()