ScheduleItem uses Buffer (#3995)

* schedule Buffer

* update

* update tests

* master

* works

* remove LoadOps.WAIT

* fix compile2

* bad test

* rename and note
This commit is contained in:
George Hotz
2024-03-29 20:50:27 -07:00
committed by GitHub
parent 1bd4f01da2
commit 9eef44521b
9 changed files with 47 additions and 41 deletions

View File

@@ -1,8 +1,9 @@
import sys
from collections import defaultdict, deque
from typing import List, Dict, Optional, Set, DefaultDict
from tinygrad.ops import LoadOps, ScheduleItem, BufferOps, LazyOp, ReduceOps, ConstBuffer, MemBuffer, BinaryOps, UnaryOps
from tinygrad.features.graph import log_lazybuffer
from dataclasses import dataclass
from typing import Tuple, List, Dict, Optional, Set, DefaultDict
from tinygrad.ops import LoadOps, ScheduleItem, BufferOps, LazyOp, ReduceOps, ConstBuffer, MemBuffer, BinaryOps, UnaryOps, GlobalCounters
from tinygrad.features.graph import log_lazybuffer, realized_lazybuffer
from tinygrad.helpers import GRAPH, DEBUG, prod, dedup, all_int
from tinygrad.shape.symbolic import Variable
from tinygrad.dtype import ImageDType, dtypes
@@ -12,6 +13,14 @@ from tinygrad.shape.shapetracker import ShapeTracker
# creation can recurse a lot
sys.setrecursionlimit(10000)
# TODO: it's unfortunate this needs to exist, but because of ASSIGN, we have to retain the LazyBuffer structure until post toposort
@dataclass(frozen=True)
class _LBScheduleItem:
ast: Tuple[LazyOp, ...]
outputs: Tuple[LazyBuffer, ...]
inputs: Tuple[LazyBuffer, ...]
var_vals: Dict[Variable, int]
# recursively create a lazyop
def _recursive_lazyop(buf:LazyBuffer, membufs:List[LazyBuffer], var_vals:Dict[Variable, int], st:ShapeTracker,
realizes:Set[LazyBuffer], cache, first=True, assign_to:Optional[LazyBuffer]=None, assign_idx:Optional[int]=None) -> LazyOp:
@@ -63,16 +72,16 @@ def _recursive_lazyop(buf:LazyBuffer, membufs:List[LazyBuffer], var_vals:Dict[Va
LazyOp(buf.op, tuple(_recursive_lazyop(x, membufs, var_vals, st, realizes, cache, False, assign_to, assign_idx) for x in buf.srcs), buf.arg)
return ret
def _schedule_one(out:LazyBuffer, realizes:Set[LazyBuffer], reduce_for_op: Dict[LazyBuffer, LazyBuffer]) -> ScheduleItem:
def _schedule_one(out:LazyBuffer, realizes:Set[LazyBuffer], reduce_for_op: Dict[LazyBuffer, LazyBuffer]) -> _LBScheduleItem:
inputs: List[LazyBuffer] = []
var_vals: Dict[Variable, int] = out.st.var_vals.copy()
if out.op in {LoadOps.CUSTOM, LoadOps.SYNC, LoadOps.WAIT, LoadOps.COPY, LoadOps.EMPTY}:
if out.op in {LoadOps.CUSTOM, LoadOps.SYNC, LoadOps.COPY, LoadOps.EMPTY}:
op, inputs = LazyOp(out.op, (), out.arg), list(out.srcs)
else:
output_st, membufs = ShapeTracker.from_shape(reduce_for_op[out].shape if out in reduce_for_op else out.shape), [out]
op = _recursive_lazyop(out, membufs, var_vals, output_st, realizes, cache={})
op, inputs = LazyOp(BufferOps.STORE, (op, ), MemBuffer(0, out.dtype, output_st.simplify().unbind()[0])), membufs[1:]
return ScheduleItem((op,), (out,), tuple(inputs), var_vals)
return _LBScheduleItem((op,), (out,), tuple(inputs), var_vals)
# recursively search the entire graph for all LazyBuffers, insert realizes after expands
def _recurse_lb(buf:LazyBuffer, realizes:Set[LazyBuffer], allbufs:Dict[LazyBuffer, None],
@@ -201,10 +210,15 @@ def create_schedule(outs:List[LazyBuffer], seen:Optional[Set[LazyBuffer]]=None)
queue = deque(out for out in prescheduled if in_degree[out] == 0)
schedule: List[ScheduleItem] = []
kernel_number = GlobalCounters.kernel_count
while queue:
buf = queue.popleft()
seen.add(buf)
schedule.append(prescheduled[buf])
ps = prescheduled[buf]
if GRAPH:
kernel_number += 1
for out in ps.outputs: realized_lazybuffer(out, kernel_number)
schedule.append(ScheduleItem(ps.ast, tuple(x.buffer for x in ps.outputs), tuple(x.buffer for x in ps.inputs), ps.var_vals))
for x in graph[buf]:
in_degree[x] -= 1
if in_degree[x] == 0: queue.append(x)