process replay from tensor uops to kernel ast (#8883)

* process replay from tensor uops to kernel ast

* this dedups

* switch back to string key
This commit is contained in:
qazal
2025-02-04 18:09:20 +01:00
committed by GitHub
parent dcf104ee68
commit acf0baefee
2 changed files with 11 additions and 16 deletions

View File

@@ -1,10 +1,9 @@
#!/usr/bin/env python3
# compare kernels created by HEAD against master
from collections import defaultdict
import os, multiprocessing, logging, pickle, sqlite3, difflib, functools, warnings
from typing import Callable, cast
from tinygrad.helpers import VERSION, Context, ContextVar, colored, db_connection, getenv, tqdm
from tinygrad.engine.schedule import ScheduleContext, schedule_uop
from tinygrad.engine.schedule import create_schedule_with_vars
from tinygrad.codegen.kernel import Kernel, Opt
from tinygrad.renderer import Renderer
from tinygrad.ops import UOp
@@ -30,9 +29,9 @@ class ProcessReplayWarning(Warning): pass
# *** recreators
def recreate_sched(ast:UOp) -> UOp:
# NOTE: process replay isn't meant to actually schedule anything
return schedule_uop(ast, ScheduleContext(tensor_uops=defaultdict(list)), {}).ast
def recreate_sched(big_sink:UOp) -> list[UOp]:
sched, _, __ = create_schedule_with_vars(big_sink)
return [x.ast for x in sched]
def recreate_kernel(ast:UOp, opts:Renderer, applied_opts:list[Opt], name:str) -> str:
k = Kernel(ast, opts=opts)
for opt in applied_opts: k.apply_opt(opt)
@@ -42,7 +41,8 @@ def recreate_kernel(ast:UOp, opts:Renderer, applied_opts:list[Opt], name:str) ->
# *** diff a "good" recreation against the generated version
def diff(offset:int, name:str, fxn:Callable) -> None:
if ASSERT_DIFF: warnings.filterwarnings("error", category=ProcessReplayWarning)
# TODO: add this assert back for schedule
if ASSERT_DIFF and name != "schedule": warnings.filterwarnings("error", category=ProcessReplayWarning)
if early_stop.is_set(): return None
conn = db_connection()
cur = conn.cursor()

View File

@@ -1,4 +1,4 @@
import sys, atexit, functools, pickle
import sys, functools
from collections import defaultdict, deque
from dataclasses import dataclass, field
from tinygrad.ops import UOp, Variable, Ops, GroupOp, PatternMatcher, UPat, graph_rewrite, graph_rewrite_map, track_rewrites, buffers
@@ -396,17 +396,8 @@ def schedule_uop(pre:UOp, ctx:ScheduleContext, var_vals:dict[UOp, int]) -> Sched
# otherwise, it's not fine
else: raise RuntimeError("self operand of augmented assign must be contiguous.\nhelp: consider using .contiguous():\n"
+colored(" - a += a.T\n", "red")+colored(" + a += a.T.contiguous()", "green"))
# capture process replay
if CAPTURE_PROCESS_REPLAY:
with Context(PICKLE_BUFFERS=0): PROCESS_REPLAY_CAPTURE[str(pre.key)] = pickle.dumps((pre, ContextVar._cache, ast))
return ScheduleItem(ast, tuple(u.buffer for u in si_ctx.bufs), tuple(dedup(m for x in pre.toposort if (m:=ctx.ops_metadata.get(x)) is not None)))
PROCESS_REPLAY_CAPTURE: dict[str, bytes] = {}
if CAPTURE_PROCESS_REPLAY:
@atexit.register
def save_process_replay() -> None:
for k,v in PROCESS_REPLAY_CAPTURE.items(): diskcache_put("schedule_process_replay", k, v, prepickled=True)
# **** schedule creation and toposort
@track_rewrites(named=True)
@@ -473,4 +464,8 @@ def create_schedule_with_vars(big_sink:UOp) -> tuple[list[ScheduleItem], dict[Va
# confirm everything was scheduled correctly
if len(schedule) != (groups:=len(prescheduled)): raise RuntimeError(f"cycle detected in graph, grouped {groups} but only scheduled {len(schedule)}")
if DEBUG >= 1 and len(schedule) >= 10: print(f"scheduled {len(schedule)} kernels")
# capture process replay
if CAPTURE_PROCESS_REPLAY:
with Context(PICKLE_BUFFERS=0):
diskcache_put("schedule_process_replay", str(big_sink.key), (big_sink, ContextVar._cache, [x.ast for x in schedule]))
return schedule, var_vals, becomes_map