mirror of
https://github.com/tinygrad/tinygrad.git
synced 2026-01-09 15:08:02 -05:00
jit graph split tests (#11507)
* jit graph split tests * fix * one more test * more tests * fix * xm * rmeote
This commit is contained in:
196
test/test_jit.py
196
test/test_jit.py
@@ -5,10 +5,10 @@ import numpy as np
|
||||
from hypothesis import given, settings, strategies as strat
|
||||
from test.helpers import assert_jit_cache_len, not_support_multi_device, REAL_DEV
|
||||
from tinygrad.tensor import Tensor
|
||||
from tinygrad.engine.jit import TinyJit, GraphRunner
|
||||
from tinygrad.engine.jit import TinyJit, GraphRunner, MultiGraphRunner, graph_class
|
||||
from tinygrad.engine.realize import CompiledRunner, BufferCopy, BufferXfer
|
||||
from tinygrad.device import Device
|
||||
from tinygrad.helpers import Context, JIT, GlobalCounters
|
||||
from tinygrad.runtime.support.hcq import HCQCompiled
|
||||
from tinygrad.helpers import Context, JIT, GlobalCounters, getenv
|
||||
from tinygrad.dtype import dtypes
|
||||
from extra.models.unet import ResBlock
|
||||
|
||||
@@ -472,32 +472,6 @@ class TestJit(unittest.TestCase):
|
||||
np.testing.assert_allclose((a.numpy()+b.numpy()), zc.numpy(), atol=1e-4, rtol=1e-5)
|
||||
np.testing.assert_allclose((a.numpy()*b.numpy()), wc.numpy(), atol=1e-4, rtol=1e-5)
|
||||
|
||||
@unittest.skipUnless((not isinstance(Device.default, HCQCompiled)) and Device.default.graph is not None, "must be non-hcq with graph")
|
||||
def test_jit_several_incompatible_devs(self):
|
||||
assert isinstance(Device["CPU"], HCQCompiled) and Device["CPU"].graph is not None
|
||||
assert (not isinstance(Device.default, HCQCompiled)) and Device.default.graph is not None
|
||||
|
||||
d0, d1 = Device.DEFAULT, "CPU"
|
||||
|
||||
@TinyJit
|
||||
def f(a0, b0):
|
||||
a1 = (a + 2.0).contiguous().realize()
|
||||
a2 = (a1 * 2.0).contiguous().realize()
|
||||
|
||||
b1 = (b0 + 2.0).contiguous().realize()
|
||||
b2 = (b1 * 2.0).contiguous().realize()
|
||||
|
||||
return a2, b2
|
||||
|
||||
for _ in range(5):
|
||||
a = Tensor.randn(10, 10, device=d0).realize()
|
||||
b = Tensor.randn(10, 10, device=d1).realize()
|
||||
a1, b1 = f(a, b)
|
||||
np.testing.assert_allclose(((a.numpy()+2.0)*2.0), a1.numpy(), atol=1e-4, rtol=1e-5)
|
||||
np.testing.assert_allclose(((b.numpy()+2.0)*2.0), b1.numpy(), atol=1e-4, rtol=1e-5)
|
||||
|
||||
assert all(isinstance(ei.prg, GraphRunner) for ei in f.jit_cache), repr(f.jit_cache)
|
||||
|
||||
@unittest.skipIf(not_support_multi_device(), "no multi")
|
||||
def test_jitted_view(self):
|
||||
d0, d1 = f"{Device.DEFAULT}:0", f"{Device.DEFAULT}:1"
|
||||
@@ -696,5 +670,169 @@ class TestJitFree(unittest.TestCase):
|
||||
out = fxn(Tensor([11,1,2,3,4]))
|
||||
self.assertEqual(out.item(), 13600)
|
||||
|
||||
class TestJitGraphSplit(unittest.TestCase):
|
||||
def compute(self, device, inp):
|
||||
assert inp.device == device, f"Input device {inp.device} does not match expected {device}"
|
||||
return (inp + 1.0).contiguous().realize()
|
||||
|
||||
def copy(self, device, to_device, inp):
|
||||
assert inp.device == device, f"Input device {inp.device} does not match expected {device}"
|
||||
return inp.to(to_device).realize()
|
||||
|
||||
def expect(self, f, *args, graph=None, multigraph=None, hcqgraph=None):
|
||||
def _numpies(tpl): return tpl.numpy() if tpl.__class__ is Tensor else tuple([t.numpy() for t in tpl])
|
||||
|
||||
expected = _numpies(f(*args))
|
||||
for i in range(4):
|
||||
res = _numpies(f(*args))
|
||||
np.testing.assert_allclose(res, expected, atol=1e-4, rtol=1e-5)
|
||||
|
||||
dev = Device[Device.DEFAULT]
|
||||
graph_t = graph_class(dev)
|
||||
if graph_t is None: return
|
||||
|
||||
got = f.jit_cache
|
||||
from tinygrad.runtime.graph.hcq import HCQGraph
|
||||
if graph_t is HCQGraph:
|
||||
validate = hcqgraph
|
||||
elif issubclass(graph_t, MultiGraphRunner):
|
||||
validate = multigraph
|
||||
else:
|
||||
validate = graph
|
||||
|
||||
assert len(got) == len(validate), f"Expected {len(validate)} operations, got {len(got)}"
|
||||
for expected, got in zip(validate, got):
|
||||
if expected["type"] == "graph":
|
||||
assert isinstance(got.prg, GraphRunner), f"Expected GraphRunner, got {type(got.prg)}"
|
||||
assert len(got.prg.jit_cache) == expected["cnt"], f"Expected {expected['cnt']} operations in graph, got {len(got.prg.jit_cache)}"
|
||||
elif expected["type"] == "comp":
|
||||
assert isinstance(got.prg, CompiledRunner), f"Expected CompiledRunner, got {type(got.prg)}"
|
||||
elif expected["type"] == "copy":
|
||||
assert isinstance(got.prg, BufferCopy), f"Expected BufferCopy, got {type(got.prg)}"
|
||||
elif expected["type"] == "xfer":
|
||||
assert isinstance(got.prg, BufferXfer), f"Expected BufferXfer, got {type(got.prg)}"
|
||||
|
||||
def ji_graph(self, cnt): return {"type": "graph", "cnt": cnt}
|
||||
def ji_comp(self): return {"type": "comp"}
|
||||
def ji_copy(self): return {"type": "copy"}
|
||||
def ji_xfer(self): return {"type": "xfer"}
|
||||
|
||||
def test_jit_split_simple(self):
|
||||
if Device.DEFAULT == "REMOTE": raise unittest.SkipTest("REMOTE gpu is broken")
|
||||
|
||||
@TinyJit
|
||||
def f(inp):
|
||||
op0 = self.compute(Device.DEFAULT, inp)
|
||||
op1 = self.compute(Device.DEFAULT, op0)
|
||||
op2 = self.compute(Device.DEFAULT, op1)
|
||||
return op2
|
||||
|
||||
inp = Tensor.randn(10, 10, device=Device.DEFAULT).realize()
|
||||
self.expect(f, inp,
|
||||
graph=[self.ji_graph(3)],
|
||||
multigraph=[self.ji_graph(3)],
|
||||
hcqgraph=[self.ji_graph(3)])
|
||||
|
||||
def test_jit_cpu_simple(self):
|
||||
if Device.DEFAULT == "CPU": raise unittest.SkipTest("CPU is not a valid default device for this test")
|
||||
|
||||
@TinyJit
|
||||
def f(inp, inp_cpu):
|
||||
op0 = self.compute(Device.DEFAULT, inp)
|
||||
op1 = self.compute(Device.DEFAULT, op0)
|
||||
op2 = self.compute("CPU", inp_cpu)
|
||||
op3 = self.compute(Device.DEFAULT, op1)
|
||||
return op2, op3
|
||||
|
||||
inp = Tensor.randn(10, 10, device=Device.DEFAULT).realize()
|
||||
inp_cpu = Tensor.randn(10, 10, device="CPU").realize()
|
||||
self.expect(f, inp, inp_cpu,
|
||||
graph=[self.ji_graph(2), self.ji_comp(), self.ji_comp()],
|
||||
multigraph=[self.ji_graph(2), self.ji_comp(), self.ji_comp()],
|
||||
hcqgraph=[self.ji_graph(4)])
|
||||
|
||||
def test_jit_cpu_several(self):
|
||||
if Device.DEFAULT == "CPU": raise unittest.SkipTest("CPU is not a valid default device for this test")
|
||||
|
||||
@TinyJit
|
||||
def f(inp, inp_cpu):
|
||||
op0 = self.compute(Device.DEFAULT, inp)
|
||||
op1 = self.compute(Device.DEFAULT, op0)
|
||||
op2 = self.compute("CPU", inp_cpu)
|
||||
op3 = self.compute("CPU", op2)
|
||||
op4 = self.compute(Device.DEFAULT, op1)
|
||||
return op3, op4
|
||||
|
||||
inp = Tensor.randn(10, 10, device=Device.DEFAULT).realize()
|
||||
inp_cpu = Tensor.randn(10, 10, device="CPU").realize()
|
||||
self.expect(f, inp, inp_cpu,
|
||||
graph=[self.ji_graph(2), self.ji_graph(2), self.ji_comp()],
|
||||
multigraph=[self.ji_graph(2), self.ji_graph(2), self.ji_comp()],
|
||||
hcqgraph=[self.ji_graph(5)])
|
||||
|
||||
def test_jit_multidev(self):
|
||||
if Device.DEFAULT == "CPU": raise unittest.SkipTest("CPU is not a valid default device for this test")
|
||||
|
||||
try: Device[f"{Device.DEFAULT}:1"]
|
||||
except Exception: raise unittest.SkipTest("no multidevice")
|
||||
|
||||
@TinyJit
|
||||
def f(inp, inp_d1):
|
||||
op0 = self.compute(Device.DEFAULT, inp)
|
||||
op1 = self.compute(Device.DEFAULT, op0)
|
||||
op2 = self.compute(f"{Device.DEFAULT}:1", inp_d1)
|
||||
op3 = self.compute(f"{Device.DEFAULT}:1", op2)
|
||||
op4 = self.compute(Device.DEFAULT, op1)
|
||||
return op3, op4
|
||||
|
||||
inp = Tensor.randn(10, 10, device=Device.DEFAULT).realize()
|
||||
inp_d1 = Tensor.randn(10, 10, device=f"{Device.DEFAULT}:1").realize()
|
||||
self.expect(f, inp, inp_d1,
|
||||
graph=[self.ji_graph(2), self.ji_graph(2), self.ji_comp()],
|
||||
multigraph=[self.ji_graph(5)],
|
||||
hcqgraph=[self.ji_graph(5)])
|
||||
|
||||
def test_jit_multidev_xfer(self):
|
||||
if Device.DEFAULT in {"CPU", "LLVM"}: raise unittest.SkipTest("CPU/LLVM is not a valid default device for this test (zero-copies)")
|
||||
|
||||
try: Device[f"{Device.DEFAULT}:1"]
|
||||
except Exception: raise unittest.SkipTest("no multidevice")
|
||||
|
||||
@TinyJit
|
||||
def f(inp, inp_d1):
|
||||
op0 = self.compute(Device.DEFAULT, inp)
|
||||
op1 = self.compute(Device.DEFAULT, op0)
|
||||
op2 = self.compute(f"{Device.DEFAULT}:1", inp_d1)
|
||||
op3 = self.copy(f"{Device.DEFAULT}:1", Device.DEFAULT, op2)
|
||||
op4 = self.compute(f"{Device.DEFAULT}:1", op2)
|
||||
op5 = self.compute(Device.DEFAULT, op3)
|
||||
return op1, op4, op5
|
||||
|
||||
inp = Tensor.randn(10, 10, device=Device.DEFAULT).realize()
|
||||
inp_d1 = Tensor.randn(10, 10, device=f"{Device.DEFAULT}:1").realize()
|
||||
self.expect(f, inp, inp_d1,
|
||||
graph=[self.ji_graph(2), self.ji_comp(), self.ji_xfer(), self.ji_comp(), self.ji_comp()],
|
||||
multigraph=[self.ji_graph(6)],
|
||||
hcqgraph=[self.ji_graph(6)])
|
||||
|
||||
@unittest.skipIf(getenv("MOCKGPU"), "MockGPU does not support parallel copies")
|
||||
def test_jit_multidev_copy(self):
|
||||
if Device.DEFAULT in {"CPU", "LLVM"}: raise unittest.SkipTest("CPU/LLVM is not a valid default device for this test (zero-copies)")
|
||||
if Device.DEFAULT == "REMOTE": raise unittest.SkipTest("REMOTE gpu is broken")
|
||||
|
||||
@TinyJit
|
||||
def f(inp):
|
||||
op0 = self.compute(Device.DEFAULT, inp)
|
||||
op1 = self.compute(Device.DEFAULT, op0)
|
||||
op2 = self.copy(Device.DEFAULT, "CPU", op1)
|
||||
op3 = self.compute("CPU", op2)
|
||||
return op3
|
||||
|
||||
inp = Tensor.randn(10, 10, device=Device.DEFAULT).realize()
|
||||
self.expect(f, inp,
|
||||
graph=[self.ji_graph(2), self.ji_copy(), self.ji_comp()],
|
||||
multigraph=[self.ji_graph(2), self.ji_copy(), self.ji_comp()],
|
||||
hcqgraph=[self.ji_graph(4)])
|
||||
|
||||
if __name__ == '__main__':
|
||||
unittest.main()
|
||||
|
||||
Reference in New Issue
Block a user