mirror of
https://github.com/tinygrad/tinygrad.git
synced 2026-02-12 07:35:16 -05:00
new memory scheduler with explicit refcounts (#4198)
* new memory scheduler with explict refcounts * move central memory planner * typo + use central memory planner in openpilot * cleanups * include lb_refcount in pickle * replace PlaceHolder with memory planner * cleaner
This commit is contained in:
@@ -14,7 +14,7 @@ from extra.onnx import get_run_onnx
|
|||||||
from tinygrad import Tensor, Device, GlobalCounters, dtypes
|
from tinygrad import Tensor, Device, GlobalCounters, dtypes
|
||||||
from tinygrad.dtype import ImageDType
|
from tinygrad.dtype import ImageDType
|
||||||
from tinygrad.helpers import partition, Context, fetch, getenv, DEBUG
|
from tinygrad.helpers import partition, Context, fetch, getenv, DEBUG
|
||||||
from tinygrad.engine.realize import run_schedule
|
from tinygrad.engine.realize import run_schedule, memory_planner
|
||||||
from tinygrad.engine.schedule import create_schedule
|
from tinygrad.engine.schedule import create_schedule
|
||||||
from tinygrad.ops import LoadOps, ScheduleItem
|
from tinygrad.ops import LoadOps, ScheduleItem
|
||||||
Device.DEFAULT = "GPU"
|
Device.DEFAULT = "GPU"
|
||||||
@@ -107,6 +107,7 @@ if __name__ == "__main__":
|
|||||||
|
|
||||||
run_schedule(schedule_independent)
|
run_schedule(schedule_independent)
|
||||||
run_schedule(schedule_input)
|
run_schedule(schedule_input)
|
||||||
|
schedule = memory_planner(schedule)
|
||||||
with Context(DEBUG=max(DEBUG.value, 2), BEAM=getenv("LATEBEAM")):
|
with Context(DEBUG=max(DEBUG.value, 2), BEAM=getenv("LATEBEAM")):
|
||||||
image_count = sum(isinstance(out.dtype, ImageDType) for si in schedule for out in si.outputs)
|
image_count = sum(isinstance(out.dtype, ImageDType) for si in schedule for out in si.outputs)
|
||||||
print(f"**** running real kernels {image_count}/{len(schedule)} images ****")
|
print(f"**** running real kernels {image_count}/{len(schedule)} images ****")
|
||||||
|
|||||||
@@ -12,14 +12,16 @@ class BufferOptions:
|
|||||||
nolru: bool = False
|
nolru: bool = False
|
||||||
|
|
||||||
class Buffer:
|
class Buffer:
|
||||||
def __init__(self, device:str, size:int, dtype:DType, opaque:Any=None, options:Optional[BufferOptions]=None, initial_value:Optional[bytes]=None):
|
def __init__(self, device:str, size:int, dtype:DType, opaque:Any=None, options:Optional[BufferOptions]=None,
|
||||||
|
initial_value:Optional[bytes]=None, lb_refcount=0):
|
||||||
assert isinstance(dtype, DType)
|
assert isinstance(dtype, DType)
|
||||||
if isinstance(dtype, ImageDType): options = BufferOptions(image=dtype) # TODO: image hack shouldn't be here. where should it be?
|
if isinstance(dtype, ImageDType): options = BufferOptions(image=dtype) # TODO: image hack shouldn't be here. where should it be?
|
||||||
self.device, self.size, self.dtype, self.options = device, size, dtype, options
|
self.device, self.size, self.dtype, self.options, self.lb_refcount = device, size, dtype, options, lb_refcount
|
||||||
if opaque is not None: self.allocate(opaque)
|
if opaque is not None: self.allocate(opaque)
|
||||||
if initial_value is not None:
|
if initial_value is not None:
|
||||||
self.allocate()
|
self.allocate()
|
||||||
self.copyin(memoryview(initial_value))
|
self.copyin(memoryview(initial_value))
|
||||||
|
def is_allocated(self) -> bool: return hasattr(self, '_buf')
|
||||||
def ensure_allocated(self) -> Buffer: return self.allocate() if not hasattr(self, '_buf') else self
|
def ensure_allocated(self) -> Buffer: return self.allocate() if not hasattr(self, '_buf') else self
|
||||||
def allocate(self, opaque=None) -> Buffer:
|
def allocate(self, opaque=None) -> Buffer:
|
||||||
assert not hasattr(self, '_buf'), "can't allocate already allocated buffer"
|
assert not hasattr(self, '_buf'), "can't allocate already allocated buffer"
|
||||||
@@ -30,11 +32,11 @@ class Buffer:
|
|||||||
return self
|
return self
|
||||||
def __reduce__(self):
|
def __reduce__(self):
|
||||||
buf = None
|
buf = None
|
||||||
if self.device == "NPY": return self.__class__, (self.device, self.size, self.dtype, self._buf, self.options)
|
if self.device == "NPY": return self.__class__, (self.device, self.size, self.dtype, self._buf, self.options, None, self.lb_refcount)
|
||||||
if hasattr(self, '_buf'):
|
if self.is_allocated():
|
||||||
buf = bytearray(self.nbytes)
|
buf = bytearray(self.nbytes)
|
||||||
self.copyout(memoryview(buf))
|
self.copyout(memoryview(buf))
|
||||||
return self.__class__, (self.device, self.size, self.dtype, None, self.options, buf)
|
return self.__class__, (self.device, self.size, self.dtype, None, self.options, buf, self.lb_refcount)
|
||||||
@property
|
@property
|
||||||
def nbytes(self): return self.size*self.dtype.itemsize
|
def nbytes(self): return self.size*self.dtype.itemsize
|
||||||
def __del__(self):
|
def __del__(self):
|
||||||
@@ -52,10 +54,12 @@ class Buffer:
|
|||||||
def copyin(self, mv:memoryview):
|
def copyin(self, mv:memoryview):
|
||||||
mv = flat_mv(mv)
|
mv = flat_mv(mv)
|
||||||
assert len(mv) == self.nbytes, f"size mismatch, {len(mv)=} != {self.dtype=} {self.size=}"
|
assert len(mv) == self.nbytes, f"size mismatch, {len(mv)=} != {self.dtype=} {self.size=}"
|
||||||
|
assert self.is_allocated(), "can't copyin to unallocated buffer"
|
||||||
self.allocator.copyin(self._buf, mv)
|
self.allocator.copyin(self._buf, mv)
|
||||||
return self
|
return self
|
||||||
def copyout(self, mv:memoryview) -> memoryview:
|
def copyout(self, mv:memoryview) -> memoryview:
|
||||||
mv = flat_mv(mv)
|
mv = flat_mv(mv)
|
||||||
assert len(mv) == self.nbytes, f"size mismatch, {len(mv)=} != {self.dtype=} {self.size=}"
|
assert len(mv) == self.nbytes, f"size mismatch, {len(mv)=} != {self.dtype=} {self.size=}"
|
||||||
|
assert self.is_allocated(), "can't copyout unallocated buffer"
|
||||||
self.allocator.copyout(mv, self._buf)
|
self.allocator.copyout(mv, self._buf)
|
||||||
return mv
|
return mv
|
||||||
|
|||||||
@@ -1,17 +1,16 @@
|
|||||||
from __future__ import annotations
|
from __future__ import annotations
|
||||||
from typing import TypeVar, Generic, Callable, List, Tuple, Union, Dict, cast, Optional
|
from typing import TypeVar, Generic, Callable, List, Tuple, Union, Dict, cast, Optional
|
||||||
import functools, itertools, operator
|
import functools, itertools, operator
|
||||||
from dataclasses import dataclass
|
|
||||||
from tinygrad.tensor import Tensor
|
from tinygrad.tensor import Tensor
|
||||||
from tinygrad.lazy import LazyBuffer
|
from tinygrad.lazy import LazyBuffer
|
||||||
from tinygrad.helpers import flatten, merge_dicts, DEBUG, Context, GRAPH, BEAM, getenv, all_int, GraphException
|
from tinygrad.helpers import flatten, merge_dicts, DEBUG, Context, GRAPH, BEAM, getenv, all_int, GraphException
|
||||||
from tinygrad.device import Buffer, Runner, CompiledRunner, BufferXfer, Compiled, MultiDeviceJITGraph, Device
|
from tinygrad.device import Buffer, CompiledRunner, BufferXfer, Compiled, MultiDeviceJITGraph, Device
|
||||||
from tinygrad.dtype import DType
|
from tinygrad.dtype import DType
|
||||||
from tinygrad.shape.shapetracker import ShapeTracker
|
from tinygrad.shape.shapetracker import ShapeTracker
|
||||||
from tinygrad.shape.symbolic import Variable, sint
|
from tinygrad.shape.symbolic import Variable, sint
|
||||||
from tinygrad.engine.realize import ExecItem, capturing
|
from tinygrad.engine.realize import ExecItem, capturing, _internal_memory_planner
|
||||||
from tinygrad.nn.state import get_parameters
|
from tinygrad.nn.state import get_parameters
|
||||||
from weakref import ref, WeakKeyDictionary
|
from weakref import WeakKeyDictionary
|
||||||
|
|
||||||
# TODO: these graph functions probably shouldn't exist here
|
# TODO: these graph functions probably shouldn't exist here
|
||||||
|
|
||||||
@@ -71,44 +70,25 @@ def get_input_replace(jit_cache: List[ExecItem], input_rawbuffers:List[Buffer])
|
|||||||
input_replace[(j,i)] = input_rawbuffers.index(a)
|
input_replace[(j,i)] = input_rawbuffers.index(a)
|
||||||
return input_replace
|
return input_replace
|
||||||
|
|
||||||
class PlaceHolder:
|
|
||||||
placeholders: WeakKeyDictionary[Buffer, PlaceHolder] = WeakKeyDictionary()
|
|
||||||
def __init__(self, buf:Buffer):
|
|
||||||
self.size, self.dtype, self.device, self.ref, self.bufid, self.options = buf.size, buf.dtype, buf.device, ref(buf), id(buf._buf), buf.options
|
|
||||||
def to_tuple(self): return (self.size, self.dtype, self.device, self.bufid, self.options)
|
|
||||||
def __hash__(self): return hash(self.to_tuple())
|
|
||||||
def __eq__(self, x): return isinstance(x, PlaceHolder) and self.to_tuple() == x.to_tuple()
|
|
||||||
@staticmethod
|
|
||||||
def create_if_needed(buf:Buffer) -> Union[PlaceHolder, Buffer]:
|
|
||||||
if found:=PlaceHolder.placeholders.get(buf, None): return found
|
|
||||||
if hasattr(buf, '_buf'): return buf
|
|
||||||
PlaceHolder.placeholders[buf] = ret = PlaceHolder(buf.ensure_allocated()) # TODO: do I need to allocate here?
|
|
||||||
return ret
|
|
||||||
|
|
||||||
def alloc_if_needed(self, buffer_cache: Dict[PlaceHolder, Buffer]) -> Buffer:
|
|
||||||
ret = self.ref()
|
|
||||||
if ret: return ret
|
|
||||||
if self not in buffer_cache: buffer_cache[self] = Buffer(self.device, self.size, self.dtype, options=self.options).allocate()
|
|
||||||
return buffer_cache[self]
|
|
||||||
|
|
||||||
@dataclass(frozen=True)
|
|
||||||
class WeakExecItem:
|
|
||||||
prg: Runner
|
|
||||||
rawbufs: List[Union[PlaceHolder, Buffer]]
|
|
||||||
|
|
||||||
ReturnType = TypeVar('ReturnType')
|
ReturnType = TypeVar('ReturnType')
|
||||||
class TinyJit(Generic[ReturnType]):
|
class TinyJit(Generic[ReturnType]):
|
||||||
def __init__(self, fxn:Callable[..., ReturnType]):
|
def __init__(self, fxn:Callable[..., ReturnType]):
|
||||||
self.fxn = fxn
|
self.fxn = fxn
|
||||||
self.reset()
|
self.reset()
|
||||||
|
|
||||||
|
def add_buffer(self, b:Buffer) -> Buffer:
|
||||||
|
if found:=self.buffer_replace.get(b, None): return found
|
||||||
|
if b.is_allocated() or b.lb_refcount > 0: return b
|
||||||
|
self.buffer_replace[b] = ret = Buffer(b.device, b.size, b.dtype, options=b.options)
|
||||||
|
return ret
|
||||||
|
|
||||||
def add(self, ei:ExecItem):
|
def add(self, ei:ExecItem):
|
||||||
self._cc.append(WeakExecItem(ei.prg, [PlaceHolder.create_if_needed(buf) for buf in ei.rawbufs if buf is not None]))
|
self.jit_cache.append(ExecItem(ei.prg, [self.add_buffer(buf) for buf in ei.rawbufs if buf is not None]))
|
||||||
|
|
||||||
def reset(self):
|
def reset(self):
|
||||||
self._cc: List[WeakExecItem] = []
|
|
||||||
self.jit_cache: List[ExecItem] = []
|
self.jit_cache: List[ExecItem] = []
|
||||||
self.input_replace: Dict[Tuple[int, int], int] = {}
|
self.input_replace: Dict[Tuple[int, int], int] = {}
|
||||||
|
self.buffer_replace: WeakKeyDictionary[Buffer, Buffer] = WeakKeyDictionary()
|
||||||
self.cnt: int = 0
|
self.cnt: int = 0
|
||||||
|
|
||||||
def __get__(self, obj, objtype): return functools.partial(self.__call__, obj) # add support for instance methods
|
def __get__(self, obj, objtype): return functools.partial(self.__call__, obj) # add support for instance methods
|
||||||
@@ -140,13 +120,14 @@ class TinyJit(Generic[ReturnType]):
|
|||||||
self.ret = self.fxn(*args, **kwargs)
|
self.ret = self.fxn(*args, **kwargs)
|
||||||
Tensor.corealize(get_parameters(self.ret))
|
Tensor.corealize(get_parameters(self.ret))
|
||||||
capturing.clear()
|
capturing.clear()
|
||||||
assert len(self._cc), "didn't JIT anything!"
|
del self.buffer_replace
|
||||||
buffer_cache: Dict[PlaceHolder, Buffer] = {}
|
assert len(self.jit_cache), "didn't JIT anything!"
|
||||||
self.jit_cache = \
|
|
||||||
[ExecItem(ei.prg, [x.alloc_if_needed(buffer_cache) if isinstance(x, PlaceHolder) else x for x in ei.rawbufs]) for ei in self._cc]
|
|
||||||
del self._cc
|
|
||||||
if DEBUG >= 1: print(f"JIT captured {len(self.jit_cache)} kernels with {len(input_rawbuffers)} inputs")
|
if DEBUG >= 1: print(f"JIT captured {len(self.jit_cache)} kernels with {len(input_rawbuffers)} inputs")
|
||||||
|
|
||||||
|
# memory planning (optional)
|
||||||
|
assigned = _internal_memory_planner([cast(List[Buffer], x.rawbufs) for x in self.jit_cache], debug_prefix="JIT ")
|
||||||
|
self.jit_cache = [ExecItem(ei.prg, [assigned.get(x,x).ensure_allocated() for x in ei.rawbufs if x is not None]) for ei in self.jit_cache]
|
||||||
|
|
||||||
# Condense the items into a graph executor.
|
# Condense the items into a graph executor.
|
||||||
if getenv("JIT") != 2: self.jit_cache = apply_graph_to_jit(self.jit_cache, input_rawbuffers, var_vals)
|
if getenv("JIT") != 2: self.jit_cache = apply_graph_to_jit(self.jit_cache, input_rawbuffers, var_vals)
|
||||||
|
|
||||||
|
|||||||
@@ -1,6 +1,8 @@
|
|||||||
from typing import List, Dict, Optional, cast, Generator
|
from typing import List, Dict, Optional, cast, Generator, DefaultDict, Tuple, Iterable
|
||||||
|
from collections import defaultdict
|
||||||
from dataclasses import dataclass
|
from dataclasses import dataclass
|
||||||
from tinygrad.helpers import colored, getenv
|
from tinygrad.dtype import DType
|
||||||
|
from tinygrad.helpers import colored, getenv, dedup, DEBUG
|
||||||
from tinygrad.ops import ScheduleItem, BufferOps, LoadOps, copy_ast
|
from tinygrad.ops import ScheduleItem, BufferOps, LoadOps, copy_ast
|
||||||
from tinygrad.device import Runner, Device, BufferCopy, BufferXfer, update_stats
|
from tinygrad.device import Runner, Device, BufferCopy, BufferXfer, update_stats
|
||||||
from tinygrad.buffer import Buffer
|
from tinygrad.buffer import Buffer
|
||||||
@@ -41,6 +43,34 @@ def lower_schedule(schedule:List[ScheduleItem]) -> Generator[ExecItem, None, Non
|
|||||||
|
|
||||||
capturing: List = [] # put classes with an add method in here
|
capturing: List = [] # put classes with an add method in here
|
||||||
|
|
||||||
|
def _internal_memory_planner(buffers:List[Iterable[Buffer]], debug_prefix="") -> Dict[Buffer, Buffer]:
|
||||||
|
last_appearance = {}
|
||||||
|
for i,u in enumerate(buffers):
|
||||||
|
for buf in u: last_appearance[buf] = i
|
||||||
|
|
||||||
|
# LRU algorithm
|
||||||
|
assigned: Dict[Buffer, Buffer] = {}
|
||||||
|
local_cache: DefaultDict[Tuple[str, int, DType], List[Buffer]] = defaultdict(list)
|
||||||
|
for i,u in enumerate(buffers):
|
||||||
|
for buf in u:
|
||||||
|
# all unallocated unparented buffers are fair game to replace
|
||||||
|
if buf.is_allocated() or buf.lb_refcount > 0: continue
|
||||||
|
key = (buf.device, buf.size, buf.dtype)
|
||||||
|
if buf not in assigned:
|
||||||
|
if len(ll:=local_cache[key]): assigned[buf] = ll.pop()
|
||||||
|
else: assigned[buf] = Buffer(*key)
|
||||||
|
if i == last_appearance[buf]:
|
||||||
|
local_cache[key].append(assigned[buf])
|
||||||
|
|
||||||
|
if DEBUG >= 1 and len(ak:=dedup(assigned.keys())) != len(av:=dedup(assigned.values())):
|
||||||
|
print(debug_prefix+f"memory reduced from {sum([x.nbytes for x in ak])/1e6:.2f} MB to {sum([x.nbytes for x in av])/1e6:.2f} MB")
|
||||||
|
return assigned
|
||||||
|
|
||||||
|
def memory_planner(schedule:List[ScheduleItem]) -> List[ScheduleItem]:
|
||||||
|
assigned = _internal_memory_planner([si.outputs+si.inputs for si in schedule])
|
||||||
|
return [ScheduleItem(si.ast, tuple(assigned.get(x, x) for x in si.outputs),
|
||||||
|
tuple(assigned.get(x, x) for x in si.inputs)) for si in schedule]
|
||||||
|
|
||||||
def run_schedule(schedule:List[ScheduleItem], var_vals:Optional[Dict[Variable, int]]=None):
|
def run_schedule(schedule:List[ScheduleItem], var_vals:Optional[Dict[Variable, int]]=None):
|
||||||
for ei in lower_schedule(schedule):
|
for ei in lower_schedule(schedule):
|
||||||
if len(capturing): capturing[0].add(ei)
|
if len(capturing): capturing[0].add(ei)
|
||||||
|
|||||||
@@ -33,6 +33,7 @@ class LazyBuffer:
|
|||||||
self.op, self.arg, self.srcs = op, arg, srcs # this is a LazyOp, except the src is LazyBuffers and not LazyOps
|
self.op, self.arg, self.srcs = op, arg, srcs # this is a LazyOp, except the src is LazyBuffers and not LazyOps
|
||||||
assert self.op is not LoadOps.ASSIGN or srcs[1].base.realized is not None, "assign target must be realized"
|
assert self.op is not LoadOps.ASSIGN or srcs[1].base.realized is not None, "assign target must be realized"
|
||||||
self.buffer: Buffer = srcs[1].base.buffer if self.op is LoadOps.ASSIGN else Buffer(device, self.size, dtype)
|
self.buffer: Buffer = srcs[1].base.buffer if self.op is LoadOps.ASSIGN else Buffer(device, self.size, dtype)
|
||||||
|
self.buffer.lb_refcount += 1
|
||||||
self.contiguous_child: Optional[Tuple[ReferenceType[LazyBuffer], ShapeTracker]] = None
|
self.contiguous_child: Optional[Tuple[ReferenceType[LazyBuffer], ShapeTracker]] = None
|
||||||
self.forced_realize = False
|
self.forced_realize = False
|
||||||
else:
|
else:
|
||||||
@@ -40,6 +41,9 @@ class LazyBuffer:
|
|||||||
assert base.base == base, "base must be a base itself"
|
assert base.base == base, "base must be a base itself"
|
||||||
self._base = base
|
self._base = base
|
||||||
|
|
||||||
|
def __del__(self):
|
||||||
|
if hasattr(self, 'buffer'): self.buffer.lb_refcount -= 1
|
||||||
|
|
||||||
def __repr__(self) -> str:
|
def __repr__(self) -> str:
|
||||||
return f"<LB {self.device} {self.shape} {str(self.dtype)[7:]} {self.st if self.base != self else (self.op, self.realized)}>"
|
return f"<LB {self.device} {self.shape} {str(self.dtype)[7:]} {self.st if self.base != self else (self.op, self.realized)}>"
|
||||||
|
|
||||||
|
|||||||
@@ -15,7 +15,7 @@ from tinygrad.ops import LoadOps
|
|||||||
from tinygrad.buffer import Buffer, BufferOptions
|
from tinygrad.buffer import Buffer, BufferOptions
|
||||||
from tinygrad.device import Device
|
from tinygrad.device import Device
|
||||||
from tinygrad.shape.symbolic import sint
|
from tinygrad.shape.symbolic import sint
|
||||||
from tinygrad.engine.realize import run_schedule
|
from tinygrad.engine.realize import run_schedule, memory_planner
|
||||||
from tinygrad.engine.schedule import create_schedule_with_vars
|
from tinygrad.engine.schedule import create_schedule_with_vars
|
||||||
|
|
||||||
# **** start with two base classes, Tensor and Function ****
|
# **** start with two base classes, Tensor and Function ****
|
||||||
@@ -145,7 +145,8 @@ class Tensor:
|
|||||||
if getenv("FUZZ_SCHEDULE"):
|
if getenv("FUZZ_SCHEDULE"):
|
||||||
from test.external.fuzz_schedule import fuzz_schedule
|
from test.external.fuzz_schedule import fuzz_schedule
|
||||||
fuzz_schedule(flatten([x.lazydata.lbs for x in lst]))
|
fuzz_schedule(flatten([x.lazydata.lbs for x in lst]))
|
||||||
run_schedule(*create_schedule_with_vars(flatten([x.lazydata.lbs for x in lst])))
|
schedule, var_vals = create_schedule_with_vars(flatten([x.lazydata.lbs for x in lst]))
|
||||||
|
run_schedule(memory_planner(schedule), var_vals)
|
||||||
|
|
||||||
def realize(self) -> Tensor:
|
def realize(self) -> Tensor:
|
||||||
"""Trigger the computation needed to create this Tensor. This is a light wrapper around corealize."""
|
"""Trigger the computation needed to create this Tensor. This is a light wrapper around corealize."""
|
||||||
|
|||||||
Reference in New Issue
Block a user