From 08ddeb56854bd64185e4c815c7a398cc8280d3ce Mon Sep 17 00:00:00 2001 From: George Hotz <72895+geohot@users.noreply.github.com> Date: Tue, 9 Apr 2024 21:42:16 -0700 Subject: [PATCH] create schedule has global vars (#4125) * abstractions3 is currently wishful thinking * create_schedule_with_vars --- tinygrad/engine/realize.py | 4 ++-- tinygrad/engine/schedule.py | 14 ++++++++++---- tinygrad/ops.py | 1 - tinygrad/tensor.py | 4 ++-- 4 files changed, 14 insertions(+), 9 deletions(-) diff --git a/tinygrad/engine/realize.py b/tinygrad/engine/realize.py index 26e92bce0a..7103fb9ec4 100644 --- a/tinygrad/engine/realize.py +++ b/tinygrad/engine/realize.py @@ -23,7 +23,7 @@ def lower_schedule_item(si:ScheduleItem) -> Optional[JITRunner]: return None logops = open(getenv("LOGOPS", ""), "a") if getenv("LOGOPS", "") else None -def run_schedule(schedule:List[ScheduleItem]): +def run_schedule(schedule:List[ScheduleItem], var_vals:Optional[Dict[Variable, int]] = None): while len(schedule): si = schedule.pop(0) if logops and si.ast[0].op not in LoadOps and not any(i.device.startswith("DISK:") for i in si.inputs): logops.write(str(si.ast)+"\n") @@ -37,5 +37,5 @@ def run_schedule(schedule:List[ScheduleItem]): # run the function (put it in JIT) real_buffers = [x for x in si.outputs+si.inputs if x.size != 0] - if prg: prg.exec(real_buffers, si.var_vals) + if prg: prg.exec(real_buffers, var_vals if var_vals is not None else {}) elif (out:=si.outputs[0]).size > 0: update_stats(colored(f"empty {out.size:10d} {out.dtype}", "yellow"), 0, 0, {}, None, 1, device=out.device) diff --git a/tinygrad/engine/schedule.py b/tinygrad/engine/schedule.py index 68e549dbea..4bb28e1f17 100644 --- a/tinygrad/engine/schedule.py +++ b/tinygrad/engine/schedule.py @@ -4,7 +4,7 @@ from dataclasses import dataclass from typing import Tuple, List, Dict, Optional, Set, DefaultDict from tinygrad.ops import LoadOps, ScheduleItem, BufferOps, LazyOp, ReduceOps, ConstBuffer, MemBuffer, BinaryOps, UnaryOps from tinygrad.features.graph import log_lazybuffer, realized_lazybuffer -from tinygrad.helpers import GRAPH, DEBUG, GlobalCounters, prod, dedup, all_int +from tinygrad.helpers import GRAPH, DEBUG, GlobalCounters, prod, dedup, all_int, merge_dicts from tinygrad.shape.symbolic import Variable from tinygrad.dtype import ImageDType, dtypes from tinygrad.lazy import LazyBuffer @@ -126,7 +126,7 @@ def _is_padding_okay(buf:LazyBuffer, realizes:Set[LazyBuffer]) -> bool: if buf.op in UNSAFE_PAD_OPS: return False return all(_is_padding_okay(x.base, realizes) for x in buf.srcs) -def create_schedule(outs:List[LazyBuffer], seen:Optional[Set[LazyBuffer]]=None) -> List[ScheduleItem]: +def create_schedule_with_vars(outs:List[LazyBuffer], seen:Optional[Set[LazyBuffer]]=None) -> Tuple[List[ScheduleItem], Dict[Variable, int]]: if seen is None: seen = set() # start by just realizing the buffers passed in @@ -218,6 +218,7 @@ def create_schedule(outs:List[LazyBuffer], seen:Optional[Set[LazyBuffer]]=None) queue = deque(si for key, si in prescheduled.items() if in_degree[key] == 0) schedule: List[ScheduleItem] = [] + var_vals: Dict[Variable, int] = {} kernel_number = GlobalCounters.kernel_count while queue: ps = queue.popleft() @@ -225,8 +226,8 @@ def create_schedule(outs:List[LazyBuffer], seen:Optional[Set[LazyBuffer]]=None) if GRAPH: kernel_number += 1 for out in ps.outputs: realized_lazybuffer(out, kernel_number) - schedule.append(ScheduleItem(ps.ast, tuple(x.buffer for x in ps.outputs if x.size != 0), - tuple(x.buffer for x in ps.inputs if x.size != 0), ps.var_vals)) + var_vals = merge_dicts([var_vals, ps.var_vals]) + schedule.append(ScheduleItem(ps.ast, tuple(x.buffer for x in ps.outputs if x.size != 0), tuple(x.buffer for x in ps.inputs if x.size != 0))) for x in graph[ps.outputs[0]]: in_degree[x] -= 1 if in_degree[x] == 0: queue.append(prescheduled[x]) @@ -234,4 +235,9 @@ def create_schedule(outs:List[LazyBuffer], seen:Optional[Set[LazyBuffer]]=None) # confirm everything was scheduled correctly if not all(degree == 0 for degree in in_degree.values()) or len(prescheduled) != len(schedule): raise RuntimeError(f"cycle detected in graph, prescheduled {len(prescheduled)} but only scheduled {len(schedule)}") + return schedule, var_vals + +def create_schedule(outs:List[LazyBuffer], seen:Optional[Set[LazyBuffer]]=None) -> List[ScheduleItem]: + schedule, var_vals = create_schedule_with_vars(outs, seen) + assert len(var_vals) == 0 return schedule diff --git a/tinygrad/ops.py b/tinygrad/ops.py index be5bf1d7bc..bdc8bb855a 100644 --- a/tinygrad/ops.py +++ b/tinygrad/ops.py @@ -41,7 +41,6 @@ class ScheduleItem: ast: Tuple[LazyOp, ...] outputs: Tuple[Buffer, ...] inputs: Tuple[Buffer, ...] - var_vals: Dict[Variable, int] @dataclass(frozen=True, eq=False) class LazyOp: diff --git a/tinygrad/tensor.py b/tinygrad/tensor.py index a17503ef7c..24f6844c9d 100644 --- a/tinygrad/tensor.py +++ b/tinygrad/tensor.py @@ -15,7 +15,7 @@ from tinygrad.buffer import Buffer, BufferOptions from tinygrad.device import Device from tinygrad.shape.symbolic import sint from tinygrad.engine.realize import run_schedule -from tinygrad.engine.schedule import create_schedule +from tinygrad.engine.schedule import create_schedule_with_vars # **** start with two base classes, Tensor and Function **** @@ -138,7 +138,7 @@ class Tensor: @staticmethod def corealize(lst:Iterable[Tensor]): - run_schedule(create_schedule(flatten([x.lazydata.lbs if isinstance(x.lazydata, MultiLazyBuffer) else [x.lazydata] for x in lst]))) + run_schedule(*create_schedule_with_vars(flatten([x.lazydata.lbs if isinstance(x.lazydata, MultiLazyBuffer) else [x.lazydata] for x in lst]))) def realize(self) -> Tensor: Tensor.corealize([self])