more work toward non-blocking process replay (#5653)

* non-blocking process replay

* more actionable

* test it

* revert the test

* %s/logging.warn/logging.warning
This commit is contained in:
qazal
2024-07-23 19:26:31 +08:00
committed by GitHub
parent a93982ef42
commit 5f394fc9c6

View File

@@ -1,15 +1,15 @@
#!/usr/bin/env python3
# compare kernels created by HEAD against master
import difflib, pickle, multiprocessing, os, logging
import difflib, pickle, multiprocessing, os, logging, sqlite3
from tinygrad.codegen.kernel import Kernel
from tinygrad.helpers import Context, ContextVar, colored, db_connection, VERSION, getenv, tqdm
PAGE_SIZE = 100
REF = os.getenv("GITHUB_REF_NAME", "")
MAX_DIFF_PCT = getenv("PROCESS_REPLAY_MAX_DIFF_PCT", 20)
TABLE_NAME = f"process_replay_{getenv('GITHUB_RUN_ID', 'HEAD')}_{VERSION}"
ASSERT_DIFF = getenv("ASSERT_PROCESS_REPLAY", int((k:="[run_process_replay]") in os.getenv("COMMIT_MESSAGE", k) or k in os.getenv("PR_TITLE", k)))
REF = os.getenv("GITHUB_REF_NAME", "")
SKIP_PROCESS_REPLAY = int((k:="[skip_process_replay]") in os.getenv("COMMIT_MESSAGE", "") or k in os.getenv("PR_TITLE", "")) or REF == "master"
MAX_DIFF_PCT = getenv("PROCESS_REPLAY_MAX_DIFF_PCT", 20)
SKIP_PROCESS_REPLAY = (k:="[skip_process_replay]") in os.getenv("COMMIT_MESSAGE", "") or k in os.getenv("PR_TITLE", "") or REF == "master"
early_stop = multiprocessing.Event()
logging.basicConfig(level=logging.INFO, format='%(message)s')
@@ -29,7 +29,7 @@ def process_replay(offset:int):
for opt in applied_opts: k.apply_opt(opt)
good_src = k.opts.render(name, k.linearize().uops)
except Exception as e:
logging.warn("FAILED TO RECREATE KERNEL")
logging.warning("FAILED TO RECREATE KERNEL")
logging.info(ast)
logging.info(applied_opts)
logging.info(e)
@@ -47,7 +47,7 @@ def process_replay(offset:int):
logging.info(colored(line, "red" if line.startswith("-") else "green" if line.startswith("+") else None))
if ASSERT_DIFF: raise e
if changed > MAX_DIFF_PCT:
logging.warn(f"detected changes in over {MAX_DIFF_PCT}% of kernels. skipping further diff generation.")
logging.warning(f"detected changes in over {MAX_DIFF_PCT}% of kernels. skipping further diff generation.")
early_stop.set()
break
conn.commit()
@@ -59,7 +59,10 @@ if __name__ == "__main__":
exit(0)
conn = db_connection()
cur = conn.cursor()
row_count = cur.execute(f"select count(*) from '{TABLE_NAME}'").fetchone()[0]
try: row_count = cur.execute(f"select count(*) from '{TABLE_NAME}'").fetchone()[0]
except sqlite3.OperationalError:
logging.warning(f"{TABLE_NAME} isn't accessible in master, did DB_VERSION change?")
exit(0)
conn.commit()
cur.close()
offsets = range(0, row_count, PAGE_SIZE)