process replay filter warnings [pr] (#8199)

This commit is contained in:
qazal
2024-12-13 11:43:43 +02:00
committed by GitHub
parent c5c0d0277d
commit 5864627abe

View File

@@ -1,6 +1,6 @@
#!/usr/bin/env python3
# compare kernels created by HEAD against master
import os, multiprocessing, logging, pickle, sqlite3, difflib, functools
import os, multiprocessing, logging, pickle, sqlite3, difflib, functools, warnings
from typing import Callable, List, Set, Tuple, Union, cast
from tinygrad.helpers import VERSION, Context, ContextVar, colored, db_connection, getenv, tqdm
from tinygrad.engine.schedule import ScheduleContext, full_ast_rewrite
@@ -25,6 +25,7 @@ ASSERT_DIFF = int((flag:="[pr]") in os.getenv("COMMIT_MESSAGE", flag) or flag in
if not getenv("ASSERT_PROCESS_REPLAY", 1): ASSERT_DIFF = 0
SKIP_PROCESS_REPLAY = (k:="[skip_process_replay]") in os.getenv("COMMIT_MESSAGE", "") or k in os.getenv("PR_TITLE", "")
if REF == "master": SKIP_PROCESS_REPLAY = True
class ProcessReplayWarning(Warning): pass
# *** recreators
@@ -56,9 +57,8 @@ def diff(offset:int, name:str, fxn:Callable) -> Union[Tuple[int, int], bool]:
with Context(**{k:v for k,v in args[-2].items() if k in ContextVar._cache and k != "DEBUG"}): good = fxn(*args[:-2])
if good is None: continue
except Exception as e:
logging.warning(f"FAILED TO RECREATE KERNEL {e}")
warnings.warn(f"FAILED TO RECREATE KERNEL {e}", ProcessReplayWarning)
for x in args[:-1]: logging.info(x)
if ASSERT_DIFF: return True
continue
# diff kernels
try: assert args[-1] == good
@@ -85,7 +85,7 @@ def _pmap(name:str, fxn:Callable, maxtasksperchild:int=16) -> None:
cur = conn.cursor()
try: row_count = cur.execute(f"select count(*) from '{name}_{TABLE_NAME}'").fetchone()[0]
except sqlite3.OperationalError:
logging.warning(f"{name}_{TABLE_NAME} isn't accessible in master, did DB_VERSION change?")
warnings.warn(f"{name}_{TABLE_NAME} isn't accessible in master, did DB_VERSION change?", ProcessReplayWarning)
return None
conn.commit()
cur.close()
@@ -100,7 +100,7 @@ def _pmap(name:str, fxn:Callable, maxtasksperchild:int=16) -> None:
logging.info(f"{sum(changed)} kernels changed")
if sum(insertion) != 0: logging.info(colored(f"{sum(insertion)} insertions(+)", "green"))
if sum(deletions) != 0: logging.info(colored(f"{sum(deletions)} deletions(-)", "red"))
if any(changed) and ASSERT_DIFF: raise AssertionError("process replay detected changes")
if any(changed): warnings.warn("process replay detected changes", ProcessReplayWarning)
# *** main loop
@@ -109,6 +109,7 @@ if __name__ == "__main__":
logging.info("skipping process replay.")
exit(0)
if ASSERT_DIFF: warnings.filterwarnings("error", category=ProcessReplayWarning)
for name,fxn in [("schedule", recreate_sched), ("kernel", recreate_kernel)]:
logging.info(f"***** {name} diff")
try: _pmap(name, fxn)