mirror of
https://github.com/tinygrad/tinygrad.git
synced 2026-04-29 03:00:14 -04:00
do not use jit_cache in test (#15823)
* do not use jit_cache in test * fix
This commit is contained in:
@@ -1,15 +1,15 @@
|
||||
#!/usr/bin/env python
|
||||
import unittest, functools
|
||||
import unittest
|
||||
import numpy as np
|
||||
|
||||
from hypothesis import given, settings, strategies as strat
|
||||
from test.helpers import assert_jit_cache_len, not_support_multi_device, needs_second_gpu
|
||||
from test.helpers import assert_jit_cache_len, call_is_graph, not_support_multi_device, needs_second_gpu
|
||||
from tinygrad.tensor import Tensor
|
||||
from tinygrad.engine.jit import TinyJit, JitError, GraphRunner, MultiGraphRunner, graph_class
|
||||
from tinygrad.engine.realize import CompiledRunner, BufferCopy, BufferXfer
|
||||
from tinygrad.engine.jit import TinyJit, JitError, graph_class
|
||||
from tinygrad.device import Device
|
||||
from tinygrad.helpers import Context, JIT, DEV, GlobalCounters
|
||||
from tinygrad.dtype import dtypes
|
||||
from tinygrad.uop.ops import Ops
|
||||
from extra.models.unet import ResBlock
|
||||
|
||||
def _simple_test(add, extract=lambda x: x, N=10):
|
||||
@@ -419,10 +419,10 @@ class TestJit(unittest.TestCase):
|
||||
if prev is not None: np.testing.assert_allclose(o, prev, atol=1e-4, rtol=1e-5)
|
||||
prev = o
|
||||
|
||||
graph_t = Device[Device.DEFAULT].graph.func if isinstance(Device[Device.DEFAULT].graph, functools.partial) else Device[Device.DEFAULT].graph
|
||||
# Checking that 2 graphs are inited.
|
||||
assert isinstance(jf.jit_cache[0].prg, graph_t)
|
||||
assert isinstance(jf.jit_cache[1].prg, graph_t)
|
||||
assert len(jf.captured.linear.src) == 2
|
||||
for si in jf.captured.linear.src:
|
||||
assert call_is_graph(si)
|
||||
|
||||
def test_jitted_clone(self):
|
||||
def f(a): return a.clone().realize()
|
||||
@@ -583,7 +583,7 @@ class TestJitPrune(unittest.TestCase):
|
||||
a = Tensor.rand(16).realize()
|
||||
out = w2_prune(a)
|
||||
np.testing.assert_allclose(out.tolist(), [x*2+y for x,y in zip(weights.tolist(), a.tolist())])
|
||||
assert len(w2_prune.captured.jit_cache) == 1
|
||||
assert_jit_cache_len(w2_prune, 1)
|
||||
|
||||
def test_prune_w_copy_correct(self):
|
||||
weights = Tensor.rand(16).realize()
|
||||
@@ -617,7 +617,7 @@ class TestJitPrune(unittest.TestCase):
|
||||
out = w2_prune(a)
|
||||
np.testing.assert_allclose(out.tolist(), [x*2+y for x,y in zip(weights.tolist(), a.tolist())])
|
||||
|
||||
assert len(w2_prune.captured.jit_cache) == 1, "prune should have removed the copy"
|
||||
assert_jit_cache_len(w2_prune, 1)
|
||||
|
||||
class TestJitFree(unittest.TestCase):
|
||||
def test_free_intermediates(self):
|
||||
@@ -688,8 +688,9 @@ class TestJitGraphSplit(unittest.TestCase):
|
||||
graph_t = graph_class(dev)
|
||||
if graph_t is None: return
|
||||
|
||||
got = f.jit_cache
|
||||
got = f.captured.linear.src
|
||||
from tinygrad.runtime.graph.hcq import HCQGraph
|
||||
from tinygrad.engine.jit import MultiGraphRunner
|
||||
if graph_t is HCQGraph:
|
||||
validate = hcqgraph
|
||||
elif issubclass(graph_t, MultiGraphRunner):
|
||||
@@ -698,16 +699,16 @@ class TestJitGraphSplit(unittest.TestCase):
|
||||
validate = graph
|
||||
|
||||
assert len(got) == len(validate), f"Expected {len(validate)} operations, got {len(got)}"
|
||||
for expected, got in zip(validate, got):
|
||||
for expected, si in zip(validate, got):
|
||||
ast = si.src[0]
|
||||
if expected["type"] == "graph":
|
||||
assert isinstance(got.prg, GraphRunner), f"Expected GraphRunner, got {type(got.prg)}"
|
||||
assert len(got.prg.jit_cache) == expected["cnt"], f"Expected {expected['cnt']} operations in graph, got {len(got.prg.jit_cache)}"
|
||||
assert call_is_graph(si), f"Expected graph, got {ast.op}"
|
||||
inner_cnt = len(ast.src[0].src)
|
||||
assert inner_cnt == expected["cnt"], f"Expected {expected['cnt']} operations in graph, got {inner_cnt}"
|
||||
elif expected["type"] == "comp":
|
||||
assert isinstance(got.prg, CompiledRunner), f"Expected CompiledRunner, got {type(got.prg)}"
|
||||
elif expected["type"] == "copy":
|
||||
assert isinstance(got.prg, BufferCopy), f"Expected BufferCopy, got {type(got.prg)}"
|
||||
elif expected["type"] == "xfer":
|
||||
assert isinstance(got.prg, BufferXfer), f"Expected BufferXfer, got {type(got.prg)}"
|
||||
assert ast.op in (Ops.SINK, Ops.PROGRAM, Ops.BEAM), f"Expected kernel, got {ast.op}"
|
||||
elif expected["type"] in ("copy", "xfer"):
|
||||
assert ast.op is Ops.COPY, f"Expected COPY, got {ast.op}"
|
||||
|
||||
def ji_graph(self, cnt): return {"type": "graph", "cnt": cnt}
|
||||
def ji_comp(self): return {"type": "comp"}
|
||||
|
||||
@@ -1,13 +1,13 @@
|
||||
import unittest, functools, random
|
||||
import unittest, random
|
||||
from tinygrad import Tensor, Device, nn, GlobalCounters, TinyJit, dtypes, Variable
|
||||
from tinygrad.device import is_dtype_supported
|
||||
from tinygrad.uop.ops import Ops, UOp
|
||||
from tinygrad.helpers import getenv, prod, Context
|
||||
from tinygrad.nn.state import get_parameters, get_state_dict
|
||||
from tinygrad.engine.realize import BufferCopy, CompiledRunner, run_schedule
|
||||
from tinygrad.engine.realize import CompiledRunner, run_schedule
|
||||
import numpy as np
|
||||
from hypothesis import given, strategies as strat, settings
|
||||
from test.helpers import not_support_multi_device, needs_second_gpu, slow
|
||||
from test.helpers import not_support_multi_device, needs_second_gpu, slow, call_is_graph
|
||||
|
||||
settings.register_profile("my_profile", max_examples=200, deadline=None, derandomize=getenv("DERANDOMIZE_CI", False))
|
||||
settings.load_profile("my_profile")
|
||||
@@ -544,7 +544,7 @@ class TestMultiTensor(unittest.TestCase):
|
||||
b.shard_(devices_2)
|
||||
c = jf(a, b)
|
||||
np.testing.assert_allclose(c.numpy(), a.numpy()+b.numpy(), atol=1e-4, rtol=1e-5)
|
||||
assert len(jf.jit_cache) > 0
|
||||
assert jf.captured is not None
|
||||
|
||||
def test_multi_tensor_jit_body(self):
|
||||
@TinyJit
|
||||
@@ -558,7 +558,7 @@ class TestMultiTensor(unittest.TestCase):
|
||||
for _ in range(5):
|
||||
r = jf()
|
||||
np.testing.assert_allclose(r.numpy(), np.ones(256)+np.ones(256), atol=1e-4, rtol=1e-5)
|
||||
assert len(jf.jit_cache) > 0
|
||||
assert jf.captured is not None
|
||||
|
||||
def test_multitensor_jit_in_list(self):
|
||||
# test MULTI tensor inside a list container - exercises the container unpacking + MULTI unpacking
|
||||
@@ -618,15 +618,12 @@ class TestMultiTensor(unittest.TestCase):
|
||||
o = jf(a, b, c, d).numpy()
|
||||
np.testing.assert_allclose(ref, o, atol=1e-4, rtol=1e-5)
|
||||
|
||||
graph_d0 = Device[d0].graph.func if isinstance(Device[d0].graph, functools.partial) else Device[d0].graph
|
||||
graph_d1 = Device[d1].graph.func if isinstance(Device[d1].graph, functools.partial) else Device[d1].graph
|
||||
# Checking that 2 graphs per device, 1 copy and 1 last graph on device 1 are created.
|
||||
assert isinstance(jf.jit_cache[0].prg, graph_d0)
|
||||
assert isinstance(jf.jit_cache[1].prg, graph_d0)
|
||||
assert isinstance(jf.jit_cache[2].prg, graph_d1)
|
||||
assert isinstance(jf.jit_cache[3].prg, graph_d1)
|
||||
assert isinstance(jf.jit_cache[4].prg, BufferCopy)
|
||||
assert isinstance(jf.jit_cache[5].prg, graph_d1)
|
||||
sis = jf.captured.linear.src
|
||||
assert len(sis) == 6
|
||||
for si in (sis[0], sis[1], sis[2], sis[3], sis[5]):
|
||||
assert call_is_graph(si)
|
||||
assert sis[4].src[0].op is Ops.COPY
|
||||
|
||||
def test_bn_ast_on_devices(self):
|
||||
t = Tensor.empty((16, 64, 112, 112)).shard(devices_4, axis=0)
|
||||
|
||||
@@ -5,7 +5,7 @@ import numpy as np
|
||||
from tinygrad import Tensor, dtypes, Device
|
||||
from tinygrad.uop.ops import UOp, Ops, KernelInfo
|
||||
from tinygrad.tensor import _to_np_dtype
|
||||
from tinygrad.engine.realize import Runner, get_program
|
||||
from tinygrad.engine.realize import get_program
|
||||
from tinygrad.dtype import DType
|
||||
from tinygrad.nn.state import get_parameters
|
||||
from tinygrad.helpers import T, CI, Target
|
||||
@@ -31,18 +31,29 @@ def derandomize_model(model):
|
||||
p.replace(Tensor.empty(p.shape, device=p.device, dtype=p.dtype))
|
||||
p.realize()
|
||||
|
||||
def call_is_graph(call:UOp) -> bool:
|
||||
ast = call.src[0]
|
||||
return ast.op is Ops.CUSTOM_FUNCTION and ast.arg == "graph"
|
||||
|
||||
def jit_cache_count(linear:UOp) -> int:
|
||||
n = 0
|
||||
for call in linear.src:
|
||||
ast = call.src[0]
|
||||
if ast.op is Ops.CUSTOM_FUNCTION and ast.arg == "graph": n += jit_cache_count(ast.src[0])
|
||||
else: n += 1
|
||||
return n
|
||||
|
||||
def assert_jit_cache_len(fxn, expected_len):
|
||||
if not fxn.jit_cache:
|
||||
linear = fxn.captured.linear if fxn.captured is not None else None
|
||||
if linear is None or not linear.src:
|
||||
assert expected_len == 0, expected_len
|
||||
return
|
||||
# until we have a better way of typing the prg in ExecItem
|
||||
if issubclass(type(fxn.jit_cache[0].prg), Runner) and not type(fxn.jit_cache[0].prg).__name__.endswith('Graph'):
|
||||
assert len(fxn.jit_cache) == expected_len, f"expected {expected_len}, got {len(fxn.jit_cache)}"
|
||||
if call_is_graph(linear.src[0]):
|
||||
assert len(linear.src) == 1, len(linear.src)
|
||||
inner = linear.src[0].src[0].src[0] # LINEAR UOp inside CUSTOM_FUNCTION
|
||||
assert len(inner.src) == expected_len, f"expected {expected_len}, got {len(inner.src)}"
|
||||
else:
|
||||
assert len(fxn.jit_cache) == 1, len(fxn.jit_cache)
|
||||
# until we have a better way of typing the prg in ExecItem
|
||||
assert type(fxn.jit_cache[0].prg).__name__.endswith('Graph')
|
||||
assert len(fxn.jit_cache[0].prg.jit_cache) == expected_len, f"expected {expected_len}, got {len(fxn.jit_cache[0].prg.jit_cache)}"
|
||||
assert len(linear.src) == expected_len, f"expected {expected_len}, got {len(linear.src)}"
|
||||
|
||||
def rand_for_dtype(dt:DType, size:int, allow_subnormal=True):
|
||||
if dtypes.is_unsigned(dt):
|
||||
|
||||
@@ -6,7 +6,7 @@ from tinygrad.nn.state import get_parameters
|
||||
from tinygrad.engine.jit import TinyJit
|
||||
from tinygrad import Tensor, Device, GlobalCounters, dtypes, Variable
|
||||
from tinygrad.helpers import Context
|
||||
from test.helpers import slow
|
||||
from test.helpers import slow, jit_cache_count
|
||||
from extra.lr_scheduler import OneCycleLR
|
||||
from test.helpers import derandomize_model
|
||||
|
||||
@@ -31,8 +31,7 @@ def helper_test(nm, gen, model, max_memory_allowed, max_kernels_allowed, all_jit
|
||||
tms.append(time.perf_counter_ns() - st)
|
||||
mem_used = (GlobalCounters.mem_used - global_mem_used) / 1e9
|
||||
|
||||
# TODO: jit should expose this correctly with graph
|
||||
kernels_used = len(model.jit_cache) if hasattr(model, "jit_cache") else None
|
||||
kernels_used = jit_cache_count(model.captured.linear) if getattr(model, "captured", None) is not None else None
|
||||
print(f"{nm}: used {mem_used/1e9:.2f} GB and {kernels_used} kernels in {min(tms)/1e6:.2f} ms")
|
||||
assert mem_used < max_memory_allowed, f"{nm} used more than {max_memory_allowed:.3f} GB - {mem_used:.3} GB used"
|
||||
assert (max_memory_allowed - mem_used) / max_memory_allowed < 0.2, f"{max_memory_allowed:.3f} GB is too far from {mem_used:.3} GB used"
|
||||
|
||||
Reference in New Issue
Block a user