mirror of
https://github.com/tinygrad/tinygrad.git
synced 2026-01-10 23:48:01 -05:00
update test_dtype_alu for METAL (#3629)
This commit is contained in:
@@ -19,8 +19,8 @@ dtypes_int = (dtypes.int8, dtypes.int16, dtypes.int32, dtypes.int64, dtypes.uint
|
||||
dtypes_bool = (dtypes.bool,)
|
||||
binary_operations = [operator.add, operator.sub, operator.mul, operator.lt, operator.eq]
|
||||
|
||||
# TODO: LLVM and METAL comparing with nan is incorrect
|
||||
if Device.DEFAULT in {"LLVM", "METAL"}:
|
||||
# TODO: LLVM comparing with nan is incorrect
|
||||
if Device.DEFAULT == "LLVM":
|
||||
binary_operations.remove(operator.lt)
|
||||
binary_operations.remove(operator.eq)
|
||||
|
||||
@@ -38,8 +38,7 @@ unary_operations = [(Tensor.exp, np.exp), (Tensor.log, np.log), operator.neg, (T
|
||||
#binary_operations += [(Tensor.maximum, np.maximum)]
|
||||
|
||||
# TODO: CUDACPU segfaults on sin
|
||||
# TODO: METAL sin can't handle infinity
|
||||
if getenv("CUDACPU") or Device.DEFAULT == "METAL": unary_operations.remove((Tensor.sin, np.sin))
|
||||
if getenv("CUDACPU"): unary_operations.remove((Tensor.sin, np.sin))
|
||||
|
||||
class ht:
|
||||
float64 = strat.floats(width=64, allow_subnormal=False)
|
||||
@@ -142,7 +141,7 @@ class TestDTypeALU(unittest.TestCase):
|
||||
def test_int32_midcast_float(self, a, b, c, op1, op2): universal_test_midcast(a, b, c, op1, op2, dtypes.int32, dtypes.float32)
|
||||
|
||||
# Metal and CUDACPU and HIP behave differently than numpy in CI for overflows
|
||||
skip_overflow = CI and (Device.DEFAULT in ["METAL","HIP"] or getenv("CUDACPU"))
|
||||
skip_overflow = CI and (Device.DEFAULT == "HIP" or getenv("CUDACPU"))
|
||||
@given(strat.floats(width=32, min_value=0, max_value=10.0) if skip_overflow else ht.float32,
|
||||
strat.floats(width=32, min_value=0, max_value=10.0) if skip_overflow else ht.float32,
|
||||
ht.int32, strat.sampled_from(binary_operations), strat.sampled_from(integer_binary_operations))
|
||||
|
||||
Reference in New Issue
Block a user