jit graph split tests (#11507)

* jit graph split tests

* fix

* one more test

* more tests

* fix

* xm

* rmeote
This commit is contained in:
nimlgen
2025-08-05 21:32:37 +03:00
committed by GitHub
parent c57fde51f9
commit fc4e713d1c

View File

@@ -5,10 +5,10 @@ import numpy as np
from hypothesis import given, settings, strategies as strat
from test.helpers import assert_jit_cache_len, not_support_multi_device, REAL_DEV
from tinygrad.tensor import Tensor
from tinygrad.engine.jit import TinyJit, GraphRunner
from tinygrad.engine.jit import TinyJit, GraphRunner, MultiGraphRunner, graph_class
from tinygrad.engine.realize import CompiledRunner, BufferCopy, BufferXfer
from tinygrad.device import Device
from tinygrad.helpers import Context, JIT, GlobalCounters
from tinygrad.runtime.support.hcq import HCQCompiled
from tinygrad.helpers import Context, JIT, GlobalCounters, getenv
from tinygrad.dtype import dtypes
from extra.models.unet import ResBlock
@@ -472,32 +472,6 @@ class TestJit(unittest.TestCase):
np.testing.assert_allclose((a.numpy()+b.numpy()), zc.numpy(), atol=1e-4, rtol=1e-5)
np.testing.assert_allclose((a.numpy()*b.numpy()), wc.numpy(), atol=1e-4, rtol=1e-5)
@unittest.skipUnless((not isinstance(Device.default, HCQCompiled)) and Device.default.graph is not None, "must be non-hcq with graph")
def test_jit_several_incompatible_devs(self):
assert isinstance(Device["CPU"], HCQCompiled) and Device["CPU"].graph is not None
assert (not isinstance(Device.default, HCQCompiled)) and Device.default.graph is not None
d0, d1 = Device.DEFAULT, "CPU"
@TinyJit
def f(a0, b0):
a1 = (a + 2.0).contiguous().realize()
a2 = (a1 * 2.0).contiguous().realize()
b1 = (b0 + 2.0).contiguous().realize()
b2 = (b1 * 2.0).contiguous().realize()
return a2, b2
for _ in range(5):
a = Tensor.randn(10, 10, device=d0).realize()
b = Tensor.randn(10, 10, device=d1).realize()
a1, b1 = f(a, b)
np.testing.assert_allclose(((a.numpy()+2.0)*2.0), a1.numpy(), atol=1e-4, rtol=1e-5)
np.testing.assert_allclose(((b.numpy()+2.0)*2.0), b1.numpy(), atol=1e-4, rtol=1e-5)
assert all(isinstance(ei.prg, GraphRunner) for ei in f.jit_cache), repr(f.jit_cache)
@unittest.skipIf(not_support_multi_device(), "no multi")
def test_jitted_view(self):
d0, d1 = f"{Device.DEFAULT}:0", f"{Device.DEFAULT}:1"
@@ -696,5 +670,169 @@ class TestJitFree(unittest.TestCase):
out = fxn(Tensor([11,1,2,3,4]))
self.assertEqual(out.item(), 13600)
class TestJitGraphSplit(unittest.TestCase):
def compute(self, device, inp):
assert inp.device == device, f"Input device {inp.device} does not match expected {device}"
return (inp + 1.0).contiguous().realize()
def copy(self, device, to_device, inp):
assert inp.device == device, f"Input device {inp.device} does not match expected {device}"
return inp.to(to_device).realize()
def expect(self, f, *args, graph=None, multigraph=None, hcqgraph=None):
def _numpies(tpl): return tpl.numpy() if tpl.__class__ is Tensor else tuple([t.numpy() for t in tpl])
expected = _numpies(f(*args))
for i in range(4):
res = _numpies(f(*args))
np.testing.assert_allclose(res, expected, atol=1e-4, rtol=1e-5)
dev = Device[Device.DEFAULT]
graph_t = graph_class(dev)
if graph_t is None: return
got = f.jit_cache
from tinygrad.runtime.graph.hcq import HCQGraph
if graph_t is HCQGraph:
validate = hcqgraph
elif issubclass(graph_t, MultiGraphRunner):
validate = multigraph
else:
validate = graph
assert len(got) == len(validate), f"Expected {len(validate)} operations, got {len(got)}"
for expected, got in zip(validate, got):
if expected["type"] == "graph":
assert isinstance(got.prg, GraphRunner), f"Expected GraphRunner, got {type(got.prg)}"
assert len(got.prg.jit_cache) == expected["cnt"], f"Expected {expected['cnt']} operations in graph, got {len(got.prg.jit_cache)}"
elif expected["type"] == "comp":
assert isinstance(got.prg, CompiledRunner), f"Expected CompiledRunner, got {type(got.prg)}"
elif expected["type"] == "copy":
assert isinstance(got.prg, BufferCopy), f"Expected BufferCopy, got {type(got.prg)}"
elif expected["type"] == "xfer":
assert isinstance(got.prg, BufferXfer), f"Expected BufferXfer, got {type(got.prg)}"
def ji_graph(self, cnt): return {"type": "graph", "cnt": cnt}
def ji_comp(self): return {"type": "comp"}
def ji_copy(self): return {"type": "copy"}
def ji_xfer(self): return {"type": "xfer"}
def test_jit_split_simple(self):
if Device.DEFAULT == "REMOTE": raise unittest.SkipTest("REMOTE gpu is broken")
@TinyJit
def f(inp):
op0 = self.compute(Device.DEFAULT, inp)
op1 = self.compute(Device.DEFAULT, op0)
op2 = self.compute(Device.DEFAULT, op1)
return op2
inp = Tensor.randn(10, 10, device=Device.DEFAULT).realize()
self.expect(f, inp,
graph=[self.ji_graph(3)],
multigraph=[self.ji_graph(3)],
hcqgraph=[self.ji_graph(3)])
def test_jit_cpu_simple(self):
if Device.DEFAULT == "CPU": raise unittest.SkipTest("CPU is not a valid default device for this test")
@TinyJit
def f(inp, inp_cpu):
op0 = self.compute(Device.DEFAULT, inp)
op1 = self.compute(Device.DEFAULT, op0)
op2 = self.compute("CPU", inp_cpu)
op3 = self.compute(Device.DEFAULT, op1)
return op2, op3
inp = Tensor.randn(10, 10, device=Device.DEFAULT).realize()
inp_cpu = Tensor.randn(10, 10, device="CPU").realize()
self.expect(f, inp, inp_cpu,
graph=[self.ji_graph(2), self.ji_comp(), self.ji_comp()],
multigraph=[self.ji_graph(2), self.ji_comp(), self.ji_comp()],
hcqgraph=[self.ji_graph(4)])
def test_jit_cpu_several(self):
if Device.DEFAULT == "CPU": raise unittest.SkipTest("CPU is not a valid default device for this test")
@TinyJit
def f(inp, inp_cpu):
op0 = self.compute(Device.DEFAULT, inp)
op1 = self.compute(Device.DEFAULT, op0)
op2 = self.compute("CPU", inp_cpu)
op3 = self.compute("CPU", op2)
op4 = self.compute(Device.DEFAULT, op1)
return op3, op4
inp = Tensor.randn(10, 10, device=Device.DEFAULT).realize()
inp_cpu = Tensor.randn(10, 10, device="CPU").realize()
self.expect(f, inp, inp_cpu,
graph=[self.ji_graph(2), self.ji_graph(2), self.ji_comp()],
multigraph=[self.ji_graph(2), self.ji_graph(2), self.ji_comp()],
hcqgraph=[self.ji_graph(5)])
def test_jit_multidev(self):
if Device.DEFAULT == "CPU": raise unittest.SkipTest("CPU is not a valid default device for this test")
try: Device[f"{Device.DEFAULT}:1"]
except Exception: raise unittest.SkipTest("no multidevice")
@TinyJit
def f(inp, inp_d1):
op0 = self.compute(Device.DEFAULT, inp)
op1 = self.compute(Device.DEFAULT, op0)
op2 = self.compute(f"{Device.DEFAULT}:1", inp_d1)
op3 = self.compute(f"{Device.DEFAULT}:1", op2)
op4 = self.compute(Device.DEFAULT, op1)
return op3, op4
inp = Tensor.randn(10, 10, device=Device.DEFAULT).realize()
inp_d1 = Tensor.randn(10, 10, device=f"{Device.DEFAULT}:1").realize()
self.expect(f, inp, inp_d1,
graph=[self.ji_graph(2), self.ji_graph(2), self.ji_comp()],
multigraph=[self.ji_graph(5)],
hcqgraph=[self.ji_graph(5)])
def test_jit_multidev_xfer(self):
if Device.DEFAULT in {"CPU", "LLVM"}: raise unittest.SkipTest("CPU/LLVM is not a valid default device for this test (zero-copies)")
try: Device[f"{Device.DEFAULT}:1"]
except Exception: raise unittest.SkipTest("no multidevice")
@TinyJit
def f(inp, inp_d1):
op0 = self.compute(Device.DEFAULT, inp)
op1 = self.compute(Device.DEFAULT, op0)
op2 = self.compute(f"{Device.DEFAULT}:1", inp_d1)
op3 = self.copy(f"{Device.DEFAULT}:1", Device.DEFAULT, op2)
op4 = self.compute(f"{Device.DEFAULT}:1", op2)
op5 = self.compute(Device.DEFAULT, op3)
return op1, op4, op5
inp = Tensor.randn(10, 10, device=Device.DEFAULT).realize()
inp_d1 = Tensor.randn(10, 10, device=f"{Device.DEFAULT}:1").realize()
self.expect(f, inp, inp_d1,
graph=[self.ji_graph(2), self.ji_comp(), self.ji_xfer(), self.ji_comp(), self.ji_comp()],
multigraph=[self.ji_graph(6)],
hcqgraph=[self.ji_graph(6)])
@unittest.skipIf(getenv("MOCKGPU"), "MockGPU does not support parallel copies")
def test_jit_multidev_copy(self):
if Device.DEFAULT in {"CPU", "LLVM"}: raise unittest.SkipTest("CPU/LLVM is not a valid default device for this test (zero-copies)")
if Device.DEFAULT == "REMOTE": raise unittest.SkipTest("REMOTE gpu is broken")
@TinyJit
def f(inp):
op0 = self.compute(Device.DEFAULT, inp)
op1 = self.compute(Device.DEFAULT, op0)
op2 = self.copy(Device.DEFAULT, "CPU", op1)
op3 = self.compute("CPU", op2)
return op3
inp = Tensor.randn(10, 10, device=Device.DEFAULT).realize()
self.expect(f, inp,
graph=[self.ji_graph(2), self.ji_copy(), self.ji_comp()],
multigraph=[self.ji_graph(2), self.ji_copy(), self.ji_comp()],
hcqgraph=[self.ji_graph(4)])
if __name__ == '__main__':
unittest.main()