diff --git a/test/external/process_replay/process_replay.py b/test/external/process_replay/process_replay.py index 6e7e7370bc..4d9b03c385 100755 --- a/test/external/process_replay/process_replay.py +++ b/test/external/process_replay/process_replay.py @@ -1,8 +1,9 @@ #!/usr/bin/env python3 # compare kernels created by HEAD against master -import difflib, pickle, multiprocessing, os, logging, sqlite3 +import difflib, pickle, multiprocessing, os, logging, sqlite3, requests, io, zipfile from typing import List from tinygrad.codegen.kernel import Kernel +from tinygrad.device import Device from tinygrad.helpers import Context, ContextVar, colored, db_connection, VERSION, getenv, tqdm from tinygrad.ops import LazyOp @@ -12,7 +13,9 @@ MAX_DIFF_PCT = getenv("PROCESS_REPLAY_MAX_DIFF_PCT", 20) TABLE_NAME = f"process_replay_{getenv('GITHUB_RUN_ID', 'HEAD')}_{VERSION}" REF_TABLE_NAME = f"process_replay_master_{VERSION}" ASSERT_DIFF = getenv("ASSERT_PROCESS_REPLAY", int((k:="[run_process_replay]") in os.getenv("COMMIT_MESSAGE", k) or k in os.getenv("PR_TITLE", k))) -SKIP_PROCESS_REPLAY = (k:="[skip_process_replay]") in os.getenv("COMMIT_MESSAGE", "") or k in os.getenv("PR_TITLE", "") or REF == "master" +if REF == "master": ASSERT_DIFF = False +COMPARE_SCHEDULE = getenv("COMPARE_SCHEDULE", int((k:="[compare_schedule]") in os.getenv("COMMIT_MESSAGE", k) or k in os.getenv("PR_TITLE", k))) +SKIP_PROCESS_REPLAY = (k:="[skip_process_replay]") in os.getenv("COMMIT_MESSAGE", "") or k in os.getenv("PR_TITLE", "") early_stop = multiprocessing.Event() logging.basicConfig(level=logging.INFO, format='%(message)s') @@ -39,7 +42,7 @@ def process_replay(offset:int, ref_schedule:List[LazyOp]): if ASSERT_DIFF: raise e continue # try compare - if getenv("COMPARE_SCHEDULE") and ast not in ref_schedule: + if COMPARE_SCHEDULE and ast not in ref_schedule: with Context(**{k:v for k,v in ctx.items() if k in ContextVar._cache and k != "DEBUG"}): print(opts.render(name, Kernel(ast, opts=opts).linearize().uops)) continue @@ -61,25 +64,47 @@ def process_replay(offset:int, ref_schedule:List[LazyOp]): cur.close() def get_ref_schedule(offset:int, ref_schedule): - conn = db_connection() + conn = sqlite3.connect("/tmp/process_replay/process_replay.db") cur = conn.cursor() cur.execute(f"SELECT val FROM '{REF_TABLE_NAME}' LIMIT ? OFFSET ?", (PAGE_SIZE, offset)) for row in cur.fetchall(): ref_schedule.append(pickle.loads(row[0])[0]) + conn.commit() + cur.close() if __name__ == "__main__": if SKIP_PROCESS_REPLAY: logging.info("skipping process replay.") exit(0) - conn = db_connection() - cur = conn.cursor() ref_schedule = multiprocessing.Manager().list() - if getenv("COMPARE_SCHEDULE"): - row_count = cur.execute(f"select count(*) from '{REF_TABLE_NAME}'").fetchone()[0] + # *** download the reference schedule + if COMPARE_SCHEDULE: + logging.info("fetching process replay reference") + # TODO: make this run_id dynamic + run_id = "10093148840" + name = f"process_replay_{Device.DEFAULT.lower()}.db" # TODO: onnx and openpilot is matrix.task + url = f"https://api.github.com/repos/{os.getenv('GITHUB_REPOSITORY', 'tinygrad/tinygrad')}/actions/runs/{run_id}/artifacts?name={name}" + headers = { + "Authorization": f"Bearer {os.getenv('GITHUB_TOKEN')}", + "Accept": "application/vnd.github+json", "X-GitHub-Api-Version": "2022-11-28" + } + res = requests.get(url, headers=headers) + assert res.status_code == 200, f"download failed {res.status_code} {res.json()}" + download_url = res.json()["artifacts"][0]["archive_download_url"] + res = requests.get(download_url, headers=headers) + assert res.status_code == 200, f"download failed {res.status_code}" + with io.BytesIO(res.content) as zip_content: + with zipfile.ZipFile(zip_content, "r") as zip_ref: zip_ref.extractall("/tmp/process_replay/") + ref_conn = sqlite3.connect("/tmp/process_replay/process_replay.db") + row_count = ref_conn.execute(f"select count(*) from '{REF_TABLE_NAME}'").fetchone()[0] processes = [] for i in tqdm(range(0, row_count, PAGE_SIZE)): processes.append(p:=multiprocessing.Process(target=get_ref_schedule, args=(i, ref_schedule))) p.start() for p in processes: p.join() + ref_conn.close() + conn = db_connection() + cur = conn.cursor() + # *** run the comparison try: row_count = cur.execute(f"select count(*) from '{TABLE_NAME}'").fetchone()[0] except sqlite3.OperationalError: logging.warning(f"{TABLE_NAME} isn't accessible in master, did DB_VERSION change?")