Files
tinygrad/test/external/process_replay/process_replay.py
2024-10-30 22:45:40 +02:00

116 lines
4.8 KiB
Python
Executable File

#!/usr/bin/env python3
# compare kernels created by HEAD against master
import os, multiprocessing, logging, pickle, sqlite3, difflib, functools
from typing import Callable, List, Tuple, Union, cast
from tinygrad.helpers import VERSION, Context, ContextVar, colored, db_connection, getenv, tqdm
from tinygrad.engine.schedule import full_ast_rewrite
from tinygrad.codegen.kernel import Kernel, Opt
from tinygrad.renderer import Renderer
from tinygrad.ops import UOp
from test.helpers import print_diff
# *** process replay settings
# internal
PAGE_SIZE = 100
REF = os.getenv("GITHUB_REF_NAME", "")
MAX_DIFF_PCT = getenv("PROCESS_REPLAY_MAX_DIFF_PCT", 20)
TABLE_NAME = f"process_replay_{VERSION}"
os.environ["RUN_PROCESS_REPLAY"] = "0"
early_stop = multiprocessing.Event()
logging.basicConfig(level=logging.INFO, format="%(message)s")
# user config
ASSERT_DIFF = int((flag:="[pr]") in os.getenv("COMMIT_MESSAGE", flag) or flag in os.getenv("PR_TITLE", flag))
if not getenv("ASSERT_PROCESS_REPLAY", 1): ASSERT_DIFF = 0
SKIP_PROCESS_REPLAY = (k:="[skip_process_replay]") in os.getenv("COMMIT_MESSAGE", "") or k in os.getenv("PR_TITLE", "")
if REF == "master": SKIP_PROCESS_REPLAY = True
# *** recreators
def recreate_sched(*args) -> UOp: return full_ast_rewrite(*args[0])[0]
def recreate_kernel(ast:UOp, opts:Renderer, applied_opts:List[Opt], name:str, _) -> str:
k = Kernel(ast, opts=opts)
for opt in applied_opts: k.apply_opt(opt)
# NOTE: replay with the captured renderer, not the one in master
return k.opts.render(name, cast(List,k.to_program().uops))
# *** diff a "good" recreation against the generated version
def diff(offset:int, name:str, fxn:Callable) -> Union[Tuple[int, int], bool]:
if early_stop.is_set(): return True
conn = db_connection()
cur = conn.cursor()
cur.execute(f"SELECT val FROM '{name}_{TABLE_NAME}' LIMIT ? OFFSET ?", (PAGE_SIZE, offset))
additions, deletions, changed = 0, 0, 0
for row in cur.fetchall():
# try unpickle
try: args = pickle.loads(row[0])
except Exception as e:
logging.warning(f"FAILED TO UNPICKLE OBJECTS {e}")
if ASSERT_DIFF: return True
continue
# try recreate
try:
with Context(**{k:v for k,v in args[-2].items() if k in ContextVar._cache and k != "DEBUG"}): good = fxn(*args[:-2])
except Exception as e:
logging.warning(f"FAILED TO RECREATE KERNEL {e}")
for x in args[:-1]: logging.info(x)
if ASSERT_DIFF: return True
continue
# diff kernels
try: assert args[-1] == good
except AssertionError:
logging.info("PROCESS REPLAY DETECTED CHANGE")
for x in args[:-1]: logging.info(x)
print_diff(good, args[-1])
changes = list(difflib.unified_diff(str(good).splitlines(), str(args[-1]).splitlines()))
additions += len([x for x in changes if x.startswith("+")])
deletions += len([x for x in changes if x.startswith("-")])
if ASSERT_DIFF: return additions, deletions
if changed > MAX_DIFF_PCT:
logging.warning(f"detected changes in over {MAX_DIFF_PCT}% of {name}s. skipping further diff generation.")
early_stop.set()
break
conn.commit()
cur.close()
return additions, deletions
# *** generic runner for executing fxn across all rows of a table in parallel
def _pmap(name:str, fxn:Callable, maxtasksperchild:int=16) -> None:
conn = db_connection()
cur = conn.cursor()
try: row_count = cur.execute(f"select count(*) from '{name}_{TABLE_NAME}'").fetchone()[0]
except sqlite3.OperationalError:
logging.warning(f"{name}_{TABLE_NAME} isn't accessible in master, did DB_VERSION change?")
return None
conn.commit()
cur.close()
with multiprocessing.get_context("spawn").Pool(multiprocessing.cpu_count(), maxtasksperchild=maxtasksperchild) as pool:
inputs = list(range(0, row_count, PAGE_SIZE))
ret: List[Union[bool, Tuple[int, int]]] = list(tqdm(pool.imap_unordered(functools.partial(diff, name=name, fxn=fxn), inputs), total=len(inputs)))
pool.close()
pool.join()
pool.terminate()
changed = [bool(x[0] or x[1]) if isinstance(x, tuple) else x for x in ret]
insertion, deletions = [x[0] for x in ret if isinstance(x, tuple)], [x[1] for x in ret if isinstance(x, tuple)]
logging.info(f"{sum(changed)} kernels changed")
if sum(insertion) != 0: logging.info(colored(f"{sum(insertion)} insertions(+)", "green"))
if sum(deletions) != 0: logging.info(colored(f"{sum(deletions)} deletions(-)", "red"))
if any(changed) and ASSERT_DIFF: raise AssertionError("process replay detected changes")
# *** main loop
if __name__ == "__main__":
if SKIP_PROCESS_REPLAY:
logging.info("skipping process replay.")
exit(0)
for name,fxn in [("schedule", recreate_sched), ("kernel", recreate_kernel)]:
logging.info(f"***** {name} diff")
try: _pmap(name, fxn)
except Exception as e:
if ASSERT_DIFF: raise e
logging.error(f"{name} diff err {e}")