Files
tinygrad/test/external/process_replay/process_replay.py
chenyu 6b3112d525 fix qcom process_replay for kernel diff (#6079)
* debug why qcom process_replay does not run

skipping the wrong exception?

* um-hum

* get_step_times was parsed incorrectly

* cleanup
2024-08-14 15:05:49 -04:00

152 lines
7.0 KiB
Python
Executable File

#!/usr/bin/env python3
# compare kernels created by HEAD against master
import difflib, pickle, multiprocessing, os, logging, sqlite3, requests
from tabulate import tabulate
from datetime import datetime
from typing import Dict, List, cast
from test.external.process_replay.utils import print_diff
from tinygrad.codegen.kernel import Kernel
from tinygrad.helpers import Context, ContextVar, colored, db_connection, VERSION, getenv, tqdm
# *** process replay settings
PAGE_SIZE = 100
REF = os.getenv("GITHUB_REF_NAME", "")
MAX_DIFF_PCT = getenv("PROCESS_REPLAY_MAX_DIFF_PCT", 20)
RUN_ID = os.getenv("GITHUB_RUN_ID", "HEAD")
TABLE_NAME = f"process_replay_{RUN_ID}_{getenv('GITHUB_RUN_ATTEMPT')}_{VERSION}"
ASSERT_DIFF = getenv("ASSERT_PROCESS_REPLAY", int((k:="[run_process_replay]") in os.getenv("COMMIT_MESSAGE", k) or k in os.getenv("PR_TITLE", k)))
COMPARE_SCHEDULE = getenv("COMPARE_SCHEDULE", int((k:="[compare_schedule]") in os.getenv("COMMIT_MESSAGE", "") or k in os.getenv("PR_TITLE", "")))
SKIP_PROCESS_REPLAY = (k:="[skip_process_replay]") in os.getenv("COMMIT_MESSAGE", "") or k in os.getenv("PR_TITLE", "")
if REF == "master": SKIP_PROCESS_REPLAY = True
early_stop = multiprocessing.Event()
logging.basicConfig(level=logging.INFO, format='%(message)s')
# *** github settings
BASE_URL = f"https://api.github.com/repos/{os.getenv('GITHUB_REPOSITORY', 'tinygrad/tinygrad')}"
GH_HEADERS = {"Authorization": f"Bearer {os.getenv('GH_TOKEN', '')}", "Accept": "application/vnd.github+json", "X-GitHub-Api-Version": "2022-11-28"}
def diff_kernel(offset:int) -> bool:
if early_stop.is_set(): return True
conn = db_connection()
cur = conn.cursor()
cur.execute(f"SELECT val FROM '{TABLE_NAME}' LIMIT ? OFFSET ?", (PAGE_SIZE, offset))
changed = 0
for row in cur.fetchall():
ast, applied_opts = None, None
# try unpickle and linearize
try:
ast, opts, applied_opts, name, compare_src, ctx = pickle.loads(row[0])
with Context(**{k:v for k,v in ctx.items() if k in ContextVar._cache and k != "DEBUG"}):
k = Kernel(ast, opts=opts)
for opt in applied_opts: k.apply_opt(opt)
# NOTE: replay with the captured renderer, not the one in master
good_src = k.opts.render(name, cast(List,k.to_program().uops))
except Exception as e:
logging.warning("FAILED TO RECREATE KERNEL")
logging.info(ast)
logging.info(applied_opts)
logging.info(e)
if ASSERT_DIFF: return True
continue
try: assert compare_src == good_src
except AssertionError:
changed += 1
logging.info("PROCESS REPLAY DETECTED CHANGE")
logging.info(ast)
logging.info(applied_opts)
diff = list(difflib.unified_diff(good_src.splitlines(), compare_src.splitlines()))
for line in diff:
logging.info(colored(line, "red" if line.startswith("-") else "green" if line.startswith("+") else None))
if ASSERT_DIFF: return True
if changed > MAX_DIFF_PCT:
logging.warning(f"detected changes in over {MAX_DIFF_PCT}% of kernels. skipping further diff generation.")
early_stop.set()
break
conn.commit()
cur.close()
return bool(changed)
def print_ast_diff(offset:int):
conn = db_connection()
cur = conn.cursor()
cur.execute(f"SELECT val FROM 'schedule_diff_{VERSION}' LIMIT ? OFFSET ?", (PAGE_SIZE, offset))
for row in cur.fetchall():
buf, asts = pickle.loads(row[0])
if len(asts) == 1:
logging.info(f"{buf} was folded")
logging.info(asts[0])
else: print_diff(asts[0], asts[1])
def get_step_times(data) -> Dict[str, float]:
tms: Dict[str, float] = {}
for step in data["steps"][4:]:
# last task
if step["name"] == "Run actions/upload-artifact@v4": break
fmt = "%Y-%m-%dT%H:%M:%SZ"
tm = datetime.strptime(step["completed_at"], fmt) - datetime.strptime(step["started_at"], fmt)
tms[step["name"]] = tm.total_seconds()
return tms
def process_replay():
# *** speed diff (for benchmarks)
# TODO: fix this for testqualcommbenchmark
if REF == "update_benchmark" and os.environ["GITHUB_JOB"] != "testqualcommbenchmark":
name = {"testmacbenchmark": "Mac", "testnvidiabenchmark": "tinybox green", "testmorenvidiabenchmark": "tinybox green Training",
"testamdbenchmark": "tinybox red", "testmoreamdbenchmark": "tinybox red Training",
"testqualcommbenchmark": "comma"}[os.environ["GITHUB_JOB"]]
compare_jobs = requests.get(f"{BASE_URL}/actions/runs/{RUN_ID}/jobs", headers=GH_HEADERS).json()["jobs"]
compare_job = next(j for j in compare_jobs if j["name"] == f"{name} Benchmark")
ref_runs = requests.get(f"{BASE_URL}/actions/workflows/benchmark.yml/runs?per_page=1&branch=master&status=success", headers=GH_HEADERS).json()
ref_jobs = requests.get(f"{BASE_URL}/actions/runs/{ref_runs['workflow_runs'][0]['id']}/jobs").json()["jobs"]
ref_job = next(j for j in ref_jobs if j["name"] == f"{name} Benchmark")
logging.info(f"comparing speed for {compare_job['id']} against {ref_job['id']}")
compare_tms = get_step_times(compare_job)
ref_tms = get_step_times(ref_job)
diff = [[k, f"{v}s", f"{compare_tms[k]}s", f"{(((v-compare_tms[k])/v)*100):7.2f}%"] for k,v in ref_tms.items() if v>0]
logging.info(tabulate(diff, headers=["job", "master", "compare", "diff"]))
# *** schedule diff
if COMPARE_SCHEDULE:
conn = db_connection()
cur = conn.cursor()
try: has_diff = cur.execute(f"select name from sqlite_master where type='table' and name='schedule_diff_{VERSION}'").fetchone()
except sqlite3.OperationalError:
logging.warning(f"schedule_diff_{VERSION} isn't accessible in master, did DB_VERSION change?")
exit(0)
if has_diff:
row_count = cur.execute(f"select count(*) from 'schedule_diff_{VERSION}'").fetchone()[0]
conn.commit()
cur.close()
with multiprocessing.get_context("spawn").Pool(multiprocessing.cpu_count(), maxtasksperchild=16) as pool:
inputs = list(range(0, row_count, PAGE_SIZE))
list(tqdm(pool.imap_unordered(print_ast_diff, inputs), total=len(inputs)))
pool.close()
pool.join()
pool.terminate()
if ASSERT_DIFF: raise Exception("kernel process replay detected changes")
# *** kernel diff
conn = db_connection()
cur = conn.cursor()
try: row_count = cur.execute(f"select count(*) from '{TABLE_NAME}'").fetchone()[0]
except sqlite3.OperationalError:
logging.warning(f"{TABLE_NAME} isn't accessible in master, did DB_VERSION change?")
exit(0)
conn.commit()
cur.close()
with multiprocessing.get_context("spawn").Pool(multiprocessing.cpu_count(), maxtasksperchild=16) as pool:
inputs = list(range(0, row_count, PAGE_SIZE))
changed = list(tqdm(pool.imap_unordered(diff_kernel, inputs), total=len(inputs)))
pool.close()
pool.join()
pool.terminate()
if any(changed) and ASSERT_DIFF: raise Exception("kernel process replay detected changes")
if __name__ == "__main__":
if SKIP_PROCESS_REPLAY:
logging.info("skipping process replay.")
exit(0)
try: process_replay()
except Exception as e:
# TODO: catch specific Exception
if ASSERT_DIFF: raise e