Files
tinygrad/tinygrad/engine/realize.py
George Hotz 9fc4465557 subbuffer support (#4397)
* subbuffer support

* diskbuffer offset

* cuda subbuffer works

* use subbuffer

* more subbuffer tests

* consecutive

* cast

* consec

* offset

* view is a better name

* offset is in nbytes

* fix view + memory planner

* delete unused DiskRunner

* reverse order

* no subbuffers on unrealized consts

* only enabled for disk

* don't reverse memory

* view supported devices

* pickle buffer view

* ring jit

* support extra view inputs in jit

* fix JIT=2 issue

* test copy jit

* p2p isn't an option anymore

* fix dep tracking issue

* fix mypy

* fix pickle

* from_nv is contents now
2024-05-03 18:05:57 -07:00

111 lines
6.0 KiB
Python

from typing import List, Dict, Optional, cast, Generator, DefaultDict, Tuple, Union
from collections import defaultdict
from dataclasses import dataclass
from tinygrad.dtype import DType
from tinygrad.helpers import colored, getenv, dedup, DEBUG, GlobalCounters, ansilen
from tinygrad.ops import ScheduleItem, BufferOps, LoadOps, copy_ast
from tinygrad.device import Runner, Device, BufferCopy, BufferXfer
from tinygrad.buffer import Buffer
from tinygrad.shape.symbolic import Variable, sym_infer
@dataclass(frozen=True)
class ExecItem:
prg: Runner
bufs: List[Optional[Buffer]]
def run(self, var_vals:Optional[Dict[Variable, int]]=None, wait=False, jit=False, do_update_stats=True) -> Optional[float]:
bufs = [cast(Buffer, x) for x in self.bufs] if jit else [cast(Buffer, x).ensure_allocated() for x in self.bufs]
et = self.prg(bufs, var_vals if var_vals is not None else {}, wait=wait or DEBUG >= 2)
if do_update_stats:
GlobalCounters.kernel_count += 1
GlobalCounters.global_ops += (op_estimate:=sym_infer(self.prg.op_estimate, var_vals))
GlobalCounters.global_mem += (mem_estimate:=sym_infer(self.prg.mem_estimate, var_vals))
if et is not None: GlobalCounters.time_sum_s += et
if DEBUG >= 2:
ptm = (colored(f"{et*1e3:9.2f}ms", "yellow") if et > 0.01 else f"{et*1e6:9.2f}us") if et is not None else ""
print(f"{colored(f'*** {self.prg.dname[:7]:7s} {GlobalCounters.kernel_count:4d}', 'magenta' if jit else ('green' if self.prg.first_run else None))} {self.prg.display_name+' '*(38-ansilen(self.prg.display_name))} arg {len(self.bufs):3d} mem {GlobalCounters.mem_used/1e9:5.2f} GB " + # noqa: E501
(str() if et is None else f"tm {ptm}/{GlobalCounters.time_sum_s*1e3:9.2f}ms ({op_estimate/((et or 1e-20)*1e9):8.2f} GFLOPS, {mem_estimate/((et or 1e-20)*1e9):7.2f} GB/s)")) # noqa: E501
self.prg.first_run = False
return et
class CustomOp(Runner):
def __init__(self, fxn):
self.fxn = fxn
super().__init__(self.fxn.__name__, "CUSTOM", 0, 0)
def __call__(self, rawbufs:List[Buffer], var_vals:Dict[Variable, int], wait=False): self.fxn(*rawbufs)
class EmptyOp(Runner):
def __init__(self, buf:Buffer): super().__init__(colored(f"empty {buf.size:10d} {buf.dtype}", "yellow"), buf.device)
def __call__(self, rawbufs:List[Buffer], var_vals:Dict[Variable, int], wait=False): pass
class ViewOp(Runner):
def __init__(self, buf:Buffer): super().__init__(colored(f"view {buf.nbytes:8d} @ {buf.offset:<10d}", "yellow"), buf.device)
def __call__(self, rawbufs:List[Buffer], var_vals:Dict[Variable, int], wait=False):
assert rawbufs[0]._base is not None and rawbufs[0]._base == rawbufs[1].base, f"must be base {rawbufs}"
def lower_runner(runner:Runner, bufs) -> ExecItem:
# TODO: globals isn't on the stupid diskrunner, remove the need for it
return ExecItem(runner, [bufs[x[0]] for x in runner.globals] if hasattr(runner, 'globals') else bufs)
def lower_schedule_item(si:ScheduleItem) -> ExecItem:
assert len(set(x.device for x in si.bufs)) == 1 or si.ast[0].op is LoadOps.COPY
if si.ast[0].op is BufferOps.STORE: return lower_runner(Device[si.outputs[0].device].get_runner(*si.ast), si.bufs)
assert len(si.ast) == 1 and len(si.outputs) == 1, "only ASTRunner supports multioutput"
out, ast = si.outputs[0], si.ast[0]
if ast.op is LoadOps.COPY:
kernel_type = BufferCopy
if hasattr(Device[out.device].allocator, 'transfer') and out.device.split(":")[0] == si.inputs[0].device.split(":")[0]:
if getenv("USE_COPY_KERNEL"): return lower_runner(Device[out.device].get_runner(copy_ast(ast.arg)), si.bufs)
kernel_type = BufferXfer
return ExecItem(kernel_type(ast.arg, out.device, si.inputs[0].device), list(si.bufs))
if ast.op is LoadOps.CUSTOM: return ExecItem(CustomOp(ast.arg), list(si.bufs))
if ast.op is LoadOps.EMPTY: return ExecItem(EmptyOp(out), list(si.bufs))
if ast.op is LoadOps.VIEW: return ExecItem(ViewOp(out), list(si.bufs))
raise RuntimeError(f"don't know how to lower {ast}")
def lower_schedule(schedule:List[ScheduleItem]) -> Generator[ExecItem, None, None]:
while len(schedule): yield lower_schedule_item(schedule.pop(0))
capturing: List = [] # put classes with an add method in here
def _internal_memory_planner(buffers:List[Union[List[Buffer], Tuple[Buffer, ...]]], debug_prefix="") -> Dict[Buffer, Buffer]:
if getenv("NO_MEMORY_PLANNER"): return {}
last_appearance = {}
for i,u in enumerate(buffers):
for buf in u: last_appearance[buf] = i
# LRU algorithm
assigned: Dict[Buffer, Buffer] = {}
local_cache: DefaultDict[Tuple[str, int, DType], List[Buffer]] = defaultdict(list)
def handle_buffer(buf):
key = (buf.device, buf.size, buf.dtype)
if buf not in assigned:
if len(ll:=local_cache[key]): assigned[buf] = ll.pop()
else: assigned[buf] = Buffer(*key)
if i == last_appearance[buf]:
if assigned[buf] not in local_cache[key]: local_cache[key].append(assigned[buf])
for i,u in enumerate(buffers):
for buf in u:
# all unallocated unparented buffers are fair game to replace
if buf.is_allocated() or buf.lb_refcount > 0: continue
# handle view buffers
if buf._base is not None:
assigned[buf] = Buffer(buf.device, buf.size, buf.dtype, base=assigned.get(buf._base, buf._base), offset=buf.offset)
else:
handle_buffer(buf)
if DEBUG >= 1 and len(ak:=dedup(assigned.keys())) != len(av:=dedup(assigned.values())):
print(debug_prefix+f"memory reduced from {sum([x.nbytes for x in ak])/1e6:.2f} MB -> {sum([x.nbytes for x in av])/1e6:.2f} MB,",
f"{len(ak)} -> {len(av)} bufs")
return assigned
def memory_planner(schedule:List[ScheduleItem]) -> List[ScheduleItem]:
assigned = _internal_memory_planner([si.bufs for si in schedule])
return [ScheduleItem(si.ast, tuple(assigned.get(x, x) for x in si.bufs)) for si in schedule]
def run_schedule(schedule:List[ScheduleItem], var_vals:Optional[Dict[Variable, int]]=None):
for ei in lower_schedule(schedule):
if len(capturing): capturing[0].add(ei)
ei.run(var_vals)