Enable JIT tests for supported devices, skip METAL and WEBGPU (#1265)

* Enable JIT test

* really test metal

* Skip some device
This commit is contained in:
chenyu
2023-07-18 14:40:37 -04:00
committed by GitHub
parent f8c539989e
commit c96bf395df
3 changed files with 11 additions and 4 deletions

View File

@@ -239,9 +239,13 @@ jobs:
run: PYTHONPATH="." METAL=1 python3 test/external/external_test_speed_llama.py
#- name: Run dtype test
# run: DEBUG=4 METAL=1 python -m pytest test/test_dtype.py
# dtype test has issues on test_half_to_int8
- name: Run ops test
run: DEBUG=2 METAL=1 python -m pytest test/test_ops.py
# dtype test has issues on test_half_to_int8
- name: Run JIT test
run: DEBUG=2 METAL=1 python -m pytest test/test_jit.py
# TODO: why not testing the whole test/?
testdocker:
name: Docker Test

View File

@@ -2,9 +2,10 @@
import unittest
import numpy as np
from tinygrad.tensor import Tensor, Device
from tinygrad.jit import TinyJit
from tinygrad.jit import TinyJit, JIT_SUPPORTED_DEVICE
@unittest.skipUnless(Device.DEFAULT == "GPU", "JIT is only for GPU")
# NOTE: METAL fails, might be platform and optimization options dependent.
@unittest.skipUnless(Device.DEFAULT in JIT_SUPPORTED_DEVICE and Device.DEFAULT not in ["METAL", "WEBGPU"], f"no JIT on {Device.DEFAULT}")
class TestJit(unittest.TestCase):
def test_simple_jit(self):
@TinyJit

View File

@@ -6,6 +6,8 @@ from tinygrad.lazy import Device
from tinygrad.tensor import Tensor
from tinygrad.ops import GlobalCounters, RawBuffer
JIT_SUPPORTED_DEVICE = ["GPU", "CLANG", "METAL", "CUDA", "HIP", "WEBGPU"]
class TinyJit:
def __init__(self, fxn:Callable):
self.fxn: Callable = fxn
@@ -18,7 +20,7 @@ class TinyJit:
def __get__(self, obj, objtype): return functools.partial(self.__call__, obj)
def __call__(self, *args, **kwargs) -> Any:
if Device.DEFAULT not in ["GPU", "CLANG", "METAL", "CUDA", "HIP", "WEBGPU"]: return self.fxn(*args, **kwargs) # only jit on the GPU codegen
if Device.DEFAULT not in JIT_SUPPORTED_DEVICE: return self.fxn(*args, **kwargs) # only jit on supported device
# NOTE: this cast is needed since although we know realize will create a ".realized" DeviceBuffer, the type checker doesn't
input_rawbuffers: Dict[Union[int, str], RawBuffer] = {cast(Union[int, str], k):cast(RawBuffer, v.realize().lazydata.realized) for k,v in itertools.chain(enumerate(args), kwargs.items()) if isinstance(v, Tensor)}
assert len(input_rawbuffers) != 0, "no inputs to JIT"