mirror of
https://github.com/tinygrad/tinygrad.git
synced 2026-04-07 03:00:26 -04:00
process replay filter warnings [pr] (#8199)
This commit is contained in:
11
test/external/process_replay/process_replay.py
vendored
11
test/external/process_replay/process_replay.py
vendored
@@ -1,6 +1,6 @@
|
||||
#!/usr/bin/env python3
|
||||
# compare kernels created by HEAD against master
|
||||
import os, multiprocessing, logging, pickle, sqlite3, difflib, functools
|
||||
import os, multiprocessing, logging, pickle, sqlite3, difflib, functools, warnings
|
||||
from typing import Callable, List, Set, Tuple, Union, cast
|
||||
from tinygrad.helpers import VERSION, Context, ContextVar, colored, db_connection, getenv, tqdm
|
||||
from tinygrad.engine.schedule import ScheduleContext, full_ast_rewrite
|
||||
@@ -25,6 +25,7 @@ ASSERT_DIFF = int((flag:="[pr]") in os.getenv("COMMIT_MESSAGE", flag) or flag in
|
||||
if not getenv("ASSERT_PROCESS_REPLAY", 1): ASSERT_DIFF = 0
|
||||
SKIP_PROCESS_REPLAY = (k:="[skip_process_replay]") in os.getenv("COMMIT_MESSAGE", "") or k in os.getenv("PR_TITLE", "")
|
||||
if REF == "master": SKIP_PROCESS_REPLAY = True
|
||||
class ProcessReplayWarning(Warning): pass
|
||||
|
||||
# *** recreators
|
||||
|
||||
@@ -56,9 +57,8 @@ def diff(offset:int, name:str, fxn:Callable) -> Union[Tuple[int, int], bool]:
|
||||
with Context(**{k:v for k,v in args[-2].items() if k in ContextVar._cache and k != "DEBUG"}): good = fxn(*args[:-2])
|
||||
if good is None: continue
|
||||
except Exception as e:
|
||||
logging.warning(f"FAILED TO RECREATE KERNEL {e}")
|
||||
warnings.warn(f"FAILED TO RECREATE KERNEL {e}", ProcessReplayWarning)
|
||||
for x in args[:-1]: logging.info(x)
|
||||
if ASSERT_DIFF: return True
|
||||
continue
|
||||
# diff kernels
|
||||
try: assert args[-1] == good
|
||||
@@ -85,7 +85,7 @@ def _pmap(name:str, fxn:Callable, maxtasksperchild:int=16) -> None:
|
||||
cur = conn.cursor()
|
||||
try: row_count = cur.execute(f"select count(*) from '{name}_{TABLE_NAME}'").fetchone()[0]
|
||||
except sqlite3.OperationalError:
|
||||
logging.warning(f"{name}_{TABLE_NAME} isn't accessible in master, did DB_VERSION change?")
|
||||
warnings.warn(f"{name}_{TABLE_NAME} isn't accessible in master, did DB_VERSION change?", ProcessReplayWarning)
|
||||
return None
|
||||
conn.commit()
|
||||
cur.close()
|
||||
@@ -100,7 +100,7 @@ def _pmap(name:str, fxn:Callable, maxtasksperchild:int=16) -> None:
|
||||
logging.info(f"{sum(changed)} kernels changed")
|
||||
if sum(insertion) != 0: logging.info(colored(f"{sum(insertion)} insertions(+)", "green"))
|
||||
if sum(deletions) != 0: logging.info(colored(f"{sum(deletions)} deletions(-)", "red"))
|
||||
if any(changed) and ASSERT_DIFF: raise AssertionError("process replay detected changes")
|
||||
if any(changed): warnings.warn("process replay detected changes", ProcessReplayWarning)
|
||||
|
||||
# *** main loop
|
||||
|
||||
@@ -109,6 +109,7 @@ if __name__ == "__main__":
|
||||
logging.info("skipping process replay.")
|
||||
exit(0)
|
||||
|
||||
if ASSERT_DIFF: warnings.filterwarnings("error", category=ProcessReplayWarning)
|
||||
for name,fxn in [("schedule", recreate_sched), ("kernel", recreate_kernel)]:
|
||||
logging.info(f"***** {name} diff")
|
||||
try: _pmap(name, fxn)
|
||||
|
||||
Reference in New Issue
Block a user