Revert "Revert "cleanup process_replay/* namings [run_process_replay] (#6429)…" (#6442)

This reverts commit eda177da84.
This commit is contained in:
George Hotz
2024-09-10 07:00:16 +08:00
committed by GitHub
parent 8d3450ceab
commit 904f6a63fa
4 changed files with 25 additions and 20 deletions

View File

@@ -2,14 +2,14 @@
# extract asts from process replay artifacts
import os, pickle
from tinygrad.helpers import db_connection, getenv, VERSION
from test.external.process_replay.process_replay import _run_differ
from test.external.process_replay.process_replay import _pmap
PAGE_SIZE = 100
RUN_ID = os.getenv("GITHUB_RUN_ID", "HEAD")
TABLE_NAME = f"process_replay_{RUN_ID}_{getenv('GITHUB_RUN_ATTEMPT')}_{VERSION}"
LOGOPS = os.getenv("LOGOPS", "/tmp/sops")
def extract_ast(offset:int):
def extract_ast(offset:int) -> bool:
logops = open(LOGOPS, "a")
conn = db_connection()
for row in conn.execute(f"SELECT val FROM '{TABLE_NAME}' LIMIT ? OFFSET ?", (PAGE_SIZE, offset)).fetchall():
@@ -19,4 +19,4 @@ def extract_ast(offset:int):
if __name__ == "__main__":
conn = db_connection()
row_count = conn.execute(f"SELECT COUNT(*) FROM '{TABLE_NAME}'").fetchone()[0]
_run_differ(row_count, extract_ast)
_pmap(row_count, extract_ast)

View File

@@ -1,5 +1,5 @@
#!/bin/bash
export LOGOPS=/tmp/sops
export LOGOPS=/tmp/ops
export RUN_PROCESS_REPLAY=1
rm $LOGOPS
test/external/process_replay/reset.py

View File

@@ -40,24 +40,27 @@ def diff_schedule(s:List[Tuple[DefaultDict[LBScheduleItem, List[LBScheduleItem]]
seen_diffs.add(cache_key)
changed += 1
if CAPTURING_PROCESS_REPLAY: diskcache_put("schedule_diff", str(uuid.uuid4()), (str(buf), [ref.ast.key, compare.ast.key]))
if not CI: print_si_diff(si[0], si[1])
if not CI: print_si_diff(ref, compare)
if DEBUG >= 1: print(f"*** process replay: {changed} unique kernel{'s' if changed>1 else ''} changed")
return changed
def print_si_diff(si0:ScheduleItem, si1:ScheduleItem):
def print_si_diff(ref:ScheduleItem, compare:ScheduleItem) -> None:
logging.basicConfig(level=logging.INFO)
print_diff(si0.ast, si1.ast)
print_diff(ref.ast, compare.ast)
# skip lowering/runtime error
with contextlib.suppress(Exception):
ei0 = lower_schedule_item(si0)
ei1 = lower_schedule_item(si1)
assert isinstance(ei0.prg, CompiledRunner) and isinstance(ei1.prg, CompiledRunner)
if DEBUG >= 4: print_diff(ei0.prg.p.src, ei1.prg.p.src)
with contextlib.suppress(Exception): lower_si_diff(ref, compare)
def lower_si_diff(ref:ScheduleItem, compare:ScheduleItem) -> None:
if DEBUG >= 4:
ref_ei = lower_schedule_item(ref)
compare_ei = lower_schedule_item(compare)
assert isinstance(ref_ei.prg, CompiledRunner) and isinstance(compare_ei.prg, CompiledRunner)
print_diff(ref_ei.prg.p.src, compare_ei.prg.p.src)
# TODO: create new Buffers for process replay to test correctness
if getenv("TIMING"):
with Context(DEBUG=2):
tm0 = ei0.run(wait=True)
tm1 = ei1.run(wait=True)
tm0 = ref_ei.run(wait=True)
tm1 = compare_ei.run(wait=True)
assert tm0 is not None and tm1 is not None
tm_diff = ((tm0 - tm1) / tm0) * 100
if tm_diff > 0: print(colored(f"{tm_diff:.2f}% faster", "green"))

View File

@@ -82,17 +82,19 @@ def diff_kernel(offset:int) -> bool:
cur.close()
return bool(changed)
# *** differ runners with multiprocessing
# *** generic runner for executing fxn across all rows of a table in parallel
def _run_differ(row_count:int, differ:Callable[[int], bool]) -> None:
with multiprocessing.get_context("spawn").Pool(multiprocessing.cpu_count(), maxtasksperchild=16) as pool:
def _pmap(row_count:int, fxn:Callable[[int], bool], maxtasksperchild:int=16) -> None:
with multiprocessing.get_context("spawn").Pool(multiprocessing.cpu_count(), maxtasksperchild=maxtasksperchild) as pool:
inputs = list(range(0, row_count, PAGE_SIZE))
changed: List[bool] = list(tqdm(pool.imap_unordered(differ, inputs), total=len(inputs)))
changed: List[bool] = list(tqdm(pool.imap_unordered(fxn, inputs), total=len(inputs)))
pool.close()
pool.join()
pool.terminate()
if any(changed) and ASSERT_DIFF: raise AssertionError("process replay detected changes")
# *** process replay parallel differ runners
def process_replay_schedule() -> None:
conn = db_connection()
cur = conn.cursor()
@@ -105,7 +107,7 @@ def process_replay_schedule() -> None:
if row_count != 0: logging.info("***** schedule diff")
conn.commit()
cur.close()
_run_differ(row_count, diff_schedule)
_pmap(row_count, diff_schedule)
def process_replay_kernel() -> None:
conn = db_connection()
@@ -116,7 +118,7 @@ def process_replay_kernel() -> None:
return None
conn.commit()
cur.close()
_run_differ(row_count, diff_kernel)
_pmap(row_count, diff_kernel)
# *** main loop