everything can jit now (#2338)

This commit is contained in:
chenyu
2023-11-16 23:54:57 -05:00
committed by GitHub
parent a8875bd770
commit 8e22c0d95c
2 changed files with 10 additions and 11 deletions

View File

@@ -244,18 +244,18 @@ jobs:
run: METAL=1 python -m pytest -n=auto test/external/external_test_onnx_backend.py
- name: Run whisper test
run: METAL=1 python -m pytest test/models/test_whisper.py
- name: Check Device.DEFAULT (WEBGPU) and print some source
run: |
WEBGPU=1 python -c "from tinygrad.ops import Device; assert Device.DEFAULT == 'WEBGPU', Device.DEFAULT"
WEBGPU=1 DEBUG=4 FORWARD_ONLY=1 python3 test/test_ops.py TestOps.test_add
- name: Run linearizer and tensor core test
run: METAL=1 python -m pytest -n=auto test/test_linearizer.py
- name: Test tensor core reshape-only ops
run: METAL=1 TC=2 python -m pytest -n=auto test/test_ops.py
- name: Check Device.DEFAULT (WEBGPU) and print some source
run: |
WEBGPU=1 python -c "from tinygrad.ops import Device; assert Device.DEFAULT == 'WEBGPU', Device.DEFAULT"
WEBGPU=1 DEBUG=4 FORWARD_ONLY=1 python3 test/test_ops.py TestOps.test_add
#- name: Run webgpu pytest
# run: WEBGPU=1 WGPU_BACKEND_TYPE=Metal python -m pytest -n=auto
- name: Run webgpu dtype tests
run: WEBGPU=1 WGPU_BACKEND_TYPE=Metal python -m pytest -n=auto test/test_dtype.py
- name: Run selected webgpu tests
run: WEBGPU=1 WGPU_BACKEND_TYPE=Metal python -m pytest -n=auto test/test_dtype.py test/test_jit.py
- name: Build WEBGPU Efficientnet
run: WEBGPU=1 WGPU_BACKEND_TYPE=Metal python -m examples.compile_efficientnet
- name: Install Puppeteer

View File

@@ -1,13 +1,9 @@
#!/usr/bin/env python
import unittest
import numpy as np
from tinygrad.tensor import Tensor, Device
from tinygrad.tensor import Tensor
from tinygrad.jit import TinyJit
import pytest
pytestmark = pytest.mark.webgpu
@unittest.skipUnless(Device.DEFAULT != "WEBGPU", f"no JIT on {Device.DEFAULT}")
class TestJit(unittest.TestCase):
def test_simple_jit(self):
@TinyJit
@@ -240,5 +236,8 @@ class TestJit(unittest.TestCase):
# but the bad_jitted doesn't!
np.testing.assert_equal([1], cache.bad_jitted(zero).numpy())
assert len(cache.good_jitted.jit_cache) == 1
assert len(cache.bad_jitted.jit_cache) == 1
if __name__ == '__main__':
unittest.main()