update sampling in test_float_cast_to_unsigned (#13889)

filter is slow for small dtypes
This commit is contained in:
chenyu
2025-12-29 21:35:46 -05:00
committed by GitHub
parent 0497387e45
commit ab58926b00

View File

@@ -9,7 +9,7 @@ from tinygrad.renderer.ptx import PTXRenderer
from tinygrad.renderer.nir import NIRRenderer
import numpy as np
import pytest
from hypothesis import assume, given, strategies as strat, settings, HealthCheck
from hypothesis import assume, given, strategies as strat, settings
pytestmark = pytest.mark.filterwarnings("ignore")
@@ -206,29 +206,23 @@ class TestDTypeALU(unittest.TestCase):
@given(ht.int32, strat.sampled_from(dtypes_float+dtypes_int+dtypes_bool))
def test_int32_cast(self, a, dtype): universal_test_cast(a, dtypes.int32, dtype)
@settings(suppress_health_check=[HealthCheck.filter_too_much])
@given(strat.data(), strat.sampled_from(dtypes_float), strat.sampled_from((dtypes.uint8, dtypes.uint16)))
@given(strat.floats(width=32, min_value=1.0, max_value=254.0, allow_subnormal=False),
strat.sampled_from(dtypes_float), strat.sampled_from((dtypes.uint8, dtypes.uint16)))
def test_float_cast_to_unsigned(self, a, float_dtype, unsigned_dtype):
if not is_dtype_supported(float_dtype): float_dtype = dtypes.float32
float_strat = {dtypes.float16: ht.float16, dtypes.float32: ht.float32, dtypes.float64: ht.float64}[float_dtype]
float_strat = float_strat.filter(lambda x: 0 < x < dtypes.max(unsigned_dtype))
universal_test_cast(a.draw(float_strat), float_dtype, unsigned_dtype)
universal_test_cast(a, float_dtype, unsigned_dtype)
@settings(suppress_health_check=[HealthCheck.filter_too_much])
@given(strat.data(), strat.sampled_from(dtypes_float), strat.sampled_from((dtypes.uint8, dtypes.uint16)))
@given(strat.floats(width=32, min_value=256.0, max_value=65000.0, allow_subnormal=False),
strat.sampled_from(dtypes_float), strat.sampled_from((dtypes.uint8, dtypes.uint16)))
def test_float_cast_to_unsigned_overflow(self, a, float_dtype, unsigned_dtype):
if not is_dtype_supported(float_dtype): float_dtype = dtypes.float32
float_strat = {dtypes.float16: ht.float16, dtypes.float32: ht.float32, dtypes.float64: ht.float64}[float_dtype]
overflow_strat = float_strat.filter(lambda x: x > dtypes.max(unsigned_dtype) and x <= dtypes.max(dtypes.int32))
universal_test_cast(a.draw(overflow_strat), float_dtype, unsigned_dtype)
universal_test_cast(a, float_dtype, unsigned_dtype)
@settings(suppress_health_check=[HealthCheck.filter_too_much])
@given(strat.data(), strat.sampled_from(dtypes_float), strat.sampled_from((dtypes.uint8, dtypes.uint16)))
@given(strat.floats(width=32, min_value=-65000.0, max_value=-1.0, allow_subnormal=False),
strat.sampled_from(dtypes_float), strat.sampled_from((dtypes.uint8, dtypes.uint16)))
def test_float_cast_to_unsigned_underflow(self, a, float_dtype, unsigned_dtype):
if not is_dtype_supported(float_dtype): float_dtype = dtypes.float32
float_strat = {dtypes.float16: ht.float16, dtypes.float32: ht.float32, dtypes.float64: ht.float64}[float_dtype]
underflow_strat = float_strat.filter(lambda x: x < 0 and x >= dtypes.min(dtypes.int32))
universal_test_cast(a.draw(underflow_strat), float_dtype, unsigned_dtype)
universal_test_cast(a, float_dtype, unsigned_dtype)
@unittest.expectedFailure
def test_unsafe_cast_float_to_int_failure(self):