mirror of
https://github.com/tinygrad/tinygrad.git
synced 2026-01-23 13:58:00 -05:00
even smaller process_replay.py [pr] (#6941)
* even smaller process_replay.py [pr] * delete those tests * dedup asts
This commit is contained in:
83
test/external/process_replay/process_replay.py
vendored
83
test/external/process_replay/process_replay.py
vendored
@@ -1,11 +1,13 @@
|
||||
#!/usr/bin/env python3
|
||||
# compare kernels created by HEAD against master
|
||||
import os, multiprocessing, logging, pickle, sqlite3, difflib
|
||||
import os, multiprocessing, logging, pickle, sqlite3, difflib, functools
|
||||
from typing import Callable, List, Tuple, Union, cast
|
||||
from tinygrad.engine.schedule import full_ast_rewrite
|
||||
from tinygrad.helpers import VERSION, Context, ContextVar, colored, db_connection, getenv, tqdm
|
||||
from tinygrad.codegen.kernel import Kernel
|
||||
from test.external.process_replay.helpers import print_diff
|
||||
from tinygrad.codegen.kernel import Kernel, Opt
|
||||
from test.external.process_replay.helpers import ProcessReplayContext, print_diff
|
||||
from tinygrad.ops import UOp
|
||||
from tinygrad.renderer import Renderer
|
||||
|
||||
# *** process replay settings
|
||||
|
||||
@@ -25,77 +27,50 @@ if not getenv("ASSERT_PROCESS_REPLAY", 1): ASSERT_DIFF = 0
|
||||
SKIP_PROCESS_REPLAY = (k:="[skip_process_replay]") in os.getenv("COMMIT_MESSAGE", "") or k in os.getenv("PR_TITLE", "")
|
||||
if REF == "master": SKIP_PROCESS_REPLAY = True
|
||||
|
||||
# *** differs
|
||||
# *** recreators
|
||||
|
||||
def diff_schedule(offset:int) -> bool:
|
||||
conn = db_connection()
|
||||
cur = conn.cursor()
|
||||
cur.execute(f"SELECT val FROM 'schedule_{TABLE_NAME}' LIMIT ? OFFSET ?", (PAGE_SIZE, offset))
|
||||
changed = 0
|
||||
for row in cur.fetchall():
|
||||
# try unpickle
|
||||
try: raw_ast, ctx, compare_ast = pickle.loads(row[0])
|
||||
except Exception as e:
|
||||
logging.warning(f"FAILED TO UNPICKLE OBJECTS {e}")
|
||||
if ASSERT_DIFF: return True
|
||||
continue
|
||||
# try full_ast_rewrite
|
||||
try: good_ast = full_ast_rewrite(raw_ast, ctx)
|
||||
except Exception as e:
|
||||
logging.warning(f"FAILED TO DO AST REWRITE {e}")
|
||||
logging.info(raw_ast)
|
||||
logging.info(ctx)
|
||||
if ASSERT_DIFF: return True
|
||||
continue
|
||||
# diff asts
|
||||
try: assert compare_ast == good_ast
|
||||
except AssertionError:
|
||||
logging.info("PROCESS REPLAY DETECTED CHANGE")
|
||||
logging.info(raw_ast)
|
||||
logging.info(ctx)
|
||||
print_diff(good_ast, compare_ast)
|
||||
return bool(changed)
|
||||
def recreate_sched(sink:UOp, ctx) -> UOp: return full_ast_rewrite(sink, ctx)
|
||||
def recreate_kernel(ast:UOp, opts:Renderer, applied_opts:List[Opt], name:str, ctx:ProcessReplayContext) -> str:
|
||||
with Context(**{k:v for k,v in ctx.ctx_vars.items() if k in ContextVar._cache and k != "DEBUG"}):
|
||||
k = Kernel(ast, opts=opts)
|
||||
for opt in applied_opts: k.apply_opt(opt)
|
||||
# NOTE: replay with the captured renderer, not the one in master
|
||||
return k.opts.render(name, cast(List,k.to_program().uops))
|
||||
|
||||
def diff_kernel(offset:int) -> Union[Tuple[int, int], bool]:
|
||||
# *** diff a "good" recreation against the generated version
|
||||
|
||||
def diff(offset:int, name:str, fxn:Callable) -> Union[Tuple[int, int], bool]:
|
||||
if early_stop.is_set(): return True
|
||||
conn = db_connection()
|
||||
cur = conn.cursor()
|
||||
cur.execute(f"SELECT val FROM 'kernel_{TABLE_NAME}' LIMIT ? OFFSET ?", (PAGE_SIZE, offset))
|
||||
cur.execute(f"SELECT val FROM '{name}_{TABLE_NAME}' LIMIT ? OFFSET ?", (PAGE_SIZE, offset))
|
||||
additions, deletions, changed = 0, 0, 0
|
||||
for row in cur.fetchall():
|
||||
# try unpickle
|
||||
try: ast, opts, applied_opts, name, compare_src, ctx = pickle.loads(row[0])
|
||||
try: args = pickle.loads(row[0])
|
||||
except Exception as e:
|
||||
logging.warning(f"FAILED TO UNPICKLE OBJECTS {e}")
|
||||
if ASSERT_DIFF: return True
|
||||
continue
|
||||
# try linearize
|
||||
try:
|
||||
with Context(**{k:v for k,v in ctx.ctx_vars.items() if k in ContextVar._cache and k != "DEBUG"}):
|
||||
k = Kernel(ast, opts=opts)
|
||||
for opt in applied_opts: k.apply_opt(opt)
|
||||
# NOTE: replay with the captured renderer, not the one in master
|
||||
good_src = k.opts.render(name, cast(List,k.to_program().uops))
|
||||
# try recreate
|
||||
try: good = fxn(*args[:-1])
|
||||
except Exception as e:
|
||||
logging.warning(f"FAILED TO RECREATE KERNEL {e}")
|
||||
logging.info(ast)
|
||||
logging.info(applied_opts)
|
||||
for x in args[:-1]: logging.info(x)
|
||||
if ASSERT_DIFF: return True
|
||||
continue
|
||||
# diff kernels
|
||||
try: assert compare_src == good_src
|
||||
try: assert args[-1] == good
|
||||
except AssertionError:
|
||||
logging.info("PROCESS REPLAY DETECTED CHANGE")
|
||||
logging.info(ast)
|
||||
logging.info(applied_opts)
|
||||
logging.info(ctx.loc)
|
||||
print_diff(good_src, compare_src)
|
||||
changes = list(difflib.unified_diff(str(good_src).splitlines(), str(compare_src).splitlines()))
|
||||
for x in args[:-1]: logging.info(x)
|
||||
print_diff(good, args[-1])
|
||||
changes = list(difflib.unified_diff(str(good).splitlines(), str(args[-1]).splitlines()))
|
||||
additions += len([x for x in changes if x.startswith("+")])
|
||||
deletions += len([x for x in changes if x.startswith("-")])
|
||||
if ASSERT_DIFF: return additions, deletions
|
||||
if changed > MAX_DIFF_PCT:
|
||||
logging.warning(f"detected changes in over {MAX_DIFF_PCT}% of kernels. skipping further diff generation.")
|
||||
logging.warning(f"detected changes in over {MAX_DIFF_PCT}% of {name}s. skipping further diff generation.")
|
||||
early_stop.set()
|
||||
break
|
||||
conn.commit()
|
||||
@@ -104,7 +79,7 @@ def diff_kernel(offset:int) -> Union[Tuple[int, int], bool]:
|
||||
|
||||
# *** generic runner for executing fxn across all rows of a table in parallel
|
||||
|
||||
def _pmap(name:str, fxn:Callable[[int], Union[bool, Tuple[int, int]]], maxtasksperchild:int=16) -> None:
|
||||
def _pmap(name:str, fxn:Callable, maxtasksperchild:int=16) -> None:
|
||||
conn = db_connection()
|
||||
cur = conn.cursor()
|
||||
try: row_count = cur.execute(f"select count(*) from '{name}_{TABLE_NAME}'").fetchone()[0]
|
||||
@@ -115,7 +90,7 @@ def _pmap(name:str, fxn:Callable[[int], Union[bool, Tuple[int, int]]], maxtasksp
|
||||
cur.close()
|
||||
with multiprocessing.get_context("spawn").Pool(multiprocessing.cpu_count(), maxtasksperchild=maxtasksperchild) as pool:
|
||||
inputs = list(range(0, row_count, PAGE_SIZE))
|
||||
ret: List[Union[bool, Tuple[int, int]]] = list(tqdm(pool.imap_unordered(fxn, inputs), total=len(inputs)))
|
||||
ret: List[Union[bool, Tuple[int, int]]] = list(tqdm(pool.imap_unordered(functools.partial(diff, name=name, fxn=fxn), inputs), total=len(inputs)))
|
||||
pool.close()
|
||||
pool.join()
|
||||
pool.terminate()
|
||||
@@ -133,7 +108,7 @@ if __name__ == "__main__":
|
||||
logging.info("skipping process replay.")
|
||||
exit(0)
|
||||
|
||||
for name,fxn in [("schedule", diff_schedule), ("kernel", diff_kernel)]:
|
||||
for name,fxn in [("schedule", recreate_sched), ("kernel", recreate_kernel)]:
|
||||
logging.info(f"***** {name} diff")
|
||||
try: _pmap(name, fxn)
|
||||
except Exception as e:
|
||||
|
||||
@@ -1,49 +0,0 @@
|
||||
import unittest
|
||||
import contextlib, sqlite3
|
||||
from test.external.process_replay.helpers import ProcessReplayContext
|
||||
from test.external.process_replay.process_replay import TABLE_NAME, diff_kernel
|
||||
|
||||
from tinygrad.codegen.kernel import Kernel
|
||||
from tinygrad.helpers import to_function_name, db_connection, diskcache_put, VERSION
|
||||
from tinygrad.ops import UOp
|
||||
from tinygrad.renderer.cstyle import ClangRenderer
|
||||
from tinygrad.tensor import Tensor
|
||||
|
||||
def helper_append_replay(ast:UOp, name:str, src:str) -> int:
|
||||
name = f"kernel_{TABLE_NAME}"
|
||||
diskcache_put(name.replace(f"_{VERSION}", ""), "test_1", (ast, ClangRenderer(), [], to_function_name(name), src, ProcessReplayContext({})))
|
||||
conn = db_connection()
|
||||
row_count = conn.execute(f"select count(*) from '{name}'").fetchone()[0]
|
||||
return row_count
|
||||
|
||||
class TestProcessReplay(unittest.TestCase):
|
||||
def tearDown(self):
|
||||
conn = db_connection()
|
||||
cur = conn.cursor()
|
||||
with contextlib.suppress(sqlite3.OperationalError): cur.execute(f"DELETE FROM 'kernel_{TABLE_NAME}' WHERE key LIKE 'test_%'")
|
||||
conn.commit()
|
||||
cur.close()
|
||||
|
||||
def test_simple_diff(self):
|
||||
out = Tensor([1, 2, 3])+1
|
||||
ast = out.schedule()[-1].ast
|
||||
test_src = """
|
||||
void test(int* restrict a, const int* restrict b) {
|
||||
for (int ridx0 = 0; ridx0 < 3; ridx0++) {
|
||||
int val0 = b[ridx0];
|
||||
a[ridx0] = (val0+1);
|
||||
}
|
||||
}
|
||||
"""
|
||||
offset = helper_append_replay(ast, "test", test_src)
|
||||
assert diff_kernel(offset-1) == (5, 4)
|
||||
|
||||
def test_identical_run(self):
|
||||
out = Tensor([1, 2, 3])+1
|
||||
ast = out.schedule()[-1].ast
|
||||
test_prg = Kernel(ast, ClangRenderer()).to_program()
|
||||
offset = helper_append_replay(ast, test_prg.name, test_prg.src)
|
||||
assert diff_kernel(offset) == (0, 0)
|
||||
|
||||
if __name__ == "__main__":
|
||||
unittest.main()
|
||||
@@ -1,12 +0,0 @@
|
||||
#!/bin/bash
|
||||
|
||||
# should assert
|
||||
sed -i 's/temp/temp1/g' ./tinygrad/codegen/kernel.py
|
||||
COMPARE_SCHEDULE=0 ASSERT_PROCESS_REPLAY=1 python3 test/external/process_replay/process_replay.py &> /dev/null
|
||||
if [[ $? -eq 0 ]]; then
|
||||
echo "PROCESS REPLAY IS WRONG."
|
||||
exit 1
|
||||
fi
|
||||
# should NOT assert
|
||||
git stash > /dev/null
|
||||
COMPARE_SCHEDULE=0 ASSERT_PROCESS_REPLAY=1 python3 test/external/process_replay/process_replay.py &> /dev/null
|
||||
@@ -728,7 +728,7 @@ class Kernel:
|
||||
|
||||
if getenv("RUN_PROCESS_REPLAY"):
|
||||
from test.external.process_replay.helpers import get_process_replay_ctx
|
||||
diskcache_put("kernel_process_replay", str(id(self)), (self.ast, self.opts, self.applied_opts, name, src, get_process_replay_ctx()))
|
||||
diskcache_put("kernel_process_replay", str(id(self)), (self.ast, self.opts, self.applied_opts, name, get_process_replay_ctx(), src))
|
||||
|
||||
# group non-local bufs by the op type (LOAD or STORE) and the buffer arg. take the max access of that buffer in bytes
|
||||
# TODO: these max and min don't work on symbolic, and results are very wrong.
|
||||
|
||||
@@ -1,4 +1,4 @@
|
||||
import sys, pickle, atexit, uuid
|
||||
import sys, pickle, atexit
|
||||
from collections import defaultdict, deque
|
||||
from dataclasses import dataclass
|
||||
from typing import Callable, Tuple, List, Dict, Optional, DefaultDict, cast
|
||||
@@ -131,7 +131,7 @@ def full_ast_rewrite(base_sink:UOp, ctx:ScheduleItemContext) -> UOp:
|
||||
if not AST_REWRITE: return base_sink
|
||||
sink = graph_rewrite(base_sink, reduceop_fusor)
|
||||
ret = graph_rewrite(sink, enumerate_bufs, ctx)
|
||||
if getenv("RUN_PROCESS_REPLAY"): diskcache_put("schedule_process_replay", str(uuid.uuid4()), (base_sink, ctx, ret))
|
||||
if getenv("RUN_PROCESS_REPLAY"): diskcache_put("schedule_process_replay", str(base_sink.key), (base_sink, ctx, ret))
|
||||
return ret
|
||||
|
||||
# *** List[LazyBuffer] lowering to ScheduleItem ***
|
||||
|
||||
Reference in New Issue
Block a user