mirror of
https://github.com/tinygrad/tinygrad.git
synced 2026-01-07 22:23:55 -05:00
update sampling in test_float_cast_to_unsigned (#13889)
filter is slow for small dtypes
This commit is contained in:
@@ -9,7 +9,7 @@ from tinygrad.renderer.ptx import PTXRenderer
|
||||
from tinygrad.renderer.nir import NIRRenderer
|
||||
import numpy as np
|
||||
import pytest
|
||||
from hypothesis import assume, given, strategies as strat, settings, HealthCheck
|
||||
from hypothesis import assume, given, strategies as strat, settings
|
||||
|
||||
pytestmark = pytest.mark.filterwarnings("ignore")
|
||||
|
||||
@@ -206,29 +206,23 @@ class TestDTypeALU(unittest.TestCase):
|
||||
@given(ht.int32, strat.sampled_from(dtypes_float+dtypes_int+dtypes_bool))
|
||||
def test_int32_cast(self, a, dtype): universal_test_cast(a, dtypes.int32, dtype)
|
||||
|
||||
@settings(suppress_health_check=[HealthCheck.filter_too_much])
|
||||
@given(strat.data(), strat.sampled_from(dtypes_float), strat.sampled_from((dtypes.uint8, dtypes.uint16)))
|
||||
@given(strat.floats(width=32, min_value=1.0, max_value=254.0, allow_subnormal=False),
|
||||
strat.sampled_from(dtypes_float), strat.sampled_from((dtypes.uint8, dtypes.uint16)))
|
||||
def test_float_cast_to_unsigned(self, a, float_dtype, unsigned_dtype):
|
||||
if not is_dtype_supported(float_dtype): float_dtype = dtypes.float32
|
||||
float_strat = {dtypes.float16: ht.float16, dtypes.float32: ht.float32, dtypes.float64: ht.float64}[float_dtype]
|
||||
float_strat = float_strat.filter(lambda x: 0 < x < dtypes.max(unsigned_dtype))
|
||||
universal_test_cast(a.draw(float_strat), float_dtype, unsigned_dtype)
|
||||
universal_test_cast(a, float_dtype, unsigned_dtype)
|
||||
|
||||
@settings(suppress_health_check=[HealthCheck.filter_too_much])
|
||||
@given(strat.data(), strat.sampled_from(dtypes_float), strat.sampled_from((dtypes.uint8, dtypes.uint16)))
|
||||
@given(strat.floats(width=32, min_value=256.0, max_value=65000.0, allow_subnormal=False),
|
||||
strat.sampled_from(dtypes_float), strat.sampled_from((dtypes.uint8, dtypes.uint16)))
|
||||
def test_float_cast_to_unsigned_overflow(self, a, float_dtype, unsigned_dtype):
|
||||
if not is_dtype_supported(float_dtype): float_dtype = dtypes.float32
|
||||
float_strat = {dtypes.float16: ht.float16, dtypes.float32: ht.float32, dtypes.float64: ht.float64}[float_dtype]
|
||||
overflow_strat = float_strat.filter(lambda x: x > dtypes.max(unsigned_dtype) and x <= dtypes.max(dtypes.int32))
|
||||
universal_test_cast(a.draw(overflow_strat), float_dtype, unsigned_dtype)
|
||||
universal_test_cast(a, float_dtype, unsigned_dtype)
|
||||
|
||||
@settings(suppress_health_check=[HealthCheck.filter_too_much])
|
||||
@given(strat.data(), strat.sampled_from(dtypes_float), strat.sampled_from((dtypes.uint8, dtypes.uint16)))
|
||||
@given(strat.floats(width=32, min_value=-65000.0, max_value=-1.0, allow_subnormal=False),
|
||||
strat.sampled_from(dtypes_float), strat.sampled_from((dtypes.uint8, dtypes.uint16)))
|
||||
def test_float_cast_to_unsigned_underflow(self, a, float_dtype, unsigned_dtype):
|
||||
if not is_dtype_supported(float_dtype): float_dtype = dtypes.float32
|
||||
float_strat = {dtypes.float16: ht.float16, dtypes.float32: ht.float32, dtypes.float64: ht.float64}[float_dtype]
|
||||
underflow_strat = float_strat.filter(lambda x: x < 0 and x >= dtypes.min(dtypes.int32))
|
||||
universal_test_cast(a.draw(underflow_strat), float_dtype, unsigned_dtype)
|
||||
universal_test_cast(a, float_dtype, unsigned_dtype)
|
||||
|
||||
@unittest.expectedFailure
|
||||
def test_unsafe_cast_float_to_int_failure(self):
|
||||
|
||||
Reference in New Issue
Block a user