From 2da008ae3b06f48485fcae50c4af76290e670958 Mon Sep 17 00:00:00 2001 From: nimlgen <138685161+nimlgen@users.noreply.github.com> Date: Mon, 23 Mar 2026 19:31:51 +0800 Subject: [PATCH] jit: rm replan (#15433) --- examples/benchmark_onnx.py | 2 +- test/backend/test_jit.py | 19 ------------------- tinygrad/engine/jit.py | 14 ++------------ 3 files changed, 3 insertions(+), 32 deletions(-) diff --git a/examples/benchmark_onnx.py b/examples/benchmark_onnx.py index 27568117f3..71bf7b8ed5 100644 --- a/examples/benchmark_onnx.py +++ b/examples/benchmark_onnx.py @@ -5,7 +5,7 @@ from extra.onnx_helpers import get_example_inputs, validate def load_onnx_model(onnx_file): run_onnx = OnnxRunner(onnx_file) - run_onnx_jit = TinyJit(lambda **kwargs: next(iter(run_onnx({k:v.to(None) for k,v in kwargs.items()}).values())), prune=True, optimize=True) + run_onnx_jit = TinyJit(lambda **kwargs: next(iter(run_onnx({k:v.to(None) for k,v in kwargs.items()}).values())), prune=True) return run_onnx_jit, run_onnx.graph_inputs if __name__ == "__main__": diff --git a/test/backend/test_jit.py b/test/backend/test_jit.py index d2fc3d16bb..765600d370 100644 --- a/test/backend/test_jit.py +++ b/test/backend/test_jit.py @@ -660,25 +660,6 @@ class TestJitFree(unittest.TestCase): fxn(Tensor([2])) self.assertEqual(x.item(), 8) - def test_replan_buffers_memory_layout(self): - if not hasattr(Device[Device.DEFAULT].allocator, '_offset'): raise unittest.SkipTest("replan_buffers_memory_layout useless") - - ext_tensor = Tensor([1,24,23,45,1]).contiguous() - ext_tensor_2 = Tensor([2,2,2,2,2]).contiguous() - @TinyJit - def fxn(x:Tensor): - out = (x*ext_tensor_2+ext_tensor).reshape(5,1).expand(5, 100).contiguous() - return out.sum() - for i in range(5): - out = fxn(Tensor([i,1,2,3,4])) - self.assertEqual(out.item(), 11400+200*i) - self.assertEqual(len(set([b.base for item in fxn.captured.jit_cache for b in item.bufs if b is not None])), 4) - fxn.captured.replan_buffers_memory_layout() - self.assertEqual(len(set([b.base for item in fxn.captured.jit_cache for b in item.bufs if b is not None])), 2) - - out = fxn(Tensor([11,1,2,3,4])) - self.assertEqual(out.item(), 13600) - class TestJitGraphSplit(unittest.TestCase): def compute(self, device, inp): assert inp.device == device, f"Input device {inp.device} does not match expected {device}" diff --git a/tinygrad/engine/jit.py b/tinygrad/engine/jit.py index 6f4d0cdfac..ad1ece845b 100644 --- a/tinygrad/engine/jit.py +++ b/tinygrad/engine/jit.py @@ -190,7 +190,7 @@ class CapturedJit(Generic[ReturnType]): expected_input_info: list[tuple[UOp, tuple[Variable, ...], DType, str]] # (view, variables, dtype, device) per input def __reduce__(self): - # TODO: free_intermediates here? replan_buffers_memory_layout here? + # TODO: free_intermediates here? return self.__class__, (self.ret, self.jit_cache, self.input_replace, self.extra_view_inputs, self.expected_names, self.expected_input_info) def __post_init__(self): @@ -218,14 +218,6 @@ class CapturedJit(Generic[ReturnType]): if b._base is not None and b._base.allocated_views == 0 and b._base.is_allocated(): b._base.deallocate() self.__post_init__() # reset the graph state - def replan_buffers_memory_layout(self): - blacklist = [t.uop.buffer for t in get_parameters(self.ret)] - asgn = _internal_memory_planner([[b for item in self.jit_cache for b in item.bufs if b is not None and b not in blacklist]], ignore_checks=True) - self.jit_cache = [replace(item, bufs=[asgn.get(b,b) if b is not None else None for b in item.bufs]) for item in self.jit_cache] - for old, new in asgn.items(): - if old.is_allocated(): new.ensure_allocated().copyin(old.as_memoryview()) - self.__post_init__() - # jit exec def __call__(self, input_buffers:list[Buffer], var_vals:dict[str, int]) -> ReturnType: # assign inputs @@ -282,13 +274,12 @@ def _prepare_jit_inputs(args, kwargs): return input_buffers, var_vals, names, expected_input_info class TinyJit(Generic[ReturnType]): - def __init__(self, fxn:Callable[..., ReturnType]|None, captured:CapturedJit|None=None, prune=False, optimize=False): + def __init__(self, fxn:Callable[..., ReturnType]|None, captured:CapturedJit|None=None, prune=False): assert fxn or captured, "need either a function or a CapturedJit" self.fxn = fxn self.captured: CapturedJit|None = captured self.cnt: int = 2 if self.fxn is None else 0 self.prune = prune - self.optimize = optimize def add_linear(self, linear:UOp, var_vals:dict[str, int]): self._linears.append(linear) @@ -367,7 +358,6 @@ class TinyJit(Generic[ReturnType]): for ei in jit_cache: ei.run(var_vals) self.captured = CapturedJit(ret, jit_cache, input_replace, extra_view_inputs, names, expected_input_info) - if self.optimize: self.captured.replan_buffers_memory_layout() elif self.cnt >= 2: # jit exec assert self.captured is not None