mirror of
https://github.com/tinygrad/tinygrad.git
synced 2026-01-10 15:38:29 -05:00
SAVE_SCHEDULE as contextvar (#4230)
This commit is contained in:
@@ -4,7 +4,7 @@ from dataclasses import dataclass
|
||||
from typing import Tuple, List, Dict, Optional, Set, DefaultDict
|
||||
from tinygrad.ops import LoadOps, ScheduleItem, BufferOps, LazyOp, ReduceOps, ConstBuffer, MemBuffer, UNSAFE_PAD_OPS, UnaryOps
|
||||
from tinygrad.features.graph import log_lazybuffer, realized_lazybuffer
|
||||
from tinygrad.helpers import GRAPH, DEBUG, MULTIOUTPUT, GlobalCounters, prod, dedup, all_int, merge_dicts, getenv
|
||||
from tinygrad.helpers import GRAPH, DEBUG, MULTIOUTPUT, SAVE_SCHEDULE, GlobalCounters, prod, dedup, all_int, merge_dicts, getenv
|
||||
from tinygrad.shape.symbolic import Variable
|
||||
from tinygrad.dtype import ImageDType, dtypes
|
||||
from tinygrad.lazy import LazyBuffer
|
||||
@@ -264,7 +264,7 @@ def create_schedule_with_vars(outs:List[LazyBuffer], seen:Optional[Set[LazyBuffe
|
||||
if not all(degree == 0 for degree in in_degree.values()) or len(prescheduled) != len(schedule):
|
||||
raise RuntimeError(f"cycle detected in graph, prescheduled {len(prescheduled)} but only scheduled {len(schedule)}")
|
||||
if DEBUG >= 1 and len(schedule) >= 10: print(f"scheduled {len(schedule)} kernels")
|
||||
if getenv("SAVE_SCHEDULE"):
|
||||
if SAVE_SCHEDULE:
|
||||
def _save():
|
||||
print(f"saving {len(SCHEDULES)} schedule items to", fp:="schedule.pkl")
|
||||
pickle.dump(SCHEDULES, open(fp, "wb"))
|
||||
|
||||
@@ -97,7 +97,7 @@ class ContextVar:
|
||||
|
||||
DEBUG, IMAGE, BEAM, NOOPT = ContextVar("DEBUG", 0), ContextVar("IMAGE", 0), ContextVar("BEAM", 0), ContextVar("NOOPT", 0)
|
||||
WINO, THREEFRY, CACHECOLLECTING = ContextVar("WINO", 0), ContextVar("THREEFRY", 0), ContextVar("CACHECOLLECTING", 1)
|
||||
GRAPH, GRAPHPATH, RING = ContextVar("GRAPH", 0), getenv("GRAPHPATH", "/tmp/net"), ContextVar("RING", 1)
|
||||
GRAPH, GRAPHPATH, SAVE_SCHEDULE, RING = ContextVar("GRAPH", 0), getenv("GRAPHPATH", "/tmp/net"), ContextVar("SAVE_SCHEDULE", 0), ContextVar("RING", 1)
|
||||
MULTIOUTPUT = ContextVar("MULTIOUTPUT", 1)
|
||||
|
||||
# **************** global state Counters ****************
|
||||
|
||||
Reference in New Issue
Block a user