jit: rename optimize_weights -> replan_buffers_memory_layout (#9751)

This commit is contained in:
nimlgen
2025-04-05 20:35:15 +03:00
committed by GitHub
parent 493fb315b1
commit c2573b247c
2 changed files with 6 additions and 6 deletions

View File

@@ -617,8 +617,8 @@ class TestJitFree(unittest.TestCase):
fxn(Tensor([2]))
self.assertEqual(x.item(), 8)
def test_optimize_weights(self):
if not hasattr(Device[Device.DEFAULT].allocator, '_offset'): raise unittest.SkipTest("optimize_weights useless")
def test_replan_buffers_memory_layout(self):
if not hasattr(Device[Device.DEFAULT].allocator, '_offset'): raise unittest.SkipTest("replan_buffers_memory_layout useless")
ext_tensor = Tensor([1,24,23,45,1])
ext_tensor_2 = Tensor([2,2,2,2,2])
@@ -630,7 +630,7 @@ class TestJitFree(unittest.TestCase):
out = fxn(Tensor([i,1,2,3,4]))
self.assertEqual(out.item(), 11400+200*i)
assert len(set([b.base for item in fxn.captured.jit_cache for b in item.bufs if b is not None])) == 4
fxn.captured.optimize_weights()
fxn.captured.replan_buffers_memory_layout()
assert len(set([b.base for item in fxn.captured.jit_cache for b in item.bufs if b is not None])) == 2
out = fxn(Tensor([11,1,2,3,4]))

View File

@@ -149,7 +149,7 @@ class CapturedJit(Generic[ReturnType]):
expected_st_vars_dtype_device: list[tuple[ShapeTracker, tuple[Variable, ...], DType, str]]
def __reduce__(self):
# TODO: free_intermediates here? optimize_weights here?
# TODO: free_intermediates here? replan_buffers_memory_layout here?
return self.__class__, (self.ret, self.jit_cache, self.input_replace, self.extra_view_inputs,
self.expected_names, self.expected_st_vars_dtype_device)
@@ -171,7 +171,7 @@ class CapturedJit(Generic[ReturnType]):
if b._base is not None and b._base.allocated_views == 0 and b._base.is_allocated(): b._base.deallocate()
self.__post_init__() # reset the graph state
def optimize_weights(self):
def replan_buffers_memory_layout(self):
blacklist = [t.lazydata.buffer for t in get_parameters(self.ret)]
asgn = _internal_memory_planner([[b for item in self.jit_cache for b in item.bufs if b is not None and b not in blacklist]], ignore_checks=True)
self.jit_cache = [ExecItem(item.prg, [asgn.get(b,b) if b is not None else None for b in item.bufs]) for item in self.jit_cache]
@@ -314,7 +314,7 @@ class TinyJit(Generic[ReturnType]):
# set this for next run
self.captured = CapturedJit(ret, jit_cache, input_replace, extra_view_inputs, names, st_vars_dtype_device)
if self.optimize: self.captured.optimize_weights()
if self.optimize: self.captured.replan_buffers_memory_layout()
elif self.cnt >= 2:
# jit exec
assert self.captured is not None