diff --git a/test/test_jit.py b/test/test_jit.py index 1b30ffc3b7..489590838c 100644 --- a/test/test_jit.py +++ b/test/test_jit.py @@ -5,9 +5,10 @@ import numpy as np from hypothesis import given, settings, strategies as strat from test.helpers import assert_jit_cache_len, not_support_multi_device, REAL_DEV from tinygrad.tensor import Tensor -from tinygrad.engine.jit import TinyJit +from tinygrad.engine.jit import TinyJit, GraphRunner from tinygrad.device import Device from tinygrad.helpers import Context, JIT, GlobalCounters +from tinygrad.runtime.support.hcq import HCQCompiled from tinygrad.dtype import dtypes from extra.models.unet import ResBlock @@ -471,6 +472,32 @@ class TestJit(unittest.TestCase): np.testing.assert_allclose((a.numpy()+b.numpy()), zc.numpy(), atol=1e-4, rtol=1e-5) np.testing.assert_allclose((a.numpy()*b.numpy()), wc.numpy(), atol=1e-4, rtol=1e-5) + @unittest.skipUnless((not isinstance(Device.default, HCQCompiled)) and Device.default.graph is not None, "must be non-hcq with graph") + def test_jit_several_incompatible_devs(self): + assert isinstance(Device["CPU"], HCQCompiled) and Device["CPU"].graph is not None + assert (not isinstance(Device.default, HCQCompiled)) and Device.default.graph is not None + + d0, d1 = Device.DEFAULT, "CPU" + + @TinyJit + def f(a0, b0): + a1 = (a + 2.0).contiguous().realize() + a2 = (a1 * 2.0).contiguous().realize() + + b1 = (b0 + 2.0).contiguous().realize() + b2 = (b1 * 2.0).contiguous().realize() + + return a2, b2 + + for _ in range(5): + a = Tensor.randn(10, 10, device=d0).realize() + b = Tensor.randn(10, 10, device=d1).realize() + a1, b1 = f(a, b) + np.testing.assert_allclose(((a.numpy()+2.0)*2.0), a1.numpy(), atol=1e-4, rtol=1e-5) + np.testing.assert_allclose(((b.numpy()+2.0)*2.0), b1.numpy(), atol=1e-4, rtol=1e-5) + + assert all(isinstance(ei.prg, GraphRunner) for ei in f.jit_cache), repr(f.jit_cache) + @unittest.skipIf(not_support_multi_device(), "no multi") def test_jitted_view(self): d0, d1 = f"{Device.DEFAULT}:0", f"{Device.DEFAULT}:1" diff --git a/tinygrad/engine/jit.py b/tinygrad/engine/jit.py index e0e564957a..2254c606e4 100644 --- a/tinygrad/engine/jit.py +++ b/tinygrad/engine/jit.py @@ -52,14 +52,14 @@ def apply_graph_to_jit(jit_cache: list[ExecItem], input_rawbuffers: list[Buffer] can_be_graphed = ji_graph_dev is not None and ji_graph_dev.graph is not None and graph_class(ji_graph_dev).supports_exec_item([ji_graph_dev], ji) # Check if the current batch can be extended with this item. - new_batched_devs = dedup(current_batch_devs + [ji_graph_dev]) - can_share_graph = can_be_graphed and len(current_batch_devs) > 0 and graph_class(current_batch_devs[0]).supports_exec_item(new_batched_devs, ji) + can_share_graph = can_be_graphed and len(current_batch_devs) > 0 and \ + graph_class(current_batch_devs[0]).supports_exec_item(dedup(current_batch_devs + [ji_graph_dev]), ji) can_extend_graph_batch = can_share_graph and (max_batch_size == 0 or len(current_batch) < max_batch_size) # Flush the current batch if any, since it can't be extended or is full. if not can_extend_graph_batch and len(current_batch) > 0: flush_batch() (current_batch if can_be_graphed else graphed_jit_cache).append(ji) - current_batch_devs = new_batched_devs if can_be_graphed else [] + current_batch_devs = dedup(current_batch_devs + [ji_graph_dev]) if can_be_graphed else [] if len(current_batch) > 0: flush_batch() return graphed_jit_cache