cleanup test_dtype_alu (#2919)

wrapped long lines and lowered atol for METAL.sin to 2 since atol of two sins are bounded by 2
This commit is contained in:
chenyu
2023-12-22 17:29:31 -05:00
committed by GitHub
parent 3ba591c3fd
commit 089703a390

View File

@@ -61,7 +61,11 @@ def universal_test_unary(a, dtype, op):
ast = out.lazydata.schedule()[-1].ast
tensor_value = out.numpy()
numpy_value = op[1](np.array([a]).astype(dtype.np))
if dtype in dtypes_float: np.testing.assert_allclose(tensor_value, numpy_value, atol=5 if Device.DEFAULT == "METAL" and op[0] == Tensor.sin else 1e-3, rtol=2 if Device.DEFAULT == "METAL" and op[0] == Tensor.sin else 1e-4 if dtype == dtypes.float32 else 1e-2) # exp and log and sin are approximations (in METAL, the default fast-math versions are less precise) # noqa: E501
if dtype in dtypes_float:
atol = 2 if Device.DEFAULT == "METAL" and op[0] == Tensor.sin else 1e-3
rtol = 2 if Device.DEFAULT == "METAL" and op[0] == Tensor.sin else 1e-4 if dtype == dtypes.float32 else 1e-2
# exp and log and sin are approximations (in METAL, the default fast-math versions are less precise)
np.testing.assert_allclose(tensor_value, numpy_value, atol=atol, rtol=rtol)
else: np.testing.assert_equal(tensor_value, numpy_value)
if op[0] != Tensor.reciprocal: # reciprocal is not supported in most backends
op = [x for x in ast.lazyops if x.op in UnaryOps][0]
@@ -137,7 +141,9 @@ class TestDTypeALU(unittest.TestCase):
def test_int32_midcast_float(self, a, b, c, op1, op2): universal_test_midcast(a, b, c, op1, op2, dtypes.int32, dtypes.float32)
# Metal and CUDACPU behave differently than numpy in CI for overflows
@given(st.floats(width=32, min_value=0, max_value=10.0) if CI and (Device.DEFAULT == "METAL" or getenv("CUDACPU")) else ht.float32, st.floats(width=32, min_value=0, max_value=10.0) if CI and (Device.DEFAULT == "METAL" or getenv("CUDACPU")) else ht.float32, ht.int32, st.sampled_from(binary_operations), st.sampled_from(integer_binary_operations)) # noqa: E501
@given(st.floats(width=32, min_value=0, max_value=10.0) if CI and (Device.DEFAULT == "METAL" or getenv("CUDACPU")) else ht.float32,
st.floats(width=32, min_value=0, max_value=10.0) if CI and (Device.DEFAULT == "METAL" or getenv("CUDACPU")) else ht.float32,
ht.int32, st.sampled_from(binary_operations), st.sampled_from(integer_binary_operations))
def test_float_midcast_int32(self, a, b, c, op1, op2): universal_test_midcast(a, b, c, op1, op2, dtypes.float32, dtypes.int32)
@given(ht.float32, st.sampled_from(dtypes_float+dtypes_int+dtypes_bool))