mirror of
https://github.com/tinygrad/tinygrad.git
synced 2026-04-29 03:00:14 -04:00
ScheduleItem uses Buffer (#3995)
* schedule Buffer * update * update tests * master * works * remove LoadOps.WAIT * fix compile2 * bad test * rename and note
This commit is contained in:
@@ -1,8 +1,9 @@
|
||||
import sys
|
||||
from collections import defaultdict, deque
|
||||
from typing import List, Dict, Optional, Set, DefaultDict
|
||||
from tinygrad.ops import LoadOps, ScheduleItem, BufferOps, LazyOp, ReduceOps, ConstBuffer, MemBuffer, BinaryOps, UnaryOps
|
||||
from tinygrad.features.graph import log_lazybuffer
|
||||
from dataclasses import dataclass
|
||||
from typing import Tuple, List, Dict, Optional, Set, DefaultDict
|
||||
from tinygrad.ops import LoadOps, ScheduleItem, BufferOps, LazyOp, ReduceOps, ConstBuffer, MemBuffer, BinaryOps, UnaryOps, GlobalCounters
|
||||
from tinygrad.features.graph import log_lazybuffer, realized_lazybuffer
|
||||
from tinygrad.helpers import GRAPH, DEBUG, prod, dedup, all_int
|
||||
from tinygrad.shape.symbolic import Variable
|
||||
from tinygrad.dtype import ImageDType, dtypes
|
||||
@@ -12,6 +13,14 @@ from tinygrad.shape.shapetracker import ShapeTracker
|
||||
# creation can recurse a lot
|
||||
sys.setrecursionlimit(10000)
|
||||
|
||||
# TODO: it's unfortunate this needs to exist, but because of ASSIGN, we have to retain the LazyBuffer structure until post toposort
|
||||
@dataclass(frozen=True)
|
||||
class _LBScheduleItem:
|
||||
ast: Tuple[LazyOp, ...]
|
||||
outputs: Tuple[LazyBuffer, ...]
|
||||
inputs: Tuple[LazyBuffer, ...]
|
||||
var_vals: Dict[Variable, int]
|
||||
|
||||
# recursively create a lazyop
|
||||
def _recursive_lazyop(buf:LazyBuffer, membufs:List[LazyBuffer], var_vals:Dict[Variable, int], st:ShapeTracker,
|
||||
realizes:Set[LazyBuffer], cache, first=True, assign_to:Optional[LazyBuffer]=None, assign_idx:Optional[int]=None) -> LazyOp:
|
||||
@@ -63,16 +72,16 @@ def _recursive_lazyop(buf:LazyBuffer, membufs:List[LazyBuffer], var_vals:Dict[Va
|
||||
LazyOp(buf.op, tuple(_recursive_lazyop(x, membufs, var_vals, st, realizes, cache, False, assign_to, assign_idx) for x in buf.srcs), buf.arg)
|
||||
return ret
|
||||
|
||||
def _schedule_one(out:LazyBuffer, realizes:Set[LazyBuffer], reduce_for_op: Dict[LazyBuffer, LazyBuffer]) -> ScheduleItem:
|
||||
def _schedule_one(out:LazyBuffer, realizes:Set[LazyBuffer], reduce_for_op: Dict[LazyBuffer, LazyBuffer]) -> _LBScheduleItem:
|
||||
inputs: List[LazyBuffer] = []
|
||||
var_vals: Dict[Variable, int] = out.st.var_vals.copy()
|
||||
if out.op in {LoadOps.CUSTOM, LoadOps.SYNC, LoadOps.WAIT, LoadOps.COPY, LoadOps.EMPTY}:
|
||||
if out.op in {LoadOps.CUSTOM, LoadOps.SYNC, LoadOps.COPY, LoadOps.EMPTY}:
|
||||
op, inputs = LazyOp(out.op, (), out.arg), list(out.srcs)
|
||||
else:
|
||||
output_st, membufs = ShapeTracker.from_shape(reduce_for_op[out].shape if out in reduce_for_op else out.shape), [out]
|
||||
op = _recursive_lazyop(out, membufs, var_vals, output_st, realizes, cache={})
|
||||
op, inputs = LazyOp(BufferOps.STORE, (op, ), MemBuffer(0, out.dtype, output_st.simplify().unbind()[0])), membufs[1:]
|
||||
return ScheduleItem((op,), (out,), tuple(inputs), var_vals)
|
||||
return _LBScheduleItem((op,), (out,), tuple(inputs), var_vals)
|
||||
|
||||
# recursively search the entire graph for all LazyBuffers, insert realizes after expands
|
||||
def _recurse_lb(buf:LazyBuffer, realizes:Set[LazyBuffer], allbufs:Dict[LazyBuffer, None],
|
||||
@@ -201,10 +210,15 @@ def create_schedule(outs:List[LazyBuffer], seen:Optional[Set[LazyBuffer]]=None)
|
||||
|
||||
queue = deque(out for out in prescheduled if in_degree[out] == 0)
|
||||
schedule: List[ScheduleItem] = []
|
||||
kernel_number = GlobalCounters.kernel_count
|
||||
while queue:
|
||||
buf = queue.popleft()
|
||||
seen.add(buf)
|
||||
schedule.append(prescheduled[buf])
|
||||
ps = prescheduled[buf]
|
||||
if GRAPH:
|
||||
kernel_number += 1
|
||||
for out in ps.outputs: realized_lazybuffer(out, kernel_number)
|
||||
schedule.append(ScheduleItem(ps.ast, tuple(x.buffer for x in ps.outputs), tuple(x.buffer for x in ps.inputs), ps.var_vals))
|
||||
for x in graph[buf]:
|
||||
in_degree[x] -= 1
|
||||
if in_degree[x] == 0: queue.append(x)
|
||||
|
||||
Reference in New Issue
Block a user