jit: optimize before pickle (#9611)

* jit: optimize before pickle

* optimize weights

* fix

* mypy

* mypy2
This commit is contained in:
nimlgen
2025-03-28 19:06:09 +07:00
committed by GitHub
parent 392a311312
commit fa0ebbd237
3 changed files with 32 additions and 3 deletions

View File

@@ -605,5 +605,24 @@ class TestJitFree(unittest.TestCase):
fxn(Tensor([2]))
self.assertEqual(x.item(), 8)
def test_optimize_weights(self):
if not hasattr(Device[Device.DEFAULT].allocator, '_offset'): raise unittest.SkipTest("optimize_weights useless")
ext_tensor = Tensor([1,24,23,45,1])
ext_tensor_2 = Tensor([2,2,2,2,2])
@TinyJit
def fxn(x:Tensor):
out = (x*ext_tensor_2+ext_tensor).reshape(5,1).expand(5, 100).contiguous()
return out.sum()
for i in range(5):
out = fxn(Tensor([i,1,2,3,4]))
self.assertEqual(out.item(), 11400+200*i)
assert len(set([b.base for item in fxn.captured.jit_cache for b in item.bufs if b is not None])) == 4
fxn.captured.optimize_weights()
assert len(set([b.base for item in fxn.captured.jit_cache for b in item.bufs if b is not None])) == 2
out = fxn(Tensor([11,1,2,3,4]))
self.assertEqual(out.item(), 13600)
if __name__ == '__main__':
unittest.main()

View File

@@ -150,6 +150,7 @@ class CapturedJit(Generic[ReturnType]):
def __reduce__(self):
# TODO: free_intermediates here?
self.optimize_weights()
return self.__class__, (self.ret, self.jit_cache, self.input_replace, self.extra_view_inputs,
self.expected_names, self.expected_st_vars_dtype_device)
@@ -171,6 +172,14 @@ class CapturedJit(Generic[ReturnType]):
if b._base is not None and b._base.allocated_views == 0: b._base.deallocate()
self.__post_init__() # reset the graph state
def optimize_weights(self):
blacklist = [t.lazydata.buffer for t in get_parameters(self.ret)]
asgn = _internal_memory_planner([[b for item in self.jit_cache for b in item.bufs if b is not None and b not in blacklist]], ignore_checks=True)
self.jit_cache = [ExecItem(item.prg, [asgn.get(b,b) if b is not None else None for b in item.bufs]) for item in self.jit_cache]
for old, new in asgn.items():
if old.is_allocated(): new.ensure_allocated().copyin(old.as_buffer())
self.__post_init__()
# jit exec
def __call__(self, input_buffers:list[Buffer], var_vals:dict[Variable, int]) -> ReturnType:
# assign inputs

View File

@@ -9,12 +9,13 @@ from tinygrad.runtime.support.allocator import TLSFAllocator
# **************** memory planning ****************
def _internal_memory_planner(buffers:list[list[Buffer]|tuple[Buffer, ...]], noopt_buffers=None, debug_prefix="") -> dict[Buffer, Buffer]:
def _internal_memory_planner(buffers:list[list[Buffer]], noopt_buffers=None, ignore_checks=False, debug_prefix="") -> dict[Buffer, Buffer]:
if NO_MEMORY_PLANNER: return {}
first_appearance, last_appearance, buf_to_opt = {}, {}, set()
for i,u in enumerate(buffers):
for buf in u:
if buf.is_allocated() or buf.base.is_allocated() or buf.lb_refcount > 0 or (noopt_buffers is not None and buf.base in noopt_buffers): continue
should_skip = buf.is_allocated() or buf.base.is_allocated() or buf.lb_refcount > 0 or (noopt_buffers is not None and buf.base in noopt_buffers)
if not ignore_checks and should_skip: continue
if buf.base not in first_appearance: first_appearance[buf.base] = i
last_appearance[buf.base] = i
buf_to_opt.add(buf)
@@ -63,6 +64,6 @@ def _internal_memory_planner(buffers:list[list[Buffer]|tuple[Buffer, ...]], noop
def memory_planner(schedule:list[ScheduleItem]) -> list[ScheduleItem]:
# Exclude buffers involved in load ops (e.g transfers) to preserve parallelism in graphs.
assigned = _internal_memory_planner([si.bufs for si in schedule],
assigned = _internal_memory_planner([list(si.bufs) for si in schedule],
noopt_buffers={b for si in schedule if si.ast.op is not Ops.SINK for b in si.bufs})
return [ScheduleItem(si.ast, tuple(assigned.get(x, x) for x in si.bufs), si.metadata) for si in schedule]