mirror of
https://github.com/tinygrad/tinygrad.git
synced 2026-01-09 15:08:02 -05:00
jit: rename optimize_weights -> replan_buffers_memory_layout (#9751)
This commit is contained in:
@@ -617,8 +617,8 @@ class TestJitFree(unittest.TestCase):
|
||||
fxn(Tensor([2]))
|
||||
self.assertEqual(x.item(), 8)
|
||||
|
||||
def test_optimize_weights(self):
|
||||
if not hasattr(Device[Device.DEFAULT].allocator, '_offset'): raise unittest.SkipTest("optimize_weights useless")
|
||||
def test_replan_buffers_memory_layout(self):
|
||||
if not hasattr(Device[Device.DEFAULT].allocator, '_offset'): raise unittest.SkipTest("replan_buffers_memory_layout useless")
|
||||
|
||||
ext_tensor = Tensor([1,24,23,45,1])
|
||||
ext_tensor_2 = Tensor([2,2,2,2,2])
|
||||
@@ -630,7 +630,7 @@ class TestJitFree(unittest.TestCase):
|
||||
out = fxn(Tensor([i,1,2,3,4]))
|
||||
self.assertEqual(out.item(), 11400+200*i)
|
||||
assert len(set([b.base for item in fxn.captured.jit_cache for b in item.bufs if b is not None])) == 4
|
||||
fxn.captured.optimize_weights()
|
||||
fxn.captured.replan_buffers_memory_layout()
|
||||
assert len(set([b.base for item in fxn.captured.jit_cache for b in item.bufs if b is not None])) == 2
|
||||
|
||||
out = fxn(Tensor([11,1,2,3,4]))
|
||||
|
||||
@@ -149,7 +149,7 @@ class CapturedJit(Generic[ReturnType]):
|
||||
expected_st_vars_dtype_device: list[tuple[ShapeTracker, tuple[Variable, ...], DType, str]]
|
||||
|
||||
def __reduce__(self):
|
||||
# TODO: free_intermediates here? optimize_weights here?
|
||||
# TODO: free_intermediates here? replan_buffers_memory_layout here?
|
||||
return self.__class__, (self.ret, self.jit_cache, self.input_replace, self.extra_view_inputs,
|
||||
self.expected_names, self.expected_st_vars_dtype_device)
|
||||
|
||||
@@ -171,7 +171,7 @@ class CapturedJit(Generic[ReturnType]):
|
||||
if b._base is not None and b._base.allocated_views == 0 and b._base.is_allocated(): b._base.deallocate()
|
||||
self.__post_init__() # reset the graph state
|
||||
|
||||
def optimize_weights(self):
|
||||
def replan_buffers_memory_layout(self):
|
||||
blacklist = [t.lazydata.buffer for t in get_parameters(self.ret)]
|
||||
asgn = _internal_memory_planner([[b for item in self.jit_cache for b in item.bufs if b is not None and b not in blacklist]], ignore_checks=True)
|
||||
self.jit_cache = [ExecItem(item.prg, [asgn.get(b,b) if b is not None else None for b in item.bufs]) for item in self.jit_cache]
|
||||
@@ -314,7 +314,7 @@ class TinyJit(Generic[ReturnType]):
|
||||
|
||||
# set this for next run
|
||||
self.captured = CapturedJit(ret, jit_cache, input_replace, extra_view_inputs, names, st_vars_dtype_device)
|
||||
if self.optimize: self.captured.optimize_weights()
|
||||
if self.optimize: self.captured.replan_buffers_memory_layout()
|
||||
elif self.cnt >= 2:
|
||||
# jit exec
|
||||
assert self.captured is not None
|
||||
|
||||
Reference in New Issue
Block a user