From 2800520dd5dfde8b83d3db0d08ef11de547470e9 Mon Sep 17 00:00:00 2001 From: qazal <77887910+Qazalin@users.noreply.github.com> Date: Tue, 8 Oct 2024 15:43:22 +0300 Subject: [PATCH] even smaller process_replay.py [pr] (#6941) * even smaller process_replay.py [pr] * delete those tests * dedup asts --- .../external/process_replay/process_replay.py | 83 +++++++------------ .../process_replay/test_process_replay.py | 49 ----------- .../process_replay/test_process_replay.sh | 12 --- tinygrad/codegen/kernel.py | 2 +- tinygrad/engine/schedule.py | 4 +- 5 files changed, 32 insertions(+), 118 deletions(-) delete mode 100644 test/external/process_replay/test_process_replay.py delete mode 100755 test/external/process_replay/test_process_replay.sh diff --git a/test/external/process_replay/process_replay.py b/test/external/process_replay/process_replay.py index 1ae9152752..b23ce4a7d3 100755 --- a/test/external/process_replay/process_replay.py +++ b/test/external/process_replay/process_replay.py @@ -1,11 +1,13 @@ #!/usr/bin/env python3 # compare kernels created by HEAD against master -import os, multiprocessing, logging, pickle, sqlite3, difflib +import os, multiprocessing, logging, pickle, sqlite3, difflib, functools from typing import Callable, List, Tuple, Union, cast from tinygrad.engine.schedule import full_ast_rewrite from tinygrad.helpers import VERSION, Context, ContextVar, colored, db_connection, getenv, tqdm -from tinygrad.codegen.kernel import Kernel -from test.external.process_replay.helpers import print_diff +from tinygrad.codegen.kernel import Kernel, Opt +from test.external.process_replay.helpers import ProcessReplayContext, print_diff +from tinygrad.ops import UOp +from tinygrad.renderer import Renderer # *** process replay settings @@ -25,77 +27,50 @@ if not getenv("ASSERT_PROCESS_REPLAY", 1): ASSERT_DIFF = 0 SKIP_PROCESS_REPLAY = (k:="[skip_process_replay]") in os.getenv("COMMIT_MESSAGE", "") or k in os.getenv("PR_TITLE", "") if REF == "master": SKIP_PROCESS_REPLAY = True -# *** differs +# *** recreators -def diff_schedule(offset:int) -> bool: - conn = db_connection() - cur = conn.cursor() - cur.execute(f"SELECT val FROM 'schedule_{TABLE_NAME}' LIMIT ? OFFSET ?", (PAGE_SIZE, offset)) - changed = 0 - for row in cur.fetchall(): - # try unpickle - try: raw_ast, ctx, compare_ast = pickle.loads(row[0]) - except Exception as e: - logging.warning(f"FAILED TO UNPICKLE OBJECTS {e}") - if ASSERT_DIFF: return True - continue - # try full_ast_rewrite - try: good_ast = full_ast_rewrite(raw_ast, ctx) - except Exception as e: - logging.warning(f"FAILED TO DO AST REWRITE {e}") - logging.info(raw_ast) - logging.info(ctx) - if ASSERT_DIFF: return True - continue - # diff asts - try: assert compare_ast == good_ast - except AssertionError: - logging.info("PROCESS REPLAY DETECTED CHANGE") - logging.info(raw_ast) - logging.info(ctx) - print_diff(good_ast, compare_ast) - return bool(changed) +def recreate_sched(sink:UOp, ctx) -> UOp: return full_ast_rewrite(sink, ctx) +def recreate_kernel(ast:UOp, opts:Renderer, applied_opts:List[Opt], name:str, ctx:ProcessReplayContext) -> str: + with Context(**{k:v for k,v in ctx.ctx_vars.items() if k in ContextVar._cache and k != "DEBUG"}): + k = Kernel(ast, opts=opts) + for opt in applied_opts: k.apply_opt(opt) + # NOTE: replay with the captured renderer, not the one in master + return k.opts.render(name, cast(List,k.to_program().uops)) -def diff_kernel(offset:int) -> Union[Tuple[int, int], bool]: +# *** diff a "good" recreation against the generated version + +def diff(offset:int, name:str, fxn:Callable) -> Union[Tuple[int, int], bool]: if early_stop.is_set(): return True conn = db_connection() cur = conn.cursor() - cur.execute(f"SELECT val FROM 'kernel_{TABLE_NAME}' LIMIT ? OFFSET ?", (PAGE_SIZE, offset)) + cur.execute(f"SELECT val FROM '{name}_{TABLE_NAME}' LIMIT ? OFFSET ?", (PAGE_SIZE, offset)) additions, deletions, changed = 0, 0, 0 for row in cur.fetchall(): # try unpickle - try: ast, opts, applied_opts, name, compare_src, ctx = pickle.loads(row[0]) + try: args = pickle.loads(row[0]) except Exception as e: logging.warning(f"FAILED TO UNPICKLE OBJECTS {e}") if ASSERT_DIFF: return True continue - # try linearize - try: - with Context(**{k:v for k,v in ctx.ctx_vars.items() if k in ContextVar._cache and k != "DEBUG"}): - k = Kernel(ast, opts=opts) - for opt in applied_opts: k.apply_opt(opt) - # NOTE: replay with the captured renderer, not the one in master - good_src = k.opts.render(name, cast(List,k.to_program().uops)) + # try recreate + try: good = fxn(*args[:-1]) except Exception as e: logging.warning(f"FAILED TO RECREATE KERNEL {e}") - logging.info(ast) - logging.info(applied_opts) + for x in args[:-1]: logging.info(x) if ASSERT_DIFF: return True continue # diff kernels - try: assert compare_src == good_src + try: assert args[-1] == good except AssertionError: logging.info("PROCESS REPLAY DETECTED CHANGE") - logging.info(ast) - logging.info(applied_opts) - logging.info(ctx.loc) - print_diff(good_src, compare_src) - changes = list(difflib.unified_diff(str(good_src).splitlines(), str(compare_src).splitlines())) + for x in args[:-1]: logging.info(x) + print_diff(good, args[-1]) + changes = list(difflib.unified_diff(str(good).splitlines(), str(args[-1]).splitlines())) additions += len([x for x in changes if x.startswith("+")]) deletions += len([x for x in changes if x.startswith("-")]) if ASSERT_DIFF: return additions, deletions if changed > MAX_DIFF_PCT: - logging.warning(f"detected changes in over {MAX_DIFF_PCT}% of kernels. skipping further diff generation.") + logging.warning(f"detected changes in over {MAX_DIFF_PCT}% of {name}s. skipping further diff generation.") early_stop.set() break conn.commit() @@ -104,7 +79,7 @@ def diff_kernel(offset:int) -> Union[Tuple[int, int], bool]: # *** generic runner for executing fxn across all rows of a table in parallel -def _pmap(name:str, fxn:Callable[[int], Union[bool, Tuple[int, int]]], maxtasksperchild:int=16) -> None: +def _pmap(name:str, fxn:Callable, maxtasksperchild:int=16) -> None: conn = db_connection() cur = conn.cursor() try: row_count = cur.execute(f"select count(*) from '{name}_{TABLE_NAME}'").fetchone()[0] @@ -115,7 +90,7 @@ def _pmap(name:str, fxn:Callable[[int], Union[bool, Tuple[int, int]]], maxtasksp cur.close() with multiprocessing.get_context("spawn").Pool(multiprocessing.cpu_count(), maxtasksperchild=maxtasksperchild) as pool: inputs = list(range(0, row_count, PAGE_SIZE)) - ret: List[Union[bool, Tuple[int, int]]] = list(tqdm(pool.imap_unordered(fxn, inputs), total=len(inputs))) + ret: List[Union[bool, Tuple[int, int]]] = list(tqdm(pool.imap_unordered(functools.partial(diff, name=name, fxn=fxn), inputs), total=len(inputs))) pool.close() pool.join() pool.terminate() @@ -133,7 +108,7 @@ if __name__ == "__main__": logging.info("skipping process replay.") exit(0) - for name,fxn in [("schedule", diff_schedule), ("kernel", diff_kernel)]: + for name,fxn in [("schedule", recreate_sched), ("kernel", recreate_kernel)]: logging.info(f"***** {name} diff") try: _pmap(name, fxn) except Exception as e: diff --git a/test/external/process_replay/test_process_replay.py b/test/external/process_replay/test_process_replay.py deleted file mode 100644 index 0f26ac941c..0000000000 --- a/test/external/process_replay/test_process_replay.py +++ /dev/null @@ -1,49 +0,0 @@ -import unittest -import contextlib, sqlite3 -from test.external.process_replay.helpers import ProcessReplayContext -from test.external.process_replay.process_replay import TABLE_NAME, diff_kernel - -from tinygrad.codegen.kernel import Kernel -from tinygrad.helpers import to_function_name, db_connection, diskcache_put, VERSION -from tinygrad.ops import UOp -from tinygrad.renderer.cstyle import ClangRenderer -from tinygrad.tensor import Tensor - -def helper_append_replay(ast:UOp, name:str, src:str) -> int: - name = f"kernel_{TABLE_NAME}" - diskcache_put(name.replace(f"_{VERSION}", ""), "test_1", (ast, ClangRenderer(), [], to_function_name(name), src, ProcessReplayContext({}))) - conn = db_connection() - row_count = conn.execute(f"select count(*) from '{name}'").fetchone()[0] - return row_count - -class TestProcessReplay(unittest.TestCase): - def tearDown(self): - conn = db_connection() - cur = conn.cursor() - with contextlib.suppress(sqlite3.OperationalError): cur.execute(f"DELETE FROM 'kernel_{TABLE_NAME}' WHERE key LIKE 'test_%'") - conn.commit() - cur.close() - - def test_simple_diff(self): - out = Tensor([1, 2, 3])+1 - ast = out.schedule()[-1].ast - test_src = """ -void test(int* restrict a, const int* restrict b) { - for (int ridx0 = 0; ridx0 < 3; ridx0++) { - int val0 = b[ridx0]; - a[ridx0] = (val0+1); - } -} - """ - offset = helper_append_replay(ast, "test", test_src) - assert diff_kernel(offset-1) == (5, 4) - - def test_identical_run(self): - out = Tensor([1, 2, 3])+1 - ast = out.schedule()[-1].ast - test_prg = Kernel(ast, ClangRenderer()).to_program() - offset = helper_append_replay(ast, test_prg.name, test_prg.src) - assert diff_kernel(offset) == (0, 0) - -if __name__ == "__main__": - unittest.main() diff --git a/test/external/process_replay/test_process_replay.sh b/test/external/process_replay/test_process_replay.sh deleted file mode 100755 index ce6cd9b5be..0000000000 --- a/test/external/process_replay/test_process_replay.sh +++ /dev/null @@ -1,12 +0,0 @@ -#!/bin/bash - -# should assert -sed -i 's/temp/temp1/g' ./tinygrad/codegen/kernel.py -COMPARE_SCHEDULE=0 ASSERT_PROCESS_REPLAY=1 python3 test/external/process_replay/process_replay.py &> /dev/null -if [[ $? -eq 0 ]]; then - echo "PROCESS REPLAY IS WRONG." - exit 1 -fi -# should NOT assert -git stash > /dev/null -COMPARE_SCHEDULE=0 ASSERT_PROCESS_REPLAY=1 python3 test/external/process_replay/process_replay.py &> /dev/null diff --git a/tinygrad/codegen/kernel.py b/tinygrad/codegen/kernel.py index 0d4eabf7f8..2c49864c70 100644 --- a/tinygrad/codegen/kernel.py +++ b/tinygrad/codegen/kernel.py @@ -728,7 +728,7 @@ class Kernel: if getenv("RUN_PROCESS_REPLAY"): from test.external.process_replay.helpers import get_process_replay_ctx - diskcache_put("kernel_process_replay", str(id(self)), (self.ast, self.opts, self.applied_opts, name, src, get_process_replay_ctx())) + diskcache_put("kernel_process_replay", str(id(self)), (self.ast, self.opts, self.applied_opts, name, get_process_replay_ctx(), src)) # group non-local bufs by the op type (LOAD or STORE) and the buffer arg. take the max access of that buffer in bytes # TODO: these max and min don't work on symbolic, and results are very wrong. diff --git a/tinygrad/engine/schedule.py b/tinygrad/engine/schedule.py index 4ca5a01c3b..2eb4cd4a3a 100644 --- a/tinygrad/engine/schedule.py +++ b/tinygrad/engine/schedule.py @@ -1,4 +1,4 @@ -import sys, pickle, atexit, uuid +import sys, pickle, atexit from collections import defaultdict, deque from dataclasses import dataclass from typing import Callable, Tuple, List, Dict, Optional, DefaultDict, cast @@ -131,7 +131,7 @@ def full_ast_rewrite(base_sink:UOp, ctx:ScheduleItemContext) -> UOp: if not AST_REWRITE: return base_sink sink = graph_rewrite(base_sink, reduceop_fusor) ret = graph_rewrite(sink, enumerate_bufs, ctx) - if getenv("RUN_PROCESS_REPLAY"): diskcache_put("schedule_process_replay", str(uuid.uuid4()), (base_sink, ctx, ret)) + if getenv("RUN_PROCESS_REPLAY"): diskcache_put("schedule_process_replay", str(base_sink.key), (base_sink, ctx, ret)) return ret # *** List[LazyBuffer] lowering to ScheduleItem ***