new memory scheduler with explicit refcounts (#4198)

* new memory scheduler with explict refcounts

* move central memory planner

* typo + use central memory planner in openpilot

* cleanups

* include lb_refcount in pickle

* replace PlaceHolder with memory planner

* cleaner
This commit is contained in:
George Hotz
2024-04-17 08:46:47 +04:00
committed by GitHub
parent c91b7b1739
commit 8564e28a1b
6 changed files with 67 additions and 46 deletions

View File

@@ -14,7 +14,7 @@ from extra.onnx import get_run_onnx
from tinygrad import Tensor, Device, GlobalCounters, dtypes from tinygrad import Tensor, Device, GlobalCounters, dtypes
from tinygrad.dtype import ImageDType from tinygrad.dtype import ImageDType
from tinygrad.helpers import partition, Context, fetch, getenv, DEBUG from tinygrad.helpers import partition, Context, fetch, getenv, DEBUG
from tinygrad.engine.realize import run_schedule from tinygrad.engine.realize import run_schedule, memory_planner
from tinygrad.engine.schedule import create_schedule from tinygrad.engine.schedule import create_schedule
from tinygrad.ops import LoadOps, ScheduleItem from tinygrad.ops import LoadOps, ScheduleItem
Device.DEFAULT = "GPU" Device.DEFAULT = "GPU"
@@ -107,6 +107,7 @@ if __name__ == "__main__":
run_schedule(schedule_independent) run_schedule(schedule_independent)
run_schedule(schedule_input) run_schedule(schedule_input)
schedule = memory_planner(schedule)
with Context(DEBUG=max(DEBUG.value, 2), BEAM=getenv("LATEBEAM")): with Context(DEBUG=max(DEBUG.value, 2), BEAM=getenv("LATEBEAM")):
image_count = sum(isinstance(out.dtype, ImageDType) for si in schedule for out in si.outputs) image_count = sum(isinstance(out.dtype, ImageDType) for si in schedule for out in si.outputs)
print(f"**** running real kernels {image_count}/{len(schedule)} images ****") print(f"**** running real kernels {image_count}/{len(schedule)} images ****")

View File

@@ -12,14 +12,16 @@ class BufferOptions:
nolru: bool = False nolru: bool = False
class Buffer: class Buffer:
def __init__(self, device:str, size:int, dtype:DType, opaque:Any=None, options:Optional[BufferOptions]=None, initial_value:Optional[bytes]=None): def __init__(self, device:str, size:int, dtype:DType, opaque:Any=None, options:Optional[BufferOptions]=None,
initial_value:Optional[bytes]=None, lb_refcount=0):
assert isinstance(dtype, DType) assert isinstance(dtype, DType)
if isinstance(dtype, ImageDType): options = BufferOptions(image=dtype) # TODO: image hack shouldn't be here. where should it be? if isinstance(dtype, ImageDType): options = BufferOptions(image=dtype) # TODO: image hack shouldn't be here. where should it be?
self.device, self.size, self.dtype, self.options = device, size, dtype, options self.device, self.size, self.dtype, self.options, self.lb_refcount = device, size, dtype, options, lb_refcount
if opaque is not None: self.allocate(opaque) if opaque is not None: self.allocate(opaque)
if initial_value is not None: if initial_value is not None:
self.allocate() self.allocate()
self.copyin(memoryview(initial_value)) self.copyin(memoryview(initial_value))
def is_allocated(self) -> bool: return hasattr(self, '_buf')
def ensure_allocated(self) -> Buffer: return self.allocate() if not hasattr(self, '_buf') else self def ensure_allocated(self) -> Buffer: return self.allocate() if not hasattr(self, '_buf') else self
def allocate(self, opaque=None) -> Buffer: def allocate(self, opaque=None) -> Buffer:
assert not hasattr(self, '_buf'), "can't allocate already allocated buffer" assert not hasattr(self, '_buf'), "can't allocate already allocated buffer"
@@ -30,11 +32,11 @@ class Buffer:
return self return self
def __reduce__(self): def __reduce__(self):
buf = None buf = None
if self.device == "NPY": return self.__class__, (self.device, self.size, self.dtype, self._buf, self.options) if self.device == "NPY": return self.__class__, (self.device, self.size, self.dtype, self._buf, self.options, None, self.lb_refcount)
if hasattr(self, '_buf'): if self.is_allocated():
buf = bytearray(self.nbytes) buf = bytearray(self.nbytes)
self.copyout(memoryview(buf)) self.copyout(memoryview(buf))
return self.__class__, (self.device, self.size, self.dtype, None, self.options, buf) return self.__class__, (self.device, self.size, self.dtype, None, self.options, buf, self.lb_refcount)
@property @property
def nbytes(self): return self.size*self.dtype.itemsize def nbytes(self): return self.size*self.dtype.itemsize
def __del__(self): def __del__(self):
@@ -52,10 +54,12 @@ class Buffer:
def copyin(self, mv:memoryview): def copyin(self, mv:memoryview):
mv = flat_mv(mv) mv = flat_mv(mv)
assert len(mv) == self.nbytes, f"size mismatch, {len(mv)=} != {self.dtype=} {self.size=}" assert len(mv) == self.nbytes, f"size mismatch, {len(mv)=} != {self.dtype=} {self.size=}"
assert self.is_allocated(), "can't copyin to unallocated buffer"
self.allocator.copyin(self._buf, mv) self.allocator.copyin(self._buf, mv)
return self return self
def copyout(self, mv:memoryview) -> memoryview: def copyout(self, mv:memoryview) -> memoryview:
mv = flat_mv(mv) mv = flat_mv(mv)
assert len(mv) == self.nbytes, f"size mismatch, {len(mv)=} != {self.dtype=} {self.size=}" assert len(mv) == self.nbytes, f"size mismatch, {len(mv)=} != {self.dtype=} {self.size=}"
assert self.is_allocated(), "can't copyout unallocated buffer"
self.allocator.copyout(mv, self._buf) self.allocator.copyout(mv, self._buf)
return mv return mv

View File

@@ -1,17 +1,16 @@
from __future__ import annotations from __future__ import annotations
from typing import TypeVar, Generic, Callable, List, Tuple, Union, Dict, cast, Optional from typing import TypeVar, Generic, Callable, List, Tuple, Union, Dict, cast, Optional
import functools, itertools, operator import functools, itertools, operator
from dataclasses import dataclass
from tinygrad.tensor import Tensor from tinygrad.tensor import Tensor
from tinygrad.lazy import LazyBuffer from tinygrad.lazy import LazyBuffer
from tinygrad.helpers import flatten, merge_dicts, DEBUG, Context, GRAPH, BEAM, getenv, all_int, GraphException from tinygrad.helpers import flatten, merge_dicts, DEBUG, Context, GRAPH, BEAM, getenv, all_int, GraphException
from tinygrad.device import Buffer, Runner, CompiledRunner, BufferXfer, Compiled, MultiDeviceJITGraph, Device from tinygrad.device import Buffer, CompiledRunner, BufferXfer, Compiled, MultiDeviceJITGraph, Device
from tinygrad.dtype import DType from tinygrad.dtype import DType
from tinygrad.shape.shapetracker import ShapeTracker from tinygrad.shape.shapetracker import ShapeTracker
from tinygrad.shape.symbolic import Variable, sint from tinygrad.shape.symbolic import Variable, sint
from tinygrad.engine.realize import ExecItem, capturing from tinygrad.engine.realize import ExecItem, capturing, _internal_memory_planner
from tinygrad.nn.state import get_parameters from tinygrad.nn.state import get_parameters
from weakref import ref, WeakKeyDictionary from weakref import WeakKeyDictionary
# TODO: these graph functions probably shouldn't exist here # TODO: these graph functions probably shouldn't exist here
@@ -71,44 +70,25 @@ def get_input_replace(jit_cache: List[ExecItem], input_rawbuffers:List[Buffer])
input_replace[(j,i)] = input_rawbuffers.index(a) input_replace[(j,i)] = input_rawbuffers.index(a)
return input_replace return input_replace
class PlaceHolder:
placeholders: WeakKeyDictionary[Buffer, PlaceHolder] = WeakKeyDictionary()
def __init__(self, buf:Buffer):
self.size, self.dtype, self.device, self.ref, self.bufid, self.options = buf.size, buf.dtype, buf.device, ref(buf), id(buf._buf), buf.options
def to_tuple(self): return (self.size, self.dtype, self.device, self.bufid, self.options)
def __hash__(self): return hash(self.to_tuple())
def __eq__(self, x): return isinstance(x, PlaceHolder) and self.to_tuple() == x.to_tuple()
@staticmethod
def create_if_needed(buf:Buffer) -> Union[PlaceHolder, Buffer]:
if found:=PlaceHolder.placeholders.get(buf, None): return found
if hasattr(buf, '_buf'): return buf
PlaceHolder.placeholders[buf] = ret = PlaceHolder(buf.ensure_allocated()) # TODO: do I need to allocate here?
return ret
def alloc_if_needed(self, buffer_cache: Dict[PlaceHolder, Buffer]) -> Buffer:
ret = self.ref()
if ret: return ret
if self not in buffer_cache: buffer_cache[self] = Buffer(self.device, self.size, self.dtype, options=self.options).allocate()
return buffer_cache[self]
@dataclass(frozen=True)
class WeakExecItem:
prg: Runner
rawbufs: List[Union[PlaceHolder, Buffer]]
ReturnType = TypeVar('ReturnType') ReturnType = TypeVar('ReturnType')
class TinyJit(Generic[ReturnType]): class TinyJit(Generic[ReturnType]):
def __init__(self, fxn:Callable[..., ReturnType]): def __init__(self, fxn:Callable[..., ReturnType]):
self.fxn = fxn self.fxn = fxn
self.reset() self.reset()
def add_buffer(self, b:Buffer) -> Buffer:
if found:=self.buffer_replace.get(b, None): return found
if b.is_allocated() or b.lb_refcount > 0: return b
self.buffer_replace[b] = ret = Buffer(b.device, b.size, b.dtype, options=b.options)
return ret
def add(self, ei:ExecItem): def add(self, ei:ExecItem):
self._cc.append(WeakExecItem(ei.prg, [PlaceHolder.create_if_needed(buf) for buf in ei.rawbufs if buf is not None])) self.jit_cache.append(ExecItem(ei.prg, [self.add_buffer(buf) for buf in ei.rawbufs if buf is not None]))
def reset(self): def reset(self):
self._cc: List[WeakExecItem] = []
self.jit_cache: List[ExecItem] = [] self.jit_cache: List[ExecItem] = []
self.input_replace: Dict[Tuple[int, int], int] = {} self.input_replace: Dict[Tuple[int, int], int] = {}
self.buffer_replace: WeakKeyDictionary[Buffer, Buffer] = WeakKeyDictionary()
self.cnt: int = 0 self.cnt: int = 0
def __get__(self, obj, objtype): return functools.partial(self.__call__, obj) # add support for instance methods def __get__(self, obj, objtype): return functools.partial(self.__call__, obj) # add support for instance methods
@@ -140,13 +120,14 @@ class TinyJit(Generic[ReturnType]):
self.ret = self.fxn(*args, **kwargs) self.ret = self.fxn(*args, **kwargs)
Tensor.corealize(get_parameters(self.ret)) Tensor.corealize(get_parameters(self.ret))
capturing.clear() capturing.clear()
assert len(self._cc), "didn't JIT anything!" del self.buffer_replace
buffer_cache: Dict[PlaceHolder, Buffer] = {} assert len(self.jit_cache), "didn't JIT anything!"
self.jit_cache = \
[ExecItem(ei.prg, [x.alloc_if_needed(buffer_cache) if isinstance(x, PlaceHolder) else x for x in ei.rawbufs]) for ei in self._cc]
del self._cc
if DEBUG >= 1: print(f"JIT captured {len(self.jit_cache)} kernels with {len(input_rawbuffers)} inputs") if DEBUG >= 1: print(f"JIT captured {len(self.jit_cache)} kernels with {len(input_rawbuffers)} inputs")
# memory planning (optional)
assigned = _internal_memory_planner([cast(List[Buffer], x.rawbufs) for x in self.jit_cache], debug_prefix="JIT ")
self.jit_cache = [ExecItem(ei.prg, [assigned.get(x,x).ensure_allocated() for x in ei.rawbufs if x is not None]) for ei in self.jit_cache]
# Condense the items into a graph executor. # Condense the items into a graph executor.
if getenv("JIT") != 2: self.jit_cache = apply_graph_to_jit(self.jit_cache, input_rawbuffers, var_vals) if getenv("JIT") != 2: self.jit_cache = apply_graph_to_jit(self.jit_cache, input_rawbuffers, var_vals)

View File

@@ -1,6 +1,8 @@
from typing import List, Dict, Optional, cast, Generator from typing import List, Dict, Optional, cast, Generator, DefaultDict, Tuple, Iterable
from collections import defaultdict
from dataclasses import dataclass from dataclasses import dataclass
from tinygrad.helpers import colored, getenv from tinygrad.dtype import DType
from tinygrad.helpers import colored, getenv, dedup, DEBUG
from tinygrad.ops import ScheduleItem, BufferOps, LoadOps, copy_ast from tinygrad.ops import ScheduleItem, BufferOps, LoadOps, copy_ast
from tinygrad.device import Runner, Device, BufferCopy, BufferXfer, update_stats from tinygrad.device import Runner, Device, BufferCopy, BufferXfer, update_stats
from tinygrad.buffer import Buffer from tinygrad.buffer import Buffer
@@ -41,6 +43,34 @@ def lower_schedule(schedule:List[ScheduleItem]) -> Generator[ExecItem, None, Non
capturing: List = [] # put classes with an add method in here capturing: List = [] # put classes with an add method in here
def _internal_memory_planner(buffers:List[Iterable[Buffer]], debug_prefix="") -> Dict[Buffer, Buffer]:
last_appearance = {}
for i,u in enumerate(buffers):
for buf in u: last_appearance[buf] = i
# LRU algorithm
assigned: Dict[Buffer, Buffer] = {}
local_cache: DefaultDict[Tuple[str, int, DType], List[Buffer]] = defaultdict(list)
for i,u in enumerate(buffers):
for buf in u:
# all unallocated unparented buffers are fair game to replace
if buf.is_allocated() or buf.lb_refcount > 0: continue
key = (buf.device, buf.size, buf.dtype)
if buf not in assigned:
if len(ll:=local_cache[key]): assigned[buf] = ll.pop()
else: assigned[buf] = Buffer(*key)
if i == last_appearance[buf]:
local_cache[key].append(assigned[buf])
if DEBUG >= 1 and len(ak:=dedup(assigned.keys())) != len(av:=dedup(assigned.values())):
print(debug_prefix+f"memory reduced from {sum([x.nbytes for x in ak])/1e6:.2f} MB to {sum([x.nbytes for x in av])/1e6:.2f} MB")
return assigned
def memory_planner(schedule:List[ScheduleItem]) -> List[ScheduleItem]:
assigned = _internal_memory_planner([si.outputs+si.inputs for si in schedule])
return [ScheduleItem(si.ast, tuple(assigned.get(x, x) for x in si.outputs),
tuple(assigned.get(x, x) for x in si.inputs)) for si in schedule]
def run_schedule(schedule:List[ScheduleItem], var_vals:Optional[Dict[Variable, int]]=None): def run_schedule(schedule:List[ScheduleItem], var_vals:Optional[Dict[Variable, int]]=None):
for ei in lower_schedule(schedule): for ei in lower_schedule(schedule):
if len(capturing): capturing[0].add(ei) if len(capturing): capturing[0].add(ei)

View File

@@ -33,6 +33,7 @@ class LazyBuffer:
self.op, self.arg, self.srcs = op, arg, srcs # this is a LazyOp, except the src is LazyBuffers and not LazyOps self.op, self.arg, self.srcs = op, arg, srcs # this is a LazyOp, except the src is LazyBuffers and not LazyOps
assert self.op is not LoadOps.ASSIGN or srcs[1].base.realized is not None, "assign target must be realized" assert self.op is not LoadOps.ASSIGN or srcs[1].base.realized is not None, "assign target must be realized"
self.buffer: Buffer = srcs[1].base.buffer if self.op is LoadOps.ASSIGN else Buffer(device, self.size, dtype) self.buffer: Buffer = srcs[1].base.buffer if self.op is LoadOps.ASSIGN else Buffer(device, self.size, dtype)
self.buffer.lb_refcount += 1
self.contiguous_child: Optional[Tuple[ReferenceType[LazyBuffer], ShapeTracker]] = None self.contiguous_child: Optional[Tuple[ReferenceType[LazyBuffer], ShapeTracker]] = None
self.forced_realize = False self.forced_realize = False
else: else:
@@ -40,6 +41,9 @@ class LazyBuffer:
assert base.base == base, "base must be a base itself" assert base.base == base, "base must be a base itself"
self._base = base self._base = base
def __del__(self):
if hasattr(self, 'buffer'): self.buffer.lb_refcount -= 1
def __repr__(self) -> str: def __repr__(self) -> str:
return f"<LB {self.device} {self.shape} {str(self.dtype)[7:]} {self.st if self.base != self else (self.op, self.realized)}>" return f"<LB {self.device} {self.shape} {str(self.dtype)[7:]} {self.st if self.base != self else (self.op, self.realized)}>"

View File

@@ -15,7 +15,7 @@ from tinygrad.ops import LoadOps
from tinygrad.buffer import Buffer, BufferOptions from tinygrad.buffer import Buffer, BufferOptions
from tinygrad.device import Device from tinygrad.device import Device
from tinygrad.shape.symbolic import sint from tinygrad.shape.symbolic import sint
from tinygrad.engine.realize import run_schedule from tinygrad.engine.realize import run_schedule, memory_planner
from tinygrad.engine.schedule import create_schedule_with_vars from tinygrad.engine.schedule import create_schedule_with_vars
# **** start with two base classes, Tensor and Function **** # **** start with two base classes, Tensor and Function ****
@@ -145,7 +145,8 @@ class Tensor:
if getenv("FUZZ_SCHEDULE"): if getenv("FUZZ_SCHEDULE"):
from test.external.fuzz_schedule import fuzz_schedule from test.external.fuzz_schedule import fuzz_schedule
fuzz_schedule(flatten([x.lazydata.lbs for x in lst])) fuzz_schedule(flatten([x.lazydata.lbs for x in lst]))
run_schedule(*create_schedule_with_vars(flatten([x.lazydata.lbs for x in lst]))) schedule, var_vals = create_schedule_with_vars(flatten([x.lazydata.lbs for x in lst]))
run_schedule(memory_planner(schedule), var_vals)
def realize(self) -> Tensor: def realize(self) -> Tensor:
"""Trigger the computation needed to create this Tensor. This is a light wrapper around corealize.""" """Trigger the computation needed to create this Tensor. This is a light wrapper around corealize."""