update sampling in test_float_cast_to_unsigned (#13889)

filter is slow for small dtypes
This commit is contained in:
chenyu
2025-12-29 21:35:46 -05:00
committed by GitHub
parent 0497387e45
commit ab58926b00

View File

@@ -9,7 +9,7 @@ from tinygrad.renderer.ptx import PTXRenderer
from tinygrad.renderer.nir import NIRRenderer from tinygrad.renderer.nir import NIRRenderer
import numpy as np import numpy as np
import pytest import pytest
from hypothesis import assume, given, strategies as strat, settings, HealthCheck from hypothesis import assume, given, strategies as strat, settings
pytestmark = pytest.mark.filterwarnings("ignore") pytestmark = pytest.mark.filterwarnings("ignore")
@@ -206,29 +206,23 @@ class TestDTypeALU(unittest.TestCase):
@given(ht.int32, strat.sampled_from(dtypes_float+dtypes_int+dtypes_bool)) @given(ht.int32, strat.sampled_from(dtypes_float+dtypes_int+dtypes_bool))
def test_int32_cast(self, a, dtype): universal_test_cast(a, dtypes.int32, dtype) def test_int32_cast(self, a, dtype): universal_test_cast(a, dtypes.int32, dtype)
@settings(suppress_health_check=[HealthCheck.filter_too_much]) @given(strat.floats(width=32, min_value=1.0, max_value=254.0, allow_subnormal=False),
@given(strat.data(), strat.sampled_from(dtypes_float), strat.sampled_from((dtypes.uint8, dtypes.uint16))) strat.sampled_from(dtypes_float), strat.sampled_from((dtypes.uint8, dtypes.uint16)))
def test_float_cast_to_unsigned(self, a, float_dtype, unsigned_dtype): def test_float_cast_to_unsigned(self, a, float_dtype, unsigned_dtype):
if not is_dtype_supported(float_dtype): float_dtype = dtypes.float32 if not is_dtype_supported(float_dtype): float_dtype = dtypes.float32
float_strat = {dtypes.float16: ht.float16, dtypes.float32: ht.float32, dtypes.float64: ht.float64}[float_dtype] universal_test_cast(a, float_dtype, unsigned_dtype)
float_strat = float_strat.filter(lambda x: 0 < x < dtypes.max(unsigned_dtype))
universal_test_cast(a.draw(float_strat), float_dtype, unsigned_dtype)
@settings(suppress_health_check=[HealthCheck.filter_too_much]) @given(strat.floats(width=32, min_value=256.0, max_value=65000.0, allow_subnormal=False),
@given(strat.data(), strat.sampled_from(dtypes_float), strat.sampled_from((dtypes.uint8, dtypes.uint16))) strat.sampled_from(dtypes_float), strat.sampled_from((dtypes.uint8, dtypes.uint16)))
def test_float_cast_to_unsigned_overflow(self, a, float_dtype, unsigned_dtype): def test_float_cast_to_unsigned_overflow(self, a, float_dtype, unsigned_dtype):
if not is_dtype_supported(float_dtype): float_dtype = dtypes.float32 if not is_dtype_supported(float_dtype): float_dtype = dtypes.float32
float_strat = {dtypes.float16: ht.float16, dtypes.float32: ht.float32, dtypes.float64: ht.float64}[float_dtype] universal_test_cast(a, float_dtype, unsigned_dtype)
overflow_strat = float_strat.filter(lambda x: x > dtypes.max(unsigned_dtype) and x <= dtypes.max(dtypes.int32))
universal_test_cast(a.draw(overflow_strat), float_dtype, unsigned_dtype)
@settings(suppress_health_check=[HealthCheck.filter_too_much]) @given(strat.floats(width=32, min_value=-65000.0, max_value=-1.0, allow_subnormal=False),
@given(strat.data(), strat.sampled_from(dtypes_float), strat.sampled_from((dtypes.uint8, dtypes.uint16))) strat.sampled_from(dtypes_float), strat.sampled_from((dtypes.uint8, dtypes.uint16)))
def test_float_cast_to_unsigned_underflow(self, a, float_dtype, unsigned_dtype): def test_float_cast_to_unsigned_underflow(self, a, float_dtype, unsigned_dtype):
if not is_dtype_supported(float_dtype): float_dtype = dtypes.float32 if not is_dtype_supported(float_dtype): float_dtype = dtypes.float32
float_strat = {dtypes.float16: ht.float16, dtypes.float32: ht.float32, dtypes.float64: ht.float64}[float_dtype] universal_test_cast(a, float_dtype, unsigned_dtype)
underflow_strat = float_strat.filter(lambda x: x < 0 and x >= dtypes.min(dtypes.int32))
universal_test_cast(a.draw(underflow_strat), float_dtype, unsigned_dtype)
@unittest.expectedFailure @unittest.expectedFailure
def test_unsafe_cast_float_to_int_failure(self): def test_unsafe_cast_float_to_int_failure(self):