mirror of
https://github.com/tinygrad/tinygrad.git
synced 2026-01-10 07:28:15 -05:00
everything can jit now (#2338)
This commit is contained in:
12
.github/workflows/test.yml
vendored
12
.github/workflows/test.yml
vendored
@@ -244,18 +244,18 @@ jobs:
|
||||
run: METAL=1 python -m pytest -n=auto test/external/external_test_onnx_backend.py
|
||||
- name: Run whisper test
|
||||
run: METAL=1 python -m pytest test/models/test_whisper.py
|
||||
- name: Check Device.DEFAULT (WEBGPU) and print some source
|
||||
run: |
|
||||
WEBGPU=1 python -c "from tinygrad.ops import Device; assert Device.DEFAULT == 'WEBGPU', Device.DEFAULT"
|
||||
WEBGPU=1 DEBUG=4 FORWARD_ONLY=1 python3 test/test_ops.py TestOps.test_add
|
||||
- name: Run linearizer and tensor core test
|
||||
run: METAL=1 python -m pytest -n=auto test/test_linearizer.py
|
||||
- name: Test tensor core reshape-only ops
|
||||
run: METAL=1 TC=2 python -m pytest -n=auto test/test_ops.py
|
||||
- name: Check Device.DEFAULT (WEBGPU) and print some source
|
||||
run: |
|
||||
WEBGPU=1 python -c "from tinygrad.ops import Device; assert Device.DEFAULT == 'WEBGPU', Device.DEFAULT"
|
||||
WEBGPU=1 DEBUG=4 FORWARD_ONLY=1 python3 test/test_ops.py TestOps.test_add
|
||||
#- name: Run webgpu pytest
|
||||
# run: WEBGPU=1 WGPU_BACKEND_TYPE=Metal python -m pytest -n=auto
|
||||
- name: Run webgpu dtype tests
|
||||
run: WEBGPU=1 WGPU_BACKEND_TYPE=Metal python -m pytest -n=auto test/test_dtype.py
|
||||
- name: Run selected webgpu tests
|
||||
run: WEBGPU=1 WGPU_BACKEND_TYPE=Metal python -m pytest -n=auto test/test_dtype.py test/test_jit.py
|
||||
- name: Build WEBGPU Efficientnet
|
||||
run: WEBGPU=1 WGPU_BACKEND_TYPE=Metal python -m examples.compile_efficientnet
|
||||
- name: Install Puppeteer
|
||||
|
||||
@@ -1,13 +1,9 @@
|
||||
#!/usr/bin/env python
|
||||
import unittest
|
||||
import numpy as np
|
||||
from tinygrad.tensor import Tensor, Device
|
||||
from tinygrad.tensor import Tensor
|
||||
from tinygrad.jit import TinyJit
|
||||
import pytest
|
||||
|
||||
pytestmark = pytest.mark.webgpu
|
||||
|
||||
@unittest.skipUnless(Device.DEFAULT != "WEBGPU", f"no JIT on {Device.DEFAULT}")
|
||||
class TestJit(unittest.TestCase):
|
||||
def test_simple_jit(self):
|
||||
@TinyJit
|
||||
@@ -240,5 +236,8 @@ class TestJit(unittest.TestCase):
|
||||
# but the bad_jitted doesn't!
|
||||
np.testing.assert_equal([1], cache.bad_jitted(zero).numpy())
|
||||
|
||||
assert len(cache.good_jitted.jit_cache) == 1
|
||||
assert len(cache.bad_jitted.jit_cache) == 1
|
||||
|
||||
if __name__ == '__main__':
|
||||
unittest.main()
|
||||
Reference in New Issue
Block a user