mirror of
https://github.com/tinygrad/tinygrad.git
synced 2026-02-11 15:15:13 -05:00
ScheduleItem uses Buffer (#3995)
* schedule Buffer * update * update tests * master * works * remove LoadOps.WAIT * fix compile2 * bad test * rename and note
This commit is contained in:
@@ -1,8 +1,7 @@
|
||||
from typing import List, Dict, Optional
|
||||
from tinygrad.ops import LoadOps, ScheduleItem, BufferOps, GlobalCounters
|
||||
from tinygrad.ops import LoadOps, ScheduleItem, BufferOps
|
||||
from tinygrad.device import Device, Buffer, BufferCopy, BufferXfer, JITRunner, update_stats
|
||||
from tinygrad.features.graph import realized_lazybuffer
|
||||
from tinygrad.helpers import colored, getenv, GRAPH, cpu_time_execution, DEBUG
|
||||
from tinygrad.helpers import colored, getenv, cpu_time_execution, DEBUG
|
||||
from tinygrad.shape.symbolic import Variable
|
||||
|
||||
class CustomOp(JITRunner):
|
||||
@@ -20,7 +19,7 @@ class SyncOp(JITRunner):
|
||||
update_stats(colored("synchronize", "RED"), 0, 0, {}, et, 1, device=self.dname)
|
||||
|
||||
def lower_schedule_item(si:ScheduleItem) -> Optional[JITRunner]:
|
||||
assert len(set(x.device for x in si.outputs+si.inputs)) == 1 or si.ast[0].op in {LoadOps.COPY, LoadOps.WAIT}
|
||||
assert len(set(x.device for x in si.outputs+si.inputs)) == 1 or si.ast[0].op is LoadOps.COPY
|
||||
if si.ast[0].op is BufferOps.STORE: return Device[si.outputs[0].device].get_runner(*si.ast)
|
||||
assert len(si.ast) == 1 and len(si.outputs) == 1, "only ASTRunner supports multioutput"
|
||||
out, ast = si.outputs[0], si.ast[0]
|
||||
@@ -43,12 +42,10 @@ def run_schedule(schedule:List[ScheduleItem]):
|
||||
|
||||
for out in si.outputs:
|
||||
# we don't have an output buffer, we have to create it, and create to max size if it has symbolic shape
|
||||
if out.size > 0 and not dont_allocate and out.op is not LoadOps.ASSIGN: out.buffer.allocate()
|
||||
if out.size > 0 and not dont_allocate and not hasattr(out, "_buf"): out.allocate()
|
||||
|
||||
# run the function (put it in JIT)
|
||||
real_buffers = [x.buffer for x in si.outputs+si.inputs if x.size != 0]
|
||||
real_buffers = [x for x in si.outputs+si.inputs if x.size != 0]
|
||||
assert dont_allocate or all(hasattr(x, "_buf") for x in real_buffers), f"can't run, some inputs aren't realized {real_buffers}"
|
||||
if prg: prg.exec(real_buffers, si.var_vals)
|
||||
elif (out:=si.outputs[0]).size > 0: update_stats(colored(f"empty {out.size:10d} {out.dtype}", "yellow"), 0, 0, {}, None, 1, device=out.device)
|
||||
if GRAPH:
|
||||
for out in si.outputs: realized_lazybuffer(out, GlobalCounters.kernel_count)
|
||||
|
||||
@@ -1,8 +1,9 @@
|
||||
import sys
|
||||
from collections import defaultdict, deque
|
||||
from typing import List, Dict, Optional, Set, DefaultDict
|
||||
from tinygrad.ops import LoadOps, ScheduleItem, BufferOps, LazyOp, ReduceOps, ConstBuffer, MemBuffer, BinaryOps, UnaryOps
|
||||
from tinygrad.features.graph import log_lazybuffer
|
||||
from dataclasses import dataclass
|
||||
from typing import Tuple, List, Dict, Optional, Set, DefaultDict
|
||||
from tinygrad.ops import LoadOps, ScheduleItem, BufferOps, LazyOp, ReduceOps, ConstBuffer, MemBuffer, BinaryOps, UnaryOps, GlobalCounters
|
||||
from tinygrad.features.graph import log_lazybuffer, realized_lazybuffer
|
||||
from tinygrad.helpers import GRAPH, DEBUG, prod, dedup, all_int
|
||||
from tinygrad.shape.symbolic import Variable
|
||||
from tinygrad.dtype import ImageDType, dtypes
|
||||
@@ -12,6 +13,14 @@ from tinygrad.shape.shapetracker import ShapeTracker
|
||||
# creation can recurse a lot
|
||||
sys.setrecursionlimit(10000)
|
||||
|
||||
# TODO: it's unfortunate this needs to exist, but because of ASSIGN, we have to retain the LazyBuffer structure until post toposort
|
||||
@dataclass(frozen=True)
|
||||
class _LBScheduleItem:
|
||||
ast: Tuple[LazyOp, ...]
|
||||
outputs: Tuple[LazyBuffer, ...]
|
||||
inputs: Tuple[LazyBuffer, ...]
|
||||
var_vals: Dict[Variable, int]
|
||||
|
||||
# recursively create a lazyop
|
||||
def _recursive_lazyop(buf:LazyBuffer, membufs:List[LazyBuffer], var_vals:Dict[Variable, int], st:ShapeTracker,
|
||||
realizes:Set[LazyBuffer], cache, first=True, assign_to:Optional[LazyBuffer]=None, assign_idx:Optional[int]=None) -> LazyOp:
|
||||
@@ -63,16 +72,16 @@ def _recursive_lazyop(buf:LazyBuffer, membufs:List[LazyBuffer], var_vals:Dict[Va
|
||||
LazyOp(buf.op, tuple(_recursive_lazyop(x, membufs, var_vals, st, realizes, cache, False, assign_to, assign_idx) for x in buf.srcs), buf.arg)
|
||||
return ret
|
||||
|
||||
def _schedule_one(out:LazyBuffer, realizes:Set[LazyBuffer], reduce_for_op: Dict[LazyBuffer, LazyBuffer]) -> ScheduleItem:
|
||||
def _schedule_one(out:LazyBuffer, realizes:Set[LazyBuffer], reduce_for_op: Dict[LazyBuffer, LazyBuffer]) -> _LBScheduleItem:
|
||||
inputs: List[LazyBuffer] = []
|
||||
var_vals: Dict[Variable, int] = out.st.var_vals.copy()
|
||||
if out.op in {LoadOps.CUSTOM, LoadOps.SYNC, LoadOps.WAIT, LoadOps.COPY, LoadOps.EMPTY}:
|
||||
if out.op in {LoadOps.CUSTOM, LoadOps.SYNC, LoadOps.COPY, LoadOps.EMPTY}:
|
||||
op, inputs = LazyOp(out.op, (), out.arg), list(out.srcs)
|
||||
else:
|
||||
output_st, membufs = ShapeTracker.from_shape(reduce_for_op[out].shape if out in reduce_for_op else out.shape), [out]
|
||||
op = _recursive_lazyop(out, membufs, var_vals, output_st, realizes, cache={})
|
||||
op, inputs = LazyOp(BufferOps.STORE, (op, ), MemBuffer(0, out.dtype, output_st.simplify().unbind()[0])), membufs[1:]
|
||||
return ScheduleItem((op,), (out,), tuple(inputs), var_vals)
|
||||
return _LBScheduleItem((op,), (out,), tuple(inputs), var_vals)
|
||||
|
||||
# recursively search the entire graph for all LazyBuffers, insert realizes after expands
|
||||
def _recurse_lb(buf:LazyBuffer, realizes:Set[LazyBuffer], allbufs:Dict[LazyBuffer, None],
|
||||
@@ -201,10 +210,15 @@ def create_schedule(outs:List[LazyBuffer], seen:Optional[Set[LazyBuffer]]=None)
|
||||
|
||||
queue = deque(out for out in prescheduled if in_degree[out] == 0)
|
||||
schedule: List[ScheduleItem] = []
|
||||
kernel_number = GlobalCounters.kernel_count
|
||||
while queue:
|
||||
buf = queue.popleft()
|
||||
seen.add(buf)
|
||||
schedule.append(prescheduled[buf])
|
||||
ps = prescheduled[buf]
|
||||
if GRAPH:
|
||||
kernel_number += 1
|
||||
for out in ps.outputs: realized_lazybuffer(out, kernel_number)
|
||||
schedule.append(ScheduleItem(ps.ast, tuple(x.buffer for x in ps.outputs), tuple(x.buffer for x in ps.inputs), ps.var_vals))
|
||||
for x in graph[buf]:
|
||||
in_degree[x] -= 1
|
||||
if in_degree[x] == 0: queue.append(x)
|
||||
|
||||
@@ -12,10 +12,8 @@ from weakref import ref, ReferenceType, WeakValueDictionary
|
||||
lazycache: WeakValueDictionary[Any, LazyBuffer] = WeakValueDictionary()
|
||||
def create_lazybuffer(device:str, st:ShapeTracker, dtype:DType, op:Optional[Op]=None, arg:Any=None, srcs:Tuple[LazyBuffer, ...]=(),
|
||||
base:Optional[LazyBuffer]=None, enable_cache=bool(getenv("LAZYCACHE", 1))):
|
||||
if st.size == 0 and op not in {LoadOps.SYNC, LoadOps.WAIT}: op, arg, srcs, base = LoadOps.CONST, 0, (), None
|
||||
if op is LoadOps.CONST:
|
||||
arg = dtypes.as_const(arg, dtype)
|
||||
enable_cache = True
|
||||
if st.size == 0 and op is not LoadOps.SYNC: op, arg, srcs, base = LoadOps.CONST, 0, (), None
|
||||
if op is LoadOps.CONST: arg, enable_cache = dtypes.as_const(arg, dtype), True
|
||||
|
||||
cache_key = (device, st, dtype, op, arg, tuple(ref(x) for x in srcs)) if base is None else (st, ref(base))
|
||||
if enable_cache and (rret := lazycache.get(cache_key, None)): return rret
|
||||
@@ -101,8 +99,7 @@ class LazyBuffer:
|
||||
# copies in HSA/CUDA to other HSA/CUDA don't sync either
|
||||
return create_lazybuffer(device, ShapeTracker.from_shape(self.shape), self.dtype, LoadOps.COPY, None, (self,), enable_cache=False)
|
||||
sync = LazyBuffer.loadop(LoadOps.SYNC, (0,), dtypes.uint32, self.device, src=(self,), enable_cache=True)
|
||||
wait = LazyBuffer.loadop(LoadOps.WAIT, (0,), dtypes.uint32, device, src=(sync,), enable_cache=True)
|
||||
return create_lazybuffer(device, ShapeTracker.from_shape(self.shape), self.dtype, LoadOps.COPY, None, (self, wait), enable_cache=False)
|
||||
return create_lazybuffer(device, ShapeTracker.from_shape(self.shape), self.dtype, LoadOps.COPY, None, (self, sync), enable_cache=False)
|
||||
|
||||
def copy_to_device(self, device:str, force: bool = False) -> LazyBuffer:
|
||||
# no COPY
|
||||
|
||||
@@ -18,14 +18,13 @@ class BinaryOps(Enum):
|
||||
class TernaryOps(Enum): WHERE = auto() # noqa: E702
|
||||
class ReduceOps(Enum): SUM = auto(); MAX = auto() # noqa: E702
|
||||
class BufferOps(Enum): LOAD = auto(); CONST = auto(); STORE = auto() # noqa: E702
|
||||
class LoadOps(Enum):
|
||||
EMPTY = auto(); CONST = auto(); COPY = auto(); CONTIGUOUS = auto(); CUSTOM = auto(); SYNC = auto(); WAIT = auto(); ASSIGN = auto() # noqa: E702
|
||||
class LoadOps(Enum): EMPTY = auto(); CONST = auto(); COPY = auto(); CONTIGUOUS = auto(); CUSTOM = auto(); SYNC = auto(); ASSIGN = auto() # noqa: E702
|
||||
|
||||
Op = Union[UnaryOps, BinaryOps, ReduceOps, LoadOps, TernaryOps, BufferOps]
|
||||
OpType = Union[Type[UnaryOps], Type[BinaryOps], Type[ReduceOps], Type[LoadOps], Type[TernaryOps], Type[BufferOps]]
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from tinygrad.lazy import LazyBuffer
|
||||
from tinygrad.buffer import Buffer
|
||||
|
||||
@dataclass(frozen=True)
|
||||
class MemBuffer:
|
||||
@@ -42,8 +41,8 @@ class ConstBuffer:
|
||||
@dataclass(frozen=True)
|
||||
class ScheduleItem:
|
||||
ast: Tuple[LazyOp, ...]
|
||||
outputs: Tuple[LazyBuffer, ...]
|
||||
inputs: Tuple[LazyBuffer, ...]
|
||||
outputs: Tuple[Buffer, ...]
|
||||
inputs: Tuple[Buffer, ...]
|
||||
var_vals: Dict[Variable, int]
|
||||
|
||||
@dataclass(frozen=True, eq=False)
|
||||
|
||||
Reference in New Issue
Block a user