diff --git a/.github/workflows/test.yml b/.github/workflows/test.yml index bf0e985c67..27977099d8 100644 --- a/.github/workflows/test.yml +++ b/.github/workflows/test.yml @@ -244,18 +244,18 @@ jobs: run: METAL=1 python -m pytest -n=auto test/external/external_test_onnx_backend.py - name: Run whisper test run: METAL=1 python -m pytest test/models/test_whisper.py - - name: Check Device.DEFAULT (WEBGPU) and print some source - run: | - WEBGPU=1 python -c "from tinygrad.ops import Device; assert Device.DEFAULT == 'WEBGPU', Device.DEFAULT" - WEBGPU=1 DEBUG=4 FORWARD_ONLY=1 python3 test/test_ops.py TestOps.test_add - name: Run linearizer and tensor core test run: METAL=1 python -m pytest -n=auto test/test_linearizer.py - name: Test tensor core reshape-only ops run: METAL=1 TC=2 python -m pytest -n=auto test/test_ops.py + - name: Check Device.DEFAULT (WEBGPU) and print some source + run: | + WEBGPU=1 python -c "from tinygrad.ops import Device; assert Device.DEFAULT == 'WEBGPU', Device.DEFAULT" + WEBGPU=1 DEBUG=4 FORWARD_ONLY=1 python3 test/test_ops.py TestOps.test_add #- name: Run webgpu pytest # run: WEBGPU=1 WGPU_BACKEND_TYPE=Metal python -m pytest -n=auto - - name: Run webgpu dtype tests - run: WEBGPU=1 WGPU_BACKEND_TYPE=Metal python -m pytest -n=auto test/test_dtype.py + - name: Run selected webgpu tests + run: WEBGPU=1 WGPU_BACKEND_TYPE=Metal python -m pytest -n=auto test/test_dtype.py test/test_jit.py - name: Build WEBGPU Efficientnet run: WEBGPU=1 WGPU_BACKEND_TYPE=Metal python -m examples.compile_efficientnet - name: Install Puppeteer diff --git a/test/test_jit.py b/test/test_jit.py index 6439124c40..5747bd79f8 100644 --- a/test/test_jit.py +++ b/test/test_jit.py @@ -1,13 +1,9 @@ #!/usr/bin/env python import unittest import numpy as np -from tinygrad.tensor import Tensor, Device +from tinygrad.tensor import Tensor from tinygrad.jit import TinyJit -import pytest -pytestmark = pytest.mark.webgpu - -@unittest.skipUnless(Device.DEFAULT != "WEBGPU", f"no JIT on {Device.DEFAULT}") class TestJit(unittest.TestCase): def test_simple_jit(self): @TinyJit @@ -240,5 +236,8 @@ class TestJit(unittest.TestCase): # but the bad_jitted doesn't! np.testing.assert_equal([1], cache.bad_jitted(zero).numpy()) + assert len(cache.good_jitted.jit_cache) == 1 + assert len(cache.bad_jitted.jit_cache) == 1 + if __name__ == '__main__': unittest.main() \ No newline at end of file