mirror of
https://github.com/tinygrad/tinygrad.git
synced 2026-04-29 03:00:14 -04:00
Enable JIT tests for supported devices, skip METAL and WEBGPU (#1265)
* Enable JIT test * really test metal * Skip some device
This commit is contained in:
6
.github/workflows/test.yml
vendored
6
.github/workflows/test.yml
vendored
@@ -239,9 +239,13 @@ jobs:
|
||||
run: PYTHONPATH="." METAL=1 python3 test/external/external_test_speed_llama.py
|
||||
#- name: Run dtype test
|
||||
# run: DEBUG=4 METAL=1 python -m pytest test/test_dtype.py
|
||||
# dtype test has issues on test_half_to_int8
|
||||
- name: Run ops test
|
||||
run: DEBUG=2 METAL=1 python -m pytest test/test_ops.py
|
||||
# dtype test has issues on test_half_to_int8
|
||||
- name: Run JIT test
|
||||
run: DEBUG=2 METAL=1 python -m pytest test/test_jit.py
|
||||
# TODO: why not testing the whole test/?
|
||||
|
||||
|
||||
testdocker:
|
||||
name: Docker Test
|
||||
|
||||
@@ -2,9 +2,10 @@
|
||||
import unittest
|
||||
import numpy as np
|
||||
from tinygrad.tensor import Tensor, Device
|
||||
from tinygrad.jit import TinyJit
|
||||
from tinygrad.jit import TinyJit, JIT_SUPPORTED_DEVICE
|
||||
|
||||
@unittest.skipUnless(Device.DEFAULT == "GPU", "JIT is only for GPU")
|
||||
# NOTE: METAL fails, might be platform and optimization options dependent.
|
||||
@unittest.skipUnless(Device.DEFAULT in JIT_SUPPORTED_DEVICE and Device.DEFAULT not in ["METAL", "WEBGPU"], f"no JIT on {Device.DEFAULT}")
|
||||
class TestJit(unittest.TestCase):
|
||||
def test_simple_jit(self):
|
||||
@TinyJit
|
||||
|
||||
@@ -6,6 +6,8 @@ from tinygrad.lazy import Device
|
||||
from tinygrad.tensor import Tensor
|
||||
from tinygrad.ops import GlobalCounters, RawBuffer
|
||||
|
||||
JIT_SUPPORTED_DEVICE = ["GPU", "CLANG", "METAL", "CUDA", "HIP", "WEBGPU"]
|
||||
|
||||
class TinyJit:
|
||||
def __init__(self, fxn:Callable):
|
||||
self.fxn: Callable = fxn
|
||||
@@ -18,7 +20,7 @@ class TinyJit:
|
||||
def __get__(self, obj, objtype): return functools.partial(self.__call__, obj)
|
||||
|
||||
def __call__(self, *args, **kwargs) -> Any:
|
||||
if Device.DEFAULT not in ["GPU", "CLANG", "METAL", "CUDA", "HIP", "WEBGPU"]: return self.fxn(*args, **kwargs) # only jit on the GPU codegen
|
||||
if Device.DEFAULT not in JIT_SUPPORTED_DEVICE: return self.fxn(*args, **kwargs) # only jit on supported device
|
||||
# NOTE: this cast is needed since although we know realize will create a ".realized" DeviceBuffer, the type checker doesn't
|
||||
input_rawbuffers: Dict[Union[int, str], RawBuffer] = {cast(Union[int, str], k):cast(RawBuffer, v.realize().lazydata.realized) for k,v in itertools.chain(enumerate(args), kwargs.items()) if isinstance(v, Tensor)}
|
||||
assert len(input_rawbuffers) != 0, "no inputs to JIT"
|
||||
|
||||
Reference in New Issue
Block a user