Run WebGPU tests on ubuntu (#8033)

This commit is contained in:
Ahmed Harmouche
2024-12-04 12:42:04 +01:00
committed by GitHub
parent fb89971e73
commit 13eedd373b
3 changed files with 20 additions and 12 deletions

View File

@@ -60,6 +60,8 @@ class ht:
bool = strat.booleans()
def universal_test(a, b, dtype, op):
# The 'nan' cases only fail with Vulkan WebGPU backend (CI)
if (math.isnan(a) or math.isnan(b)) and Device.DEFAULT == "WEBGPU" and CI: return
if not isinstance(op, tuple): op = (op, op)
tensor_value = (op[0](Tensor([a], dtype=dtype), Tensor([b], dtype=dtype))).numpy()
numpy_value = op[1](np.array([a]).astype(_to_np_dtype(dtype)), np.array([b]).astype(_to_np_dtype(dtype)))
@@ -89,7 +91,7 @@ def universal_test_cast(a, in_dtype, dtype):
def universal_test_midcast(a, b, c, op1, op2, d1:DType, d2:DType):
# the 'inf' and 'nan' cases are wrong on WEBGPU
if (c in [math.inf, -math.inf] or math.isnan(c)) and Device.DEFAULT == "WEBGPU": return
if (any(map(math.isnan, [a, b, c])) or math.isinf(c)) and Device.DEFAULT == "WEBGPU": return
if not isinstance(op1, tuple): op1 = (op1, op1)
if not isinstance(op2, tuple): op2 = (op2, op2)
at, bt, ct = Tensor([a], dtype=d1), Tensor([b], dtype=d1), Tensor([c], dtype=d2)

View File

@@ -378,6 +378,7 @@ class TestOps(unittest.TestCase):
helper_test_op(None, lambda x: x.round(), vals=[[1.499, 1.5, 1.501, 1.0, 2.1, 0.0, -5.0, -2.499, -2.5, -2.501]], forward_only=True)
helper_test_op(None, lambda x: x.round(), vals=[[2.5, -1.5]], forward_only=True)
@unittest.skipIf(Device.DEFAULT == "WEBGPU" and CI, "isinf check of 'nan' fails on CI software-based vulkan")
def test_isinf(self):
val = [float('-inf'), 0., float('inf'), float('nan'), 1.1]
helper_test_op(None, torch.isinf, Tensor.isinf, vals=[val], forward_only=True)