From ab58926b003e48b8b5a2b0734536d8d44525d6ec Mon Sep 17 00:00:00 2001 From: chenyu Date: Mon, 29 Dec 2025 21:35:46 -0500 Subject: [PATCH] update sampling in test_float_cast_to_unsigned (#13889) filter is slow for small dtypes --- test/test_dtype_alu.py | 26 ++++++++++---------------- 1 file changed, 10 insertions(+), 16 deletions(-) diff --git a/test/test_dtype_alu.py b/test/test_dtype_alu.py index 375b087a80..102ae908b9 100644 --- a/test/test_dtype_alu.py +++ b/test/test_dtype_alu.py @@ -9,7 +9,7 @@ from tinygrad.renderer.ptx import PTXRenderer from tinygrad.renderer.nir import NIRRenderer import numpy as np import pytest -from hypothesis import assume, given, strategies as strat, settings, HealthCheck +from hypothesis import assume, given, strategies as strat, settings pytestmark = pytest.mark.filterwarnings("ignore") @@ -206,29 +206,23 @@ class TestDTypeALU(unittest.TestCase): @given(ht.int32, strat.sampled_from(dtypes_float+dtypes_int+dtypes_bool)) def test_int32_cast(self, a, dtype): universal_test_cast(a, dtypes.int32, dtype) - @settings(suppress_health_check=[HealthCheck.filter_too_much]) - @given(strat.data(), strat.sampled_from(dtypes_float), strat.sampled_from((dtypes.uint8, dtypes.uint16))) + @given(strat.floats(width=32, min_value=1.0, max_value=254.0, allow_subnormal=False), + strat.sampled_from(dtypes_float), strat.sampled_from((dtypes.uint8, dtypes.uint16))) def test_float_cast_to_unsigned(self, a, float_dtype, unsigned_dtype): if not is_dtype_supported(float_dtype): float_dtype = dtypes.float32 - float_strat = {dtypes.float16: ht.float16, dtypes.float32: ht.float32, dtypes.float64: ht.float64}[float_dtype] - float_strat = float_strat.filter(lambda x: 0 < x < dtypes.max(unsigned_dtype)) - universal_test_cast(a.draw(float_strat), float_dtype, unsigned_dtype) + universal_test_cast(a, float_dtype, unsigned_dtype) - @settings(suppress_health_check=[HealthCheck.filter_too_much]) - @given(strat.data(), strat.sampled_from(dtypes_float), strat.sampled_from((dtypes.uint8, dtypes.uint16))) + @given(strat.floats(width=32, min_value=256.0, max_value=65000.0, allow_subnormal=False), + strat.sampled_from(dtypes_float), strat.sampled_from((dtypes.uint8, dtypes.uint16))) def test_float_cast_to_unsigned_overflow(self, a, float_dtype, unsigned_dtype): if not is_dtype_supported(float_dtype): float_dtype = dtypes.float32 - float_strat = {dtypes.float16: ht.float16, dtypes.float32: ht.float32, dtypes.float64: ht.float64}[float_dtype] - overflow_strat = float_strat.filter(lambda x: x > dtypes.max(unsigned_dtype) and x <= dtypes.max(dtypes.int32)) - universal_test_cast(a.draw(overflow_strat), float_dtype, unsigned_dtype) + universal_test_cast(a, float_dtype, unsigned_dtype) - @settings(suppress_health_check=[HealthCheck.filter_too_much]) - @given(strat.data(), strat.sampled_from(dtypes_float), strat.sampled_from((dtypes.uint8, dtypes.uint16))) + @given(strat.floats(width=32, min_value=-65000.0, max_value=-1.0, allow_subnormal=False), + strat.sampled_from(dtypes_float), strat.sampled_from((dtypes.uint8, dtypes.uint16))) def test_float_cast_to_unsigned_underflow(self, a, float_dtype, unsigned_dtype): if not is_dtype_supported(float_dtype): float_dtype = dtypes.float32 - float_strat = {dtypes.float16: ht.float16, dtypes.float32: ht.float32, dtypes.float64: ht.float64}[float_dtype] - underflow_strat = float_strat.filter(lambda x: x < 0 and x >= dtypes.min(dtypes.int32)) - universal_test_cast(a.draw(underflow_strat), float_dtype, unsigned_dtype) + universal_test_cast(a, float_dtype, unsigned_dtype) @unittest.expectedFailure def test_unsafe_cast_float_to_int_failure(self):