diff --git a/test/backend/test_dtype.py b/test/backend/test_dtype.py index e3c9dbac8d..2c60675ac5 100644 --- a/test/backend/test_dtype.py +++ b/test/backend/test_dtype.py @@ -251,7 +251,7 @@ class TestEmulatedHalf(TestHalfDType): def setUpClass(cls): cls.stack = contextlib.ExitStack() cls.stack.enter_context(Context(EMULATED_DTYPES="half")) - cls.DATA = rand_for_dtype(cls.DTYPE, 10) + cls.DATA = rand_for_dtype(cls.DTYPE, 10, allow_subnormal=False) @classmethod def tearDownClass(cls): cls.stack.close() @@ -355,7 +355,7 @@ class TestEmulatedInt64DType(TestInt64DType): def setUpClass(cls): cls.stack = contextlib.ExitStack() cls.stack.enter_context(Context(EMULATED_DTYPES="long")) - cls.DATA = rand_for_dtype(cls.DTYPE, 10) + cls.DATA = rand_for_dtype(cls.DTYPE, 10, allow_subnormal=False) @classmethod def tearDownClass(cls): cls.stack.close() @@ -371,7 +371,7 @@ class TestEmulatedUInt64DType(TestUint64DType): def setUpClass(cls): cls.stack = contextlib.ExitStack() cls.stack.enter_context(Context(EMULATED_DTYPES="long")) - cls.DATA = rand_for_dtype(cls.DTYPE, 10) + cls.DATA = rand_for_dtype(cls.DTYPE, 10, allow_subnormal=False) @classmethod def tearDownClass(cls): cls.stack.close() @@ -385,7 +385,7 @@ class TestEmulatedBFloat16Type(TestBFloat16Type): def setUpClass(cls): cls.stack = contextlib.ExitStack() cls.stack.enter_context(Context(EMULATED_DTYPES="bfloat16")) - cls.DATA = rand_for_dtype(cls.DTYPE, 10) + cls.DATA = rand_for_dtype(cls.DTYPE, 10, allow_subnormal=False) @classmethod def tearDownClass(cls): cls.stack.close() @@ -397,7 +397,7 @@ class TestEmulatedFp8e4m3(TestFp8e4m3): def setUpClass(cls): cls.stack = contextlib.ExitStack() cls.stack.enter_context(Context(EMULATED_DTYPES="fp8e4m3")) - cls.DATA = rand_for_dtype(cls.DTYPE, 10) + cls.DATA = rand_for_dtype(cls.DTYPE, 10, allow_subnormal=False) @classmethod def tearDownClass(cls): cls.stack.close() @@ -409,7 +409,7 @@ class TestEmulatedFp8e5m2(TestFp8e5m2): def setUpClass(cls): cls.stack = contextlib.ExitStack() cls.stack.enter_context(Context(EMULATED_DTYPES="fp8e5m2")) - cls.DATA = rand_for_dtype(cls.DTYPE, 10) + cls.DATA = rand_for_dtype(cls.DTYPE, 10, allow_subnormal=False) @classmethod def tearDownClass(cls): cls.stack.close() diff --git a/test/helpers.py b/test/helpers.py index 8e198eab54..9ec5bbe008 100644 --- a/test/helpers.py +++ b/test/helpers.py @@ -41,14 +41,18 @@ def assert_jit_cache_len(fxn, expected_len): assert type(fxn.jit_cache[0].prg).__name__.endswith('Graph') assert len(fxn.jit_cache[0].prg.jit_cache) == expected_len -def rand_for_dtype(dt:DType, size:int): +def rand_for_dtype(dt:DType, size:int, allow_subnormal=True): if dtypes.is_unsigned(dt): return np.random.randint(0, 100, size=size, dtype=_to_np_dtype(dt)) elif dtypes.is_int(dt): return np.random.randint(-100, 100, size=size, dtype=_to_np_dtype(dt)) elif dt == dtypes.bool: return np.random.choice([True, False], size=size) - return np.random.uniform(-10, 10, size=size).astype(_to_np_dtype(dt)) + ret = np.random.uniform(-10, 10, size=size).astype(_to_np_dtype(dt)) + if not allow_subnormal: + min_normal = 2.0 ** (2 - (1 << (dtypes.finfo(dt)[0] - 1))) + ret = np.where(np.abs(ret) < min_normal, 0, ret) + return ret def timeit(fxn:Callable[..., T], *args, **kwargs) -> tuple[T, float]: st = time.perf_counter_ns()