diff --git a/.github/workflows/test.yml b/.github/workflows/test.yml index 229ed788d4..41f40a14ab 100644 --- a/.github/workflows/test.yml +++ b/.github/workflows/test.yml @@ -239,9 +239,13 @@ jobs: run: PYTHONPATH="." METAL=1 python3 test/external/external_test_speed_llama.py #- name: Run dtype test # run: DEBUG=4 METAL=1 python -m pytest test/test_dtype.py + # dtype test has issues on test_half_to_int8 - name: Run ops test run: DEBUG=2 METAL=1 python -m pytest test/test_ops.py - # dtype test has issues on test_half_to_int8 + - name: Run JIT test + run: DEBUG=2 METAL=1 python -m pytest test/test_jit.py + # TODO: why not testing the whole test/? + testdocker: name: Docker Test diff --git a/test/test_jit.py b/test/test_jit.py index 3a9e1f04d4..1221c4860a 100644 --- a/test/test_jit.py +++ b/test/test_jit.py @@ -2,9 +2,10 @@ import unittest import numpy as np from tinygrad.tensor import Tensor, Device -from tinygrad.jit import TinyJit +from tinygrad.jit import TinyJit, JIT_SUPPORTED_DEVICE -@unittest.skipUnless(Device.DEFAULT == "GPU", "JIT is only for GPU") +# NOTE: METAL fails, might be platform and optimization options dependent. +@unittest.skipUnless(Device.DEFAULT in JIT_SUPPORTED_DEVICE and Device.DEFAULT not in ["METAL", "WEBGPU"], f"no JIT on {Device.DEFAULT}") class TestJit(unittest.TestCase): def test_simple_jit(self): @TinyJit diff --git a/tinygrad/jit.py b/tinygrad/jit.py index cc139efde5..e9da4a8f8a 100644 --- a/tinygrad/jit.py +++ b/tinygrad/jit.py @@ -6,6 +6,8 @@ from tinygrad.lazy import Device from tinygrad.tensor import Tensor from tinygrad.ops import GlobalCounters, RawBuffer +JIT_SUPPORTED_DEVICE = ["GPU", "CLANG", "METAL", "CUDA", "HIP", "WEBGPU"] + class TinyJit: def __init__(self, fxn:Callable): self.fxn: Callable = fxn @@ -18,7 +20,7 @@ class TinyJit: def __get__(self, obj, objtype): return functools.partial(self.__call__, obj) def __call__(self, *args, **kwargs) -> Any: - if Device.DEFAULT not in ["GPU", "CLANG", "METAL", "CUDA", "HIP", "WEBGPU"]: return self.fxn(*args, **kwargs) # only jit on the GPU codegen + if Device.DEFAULT not in JIT_SUPPORTED_DEVICE: return self.fxn(*args, **kwargs) # only jit on supported device # NOTE: this cast is needed since although we know realize will create a ".realized" DeviceBuffer, the type checker doesn't input_rawbuffers: Dict[Union[int, str], RawBuffer] = {cast(Union[int, str], k):cast(RawBuffer, v.realize().lazydata.realized) for k,v in itertools.chain(enumerate(args), kwargs.items()) if isinstance(v, Tensor)} assert len(input_rawbuffers) != 0, "no inputs to JIT"