update test_dtype_alu for METAL (#3629)

This commit is contained in:
chenyu
2024-03-06 14:55:19 -05:00
committed by GitHub
parent abc5f3a6a0
commit c270d54c32

View File

@@ -19,8 +19,8 @@ dtypes_int = (dtypes.int8, dtypes.int16, dtypes.int32, dtypes.int64, dtypes.uint
dtypes_bool = (dtypes.bool,)
binary_operations = [operator.add, operator.sub, operator.mul, operator.lt, operator.eq]
# TODO: LLVM and METAL comparing with nan is incorrect
if Device.DEFAULT in {"LLVM", "METAL"}:
# TODO: LLVM comparing with nan is incorrect
if Device.DEFAULT == "LLVM":
binary_operations.remove(operator.lt)
binary_operations.remove(operator.eq)
@@ -38,8 +38,7 @@ unary_operations = [(Tensor.exp, np.exp), (Tensor.log, np.log), operator.neg, (T
#binary_operations += [(Tensor.maximum, np.maximum)]
# TODO: CUDACPU segfaults on sin
# TODO: METAL sin can't handle infinity
if getenv("CUDACPU") or Device.DEFAULT == "METAL": unary_operations.remove((Tensor.sin, np.sin))
if getenv("CUDACPU"): unary_operations.remove((Tensor.sin, np.sin))
class ht:
float64 = strat.floats(width=64, allow_subnormal=False)
@@ -142,7 +141,7 @@ class TestDTypeALU(unittest.TestCase):
def test_int32_midcast_float(self, a, b, c, op1, op2): universal_test_midcast(a, b, c, op1, op2, dtypes.int32, dtypes.float32)
# Metal and CUDACPU and HIP behave differently than numpy in CI for overflows
skip_overflow = CI and (Device.DEFAULT in ["METAL","HIP"] or getenv("CUDACPU"))
skip_overflow = CI and (Device.DEFAULT == "HIP" or getenv("CUDACPU"))
@given(strat.floats(width=32, min_value=0, max_value=10.0) if skip_overflow else ht.float32,
strat.floats(width=32, min_value=0, max_value=10.0) if skip_overflow else ht.float32,
ht.int32, strat.sampled_from(binary_operations), strat.sampled_from(integer_binary_operations))