mirror of
https://github.com/ROCm/ROCm.git
synced 2026-04-05 03:01:17 -04:00
[TOOLS] Add support for autotuning AOT kernel (#2123)
This PR makes the following change to AOT kernel
- Allow the client to generate AOT kernels with different sets of
constexprs and meta-parameters. Each combination of constexpr set and
meta-parameters is referred to an "algo". Within an algo client can
still give different hints about integer arguments.
- Add a API int ${kernle_name}_get_num_algos() that returns the total
number of algos.
- Add a algo_id to allow client to the generated kernel to select the
algo
- Remove gX, gY and gZ from the kernel parameter list. This is because
the launch grid is usually different with different algos, and the
client should not need to care about how to compute the launch grid for
each algo. Instead, we ask the client to pass the expression of
computing gX, gY and gZ for compile.py (when AOT kernels are generated).
The expression can only use kernel parameter or const values.
- We also change the testing flow. Now we first build the kernels into a
shared library libkernel.so, then the client test.c code is built and
link with libkernel.so. This is closer to a typical AOT kernel usage
flow.
This commit is contained in:
@@ -23,26 +23,43 @@ import triton.language as tl
|
||||
import kernel_utils
|
||||
|
||||
@triton.jit
|
||||
def kernel(C, A, B,
|
||||
def kernel(C, A, B, M, N, K,
|
||||
stride_cm, stride_cn,
|
||||
stride_am, stride_ak,
|
||||
stride_bk, stride_bn,
|
||||
BLOCK_M: tl.constexpr,
|
||||
BLOCK_N: tl.constexpr,
|
||||
BLOCK_K: tl.constexpr):
|
||||
ms = tl.arange(0, BLOCK_M)
|
||||
ns = tl.arange(0, BLOCK_N)
|
||||
ks = tl.arange(0, BLOCK_K)
|
||||
a = tl.load(A + ms[:, None] * stride_am + ks[None, :] * stride_ak)
|
||||
b = tl.load(B + ks[:, None] * stride_bk + ns[None, :] * stride_bn)
|
||||
c = tl.dot(a, b)
|
||||
c = kernel_utils.mul(c, c)
|
||||
tl.store(C + ms[:, None] * stride_cm + ns[None, :] * stride_cn, c)
|
||||
pid_m = tl.program_id(0)
|
||||
pid_n = tl.program_id(1)
|
||||
|
||||
offs_am = (pid_m * BLOCK_M + tl.arange(0, BLOCK_M)) % M
|
||||
offs_bn = (pid_n * BLOCK_N + tl.arange(0, BLOCK_N)) % N
|
||||
offs_k = tl.arange(0, BLOCK_K)
|
||||
a_ptrs = A + (offs_am[:, None] * stride_am + offs_k[None, :] * stride_ak)
|
||||
b_ptrs = B + (offs_k[:, None] * stride_bk + offs_bn[None, :] * stride_bn)
|
||||
|
||||
accumulator = tl.zeros((BLOCK_M, BLOCK_N), dtype=tl.float32)
|
||||
for k in range(0, tl.cdiv(K, BLOCK_K)):
|
||||
# Load the next block of A and B, generate a mask by checking the K dimension.
|
||||
# If it is out of bounds, set it to 0.
|
||||
a = tl.load(a_ptrs, mask=offs_k[None, :] < K - k * BLOCK_K, other=0.0)
|
||||
b = tl.load(b_ptrs, mask=offs_k[:, None] < K - k * BLOCK_K, other=0.0)
|
||||
# We accumulate along the K dimension.
|
||||
accumulator += tl.dot(a, b)
|
||||
# Advance the ptrs to the next K block.
|
||||
a_ptrs += BLOCK_K * stride_ak
|
||||
b_ptrs += BLOCK_K * stride_bk
|
||||
|
||||
c = kernel_utils.mul(accumulator, accumulator)
|
||||
# Write back the block of the output matrix C with masks.
|
||||
offs_cm = pid_m * BLOCK_M + tl.arange(0, BLOCK_M)
|
||||
offs_cn = pid_n * BLOCK_N + tl.arange(0, BLOCK_N)
|
||||
c_ptrs = C + stride_cm * offs_cm[:, None] + stride_cn * offs_cn[None, :]
|
||||
tl.store(c_ptrs, c)
|
||||
"""
|
||||
|
||||
|
||||
def gen_test_bin(dir, M, N, K, BM, BN, BK):
|
||||
test_src = '''
|
||||
test_utils_src = '''
|
||||
#include <cuda.h>
|
||||
#include <stdio.h>
|
||||
#include <stdint.h>
|
||||
@@ -78,10 +95,23 @@ static void read_csv_to_buffer(char *filename, int16_t *buffer, int size) {
|
||||
fclose(file);
|
||||
}'''
|
||||
|
||||
test_src += f'''
|
||||
|
||||
def gen_kernel_library(dir, libname):
|
||||
c_files = glob.glob(os.path.join(dir, "*.c"))
|
||||
subprocess.run(["gcc"] + c_files + ["-I", cuda_include_dir(),
|
||||
"-c", "-fPIC"],
|
||||
check=True, cwd=dir)
|
||||
o_files = glob.glob(os.path.join(dir, "*.o"))
|
||||
subprocess.run(["gcc"] + o_files + ["-shared",
|
||||
"-o", libname,
|
||||
"-L", libcuda_dirs()[0]],
|
||||
check=True, cwd=dir)
|
||||
|
||||
|
||||
def gen_test_bin(dir, M, N, K, exe="test", algo_id=0):
|
||||
test_src = f'''
|
||||
int main(int argc, char **argv) {{
|
||||
int M = {M}, N = {N}, K = {K};
|
||||
int BM = {M}, BN = {N}, BK = {K};
|
||||
|
||||
// initialize CUDA handles
|
||||
CUdevice dev;
|
||||
@@ -96,7 +126,7 @@ int main(int argc, char **argv) {{
|
||||
cuMemAlloc(&B, K * N * 2);
|
||||
cuMemAlloc(&C, M * N * 4);
|
||||
cuStreamCreate(&stream, 0);
|
||||
load_matmul_fp16xfp16_16x16x16();
|
||||
load_matmul_fp16();
|
||||
|
||||
// initialize input data
|
||||
int16_t hA[M*K];
|
||||
@@ -110,7 +140,13 @@ int main(int argc, char **argv) {{
|
||||
|
||||
// launch kernel
|
||||
cuStreamSynchronize(stream);
|
||||
CUresult ret = matmul_fp16xfp16_16x16x16(stream, M/BM, N/BN, 1, C, A, B, N, 1, K, 1, N, 1);
|
||||
CUresult ret;
|
||||
int algo_id = {algo_id};
|
||||
if (algo_id == 0) {{
|
||||
ret = matmul_fp16_default(stream, C, A, B, M, N, K, N, 1, K, 1, N, 1);
|
||||
}} else {{
|
||||
ret = matmul_fp16(stream, C, A, B, M, N, K, N, 1, K, 1, N, 1, {algo_id});
|
||||
}}
|
||||
if (ret != 0) fprintf(stderr, "kernel launch failed\\n");
|
||||
assert(ret == 0);
|
||||
|
||||
@@ -123,41 +159,51 @@ int main(int argc, char **argv) {{
|
||||
write_buffer_to_csv(argv[3], hC, M*N);
|
||||
|
||||
// free cuda handles
|
||||
unload_matmul_fp16xfp16_16x16x16();
|
||||
unload_matmul_fp16();
|
||||
cuMemFree(A);
|
||||
cuMemFree(B);
|
||||
cuMemFree(C);
|
||||
cuCtxDestroy(ctx);
|
||||
}}
|
||||
'''
|
||||
|
||||
src = test_utils_src + test_src
|
||||
with open(os.path.join(dir, "test.c"), "w") as file:
|
||||
file.write(test_src)
|
||||
c_files = glob.glob(os.path.join(dir, "*.c"))
|
||||
subprocess.run(["gcc"] + c_files + ["-I", cuda_include_dir(),
|
||||
"-L", libcuda_dirs()[0],
|
||||
"-l", "cuda",
|
||||
"-o", "test"], check=True, cwd=dir)
|
||||
file.write(src)
|
||||
subprocess.run(["gcc"] + ["test.c",
|
||||
"-I", cuda_include_dir(),
|
||||
"-L", libcuda_dirs()[0],
|
||||
"-l", "cuda",
|
||||
"-L", dir,
|
||||
"-l", "kernel",
|
||||
"-o", exe], check=True, cwd=dir)
|
||||
|
||||
|
||||
def generate_matmul_launcher(dir, dtype, BM, BN, BK, ha_hb_hints):
|
||||
def write_triton_kernels(dir, src, util_src):
|
||||
kernel_path = os.path.join(dir, "kernel.py")
|
||||
with open(kernel_path, "w") as file:
|
||||
file.write(kernel_src)
|
||||
file.write(src)
|
||||
|
||||
kernel_utils_path = os.path.join(dir, "kernel_utils.py")
|
||||
with open(kernel_utils_path, "w") as file:
|
||||
file.write(kernel_utils_src)
|
||||
file.write(util_src)
|
||||
|
||||
return kernel_path
|
||||
|
||||
|
||||
def compile_aot_kernels(dir, kernel_path, dtype, BM, BN, BK, ha_hb_hints):
|
||||
compiler_path = os.path.join(triton.tools.__path__[0], "compile.py")
|
||||
linker_path = os.path.join(triton.tools.__path__[0], "link.py")
|
||||
|
||||
# compile all desired configs
|
||||
for ha in ha_hb_hints:
|
||||
for hb in ha_hb_hints:
|
||||
sig = f'*fp32:16, *{dtype}:16, *{dtype}:16, i32{ha}, i32:1, i32{hb}, i32:1, i32:16, i32:1, {BM}, {BN}, {BK}'
|
||||
name = f"matmul_{dtype}x{dtype}_{BM}x{BN}x{BK}"
|
||||
subprocess.run([sys.executable, compiler_path, "-n", "kernel", "--signature", sig, "--out-name", name, "-o", name, "-w", "1", kernel_path], check=True, cwd=dir)
|
||||
sig = f'*fp32:16, *{dtype}:16, *{dtype}:16, i32, i32, i32, i32{ha}, i32:1, i32{hb}, i32:1, i32:16, i32:1, {BM}, {BN}, {BK}'
|
||||
name = f"matmul_{dtype}"
|
||||
grid = f'M/{BM}, N/{BN}, 1'
|
||||
subprocess.run([sys.executable, compiler_path, "-n", "kernel", "--signature", sig, "--out-name", name, "-o", name, "-w", "1", "-g", grid, kernel_path], check=True, cwd=dir)
|
||||
|
||||
|
||||
def link_aot_kernels(dir):
|
||||
linker_path = os.path.join(triton.tools.__path__[0], "link.py")
|
||||
|
||||
# link all desired configs
|
||||
h_files = glob.glob(os.path.join(dir, "*.h"))
|
||||
@@ -183,17 +229,22 @@ def test_compile_link_matmul():
|
||||
dtype = "fp16"
|
||||
BM, BN, BK = 16, 16, 16
|
||||
|
||||
generate_matmul_launcher(tmp_dir, dtype, BM, BN, BK, ha_hb_hints=["", ":16"])
|
||||
kernel_path = write_triton_kernels(tmp_dir, kernel_src, kernel_utils_src)
|
||||
compile_aot_kernels(tmp_dir, kernel_path, dtype, BM, BN, BK, ha_hb_hints=["", ":16"])
|
||||
link_aot_kernels(tmp_dir)
|
||||
|
||||
# compile test case
|
||||
M, N, K = 16, 16, 16
|
||||
gen_test_bin(tmp_dir, M, N, K, BM, BN, BK)
|
||||
gen_kernel_library(tmp_dir, "libkernel.so")
|
||||
gen_test_bin(tmp_dir, M, N, K)
|
||||
|
||||
# initialize test data
|
||||
a, b, a_path, b_path, c_path = generate_matmul_test_data(tmp_dir, M, N, K)
|
||||
|
||||
# run test case
|
||||
subprocess.run(["./test", a_path, b_path, c_path], check=True, cwd=tmp_dir)
|
||||
env = os.environ.copy()
|
||||
env["LD_LIBRARY_PATH"] = tmp_dir
|
||||
subprocess.run(["./test", a_path, b_path, c_path], env=env, check=True, cwd=tmp_dir)
|
||||
|
||||
# read data and compare against reference
|
||||
c = np.genfromtxt(c_path, delimiter=",", dtype=np.int32)
|
||||
@@ -209,23 +260,73 @@ def test_launcher_has_no_available_kernel():
|
||||
dtype = "fp16"
|
||||
BM, BN, BK = 16, 16, 16
|
||||
|
||||
generate_matmul_launcher(tmp_dir, dtype, BM, BN, BK, ha_hb_hints=[":1"])
|
||||
kernel_path = write_triton_kernels(tmp_dir, kernel_src, kernel_utils_src)
|
||||
compile_aot_kernels(tmp_dir, kernel_path, dtype, BM, BN, BK, ha_hb_hints=[":1"])
|
||||
link_aot_kernels(tmp_dir)
|
||||
|
||||
# compile test case
|
||||
M, N, K = 16, 16, 16
|
||||
gen_test_bin(tmp_dir, M, N, K, BM, BN, BK)
|
||||
gen_kernel_library(tmp_dir, "libkernel.so")
|
||||
gen_test_bin(tmp_dir, M, N, K)
|
||||
|
||||
# initialize test data
|
||||
a, b, a_path, b_path, c_path = generate_matmul_test_data(tmp_dir, M, N, K)
|
||||
|
||||
# run test case
|
||||
result = subprocess.run(["./test", a_path, b_path, c_path], cwd=tmp_dir, capture_output=True, text=True)
|
||||
env = os.environ.copy()
|
||||
env["LD_LIBRARY_PATH"] = tmp_dir
|
||||
result = subprocess.run(["./test", a_path, b_path, c_path], env=env, cwd=tmp_dir, capture_output=True, text=True)
|
||||
|
||||
# It should fail since the launcher requires all the strides be 1 while they are not.
|
||||
assert result.returncode == -6
|
||||
assert "kernel launch failed" in result.stderr
|
||||
|
||||
|
||||
def test_compile_link_autotune_matmul():
|
||||
np.random.seed(3)
|
||||
|
||||
with tempfile.TemporaryDirectory() as tmp_dir:
|
||||
|
||||
dtype = "fp16"
|
||||
|
||||
kernel_path = write_triton_kernels(tmp_dir, kernel_src, kernel_utils_src)
|
||||
|
||||
tile_sizes = [
|
||||
[16, 16, 16],
|
||||
[32, 32, 16],
|
||||
[32, 32, 32],
|
||||
[64, 64, 32],
|
||||
]
|
||||
|
||||
for ts in tile_sizes:
|
||||
BM, BN, BK = ts[0], ts[1], ts[2]
|
||||
compile_aot_kernels(tmp_dir, kernel_path, dtype, BM, BN, BK, ha_hb_hints=["", ":16"])
|
||||
|
||||
link_aot_kernels(tmp_dir)
|
||||
|
||||
gen_kernel_library(tmp_dir, "libkernel.so")
|
||||
|
||||
# compile test case
|
||||
M, N, K = 64, 64, 64
|
||||
# initialize test data
|
||||
a, b, a_path, b_path, c_path = generate_matmul_test_data(tmp_dir, M, N, K)
|
||||
c_ref = np.matmul(a.astype(np.float32), b.astype(np.float32))
|
||||
|
||||
for algo_id in range(len(tile_sizes)):
|
||||
# generate and run test case
|
||||
test_name = f"test_{algo_id}"
|
||||
gen_test_bin(tmp_dir, M, N, K, exe=test_name, algo_id=algo_id)
|
||||
|
||||
env = os.environ.copy()
|
||||
env["LD_LIBRARY_PATH"] = tmp_dir
|
||||
subprocess.run([f"./{test_name}", a_path, b_path, c_path], check=True, cwd=tmp_dir, env=env)
|
||||
|
||||
# read data and compare against reference
|
||||
c = np.genfromtxt(c_path, delimiter=",", dtype=np.int32)
|
||||
c_tri = c.reshape((M, N)).view(np.float32)
|
||||
np.testing.assert_allclose(c_tri, c_ref * c_ref, atol=1e-4, rtol=1e-4)
|
||||
|
||||
|
||||
def test_ttgir_to_ptx():
|
||||
src = """
|
||||
module attributes {"triton_gpu.num-warps" = 4 : i32, "triton_gpu.threads-per-warp" = 32 : i32, "triton_gpu.num-ctas" = 1 : i32} {
|
||||
|
||||
@@ -54,9 +54,12 @@ void load_{kernel_name}() {{
|
||||
/*
|
||||
{kernel_docstring}
|
||||
*/
|
||||
CUresult {kernel_name}(CUstream stream, unsigned int gX, unsigned int gY, unsigned int gZ, {signature}) {{
|
||||
CUresult {kernel_name}(CUstream stream, {signature}) {{
|
||||
if ({kernel_name}_func == NULL)
|
||||
load_{kernel_name}();
|
||||
unsigned int gX = {gridX};
|
||||
unsigned int gY = {gridY};
|
||||
unsigned int gZ = {gridZ};
|
||||
void *args[{num_args}] = {{ {arg_pointers} }};
|
||||
// TODO: shared memory
|
||||
if(gX * gY * gZ > 0)
|
||||
|
||||
@@ -10,7 +10,5 @@
|
||||
|
||||
void unload_{kernel_name}(void);
|
||||
void load_{kernel_name}(void);
|
||||
// tt-linker: {kernel_name}:{full_signature}
|
||||
CUresult{_placeholder} {kernel_name}(CUstream stream, unsigned int gX,
|
||||
unsigned int gY, unsigned int gZ,
|
||||
{signature});
|
||||
// tt-linker: {kernel_name}:{full_signature}:{algo_info}
|
||||
CUresult{_placeholder} {kernel_name}(CUstream stream, {signature});
|
||||
|
||||
@@ -43,10 +43,11 @@ if __name__ == "__main__":
|
||||
parser.add_argument("path", help="Path to Python source containing desired kernel in its scope. File will be executed.")
|
||||
parser.add_argument("--kernel-name", "-n", type=str, default="", help="Name of the kernel to compile", required=True)
|
||||
parser.add_argument("--num-warps", "-w", type=int, default=1, help="Number of warps to launch the kernel")
|
||||
parser.add_argument("--num-stages", "-ns", type=int, default=3, help="Number of stages meta-parameter for the kernel")
|
||||
parser.add_argument("--num-stages", "-ns", type=int, default=3, help="Number of stages (meta-parameter of the kernel)")
|
||||
parser.add_argument("--out-name", "-on", type=str, default=None, help="Out name for the compiled kernel")
|
||||
parser.add_argument("--out-path", "-o", type=Path, default=None, help="Out filename")
|
||||
parser.add_argument("--signature", "-s", type=str, help="Signature of the kernel", required=True)
|
||||
parser.add_argument("--grid", "-g", type=str, help="Launch grid of the kernel", required=True)
|
||||
args = parser.parse_args()
|
||||
|
||||
out_name = args.out_name if args.out_name else args.kernel_name
|
||||
@@ -59,6 +60,8 @@ if __name__ == "__main__":
|
||||
mod = importlib.util.module_from_spec(spec)
|
||||
spec.loader.exec_module(mod)
|
||||
kernel = getattr(mod, args.kernel_name)
|
||||
grid = args.grid.split(",")
|
||||
assert len(grid) == 3
|
||||
|
||||
# validate and parse signature
|
||||
signature = list(map(lambda s: s.strip(" "), args.signature.split(",")))
|
||||
@@ -68,7 +71,8 @@ if __name__ == "__main__":
|
||||
m.update(" ".join(signature).encode())
|
||||
return m.hexdigest()[:8]
|
||||
|
||||
sig_hash = hash_signature(signature)
|
||||
meta_sig = f"warps{args.num_warps}xstages{args.num_stages}"
|
||||
sig_hash = hash_signature(signature + [meta_sig])
|
||||
|
||||
def constexpr(s):
|
||||
try:
|
||||
@@ -88,6 +92,9 @@ if __name__ == "__main__":
|
||||
constexprs = {i: constexpr(s) for i, s in enumerate(signature)}
|
||||
constexprs = {k: v for k, v in constexprs.items() if v is not None}
|
||||
signature = {i: s.split(":")[0] for i, s in enumerate(signature) if i not in constexprs}
|
||||
const_sig = 'x'.join([str(v) for v in constexprs.values()])
|
||||
doc_string = [f"{kernel.arg_names[i]}={constexprs[i]}" for i in constexprs.keys()]
|
||||
doc_string += [f"num_warps={args.num_warps}", f"num_stages={args.num_stages}"]
|
||||
|
||||
# compile ast into cubin
|
||||
for h in hints.values():
|
||||
@@ -119,9 +126,13 @@ if __name__ == "__main__":
|
||||
"full_signature": ", ".join([f"{ty_to_cpp(signature[i])} {kernel.arg_names[i]}" for i in signature.keys()]),
|
||||
"arg_pointers": ", ".join([f"&{arg}" for arg in arg_names]),
|
||||
"num_args": len(arg_names),
|
||||
"kernel_docstring": "",
|
||||
"kernel_docstring": doc_string,
|
||||
"shared": ccinfo.shared,
|
||||
"num_warps": args.num_warps,
|
||||
"algo_info": '_'.join([const_sig, meta_sig]),
|
||||
"gridX": grid[0],
|
||||
"gridY": grid[1],
|
||||
"gridZ": grid[2],
|
||||
"_placeholder": "",
|
||||
}
|
||||
for ext in ['h', 'c']:
|
||||
|
||||
@@ -15,10 +15,12 @@ class LinkerError(Exception):
|
||||
|
||||
@dataclass
|
||||
class KernelLinkerMeta:
|
||||
orig_kernel_name: str
|
||||
arg_names: Sequence[str]
|
||||
arg_ctypes: Sequence[str]
|
||||
sizes: Sequence[Union[int, None]]
|
||||
sig_hash: str
|
||||
triton_suffix: str
|
||||
suffix: str
|
||||
num_specs: int
|
||||
""" number of specialized arguments """
|
||||
@@ -29,8 +31,8 @@ class HeaderParser:
|
||||
import re
|
||||
|
||||
# [kernel_name, c signature]
|
||||
self.linker_directives = re.compile("//[\\s]*tt-linker:[\\s]*([\\w]+):(.+)")
|
||||
# [name, suffix]
|
||||
self.linker_directives = re.compile("//[\\s]*tt-linker:[\\s]*([\\w]+):(.+):(.+)")
|
||||
# [name, hash, suffix]
|
||||
self.kernel_name = re.compile("^([\\w]+)_([\\w]+)_([\\w]+)$")
|
||||
# [(type, name)]
|
||||
self.c_sig = re.compile("[\\s]*(\\w+)\\s(\\w+)[,]?")
|
||||
@@ -45,17 +47,19 @@ class HeaderParser:
|
||||
if ln.startswith("//"):
|
||||
m = self.linker_directives.match(ln)
|
||||
if _exists(m):
|
||||
ker_name, c_sig = m.group(1), m.group(2)
|
||||
ker_name, c_sig, algo_info = m.group(1), m.group(2), m.group(3)
|
||||
name, sig_hash, suffix = self._match_name(ker_name)
|
||||
c_types, arg_names = self._match_c_sig(c_sig)
|
||||
num_specs, sizes = self._match_suffix(suffix, c_sig)
|
||||
self._add_kernel(
|
||||
name,
|
||||
"_".join([name, algo_info]),
|
||||
KernelLinkerMeta(
|
||||
orig_kernel_name=name,
|
||||
arg_names=arg_names,
|
||||
arg_ctypes=c_types,
|
||||
sizes=sizes,
|
||||
sig_hash=sig_hash,
|
||||
triton_suffix=suffix,
|
||||
suffix=suffix,
|
||||
num_specs=num_specs,
|
||||
),
|
||||
@@ -126,28 +130,48 @@ def gen_signature(m):
|
||||
return sig
|
||||
|
||||
|
||||
def make_decls(name: str, metas: Sequence[KernelLinkerMeta]) -> str:
|
||||
# generate declarations of kernels with meta-parameter and constant values
|
||||
def make_algo_decls(name: str, metas: Sequence[KernelLinkerMeta]) -> str:
|
||||
return f"""
|
||||
CUresult {name}(CUstream stream, unsigned int gX, unsigned int gY, unsigned int gZ, {gen_signature_with_full_args(metas[-1])});
|
||||
CUresult {name}(CUstream stream, {gen_signature_with_full_args(metas[-1])});
|
||||
void load_{name}();
|
||||
void unload_{name}();
|
||||
"""
|
||||
|
||||
|
||||
def make_kernel_dispatcher(name: str, metas: Sequence[KernelLinkerMeta]) -> str:
|
||||
# generate declarations of kernels with meta-parameter and constant values
|
||||
def make_global_decl(meta: KernelLinkerMeta) -> str:
|
||||
return f"""
|
||||
CUresult {meta.orig_kernel_name}_default(CUstream stream, {gen_signature_with_full_args(meta)});
|
||||
CUresult {meta.orig_kernel_name}(CUstream stream, {gen_signature_with_full_args(meta)}, int algo_id);
|
||||
void load_{meta.orig_kernel_name}();
|
||||
void unload_{meta.orig_kernel_name}();
|
||||
"""
|
||||
|
||||
|
||||
# generate dispatcher function for kernels with different meta-parameter and constant values
|
||||
def make_default_algo_kernel(meta: KernelLinkerMeta) -> str:
|
||||
src = f"CUresult {meta.orig_kernel_name}_default(CUstream stream, {gen_signature_with_full_args(meta)}){{\n"
|
||||
src += f" return {meta.orig_kernel_name}(stream, {', '.join(meta.arg_names)}, 0);\n"
|
||||
src += "}\n"
|
||||
return src
|
||||
|
||||
|
||||
# generate dispatcher function for kernels with different integer value hints
|
||||
def make_kernel_hints_dispatcher(name: str, metas: Sequence[KernelLinkerMeta]) -> str:
|
||||
src = f"// launcher for: {name}\n"
|
||||
for meta in sorted(metas, key=lambda m: -m.num_specs):
|
||||
src += f"CUresult {name}_{meta.sig_hash}_{meta.suffix}(CUstream stream, unsigned int gX, unsigned int gY, unsigned int gZ, {gen_signature(meta)});\n"
|
||||
src += f"CUresult {meta.orig_kernel_name}_{meta.sig_hash}_{meta.suffix}(CUstream stream, {gen_signature(meta)});\n"
|
||||
src += "\n"
|
||||
|
||||
src += f"CUresult {name}(CUstream stream, unsigned int gX, unsigned int gY, unsigned int gZ, {gen_signature_with_full_args(metas[-1])}){{"
|
||||
src += f"CUresult {name}(CUstream stream, {gen_signature_with_full_args(metas[-1])}){{"
|
||||
src += "\n"
|
||||
for meta in sorted(metas, key=lambda m: -m.num_specs):
|
||||
cond_fn = lambda val, hint: f"({val} % {hint} == 0)" if hint == 16 else f"({val} == {hint})" if hint == 1 else None
|
||||
conds = " && ".join([cond_fn(val, hint) for val, hint in zip(meta.arg_names, meta.sizes) if hint is not None])
|
||||
src += f" if ({conds})\n"
|
||||
arg_names = [arg for arg, hint in zip(meta.arg_names, meta.sizes) if hint != 1]
|
||||
src += f" return {name}_{meta.sig_hash}_{meta.suffix}(stream, gX, gY, gZ, {', '.join(arg_names)});\n"
|
||||
src += f" return {meta.orig_kernel_name}_{meta.sig_hash}_{meta.suffix}(stream, {', '.join(arg_names)});\n"
|
||||
src += "\n"
|
||||
src += " return CUDA_ERROR_INVALID_VALUE;\n"
|
||||
src += "}\n"
|
||||
@@ -155,15 +179,58 @@ def make_kernel_dispatcher(name: str, metas: Sequence[KernelLinkerMeta]) -> str:
|
||||
for mode in ["load", "unload"]:
|
||||
src += f"\n// {mode} for: {name}\n"
|
||||
for meta in sorted(metas, key=lambda m: -m.num_specs):
|
||||
src += f"void {mode}_{name}_{meta.sig_hash}_{meta.suffix}();\n"
|
||||
src += f"void {mode}_{meta.orig_kernel_name}_{meta.sig_hash}_{meta.suffix}();\n"
|
||||
src += f"void {mode}_{name}() {{"
|
||||
src += "\n"
|
||||
for meta in sorted(metas, key=lambda m: -m.num_specs):
|
||||
src += f" {mode}_{name}_{meta.sig_hash}_{meta.suffix}();\n"
|
||||
src += f" {mode}_{meta.orig_kernel_name}_{meta.sig_hash}_{meta.suffix}();\n"
|
||||
src += "}\n"
|
||||
return src
|
||||
|
||||
|
||||
# generate dispatcher function for kernels with different meta-parameter and constant values
|
||||
def make_kernel_meta_const_dispatcher(meta: KernelLinkerMeta) -> str:
|
||||
src = f"CUresult {meta.orig_kernel_name}(CUstream stream, {gen_signature_with_full_args(meta)}, int algo_id){{\n"
|
||||
src += f" assert (algo_id < (int)sizeof({meta.orig_kernel_name}_kernels));\n"
|
||||
src += f" return {meta.orig_kernel_name}_kernels[algo_id](stream, {', '.join(meta.arg_names)});\n"
|
||||
src += "}\n"
|
||||
return src
|
||||
|
||||
|
||||
# generate definition of function pointers of kernel dispatchers based on meta-parameter and constant values
|
||||
def make_func_pointers(names: str, meta: KernelLinkerMeta) -> str:
|
||||
# the table of hint dispatchers
|
||||
src = f"typedef CUresult (*kernel_func_t)(CUstream stream, {gen_signature_with_full_args(meta)});\n"
|
||||
src += f"kernel_func_t {meta.orig_kernel_name}_kernels[] = {{\n"
|
||||
for name in names:
|
||||
src += f" {name},\n"
|
||||
src += "};\n"
|
||||
return src
|
||||
|
||||
|
||||
# generate definition for load/unload functions for kernels with different meta-parameter and constant values
|
||||
def make_kernel_load_def(names: str, meta: KernelLinkerMeta) -> str:
|
||||
src = ""
|
||||
for mode in ["load", "unload"]:
|
||||
src += f"void {mode}_{meta.orig_kernel_name}(void){{\n"
|
||||
for name in names:
|
||||
src += f" {mode}_{name}();\n"
|
||||
src += "}\n\n"
|
||||
return src
|
||||
|
||||
|
||||
def make_get_num_algos_decl(meta: KernelLinkerMeta) -> str:
|
||||
src = f"int {meta.orig_kernel_name}_get_num_algos(void);"
|
||||
return src
|
||||
|
||||
|
||||
def make_get_num_algos_def(meta: KernelLinkerMeta) -> str:
|
||||
src = f"int {meta.orig_kernel_name}_get_num_algos(void){{\n"
|
||||
src += f" return (int)sizeof({meta.orig_kernel_name}_kernels);\n"
|
||||
src += "}\n"
|
||||
return src
|
||||
|
||||
|
||||
desc = """
|
||||
Triton ahead-of-time linker:
|
||||
|
||||
@@ -198,16 +265,43 @@ if __name__ == "__main__":
|
||||
parser.extract_linker_meta(h_str)
|
||||
|
||||
# generate headers
|
||||
decls = [make_decls(name, meta) for name, meta in parser.kernels.items()]
|
||||
algo_decls = [make_algo_decls(name, meta) for name, meta in parser.kernels.items()]
|
||||
meta_lists = [meta for name, meta in parser.kernels.items()]
|
||||
meta = meta_lists[0][0]
|
||||
get_num_algos_decl = make_get_num_algos_decl(meta)
|
||||
global_decl = make_global_decl(meta)
|
||||
with args.out.with_suffix(".h").open("w") as fp:
|
||||
fp.write("#include <cuda.h>\n" + "\n".join(decls))
|
||||
out = "#include <cuda.h>\n"
|
||||
out += "\n".join(algo_decls)
|
||||
out += "\n"
|
||||
out += get_num_algos_decl
|
||||
out += "\n"
|
||||
out += global_decl
|
||||
fp.write(out)
|
||||
|
||||
# generate source
|
||||
defs = [make_kernel_dispatcher(name, meta) for name, meta in parser.kernels.items()]
|
||||
defs = [make_kernel_hints_dispatcher(name, meta) for name, meta in parser.kernels.items()]
|
||||
names = [name for name in parser.kernels.keys()]
|
||||
func_pointers_def = make_func_pointers(names, meta)
|
||||
meta_const_def = make_kernel_meta_const_dispatcher(meta)
|
||||
load_unload_def = make_kernel_load_def(names, meta)
|
||||
get_num_algos_def = make_get_num_algos_def(meta)
|
||||
default_algo_kernel = make_default_algo_kernel(meta)
|
||||
with args.out.with_suffix(".c").open("w") as fp:
|
||||
out = ""
|
||||
out += "#include <cuda.h>\n"
|
||||
out += "#include <stdint.h>\n"
|
||||
out += "#include <assert.h>\n"
|
||||
out += "\n"
|
||||
out += "\n".join(defs)
|
||||
out += "\n"
|
||||
out += func_pointers_def
|
||||
out += "\n"
|
||||
out += get_num_algos_def
|
||||
out += "\n"
|
||||
out += meta_const_def
|
||||
out += "\n"
|
||||
out += load_unload_def
|
||||
out += "\n"
|
||||
out += default_algo_kernel
|
||||
fp.write(out)
|
||||
|
||||
Reference in New Issue
Block a user