create schedule has global vars (#4125)

* abstractions3 is currently wishful thinking

* create_schedule_with_vars
This commit is contained in:
George Hotz
2024-04-09 21:42:16 -07:00
committed by GitHub
parent 216eb235e5
commit 08ddeb5685
4 changed files with 14 additions and 9 deletions

View File

@@ -4,7 +4,7 @@ from dataclasses import dataclass
from typing import Tuple, List, Dict, Optional, Set, DefaultDict
from tinygrad.ops import LoadOps, ScheduleItem, BufferOps, LazyOp, ReduceOps, ConstBuffer, MemBuffer, BinaryOps, UnaryOps
from tinygrad.features.graph import log_lazybuffer, realized_lazybuffer
from tinygrad.helpers import GRAPH, DEBUG, GlobalCounters, prod, dedup, all_int
from tinygrad.helpers import GRAPH, DEBUG, GlobalCounters, prod, dedup, all_int, merge_dicts
from tinygrad.shape.symbolic import Variable
from tinygrad.dtype import ImageDType, dtypes
from tinygrad.lazy import LazyBuffer
@@ -126,7 +126,7 @@ def _is_padding_okay(buf:LazyBuffer, realizes:Set[LazyBuffer]) -> bool:
if buf.op in UNSAFE_PAD_OPS: return False
return all(_is_padding_okay(x.base, realizes) for x in buf.srcs)
def create_schedule(outs:List[LazyBuffer], seen:Optional[Set[LazyBuffer]]=None) -> List[ScheduleItem]:
def create_schedule_with_vars(outs:List[LazyBuffer], seen:Optional[Set[LazyBuffer]]=None) -> Tuple[List[ScheduleItem], Dict[Variable, int]]:
if seen is None: seen = set()
# start by just realizing the buffers passed in
@@ -218,6 +218,7 @@ def create_schedule(outs:List[LazyBuffer], seen:Optional[Set[LazyBuffer]]=None)
queue = deque(si for key, si in prescheduled.items() if in_degree[key] == 0)
schedule: List[ScheduleItem] = []
var_vals: Dict[Variable, int] = {}
kernel_number = GlobalCounters.kernel_count
while queue:
ps = queue.popleft()
@@ -225,8 +226,8 @@ def create_schedule(outs:List[LazyBuffer], seen:Optional[Set[LazyBuffer]]=None)
if GRAPH:
kernel_number += 1
for out in ps.outputs: realized_lazybuffer(out, kernel_number)
schedule.append(ScheduleItem(ps.ast, tuple(x.buffer for x in ps.outputs if x.size != 0),
tuple(x.buffer for x in ps.inputs if x.size != 0), ps.var_vals))
var_vals = merge_dicts([var_vals, ps.var_vals])
schedule.append(ScheduleItem(ps.ast, tuple(x.buffer for x in ps.outputs if x.size != 0), tuple(x.buffer for x in ps.inputs if x.size != 0)))
for x in graph[ps.outputs[0]]:
in_degree[x] -= 1
if in_degree[x] == 0: queue.append(prescheduled[x])
@@ -234,4 +235,9 @@ def create_schedule(outs:List[LazyBuffer], seen:Optional[Set[LazyBuffer]]=None)
# confirm everything was scheduled correctly
if not all(degree == 0 for degree in in_degree.values()) or len(prescheduled) != len(schedule):
raise RuntimeError(f"cycle detected in graph, prescheduled {len(prescheduled)} but only scheduled {len(schedule)}")
return schedule, var_vals
def create_schedule(outs:List[LazyBuffer], seen:Optional[Set[LazyBuffer]]=None) -> List[ScheduleItem]:
schedule, var_vals = create_schedule_with_vars(outs, seen)
assert len(var_vals) == 0
return schedule