mirror of
https://github.com/tinygrad/tinygrad.git
synced 2026-02-04 03:35:16 -05:00
create schedule has global vars (#4125)
* abstractions3 is currently wishful thinking * create_schedule_with_vars
This commit is contained in:
@@ -4,7 +4,7 @@ from dataclasses import dataclass
|
||||
from typing import Tuple, List, Dict, Optional, Set, DefaultDict
|
||||
from tinygrad.ops import LoadOps, ScheduleItem, BufferOps, LazyOp, ReduceOps, ConstBuffer, MemBuffer, BinaryOps, UnaryOps
|
||||
from tinygrad.features.graph import log_lazybuffer, realized_lazybuffer
|
||||
from tinygrad.helpers import GRAPH, DEBUG, GlobalCounters, prod, dedup, all_int
|
||||
from tinygrad.helpers import GRAPH, DEBUG, GlobalCounters, prod, dedup, all_int, merge_dicts
|
||||
from tinygrad.shape.symbolic import Variable
|
||||
from tinygrad.dtype import ImageDType, dtypes
|
||||
from tinygrad.lazy import LazyBuffer
|
||||
@@ -126,7 +126,7 @@ def _is_padding_okay(buf:LazyBuffer, realizes:Set[LazyBuffer]) -> bool:
|
||||
if buf.op in UNSAFE_PAD_OPS: return False
|
||||
return all(_is_padding_okay(x.base, realizes) for x in buf.srcs)
|
||||
|
||||
def create_schedule(outs:List[LazyBuffer], seen:Optional[Set[LazyBuffer]]=None) -> List[ScheduleItem]:
|
||||
def create_schedule_with_vars(outs:List[LazyBuffer], seen:Optional[Set[LazyBuffer]]=None) -> Tuple[List[ScheduleItem], Dict[Variable, int]]:
|
||||
if seen is None: seen = set()
|
||||
|
||||
# start by just realizing the buffers passed in
|
||||
@@ -218,6 +218,7 @@ def create_schedule(outs:List[LazyBuffer], seen:Optional[Set[LazyBuffer]]=None)
|
||||
|
||||
queue = deque(si for key, si in prescheduled.items() if in_degree[key] == 0)
|
||||
schedule: List[ScheduleItem] = []
|
||||
var_vals: Dict[Variable, int] = {}
|
||||
kernel_number = GlobalCounters.kernel_count
|
||||
while queue:
|
||||
ps = queue.popleft()
|
||||
@@ -225,8 +226,8 @@ def create_schedule(outs:List[LazyBuffer], seen:Optional[Set[LazyBuffer]]=None)
|
||||
if GRAPH:
|
||||
kernel_number += 1
|
||||
for out in ps.outputs: realized_lazybuffer(out, kernel_number)
|
||||
schedule.append(ScheduleItem(ps.ast, tuple(x.buffer for x in ps.outputs if x.size != 0),
|
||||
tuple(x.buffer for x in ps.inputs if x.size != 0), ps.var_vals))
|
||||
var_vals = merge_dicts([var_vals, ps.var_vals])
|
||||
schedule.append(ScheduleItem(ps.ast, tuple(x.buffer for x in ps.outputs if x.size != 0), tuple(x.buffer for x in ps.inputs if x.size != 0)))
|
||||
for x in graph[ps.outputs[0]]:
|
||||
in_degree[x] -= 1
|
||||
if in_degree[x] == 0: queue.append(prescheduled[x])
|
||||
@@ -234,4 +235,9 @@ def create_schedule(outs:List[LazyBuffer], seen:Optional[Set[LazyBuffer]]=None)
|
||||
# confirm everything was scheduled correctly
|
||||
if not all(degree == 0 for degree in in_degree.values()) or len(prescheduled) != len(schedule):
|
||||
raise RuntimeError(f"cycle detected in graph, prescheduled {len(prescheduled)} but only scheduled {len(schedule)}")
|
||||
return schedule, var_vals
|
||||
|
||||
def create_schedule(outs:List[LazyBuffer], seen:Optional[Set[LazyBuffer]]=None) -> List[ScheduleItem]:
|
||||
schedule, var_vals = create_schedule_with_vars(outs, seen)
|
||||
assert len(var_vals) == 0
|
||||
return schedule
|
||||
|
||||
Reference in New Issue
Block a user