separate process replay main loop (#5734)

* separate process replay main loop

* [run_process_replay]

* add kernel_changed

* test with [run_process_replay]

* revert temp [run_process_replay]
This commit is contained in:
qazal
2024-07-27 02:43:08 +08:00
committed by GitHub
parent 9838c1a6ff
commit 94d578396f

View File

@@ -19,7 +19,7 @@ SKIP_PROCESS_REPLAY = (k:="[skip_process_replay]") in os.getenv("COMMIT_MESSAGE"
early_stop = multiprocessing.Event()
logging.basicConfig(level=logging.INFO, format='%(message)s')
def process_replay(offset:int, ref_schedule:List[LazyOp]):
def diff_kernel(offset:int, ref_schedule:List[LazyOp], kernel_changed):
if early_stop.is_set(): return
conn = db_connection()
cur = conn.cursor()
@@ -39,6 +39,7 @@ def process_replay(offset:int, ref_schedule:List[LazyOp]):
logging.info(ast)
logging.info(applied_opts)
logging.info(e)
kernel_changed.value = True
if ASSERT_DIFF: raise e
continue
# try compare
@@ -55,6 +56,7 @@ def process_replay(offset:int, ref_schedule:List[LazyOp]):
diff = list(difflib.unified_diff(good_src.splitlines(), compare_src.splitlines()))
for line in diff:
logging.info(colored(line, "red" if line.startswith("-") else "green" if line.startswith("+") else None))
kernel_changed.value = True
if ASSERT_DIFF: raise e
if changed > MAX_DIFF_PCT:
logging.warning(f"detected changes in over {MAX_DIFF_PCT}% of kernels. skipping further diff generation.")
@@ -71,10 +73,7 @@ def get_ref_schedule(offset:int, ref_schedule):
conn.commit()
cur.close()
if __name__ == "__main__":
if SKIP_PROCESS_REPLAY:
logging.info("skipping process replay.")
exit(0)
def process_replay():
ref_schedule = multiprocessing.Manager().list()
# *** download the reference schedule
if COMPARE_SCHEDULE:
@@ -112,7 +111,17 @@ if __name__ == "__main__":
conn.commit()
cur.close()
processes = []
changed = multiprocessing.Manager().Value('b', False)
for i in tqdm(range(0, row_count, PAGE_SIZE)):
processes.append(p:=multiprocessing.Process(target=process_replay, args=(i, ref_schedule)))
processes.append(p:=multiprocessing.Process(target=diff_kernel, args=(i, ref_schedule, changed)))
p.start()
for p in processes: p.join()
if changed.value and ASSERT_DIFF: raise Exception("process replay detected changes")
if __name__ == "__main__":
if SKIP_PROCESS_REPLAY:
logging.info("skipping process replay.")
exit(0)
try: process_replay()
except Exception as e:
if ASSERT_DIFF: raise e