From 16afe04f45fb84e73afee379dd1b4e42ec002414 Mon Sep 17 00:00:00 2001 From: qazal <77887910+Qazalin@users.noreply.github.com> Date: Thu, 10 Apr 2025 18:27:42 +0800 Subject: [PATCH] move process replay to grouper (#9830) * simpler * sched --- .../external/process_replay/process_replay.py | 11 ++++++----- tinygrad/engine/grouper.py | 19 +++++++++++++++++-- tinygrad/engine/schedule.py | 13 +------------ 3 files changed, 24 insertions(+), 19 deletions(-) diff --git a/test/external/process_replay/process_replay.py b/test/external/process_replay/process_replay.py index 915e180ca3..a588ecdb9a 100755 --- a/test/external/process_replay/process_replay.py +++ b/test/external/process_replay/process_replay.py @@ -2,11 +2,11 @@ # compare kernels created by HEAD against master import os, multiprocessing, logging, pickle, sqlite3, difflib, functools, warnings from typing import Callable, cast -from tinygrad.helpers import VERSION, Context, ContextVar, colored, db_connection, getenv, tqdm -from tinygrad.engine.schedule import create_schedule_with_vars +from tinygrad.helpers import VERSION, Context, ContextVar, colored, db_connection, getenv, tqdm, dedup +from tinygrad.engine.grouper import get_becomes_map from tinygrad.codegen.kernel import Kernel, Opt from tinygrad.renderer import Renderer -from tinygrad.ops import UOp +from tinygrad.ops import UOp, Ops # *** process replay settings @@ -34,8 +34,9 @@ class ProcessReplayWarning(Warning): pass # *** recreators def recreate_sched(big_sink:UOp) -> list[UOp]: - sched, _, __ = create_schedule_with_vars(big_sink) - return [x.ast for x in sched] + sched_sink = get_becomes_map(big_sink)[0][big_sink] + return dedup(u.src[1].arg.ast for u in sched_sink.toposort if u.op is Ops.ASSIGN) + def recreate_kernel(ast:UOp, opts:Renderer, applied_opts:list[Opt], name:str, _) -> str: k = Kernel(ast, opts=opts) for opt in applied_opts: k.apply_opt(opt) diff --git a/tinygrad/engine/grouper.py b/tinygrad/engine/grouper.py index 8a43655aba..9630389e75 100644 --- a/tinygrad/engine/grouper.py +++ b/tinygrad/engine/grouper.py @@ -4,8 +4,8 @@ from tinygrad.ops import UOp, Variable, Ops, GroupOp, PatternMatcher, UPat, grap from tinygrad.ops import can_pad, sint, track_rewrites from tinygrad.codegen.lowerer import get_contraction from tinygrad.codegen.symbolic import symbolic_simple -from tinygrad.helpers import Metadata, all_int, all_same, colored, prod, dedup, unwrap, flatten, getenv, pluralize -from tinygrad.helpers import FUSE_CONV_BW, FUSE_ARANGE, DEBUG, DONT_REALIZE_EXPAND, DONT_GROUP_REDUCES, SPLIT_REDUCEOP +from tinygrad.helpers import Metadata, all_int, all_same, colored, prod, dedup, unwrap, flatten, getenv, pluralize, ContextVar, Context, diskcache_put +from tinygrad.helpers import FUSE_CONV_BW, FUSE_ARANGE, DEBUG, DONT_REALIZE_EXPAND, DONT_GROUP_REDUCES, SPLIT_REDUCEOP, CAPTURE_PROCESS_REPLAY from tinygrad.dtype import ImageDType from tinygrad.shape.shapetracker import ShapeTracker from tinygrad.shape.view import View, strides_for_shape @@ -439,6 +439,13 @@ def get_name(ret:tuple[dict[UOp, UOp], dict[Variable, int]]) -> str: kcount = len({u.src[1] for u in ret[0].values() if u.op is Ops.ASSIGN}) return f"Schedule {pluralize('Kernel', kcount)}"+(f" (with_{pluralize('Var', len(ret[1]))})" if ret[1] else "") +PROCESS_REPLAY_CAPTURE:dict[str, bytes] = {} +if CAPTURE_PROCESS_REPLAY: + import atexit + @atexit.register + def save_process_replay(): + for k,v in PROCESS_REPLAY_CAPTURE.items(): diskcache_put("schedule_process_replay", k, v, prepickled=True) + @track_rewrites(name_fxn=get_name) def get_becomes_map(big_sink:UOp) -> tuple[dict[UOp, UOp], dict[Variable, int]]: # merge_views + simplify @@ -487,4 +494,12 @@ def get_becomes_map(big_sink:UOp) -> tuple[dict[UOp, UOp], dict[Variable, int]]: var_vals: dict[Variable, int] = {} sched_sink = graph_rewrite(sched_sink, create_ast, ctx=var_vals, bottom_up=True) becomes_map[big_sink] = sched_sink + + # capture process replay + if CAPTURE_PROCESS_REPLAY: + with Context(PICKLE_BUFFERS=0): + import pickle + asts = dedup(u.arg.ast for u in sched_sink.toposort if u.op is Ops.KERNEL) + PROCESS_REPLAY_CAPTURE[str(big_sink.key)] = pickle.dumps((big_sink, ContextVar._cache, asts)) + return becomes_map, var_vals diff --git a/tinygrad/engine/schedule.py b/tinygrad/engine/schedule.py index 3f583f6dde..5d8d9a89d3 100644 --- a/tinygrad/engine/schedule.py +++ b/tinygrad/engine/schedule.py @@ -1,9 +1,8 @@ -import atexit, pickle from dataclasses import dataclass from collections import deque from tinygrad.ops import UOp, Variable, Ops, buffers from tinygrad.device import Buffer -from tinygrad.helpers import Metadata, CAPTURE_PROCESS_REPLAY, DEBUG, Context, ContextVar, diskcache_put, unwrap +from tinygrad.helpers import Metadata, DEBUG, unwrap from tinygrad.engine.grouper import get_becomes_map # **** ScheduleItem return type @@ -14,12 +13,6 @@ class ScheduleItem: bufs: tuple[Buffer, ...] metadata: tuple[Metadata, ...] = () -PROCESS_REPLAY_CAPTURE:dict[str, bytes] = {} -if CAPTURE_PROCESS_REPLAY: - @atexit.register - def save_process_replay(): - for k,v in PROCESS_REPLAY_CAPTURE.items(): diskcache_put("schedule_process_replay", k, v, prepickled=True) - # **** schedule linearizer def create_schedule_with_vars(big_sink:UOp) -> tuple[list[ScheduleItem], dict[Variable, int], dict[UOp, UOp]]: @@ -53,10 +46,6 @@ def create_schedule_with_vars(big_sink:UOp) -> tuple[list[ScheduleItem], dict[Va if len(schedule) != len(in_degree): raise RuntimeError(f"created {len(in_degree)} kernels but only scheduled {len(schedule)}") if DEBUG >= 1 and len(schedule) >= 10: print(f"scheduled {len(schedule)} kernels") - # capture process replay - if CAPTURE_PROCESS_REPLAY: - with Context(PICKLE_BUFFERS=0): PROCESS_REPLAY_CAPTURE[str(big_sink.key)] = pickle.dumps((big_sink, ContextVar._cache, [x.ast for x in schedule])) - # map ASSIGN to BUFFER after ScheduleItems are constructed for k,v in becomes_map.items(): if v.base.op is Ops.ASSIGN: