mirror of
https://github.com/tinygrad/tinygrad.git
synced 2026-01-09 15:08:02 -05:00
update sampling in test_float_cast_to_unsigned (#13889)
filter is slow for small dtypes
This commit is contained in:
@@ -9,7 +9,7 @@ from tinygrad.renderer.ptx import PTXRenderer
|
|||||||
from tinygrad.renderer.nir import NIRRenderer
|
from tinygrad.renderer.nir import NIRRenderer
|
||||||
import numpy as np
|
import numpy as np
|
||||||
import pytest
|
import pytest
|
||||||
from hypothesis import assume, given, strategies as strat, settings, HealthCheck
|
from hypothesis import assume, given, strategies as strat, settings
|
||||||
|
|
||||||
pytestmark = pytest.mark.filterwarnings("ignore")
|
pytestmark = pytest.mark.filterwarnings("ignore")
|
||||||
|
|
||||||
@@ -206,29 +206,23 @@ class TestDTypeALU(unittest.TestCase):
|
|||||||
@given(ht.int32, strat.sampled_from(dtypes_float+dtypes_int+dtypes_bool))
|
@given(ht.int32, strat.sampled_from(dtypes_float+dtypes_int+dtypes_bool))
|
||||||
def test_int32_cast(self, a, dtype): universal_test_cast(a, dtypes.int32, dtype)
|
def test_int32_cast(self, a, dtype): universal_test_cast(a, dtypes.int32, dtype)
|
||||||
|
|
||||||
@settings(suppress_health_check=[HealthCheck.filter_too_much])
|
@given(strat.floats(width=32, min_value=1.0, max_value=254.0, allow_subnormal=False),
|
||||||
@given(strat.data(), strat.sampled_from(dtypes_float), strat.sampled_from((dtypes.uint8, dtypes.uint16)))
|
strat.sampled_from(dtypes_float), strat.sampled_from((dtypes.uint8, dtypes.uint16)))
|
||||||
def test_float_cast_to_unsigned(self, a, float_dtype, unsigned_dtype):
|
def test_float_cast_to_unsigned(self, a, float_dtype, unsigned_dtype):
|
||||||
if not is_dtype_supported(float_dtype): float_dtype = dtypes.float32
|
if not is_dtype_supported(float_dtype): float_dtype = dtypes.float32
|
||||||
float_strat = {dtypes.float16: ht.float16, dtypes.float32: ht.float32, dtypes.float64: ht.float64}[float_dtype]
|
universal_test_cast(a, float_dtype, unsigned_dtype)
|
||||||
float_strat = float_strat.filter(lambda x: 0 < x < dtypes.max(unsigned_dtype))
|
|
||||||
universal_test_cast(a.draw(float_strat), float_dtype, unsigned_dtype)
|
|
||||||
|
|
||||||
@settings(suppress_health_check=[HealthCheck.filter_too_much])
|
@given(strat.floats(width=32, min_value=256.0, max_value=65000.0, allow_subnormal=False),
|
||||||
@given(strat.data(), strat.sampled_from(dtypes_float), strat.sampled_from((dtypes.uint8, dtypes.uint16)))
|
strat.sampled_from(dtypes_float), strat.sampled_from((dtypes.uint8, dtypes.uint16)))
|
||||||
def test_float_cast_to_unsigned_overflow(self, a, float_dtype, unsigned_dtype):
|
def test_float_cast_to_unsigned_overflow(self, a, float_dtype, unsigned_dtype):
|
||||||
if not is_dtype_supported(float_dtype): float_dtype = dtypes.float32
|
if not is_dtype_supported(float_dtype): float_dtype = dtypes.float32
|
||||||
float_strat = {dtypes.float16: ht.float16, dtypes.float32: ht.float32, dtypes.float64: ht.float64}[float_dtype]
|
universal_test_cast(a, float_dtype, unsigned_dtype)
|
||||||
overflow_strat = float_strat.filter(lambda x: x > dtypes.max(unsigned_dtype) and x <= dtypes.max(dtypes.int32))
|
|
||||||
universal_test_cast(a.draw(overflow_strat), float_dtype, unsigned_dtype)
|
|
||||||
|
|
||||||
@settings(suppress_health_check=[HealthCheck.filter_too_much])
|
@given(strat.floats(width=32, min_value=-65000.0, max_value=-1.0, allow_subnormal=False),
|
||||||
@given(strat.data(), strat.sampled_from(dtypes_float), strat.sampled_from((dtypes.uint8, dtypes.uint16)))
|
strat.sampled_from(dtypes_float), strat.sampled_from((dtypes.uint8, dtypes.uint16)))
|
||||||
def test_float_cast_to_unsigned_underflow(self, a, float_dtype, unsigned_dtype):
|
def test_float_cast_to_unsigned_underflow(self, a, float_dtype, unsigned_dtype):
|
||||||
if not is_dtype_supported(float_dtype): float_dtype = dtypes.float32
|
if not is_dtype_supported(float_dtype): float_dtype = dtypes.float32
|
||||||
float_strat = {dtypes.float16: ht.float16, dtypes.float32: ht.float32, dtypes.float64: ht.float64}[float_dtype]
|
universal_test_cast(a, float_dtype, unsigned_dtype)
|
||||||
underflow_strat = float_strat.filter(lambda x: x < 0 and x >= dtypes.min(dtypes.int32))
|
|
||||||
universal_test_cast(a.draw(underflow_strat), float_dtype, unsigned_dtype)
|
|
||||||
|
|
||||||
@unittest.expectedFailure
|
@unittest.expectedFailure
|
||||||
def test_unsafe_cast_float_to_int_failure(self):
|
def test_unsafe_cast_float_to_int_failure(self):
|
||||||
|
|||||||
Reference in New Issue
Block a user