mirror of
https://github.com/ROCm/ROCm.git
synced 2026-04-05 03:01:17 -04:00
[TOOLS] Add support for autotuning AOT kernel (#2123)
This PR makes the following change to AOT kernel
- Allow the client to generate AOT kernels with different sets of
constexprs and meta-parameters. Each combination of constexpr set and
meta-parameters is referred to an "algo". Within an algo client can
still give different hints about integer arguments.
- Add a API int ${kernle_name}_get_num_algos() that returns the total
number of algos.
- Add a algo_id to allow client to the generated kernel to select the
algo
- Remove gX, gY and gZ from the kernel parameter list. This is because
the launch grid is usually different with different algos, and the
client should not need to care about how to compute the launch grid for
each algo. Instead, we ask the client to pass the expression of
computing gX, gY and gZ for compile.py (when AOT kernels are generated).
The expression can only use kernel parameter or const values.
- We also change the testing flow. Now we first build the kernels into a
shared library libkernel.so, then the client test.c code is built and
link with libkernel.so. This is closer to a typical AOT kernel usage
flow.
This commit is contained in:
@@ -23,26 +23,43 @@ import triton.language as tl
|
||||
import kernel_utils
|
||||
|
||||
@triton.jit
|
||||
def kernel(C, A, B,
|
||||
def kernel(C, A, B, M, N, K,
|
||||
stride_cm, stride_cn,
|
||||
stride_am, stride_ak,
|
||||
stride_bk, stride_bn,
|
||||
BLOCK_M: tl.constexpr,
|
||||
BLOCK_N: tl.constexpr,
|
||||
BLOCK_K: tl.constexpr):
|
||||
ms = tl.arange(0, BLOCK_M)
|
||||
ns = tl.arange(0, BLOCK_N)
|
||||
ks = tl.arange(0, BLOCK_K)
|
||||
a = tl.load(A + ms[:, None] * stride_am + ks[None, :] * stride_ak)
|
||||
b = tl.load(B + ks[:, None] * stride_bk + ns[None, :] * stride_bn)
|
||||
c = tl.dot(a, b)
|
||||
c = kernel_utils.mul(c, c)
|
||||
tl.store(C + ms[:, None] * stride_cm + ns[None, :] * stride_cn, c)
|
||||
pid_m = tl.program_id(0)
|
||||
pid_n = tl.program_id(1)
|
||||
|
||||
offs_am = (pid_m * BLOCK_M + tl.arange(0, BLOCK_M)) % M
|
||||
offs_bn = (pid_n * BLOCK_N + tl.arange(0, BLOCK_N)) % N
|
||||
offs_k = tl.arange(0, BLOCK_K)
|
||||
a_ptrs = A + (offs_am[:, None] * stride_am + offs_k[None, :] * stride_ak)
|
||||
b_ptrs = B + (offs_k[:, None] * stride_bk + offs_bn[None, :] * stride_bn)
|
||||
|
||||
accumulator = tl.zeros((BLOCK_M, BLOCK_N), dtype=tl.float32)
|
||||
for k in range(0, tl.cdiv(K, BLOCK_K)):
|
||||
# Load the next block of A and B, generate a mask by checking the K dimension.
|
||||
# If it is out of bounds, set it to 0.
|
||||
a = tl.load(a_ptrs, mask=offs_k[None, :] < K - k * BLOCK_K, other=0.0)
|
||||
b = tl.load(b_ptrs, mask=offs_k[:, None] < K - k * BLOCK_K, other=0.0)
|
||||
# We accumulate along the K dimension.
|
||||
accumulator += tl.dot(a, b)
|
||||
# Advance the ptrs to the next K block.
|
||||
a_ptrs += BLOCK_K * stride_ak
|
||||
b_ptrs += BLOCK_K * stride_bk
|
||||
|
||||
c = kernel_utils.mul(accumulator, accumulator)
|
||||
# Write back the block of the output matrix C with masks.
|
||||
offs_cm = pid_m * BLOCK_M + tl.arange(0, BLOCK_M)
|
||||
offs_cn = pid_n * BLOCK_N + tl.arange(0, BLOCK_N)
|
||||
c_ptrs = C + stride_cm * offs_cm[:, None] + stride_cn * offs_cn[None, :]
|
||||
tl.store(c_ptrs, c)
|
||||
"""
|
||||
|
||||
|
||||
def gen_test_bin(dir, M, N, K, BM, BN, BK):
|
||||
test_src = '''
|
||||
test_utils_src = '''
|
||||
#include <cuda.h>
|
||||
#include <stdio.h>
|
||||
#include <stdint.h>
|
||||
@@ -78,10 +95,23 @@ static void read_csv_to_buffer(char *filename, int16_t *buffer, int size) {
|
||||
fclose(file);
|
||||
}'''
|
||||
|
||||
test_src += f'''
|
||||
|
||||
def gen_kernel_library(dir, libname):
|
||||
c_files = glob.glob(os.path.join(dir, "*.c"))
|
||||
subprocess.run(["gcc"] + c_files + ["-I", cuda_include_dir(),
|
||||
"-c", "-fPIC"],
|
||||
check=True, cwd=dir)
|
||||
o_files = glob.glob(os.path.join(dir, "*.o"))
|
||||
subprocess.run(["gcc"] + o_files + ["-shared",
|
||||
"-o", libname,
|
||||
"-L", libcuda_dirs()[0]],
|
||||
check=True, cwd=dir)
|
||||
|
||||
|
||||
def gen_test_bin(dir, M, N, K, exe="test", algo_id=0):
|
||||
test_src = f'''
|
||||
int main(int argc, char **argv) {{
|
||||
int M = {M}, N = {N}, K = {K};
|
||||
int BM = {M}, BN = {N}, BK = {K};
|
||||
|
||||
// initialize CUDA handles
|
||||
CUdevice dev;
|
||||
@@ -96,7 +126,7 @@ int main(int argc, char **argv) {{
|
||||
cuMemAlloc(&B, K * N * 2);
|
||||
cuMemAlloc(&C, M * N * 4);
|
||||
cuStreamCreate(&stream, 0);
|
||||
load_matmul_fp16xfp16_16x16x16();
|
||||
load_matmul_fp16();
|
||||
|
||||
// initialize input data
|
||||
int16_t hA[M*K];
|
||||
@@ -110,7 +140,13 @@ int main(int argc, char **argv) {{
|
||||
|
||||
// launch kernel
|
||||
cuStreamSynchronize(stream);
|
||||
CUresult ret = matmul_fp16xfp16_16x16x16(stream, M/BM, N/BN, 1, C, A, B, N, 1, K, 1, N, 1);
|
||||
CUresult ret;
|
||||
int algo_id = {algo_id};
|
||||
if (algo_id == 0) {{
|
||||
ret = matmul_fp16_default(stream, C, A, B, M, N, K, N, 1, K, 1, N, 1);
|
||||
}} else {{
|
||||
ret = matmul_fp16(stream, C, A, B, M, N, K, N, 1, K, 1, N, 1, {algo_id});
|
||||
}}
|
||||
if (ret != 0) fprintf(stderr, "kernel launch failed\\n");
|
||||
assert(ret == 0);
|
||||
|
||||
@@ -123,41 +159,51 @@ int main(int argc, char **argv) {{
|
||||
write_buffer_to_csv(argv[3], hC, M*N);
|
||||
|
||||
// free cuda handles
|
||||
unload_matmul_fp16xfp16_16x16x16();
|
||||
unload_matmul_fp16();
|
||||
cuMemFree(A);
|
||||
cuMemFree(B);
|
||||
cuMemFree(C);
|
||||
cuCtxDestroy(ctx);
|
||||
}}
|
||||
'''
|
||||
|
||||
src = test_utils_src + test_src
|
||||
with open(os.path.join(dir, "test.c"), "w") as file:
|
||||
file.write(test_src)
|
||||
c_files = glob.glob(os.path.join(dir, "*.c"))
|
||||
subprocess.run(["gcc"] + c_files + ["-I", cuda_include_dir(),
|
||||
"-L", libcuda_dirs()[0],
|
||||
"-l", "cuda",
|
||||
"-o", "test"], check=True, cwd=dir)
|
||||
file.write(src)
|
||||
subprocess.run(["gcc"] + ["test.c",
|
||||
"-I", cuda_include_dir(),
|
||||
"-L", libcuda_dirs()[0],
|
||||
"-l", "cuda",
|
||||
"-L", dir,
|
||||
"-l", "kernel",
|
||||
"-o", exe], check=True, cwd=dir)
|
||||
|
||||
|
||||
def generate_matmul_launcher(dir, dtype, BM, BN, BK, ha_hb_hints):
|
||||
def write_triton_kernels(dir, src, util_src):
|
||||
kernel_path = os.path.join(dir, "kernel.py")
|
||||
with open(kernel_path, "w") as file:
|
||||
file.write(kernel_src)
|
||||
file.write(src)
|
||||
|
||||
kernel_utils_path = os.path.join(dir, "kernel_utils.py")
|
||||
with open(kernel_utils_path, "w") as file:
|
||||
file.write(kernel_utils_src)
|
||||
file.write(util_src)
|
||||
|
||||
return kernel_path
|
||||
|
||||
|
||||
def compile_aot_kernels(dir, kernel_path, dtype, BM, BN, BK, ha_hb_hints):
|
||||
compiler_path = os.path.join(triton.tools.__path__[0], "compile.py")
|
||||
linker_path = os.path.join(triton.tools.__path__[0], "link.py")
|
||||
|
||||
# compile all desired configs
|
||||
for ha in ha_hb_hints:
|
||||
for hb in ha_hb_hints:
|
||||
sig = f'*fp32:16, *{dtype}:16, *{dtype}:16, i32{ha}, i32:1, i32{hb}, i32:1, i32:16, i32:1, {BM}, {BN}, {BK}'
|
||||
name = f"matmul_{dtype}x{dtype}_{BM}x{BN}x{BK}"
|
||||
subprocess.run([sys.executable, compiler_path, "-n", "kernel", "--signature", sig, "--out-name", name, "-o", name, "-w", "1", kernel_path], check=True, cwd=dir)
|
||||
sig = f'*fp32:16, *{dtype}:16, *{dtype}:16, i32, i32, i32, i32{ha}, i32:1, i32{hb}, i32:1, i32:16, i32:1, {BM}, {BN}, {BK}'
|
||||
name = f"matmul_{dtype}"
|
||||
grid = f'M/{BM}, N/{BN}, 1'
|
||||
subprocess.run([sys.executable, compiler_path, "-n", "kernel", "--signature", sig, "--out-name", name, "-o", name, "-w", "1", "-g", grid, kernel_path], check=True, cwd=dir)
|
||||
|
||||
|
||||
def link_aot_kernels(dir):
|
||||
linker_path = os.path.join(triton.tools.__path__[0], "link.py")
|
||||
|
||||
# link all desired configs
|
||||
h_files = glob.glob(os.path.join(dir, "*.h"))
|
||||
@@ -183,17 +229,22 @@ def test_compile_link_matmul():
|
||||
dtype = "fp16"
|
||||
BM, BN, BK = 16, 16, 16
|
||||
|
||||
generate_matmul_launcher(tmp_dir, dtype, BM, BN, BK, ha_hb_hints=["", ":16"])
|
||||
kernel_path = write_triton_kernels(tmp_dir, kernel_src, kernel_utils_src)
|
||||
compile_aot_kernels(tmp_dir, kernel_path, dtype, BM, BN, BK, ha_hb_hints=["", ":16"])
|
||||
link_aot_kernels(tmp_dir)
|
||||
|
||||
# compile test case
|
||||
M, N, K = 16, 16, 16
|
||||
gen_test_bin(tmp_dir, M, N, K, BM, BN, BK)
|
||||
gen_kernel_library(tmp_dir, "libkernel.so")
|
||||
gen_test_bin(tmp_dir, M, N, K)
|
||||
|
||||
# initialize test data
|
||||
a, b, a_path, b_path, c_path = generate_matmul_test_data(tmp_dir, M, N, K)
|
||||
|
||||
# run test case
|
||||
subprocess.run(["./test", a_path, b_path, c_path], check=True, cwd=tmp_dir)
|
||||
env = os.environ.copy()
|
||||
env["LD_LIBRARY_PATH"] = tmp_dir
|
||||
subprocess.run(["./test", a_path, b_path, c_path], env=env, check=True, cwd=tmp_dir)
|
||||
|
||||
# read data and compare against reference
|
||||
c = np.genfromtxt(c_path, delimiter=",", dtype=np.int32)
|
||||
@@ -209,23 +260,73 @@ def test_launcher_has_no_available_kernel():
|
||||
dtype = "fp16"
|
||||
BM, BN, BK = 16, 16, 16
|
||||
|
||||
generate_matmul_launcher(tmp_dir, dtype, BM, BN, BK, ha_hb_hints=[":1"])
|
||||
kernel_path = write_triton_kernels(tmp_dir, kernel_src, kernel_utils_src)
|
||||
compile_aot_kernels(tmp_dir, kernel_path, dtype, BM, BN, BK, ha_hb_hints=[":1"])
|
||||
link_aot_kernels(tmp_dir)
|
||||
|
||||
# compile test case
|
||||
M, N, K = 16, 16, 16
|
||||
gen_test_bin(tmp_dir, M, N, K, BM, BN, BK)
|
||||
gen_kernel_library(tmp_dir, "libkernel.so")
|
||||
gen_test_bin(tmp_dir, M, N, K)
|
||||
|
||||
# initialize test data
|
||||
a, b, a_path, b_path, c_path = generate_matmul_test_data(tmp_dir, M, N, K)
|
||||
|
||||
# run test case
|
||||
result = subprocess.run(["./test", a_path, b_path, c_path], cwd=tmp_dir, capture_output=True, text=True)
|
||||
env = os.environ.copy()
|
||||
env["LD_LIBRARY_PATH"] = tmp_dir
|
||||
result = subprocess.run(["./test", a_path, b_path, c_path], env=env, cwd=tmp_dir, capture_output=True, text=True)
|
||||
|
||||
# It should fail since the launcher requires all the strides be 1 while they are not.
|
||||
assert result.returncode == -6
|
||||
assert "kernel launch failed" in result.stderr
|
||||
|
||||
|
||||
def test_compile_link_autotune_matmul():
|
||||
np.random.seed(3)
|
||||
|
||||
with tempfile.TemporaryDirectory() as tmp_dir:
|
||||
|
||||
dtype = "fp16"
|
||||
|
||||
kernel_path = write_triton_kernels(tmp_dir, kernel_src, kernel_utils_src)
|
||||
|
||||
tile_sizes = [
|
||||
[16, 16, 16],
|
||||
[32, 32, 16],
|
||||
[32, 32, 32],
|
||||
[64, 64, 32],
|
||||
]
|
||||
|
||||
for ts in tile_sizes:
|
||||
BM, BN, BK = ts[0], ts[1], ts[2]
|
||||
compile_aot_kernels(tmp_dir, kernel_path, dtype, BM, BN, BK, ha_hb_hints=["", ":16"])
|
||||
|
||||
link_aot_kernels(tmp_dir)
|
||||
|
||||
gen_kernel_library(tmp_dir, "libkernel.so")
|
||||
|
||||
# compile test case
|
||||
M, N, K = 64, 64, 64
|
||||
# initialize test data
|
||||
a, b, a_path, b_path, c_path = generate_matmul_test_data(tmp_dir, M, N, K)
|
||||
c_ref = np.matmul(a.astype(np.float32), b.astype(np.float32))
|
||||
|
||||
for algo_id in range(len(tile_sizes)):
|
||||
# generate and run test case
|
||||
test_name = f"test_{algo_id}"
|
||||
gen_test_bin(tmp_dir, M, N, K, exe=test_name, algo_id=algo_id)
|
||||
|
||||
env = os.environ.copy()
|
||||
env["LD_LIBRARY_PATH"] = tmp_dir
|
||||
subprocess.run([f"./{test_name}", a_path, b_path, c_path], check=True, cwd=tmp_dir, env=env)
|
||||
|
||||
# read data and compare against reference
|
||||
c = np.genfromtxt(c_path, delimiter=",", dtype=np.int32)
|
||||
c_tri = c.reshape((M, N)).view(np.float32)
|
||||
np.testing.assert_allclose(c_tri, c_ref * c_ref, atol=1e-4, rtol=1e-4)
|
||||
|
||||
|
||||
def test_ttgir_to_ptx():
|
||||
src = """
|
||||
module attributes {"triton_gpu.num-warps" = 4 : i32, "triton_gpu.threads-per-warp" = 32 : i32, "triton_gpu.num-ctas" = 1 : i32} {
|
||||
|
||||
Reference in New Issue
Block a user