[TOOLS] Add support for autotuning AOT kernel (#2123)

This PR makes the following change to AOT kernel

- Allow the client to generate AOT kernels with different sets of
constexprs and meta-parameters. Each combination of constexpr set and
meta-parameters is referred to an "algo". Within an algo client can
still give different hints about integer arguments.
- Add a API int ${kernle_name}_get_num_algos() that returns the total
number of algos.
- Add a algo_id to allow client to the generated kernel to select the
algo
- Remove gX, gY and gZ from the kernel parameter list. This is because
the launch grid is usually different with different algos, and the
client should not need to care about how to compute the launch grid for
each algo. Instead, we ask the client to pass the expression of
computing gX, gY and gZ for compile.py (when AOT kernels are generated).
The expression can only use kernel parameter or const values.
- We also change the testing flow. Now we first build the kernels into a
shared library libkernel.so, then the client test.c code is built and
link with libkernel.so. This is closer to a typical AOT kernel usage
flow.
This commit is contained in:
Bin Fan
2023-08-23 09:38:29 -07:00
committed by GitHub
parent 5282ed890d
commit dad83f9dcb
5 changed files with 267 additions and 60 deletions

View File

@@ -23,26 +23,43 @@ import triton.language as tl
import kernel_utils
@triton.jit
def kernel(C, A, B,
def kernel(C, A, B, M, N, K,
stride_cm, stride_cn,
stride_am, stride_ak,
stride_bk, stride_bn,
BLOCK_M: tl.constexpr,
BLOCK_N: tl.constexpr,
BLOCK_K: tl.constexpr):
ms = tl.arange(0, BLOCK_M)
ns = tl.arange(0, BLOCK_N)
ks = tl.arange(0, BLOCK_K)
a = tl.load(A + ms[:, None] * stride_am + ks[None, :] * stride_ak)
b = tl.load(B + ks[:, None] * stride_bk + ns[None, :] * stride_bn)
c = tl.dot(a, b)
c = kernel_utils.mul(c, c)
tl.store(C + ms[:, None] * stride_cm + ns[None, :] * stride_cn, c)
pid_m = tl.program_id(0)
pid_n = tl.program_id(1)
offs_am = (pid_m * BLOCK_M + tl.arange(0, BLOCK_M)) % M
offs_bn = (pid_n * BLOCK_N + tl.arange(0, BLOCK_N)) % N
offs_k = tl.arange(0, BLOCK_K)
a_ptrs = A + (offs_am[:, None] * stride_am + offs_k[None, :] * stride_ak)
b_ptrs = B + (offs_k[:, None] * stride_bk + offs_bn[None, :] * stride_bn)
accumulator = tl.zeros((BLOCK_M, BLOCK_N), dtype=tl.float32)
for k in range(0, tl.cdiv(K, BLOCK_K)):
# Load the next block of A and B, generate a mask by checking the K dimension.
# If it is out of bounds, set it to 0.
a = tl.load(a_ptrs, mask=offs_k[None, :] < K - k * BLOCK_K, other=0.0)
b = tl.load(b_ptrs, mask=offs_k[:, None] < K - k * BLOCK_K, other=0.0)
# We accumulate along the K dimension.
accumulator += tl.dot(a, b)
# Advance the ptrs to the next K block.
a_ptrs += BLOCK_K * stride_ak
b_ptrs += BLOCK_K * stride_bk
c = kernel_utils.mul(accumulator, accumulator)
# Write back the block of the output matrix C with masks.
offs_cm = pid_m * BLOCK_M + tl.arange(0, BLOCK_M)
offs_cn = pid_n * BLOCK_N + tl.arange(0, BLOCK_N)
c_ptrs = C + stride_cm * offs_cm[:, None] + stride_cn * offs_cn[None, :]
tl.store(c_ptrs, c)
"""
def gen_test_bin(dir, M, N, K, BM, BN, BK):
test_src = '''
test_utils_src = '''
#include <cuda.h>
#include <stdio.h>
#include <stdint.h>
@@ -78,10 +95,23 @@ static void read_csv_to_buffer(char *filename, int16_t *buffer, int size) {
fclose(file);
}'''
test_src += f'''
def gen_kernel_library(dir, libname):
c_files = glob.glob(os.path.join(dir, "*.c"))
subprocess.run(["gcc"] + c_files + ["-I", cuda_include_dir(),
"-c", "-fPIC"],
check=True, cwd=dir)
o_files = glob.glob(os.path.join(dir, "*.o"))
subprocess.run(["gcc"] + o_files + ["-shared",
"-o", libname,
"-L", libcuda_dirs()[0]],
check=True, cwd=dir)
def gen_test_bin(dir, M, N, K, exe="test", algo_id=0):
test_src = f'''
int main(int argc, char **argv) {{
int M = {M}, N = {N}, K = {K};
int BM = {M}, BN = {N}, BK = {K};
// initialize CUDA handles
CUdevice dev;
@@ -96,7 +126,7 @@ int main(int argc, char **argv) {{
cuMemAlloc(&B, K * N * 2);
cuMemAlloc(&C, M * N * 4);
cuStreamCreate(&stream, 0);
load_matmul_fp16xfp16_16x16x16();
load_matmul_fp16();
// initialize input data
int16_t hA[M*K];
@@ -110,7 +140,13 @@ int main(int argc, char **argv) {{
// launch kernel
cuStreamSynchronize(stream);
CUresult ret = matmul_fp16xfp16_16x16x16(stream, M/BM, N/BN, 1, C, A, B, N, 1, K, 1, N, 1);
CUresult ret;
int algo_id = {algo_id};
if (algo_id == 0) {{
ret = matmul_fp16_default(stream, C, A, B, M, N, K, N, 1, K, 1, N, 1);
}} else {{
ret = matmul_fp16(stream, C, A, B, M, N, K, N, 1, K, 1, N, 1, {algo_id});
}}
if (ret != 0) fprintf(stderr, "kernel launch failed\\n");
assert(ret == 0);
@@ -123,41 +159,51 @@ int main(int argc, char **argv) {{
write_buffer_to_csv(argv[3], hC, M*N);
// free cuda handles
unload_matmul_fp16xfp16_16x16x16();
unload_matmul_fp16();
cuMemFree(A);
cuMemFree(B);
cuMemFree(C);
cuCtxDestroy(ctx);
}}
'''
src = test_utils_src + test_src
with open(os.path.join(dir, "test.c"), "w") as file:
file.write(test_src)
c_files = glob.glob(os.path.join(dir, "*.c"))
subprocess.run(["gcc"] + c_files + ["-I", cuda_include_dir(),
"-L", libcuda_dirs()[0],
"-l", "cuda",
"-o", "test"], check=True, cwd=dir)
file.write(src)
subprocess.run(["gcc"] + ["test.c",
"-I", cuda_include_dir(),
"-L", libcuda_dirs()[0],
"-l", "cuda",
"-L", dir,
"-l", "kernel",
"-o", exe], check=True, cwd=dir)
def generate_matmul_launcher(dir, dtype, BM, BN, BK, ha_hb_hints):
def write_triton_kernels(dir, src, util_src):
kernel_path = os.path.join(dir, "kernel.py")
with open(kernel_path, "w") as file:
file.write(kernel_src)
file.write(src)
kernel_utils_path = os.path.join(dir, "kernel_utils.py")
with open(kernel_utils_path, "w") as file:
file.write(kernel_utils_src)
file.write(util_src)
return kernel_path
def compile_aot_kernels(dir, kernel_path, dtype, BM, BN, BK, ha_hb_hints):
compiler_path = os.path.join(triton.tools.__path__[0], "compile.py")
linker_path = os.path.join(triton.tools.__path__[0], "link.py")
# compile all desired configs
for ha in ha_hb_hints:
for hb in ha_hb_hints:
sig = f'*fp32:16, *{dtype}:16, *{dtype}:16, i32{ha}, i32:1, i32{hb}, i32:1, i32:16, i32:1, {BM}, {BN}, {BK}'
name = f"matmul_{dtype}x{dtype}_{BM}x{BN}x{BK}"
subprocess.run([sys.executable, compiler_path, "-n", "kernel", "--signature", sig, "--out-name", name, "-o", name, "-w", "1", kernel_path], check=True, cwd=dir)
sig = f'*fp32:16, *{dtype}:16, *{dtype}:16, i32, i32, i32, i32{ha}, i32:1, i32{hb}, i32:1, i32:16, i32:1, {BM}, {BN}, {BK}'
name = f"matmul_{dtype}"
grid = f'M/{BM}, N/{BN}, 1'
subprocess.run([sys.executable, compiler_path, "-n", "kernel", "--signature", sig, "--out-name", name, "-o", name, "-w", "1", "-g", grid, kernel_path], check=True, cwd=dir)
def link_aot_kernels(dir):
linker_path = os.path.join(triton.tools.__path__[0], "link.py")
# link all desired configs
h_files = glob.glob(os.path.join(dir, "*.h"))
@@ -183,17 +229,22 @@ def test_compile_link_matmul():
dtype = "fp16"
BM, BN, BK = 16, 16, 16
generate_matmul_launcher(tmp_dir, dtype, BM, BN, BK, ha_hb_hints=["", ":16"])
kernel_path = write_triton_kernels(tmp_dir, kernel_src, kernel_utils_src)
compile_aot_kernels(tmp_dir, kernel_path, dtype, BM, BN, BK, ha_hb_hints=["", ":16"])
link_aot_kernels(tmp_dir)
# compile test case
M, N, K = 16, 16, 16
gen_test_bin(tmp_dir, M, N, K, BM, BN, BK)
gen_kernel_library(tmp_dir, "libkernel.so")
gen_test_bin(tmp_dir, M, N, K)
# initialize test data
a, b, a_path, b_path, c_path = generate_matmul_test_data(tmp_dir, M, N, K)
# run test case
subprocess.run(["./test", a_path, b_path, c_path], check=True, cwd=tmp_dir)
env = os.environ.copy()
env["LD_LIBRARY_PATH"] = tmp_dir
subprocess.run(["./test", a_path, b_path, c_path], env=env, check=True, cwd=tmp_dir)
# read data and compare against reference
c = np.genfromtxt(c_path, delimiter=",", dtype=np.int32)
@@ -209,23 +260,73 @@ def test_launcher_has_no_available_kernel():
dtype = "fp16"
BM, BN, BK = 16, 16, 16
generate_matmul_launcher(tmp_dir, dtype, BM, BN, BK, ha_hb_hints=[":1"])
kernel_path = write_triton_kernels(tmp_dir, kernel_src, kernel_utils_src)
compile_aot_kernels(tmp_dir, kernel_path, dtype, BM, BN, BK, ha_hb_hints=[":1"])
link_aot_kernels(tmp_dir)
# compile test case
M, N, K = 16, 16, 16
gen_test_bin(tmp_dir, M, N, K, BM, BN, BK)
gen_kernel_library(tmp_dir, "libkernel.so")
gen_test_bin(tmp_dir, M, N, K)
# initialize test data
a, b, a_path, b_path, c_path = generate_matmul_test_data(tmp_dir, M, N, K)
# run test case
result = subprocess.run(["./test", a_path, b_path, c_path], cwd=tmp_dir, capture_output=True, text=True)
env = os.environ.copy()
env["LD_LIBRARY_PATH"] = tmp_dir
result = subprocess.run(["./test", a_path, b_path, c_path], env=env, cwd=tmp_dir, capture_output=True, text=True)
# It should fail since the launcher requires all the strides be 1 while they are not.
assert result.returncode == -6
assert "kernel launch failed" in result.stderr
def test_compile_link_autotune_matmul():
np.random.seed(3)
with tempfile.TemporaryDirectory() as tmp_dir:
dtype = "fp16"
kernel_path = write_triton_kernels(tmp_dir, kernel_src, kernel_utils_src)
tile_sizes = [
[16, 16, 16],
[32, 32, 16],
[32, 32, 32],
[64, 64, 32],
]
for ts in tile_sizes:
BM, BN, BK = ts[0], ts[1], ts[2]
compile_aot_kernels(tmp_dir, kernel_path, dtype, BM, BN, BK, ha_hb_hints=["", ":16"])
link_aot_kernels(tmp_dir)
gen_kernel_library(tmp_dir, "libkernel.so")
# compile test case
M, N, K = 64, 64, 64
# initialize test data
a, b, a_path, b_path, c_path = generate_matmul_test_data(tmp_dir, M, N, K)
c_ref = np.matmul(a.astype(np.float32), b.astype(np.float32))
for algo_id in range(len(tile_sizes)):
# generate and run test case
test_name = f"test_{algo_id}"
gen_test_bin(tmp_dir, M, N, K, exe=test_name, algo_id=algo_id)
env = os.environ.copy()
env["LD_LIBRARY_PATH"] = tmp_dir
subprocess.run([f"./{test_name}", a_path, b_path, c_path], check=True, cwd=tmp_dir, env=env)
# read data and compare against reference
c = np.genfromtxt(c_path, delimiter=",", dtype=np.int32)
c_tri = c.reshape((M, N)).view(np.float32)
np.testing.assert_allclose(c_tri, c_ref * c_ref, atol=1e-4, rtol=1e-4)
def test_ttgir_to_ptx():
src = """
module attributes {"triton_gpu.num-warps" = 4 : i32, "triton_gpu.threads-per-warp" = 32 : i32, "triton_gpu.num-ctas" = 1 : i32} {