mirror of
https://github.com/tinygrad/tinygrad.git
synced 2026-01-09 15:08:02 -05:00
This reverts commit eda177da84.
This commit is contained in:
@@ -2,14 +2,14 @@
|
||||
# extract asts from process replay artifacts
|
||||
import os, pickle
|
||||
from tinygrad.helpers import db_connection, getenv, VERSION
|
||||
from test.external.process_replay.process_replay import _run_differ
|
||||
from test.external.process_replay.process_replay import _pmap
|
||||
|
||||
PAGE_SIZE = 100
|
||||
RUN_ID = os.getenv("GITHUB_RUN_ID", "HEAD")
|
||||
TABLE_NAME = f"process_replay_{RUN_ID}_{getenv('GITHUB_RUN_ATTEMPT')}_{VERSION}"
|
||||
LOGOPS = os.getenv("LOGOPS", "/tmp/sops")
|
||||
|
||||
def extract_ast(offset:int):
|
||||
def extract_ast(offset:int) -> bool:
|
||||
logops = open(LOGOPS, "a")
|
||||
conn = db_connection()
|
||||
for row in conn.execute(f"SELECT val FROM '{TABLE_NAME}' LIMIT ? OFFSET ?", (PAGE_SIZE, offset)).fetchall():
|
||||
@@ -19,4 +19,4 @@ def extract_ast(offset:int):
|
||||
if __name__ == "__main__":
|
||||
conn = db_connection()
|
||||
row_count = conn.execute(f"SELECT COUNT(*) FROM '{TABLE_NAME}'").fetchone()[0]
|
||||
_run_differ(row_count, extract_ast)
|
||||
_pmap(row_count, extract_ast)
|
||||
|
||||
@@ -1,5 +1,5 @@
|
||||
#!/bin/bash
|
||||
export LOGOPS=/tmp/sops
|
||||
export LOGOPS=/tmp/ops
|
||||
export RUN_PROCESS_REPLAY=1
|
||||
rm $LOGOPS
|
||||
test/external/process_replay/reset.py
|
||||
|
||||
23
test/external/process_replay/diff_schedule.py
vendored
23
test/external/process_replay/diff_schedule.py
vendored
@@ -40,24 +40,27 @@ def diff_schedule(s:List[Tuple[DefaultDict[LBScheduleItem, List[LBScheduleItem]]
|
||||
seen_diffs.add(cache_key)
|
||||
changed += 1
|
||||
if CAPTURING_PROCESS_REPLAY: diskcache_put("schedule_diff", str(uuid.uuid4()), (str(buf), [ref.ast.key, compare.ast.key]))
|
||||
if not CI: print_si_diff(si[0], si[1])
|
||||
if not CI: print_si_diff(ref, compare)
|
||||
if DEBUG >= 1: print(f"*** process replay: {changed} unique kernel{'s' if changed>1 else ''} changed")
|
||||
return changed
|
||||
|
||||
def print_si_diff(si0:ScheduleItem, si1:ScheduleItem):
|
||||
def print_si_diff(ref:ScheduleItem, compare:ScheduleItem) -> None:
|
||||
logging.basicConfig(level=logging.INFO)
|
||||
print_diff(si0.ast, si1.ast)
|
||||
print_diff(ref.ast, compare.ast)
|
||||
# skip lowering/runtime error
|
||||
with contextlib.suppress(Exception):
|
||||
ei0 = lower_schedule_item(si0)
|
||||
ei1 = lower_schedule_item(si1)
|
||||
assert isinstance(ei0.prg, CompiledRunner) and isinstance(ei1.prg, CompiledRunner)
|
||||
if DEBUG >= 4: print_diff(ei0.prg.p.src, ei1.prg.p.src)
|
||||
with contextlib.suppress(Exception): lower_si_diff(ref, compare)
|
||||
|
||||
def lower_si_diff(ref:ScheduleItem, compare:ScheduleItem) -> None:
|
||||
if DEBUG >= 4:
|
||||
ref_ei = lower_schedule_item(ref)
|
||||
compare_ei = lower_schedule_item(compare)
|
||||
assert isinstance(ref_ei.prg, CompiledRunner) and isinstance(compare_ei.prg, CompiledRunner)
|
||||
print_diff(ref_ei.prg.p.src, compare_ei.prg.p.src)
|
||||
# TODO: create new Buffers for process replay to test correctness
|
||||
if getenv("TIMING"):
|
||||
with Context(DEBUG=2):
|
||||
tm0 = ei0.run(wait=True)
|
||||
tm1 = ei1.run(wait=True)
|
||||
tm0 = ref_ei.run(wait=True)
|
||||
tm1 = compare_ei.run(wait=True)
|
||||
assert tm0 is not None and tm1 is not None
|
||||
tm_diff = ((tm0 - tm1) / tm0) * 100
|
||||
if tm_diff > 0: print(colored(f"{tm_diff:.2f}% faster", "green"))
|
||||
|
||||
14
test/external/process_replay/process_replay.py
vendored
14
test/external/process_replay/process_replay.py
vendored
@@ -82,17 +82,19 @@ def diff_kernel(offset:int) -> bool:
|
||||
cur.close()
|
||||
return bool(changed)
|
||||
|
||||
# *** differ runners with multiprocessing
|
||||
# *** generic runner for executing fxn across all rows of a table in parallel
|
||||
|
||||
def _run_differ(row_count:int, differ:Callable[[int], bool]) -> None:
|
||||
with multiprocessing.get_context("spawn").Pool(multiprocessing.cpu_count(), maxtasksperchild=16) as pool:
|
||||
def _pmap(row_count:int, fxn:Callable[[int], bool], maxtasksperchild:int=16) -> None:
|
||||
with multiprocessing.get_context("spawn").Pool(multiprocessing.cpu_count(), maxtasksperchild=maxtasksperchild) as pool:
|
||||
inputs = list(range(0, row_count, PAGE_SIZE))
|
||||
changed: List[bool] = list(tqdm(pool.imap_unordered(differ, inputs), total=len(inputs)))
|
||||
changed: List[bool] = list(tqdm(pool.imap_unordered(fxn, inputs), total=len(inputs)))
|
||||
pool.close()
|
||||
pool.join()
|
||||
pool.terminate()
|
||||
if any(changed) and ASSERT_DIFF: raise AssertionError("process replay detected changes")
|
||||
|
||||
# *** process replay parallel differ runners
|
||||
|
||||
def process_replay_schedule() -> None:
|
||||
conn = db_connection()
|
||||
cur = conn.cursor()
|
||||
@@ -105,7 +107,7 @@ def process_replay_schedule() -> None:
|
||||
if row_count != 0: logging.info("***** schedule diff")
|
||||
conn.commit()
|
||||
cur.close()
|
||||
_run_differ(row_count, diff_schedule)
|
||||
_pmap(row_count, diff_schedule)
|
||||
|
||||
def process_replay_kernel() -> None:
|
||||
conn = db_connection()
|
||||
@@ -116,7 +118,7 @@ def process_replay_kernel() -> None:
|
||||
return None
|
||||
conn.commit()
|
||||
cur.close()
|
||||
_run_differ(row_count, diff_kernel)
|
||||
_pmap(row_count, diff_kernel)
|
||||
|
||||
# *** main loop
|
||||
|
||||
|
||||
Reference in New Issue
Block a user