mirror of
https://github.com/tinygrad/tinygrad.git
synced 2026-01-10 23:48:01 -05:00
new memory scheduler (#5278)
* new memory schedule algo * works * fix * fix * linter * tiny fixes * do not optimize copy buffers * mpre comments * tiny cleanups
This commit is contained in:
@@ -326,4 +326,4 @@ class HCQCompatAllocator(LRUAllocator): # pylint: disable=abstract-method
|
||||
|
||||
def offset(self, buf, size:int, offset:int) -> HCQCompatAllocRes:
|
||||
return type(buf)(va_addr=buf.va_addr + offset, size=size, **{k:v for k,v in buf.__dict__.items() if k not in ['va_addr', 'size']},
|
||||
**{x[0]:getattr(buf, x[0]) for x in getattr(buf, '_fields_', []) if x[0] not in ['va_addr', 'size']})
|
||||
**{x[0]:getattr(buf, x[0]) for x in getattr(buf, '_fields_', []) if x[0] not in ['va_addr', 'size']}, _base=buf)
|
||||
|
||||
@@ -115,7 +115,7 @@ class TinyJit(Generic[ReturnType]):
|
||||
if found:=self.buffer_replace.get(b, None): return found
|
||||
if b.is_allocated() or b.lb_refcount > 0: return b
|
||||
if b._base is not None:
|
||||
self.buffer_replace[b] = ret = Buffer(b.device, b.size, b.dtype, base=self.buffer_replace.get(b._base, b._base), offset=b.offset)
|
||||
self.buffer_replace[b] = ret = Buffer(b.device, b.size, b.dtype, base=self.add_buffer(b._base), offset=b.offset)
|
||||
else:
|
||||
self.buffer_replace[b] = ret = Buffer(b.device, b.size, b.dtype, options=b.options)
|
||||
return ret
|
||||
@@ -173,7 +173,9 @@ class TinyJit(Generic[ReturnType]):
|
||||
self.extra_view_inputs.append((input_buffers.index(b.base), b.offset, b.device, b.size, b.dtype))
|
||||
|
||||
# memory planning (optional)
|
||||
assigned = _internal_memory_planner([cast(List[Buffer], item.bufs) for item in self.jit_cache], debug_prefix="JIT ")
|
||||
# Exclude buffers involved in transfer ops to preserve parallelism.
|
||||
noopt_buffers = {b for ji in self.jit_cache if isinstance(ji.prg, BufferXfer) for b in ji.bufs}
|
||||
assigned = _internal_memory_planner([cast(List[Buffer], item.bufs) for item in self.jit_cache], noopt_buffers, debug_prefix="JIT ")
|
||||
self.jit_cache = [ExecItem(item.prg, [assigned.get(b,b).ensure_allocated() for b in item.bufs if b is not None]) for item in self.jit_cache]
|
||||
|
||||
# Condense the items into a graph executor.
|
||||
|
||||
@@ -6,10 +6,10 @@ from tinygrad.ops import LoadOps, BufferOps, LazyOp, ReduceOps, ConstBuffer, Mem
|
||||
from tinygrad.engine.graph import log_lazybuffer, realized_lazybuffer
|
||||
from tinygrad.helpers import GRAPH, DEBUG, MULTIOUTPUT, SAVE_SCHEDULE, GlobalCounters, colored, prod, dedup, all_int, merge_dicts, getenv
|
||||
from tinygrad.shape.symbolic import Variable
|
||||
from tinygrad.dtype import ConstType, ImageDType, dtypes, DType
|
||||
from tinygrad.dtype import ConstType, ImageDType, dtypes
|
||||
from tinygrad.lazy import LazyBuffer
|
||||
from tinygrad.shape.shapetracker import ShapeTracker
|
||||
from tinygrad.device import Buffer
|
||||
from tinygrad.device import Buffer, Device
|
||||
|
||||
# creation can recurse a lot
|
||||
sys.setrecursionlimit(10000)
|
||||
@@ -332,39 +332,44 @@ def create_schedule(outs:List[LazyBuffer], seen:Optional[Set[LazyBuffer]]=None)
|
||||
|
||||
# *** memory planning ***
|
||||
|
||||
def _internal_memory_planner(buffers:List[Union[List[Buffer], Tuple[Buffer, ...]]], debug_prefix="") -> Dict[Buffer, Buffer]:
|
||||
def _internal_memory_planner(buffers:List[Union[List[Buffer], Tuple[Buffer, ...]]], noopt_buffers=None, debug_prefix="") -> Dict[Buffer, Buffer]:
|
||||
if getenv("NO_MEMORY_PLANNER"): return {}
|
||||
last_appearance = {}
|
||||
first_appearance, last_appearance = {}, {}
|
||||
for i,u in enumerate(buffers):
|
||||
for buf in u: last_appearance[buf] = i
|
||||
for buf in u:
|
||||
if buf.is_allocated() or buf.lb_refcount > 0 or (noopt_buffers is not None and buf.base in noopt_buffers): continue
|
||||
if buf.base not in first_appearance: first_appearance[buf.base] = i
|
||||
last_appearance[buf.base] = i
|
||||
|
||||
# LRU algorithm
|
||||
assigned: Dict[Buffer, Buffer] = {}
|
||||
local_cache: DefaultDict[Tuple[str, int, DType], List[Buffer]] = defaultdict(list)
|
||||
# Sort buffers by size in descending order, prioritizing largest buffers for allocation first.
|
||||
# Track free segments, each containing (start, stop, and buffer that could be reused on this segment).
|
||||
free_segs: Dict[Tuple, List[Tuple[int, int, Buffer]]] = defaultdict(list) # Dict[buffer key, Tuple[start, end, buffer to reuse on the seg]]
|
||||
def find_replace_buffer(buf, st, en):
|
||||
key = (buf.device, buf.dtype, buf.options) + ((buf.nbytes,) if not hasattr(Device[buf.device].allocator, "offset") else tuple())
|
||||
|
||||
def handle_buffer(buf):
|
||||
key = (buf.device, buf.size, buf.dtype)
|
||||
if buf not in assigned:
|
||||
if len(ll:=local_cache[key]): assigned[buf] = ll.pop()
|
||||
else: assigned[buf] = Buffer(*key)
|
||||
if i == last_appearance[buf]:
|
||||
if assigned[buf] not in local_cache[key]: local_cache[key].append(assigned[buf])
|
||||
default_buf = (0, len(buffers) - 1, buf) # will return the buffer itself if the replace one is not found.
|
||||
seg_st, seg_en, seg_buf = next((free_segs[key].pop(i) for i,(sst,sen,_) in enumerate(free_segs[key]) if sst <= st and en <= sen), default_buf)
|
||||
|
||||
free_segs[key] += [(seg_st, st - 1, seg_buf)] if st - 1 >= seg_st else []
|
||||
free_segs[key] += [(en + 1, seg_en, seg_buf)] if seg_en >= en + 1 else []
|
||||
|
||||
return seg_buf if seg_buf.nbytes == buf.nbytes else Buffer(buf.device, buf.size, buf.dtype, base=seg_buf)
|
||||
|
||||
buffer_requests = sorted([(first_appearance[buf], last_appearance[buf], buf) for buf in first_appearance.keys()], key=lambda x: -x[2].nbytes)
|
||||
assigned = {buf:find_replace_buffer(buf, st, en) for st, en, buf in buffer_requests}
|
||||
|
||||
for i,u in enumerate(buffers):
|
||||
for buf in u:
|
||||
# all unallocated unparented buffers are fair game to replace
|
||||
if buf.is_allocated() or buf.lb_refcount > 0: continue
|
||||
# handle view buffers
|
||||
if buf._base is not None:
|
||||
assigned[buf] = Buffer(buf.device, buf.size, buf.dtype, base=assigned.get(buf._base, buf._base), offset=buf.offset)
|
||||
else:
|
||||
handle_buffer(buf)
|
||||
if buf.is_allocated() or buf.lb_refcount > 0 or (noopt_buffers is not None and buf.base in noopt_buffers): continue
|
||||
if buf._base is not None: assigned[buf] = Buffer(buf.device, buf.size, buf.dtype, base=assigned.get(buf.base, buf.base).base, offset=buf.offset)
|
||||
else: assigned[buf] = assigned.get(buf, buf)
|
||||
|
||||
if DEBUG >= 1 and len(ak:=dedup(assigned.keys())) != len(av:=dedup(assigned.values())):
|
||||
if DEBUG >= 1 and len(ak:=dedup(x for x in assigned.keys() if x._base is None)) != len(av:=dedup(x for x in assigned.values() if x._base is None)):
|
||||
print(debug_prefix+f"memory reduced from {sum([x.nbytes for x in ak])/1e6:.2f} MB -> {sum([x.nbytes for x in av])/1e6:.2f} MB,",
|
||||
f"{len(ak)} -> {len(av)} bufs")
|
||||
return assigned
|
||||
|
||||
def memory_planner(schedule:List[ScheduleItem]) -> List[ScheduleItem]:
|
||||
assigned = _internal_memory_planner([si.bufs for si in schedule])
|
||||
# Exclude buffers involved in load ops (e.g transfers) to preserve parallelism in graphs.
|
||||
assigned = _internal_memory_planner([si.bufs for si in schedule], noopt_buffers={b for si in schedule if si.ast[0].op in LoadOps for b in si.bufs})
|
||||
return [ScheduleItem(si.ast, tuple(assigned.get(x, x) for x in si.bufs)) for si in schedule]
|
||||
|
||||
@@ -412,6 +412,7 @@ class AMDDevice(HCQCompatCompiled):
|
||||
gpus:List[pathlib.Path] = []
|
||||
|
||||
def _gpu_map(self, mem):
|
||||
mem = mem._base if hasattr(mem, '_base') else mem
|
||||
if self.gpu_id in getattr(mem, "mapped_gpu_ids", []): return
|
||||
mem.__setattr__("mapped_gpu_ids", getattr(mem, "mapped_gpu_ids", []) + [self.gpu_id])
|
||||
c_gpus = (ctypes.c_int32 * len(mem.mapped_gpu_ids))(*mem.mapped_gpu_ids)
|
||||
|
||||
@@ -462,6 +462,7 @@ class NVDevice(HCQCompatCompiled):
|
||||
gpuAttributesCount=1, perGpuAttributes=gpu_attrs, va_addr=va_base, size=size)
|
||||
|
||||
def _gpu_map(self, mem):
|
||||
mem = mem._base if hasattr(mem, '_base') else mem
|
||||
if self.gpu_uuid in getattr(mem, "mapped_gpu_ids", []): return
|
||||
mem.__setattr__("mapped_gpu_ids", getattr(mem, "mapped_gpu_ids", []) + [self.gpu_uuid])
|
||||
return self._gpu_uvm_map(mem.va_addr, mem.size, mem.hMemory, create_range=False)
|
||||
|
||||
Reference in New Issue
Block a user