diff --git a/test/backend/test_jit.py b/test/backend/test_jit.py index fa92bca76f..d73f96508a 100644 --- a/test/backend/test_jit.py +++ b/test/backend/test_jit.py @@ -1,15 +1,15 @@ #!/usr/bin/env python -import unittest, functools +import unittest import numpy as np from hypothesis import given, settings, strategies as strat -from test.helpers import assert_jit_cache_len, not_support_multi_device, needs_second_gpu +from test.helpers import assert_jit_cache_len, call_is_graph, not_support_multi_device, needs_second_gpu from tinygrad.tensor import Tensor -from tinygrad.engine.jit import TinyJit, JitError, GraphRunner, MultiGraphRunner, graph_class -from tinygrad.engine.realize import CompiledRunner, BufferCopy, BufferXfer +from tinygrad.engine.jit import TinyJit, JitError, graph_class from tinygrad.device import Device from tinygrad.helpers import Context, JIT, DEV, GlobalCounters from tinygrad.dtype import dtypes +from tinygrad.uop.ops import Ops from extra.models.unet import ResBlock def _simple_test(add, extract=lambda x: x, N=10): @@ -419,10 +419,10 @@ class TestJit(unittest.TestCase): if prev is not None: np.testing.assert_allclose(o, prev, atol=1e-4, rtol=1e-5) prev = o - graph_t = Device[Device.DEFAULT].graph.func if isinstance(Device[Device.DEFAULT].graph, functools.partial) else Device[Device.DEFAULT].graph # Checking that 2 graphs are inited. - assert isinstance(jf.jit_cache[0].prg, graph_t) - assert isinstance(jf.jit_cache[1].prg, graph_t) + assert len(jf.captured.linear.src) == 2 + for si in jf.captured.linear.src: + assert call_is_graph(si) def test_jitted_clone(self): def f(a): return a.clone().realize() @@ -583,7 +583,7 @@ class TestJitPrune(unittest.TestCase): a = Tensor.rand(16).realize() out = w2_prune(a) np.testing.assert_allclose(out.tolist(), [x*2+y for x,y in zip(weights.tolist(), a.tolist())]) - assert len(w2_prune.captured.jit_cache) == 1 + assert_jit_cache_len(w2_prune, 1) def test_prune_w_copy_correct(self): weights = Tensor.rand(16).realize() @@ -617,7 +617,7 @@ class TestJitPrune(unittest.TestCase): out = w2_prune(a) np.testing.assert_allclose(out.tolist(), [x*2+y for x,y in zip(weights.tolist(), a.tolist())]) - assert len(w2_prune.captured.jit_cache) == 1, "prune should have removed the copy" + assert_jit_cache_len(w2_prune, 1) class TestJitFree(unittest.TestCase): def test_free_intermediates(self): @@ -688,8 +688,9 @@ class TestJitGraphSplit(unittest.TestCase): graph_t = graph_class(dev) if graph_t is None: return - got = f.jit_cache + got = f.captured.linear.src from tinygrad.runtime.graph.hcq import HCQGraph + from tinygrad.engine.jit import MultiGraphRunner if graph_t is HCQGraph: validate = hcqgraph elif issubclass(graph_t, MultiGraphRunner): @@ -698,16 +699,16 @@ class TestJitGraphSplit(unittest.TestCase): validate = graph assert len(got) == len(validate), f"Expected {len(validate)} operations, got {len(got)}" - for expected, got in zip(validate, got): + for expected, si in zip(validate, got): + ast = si.src[0] if expected["type"] == "graph": - assert isinstance(got.prg, GraphRunner), f"Expected GraphRunner, got {type(got.prg)}" - assert len(got.prg.jit_cache) == expected["cnt"], f"Expected {expected['cnt']} operations in graph, got {len(got.prg.jit_cache)}" + assert call_is_graph(si), f"Expected graph, got {ast.op}" + inner_cnt = len(ast.src[0].src) + assert inner_cnt == expected["cnt"], f"Expected {expected['cnt']} operations in graph, got {inner_cnt}" elif expected["type"] == "comp": - assert isinstance(got.prg, CompiledRunner), f"Expected CompiledRunner, got {type(got.prg)}" - elif expected["type"] == "copy": - assert isinstance(got.prg, BufferCopy), f"Expected BufferCopy, got {type(got.prg)}" - elif expected["type"] == "xfer": - assert isinstance(got.prg, BufferXfer), f"Expected BufferXfer, got {type(got.prg)}" + assert ast.op in (Ops.SINK, Ops.PROGRAM, Ops.BEAM), f"Expected kernel, got {ast.op}" + elif expected["type"] in ("copy", "xfer"): + assert ast.op is Ops.COPY, f"Expected COPY, got {ast.op}" def ji_graph(self, cnt): return {"type": "graph", "cnt": cnt} def ji_comp(self): return {"type": "comp"} diff --git a/test/backend/test_multitensor.py b/test/backend/test_multitensor.py index bdda8656d7..1e5f620191 100644 --- a/test/backend/test_multitensor.py +++ b/test/backend/test_multitensor.py @@ -1,13 +1,13 @@ -import unittest, functools, random +import unittest, random from tinygrad import Tensor, Device, nn, GlobalCounters, TinyJit, dtypes, Variable from tinygrad.device import is_dtype_supported from tinygrad.uop.ops import Ops, UOp from tinygrad.helpers import getenv, prod, Context from tinygrad.nn.state import get_parameters, get_state_dict -from tinygrad.engine.realize import BufferCopy, CompiledRunner, run_schedule +from tinygrad.engine.realize import CompiledRunner, run_schedule import numpy as np from hypothesis import given, strategies as strat, settings -from test.helpers import not_support_multi_device, needs_second_gpu, slow +from test.helpers import not_support_multi_device, needs_second_gpu, slow, call_is_graph settings.register_profile("my_profile", max_examples=200, deadline=None, derandomize=getenv("DERANDOMIZE_CI", False)) settings.load_profile("my_profile") @@ -544,7 +544,7 @@ class TestMultiTensor(unittest.TestCase): b.shard_(devices_2) c = jf(a, b) np.testing.assert_allclose(c.numpy(), a.numpy()+b.numpy(), atol=1e-4, rtol=1e-5) - assert len(jf.jit_cache) > 0 + assert jf.captured is not None def test_multi_tensor_jit_body(self): @TinyJit @@ -558,7 +558,7 @@ class TestMultiTensor(unittest.TestCase): for _ in range(5): r = jf() np.testing.assert_allclose(r.numpy(), np.ones(256)+np.ones(256), atol=1e-4, rtol=1e-5) - assert len(jf.jit_cache) > 0 + assert jf.captured is not None def test_multitensor_jit_in_list(self): # test MULTI tensor inside a list container - exercises the container unpacking + MULTI unpacking @@ -618,15 +618,12 @@ class TestMultiTensor(unittest.TestCase): o = jf(a, b, c, d).numpy() np.testing.assert_allclose(ref, o, atol=1e-4, rtol=1e-5) - graph_d0 = Device[d0].graph.func if isinstance(Device[d0].graph, functools.partial) else Device[d0].graph - graph_d1 = Device[d1].graph.func if isinstance(Device[d1].graph, functools.partial) else Device[d1].graph # Checking that 2 graphs per device, 1 copy and 1 last graph on device 1 are created. - assert isinstance(jf.jit_cache[0].prg, graph_d0) - assert isinstance(jf.jit_cache[1].prg, graph_d0) - assert isinstance(jf.jit_cache[2].prg, graph_d1) - assert isinstance(jf.jit_cache[3].prg, graph_d1) - assert isinstance(jf.jit_cache[4].prg, BufferCopy) - assert isinstance(jf.jit_cache[5].prg, graph_d1) + sis = jf.captured.linear.src + assert len(sis) == 6 + for si in (sis[0], sis[1], sis[2], sis[3], sis[5]): + assert call_is_graph(si) + assert sis[4].src[0].op is Ops.COPY def test_bn_ast_on_devices(self): t = Tensor.empty((16, 64, 112, 112)).shard(devices_4, axis=0) diff --git a/test/helpers.py b/test/helpers.py index 75b60f8e3e..ed4f55b19e 100644 --- a/test/helpers.py +++ b/test/helpers.py @@ -5,7 +5,7 @@ import numpy as np from tinygrad import Tensor, dtypes, Device from tinygrad.uop.ops import UOp, Ops, KernelInfo from tinygrad.tensor import _to_np_dtype -from tinygrad.engine.realize import Runner, get_program +from tinygrad.engine.realize import get_program from tinygrad.dtype import DType from tinygrad.nn.state import get_parameters from tinygrad.helpers import T, CI, Target @@ -31,18 +31,29 @@ def derandomize_model(model): p.replace(Tensor.empty(p.shape, device=p.device, dtype=p.dtype)) p.realize() +def call_is_graph(call:UOp) -> bool: + ast = call.src[0] + return ast.op is Ops.CUSTOM_FUNCTION and ast.arg == "graph" + +def jit_cache_count(linear:UOp) -> int: + n = 0 + for call in linear.src: + ast = call.src[0] + if ast.op is Ops.CUSTOM_FUNCTION and ast.arg == "graph": n += jit_cache_count(ast.src[0]) + else: n += 1 + return n + def assert_jit_cache_len(fxn, expected_len): - if not fxn.jit_cache: + linear = fxn.captured.linear if fxn.captured is not None else None + if linear is None or not linear.src: assert expected_len == 0, expected_len return - # until we have a better way of typing the prg in ExecItem - if issubclass(type(fxn.jit_cache[0].prg), Runner) and not type(fxn.jit_cache[0].prg).__name__.endswith('Graph'): - assert len(fxn.jit_cache) == expected_len, f"expected {expected_len}, got {len(fxn.jit_cache)}" + if call_is_graph(linear.src[0]): + assert len(linear.src) == 1, len(linear.src) + inner = linear.src[0].src[0].src[0] # LINEAR UOp inside CUSTOM_FUNCTION + assert len(inner.src) == expected_len, f"expected {expected_len}, got {len(inner.src)}" else: - assert len(fxn.jit_cache) == 1, len(fxn.jit_cache) - # until we have a better way of typing the prg in ExecItem - assert type(fxn.jit_cache[0].prg).__name__.endswith('Graph') - assert len(fxn.jit_cache[0].prg.jit_cache) == expected_len, f"expected {expected_len}, got {len(fxn.jit_cache[0].prg.jit_cache)}" + assert len(linear.src) == expected_len, f"expected {expected_len}, got {len(linear.src)}" def rand_for_dtype(dt:DType, size:int, allow_subnormal=True): if dtypes.is_unsigned(dt): diff --git a/test/null/test_real_world.py b/test/null/test_real_world.py index 61b971d883..eface89ddb 100644 --- a/test/null/test_real_world.py +++ b/test/null/test_real_world.py @@ -6,7 +6,7 @@ from tinygrad.nn.state import get_parameters from tinygrad.engine.jit import TinyJit from tinygrad import Tensor, Device, GlobalCounters, dtypes, Variable from tinygrad.helpers import Context -from test.helpers import slow +from test.helpers import slow, jit_cache_count from extra.lr_scheduler import OneCycleLR from test.helpers import derandomize_model @@ -31,8 +31,7 @@ def helper_test(nm, gen, model, max_memory_allowed, max_kernels_allowed, all_jit tms.append(time.perf_counter_ns() - st) mem_used = (GlobalCounters.mem_used - global_mem_used) / 1e9 - # TODO: jit should expose this correctly with graph - kernels_used = len(model.jit_cache) if hasattr(model, "jit_cache") else None + kernels_used = jit_cache_count(model.captured.linear) if getattr(model, "captured", None) is not None else None print(f"{nm}: used {mem_used/1e9:.2f} GB and {kernels_used} kernels in {min(tms)/1e6:.2f} ms") assert mem_used < max_memory_allowed, f"{nm} used more than {max_memory_allowed:.3f} GB - {mem_used:.3} GB used" assert (max_memory_allowed - mem_used) / max_memory_allowed < 0.2, f"{max_memory_allowed:.3f} GB is too far from {mem_used:.3} GB used"