Fix incorrect jit current batch devs reset (#11505)

`current_batch_devs = []` (in `flush_batch()`) happens between
`new_batched_devs = ...` and `current_batch_devs = new_batched_devs` =>
doesn't actually reset anything leading to things not jitting properly

which 2xs remote bert step time (should have similar effects on any
non-hcq backend)
This commit is contained in:
uuuvn
2025-08-05 05:16:16 +00:00
committed by GitHub
parent f02720ca2d
commit 011ef8fa9d
2 changed files with 31 additions and 4 deletions

View File

@@ -5,9 +5,10 @@ import numpy as np
from hypothesis import given, settings, strategies as strat from hypothesis import given, settings, strategies as strat
from test.helpers import assert_jit_cache_len, not_support_multi_device, REAL_DEV from test.helpers import assert_jit_cache_len, not_support_multi_device, REAL_DEV
from tinygrad.tensor import Tensor from tinygrad.tensor import Tensor
from tinygrad.engine.jit import TinyJit from tinygrad.engine.jit import TinyJit, GraphRunner
from tinygrad.device import Device from tinygrad.device import Device
from tinygrad.helpers import Context, JIT, GlobalCounters from tinygrad.helpers import Context, JIT, GlobalCounters
from tinygrad.runtime.support.hcq import HCQCompiled
from tinygrad.dtype import dtypes from tinygrad.dtype import dtypes
from extra.models.unet import ResBlock from extra.models.unet import ResBlock
@@ -471,6 +472,32 @@ class TestJit(unittest.TestCase):
np.testing.assert_allclose((a.numpy()+b.numpy()), zc.numpy(), atol=1e-4, rtol=1e-5) np.testing.assert_allclose((a.numpy()+b.numpy()), zc.numpy(), atol=1e-4, rtol=1e-5)
np.testing.assert_allclose((a.numpy()*b.numpy()), wc.numpy(), atol=1e-4, rtol=1e-5) np.testing.assert_allclose((a.numpy()*b.numpy()), wc.numpy(), atol=1e-4, rtol=1e-5)
@unittest.skipUnless((not isinstance(Device.default, HCQCompiled)) and Device.default.graph is not None, "must be non-hcq with graph")
def test_jit_several_incompatible_devs(self):
assert isinstance(Device["CPU"], HCQCompiled) and Device["CPU"].graph is not None
assert (not isinstance(Device.default, HCQCompiled)) and Device.default.graph is not None
d0, d1 = Device.DEFAULT, "CPU"
@TinyJit
def f(a0, b0):
a1 = (a + 2.0).contiguous().realize()
a2 = (a1 * 2.0).contiguous().realize()
b1 = (b0 + 2.0).contiguous().realize()
b2 = (b1 * 2.0).contiguous().realize()
return a2, b2
for _ in range(5):
a = Tensor.randn(10, 10, device=d0).realize()
b = Tensor.randn(10, 10, device=d1).realize()
a1, b1 = f(a, b)
np.testing.assert_allclose(((a.numpy()+2.0)*2.0), a1.numpy(), atol=1e-4, rtol=1e-5)
np.testing.assert_allclose(((b.numpy()+2.0)*2.0), b1.numpy(), atol=1e-4, rtol=1e-5)
assert all(isinstance(ei.prg, GraphRunner) for ei in f.jit_cache), repr(f.jit_cache)
@unittest.skipIf(not_support_multi_device(), "no multi") @unittest.skipIf(not_support_multi_device(), "no multi")
def test_jitted_view(self): def test_jitted_view(self):
d0, d1 = f"{Device.DEFAULT}:0", f"{Device.DEFAULT}:1" d0, d1 = f"{Device.DEFAULT}:0", f"{Device.DEFAULT}:1"

View File

@@ -52,14 +52,14 @@ def apply_graph_to_jit(jit_cache: list[ExecItem], input_rawbuffers: list[Buffer]
can_be_graphed = ji_graph_dev is not None and ji_graph_dev.graph is not None and graph_class(ji_graph_dev).supports_exec_item([ji_graph_dev], ji) can_be_graphed = ji_graph_dev is not None and ji_graph_dev.graph is not None and graph_class(ji_graph_dev).supports_exec_item([ji_graph_dev], ji)
# Check if the current batch can be extended with this item. # Check if the current batch can be extended with this item.
new_batched_devs = dedup(current_batch_devs + [ji_graph_dev]) can_share_graph = can_be_graphed and len(current_batch_devs) > 0 and \
can_share_graph = can_be_graphed and len(current_batch_devs) > 0 and graph_class(current_batch_devs[0]).supports_exec_item(new_batched_devs, ji) graph_class(current_batch_devs[0]).supports_exec_item(dedup(current_batch_devs + [ji_graph_dev]), ji)
can_extend_graph_batch = can_share_graph and (max_batch_size == 0 or len(current_batch) < max_batch_size) can_extend_graph_batch = can_share_graph and (max_batch_size == 0 or len(current_batch) < max_batch_size)
# Flush the current batch if any, since it can't be extended or is full. # Flush the current batch if any, since it can't be extended or is full.
if not can_extend_graph_batch and len(current_batch) > 0: flush_batch() if not can_extend_graph_batch and len(current_batch) > 0: flush_batch()
(current_batch if can_be_graphed else graphed_jit_cache).append(ji) (current_batch if can_be_graphed else graphed_jit_cache).append(ji)
current_batch_devs = new_batched_devs if can_be_graphed else [] current_batch_devs = dedup(current_batch_devs + [ji_graph_dev]) if can_be_graphed else []
if len(current_batch) > 0: flush_batch() if len(current_batch) > 0: flush_batch()
return graphed_jit_cache return graphed_jit_cache