From dc9da1d917c7bb8b4d2b7028dd31b29c359690d9 Mon Sep 17 00:00:00 2001 From: nimlgen <138685161+nimlgen@users.noreply.github.com> Date: Thu, 27 Mar 2025 01:46:50 +0700 Subject: [PATCH] memplan into one buffer (#9526) * new memplanner * new should works * fix * VALIDATE_MEMORY_PLANNER * hm? * ugh * fix alignment * fix2 * rm * tiny fixes * test * comments and fixes * fix2 * liiiinetr * t * fix --- test/test_jit.py | 4 ++- tinygrad/device.py | 4 ++- tinygrad/engine/jit.py | 4 ++- tinygrad/engine/memory.py | 64 +++++++++++++++++++++++++-------------- 4 files changed, 50 insertions(+), 26 deletions(-) diff --git a/test/test_jit.py b/test/test_jit.py index a9a29ed90d..d1fee92976 100644 --- a/test/test_jit.py +++ b/test/test_jit.py @@ -583,7 +583,9 @@ class TestJitFree(unittest.TestCase): pre_free = GlobalCounters.mem_used fxn.captured.free_intermediates() savings_after_free = pre_free - GlobalCounters.mem_used - self.assertEqual(savings_after_free, 2024) + + # Different allocator implementations have different savings. + self.assertEqual(savings_after_free, 8196 if hasattr(Device[Device.DEFAULT].allocator, '_offset') else 2024) out = fxn(Tensor([11,1,2,3,4])) self.assertEqual(out.item(), 13600) diff --git a/tinygrad/device.py b/tinygrad/device.py index 4399ddbb5c..ce3287c21b 100644 --- a/tinygrad/device.py +++ b/tinygrad/device.py @@ -94,7 +94,7 @@ class Buffer: lb_refcount=0, base:Optional[Buffer]=None, offset:int=0, preallocate=False): if isinstance(dtype, ImageDType): options = BufferSpec(image=dtype) # TODO: image hack shouldn't be here. where should it be? else: assert isinstance(dtype, DType) and not isinstance(dtype, PtrDType) - self.device, self.size, self.dtype, self.options, self.offset = device, size, dtype, options, offset + self.device, self.size, self.dtype, self.options, self.offset, self.allocated_views = device, size, dtype, options, offset, 0 if base is None: assert offset == 0, "base buffers can't have offset" self._base = None @@ -122,6 +122,7 @@ class Buffer: self.options = replace(self.options, external_ptr=external_ptr) if self.options else BufferSpec(external_ptr=external_ptr) if self._base is not None: self._base.ensure_allocated() + self._base.allocated_views += 1 assert hasattr(self.allocator, "_offset"), "offset function required for view" self._buf: Any = self.allocator._offset(self.base._buf, self.nbytes, self.offset) else: @@ -133,6 +134,7 @@ class Buffer: if self._base is None and (self.options is None or self.options.external_ptr is None): if not self.device.startswith("DISK"): GlobalCounters.mem_used -= self.nbytes self.allocator.free(self._buf, self.nbytes, self.options) + elif self._base is not None: self._base.allocated_views -= 1 del self._buf def __reduce__(self): buf = None diff --git a/tinygrad/engine/jit.py b/tinygrad/engine/jit.py index 1400d8d772..125a549ccf 100644 --- a/tinygrad/engine/jit.py +++ b/tinygrad/engine/jit.py @@ -166,7 +166,9 @@ class CapturedJit(Generic[ReturnType]): depends: set[Buffer|None] = set([None]) update_depends(depends, self.jit_cache) for b in depends: - if b is not None: b.deallocate() + if b is not None: + b.deallocate() + if b._base is not None and b._base.allocated_views == 0: b._base.deallocate() self.__post_init__() # reset the graph state # jit exec diff --git a/tinygrad/engine/memory.py b/tinygrad/engine/memory.py index a9359ce259..d53ae565ea 100644 --- a/tinygrad/engine/memory.py +++ b/tinygrad/engine/memory.py @@ -1,46 +1,64 @@ +from typing import cast from collections import defaultdict from tinygrad.engine.schedule import ScheduleItem from tinygrad.device import Device, Buffer -from tinygrad.helpers import NO_MEMORY_PLANNER, dedup, DEBUG +from tinygrad.helpers import NO_MEMORY_PLANNER, dedup, DEBUG, round_up from tinygrad.ops import Ops +from tinygrad.dtype import dtypes, ImageDType +from tinygrad.runtime.support.allocator import TLSFAllocator # **************** memory planning **************** def _internal_memory_planner(buffers:list[list[Buffer]|tuple[Buffer, ...]], noopt_buffers=None, debug_prefix="") -> dict[Buffer, Buffer]: if NO_MEMORY_PLANNER: return {} - first_appearance, last_appearance = {}, {} + first_appearance, last_appearance, buf_to_opt = {}, {}, set() for i,u in enumerate(buffers): for buf in u: - if buf.is_allocated() or buf.lb_refcount > 0 or (noopt_buffers is not None and buf.base in noopt_buffers): continue + if buf.is_allocated() or buf.base.is_allocated() or buf.lb_refcount > 0 or (noopt_buffers is not None and buf.base in noopt_buffers): continue if buf.base not in first_appearance: first_appearance[buf.base] = i last_appearance[buf.base] = i + buf_to_opt.add(buf) - # Sort buffers by size in descending order, prioritizing largest buffers for allocation first. - # Track free segments, each containing (start, stop, and buffer that could be reused on this segment). - free_segs: dict[tuple, list[tuple[int, int, Buffer]]] = defaultdict(list) # dict[buffer key, tuple[start, end, buffer to reuse on the seg]] - def find_replace_buffer(buf, st, en): - key = (buf.device, buf.dtype, buf.options) + ((buf.nbytes,) if not hasattr(Device[buf.device].allocator, "offset") else tuple()) + # Sort buffer operations in timeline order. Two events: buffer is allocated or buffer is freed. + buffer_requests = sorted([((first_appearance[buf], True), buf) for buf in first_appearance.keys()] + \ + [((last_appearance[buf] + 1, False), buf) for buf in first_appearance.keys()], key=lambda x: x[0]) - default_buf = (0, len(buffers) - 1, buf) # will return the buffer itself if the replace one is not found. - seg_st, seg_en, seg_buf = next((free_segs[key].pop(i) for i,(sst,sen,_) in enumerate(free_segs[key]) if sst <= st and en <= sen), default_buf) + # Try to suballocate from a shared buffer managed by global_planner using TLSFAllocator. + # Also track buffer replacements for buffers that do not support suballocation. + buffer_replace:dict[Buffer, tuple[Buffer|None, int|None]] = {} + reuse_buffers:dict[tuple, list[Buffer]] = defaultdict(list) + global_planner:dict[str, tuple[int, TLSFAllocator]] = defaultdict(lambda: (0, TLSFAllocator(1 << 44, block_size=0x1000, lv2_cnt=32))) + for (_, is_open_ev), buf in buffer_requests: + # Check if suballocation is possible for the given buffer and device. + if hasattr(Device[buf.device].allocator, "_offset") and not isinstance(buf.dtype, ImageDType): + if is_open_ev: buffer_replace[buf] = (None, global_planner[buf.device][1].alloc(round_up(buf.nbytes, 0x1000))) + else: global_planner[buf.device][1].free(cast(int, buffer_replace[buf][1])) + global_planner[buf.device] = (max(global_planner[buf.device][0], buffer_replace[buf][1] + buf.nbytes), global_planner[buf.device][1]) + else: + key = (buf.device, buf.dtype, buf.options, buf.nbytes) + if is_open_ev: buffer_replace[buf] = (reuse_buffers[key].pop(), None) if key in reuse_buffers and len(reuse_buffers[key]) > 0 else (buf, None) + else: reuse_buffers[key].append(cast(Buffer, buffer_replace[buf][0])) - free_segs[key] += [(seg_st, st - 1, seg_buf)] if st - 1 >= seg_st else [] - free_segs[key] += [(en + 1, seg_en, seg_buf)] if seg_en >= en + 1 else [] + # Allocate global buffers based on the memory planner. + global_buffers = {dev: Buffer(dev, round_up(sz, 0x1000), dtypes.int8) for dev, (sz, _) in global_planner.items()} + buffer_resolve:dict[Buffer, tuple[Buffer, int|None]] = {buf: (base or global_buffers[buf.device], off) for buf,(base,off) in buffer_replace.items()} - return seg_buf if seg_buf.nbytes == buf.nbytes else Buffer(buf.device, buf.size, buf.dtype, base=seg_buf) + # Assign buffers. First, assign full buffers (not sub-buffers). + assigned:dict[Buffer, Buffer] = {} + for buf, (base, off) in buffer_resolve.items(): + if buf != base: + assigned[buf] = base if off is None else Buffer(buf.device, buf.size, buf.dtype, base=base, offset=off) - buffer_requests = sorted([(first_appearance[buf], last_appearance[buf], buf) for buf in first_appearance.keys()], key=lambda x: -x[2].nbytes) - assigned = {buf:find_replace_buffer(buf, st, en) for st, en, buf in buffer_requests} + # Now assign sub-buffers. + for buf in buf_to_opt: + if buf._base is not None: + assigned[buf] = Buffer(buf.device, buf.size, buf.dtype, base=(pbuf:=assigned.get(buf.base, buf.base)).base, offset=pbuf.offset+buf.offset) - for i,u in enumerate(buffers): - for buf in u: - if buf.is_allocated() or buf.lb_refcount > 0 or (noopt_buffers is not None and buf.base in noopt_buffers): continue - if buf._base is not None: assigned[buf] = Buffer(buf.device, buf.size, buf.dtype, base=assigned.get(buf.base, buf.base).base, offset=buf.offset) - else: assigned[buf] = assigned.get(buf, buf) + if DEBUG >= 1: + ak, av = dedup(x for x in assigned.keys() if x._base is None),dedup(x for x in assigned.values() if x._base is None)+list(global_buffers.values()) + omem, nmem = sum([x.nbytes for x in ak])/1e6, sum([x.nbytes for x in av])/1e6 + if omem != nmem: print(f"{debug_prefix}memory reduced from {omem:.2f} MB -> {nmem:.2f} MB,", f"{len(ak)} -> {len(av)} bufs") - if DEBUG >= 1 and len(ak:=dedup(x for x in assigned.keys() if x._base is None)) != len(av:=dedup(x for x in assigned.values() if x._base is None)): - print(debug_prefix+f"memory reduced from {sum([x.nbytes for x in ak])/1e6:.2f} MB -> {sum([x.nbytes for x in av])/1e6:.2f} MB,", - f"{len(ak)} -> {len(av)} bufs") return assigned def memory_planner(schedule:list[ScheduleItem]) -> list[ScheduleItem]: