mirror of
https://github.com/tinygrad/tinygrad.git
synced 2026-01-24 14:28:09 -05:00
cleanup test_dtype_alu (#2919)
wrapped long lines and lowered atol for METAL.sin to 2 since atol of two sins are bounded by 2
This commit is contained in:
@@ -61,7 +61,11 @@ def universal_test_unary(a, dtype, op):
|
||||
ast = out.lazydata.schedule()[-1].ast
|
||||
tensor_value = out.numpy()
|
||||
numpy_value = op[1](np.array([a]).astype(dtype.np))
|
||||
if dtype in dtypes_float: np.testing.assert_allclose(tensor_value, numpy_value, atol=5 if Device.DEFAULT == "METAL" and op[0] == Tensor.sin else 1e-3, rtol=2 if Device.DEFAULT == "METAL" and op[0] == Tensor.sin else 1e-4 if dtype == dtypes.float32 else 1e-2) # exp and log and sin are approximations (in METAL, the default fast-math versions are less precise) # noqa: E501
|
||||
if dtype in dtypes_float:
|
||||
atol = 2 if Device.DEFAULT == "METAL" and op[0] == Tensor.sin else 1e-3
|
||||
rtol = 2 if Device.DEFAULT == "METAL" and op[0] == Tensor.sin else 1e-4 if dtype == dtypes.float32 else 1e-2
|
||||
# exp and log and sin are approximations (in METAL, the default fast-math versions are less precise)
|
||||
np.testing.assert_allclose(tensor_value, numpy_value, atol=atol, rtol=rtol)
|
||||
else: np.testing.assert_equal(tensor_value, numpy_value)
|
||||
if op[0] != Tensor.reciprocal: # reciprocal is not supported in most backends
|
||||
op = [x for x in ast.lazyops if x.op in UnaryOps][0]
|
||||
@@ -137,7 +141,9 @@ class TestDTypeALU(unittest.TestCase):
|
||||
def test_int32_midcast_float(self, a, b, c, op1, op2): universal_test_midcast(a, b, c, op1, op2, dtypes.int32, dtypes.float32)
|
||||
|
||||
# Metal and CUDACPU behave differently than numpy in CI for overflows
|
||||
@given(st.floats(width=32, min_value=0, max_value=10.0) if CI and (Device.DEFAULT == "METAL" or getenv("CUDACPU")) else ht.float32, st.floats(width=32, min_value=0, max_value=10.0) if CI and (Device.DEFAULT == "METAL" or getenv("CUDACPU")) else ht.float32, ht.int32, st.sampled_from(binary_operations), st.sampled_from(integer_binary_operations)) # noqa: E501
|
||||
@given(st.floats(width=32, min_value=0, max_value=10.0) if CI and (Device.DEFAULT == "METAL" or getenv("CUDACPU")) else ht.float32,
|
||||
st.floats(width=32, min_value=0, max_value=10.0) if CI and (Device.DEFAULT == "METAL" or getenv("CUDACPU")) else ht.float32,
|
||||
ht.int32, st.sampled_from(binary_operations), st.sampled_from(integer_binary_operations))
|
||||
def test_float_midcast_int32(self, a, b, c, op1, op2): universal_test_midcast(a, b, c, op1, op2, dtypes.float32, dtypes.int32)
|
||||
|
||||
@given(ht.float32, st.sampled_from(dtypes_float+dtypes_int+dtypes_bool))
|
||||
|
||||
Reference in New Issue
Block a user