mirror of
https://github.com/tinygrad/tinygrad.git
synced 2026-02-17 18:11:49 -05:00
* process replay rewrite p2 * start some unittests + exceptions and exits * shebang * remove extra kernel init
47 lines
1.5 KiB
Python
47 lines
1.5 KiB
Python
import unittest
|
|
from test.external.process_replay.process_replay import TABLE_NAME, diff_kernel
|
|
|
|
from tinygrad.codegen.kernel import Kernel
|
|
from tinygrad.helpers import to_function_name, db_connection, diskcache_put, VERSION
|
|
from tinygrad.ops import UOp
|
|
from tinygrad.renderer.cstyle import ClangRenderer
|
|
from tinygrad.tensor import Tensor
|
|
|
|
def helper_append_replay(ast:UOp, name:str, src:str) -> int:
|
|
diskcache_put(TABLE_NAME.replace(f"_{VERSION}", ""), "test_1", (ast, ClangRenderer(), [], to_function_name(name), src, {}))
|
|
conn = db_connection()
|
|
row_count = conn.execute(f"select count(*) from '{TABLE_NAME}'").fetchone()[0]
|
|
return row_count
|
|
|
|
class TestProcessReplay(unittest.TestCase):
|
|
def tearDown(self):
|
|
conn = db_connection()
|
|
cur = conn.cursor()
|
|
cur.execute(f"DELETE FROM '{TABLE_NAME}' WHERE key LIKE 'test_%'")
|
|
conn.commit()
|
|
cur.close()
|
|
|
|
def test_simple_diff(self):
|
|
out = Tensor([1, 2, 3])+1
|
|
ast = out.schedule()[-1].ast
|
|
test_src = """
|
|
void test(int* restrict a, const int* restrict b) {
|
|
for (int ridx0 = 0; ridx0 < 3; ridx0++) {
|
|
int val0 = b[ridx0];
|
|
a[ridx0] = (val0+1);
|
|
}
|
|
}
|
|
"""
|
|
offset = helper_append_replay(ast, "test", test_src)
|
|
assert diff_kernel(offset-1)
|
|
|
|
def test_identical_run(self):
|
|
out = Tensor([1, 2, 3])+1
|
|
ast = out.schedule()[-1].ast
|
|
test_prg = Kernel(ast, ClangRenderer()).to_program()
|
|
offset = helper_append_replay(ast, test_prg.name, test_prg.src)
|
|
assert not diff_kernel(offset)
|
|
|
|
if __name__ == "__main__":
|
|
unittest.main()
|