From 089703a390cb6b799fd364bf4ae7456e75939449 Mon Sep 17 00:00:00 2001 From: chenyu Date: Fri, 22 Dec 2023 17:29:31 -0500 Subject: [PATCH] cleanup test_dtype_alu (#2919) wrapped long lines and lowered atol for METAL.sin to 2 since atol of two sins are bounded by 2 --- test/test_dtype_alu.py | 10 ++++++++-- 1 file changed, 8 insertions(+), 2 deletions(-) diff --git a/test/test_dtype_alu.py b/test/test_dtype_alu.py index a08dfaa7a2..12f0f3f41c 100644 --- a/test/test_dtype_alu.py +++ b/test/test_dtype_alu.py @@ -61,7 +61,11 @@ def universal_test_unary(a, dtype, op): ast = out.lazydata.schedule()[-1].ast tensor_value = out.numpy() numpy_value = op[1](np.array([a]).astype(dtype.np)) - if dtype in dtypes_float: np.testing.assert_allclose(tensor_value, numpy_value, atol=5 if Device.DEFAULT == "METAL" and op[0] == Tensor.sin else 1e-3, rtol=2 if Device.DEFAULT == "METAL" and op[0] == Tensor.sin else 1e-4 if dtype == dtypes.float32 else 1e-2) # exp and log and sin are approximations (in METAL, the default fast-math versions are less precise) # noqa: E501 + if dtype in dtypes_float: + atol = 2 if Device.DEFAULT == "METAL" and op[0] == Tensor.sin else 1e-3 + rtol = 2 if Device.DEFAULT == "METAL" and op[0] == Tensor.sin else 1e-4 if dtype == dtypes.float32 else 1e-2 + # exp and log and sin are approximations (in METAL, the default fast-math versions are less precise) + np.testing.assert_allclose(tensor_value, numpy_value, atol=atol, rtol=rtol) else: np.testing.assert_equal(tensor_value, numpy_value) if op[0] != Tensor.reciprocal: # reciprocal is not supported in most backends op = [x for x in ast.lazyops if x.op in UnaryOps][0] @@ -137,7 +141,9 @@ class TestDTypeALU(unittest.TestCase): def test_int32_midcast_float(self, a, b, c, op1, op2): universal_test_midcast(a, b, c, op1, op2, dtypes.int32, dtypes.float32) # Metal and CUDACPU behave differently than numpy in CI for overflows - @given(st.floats(width=32, min_value=0, max_value=10.0) if CI and (Device.DEFAULT == "METAL" or getenv("CUDACPU")) else ht.float32, st.floats(width=32, min_value=0, max_value=10.0) if CI and (Device.DEFAULT == "METAL" or getenv("CUDACPU")) else ht.float32, ht.int32, st.sampled_from(binary_operations), st.sampled_from(integer_binary_operations)) # noqa: E501 + @given(st.floats(width=32, min_value=0, max_value=10.0) if CI and (Device.DEFAULT == "METAL" or getenv("CUDACPU")) else ht.float32, + st.floats(width=32, min_value=0, max_value=10.0) if CI and (Device.DEFAULT == "METAL" or getenv("CUDACPU")) else ht.float32, + ht.int32, st.sampled_from(binary_operations), st.sampled_from(integer_binary_operations)) def test_float_midcast_int32(self, a, b, c, op1, op2): universal_test_midcast(a, b, c, op1, op2, dtypes.float32, dtypes.int32) @given(ht.float32, st.sampled_from(dtypes_float+dtypes_int+dtypes_bool))