cleanup process_replay/* namings [run_process_replay] (#6429)

This commit is contained in:
qazal
2024-09-09 16:59:04 +08:00
committed by GitHub
parent 8186e4e7d6
commit f4e83b30b4
4 changed files with 25 additions and 20 deletions

View File

@@ -40,24 +40,27 @@ def diff_schedule(s:List[Tuple[DefaultDict[LBScheduleItem, List[LBScheduleItem]]
seen_diffs.add(cache_key)
changed += 1
if CAPTURING_PROCESS_REPLAY: diskcache_put("schedule_diff", str(uuid.uuid4()), (str(buf), [ref.ast.key, compare.ast.key]))
if not CI: print_si_diff(si[0], si[1])
if not CI: print_si_diff(ref, compare)
if DEBUG >= 1: print(f"*** process replay: {changed} unique kernel{'s' if changed>1 else ''} changed")
return changed
def print_si_diff(si0:ScheduleItem, si1:ScheduleItem):
def print_si_diff(ref:ScheduleItem, compare:ScheduleItem) -> None:
logging.basicConfig(level=logging.INFO)
print_diff(si0.ast, si1.ast)
print_diff(ref.ast, compare.ast)
# skip lowering/runtime error
with contextlib.suppress(Exception):
ei0 = lower_schedule_item(si0)
ei1 = lower_schedule_item(si1)
assert isinstance(ei0.prg, CompiledRunner) and isinstance(ei1.prg, CompiledRunner)
if DEBUG >= 4: print_diff(ei0.prg.p.src, ei1.prg.p.src)
with contextlib.suppress(Exception): lower_si_diff(ref, compare)
def lower_si_diff(ref:ScheduleItem, compare:ScheduleItem) -> None:
if DEBUG >= 4:
ref_ei = lower_schedule_item(ref)
compare_ei = lower_schedule_item(compare)
assert isinstance(ref_ei.prg, CompiledRunner) and isinstance(compare_ei.prg, CompiledRunner)
print_diff(ref_ei.prg.p.src, compare_ei.prg.p.src)
# TODO: create new Buffers for process replay to test correctness
if getenv("TIMING"):
with Context(DEBUG=2):
tm0 = ei0.run(wait=True)
tm1 = ei1.run(wait=True)
tm0 = ref_ei.run(wait=True)
tm1 = compare_ei.run(wait=True)
assert tm0 is not None and tm1 is not None
tm_diff = ((tm0 - tm1) / tm0) * 100
if tm_diff > 0: print(colored(f"{tm_diff:.2f}% faster", "green"))

View File

@@ -82,17 +82,19 @@ def diff_kernel(offset:int) -> bool:
cur.close()
return bool(changed)
# *** differ runners with multiprocessing
# *** generic runner for executing fxn across all rows of a table in parallel
def _run_differ(row_count:int, differ:Callable[[int], bool]) -> None:
with multiprocessing.get_context("spawn").Pool(multiprocessing.cpu_count(), maxtasksperchild=16) as pool:
def _pmap(row_count:int, fxn:Callable[[int], bool], maxtasksperchild:int=16) -> None:
with multiprocessing.get_context("spawn").Pool(multiprocessing.cpu_count(), maxtasksperchild=maxtasksperchild) as pool:
inputs = list(range(0, row_count, PAGE_SIZE))
changed: List[bool] = list(tqdm(pool.imap_unordered(differ, inputs), total=len(inputs)))
changed: List[bool] = list(tqdm(pool.imap_unordered(fxn, inputs), total=len(inputs)))
pool.close()
pool.join()
pool.terminate()
if any(changed) and ASSERT_DIFF: raise AssertionError("process replay detected changes")
# *** process replay parallel differ runners
def process_replay_schedule() -> None:
conn = db_connection()
cur = conn.cursor()
@@ -105,7 +107,7 @@ def process_replay_schedule() -> None:
if row_count != 0: logging.info("***** schedule diff")
conn.commit()
cur.close()
_run_differ(row_count, diff_schedule)
_pmap(row_count, diff_schedule)
def process_replay_kernel() -> None:
conn = db_connection()
@@ -116,7 +118,7 @@ def process_replay_kernel() -> None:
return None
conn.commit()
cur.close()
_run_differ(row_count, diff_kernel)
_pmap(row_count, diff_kernel)
# *** main loop