diff --git a/extra/optimization/extract_dataset.py b/extra/optimization/extract_dataset.py index 7a4df252e9..174c276e37 100755 --- a/extra/optimization/extract_dataset.py +++ b/extra/optimization/extract_dataset.py @@ -2,14 +2,14 @@ # extract asts from process replay artifacts import os, pickle from tinygrad.helpers import db_connection, getenv, VERSION -from test.external.process_replay.process_replay import _pmap +from test.external.process_replay.process_replay import _run_differ PAGE_SIZE = 100 RUN_ID = os.getenv("GITHUB_RUN_ID", "HEAD") TABLE_NAME = f"process_replay_{RUN_ID}_{getenv('GITHUB_RUN_ATTEMPT')}_{VERSION}" LOGOPS = os.getenv("LOGOPS", "/tmp/ops") -def extract_ast(offset:int) -> bool: +def extract_ast(offset:int): logops = open(LOGOPS, "a") conn = db_connection() for row in conn.execute(f"SELECT val FROM '{TABLE_NAME}' LIMIT ? OFFSET ?", (PAGE_SIZE, offset)).fetchall(): @@ -19,4 +19,4 @@ def extract_ast(offset:int) -> bool: if __name__ == "__main__": conn = db_connection() row_count = conn.execute(f"SELECT COUNT(*) FROM '{TABLE_NAME}'").fetchone()[0] - _pmap(row_count, extract_ast) + _run_differ(row_count, extract_ast) diff --git a/extra/optimization/generate_dataset.sh b/extra/optimization/generate_dataset.sh index 905d7d589b..36f57890f0 100755 --- a/extra/optimization/generate_dataset.sh +++ b/extra/optimization/generate_dataset.sh @@ -1,5 +1,5 @@ #!/bin/bash -export LOGOPS=/tmp/ops +export LOGOPS=/tmp/sops export RUN_PROCESS_REPLAY=1 rm $LOGOPS test/external/process_replay/reset.py diff --git a/test/external/process_replay/diff_schedule.py b/test/external/process_replay/diff_schedule.py index cb2ac33002..25df93342d 100644 --- a/test/external/process_replay/diff_schedule.py +++ b/test/external/process_replay/diff_schedule.py @@ -40,27 +40,24 @@ def diff_schedule(s:List[Tuple[DefaultDict[LBScheduleItem, List[LBScheduleItem]] seen_diffs.add(cache_key) changed += 1 if CAPTURING_PROCESS_REPLAY: diskcache_put("schedule_diff", str(uuid.uuid4()), (str(buf), [ref.ast.key, compare.ast.key])) - if not CI: print_si_diff(ref, compare) + if not CI: print_si_diff(si[0], si[1]) if DEBUG >= 1: print(f"*** process replay: {changed} unique kernel{'s' if changed>1 else ''} changed") return changed -def print_si_diff(ref:ScheduleItem, compare:ScheduleItem) -> None: +def print_si_diff(si0:ScheduleItem, si1:ScheduleItem): logging.basicConfig(level=logging.INFO) - print_diff(ref.ast, compare.ast) + print_diff(si0.ast, si1.ast) # skip lowering/runtime error - with contextlib.suppress(Exception): lower_si_diff(ref, compare) - -def lower_si_diff(ref:ScheduleItem, compare:ScheduleItem) -> None: - if DEBUG >= 4: - ref_ei = lower_schedule_item(ref) - compare_ei = lower_schedule_item(compare) - assert isinstance(ref_ei.prg, CompiledRunner) and isinstance(compare_ei.prg, CompiledRunner) - print_diff(ref_ei.prg.p.src, compare_ei.prg.p.src) + with contextlib.suppress(Exception): + ei0 = lower_schedule_item(si0) + ei1 = lower_schedule_item(si1) + assert isinstance(ei0.prg, CompiledRunner) and isinstance(ei1.prg, CompiledRunner) + if DEBUG >= 4: print_diff(ei0.prg.p.src, ei1.prg.p.src) # TODO: create new Buffers for process replay to test correctness if getenv("TIMING"): with Context(DEBUG=2): - tm0 = ref_ei.run(wait=True) - tm1 = compare_ei.run(wait=True) + tm0 = ei0.run(wait=True) + tm1 = ei1.run(wait=True) assert tm0 is not None and tm1 is not None tm_diff = ((tm0 - tm1) / tm0) * 100 if tm_diff > 0: print(colored(f"{tm_diff:.2f}% faster", "green")) diff --git a/test/external/process_replay/process_replay.py b/test/external/process_replay/process_replay.py index 9345ee87df..edabbe9b42 100755 --- a/test/external/process_replay/process_replay.py +++ b/test/external/process_replay/process_replay.py @@ -82,19 +82,17 @@ def diff_kernel(offset:int) -> bool: cur.close() return bool(changed) -# *** generic runner for executing fxn across all rows of a table in parallel +# *** differ runners with multiprocessing -def _pmap(row_count:int, fxn:Callable[[int], bool], maxtasksperchild:int=16) -> None: - with multiprocessing.get_context("spawn").Pool(multiprocessing.cpu_count(), maxtasksperchild=maxtasksperchild) as pool: +def _run_differ(row_count:int, differ:Callable[[int], bool]) -> None: + with multiprocessing.get_context("spawn").Pool(multiprocessing.cpu_count(), maxtasksperchild=16) as pool: inputs = list(range(0, row_count, PAGE_SIZE)) - changed: List[bool] = list(tqdm(pool.imap_unordered(fxn, inputs), total=len(inputs))) + changed: List[bool] = list(tqdm(pool.imap_unordered(differ, inputs), total=len(inputs))) pool.close() pool.join() pool.terminate() if any(changed) and ASSERT_DIFF: raise AssertionError("process replay detected changes") -# *** process replay parallel differ runners - def process_replay_schedule() -> None: conn = db_connection() cur = conn.cursor() @@ -107,7 +105,7 @@ def process_replay_schedule() -> None: if row_count != 0: logging.info("***** schedule diff") conn.commit() cur.close() - _pmap(row_count, diff_schedule) + _run_differ(row_count, diff_schedule) def process_replay_kernel() -> None: conn = db_connection() @@ -118,7 +116,7 @@ def process_replay_kernel() -> None: return None conn.commit() cur.close() - _pmap(row_count, diff_kernel) + _run_differ(row_count, diff_kernel) # *** main loop