diff --git a/.gitignore b/.gitignore index 8dff6dbd1d..a897885df0 100644 --- a/.gitignore +++ b/.gitignore @@ -52,4 +52,5 @@ quickstart.py .hypothesis weights *.lprof +*.pkl site/ diff --git a/tinygrad/engine/schedule.py b/tinygrad/engine/schedule.py index 52293a1e2c..765c220873 100644 --- a/tinygrad/engine/schedule.py +++ b/tinygrad/engine/schedule.py @@ -1,4 +1,4 @@ -import sys +import sys, pickle, atexit from collections import defaultdict, deque from dataclasses import dataclass from typing import Tuple, List, Dict, Optional, Set, DefaultDict @@ -238,6 +238,7 @@ def _graph_schedule(outs:List[LazyBuffer], seen:Set[LazyBuffer]) -> Tuple[Defaul return graph, in_degree, prescheduled +SCHEDULES: List[ScheduleItem] = [] def create_schedule_with_vars(outs:List[LazyBuffer], seen:Optional[Set[LazyBuffer]]=None) -> Tuple[List[ScheduleItem], Dict[Variable, int]]: if seen is None: seen = set() graph, in_degree, prescheduled = _graph_schedule(outs, seen) @@ -263,6 +264,12 @@ def create_schedule_with_vars(outs:List[LazyBuffer], seen:Optional[Set[LazyBuffe if not all(degree == 0 for degree in in_degree.values()) or len(prescheduled) != len(schedule): raise RuntimeError(f"cycle detected in graph, prescheduled {len(prescheduled)} but only scheduled {len(schedule)}") if DEBUG >= 1 and len(schedule) >= 10: print(f"scheduled {len(schedule)} kernels") + if getenv("SAVE_SCHEDULE"): + def _save(): + print(f"saving {len(SCHEDULES)} schedule items to", fp:="schedule.pkl") + pickle.dump(SCHEDULES, open(fp, "wb")) + if len(SCHEDULES) == 0: atexit.register(_save) + SCHEDULES.extend(schedule) return schedule, var_vals def create_schedule(outs:List[LazyBuffer], seen:Optional[Set[LazyBuffer]]=None) -> List[ScheduleItem]: