new memory scheduler (#5278)

* new memory schedule algo

* works

* fix

* fix

* linter

* tiny fixes

* do not optimize copy buffers

* mpre comments

* tiny cleanups
This commit is contained in:
nimlgen
2024-07-04 18:06:04 +03:00
committed by GitHub
parent 84b3e3bb6f
commit 2778b6046c
5 changed files with 36 additions and 27 deletions

View File

@@ -326,4 +326,4 @@ class HCQCompatAllocator(LRUAllocator): # pylint: disable=abstract-method
def offset(self, buf, size:int, offset:int) -> HCQCompatAllocRes:
return type(buf)(va_addr=buf.va_addr + offset, size=size, **{k:v for k,v in buf.__dict__.items() if k not in ['va_addr', 'size']},
**{x[0]:getattr(buf, x[0]) for x in getattr(buf, '_fields_', []) if x[0] not in ['va_addr', 'size']})
**{x[0]:getattr(buf, x[0]) for x in getattr(buf, '_fields_', []) if x[0] not in ['va_addr', 'size']}, _base=buf)

View File

@@ -115,7 +115,7 @@ class TinyJit(Generic[ReturnType]):
if found:=self.buffer_replace.get(b, None): return found
if b.is_allocated() or b.lb_refcount > 0: return b
if b._base is not None:
self.buffer_replace[b] = ret = Buffer(b.device, b.size, b.dtype, base=self.buffer_replace.get(b._base, b._base), offset=b.offset)
self.buffer_replace[b] = ret = Buffer(b.device, b.size, b.dtype, base=self.add_buffer(b._base), offset=b.offset)
else:
self.buffer_replace[b] = ret = Buffer(b.device, b.size, b.dtype, options=b.options)
return ret
@@ -173,7 +173,9 @@ class TinyJit(Generic[ReturnType]):
self.extra_view_inputs.append((input_buffers.index(b.base), b.offset, b.device, b.size, b.dtype))
# memory planning (optional)
assigned = _internal_memory_planner([cast(List[Buffer], item.bufs) for item in self.jit_cache], debug_prefix="JIT ")
# Exclude buffers involved in transfer ops to preserve parallelism.
noopt_buffers = {b for ji in self.jit_cache if isinstance(ji.prg, BufferXfer) for b in ji.bufs}
assigned = _internal_memory_planner([cast(List[Buffer], item.bufs) for item in self.jit_cache], noopt_buffers, debug_prefix="JIT ")
self.jit_cache = [ExecItem(item.prg, [assigned.get(b,b).ensure_allocated() for b in item.bufs if b is not None]) for item in self.jit_cache]
# Condense the items into a graph executor.

View File

@@ -6,10 +6,10 @@ from tinygrad.ops import LoadOps, BufferOps, LazyOp, ReduceOps, ConstBuffer, Mem
from tinygrad.engine.graph import log_lazybuffer, realized_lazybuffer
from tinygrad.helpers import GRAPH, DEBUG, MULTIOUTPUT, SAVE_SCHEDULE, GlobalCounters, colored, prod, dedup, all_int, merge_dicts, getenv
from tinygrad.shape.symbolic import Variable
from tinygrad.dtype import ConstType, ImageDType, dtypes, DType
from tinygrad.dtype import ConstType, ImageDType, dtypes
from tinygrad.lazy import LazyBuffer
from tinygrad.shape.shapetracker import ShapeTracker
from tinygrad.device import Buffer
from tinygrad.device import Buffer, Device
# creation can recurse a lot
sys.setrecursionlimit(10000)
@@ -332,39 +332,44 @@ def create_schedule(outs:List[LazyBuffer], seen:Optional[Set[LazyBuffer]]=None)
# *** memory planning ***
def _internal_memory_planner(buffers:List[Union[List[Buffer], Tuple[Buffer, ...]]], debug_prefix="") -> Dict[Buffer, Buffer]:
def _internal_memory_planner(buffers:List[Union[List[Buffer], Tuple[Buffer, ...]]], noopt_buffers=None, debug_prefix="") -> Dict[Buffer, Buffer]:
if getenv("NO_MEMORY_PLANNER"): return {}
last_appearance = {}
first_appearance, last_appearance = {}, {}
for i,u in enumerate(buffers):
for buf in u: last_appearance[buf] = i
for buf in u:
if buf.is_allocated() or buf.lb_refcount > 0 or (noopt_buffers is not None and buf.base in noopt_buffers): continue
if buf.base not in first_appearance: first_appearance[buf.base] = i
last_appearance[buf.base] = i
# LRU algorithm
assigned: Dict[Buffer, Buffer] = {}
local_cache: DefaultDict[Tuple[str, int, DType], List[Buffer]] = defaultdict(list)
# Sort buffers by size in descending order, prioritizing largest buffers for allocation first.
# Track free segments, each containing (start, stop, and buffer that could be reused on this segment).
free_segs: Dict[Tuple, List[Tuple[int, int, Buffer]]] = defaultdict(list) # Dict[buffer key, Tuple[start, end, buffer to reuse on the seg]]
def find_replace_buffer(buf, st, en):
key = (buf.device, buf.dtype, buf.options) + ((buf.nbytes,) if not hasattr(Device[buf.device].allocator, "offset") else tuple())
def handle_buffer(buf):
key = (buf.device, buf.size, buf.dtype)
if buf not in assigned:
if len(ll:=local_cache[key]): assigned[buf] = ll.pop()
else: assigned[buf] = Buffer(*key)
if i == last_appearance[buf]:
if assigned[buf] not in local_cache[key]: local_cache[key].append(assigned[buf])
default_buf = (0, len(buffers) - 1, buf) # will return the buffer itself if the replace one is not found.
seg_st, seg_en, seg_buf = next((free_segs[key].pop(i) for i,(sst,sen,_) in enumerate(free_segs[key]) if sst <= st and en <= sen), default_buf)
free_segs[key] += [(seg_st, st - 1, seg_buf)] if st - 1 >= seg_st else []
free_segs[key] += [(en + 1, seg_en, seg_buf)] if seg_en >= en + 1 else []
return seg_buf if seg_buf.nbytes == buf.nbytes else Buffer(buf.device, buf.size, buf.dtype, base=seg_buf)
buffer_requests = sorted([(first_appearance[buf], last_appearance[buf], buf) for buf in first_appearance.keys()], key=lambda x: -x[2].nbytes)
assigned = {buf:find_replace_buffer(buf, st, en) for st, en, buf in buffer_requests}
for i,u in enumerate(buffers):
for buf in u:
# all unallocated unparented buffers are fair game to replace
if buf.is_allocated() or buf.lb_refcount > 0: continue
# handle view buffers
if buf._base is not None:
assigned[buf] = Buffer(buf.device, buf.size, buf.dtype, base=assigned.get(buf._base, buf._base), offset=buf.offset)
else:
handle_buffer(buf)
if buf.is_allocated() or buf.lb_refcount > 0 or (noopt_buffers is not None and buf.base in noopt_buffers): continue
if buf._base is not None: assigned[buf] = Buffer(buf.device, buf.size, buf.dtype, base=assigned.get(buf.base, buf.base).base, offset=buf.offset)
else: assigned[buf] = assigned.get(buf, buf)
if DEBUG >= 1 and len(ak:=dedup(assigned.keys())) != len(av:=dedup(assigned.values())):
if DEBUG >= 1 and len(ak:=dedup(x for x in assigned.keys() if x._base is None)) != len(av:=dedup(x for x in assigned.values() if x._base is None)):
print(debug_prefix+f"memory reduced from {sum([x.nbytes for x in ak])/1e6:.2f} MB -> {sum([x.nbytes for x in av])/1e6:.2f} MB,",
f"{len(ak)} -> {len(av)} bufs")
return assigned
def memory_planner(schedule:List[ScheduleItem]) -> List[ScheduleItem]:
assigned = _internal_memory_planner([si.bufs for si in schedule])
# Exclude buffers involved in load ops (e.g transfers) to preserve parallelism in graphs.
assigned = _internal_memory_planner([si.bufs for si in schedule], noopt_buffers={b for si in schedule if si.ast[0].op in LoadOps for b in si.bufs})
return [ScheduleItem(si.ast, tuple(assigned.get(x, x) for x in si.bufs)) for si in schedule]

View File

@@ -412,6 +412,7 @@ class AMDDevice(HCQCompatCompiled):
gpus:List[pathlib.Path] = []
def _gpu_map(self, mem):
mem = mem._base if hasattr(mem, '_base') else mem
if self.gpu_id in getattr(mem, "mapped_gpu_ids", []): return
mem.__setattr__("mapped_gpu_ids", getattr(mem, "mapped_gpu_ids", []) + [self.gpu_id])
c_gpus = (ctypes.c_int32 * len(mem.mapped_gpu_ids))(*mem.mapped_gpu_ids)

View File

@@ -462,6 +462,7 @@ class NVDevice(HCQCompatCompiled):
gpuAttributesCount=1, perGpuAttributes=gpu_attrs, va_addr=va_base, size=size)
def _gpu_map(self, mem):
mem = mem._base if hasattr(mem, '_base') else mem
if self.gpu_uuid in getattr(mem, "mapped_gpu_ids", []): return
mem.__setattr__("mapped_gpu_ids", getattr(mem, "mapped_gpu_ids", []) + [self.gpu_uuid])
return self._gpu_uvm_map(mem.va_addr, mem.size, mem.hMemory, create_range=False)