From fc4e713d1c075c725fe61540155df88dda0f47fa Mon Sep 17 00:00:00 2001 From: nimlgen <138685161+nimlgen@users.noreply.github.com> Date: Tue, 5 Aug 2025 21:32:37 +0300 Subject: [PATCH] jit graph split tests (#11507) * jit graph split tests * fix * one more test * more tests * fix * xm * rmeote --- test/test_jit.py | 196 ++++++++++++++++++++++++++++++++++++++++------- 1 file changed, 167 insertions(+), 29 deletions(-) diff --git a/test/test_jit.py b/test/test_jit.py index 489590838c..449b63687f 100644 --- a/test/test_jit.py +++ b/test/test_jit.py @@ -5,10 +5,10 @@ import numpy as np from hypothesis import given, settings, strategies as strat from test.helpers import assert_jit_cache_len, not_support_multi_device, REAL_DEV from tinygrad.tensor import Tensor -from tinygrad.engine.jit import TinyJit, GraphRunner +from tinygrad.engine.jit import TinyJit, GraphRunner, MultiGraphRunner, graph_class +from tinygrad.engine.realize import CompiledRunner, BufferCopy, BufferXfer from tinygrad.device import Device -from tinygrad.helpers import Context, JIT, GlobalCounters -from tinygrad.runtime.support.hcq import HCQCompiled +from tinygrad.helpers import Context, JIT, GlobalCounters, getenv from tinygrad.dtype import dtypes from extra.models.unet import ResBlock @@ -472,32 +472,6 @@ class TestJit(unittest.TestCase): np.testing.assert_allclose((a.numpy()+b.numpy()), zc.numpy(), atol=1e-4, rtol=1e-5) np.testing.assert_allclose((a.numpy()*b.numpy()), wc.numpy(), atol=1e-4, rtol=1e-5) - @unittest.skipUnless((not isinstance(Device.default, HCQCompiled)) and Device.default.graph is not None, "must be non-hcq with graph") - def test_jit_several_incompatible_devs(self): - assert isinstance(Device["CPU"], HCQCompiled) and Device["CPU"].graph is not None - assert (not isinstance(Device.default, HCQCompiled)) and Device.default.graph is not None - - d0, d1 = Device.DEFAULT, "CPU" - - @TinyJit - def f(a0, b0): - a1 = (a + 2.0).contiguous().realize() - a2 = (a1 * 2.0).contiguous().realize() - - b1 = (b0 + 2.0).contiguous().realize() - b2 = (b1 * 2.0).contiguous().realize() - - return a2, b2 - - for _ in range(5): - a = Tensor.randn(10, 10, device=d0).realize() - b = Tensor.randn(10, 10, device=d1).realize() - a1, b1 = f(a, b) - np.testing.assert_allclose(((a.numpy()+2.0)*2.0), a1.numpy(), atol=1e-4, rtol=1e-5) - np.testing.assert_allclose(((b.numpy()+2.0)*2.0), b1.numpy(), atol=1e-4, rtol=1e-5) - - assert all(isinstance(ei.prg, GraphRunner) for ei in f.jit_cache), repr(f.jit_cache) - @unittest.skipIf(not_support_multi_device(), "no multi") def test_jitted_view(self): d0, d1 = f"{Device.DEFAULT}:0", f"{Device.DEFAULT}:1" @@ -696,5 +670,169 @@ class TestJitFree(unittest.TestCase): out = fxn(Tensor([11,1,2,3,4])) self.assertEqual(out.item(), 13600) +class TestJitGraphSplit(unittest.TestCase): + def compute(self, device, inp): + assert inp.device == device, f"Input device {inp.device} does not match expected {device}" + return (inp + 1.0).contiguous().realize() + + def copy(self, device, to_device, inp): + assert inp.device == device, f"Input device {inp.device} does not match expected {device}" + return inp.to(to_device).realize() + + def expect(self, f, *args, graph=None, multigraph=None, hcqgraph=None): + def _numpies(tpl): return tpl.numpy() if tpl.__class__ is Tensor else tuple([t.numpy() for t in tpl]) + + expected = _numpies(f(*args)) + for i in range(4): + res = _numpies(f(*args)) + np.testing.assert_allclose(res, expected, atol=1e-4, rtol=1e-5) + + dev = Device[Device.DEFAULT] + graph_t = graph_class(dev) + if graph_t is None: return + + got = f.jit_cache + from tinygrad.runtime.graph.hcq import HCQGraph + if graph_t is HCQGraph: + validate = hcqgraph + elif issubclass(graph_t, MultiGraphRunner): + validate = multigraph + else: + validate = graph + + assert len(got) == len(validate), f"Expected {len(validate)} operations, got {len(got)}" + for expected, got in zip(validate, got): + if expected["type"] == "graph": + assert isinstance(got.prg, GraphRunner), f"Expected GraphRunner, got {type(got.prg)}" + assert len(got.prg.jit_cache) == expected["cnt"], f"Expected {expected['cnt']} operations in graph, got {len(got.prg.jit_cache)}" + elif expected["type"] == "comp": + assert isinstance(got.prg, CompiledRunner), f"Expected CompiledRunner, got {type(got.prg)}" + elif expected["type"] == "copy": + assert isinstance(got.prg, BufferCopy), f"Expected BufferCopy, got {type(got.prg)}" + elif expected["type"] == "xfer": + assert isinstance(got.prg, BufferXfer), f"Expected BufferXfer, got {type(got.prg)}" + + def ji_graph(self, cnt): return {"type": "graph", "cnt": cnt} + def ji_comp(self): return {"type": "comp"} + def ji_copy(self): return {"type": "copy"} + def ji_xfer(self): return {"type": "xfer"} + + def test_jit_split_simple(self): + if Device.DEFAULT == "REMOTE": raise unittest.SkipTest("REMOTE gpu is broken") + + @TinyJit + def f(inp): + op0 = self.compute(Device.DEFAULT, inp) + op1 = self.compute(Device.DEFAULT, op0) + op2 = self.compute(Device.DEFAULT, op1) + return op2 + + inp = Tensor.randn(10, 10, device=Device.DEFAULT).realize() + self.expect(f, inp, + graph=[self.ji_graph(3)], + multigraph=[self.ji_graph(3)], + hcqgraph=[self.ji_graph(3)]) + + def test_jit_cpu_simple(self): + if Device.DEFAULT == "CPU": raise unittest.SkipTest("CPU is not a valid default device for this test") + + @TinyJit + def f(inp, inp_cpu): + op0 = self.compute(Device.DEFAULT, inp) + op1 = self.compute(Device.DEFAULT, op0) + op2 = self.compute("CPU", inp_cpu) + op3 = self.compute(Device.DEFAULT, op1) + return op2, op3 + + inp = Tensor.randn(10, 10, device=Device.DEFAULT).realize() + inp_cpu = Tensor.randn(10, 10, device="CPU").realize() + self.expect(f, inp, inp_cpu, + graph=[self.ji_graph(2), self.ji_comp(), self.ji_comp()], + multigraph=[self.ji_graph(2), self.ji_comp(), self.ji_comp()], + hcqgraph=[self.ji_graph(4)]) + + def test_jit_cpu_several(self): + if Device.DEFAULT == "CPU": raise unittest.SkipTest("CPU is not a valid default device for this test") + + @TinyJit + def f(inp, inp_cpu): + op0 = self.compute(Device.DEFAULT, inp) + op1 = self.compute(Device.DEFAULT, op0) + op2 = self.compute("CPU", inp_cpu) + op3 = self.compute("CPU", op2) + op4 = self.compute(Device.DEFAULT, op1) + return op3, op4 + + inp = Tensor.randn(10, 10, device=Device.DEFAULT).realize() + inp_cpu = Tensor.randn(10, 10, device="CPU").realize() + self.expect(f, inp, inp_cpu, + graph=[self.ji_graph(2), self.ji_graph(2), self.ji_comp()], + multigraph=[self.ji_graph(2), self.ji_graph(2), self.ji_comp()], + hcqgraph=[self.ji_graph(5)]) + + def test_jit_multidev(self): + if Device.DEFAULT == "CPU": raise unittest.SkipTest("CPU is not a valid default device for this test") + + try: Device[f"{Device.DEFAULT}:1"] + except Exception: raise unittest.SkipTest("no multidevice") + + @TinyJit + def f(inp, inp_d1): + op0 = self.compute(Device.DEFAULT, inp) + op1 = self.compute(Device.DEFAULT, op0) + op2 = self.compute(f"{Device.DEFAULT}:1", inp_d1) + op3 = self.compute(f"{Device.DEFAULT}:1", op2) + op4 = self.compute(Device.DEFAULT, op1) + return op3, op4 + + inp = Tensor.randn(10, 10, device=Device.DEFAULT).realize() + inp_d1 = Tensor.randn(10, 10, device=f"{Device.DEFAULT}:1").realize() + self.expect(f, inp, inp_d1, + graph=[self.ji_graph(2), self.ji_graph(2), self.ji_comp()], + multigraph=[self.ji_graph(5)], + hcqgraph=[self.ji_graph(5)]) + + def test_jit_multidev_xfer(self): + if Device.DEFAULT in {"CPU", "LLVM"}: raise unittest.SkipTest("CPU/LLVM is not a valid default device for this test (zero-copies)") + + try: Device[f"{Device.DEFAULT}:1"] + except Exception: raise unittest.SkipTest("no multidevice") + + @TinyJit + def f(inp, inp_d1): + op0 = self.compute(Device.DEFAULT, inp) + op1 = self.compute(Device.DEFAULT, op0) + op2 = self.compute(f"{Device.DEFAULT}:1", inp_d1) + op3 = self.copy(f"{Device.DEFAULT}:1", Device.DEFAULT, op2) + op4 = self.compute(f"{Device.DEFAULT}:1", op2) + op5 = self.compute(Device.DEFAULT, op3) + return op1, op4, op5 + + inp = Tensor.randn(10, 10, device=Device.DEFAULT).realize() + inp_d1 = Tensor.randn(10, 10, device=f"{Device.DEFAULT}:1").realize() + self.expect(f, inp, inp_d1, + graph=[self.ji_graph(2), self.ji_comp(), self.ji_xfer(), self.ji_comp(), self.ji_comp()], + multigraph=[self.ji_graph(6)], + hcqgraph=[self.ji_graph(6)]) + + @unittest.skipIf(getenv("MOCKGPU"), "MockGPU does not support parallel copies") + def test_jit_multidev_copy(self): + if Device.DEFAULT in {"CPU", "LLVM"}: raise unittest.SkipTest("CPU/LLVM is not a valid default device for this test (zero-copies)") + if Device.DEFAULT == "REMOTE": raise unittest.SkipTest("REMOTE gpu is broken") + + @TinyJit + def f(inp): + op0 = self.compute(Device.DEFAULT, inp) + op1 = self.compute(Device.DEFAULT, op0) + op2 = self.copy(Device.DEFAULT, "CPU", op1) + op3 = self.compute("CPU", op2) + return op3 + + inp = Tensor.randn(10, 10, device=Device.DEFAULT).realize() + self.expect(f, inp, + graph=[self.ji_graph(2), self.ji_copy(), self.ji_comp()], + multigraph=[self.ji_graph(2), self.ji_copy(), self.ji_comp()], + hcqgraph=[self.ji_graph(4)]) + if __name__ == '__main__': unittest.main()