mirror of
https://github.com/tinygrad/tinygrad.git
synced 2026-04-07 03:00:26 -04:00
more work toward non-blocking process replay (#5653)
* non-blocking process replay * more actionable * test it * revert the test * %s/logging.warn/logging.warning
This commit is contained in:
17
test/external/process_replay/process_replay.py
vendored
17
test/external/process_replay/process_replay.py
vendored
@@ -1,15 +1,15 @@
|
||||
#!/usr/bin/env python3
|
||||
# compare kernels created by HEAD against master
|
||||
import difflib, pickle, multiprocessing, os, logging
|
||||
import difflib, pickle, multiprocessing, os, logging, sqlite3
|
||||
from tinygrad.codegen.kernel import Kernel
|
||||
from tinygrad.helpers import Context, ContextVar, colored, db_connection, VERSION, getenv, tqdm
|
||||
|
||||
PAGE_SIZE = 100
|
||||
REF = os.getenv("GITHUB_REF_NAME", "")
|
||||
MAX_DIFF_PCT = getenv("PROCESS_REPLAY_MAX_DIFF_PCT", 20)
|
||||
TABLE_NAME = f"process_replay_{getenv('GITHUB_RUN_ID', 'HEAD')}_{VERSION}"
|
||||
ASSERT_DIFF = getenv("ASSERT_PROCESS_REPLAY", int((k:="[run_process_replay]") in os.getenv("COMMIT_MESSAGE", k) or k in os.getenv("PR_TITLE", k)))
|
||||
REF = os.getenv("GITHUB_REF_NAME", "")
|
||||
SKIP_PROCESS_REPLAY = int((k:="[skip_process_replay]") in os.getenv("COMMIT_MESSAGE", "") or k in os.getenv("PR_TITLE", "")) or REF == "master"
|
||||
MAX_DIFF_PCT = getenv("PROCESS_REPLAY_MAX_DIFF_PCT", 20)
|
||||
SKIP_PROCESS_REPLAY = (k:="[skip_process_replay]") in os.getenv("COMMIT_MESSAGE", "") or k in os.getenv("PR_TITLE", "") or REF == "master"
|
||||
early_stop = multiprocessing.Event()
|
||||
logging.basicConfig(level=logging.INFO, format='%(message)s')
|
||||
|
||||
@@ -29,7 +29,7 @@ def process_replay(offset:int):
|
||||
for opt in applied_opts: k.apply_opt(opt)
|
||||
good_src = k.opts.render(name, k.linearize().uops)
|
||||
except Exception as e:
|
||||
logging.warn("FAILED TO RECREATE KERNEL")
|
||||
logging.warning("FAILED TO RECREATE KERNEL")
|
||||
logging.info(ast)
|
||||
logging.info(applied_opts)
|
||||
logging.info(e)
|
||||
@@ -47,7 +47,7 @@ def process_replay(offset:int):
|
||||
logging.info(colored(line, "red" if line.startswith("-") else "green" if line.startswith("+") else None))
|
||||
if ASSERT_DIFF: raise e
|
||||
if changed > MAX_DIFF_PCT:
|
||||
logging.warn(f"detected changes in over {MAX_DIFF_PCT}% of kernels. skipping further diff generation.")
|
||||
logging.warning(f"detected changes in over {MAX_DIFF_PCT}% of kernels. skipping further diff generation.")
|
||||
early_stop.set()
|
||||
break
|
||||
conn.commit()
|
||||
@@ -59,7 +59,10 @@ if __name__ == "__main__":
|
||||
exit(0)
|
||||
conn = db_connection()
|
||||
cur = conn.cursor()
|
||||
row_count = cur.execute(f"select count(*) from '{TABLE_NAME}'").fetchone()[0]
|
||||
try: row_count = cur.execute(f"select count(*) from '{TABLE_NAME}'").fetchone()[0]
|
||||
except sqlite3.OperationalError:
|
||||
logging.warning(f"{TABLE_NAME} isn't accessible in master, did DB_VERSION change?")
|
||||
exit(0)
|
||||
conn.commit()
|
||||
cur.close()
|
||||
offsets = range(0, row_count, PAGE_SIZE)
|
||||
|
||||
Reference in New Issue
Block a user