even smaller process_replay.py [pr] (#6941)

* even smaller process_replay.py [pr]

* delete those tests

* dedup asts
This commit is contained in:
qazal
2024-10-08 15:43:22 +03:00
committed by GitHub
parent 851f39653a
commit 2800520dd5
5 changed files with 32 additions and 118 deletions

View File

@@ -1,11 +1,13 @@
#!/usr/bin/env python3
# compare kernels created by HEAD against master
import os, multiprocessing, logging, pickle, sqlite3, difflib
import os, multiprocessing, logging, pickle, sqlite3, difflib, functools
from typing import Callable, List, Tuple, Union, cast
from tinygrad.engine.schedule import full_ast_rewrite
from tinygrad.helpers import VERSION, Context, ContextVar, colored, db_connection, getenv, tqdm
from tinygrad.codegen.kernel import Kernel
from test.external.process_replay.helpers import print_diff
from tinygrad.codegen.kernel import Kernel, Opt
from test.external.process_replay.helpers import ProcessReplayContext, print_diff
from tinygrad.ops import UOp
from tinygrad.renderer import Renderer
# *** process replay settings
@@ -25,77 +27,50 @@ if not getenv("ASSERT_PROCESS_REPLAY", 1): ASSERT_DIFF = 0
SKIP_PROCESS_REPLAY = (k:="[skip_process_replay]") in os.getenv("COMMIT_MESSAGE", "") or k in os.getenv("PR_TITLE", "")
if REF == "master": SKIP_PROCESS_REPLAY = True
# *** differs
# *** recreators
def diff_schedule(offset:int) -> bool:
conn = db_connection()
cur = conn.cursor()
cur.execute(f"SELECT val FROM 'schedule_{TABLE_NAME}' LIMIT ? OFFSET ?", (PAGE_SIZE, offset))
changed = 0
for row in cur.fetchall():
# try unpickle
try: raw_ast, ctx, compare_ast = pickle.loads(row[0])
except Exception as e:
logging.warning(f"FAILED TO UNPICKLE OBJECTS {e}")
if ASSERT_DIFF: return True
continue
# try full_ast_rewrite
try: good_ast = full_ast_rewrite(raw_ast, ctx)
except Exception as e:
logging.warning(f"FAILED TO DO AST REWRITE {e}")
logging.info(raw_ast)
logging.info(ctx)
if ASSERT_DIFF: return True
continue
# diff asts
try: assert compare_ast == good_ast
except AssertionError:
logging.info("PROCESS REPLAY DETECTED CHANGE")
logging.info(raw_ast)
logging.info(ctx)
print_diff(good_ast, compare_ast)
return bool(changed)
def recreate_sched(sink:UOp, ctx) -> UOp: return full_ast_rewrite(sink, ctx)
def recreate_kernel(ast:UOp, opts:Renderer, applied_opts:List[Opt], name:str, ctx:ProcessReplayContext) -> str:
with Context(**{k:v for k,v in ctx.ctx_vars.items() if k in ContextVar._cache and k != "DEBUG"}):
k = Kernel(ast, opts=opts)
for opt in applied_opts: k.apply_opt(opt)
# NOTE: replay with the captured renderer, not the one in master
return k.opts.render(name, cast(List,k.to_program().uops))
def diff_kernel(offset:int) -> Union[Tuple[int, int], bool]:
# *** diff a "good" recreation against the generated version
def diff(offset:int, name:str, fxn:Callable) -> Union[Tuple[int, int], bool]:
if early_stop.is_set(): return True
conn = db_connection()
cur = conn.cursor()
cur.execute(f"SELECT val FROM 'kernel_{TABLE_NAME}' LIMIT ? OFFSET ?", (PAGE_SIZE, offset))
cur.execute(f"SELECT val FROM '{name}_{TABLE_NAME}' LIMIT ? OFFSET ?", (PAGE_SIZE, offset))
additions, deletions, changed = 0, 0, 0
for row in cur.fetchall():
# try unpickle
try: ast, opts, applied_opts, name, compare_src, ctx = pickle.loads(row[0])
try: args = pickle.loads(row[0])
except Exception as e:
logging.warning(f"FAILED TO UNPICKLE OBJECTS {e}")
if ASSERT_DIFF: return True
continue
# try linearize
try:
with Context(**{k:v for k,v in ctx.ctx_vars.items() if k in ContextVar._cache and k != "DEBUG"}):
k = Kernel(ast, opts=opts)
for opt in applied_opts: k.apply_opt(opt)
# NOTE: replay with the captured renderer, not the one in master
good_src = k.opts.render(name, cast(List,k.to_program().uops))
# try recreate
try: good = fxn(*args[:-1])
except Exception as e:
logging.warning(f"FAILED TO RECREATE KERNEL {e}")
logging.info(ast)
logging.info(applied_opts)
for x in args[:-1]: logging.info(x)
if ASSERT_DIFF: return True
continue
# diff kernels
try: assert compare_src == good_src
try: assert args[-1] == good
except AssertionError:
logging.info("PROCESS REPLAY DETECTED CHANGE")
logging.info(ast)
logging.info(applied_opts)
logging.info(ctx.loc)
print_diff(good_src, compare_src)
changes = list(difflib.unified_diff(str(good_src).splitlines(), str(compare_src).splitlines()))
for x in args[:-1]: logging.info(x)
print_diff(good, args[-1])
changes = list(difflib.unified_diff(str(good).splitlines(), str(args[-1]).splitlines()))
additions += len([x for x in changes if x.startswith("+")])
deletions += len([x for x in changes if x.startswith("-")])
if ASSERT_DIFF: return additions, deletions
if changed > MAX_DIFF_PCT:
logging.warning(f"detected changes in over {MAX_DIFF_PCT}% of kernels. skipping further diff generation.")
logging.warning(f"detected changes in over {MAX_DIFF_PCT}% of {name}s. skipping further diff generation.")
early_stop.set()
break
conn.commit()
@@ -104,7 +79,7 @@ def diff_kernel(offset:int) -> Union[Tuple[int, int], bool]:
# *** generic runner for executing fxn across all rows of a table in parallel
def _pmap(name:str, fxn:Callable[[int], Union[bool, Tuple[int, int]]], maxtasksperchild:int=16) -> None:
def _pmap(name:str, fxn:Callable, maxtasksperchild:int=16) -> None:
conn = db_connection()
cur = conn.cursor()
try: row_count = cur.execute(f"select count(*) from '{name}_{TABLE_NAME}'").fetchone()[0]
@@ -115,7 +90,7 @@ def _pmap(name:str, fxn:Callable[[int], Union[bool, Tuple[int, int]]], maxtasksp
cur.close()
with multiprocessing.get_context("spawn").Pool(multiprocessing.cpu_count(), maxtasksperchild=maxtasksperchild) as pool:
inputs = list(range(0, row_count, PAGE_SIZE))
ret: List[Union[bool, Tuple[int, int]]] = list(tqdm(pool.imap_unordered(fxn, inputs), total=len(inputs)))
ret: List[Union[bool, Tuple[int, int]]] = list(tqdm(pool.imap_unordered(functools.partial(diff, name=name, fxn=fxn), inputs), total=len(inputs)))
pool.close()
pool.join()
pool.terminate()
@@ -133,7 +108,7 @@ if __name__ == "__main__":
logging.info("skipping process replay.")
exit(0)
for name,fxn in [("schedule", diff_schedule), ("kernel", diff_kernel)]:
for name,fxn in [("schedule", recreate_sched), ("kernel", recreate_kernel)]:
logging.info(f"***** {name} diff")
try: _pmap(name, fxn)
except Exception as e:

View File

@@ -1,49 +0,0 @@
import unittest
import contextlib, sqlite3
from test.external.process_replay.helpers import ProcessReplayContext
from test.external.process_replay.process_replay import TABLE_NAME, diff_kernel
from tinygrad.codegen.kernel import Kernel
from tinygrad.helpers import to_function_name, db_connection, diskcache_put, VERSION
from tinygrad.ops import UOp
from tinygrad.renderer.cstyle import ClangRenderer
from tinygrad.tensor import Tensor
def helper_append_replay(ast:UOp, name:str, src:str) -> int:
name = f"kernel_{TABLE_NAME}"
diskcache_put(name.replace(f"_{VERSION}", ""), "test_1", (ast, ClangRenderer(), [], to_function_name(name), src, ProcessReplayContext({})))
conn = db_connection()
row_count = conn.execute(f"select count(*) from '{name}'").fetchone()[0]
return row_count
class TestProcessReplay(unittest.TestCase):
def tearDown(self):
conn = db_connection()
cur = conn.cursor()
with contextlib.suppress(sqlite3.OperationalError): cur.execute(f"DELETE FROM 'kernel_{TABLE_NAME}' WHERE key LIKE 'test_%'")
conn.commit()
cur.close()
def test_simple_diff(self):
out = Tensor([1, 2, 3])+1
ast = out.schedule()[-1].ast
test_src = """
void test(int* restrict a, const int* restrict b) {
for (int ridx0 = 0; ridx0 < 3; ridx0++) {
int val0 = b[ridx0];
a[ridx0] = (val0+1);
}
}
"""
offset = helper_append_replay(ast, "test", test_src)
assert diff_kernel(offset-1) == (5, 4)
def test_identical_run(self):
out = Tensor([1, 2, 3])+1
ast = out.schedule()[-1].ast
test_prg = Kernel(ast, ClangRenderer()).to_program()
offset = helper_append_replay(ast, test_prg.name, test_prg.src)
assert diff_kernel(offset) == (0, 0)
if __name__ == "__main__":
unittest.main()

View File

@@ -1,12 +0,0 @@
#!/bin/bash
# should assert
sed -i 's/temp/temp1/g' ./tinygrad/codegen/kernel.py
COMPARE_SCHEDULE=0 ASSERT_PROCESS_REPLAY=1 python3 test/external/process_replay/process_replay.py &> /dev/null
if [[ $? -eq 0 ]]; then
echo "PROCESS REPLAY IS WRONG."
exit 1
fi
# should NOT assert
git stash > /dev/null
COMPARE_SCHEDULE=0 ASSERT_PROCESS_REPLAY=1 python3 test/external/process_replay/process_replay.py &> /dev/null

View File

@@ -728,7 +728,7 @@ class Kernel:
if getenv("RUN_PROCESS_REPLAY"):
from test.external.process_replay.helpers import get_process_replay_ctx
diskcache_put("kernel_process_replay", str(id(self)), (self.ast, self.opts, self.applied_opts, name, src, get_process_replay_ctx()))
diskcache_put("kernel_process_replay", str(id(self)), (self.ast, self.opts, self.applied_opts, name, get_process_replay_ctx(), src))
# group non-local bufs by the op type (LOAD or STORE) and the buffer arg. take the max access of that buffer in bytes
# TODO: these max and min don't work on symbolic, and results are very wrong.

View File

@@ -1,4 +1,4 @@
import sys, pickle, atexit, uuid
import sys, pickle, atexit
from collections import defaultdict, deque
from dataclasses import dataclass
from typing import Callable, Tuple, List, Dict, Optional, DefaultDict, cast
@@ -131,7 +131,7 @@ def full_ast_rewrite(base_sink:UOp, ctx:ScheduleItemContext) -> UOp:
if not AST_REWRITE: return base_sink
sink = graph_rewrite(base_sink, reduceop_fusor)
ret = graph_rewrite(sink, enumerate_bufs, ctx)
if getenv("RUN_PROCESS_REPLAY"): diskcache_put("schedule_process_replay", str(uuid.uuid4()), (base_sink, ctx, ret))
if getenv("RUN_PROCESS_REPLAY"): diskcache_put("schedule_process_replay", str(base_sink.key), (base_sink, ctx, ret))
return ret
# *** List[LazyBuffer] lowering to ScheduleItem ***