ScheduleItem uses Buffer (#3995)

* schedule Buffer

* update

* update tests

* master

* works

* remove LoadOps.WAIT

* fix compile2

* bad test

* rename and note
This commit is contained in:
George Hotz
2024-03-29 20:50:27 -07:00
committed by GitHub
parent 1bd4f01da2
commit 9eef44521b
9 changed files with 47 additions and 41 deletions

View File

@@ -1,8 +1,7 @@
from typing import List, Dict, Optional
from tinygrad.ops import LoadOps, ScheduleItem, BufferOps, GlobalCounters
from tinygrad.ops import LoadOps, ScheduleItem, BufferOps
from tinygrad.device import Device, Buffer, BufferCopy, BufferXfer, JITRunner, update_stats
from tinygrad.features.graph import realized_lazybuffer
from tinygrad.helpers import colored, getenv, GRAPH, cpu_time_execution, DEBUG
from tinygrad.helpers import colored, getenv, cpu_time_execution, DEBUG
from tinygrad.shape.symbolic import Variable
class CustomOp(JITRunner):
@@ -20,7 +19,7 @@ class SyncOp(JITRunner):
update_stats(colored("synchronize", "RED"), 0, 0, {}, et, 1, device=self.dname)
def lower_schedule_item(si:ScheduleItem) -> Optional[JITRunner]:
assert len(set(x.device for x in si.outputs+si.inputs)) == 1 or si.ast[0].op in {LoadOps.COPY, LoadOps.WAIT}
assert len(set(x.device for x in si.outputs+si.inputs)) == 1 or si.ast[0].op is LoadOps.COPY
if si.ast[0].op is BufferOps.STORE: return Device[si.outputs[0].device].get_runner(*si.ast)
assert len(si.ast) == 1 and len(si.outputs) == 1, "only ASTRunner supports multioutput"
out, ast = si.outputs[0], si.ast[0]
@@ -43,12 +42,10 @@ def run_schedule(schedule:List[ScheduleItem]):
for out in si.outputs:
# we don't have an output buffer, we have to create it, and create to max size if it has symbolic shape
if out.size > 0 and not dont_allocate and out.op is not LoadOps.ASSIGN: out.buffer.allocate()
if out.size > 0 and not dont_allocate and not hasattr(out, "_buf"): out.allocate()
# run the function (put it in JIT)
real_buffers = [x.buffer for x in si.outputs+si.inputs if x.size != 0]
real_buffers = [x for x in si.outputs+si.inputs if x.size != 0]
assert dont_allocate or all(hasattr(x, "_buf") for x in real_buffers), f"can't run, some inputs aren't realized {real_buffers}"
if prg: prg.exec(real_buffers, si.var_vals)
elif (out:=si.outputs[0]).size > 0: update_stats(colored(f"empty {out.size:10d} {out.dtype}", "yellow"), 0, 0, {}, None, 1, device=out.device)
if GRAPH:
for out in si.outputs: realized_lazybuffer(out, GlobalCounters.kernel_count)

View File

@@ -1,8 +1,9 @@
import sys
from collections import defaultdict, deque
from typing import List, Dict, Optional, Set, DefaultDict
from tinygrad.ops import LoadOps, ScheduleItem, BufferOps, LazyOp, ReduceOps, ConstBuffer, MemBuffer, BinaryOps, UnaryOps
from tinygrad.features.graph import log_lazybuffer
from dataclasses import dataclass
from typing import Tuple, List, Dict, Optional, Set, DefaultDict
from tinygrad.ops import LoadOps, ScheduleItem, BufferOps, LazyOp, ReduceOps, ConstBuffer, MemBuffer, BinaryOps, UnaryOps, GlobalCounters
from tinygrad.features.graph import log_lazybuffer, realized_lazybuffer
from tinygrad.helpers import GRAPH, DEBUG, prod, dedup, all_int
from tinygrad.shape.symbolic import Variable
from tinygrad.dtype import ImageDType, dtypes
@@ -12,6 +13,14 @@ from tinygrad.shape.shapetracker import ShapeTracker
# creation can recurse a lot
sys.setrecursionlimit(10000)
# TODO: it's unfortunate this needs to exist, but because of ASSIGN, we have to retain the LazyBuffer structure until post toposort
@dataclass(frozen=True)
class _LBScheduleItem:
ast: Tuple[LazyOp, ...]
outputs: Tuple[LazyBuffer, ...]
inputs: Tuple[LazyBuffer, ...]
var_vals: Dict[Variable, int]
# recursively create a lazyop
def _recursive_lazyop(buf:LazyBuffer, membufs:List[LazyBuffer], var_vals:Dict[Variable, int], st:ShapeTracker,
realizes:Set[LazyBuffer], cache, first=True, assign_to:Optional[LazyBuffer]=None, assign_idx:Optional[int]=None) -> LazyOp:
@@ -63,16 +72,16 @@ def _recursive_lazyop(buf:LazyBuffer, membufs:List[LazyBuffer], var_vals:Dict[Va
LazyOp(buf.op, tuple(_recursive_lazyop(x, membufs, var_vals, st, realizes, cache, False, assign_to, assign_idx) for x in buf.srcs), buf.arg)
return ret
def _schedule_one(out:LazyBuffer, realizes:Set[LazyBuffer], reduce_for_op: Dict[LazyBuffer, LazyBuffer]) -> ScheduleItem:
def _schedule_one(out:LazyBuffer, realizes:Set[LazyBuffer], reduce_for_op: Dict[LazyBuffer, LazyBuffer]) -> _LBScheduleItem:
inputs: List[LazyBuffer] = []
var_vals: Dict[Variable, int] = out.st.var_vals.copy()
if out.op in {LoadOps.CUSTOM, LoadOps.SYNC, LoadOps.WAIT, LoadOps.COPY, LoadOps.EMPTY}:
if out.op in {LoadOps.CUSTOM, LoadOps.SYNC, LoadOps.COPY, LoadOps.EMPTY}:
op, inputs = LazyOp(out.op, (), out.arg), list(out.srcs)
else:
output_st, membufs = ShapeTracker.from_shape(reduce_for_op[out].shape if out in reduce_for_op else out.shape), [out]
op = _recursive_lazyop(out, membufs, var_vals, output_st, realizes, cache={})
op, inputs = LazyOp(BufferOps.STORE, (op, ), MemBuffer(0, out.dtype, output_st.simplify().unbind()[0])), membufs[1:]
return ScheduleItem((op,), (out,), tuple(inputs), var_vals)
return _LBScheduleItem((op,), (out,), tuple(inputs), var_vals)
# recursively search the entire graph for all LazyBuffers, insert realizes after expands
def _recurse_lb(buf:LazyBuffer, realizes:Set[LazyBuffer], allbufs:Dict[LazyBuffer, None],
@@ -201,10 +210,15 @@ def create_schedule(outs:List[LazyBuffer], seen:Optional[Set[LazyBuffer]]=None)
queue = deque(out for out in prescheduled if in_degree[out] == 0)
schedule: List[ScheduleItem] = []
kernel_number = GlobalCounters.kernel_count
while queue:
buf = queue.popleft()
seen.add(buf)
schedule.append(prescheduled[buf])
ps = prescheduled[buf]
if GRAPH:
kernel_number += 1
for out in ps.outputs: realized_lazybuffer(out, kernel_number)
schedule.append(ScheduleItem(ps.ast, tuple(x.buffer for x in ps.outputs), tuple(x.buffer for x in ps.inputs), ps.var_vals))
for x in graph[buf]:
in_degree[x] -= 1
if in_degree[x] == 0: queue.append(x)

View File

@@ -12,10 +12,8 @@ from weakref import ref, ReferenceType, WeakValueDictionary
lazycache: WeakValueDictionary[Any, LazyBuffer] = WeakValueDictionary()
def create_lazybuffer(device:str, st:ShapeTracker, dtype:DType, op:Optional[Op]=None, arg:Any=None, srcs:Tuple[LazyBuffer, ...]=(),
base:Optional[LazyBuffer]=None, enable_cache=bool(getenv("LAZYCACHE", 1))):
if st.size == 0 and op not in {LoadOps.SYNC, LoadOps.WAIT}: op, arg, srcs, base = LoadOps.CONST, 0, (), None
if op is LoadOps.CONST:
arg = dtypes.as_const(arg, dtype)
enable_cache = True
if st.size == 0 and op is not LoadOps.SYNC: op, arg, srcs, base = LoadOps.CONST, 0, (), None
if op is LoadOps.CONST: arg, enable_cache = dtypes.as_const(arg, dtype), True
cache_key = (device, st, dtype, op, arg, tuple(ref(x) for x in srcs)) if base is None else (st, ref(base))
if enable_cache and (rret := lazycache.get(cache_key, None)): return rret
@@ -101,8 +99,7 @@ class LazyBuffer:
# copies in HSA/CUDA to other HSA/CUDA don't sync either
return create_lazybuffer(device, ShapeTracker.from_shape(self.shape), self.dtype, LoadOps.COPY, None, (self,), enable_cache=False)
sync = LazyBuffer.loadop(LoadOps.SYNC, (0,), dtypes.uint32, self.device, src=(self,), enable_cache=True)
wait = LazyBuffer.loadop(LoadOps.WAIT, (0,), dtypes.uint32, device, src=(sync,), enable_cache=True)
return create_lazybuffer(device, ShapeTracker.from_shape(self.shape), self.dtype, LoadOps.COPY, None, (self, wait), enable_cache=False)
return create_lazybuffer(device, ShapeTracker.from_shape(self.shape), self.dtype, LoadOps.COPY, None, (self, sync), enable_cache=False)
def copy_to_device(self, device:str, force: bool = False) -> LazyBuffer:
# no COPY

View File

@@ -18,14 +18,13 @@ class BinaryOps(Enum):
class TernaryOps(Enum): WHERE = auto() # noqa: E702
class ReduceOps(Enum): SUM = auto(); MAX = auto() # noqa: E702
class BufferOps(Enum): LOAD = auto(); CONST = auto(); STORE = auto() # noqa: E702
class LoadOps(Enum):
EMPTY = auto(); CONST = auto(); COPY = auto(); CONTIGUOUS = auto(); CUSTOM = auto(); SYNC = auto(); WAIT = auto(); ASSIGN = auto() # noqa: E702
class LoadOps(Enum): EMPTY = auto(); CONST = auto(); COPY = auto(); CONTIGUOUS = auto(); CUSTOM = auto(); SYNC = auto(); ASSIGN = auto() # noqa: E702
Op = Union[UnaryOps, BinaryOps, ReduceOps, LoadOps, TernaryOps, BufferOps]
OpType = Union[Type[UnaryOps], Type[BinaryOps], Type[ReduceOps], Type[LoadOps], Type[TernaryOps], Type[BufferOps]]
if TYPE_CHECKING:
from tinygrad.lazy import LazyBuffer
from tinygrad.buffer import Buffer
@dataclass(frozen=True)
class MemBuffer:
@@ -42,8 +41,8 @@ class ConstBuffer:
@dataclass(frozen=True)
class ScheduleItem:
ast: Tuple[LazyOp, ...]
outputs: Tuple[LazyBuffer, ...]
inputs: Tuple[LazyBuffer, ...]
outputs: Tuple[Buffer, ...]
inputs: Tuple[Buffer, ...]
var_vals: Dict[Variable, int]
@dataclass(frozen=True, eq=False)