do not use jit_cache in test (#15823)

* do not use jit_cache in test

* fix
This commit is contained in:
nimlgen
2026-04-20 11:45:17 +03:00
committed by GitHub
parent 5819c0abed
commit c0d7135b5f
4 changed files with 51 additions and 43 deletions

View File

@@ -1,15 +1,15 @@
#!/usr/bin/env python
import unittest, functools
import unittest
import numpy as np
from hypothesis import given, settings, strategies as strat
from test.helpers import assert_jit_cache_len, not_support_multi_device, needs_second_gpu
from test.helpers import assert_jit_cache_len, call_is_graph, not_support_multi_device, needs_second_gpu
from tinygrad.tensor import Tensor
from tinygrad.engine.jit import TinyJit, JitError, GraphRunner, MultiGraphRunner, graph_class
from tinygrad.engine.realize import CompiledRunner, BufferCopy, BufferXfer
from tinygrad.engine.jit import TinyJit, JitError, graph_class
from tinygrad.device import Device
from tinygrad.helpers import Context, JIT, DEV, GlobalCounters
from tinygrad.dtype import dtypes
from tinygrad.uop.ops import Ops
from extra.models.unet import ResBlock
def _simple_test(add, extract=lambda x: x, N=10):
@@ -419,10 +419,10 @@ class TestJit(unittest.TestCase):
if prev is not None: np.testing.assert_allclose(o, prev, atol=1e-4, rtol=1e-5)
prev = o
graph_t = Device[Device.DEFAULT].graph.func if isinstance(Device[Device.DEFAULT].graph, functools.partial) else Device[Device.DEFAULT].graph
# Checking that 2 graphs are inited.
assert isinstance(jf.jit_cache[0].prg, graph_t)
assert isinstance(jf.jit_cache[1].prg, graph_t)
assert len(jf.captured.linear.src) == 2
for si in jf.captured.linear.src:
assert call_is_graph(si)
def test_jitted_clone(self):
def f(a): return a.clone().realize()
@@ -583,7 +583,7 @@ class TestJitPrune(unittest.TestCase):
a = Tensor.rand(16).realize()
out = w2_prune(a)
np.testing.assert_allclose(out.tolist(), [x*2+y for x,y in zip(weights.tolist(), a.tolist())])
assert len(w2_prune.captured.jit_cache) == 1
assert_jit_cache_len(w2_prune, 1)
def test_prune_w_copy_correct(self):
weights = Tensor.rand(16).realize()
@@ -617,7 +617,7 @@ class TestJitPrune(unittest.TestCase):
out = w2_prune(a)
np.testing.assert_allclose(out.tolist(), [x*2+y for x,y in zip(weights.tolist(), a.tolist())])
assert len(w2_prune.captured.jit_cache) == 1, "prune should have removed the copy"
assert_jit_cache_len(w2_prune, 1)
class TestJitFree(unittest.TestCase):
def test_free_intermediates(self):
@@ -688,8 +688,9 @@ class TestJitGraphSplit(unittest.TestCase):
graph_t = graph_class(dev)
if graph_t is None: return
got = f.jit_cache
got = f.captured.linear.src
from tinygrad.runtime.graph.hcq import HCQGraph
from tinygrad.engine.jit import MultiGraphRunner
if graph_t is HCQGraph:
validate = hcqgraph
elif issubclass(graph_t, MultiGraphRunner):
@@ -698,16 +699,16 @@ class TestJitGraphSplit(unittest.TestCase):
validate = graph
assert len(got) == len(validate), f"Expected {len(validate)} operations, got {len(got)}"
for expected, got in zip(validate, got):
for expected, si in zip(validate, got):
ast = si.src[0]
if expected["type"] == "graph":
assert isinstance(got.prg, GraphRunner), f"Expected GraphRunner, got {type(got.prg)}"
assert len(got.prg.jit_cache) == expected["cnt"], f"Expected {expected['cnt']} operations in graph, got {len(got.prg.jit_cache)}"
assert call_is_graph(si), f"Expected graph, got {ast.op}"
inner_cnt = len(ast.src[0].src)
assert inner_cnt == expected["cnt"], f"Expected {expected['cnt']} operations in graph, got {inner_cnt}"
elif expected["type"] == "comp":
assert isinstance(got.prg, CompiledRunner), f"Expected CompiledRunner, got {type(got.prg)}"
elif expected["type"] == "copy":
assert isinstance(got.prg, BufferCopy), f"Expected BufferCopy, got {type(got.prg)}"
elif expected["type"] == "xfer":
assert isinstance(got.prg, BufferXfer), f"Expected BufferXfer, got {type(got.prg)}"
assert ast.op in (Ops.SINK, Ops.PROGRAM, Ops.BEAM), f"Expected kernel, got {ast.op}"
elif expected["type"] in ("copy", "xfer"):
assert ast.op is Ops.COPY, f"Expected COPY, got {ast.op}"
def ji_graph(self, cnt): return {"type": "graph", "cnt": cnt}
def ji_comp(self): return {"type": "comp"}

View File

@@ -1,13 +1,13 @@
import unittest, functools, random
import unittest, random
from tinygrad import Tensor, Device, nn, GlobalCounters, TinyJit, dtypes, Variable
from tinygrad.device import is_dtype_supported
from tinygrad.uop.ops import Ops, UOp
from tinygrad.helpers import getenv, prod, Context
from tinygrad.nn.state import get_parameters, get_state_dict
from tinygrad.engine.realize import BufferCopy, CompiledRunner, run_schedule
from tinygrad.engine.realize import CompiledRunner, run_schedule
import numpy as np
from hypothesis import given, strategies as strat, settings
from test.helpers import not_support_multi_device, needs_second_gpu, slow
from test.helpers import not_support_multi_device, needs_second_gpu, slow, call_is_graph
settings.register_profile("my_profile", max_examples=200, deadline=None, derandomize=getenv("DERANDOMIZE_CI", False))
settings.load_profile("my_profile")
@@ -544,7 +544,7 @@ class TestMultiTensor(unittest.TestCase):
b.shard_(devices_2)
c = jf(a, b)
np.testing.assert_allclose(c.numpy(), a.numpy()+b.numpy(), atol=1e-4, rtol=1e-5)
assert len(jf.jit_cache) > 0
assert jf.captured is not None
def test_multi_tensor_jit_body(self):
@TinyJit
@@ -558,7 +558,7 @@ class TestMultiTensor(unittest.TestCase):
for _ in range(5):
r = jf()
np.testing.assert_allclose(r.numpy(), np.ones(256)+np.ones(256), atol=1e-4, rtol=1e-5)
assert len(jf.jit_cache) > 0
assert jf.captured is not None
def test_multitensor_jit_in_list(self):
# test MULTI tensor inside a list container - exercises the container unpacking + MULTI unpacking
@@ -618,15 +618,12 @@ class TestMultiTensor(unittest.TestCase):
o = jf(a, b, c, d).numpy()
np.testing.assert_allclose(ref, o, atol=1e-4, rtol=1e-5)
graph_d0 = Device[d0].graph.func if isinstance(Device[d0].graph, functools.partial) else Device[d0].graph
graph_d1 = Device[d1].graph.func if isinstance(Device[d1].graph, functools.partial) else Device[d1].graph
# Checking that 2 graphs per device, 1 copy and 1 last graph on device 1 are created.
assert isinstance(jf.jit_cache[0].prg, graph_d0)
assert isinstance(jf.jit_cache[1].prg, graph_d0)
assert isinstance(jf.jit_cache[2].prg, graph_d1)
assert isinstance(jf.jit_cache[3].prg, graph_d1)
assert isinstance(jf.jit_cache[4].prg, BufferCopy)
assert isinstance(jf.jit_cache[5].prg, graph_d1)
sis = jf.captured.linear.src
assert len(sis) == 6
for si in (sis[0], sis[1], sis[2], sis[3], sis[5]):
assert call_is_graph(si)
assert sis[4].src[0].op is Ops.COPY
def test_bn_ast_on_devices(self):
t = Tensor.empty((16, 64, 112, 112)).shard(devices_4, axis=0)

View File

@@ -5,7 +5,7 @@ import numpy as np
from tinygrad import Tensor, dtypes, Device
from tinygrad.uop.ops import UOp, Ops, KernelInfo
from tinygrad.tensor import _to_np_dtype
from tinygrad.engine.realize import Runner, get_program
from tinygrad.engine.realize import get_program
from tinygrad.dtype import DType
from tinygrad.nn.state import get_parameters
from tinygrad.helpers import T, CI, Target
@@ -31,18 +31,29 @@ def derandomize_model(model):
p.replace(Tensor.empty(p.shape, device=p.device, dtype=p.dtype))
p.realize()
def call_is_graph(call:UOp) -> bool:
ast = call.src[0]
return ast.op is Ops.CUSTOM_FUNCTION and ast.arg == "graph"
def jit_cache_count(linear:UOp) -> int:
n = 0
for call in linear.src:
ast = call.src[0]
if ast.op is Ops.CUSTOM_FUNCTION and ast.arg == "graph": n += jit_cache_count(ast.src[0])
else: n += 1
return n
def assert_jit_cache_len(fxn, expected_len):
if not fxn.jit_cache:
linear = fxn.captured.linear if fxn.captured is not None else None
if linear is None or not linear.src:
assert expected_len == 0, expected_len
return
# until we have a better way of typing the prg in ExecItem
if issubclass(type(fxn.jit_cache[0].prg), Runner) and not type(fxn.jit_cache[0].prg).__name__.endswith('Graph'):
assert len(fxn.jit_cache) == expected_len, f"expected {expected_len}, got {len(fxn.jit_cache)}"
if call_is_graph(linear.src[0]):
assert len(linear.src) == 1, len(linear.src)
inner = linear.src[0].src[0].src[0] # LINEAR UOp inside CUSTOM_FUNCTION
assert len(inner.src) == expected_len, f"expected {expected_len}, got {len(inner.src)}"
else:
assert len(fxn.jit_cache) == 1, len(fxn.jit_cache)
# until we have a better way of typing the prg in ExecItem
assert type(fxn.jit_cache[0].prg).__name__.endswith('Graph')
assert len(fxn.jit_cache[0].prg.jit_cache) == expected_len, f"expected {expected_len}, got {len(fxn.jit_cache[0].prg.jit_cache)}"
assert len(linear.src) == expected_len, f"expected {expected_len}, got {len(linear.src)}"
def rand_for_dtype(dt:DType, size:int, allow_subnormal=True):
if dtypes.is_unsigned(dt):

View File

@@ -6,7 +6,7 @@ from tinygrad.nn.state import get_parameters
from tinygrad.engine.jit import TinyJit
from tinygrad import Tensor, Device, GlobalCounters, dtypes, Variable
from tinygrad.helpers import Context
from test.helpers import slow
from test.helpers import slow, jit_cache_count
from extra.lr_scheduler import OneCycleLR
from test.helpers import derandomize_model
@@ -31,8 +31,7 @@ def helper_test(nm, gen, model, max_memory_allowed, max_kernels_allowed, all_jit
tms.append(time.perf_counter_ns() - st)
mem_used = (GlobalCounters.mem_used - global_mem_used) / 1e9
# TODO: jit should expose this correctly with graph
kernels_used = len(model.jit_cache) if hasattr(model, "jit_cache") else None
kernels_used = jit_cache_count(model.captured.linear) if getattr(model, "captured", None) is not None else None
print(f"{nm}: used {mem_used/1e9:.2f} GB and {kernels_used} kernels in {min(tms)/1e6:.2f} ms")
assert mem_used < max_memory_allowed, f"{nm} used more than {max_memory_allowed:.3f} GB - {mem_used:.3} GB used"
assert (max_memory_allowed - mem_used) / max_memory_allowed < 0.2, f"{max_memory_allowed:.3f} GB is too far from {mem_used:.3} GB used"