Files
tinygrad/test/external/process_replay/test_process_replay.py
qazal dd4e5f1c8d process replay rewrite (#6284)
* process replay rewrite

p2

* start some unittests + exceptions and exits

* shebang

* remove extra kernel init
2024-08-29 15:08:27 +03:00

47 lines
1.5 KiB
Python

import unittest
from test.external.process_replay.process_replay import TABLE_NAME, diff_kernel
from tinygrad.codegen.kernel import Kernel
from tinygrad.helpers import to_function_name, db_connection, diskcache_put, VERSION
from tinygrad.ops import UOp
from tinygrad.renderer.cstyle import ClangRenderer
from tinygrad.tensor import Tensor
def helper_append_replay(ast:UOp, name:str, src:str) -> int:
diskcache_put(TABLE_NAME.replace(f"_{VERSION}", ""), "test_1", (ast, ClangRenderer(), [], to_function_name(name), src, {}))
conn = db_connection()
row_count = conn.execute(f"select count(*) from '{TABLE_NAME}'").fetchone()[0]
return row_count
class TestProcessReplay(unittest.TestCase):
def tearDown(self):
conn = db_connection()
cur = conn.cursor()
cur.execute(f"DELETE FROM '{TABLE_NAME}' WHERE key LIKE 'test_%'")
conn.commit()
cur.close()
def test_simple_diff(self):
out = Tensor([1, 2, 3])+1
ast = out.schedule()[-1].ast
test_src = """
void test(int* restrict a, const int* restrict b) {
for (int ridx0 = 0; ridx0 < 3; ridx0++) {
int val0 = b[ridx0];
a[ridx0] = (val0+1);
}
}
"""
offset = helper_append_replay(ast, "test", test_src)
assert diff_kernel(offset-1)
def test_identical_run(self):
out = Tensor([1, 2, 3])+1
ast = out.schedule()[-1].ast
test_prg = Kernel(ast, ClangRenderer()).to_program()
offset = helper_append_replay(ast, test_prg.name, test_prg.src)
assert not diff_kernel(offset)
if __name__ == "__main__":
unittest.main()