scheduler + process_replay import cleanup (#8711)

This commit is contained in:
qazal
2025-01-22 05:44:07 -05:00
committed by GitHub
parent e3d1464ba4
commit 2dae467b75
2 changed files with 9 additions and 9 deletions

View File

@@ -2,7 +2,7 @@
# compare kernels created by HEAD against master
from collections import defaultdict
import os, multiprocessing, logging, pickle, sqlite3, difflib, functools, warnings
from typing import Callable, List, Tuple, Union, cast
from typing import Callable, cast
from tinygrad.helpers import VERSION, Context, ContextVar, colored, db_connection, getenv, tqdm
from tinygrad.engine.schedule import ScheduleContext, schedule_uop
from tinygrad.codegen.kernel import Kernel, Opt
@@ -33,15 +33,15 @@ class ProcessReplayWarning(Warning): pass
def recreate_sched(ast:UOp) -> UOp:
# NOTE: process replay isn't meant to actually schedule anything
return schedule_uop(ast, ScheduleContext(tensor_uops=defaultdict(list))).ast
def recreate_kernel(ast:UOp, opts:Renderer, applied_opts:List[Opt], name:str) -> str:
def recreate_kernel(ast:UOp, opts:Renderer, applied_opts:list[Opt], name:str) -> str:
k = Kernel(ast, opts=opts)
for opt in applied_opts: k.apply_opt(opt)
# NOTE: replay with the captured renderer, not the one in master
return k.opts.render(name, cast(List,k.to_program().uops))
return k.opts.render(name, cast(list,k.to_program().uops))
# *** diff a "good" recreation against the generated version
def diff(offset:int, name:str, fxn:Callable) -> Union[Tuple[int, int], bool]:
def diff(offset:int, name:str, fxn:Callable) -> tuple[int, int]|bool:
if early_stop.is_set(): return True
conn = db_connection()
cur = conn.cursor()
@@ -95,7 +95,7 @@ def _pmap(name:str, fxn:Callable, maxtasksperchild:int=16) -> None:
cur.close()
with multiprocessing.get_context("spawn").Pool(multiprocessing.cpu_count(), maxtasksperchild=maxtasksperchild) as pool:
inputs = list(range(0, row_count, PAGE_SIZE))
ret: List[Union[bool, Tuple[int, int]]] = list(tqdm(pool.imap_unordered(functools.partial(diff, name=name, fxn=fxn), inputs), total=len(inputs)))
ret: list[tuple[int, int]|bool] = list(tqdm(pool.imap_unordered(functools.partial(diff, name=name, fxn=fxn), inputs), total=len(inputs)))
pool.close()
pool.join()
pool.terminate()

View File

@@ -1,10 +1,10 @@
import sys, atexit, functools, pickle
from collections import defaultdict, deque
from dataclasses import dataclass, field
from tinygrad.ops import GroupOp, UOp, Ops, PatternMatcher, UPat, Variable, can_pad, graph_rewrite, resolve, track_rewrites, view_left, merge_views
from tinygrad.ops import identity_element, buffers, symbolic_simple, type_verify, graph_rewrite_map
from tinygrad.helpers import Context, Metadata, all_int, all_same, colored, diskcache_put, merge_dicts, prod, dedup, getenv, unwrap
from tinygrad.helpers import FUSE_CONV_BW, FUSE_ARANGE, DEBUG, CAPTURE_PROCESS_REPLAY, ContextVar
from tinygrad.ops import UOp, Variable, Ops, GroupOp, PatternMatcher, UPat, graph_rewrite, graph_rewrite_map, track_rewrites, type_verify, buffers
from tinygrad.ops import can_pad, identity_element, resolve, symbolic_simple, view_left, merge_views
from tinygrad.helpers import Context, ContextVar, Metadata, all_int, all_same, colored, diskcache_put, merge_dicts, prod, dedup, getenv, unwrap
from tinygrad.helpers import FUSE_CONV_BW, FUSE_ARANGE, DEBUG, CAPTURE_PROCESS_REPLAY
from tinygrad.dtype import DType, ImageDType, dtypes
from tinygrad.shape.shapetracker import ShapeTracker
from tinygrad.shape.view import View, strides_for_shape