jit: rm replan (#15433)

This commit is contained in:
nimlgen
2026-03-23 19:31:51 +08:00
committed by GitHub
parent c4c53418f8
commit 2da008ae3b
3 changed files with 3 additions and 32 deletions

View File

@@ -5,7 +5,7 @@ from extra.onnx_helpers import get_example_inputs, validate
def load_onnx_model(onnx_file):
run_onnx = OnnxRunner(onnx_file)
run_onnx_jit = TinyJit(lambda **kwargs: next(iter(run_onnx({k:v.to(None) for k,v in kwargs.items()}).values())), prune=True, optimize=True)
run_onnx_jit = TinyJit(lambda **kwargs: next(iter(run_onnx({k:v.to(None) for k,v in kwargs.items()}).values())), prune=True)
return run_onnx_jit, run_onnx.graph_inputs
if __name__ == "__main__":

View File

@@ -660,25 +660,6 @@ class TestJitFree(unittest.TestCase):
fxn(Tensor([2]))
self.assertEqual(x.item(), 8)
def test_replan_buffers_memory_layout(self):
if not hasattr(Device[Device.DEFAULT].allocator, '_offset'): raise unittest.SkipTest("replan_buffers_memory_layout useless")
ext_tensor = Tensor([1,24,23,45,1]).contiguous()
ext_tensor_2 = Tensor([2,2,2,2,2]).contiguous()
@TinyJit
def fxn(x:Tensor):
out = (x*ext_tensor_2+ext_tensor).reshape(5,1).expand(5, 100).contiguous()
return out.sum()
for i in range(5):
out = fxn(Tensor([i,1,2,3,4]))
self.assertEqual(out.item(), 11400+200*i)
self.assertEqual(len(set([b.base for item in fxn.captured.jit_cache for b in item.bufs if b is not None])), 4)
fxn.captured.replan_buffers_memory_layout()
self.assertEqual(len(set([b.base for item in fxn.captured.jit_cache for b in item.bufs if b is not None])), 2)
out = fxn(Tensor([11,1,2,3,4]))
self.assertEqual(out.item(), 13600)
class TestJitGraphSplit(unittest.TestCase):
def compute(self, device, inp):
assert inp.device == device, f"Input device {inp.device} does not match expected {device}"

View File

@@ -190,7 +190,7 @@ class CapturedJit(Generic[ReturnType]):
expected_input_info: list[tuple[UOp, tuple[Variable, ...], DType, str]] # (view, variables, dtype, device) per input
def __reduce__(self):
# TODO: free_intermediates here? replan_buffers_memory_layout here?
# TODO: free_intermediates here?
return self.__class__, (self.ret, self.jit_cache, self.input_replace, self.extra_view_inputs, self.expected_names, self.expected_input_info)
def __post_init__(self):
@@ -218,14 +218,6 @@ class CapturedJit(Generic[ReturnType]):
if b._base is not None and b._base.allocated_views == 0 and b._base.is_allocated(): b._base.deallocate()
self.__post_init__() # reset the graph state
def replan_buffers_memory_layout(self):
blacklist = [t.uop.buffer for t in get_parameters(self.ret)]
asgn = _internal_memory_planner([[b for item in self.jit_cache for b in item.bufs if b is not None and b not in blacklist]], ignore_checks=True)
self.jit_cache = [replace(item, bufs=[asgn.get(b,b) if b is not None else None for b in item.bufs]) for item in self.jit_cache]
for old, new in asgn.items():
if old.is_allocated(): new.ensure_allocated().copyin(old.as_memoryview())
self.__post_init__()
# jit exec
def __call__(self, input_buffers:list[Buffer], var_vals:dict[str, int]) -> ReturnType:
# assign inputs
@@ -282,13 +274,12 @@ def _prepare_jit_inputs(args, kwargs):
return input_buffers, var_vals, names, expected_input_info
class TinyJit(Generic[ReturnType]):
def __init__(self, fxn:Callable[..., ReturnType]|None, captured:CapturedJit|None=None, prune=False, optimize=False):
def __init__(self, fxn:Callable[..., ReturnType]|None, captured:CapturedJit|None=None, prune=False):
assert fxn or captured, "need either a function or a CapturedJit"
self.fxn = fxn
self.captured: CapturedJit|None = captured
self.cnt: int = 2 if self.fxn is None else 0
self.prune = prune
self.optimize = optimize
def add_linear(self, linear:UOp, var_vals:dict[str, int]): self._linears.append(linear)
@@ -367,7 +358,6 @@ class TinyJit(Generic[ReturnType]):
for ei in jit_cache: ei.run(var_vals)
self.captured = CapturedJit(ret, jit_cache, input_replace, extra_view_inputs, names, expected_input_info)
if self.optimize: self.captured.replan_buffers_memory_layout()
elif self.cnt >= 2:
# jit exec
assert self.captured is not None