diff --git a/python/test/unit/tools/test_aot.py b/python/test/unit/tools/test_aot.py index 06a7ed2b1..cd6900321 100644 --- a/python/test/unit/tools/test_aot.py +++ b/python/test/unit/tools/test_aot.py @@ -59,7 +59,7 @@ def kernel(C, A, B, M, N, K, tl.store(c_ptrs, c) """ -test_utils_src = ''' +test_utils_src = """ #include #include #include @@ -93,23 +93,26 @@ static void read_csv_to_buffer(char *filename, int16_t *buffer, int size) { index++; } fclose(file); -}''' +}""" def gen_kernel_library(dir, libname): c_files = glob.glob(os.path.join(dir, "*.c")) - subprocess.run(["gcc"] + c_files + ["-I", cuda_include_dir(), - "-c", "-fPIC"], - check=True, cwd=dir) + subprocess.run( + ["gcc"] + c_files + ["-I", cuda_include_dir(), "-c", "-fPIC"], + check=True, + cwd=dir, + ) o_files = glob.glob(os.path.join(dir, "*.o")) - subprocess.run(["gcc"] + o_files + ["-shared", - "-o", libname, - "-L", libcuda_dirs()[0]], - check=True, cwd=dir) + subprocess.run( + ["gcc"] + o_files + ["-shared", "-o", libname, "-L", libcuda_dirs()[0]], + check=True, + cwd=dir, + ) def gen_test_bin(dir, M, N, K, exe="test", algo_id=0): - test_src = f''' + test_src = f""" int main(int argc, char **argv) {{ int M = {M}, N = {N}, K = {K}; @@ -165,17 +168,30 @@ int main(int argc, char **argv) {{ cuMemFree(C); cuCtxDestroy(ctx); }} -''' +""" src = test_utils_src + test_src with open(os.path.join(dir, "test.c"), "w") as file: file.write(src) - subprocess.run(["gcc"] + ["test.c", - "-I", cuda_include_dir(), - "-L", libcuda_dirs()[0], - "-l", "cuda", - "-L", dir, - "-l", "kernel", - "-o", exe], check=True, cwd=dir) + subprocess.run( + ["gcc"] + + [ + "test.c", + "-I", + cuda_include_dir(), + "-L", + libcuda_dirs()[0], + "-l", + "cuda", + "-L", + dir, + "-l", + "kernel", + "-o", + exe, + ], + check=True, + cwd=dir, + ) def write_triton_kernels(dir, src, util_src): @@ -190,16 +206,69 @@ def write_triton_kernels(dir, src, util_src): return kernel_path -def compile_aot_kernels(dir, kernel_path, dtype, BM, BN, BK, ha_hb_hints): +def _compile_kernel( + dir, signature, kernel_name, out_name, out_path, num_warps, grid, kernel_path +): compiler_path = os.path.join(triton.tools.__path__[0], "compile.py") + subprocess.run( + [ + sys.executable, + compiler_path, + "-n", + kernel_name, + "--signature", + signature, + "--out-name", + out_name, + "-o", + out_path, + "-w", + str(num_warps), + "-g", + grid, + kernel_path, + ], + check=True, + cwd=dir, + ) + + +# Edge case kernel with no specialization +def compile_aot_kernel_no_specialization(dir, kernel_path, dtype, BM, BN, BK): + # compile all desired configs + sig = f"*fp32, *{dtype}, *{dtype}, i32, i32, i32, i32, i32, i32, i32, i32, i32, {BM}, {BN}, {BK}" + name = f"matmul_{dtype}" + grid = f"M/{BM}, N/{BN}, 1" + _compile_kernel( + dir=dir, + signature=sig, + kernel_name="kernel", + out_name=name, + out_path=name, + num_warps=1, + grid=grid, + kernel_path=kernel_path, + ) + + +def compile_aot_kernels(dir, kernel_path, dtype, BM, BN, BK, ha_hb_hints): # compile all desired configs for ha in ha_hb_hints: for hb in ha_hb_hints: - sig = f'*fp32:16, *{dtype}:16, *{dtype}:16, i32, i32, i32, i32{ha}, i32:1, i32{hb}, i32:1, i32:16, i32:1, {BM}, {BN}, {BK}' + sig = f"*fp32:16, *{dtype}:16, *{dtype}:16, i32, i32, i32, i32{ha}, i32:1, i32{hb}, i32:1, i32:16, i32:1, {BM}, {BN}, {BK}" name = f"matmul_{dtype}" - grid = f'M/{BM}, N/{BN}, 1' - subprocess.run([sys.executable, compiler_path, "-n", "kernel", "--signature", sig, "--out-name", name, "-o", name, "-w", "1", "-g", grid, kernel_path], check=True, cwd=dir) + grid = f"M/{BM}, N/{BN}, 1" + _compile_kernel( + dir=dir, + signature=sig, + kernel_name="kernel", + out_name=name, + out_path=name, + num_warps=1, + grid=grid, + kernel_path=kernel_path, + ) def link_aot_kernels(dir): @@ -207,7 +276,9 @@ def link_aot_kernels(dir): # link all desired configs h_files = glob.glob(os.path.join(dir, "*.h")) - subprocess.run([sys.executable, linker_path] + h_files + ["-o", "kernel"], check=True, cwd=dir) + subprocess.run( + [sys.executable, linker_path] + h_files + ["-o", "kernel"], check=True, cwd=dir + ) def generate_matmul_test_data(dir, M, N, K): @@ -221,16 +292,16 @@ def generate_matmul_test_data(dir, M, N, K): return a, b, a_path, b_path, c_path -def test_compile_link_matmul(): +# Test edge case where the provided kernel signature has no specializations +def test_compile_link_matmul_no_specialization(): np.random.seed(3) with tempfile.TemporaryDirectory() as tmp_dir: - dtype = "fp16" BM, BN, BK = 16, 16, 16 kernel_path = write_triton_kernels(tmp_dir, kernel_src, kernel_utils_src) - compile_aot_kernels(tmp_dir, kernel_path, dtype, BM, BN, BK, ha_hb_hints=["", ":16"]) + compile_aot_kernel_no_specialization(tmp_dir, kernel_path, dtype, BM, BN, BK) link_aot_kernels(tmp_dir) # compile test case @@ -244,13 +315,50 @@ def test_compile_link_matmul(): # run test case env = os.environ.copy() env["LD_LIBRARY_PATH"] = tmp_dir - subprocess.run(["./test", a_path, b_path, c_path], env=env, check=True, cwd=tmp_dir) + subprocess.run( + ["./test", a_path, b_path, c_path], env=env, check=True, cwd=tmp_dir + ) # read data and compare against reference c = np.genfromtxt(c_path, delimiter=",", dtype=np.int32) c_tri = c.reshape((M, N)).view(np.float32) c_ref = np.matmul(a.astype(np.float32), b.astype(np.float32)) - np.testing.assert_allclose(c_tri, c_ref * c_ref, atol=1e-4, rtol=0.) + np.testing.assert_allclose(c_tri, c_ref * c_ref, atol=1e-4, rtol=0.0) + + +def test_compile_link_matmul(): + np.random.seed(3) + + with tempfile.TemporaryDirectory() as tmp_dir: + dtype = "fp16" + BM, BN, BK = 16, 16, 16 + + kernel_path = write_triton_kernels(tmp_dir, kernel_src, kernel_utils_src) + compile_aot_kernels( + tmp_dir, kernel_path, dtype, BM, BN, BK, ha_hb_hints=["", ":16"] + ) + link_aot_kernels(tmp_dir) + + # compile test case + M, N, K = 16, 16, 16 + gen_kernel_library(tmp_dir, "libkernel.so") + gen_test_bin(tmp_dir, M, N, K) + + # initialize test data + a, b, a_path, b_path, c_path = generate_matmul_test_data(tmp_dir, M, N, K) + + # run test case + env = os.environ.copy() + env["LD_LIBRARY_PATH"] = tmp_dir + subprocess.run( + ["./test", a_path, b_path, c_path], env=env, check=True, cwd=tmp_dir + ) + + # read data and compare against reference + c = np.genfromtxt(c_path, delimiter=",", dtype=np.int32) + c_tri = c.reshape((M, N)).view(np.float32) + c_ref = np.matmul(a.astype(np.float32), b.astype(np.float32)) + np.testing.assert_allclose(c_tri, c_ref * c_ref, atol=1e-4, rtol=0.0) def test_launcher_has_no_available_kernel(): @@ -275,7 +383,13 @@ def test_launcher_has_no_available_kernel(): # run test case env = os.environ.copy() env["LD_LIBRARY_PATH"] = tmp_dir - result = subprocess.run(["./test", a_path, b_path, c_path], env=env, cwd=tmp_dir, capture_output=True, text=True) + result = subprocess.run( + ["./test", a_path, b_path, c_path], + env=env, + cwd=tmp_dir, + capture_output=True, + text=True, + ) # It should fail since the launcher requires all the strides be 1 while they are not. assert result.returncode == -6 @@ -286,7 +400,6 @@ def test_compile_link_autotune_matmul(): np.random.seed(3) with tempfile.TemporaryDirectory() as tmp_dir: - dtype = "fp16" kernel_path = write_triton_kernels(tmp_dir, kernel_src, kernel_utils_src) @@ -300,7 +413,9 @@ def test_compile_link_autotune_matmul(): for ts in tile_sizes: BM, BN, BK = ts[0], ts[1], ts[2] - compile_aot_kernels(tmp_dir, kernel_path, dtype, BM, BN, BK, ha_hb_hints=["", ":16"]) + compile_aot_kernels( + tmp_dir, kernel_path, dtype, BM, BN, BK, ha_hb_hints=["", ":16"] + ) link_aot_kernels(tmp_dir) @@ -319,7 +434,12 @@ def test_compile_link_autotune_matmul(): env = os.environ.copy() env["LD_LIBRARY_PATH"] = tmp_dir - subprocess.run([f"./{test_name}", a_path, b_path, c_path], check=True, cwd=tmp_dir, env=env) + subprocess.run( + [f"./{test_name}", a_path, b_path, c_path], + check=True, + cwd=tmp_dir, + env=env, + ) # read data and compare against reference c = np.genfromtxt(c_path, delimiter=",", dtype=np.int32) diff --git a/python/triton/tools/link.py b/python/triton/tools/link.py index 836c89c5f..68ace442f 100644 --- a/python/triton/tools/link.py +++ b/python/triton/tools/link.py @@ -31,7 +31,9 @@ class HeaderParser: import re # [kernel_name, c signature] - self.linker_directives = re.compile("//[\\s]*tt-linker:[\\s]*([\\w]+):(.+):(.+)") + self.linker_directives = re.compile( + "//[\\s]*tt-linker:[\\s]*([\\w]+):(.+):(.+)" + ) # [name, hash, suffix] self.kernel_name = re.compile("^([\\w]+)_([\\w]+)_([\\w]+)$") # [(type, name)] @@ -42,7 +44,6 @@ class HeaderParser: self.kernels = defaultdict(list) def extract_linker_meta(self, header: str): - for ln in header.splitlines(): if ln.startswith("//"): m = self.linker_directives.match(ln) @@ -76,7 +77,7 @@ class HeaderParser: m = self.c_sig.findall(c_sig) if len(m): tys, args = [], [] - for (ty, arg_name) in m: + for ty, arg_name in m: tys.append(ty) args.append(arg_name) return tys, args @@ -84,7 +85,7 @@ class HeaderParser: raise LinkerError(f"{c_sig} is not a valid argument signature") def _match_suffix(self, suffix: str, c_sig: str): - args = c_sig.split(',') + args = c_sig.split(",") s2i = {"c": 1, "d": 16} num_specs = 0 sizes = [] @@ -110,7 +111,7 @@ class HeaderParser: if name in self.kernels: last: KernelLinkerMeta = self.kernels[name][-1] - for (cur, new_) in zip(last.arg_ctypes, ker.arg_ctypes): + for cur, new_ in zip(last.arg_ctypes, ker.arg_ctypes): if cur != new_: raise LinkerError( f"Mismatched signature for kernel {name}: \n\texisting sig is: {','.join(last.arg_ctypes)}\n\tcurrent is: {','.join(ker.arg_ctypes)}" @@ -152,7 +153,9 @@ void unload_{meta.orig_kernel_name}(); # generate dispatcher function for kernels with different meta-parameter and constant values def make_default_algo_kernel(meta: KernelLinkerMeta) -> str: src = f"CUresult {meta.orig_kernel_name}_default(CUstream stream, {gen_signature_with_full_args(meta)}){{\n" - src += f" return {meta.orig_kernel_name}(stream, {', '.join(meta.arg_names)}, 0);\n" + src += ( + f" return {meta.orig_kernel_name}(stream, {', '.join(meta.arg_names)}, 0);\n" + ) src += "}\n" return src @@ -164,12 +167,28 @@ def make_kernel_hints_dispatcher(name: str, metas: Sequence[KernelLinkerMeta]) - src += f"CUresult {meta.orig_kernel_name}_{meta.sig_hash}_{meta.suffix}(CUstream stream, {gen_signature(meta)});\n" src += "\n" - src += f"CUresult {name}(CUstream stream, {gen_signature_with_full_args(metas[-1])}){{" + src += ( + f"CUresult {name}(CUstream stream, {gen_signature_with_full_args(metas[-1])}){{" + ) src += "\n" for meta in sorted(metas, key=lambda m: -m.num_specs): - cond_fn = lambda val, hint: f"({val} % {hint} == 0)" if hint == 16 else f"({val} == {hint})" if hint == 1 else None - conds = " && ".join([cond_fn(val, hint) for val, hint in zip(meta.arg_names, meta.sizes) if hint is not None]) - src += f" if ({conds})\n" + cond_fn = ( + lambda val, hint: f"({val} % {hint} == 0)" + if hint == 16 + else f"({val} == {hint})" + if hint == 1 + else None + ) + conds = " && ".join( + [ + cond_fn(val, hint) + for val, hint in zip(meta.arg_names, meta.sizes) + if hint is not None + ] + ) + src += ( + f" if ({conds})\n" if any(meta.sizes) else "if (1)\n" + ) # Edge case where no specializations hence no dispatching required arg_names = [arg for arg, hint in zip(meta.arg_names, meta.sizes) if hint != 1] src += f" return {meta.orig_kernel_name}_{meta.sig_hash}_{meta.suffix}(stream, {', '.join(arg_names)});\n" src += "\n" @@ -183,7 +202,9 @@ def make_kernel_hints_dispatcher(name: str, metas: Sequence[KernelLinkerMeta]) - src += f"void {mode}_{name}() {{" src += "\n" for meta in sorted(metas, key=lambda m: -m.num_specs): - src += f" {mode}_{meta.orig_kernel_name}_{meta.sig_hash}_{meta.suffix}();\n" + src += ( + f" {mode}_{meta.orig_kernel_name}_{meta.sig_hash}_{meta.suffix}();\n" + ) src += "}\n" return src @@ -252,7 +273,12 @@ if __name__ == "__main__": help="Paths to header files to link. Must include linker directive annotations (autogenerated by ttc)", ) parser.add_argument("--out", "-o", type=Path, help="Out filename") - parser.add_argument("--prefix", type=str, default="", help="String to prefix kernel dispatcher names") + parser.add_argument( + "--prefix", + type=str, + default="", + help="String to prefix kernel dispatcher names", + ) args = parser.parse_args() # metadata @@ -280,7 +306,10 @@ if __name__ == "__main__": fp.write(out) # generate source - defs = [make_kernel_hints_dispatcher(name, meta) for name, meta in parser.kernels.items()] + defs = [ + make_kernel_hints_dispatcher(name, meta) + for name, meta in parser.kernels.items() + ] names = [name for name in parser.kernels.keys()] func_pointers_def = make_func_pointers(names, meta) meta_const_def = make_kernel_meta_const_dispatcher(meta)