mirror of
https://github.com/tinygrad/tinygrad.git
synced 2026-01-08 22:48:25 -05:00
add needs_second_gpu decorator (#13543)
* add needs_second_gpu decorator * more skips * two more fixes
This commit is contained in:
@@ -3,7 +3,7 @@ import unittest, functools
|
||||
import numpy as np
|
||||
|
||||
from hypothesis import given, settings, strategies as strat
|
||||
from test.helpers import assert_jit_cache_len, not_support_multi_device, REAL_DEV
|
||||
from test.helpers import assert_jit_cache_len, not_support_multi_device, REAL_DEV, needs_second_gpu
|
||||
from tinygrad.tensor import Tensor
|
||||
from tinygrad.engine.jit import TinyJit, GraphRunner, MultiGraphRunner, graph_class
|
||||
from tinygrad.engine.realize import CompiledRunner, BufferCopy, BufferXfer
|
||||
@@ -439,6 +439,7 @@ class TestJit(unittest.TestCase):
|
||||
ja = jf(a)
|
||||
np.testing.assert_allclose(a.numpy(), ja.numpy(), atol=1e-4, rtol=1e-5)
|
||||
|
||||
@needs_second_gpu
|
||||
@unittest.skipIf(not_support_multi_device(), "no multi")
|
||||
def test_jitted_transfers(self):
|
||||
d0, d1 = f"{Device.DEFAULT}:0", f"{Device.DEFAULT}:1"
|
||||
@@ -472,6 +473,7 @@ class TestJit(unittest.TestCase):
|
||||
np.testing.assert_allclose((a.numpy()+b.numpy()), zc.numpy(), atol=1e-4, rtol=1e-5)
|
||||
np.testing.assert_allclose((a.numpy()*b.numpy()), wc.numpy(), atol=1e-4, rtol=1e-5)
|
||||
|
||||
@needs_second_gpu
|
||||
@unittest.skipIf(not_support_multi_device(), "no multi")
|
||||
def test_jitted_view(self):
|
||||
d0, d1 = f"{Device.DEFAULT}:0", f"{Device.DEFAULT}:1"
|
||||
|
||||
Reference in New Issue
Block a user