From 3c378efcb60506ad733f90430df7054faf4c123f Mon Sep 17 00:00:00 2001 From: qazal <77887910+Qazalin@users.noreply.github.com> Date: Mon, 15 Jul 2024 05:09:28 +0800 Subject: [PATCH] process replay docs improvements (#5481) * minor cleanups * docs and logs * shorter * comma * s/print/logging.info [run_process_replay] * use logging.warn * process name is noise * revert lowerer change [run_process_replay] --- README.md | 4 +- test/external/process_replay/README.md | 8 ++-- .../external/process_replay/process_replay.py | 42 +++++++++---------- 3 files changed, 26 insertions(+), 28 deletions(-) diff --git a/README.md b/README.md index ad5021c293..d49804fb58 100644 --- a/README.md +++ b/README.md @@ -175,6 +175,4 @@ python3 -m pytest test/ # whole test suite #### Process replay tests -[Process replay](https://github.com/tinygrad/tinygrad/blob/master/test/external/process_replay/process_replay.py) detects changes in the generated kernels of CI tests by comparing them against tinygrad master. If your PR is a refactor or speedup without any expected behavior change, it should include a green process replay pass to get merged. - -You can enable process replay by adding [run_process_replay] to your PR title. [example](https://github.com/tinygrad/tinygrad/pull/4995). Note that you should keep your branch up-to-date with master. +[Process replay](https://github.com/tinygrad/tinygrad/blob/master/test/external/process_replay/process_replay.py) compares your PR's generated kernels against master. If your PR is a refactor or speedup without any expected behavior change, It should include [run_process_replay] in the PR title, [example](https://github.com/tinygrad/tinygrad/pull/4995). Note that you should keep your branch up-to-date with master. diff --git a/test/external/process_replay/README.md b/test/external/process_replay/README.md index d6557adac6..67873a149e 100644 --- a/test/external/process_replay/README.md +++ b/test/external/process_replay/README.md @@ -1,10 +1,10 @@ # Process replay tests -Process replay is a tool for creating a diff of generated kernels between two commits. +Process replay is a tool for creating a diff of generated kernels between two commits. By default, process replay doesn't assert kernel diffs. -Refactor and speedup PRs need a green process replay check. +Refactor and speedup prs must enable the assert by including `[run_process_replay]` in the pr title. -Behavior change PRs can use process replay with `ASSERT_PROCESS_REPLAY=0` to check the diff is what was expected. It's also an indirect test coverage checker. +Note that process replay [early stops when over 20% of kernels change, for speed.](https://github.com/tinygrad/tinygrad/pull/5480). ## Running locally @@ -12,6 +12,6 @@ To run process replay locally: (optional: clear previous process replay runs with `test/external/process_replay/reset.py`) -1. Run tests with `RUN_PROCESS_REPLAY=1` in your branch +1. Run tests with `RUN_PROCESS_REPLAY=1` in your branch. This will capture the kernels. 2. Checkout master 3. Run `test/external/process_replay/process_replay.py` diff --git a/test/external/process_replay/process_replay.py b/test/external/process_replay/process_replay.py index 77abffefa1..bcbf3e4d82 100755 --- a/test/external/process_replay/process_replay.py +++ b/test/external/process_replay/process_replay.py @@ -1,22 +1,22 @@ #!/usr/bin/env python3 # compare kernels created by HEAD against master -import difflib, pickle, multiprocessing, os +import difflib, pickle, multiprocessing, os, logging from tinygrad.codegen.kernel import Kernel from tinygrad.helpers import Context, ContextVar, colored, db_connection, VERSION, getenv, tqdm -page_size = 100 -table_name = f"process_replay_{getenv('GITHUB_RUN_ID', 'HEAD')}_{VERSION}" +PAGE_SIZE = 100 +TABLE_NAME = f"process_replay_{getenv('GITHUB_RUN_ID', 'HEAD')}_{VERSION}" +ASSERT_DIFF = getenv("ASSERT_PROCESS_REPLAY", int((k:="[run_process_replay]") in os.getenv("COMMIT_MESSAGE", k) or k in os.getenv("PR_TITLE", k))) +MAX_DIFF_PCT = getenv("PROCESS_REPLAY_MAX_DIFF_PCT", 20) +assert MAX_DIFF_PCT < 100 early_stop = multiprocessing.Event() +logging.basicConfig(level=logging.INFO, format='%(message)s') def process_replay(offset:int): if early_stop.is_set(): return - ASSERT_PROCESS_REPLAY = getenv("ASSERT_PROCESS_REPLAY", int((k:="[run_process_replay]") in os.getenv("COMMIT_MESSAGE", k) or \ - k in os.getenv("PR_TITLE", k))) - MAX_DIFF_PCT = getenv("PROCESS_REPLAY_MAX_DIFF_PCT", 20) - assert MAX_DIFF_PCT <= 100 conn = db_connection() cur = conn.cursor() - cur.execute(f"SELECT val FROM '{table_name}' LIMIT ? OFFSET ?", (page_size, offset)) + cur.execute(f"SELECT val FROM '{TABLE_NAME}' LIMIT ? OFFSET ?", (PAGE_SIZE, offset)) changed = 0 for row in cur.fetchall(): ast, opts, applied_opts, name, compare_src, ctx = pickle.loads(row[0]) @@ -27,25 +27,25 @@ def process_replay(offset:int): for opt in applied_opts: k.apply_opt(opt) good_src = k.opts.render(name, k.linearize().uops) except Exception as e: - print("FAILED TO RECREATE KERNEL") - print(ast) - print(applied_opts) - print(e) - if ASSERT_PROCESS_REPLAY: raise e + logging.warn("FAILED TO RECREATE KERNEL") + logging.info(ast) + logging.info(applied_opts) + logging.info(e) + if ASSERT_DIFF: raise e continue # try compare try: assert compare_src == good_src except AssertionError as e: changed += 1 - print("PROCESS REPLAY DETECTED CHANGE") - print(ast) - print(applied_opts) + logging.info("PROCESS REPLAY DETECTED CHANGE") + logging.info(ast) + logging.info(applied_opts) diff = list(difflib.unified_diff(good_src.splitlines(), compare_src.splitlines())) for line in diff: - print(colored(line, "red" if line.startswith("-") else "green" if line.startswith("+") else None)) - if ASSERT_PROCESS_REPLAY: raise e + logging.info(colored(line, "red" if line.startswith("-") else "green" if line.startswith("+") else None)) + if ASSERT_DIFF: raise e if changed > MAX_DIFF_PCT: - print(f"WARN: detected chanegs in over {MAX_DIFF_PCT}% of kernels. skipping further diff generation.") + logging.warn(f"detected chanegs in over {MAX_DIFF_PCT}% of kernels. skipping further diff generation.") early_stop.set() break conn.commit() @@ -54,8 +54,8 @@ def process_replay(offset:int): if __name__ == "__main__": conn = db_connection() cur = conn.cursor() - row_count = cur.execute(f"select count(*) from '{table_name}'").fetchone()[0] + row_count = cur.execute(f"select count(*) from '{TABLE_NAME}'").fetchone()[0] conn.commit() cur.close() - offsets = range(0, row_count, page_size) + offsets = range(0, row_count, PAGE_SIZE) with multiprocessing.Pool(processes=multiprocessing.cpu_count()) as pool: list(tqdm(pool.imap(process_replay, offsets), total=len(offsets)))