diff --git a/test/external/process_replay/process_replay.py b/test/external/process_replay/process_replay.py index b7d530c45b..5232105a4d 100755 --- a/test/external/process_replay/process_replay.py +++ b/test/external/process_replay/process_replay.py @@ -1,10 +1,9 @@ #!/usr/bin/env python3 # compare kernels created by HEAD against master -from collections import defaultdict import os, multiprocessing, logging, pickle, sqlite3, difflib, functools, warnings from typing import Callable, cast from tinygrad.helpers import VERSION, Context, ContextVar, colored, db_connection, getenv, tqdm -from tinygrad.engine.schedule import ScheduleContext, schedule_uop +from tinygrad.engine.schedule import create_schedule_with_vars from tinygrad.codegen.kernel import Kernel, Opt from tinygrad.renderer import Renderer from tinygrad.ops import UOp @@ -30,9 +29,9 @@ class ProcessReplayWarning(Warning): pass # *** recreators -def recreate_sched(ast:UOp) -> UOp: - # NOTE: process replay isn't meant to actually schedule anything - return schedule_uop(ast, ScheduleContext(tensor_uops=defaultdict(list)), {}).ast +def recreate_sched(big_sink:UOp) -> list[UOp]: + sched, _, __ = create_schedule_with_vars(big_sink) + return [x.ast for x in sched] def recreate_kernel(ast:UOp, opts:Renderer, applied_opts:list[Opt], name:str) -> str: k = Kernel(ast, opts=opts) for opt in applied_opts: k.apply_opt(opt) @@ -42,7 +41,8 @@ def recreate_kernel(ast:UOp, opts:Renderer, applied_opts:list[Opt], name:str) -> # *** diff a "good" recreation against the generated version def diff(offset:int, name:str, fxn:Callable) -> None: - if ASSERT_DIFF: warnings.filterwarnings("error", category=ProcessReplayWarning) + # TODO: add this assert back for schedule + if ASSERT_DIFF and name != "schedule": warnings.filterwarnings("error", category=ProcessReplayWarning) if early_stop.is_set(): return None conn = db_connection() cur = conn.cursor() diff --git a/tinygrad/engine/schedule.py b/tinygrad/engine/schedule.py index c773a0c34f..18735ff53a 100644 --- a/tinygrad/engine/schedule.py +++ b/tinygrad/engine/schedule.py @@ -1,4 +1,4 @@ -import sys, atexit, functools, pickle +import sys, functools from collections import defaultdict, deque from dataclasses import dataclass, field from tinygrad.ops import UOp, Variable, Ops, GroupOp, PatternMatcher, UPat, graph_rewrite, graph_rewrite_map, track_rewrites, buffers @@ -396,17 +396,8 @@ def schedule_uop(pre:UOp, ctx:ScheduleContext, var_vals:dict[UOp, int]) -> Sched # otherwise, it's not fine else: raise RuntimeError("self operand of augmented assign must be contiguous.\nhelp: consider using .contiguous():\n" +colored(" - a += a.T\n", "red")+colored(" + a += a.T.contiguous()", "green")) - # capture process replay - if CAPTURE_PROCESS_REPLAY: - with Context(PICKLE_BUFFERS=0): PROCESS_REPLAY_CAPTURE[str(pre.key)] = pickle.dumps((pre, ContextVar._cache, ast)) return ScheduleItem(ast, tuple(u.buffer for u in si_ctx.bufs), tuple(dedup(m for x in pre.toposort if (m:=ctx.ops_metadata.get(x)) is not None))) -PROCESS_REPLAY_CAPTURE: dict[str, bytes] = {} -if CAPTURE_PROCESS_REPLAY: - @atexit.register - def save_process_replay() -> None: - for k,v in PROCESS_REPLAY_CAPTURE.items(): diskcache_put("schedule_process_replay", k, v, prepickled=True) - # **** schedule creation and toposort @track_rewrites(named=True) @@ -473,4 +464,8 @@ def create_schedule_with_vars(big_sink:UOp) -> tuple[list[ScheduleItem], dict[Va # confirm everything was scheduled correctly if len(schedule) != (groups:=len(prescheduled)): raise RuntimeError(f"cycle detected in graph, grouped {groups} but only scheduled {len(schedule)}") if DEBUG >= 1 and len(schedule) >= 10: print(f"scheduled {len(schedule)} kernels") + # capture process replay + if CAPTURE_PROCESS_REPLAY: + with Context(PICKLE_BUFFERS=0): + diskcache_put("schedule_process_replay", str(big_sink.key), (big_sink, ContextVar._cache, [x.ast for x in schedule])) return schedule, var_vals, becomes_map