From 40a4c603b9a3533ce936dc98b240fffddc257d50 Mon Sep 17 00:00:00 2001 From: chenyu Date: Thu, 12 Dec 2024 14:06:35 -0500 Subject: [PATCH] remove more test skip for webgpu [pr] (#8192) --- test/test_dtype.py | 3 +-- test/test_nn.py | 1 - test/test_speed_v_torch.py | 1 - test/test_tensor.py | 1 - test/unit/test_disk_tensor.py | 2 +- 5 files changed, 2 insertions(+), 6 deletions(-) diff --git a/test/test_dtype.py b/test/test_dtype.py index a003bd2701..741c0dcfdb 100644 --- a/test/test_dtype.py +++ b/test/test_dtype.py @@ -646,8 +646,7 @@ class TestAutoCastType(unittest.TestCase): def test_broadcast_scalar(self, dt): assert (Tensor.ones(4, 4, dtype=dt) + 2.3).dtype == (dt if dtypes.is_float(dt) else dtypes.default_float) assert (Tensor.ones(4, 4, dtype=dt) + 2).dtype == (dt if dtypes.is_float(dt) or dtypes.is_int(dt) else dtypes.default_int) - if Device.DEFAULT != "WEBGPU" and dt != dtypes.bool: - assert (Tensor.ones(4, 4, dtype=dt) + True).dtype == dt + assert (Tensor.ones(4, 4, dtype=dt) + True).dtype == dt def test_sum(self): assert (Tensor([0, 1], dtype=dtypes.bool)).sum().dtype == dtypes.int32 diff --git a/test/test_nn.py b/test/test_nn.py index 14018c9a48..198b0eb1a6 100755 --- a/test/test_nn.py +++ b/test/test_nn.py @@ -14,7 +14,6 @@ from tinygrad.device import is_dtype_supported @unittest.skipIf(CI and Device.DEFAULT in {"CUDA", "NV"}, "slow") class TestNN(unittest.TestCase): - @unittest.skipIf(Device.DEFAULT == "WEBGPU", "no int64 on WebGPU") def test_sparse_cat_cross_entropy(self): # create in tinygrad input_tensor = Tensor.randn(6, 5) # not square to test that mean scaling uses the correct dimension diff --git a/test/test_speed_v_torch.py b/test/test_speed_v_torch.py index d911ac004d..1cf1a3e044 100644 --- a/test/test_speed_v_torch.py +++ b/test/test_speed_v_torch.py @@ -137,7 +137,6 @@ class TestSpeed(unittest.TestCase): def f(a, b): return a-b helper_test_generic_square('sub', 4096, f, f) - @unittest.skipIf(CI and Device.DEFAULT == "WEBGPU", "breaking on webgpu CI") def test_pow(self): def f(a, b): return a.pow(b) helper_test_generic_square('pow', 2048, f, f) diff --git a/test/test_tensor.py b/test/test_tensor.py index c1907d7130..d34556792e 100644 --- a/test/test_tensor.py +++ b/test/test_tensor.py @@ -135,7 +135,6 @@ class TestTinygrad(unittest.TestCase): for x,y in zip(test_tinygrad(), test_pytorch()): np.testing.assert_allclose(x, y, atol=1e-5) - @unittest.skipIf(Device.DEFAULT == "WEBGPU", "this test uses more than 8 bufs which breaks webgpu") #TODO: remove after #1461 def test_backward_pass_diamond_model(self): def test_tinygrad(): u = Tensor(U_init, requires_grad=True) diff --git a/test/unit/test_disk_tensor.py b/test/unit/test_disk_tensor.py index ea8ac39b14..917a497e9d 100644 --- a/test/unit/test_disk_tensor.py +++ b/test/unit/test_disk_tensor.py @@ -82,7 +82,7 @@ class TestRawDiskBuffer(unittest.TestCase): pathlib.Path(tmp).unlink() -@unittest.skipIf(Device.DEFAULT == "WEBGPU", "webgpu doesn't support uint8 datatype") +@unittest.skipUnless(is_dtype_supported(dtypes.uint8), "need uint8") class TestSafetensors(unittest.TestCase): def test_real_safetensors(self): import torch