mirror of
https://github.com/tinygrad/tinygrad.git
synced 2026-01-09 23:18:04 -05:00
Fix incorrect jit current batch devs reset (#11505)
`current_batch_devs = []` (in `flush_batch()`) happens between `new_batched_devs = ...` and `current_batch_devs = new_batched_devs` => doesn't actually reset anything leading to things not jitting properly which 2xs remote bert step time (should have similar effects on any non-hcq backend)
This commit is contained in:
@@ -5,9 +5,10 @@ import numpy as np
|
||||
from hypothesis import given, settings, strategies as strat
|
||||
from test.helpers import assert_jit_cache_len, not_support_multi_device, REAL_DEV
|
||||
from tinygrad.tensor import Tensor
|
||||
from tinygrad.engine.jit import TinyJit
|
||||
from tinygrad.engine.jit import TinyJit, GraphRunner
|
||||
from tinygrad.device import Device
|
||||
from tinygrad.helpers import Context, JIT, GlobalCounters
|
||||
from tinygrad.runtime.support.hcq import HCQCompiled
|
||||
from tinygrad.dtype import dtypes
|
||||
from extra.models.unet import ResBlock
|
||||
|
||||
@@ -471,6 +472,32 @@ class TestJit(unittest.TestCase):
|
||||
np.testing.assert_allclose((a.numpy()+b.numpy()), zc.numpy(), atol=1e-4, rtol=1e-5)
|
||||
np.testing.assert_allclose((a.numpy()*b.numpy()), wc.numpy(), atol=1e-4, rtol=1e-5)
|
||||
|
||||
@unittest.skipUnless((not isinstance(Device.default, HCQCompiled)) and Device.default.graph is not None, "must be non-hcq with graph")
|
||||
def test_jit_several_incompatible_devs(self):
|
||||
assert isinstance(Device["CPU"], HCQCompiled) and Device["CPU"].graph is not None
|
||||
assert (not isinstance(Device.default, HCQCompiled)) and Device.default.graph is not None
|
||||
|
||||
d0, d1 = Device.DEFAULT, "CPU"
|
||||
|
||||
@TinyJit
|
||||
def f(a0, b0):
|
||||
a1 = (a + 2.0).contiguous().realize()
|
||||
a2 = (a1 * 2.0).contiguous().realize()
|
||||
|
||||
b1 = (b0 + 2.0).contiguous().realize()
|
||||
b2 = (b1 * 2.0).contiguous().realize()
|
||||
|
||||
return a2, b2
|
||||
|
||||
for _ in range(5):
|
||||
a = Tensor.randn(10, 10, device=d0).realize()
|
||||
b = Tensor.randn(10, 10, device=d1).realize()
|
||||
a1, b1 = f(a, b)
|
||||
np.testing.assert_allclose(((a.numpy()+2.0)*2.0), a1.numpy(), atol=1e-4, rtol=1e-5)
|
||||
np.testing.assert_allclose(((b.numpy()+2.0)*2.0), b1.numpy(), atol=1e-4, rtol=1e-5)
|
||||
|
||||
assert all(isinstance(ei.prg, GraphRunner) for ei in f.jit_cache), repr(f.jit_cache)
|
||||
|
||||
@unittest.skipIf(not_support_multi_device(), "no multi")
|
||||
def test_jitted_view(self):
|
||||
d0, d1 = f"{Device.DEFAULT}:0", f"{Device.DEFAULT}:1"
|
||||
|
||||
@@ -52,14 +52,14 @@ def apply_graph_to_jit(jit_cache: list[ExecItem], input_rawbuffers: list[Buffer]
|
||||
can_be_graphed = ji_graph_dev is not None and ji_graph_dev.graph is not None and graph_class(ji_graph_dev).supports_exec_item([ji_graph_dev], ji)
|
||||
|
||||
# Check if the current batch can be extended with this item.
|
||||
new_batched_devs = dedup(current_batch_devs + [ji_graph_dev])
|
||||
can_share_graph = can_be_graphed and len(current_batch_devs) > 0 and graph_class(current_batch_devs[0]).supports_exec_item(new_batched_devs, ji)
|
||||
can_share_graph = can_be_graphed and len(current_batch_devs) > 0 and \
|
||||
graph_class(current_batch_devs[0]).supports_exec_item(dedup(current_batch_devs + [ji_graph_dev]), ji)
|
||||
can_extend_graph_batch = can_share_graph and (max_batch_size == 0 or len(current_batch) < max_batch_size)
|
||||
|
||||
# Flush the current batch if any, since it can't be extended or is full.
|
||||
if not can_extend_graph_batch and len(current_batch) > 0: flush_batch()
|
||||
(current_batch if can_be_graphed else graphed_jit_cache).append(ji)
|
||||
current_batch_devs = new_batched_devs if can_be_graphed else []
|
||||
current_batch_devs = dedup(current_batch_devs + [ji_graph_dev]) if can_be_graphed else []
|
||||
|
||||
if len(current_batch) > 0: flush_batch()
|
||||
return graphed_jit_cache
|
||||
|
||||
Reference in New Issue
Block a user