From 335a261a2ef95499263f1fc9acb290f30e01897b Mon Sep 17 00:00:00 2001 From: George Hotz Date: Wed, 25 Jan 2023 10:25:22 -0800 Subject: [PATCH] test for slow kernel --- test/external_test_gpu_ast.py | 9 +++++++++ tinygrad/llops/ops_gpu.py | 4 ++-- 2 files changed, 11 insertions(+), 2 deletions(-) diff --git a/test/external_test_gpu_ast.py b/test/external_test_gpu_ast.py index 3880632b1e..634dc0f30b 100644 --- a/test/external_test_gpu_ast.py +++ b/test/external_test_gpu_ast.py @@ -100,5 +100,14 @@ class TestAST(unittest.TestCase): ast = LazyOp(MovementOps.RESHAPE, (op1,), (1, 1, 1, 1, 4956)) compile_and_test_ast(ast) + def test_enet_reduce_bs32(self): + buf0 = GPUBuffer(shape=ShapeTracker(shape=(3, 1, 32, 3, 3, 32, 112, 112), views=[View((3, 32, 225, 225), (50176, 150528, 224, 1), 0), ZeroView((3, 32, 224, 224), ((0, 3), (0, 32), (0, 225), (0, 225))), View((3, 1, 32, 3, 3, 32, 112, 112), (1620000, 1620000, 0, 225, 1, 50625, 450, 2), 0)]), hostbuf=GPUBuffer(shape=(32, 3, 224, 224), force_create=True)) + buf1 = GPUBuffer(shape=ShapeTracker(shape=(3, 1, 32, 3, 3, 32, 112, 112), views=[View((3, 1, 32, 3, 3, 32, 112, 112), (0, 12845056, 401408, 0, 0, 12544, 112, 1), 0)]), hostbuf=GPUBuffer(shape=(1, 1, 32, 1, 1, 32, 112, 112), force_create=True)) + op0 = LazyOp(BinaryOps.MUL, (buf0,buf1,), None) + op1 = LazyOp(ReduceOps.SUM, (op0,), (3, 1, 32, 3, 3, 1, 1, 1)) + ast = LazyOp(MovementOps.RESHAPE, (op1,), (3, 32, 3, 3)) + compile_and_test_ast(ast) + + if __name__ == '__main__': unittest.main() diff --git a/tinygrad/llops/ops_gpu.py b/tinygrad/llops/ops_gpu.py index 6685ca238d..dc9ae6ae21 100644 --- a/tinygrad/llops/ops_gpu.py +++ b/tinygrad/llops/ops_gpu.py @@ -17,7 +17,7 @@ NATIVE_EXPLOG = int(os.getenv("NATIVE_EXPLOG", "0")) # this is needed as a swit CLCACHE = int(os.getenv("CLCACHE", "1")) FLOAT16 = int(os.getenv("FLOAT16", "0")) -PRINT_AST = int(os.getenv("PRINT_AST", "0")) +PRINT_AST = os.getenv("PRINT_AST", "0") TEST_AST = int(os.getenv("TEST_AST", "0")) class CLBuffer: @@ -425,7 +425,7 @@ class GPUBuffer(ExplicitExecAST): def exec_ast(cls, ast:LazyOp): k = CLASTKernel(ast) k.codegen()(*k.bufs) - if PRINT_AST: + if PRINT_AST == "1" or PRINT_AST == k.fxn.name: print(k.fxn.name) k.print() if TEST_AST: