diff --git a/test/test_jit.py b/test/test_jit.py index 30842ca3b2..7c50b93adb 100644 --- a/test/test_jit.py +++ b/test/test_jit.py @@ -617,8 +617,8 @@ class TestJitFree(unittest.TestCase): fxn(Tensor([2])) self.assertEqual(x.item(), 8) - def test_optimize_weights(self): - if not hasattr(Device[Device.DEFAULT].allocator, '_offset'): raise unittest.SkipTest("optimize_weights useless") + def test_replan_buffers_memory_layout(self): + if not hasattr(Device[Device.DEFAULT].allocator, '_offset'): raise unittest.SkipTest("replan_buffers_memory_layout useless") ext_tensor = Tensor([1,24,23,45,1]) ext_tensor_2 = Tensor([2,2,2,2,2]) @@ -630,7 +630,7 @@ class TestJitFree(unittest.TestCase): out = fxn(Tensor([i,1,2,3,4])) self.assertEqual(out.item(), 11400+200*i) assert len(set([b.base for item in fxn.captured.jit_cache for b in item.bufs if b is not None])) == 4 - fxn.captured.optimize_weights() + fxn.captured.replan_buffers_memory_layout() assert len(set([b.base for item in fxn.captured.jit_cache for b in item.bufs if b is not None])) == 2 out = fxn(Tensor([11,1,2,3,4])) diff --git a/tinygrad/engine/jit.py b/tinygrad/engine/jit.py index d8b7b60a31..0e36dac219 100644 --- a/tinygrad/engine/jit.py +++ b/tinygrad/engine/jit.py @@ -149,7 +149,7 @@ class CapturedJit(Generic[ReturnType]): expected_st_vars_dtype_device: list[tuple[ShapeTracker, tuple[Variable, ...], DType, str]] def __reduce__(self): - # TODO: free_intermediates here? optimize_weights here? + # TODO: free_intermediates here? replan_buffers_memory_layout here? return self.__class__, (self.ret, self.jit_cache, self.input_replace, self.extra_view_inputs, self.expected_names, self.expected_st_vars_dtype_device) @@ -171,7 +171,7 @@ class CapturedJit(Generic[ReturnType]): if b._base is not None and b._base.allocated_views == 0 and b._base.is_allocated(): b._base.deallocate() self.__post_init__() # reset the graph state - def optimize_weights(self): + def replan_buffers_memory_layout(self): blacklist = [t.lazydata.buffer for t in get_parameters(self.ret)] asgn = _internal_memory_planner([[b for item in self.jit_cache for b in item.bufs if b is not None and b not in blacklist]], ignore_checks=True) self.jit_cache = [ExecItem(item.prg, [asgn.get(b,b) if b is not None else None for b in item.bufs]) for item in self.jit_cache] @@ -314,7 +314,7 @@ class TinyJit(Generic[ReturnType]): # set this for next run self.captured = CapturedJit(ret, jit_cache, input_replace, extra_view_inputs, names, st_vars_dtype_device) - if self.optimize: self.captured.optimize_weights() + if self.optimize: self.captured.replan_buffers_memory_layout() elif self.cnt >= 2: # jit exec assert self.captured is not None