mirror of
https://github.com/ROCm/ROCm.git
synced 2026-04-05 03:01:17 -04:00
- Support memory space for pointers (e.g., `!tt.ptr<f32, 1>`). - Support parsing function attribute, though not used yet.
346 lines
11 KiB
Python
346 lines
11 KiB
Python
import glob
|
|
import os
|
|
import subprocess
|
|
import sys
|
|
import tempfile
|
|
|
|
import numpy as np
|
|
|
|
import triton
|
|
from triton.common import cuda_include_dir, libcuda_dirs
|
|
|
|
kernel_utils_src = """
|
|
import triton
|
|
|
|
@triton.jit
|
|
def mul(x, y):
|
|
return x * y
|
|
"""
|
|
|
|
kernel_src = """
|
|
import triton
|
|
import triton.language as tl
|
|
import kernel_utils
|
|
|
|
@triton.jit
|
|
def kernel(C, A, B, M, N, K,
|
|
stride_cm, stride_cn,
|
|
stride_am, stride_ak,
|
|
stride_bk, stride_bn,
|
|
BLOCK_M: tl.constexpr,
|
|
BLOCK_N: tl.constexpr,
|
|
BLOCK_K: tl.constexpr):
|
|
pid_m = tl.program_id(0)
|
|
pid_n = tl.program_id(1)
|
|
|
|
offs_am = (pid_m * BLOCK_M + tl.arange(0, BLOCK_M)) % M
|
|
offs_bn = (pid_n * BLOCK_N + tl.arange(0, BLOCK_N)) % N
|
|
offs_k = tl.arange(0, BLOCK_K)
|
|
a_ptrs = A + (offs_am[:, None] * stride_am + offs_k[None, :] * stride_ak)
|
|
b_ptrs = B + (offs_k[:, None] * stride_bk + offs_bn[None, :] * stride_bn)
|
|
|
|
accumulator = tl.zeros((BLOCK_M, BLOCK_N), dtype=tl.float32)
|
|
for k in range(0, tl.cdiv(K, BLOCK_K)):
|
|
# Load the next block of A and B, generate a mask by checking the K dimension.
|
|
# If it is out of bounds, set it to 0.
|
|
a = tl.load(a_ptrs, mask=offs_k[None, :] < K - k * BLOCK_K, other=0.0)
|
|
b = tl.load(b_ptrs, mask=offs_k[:, None] < K - k * BLOCK_K, other=0.0)
|
|
# We accumulate along the K dimension.
|
|
accumulator += tl.dot(a, b)
|
|
# Advance the ptrs to the next K block.
|
|
a_ptrs += BLOCK_K * stride_ak
|
|
b_ptrs += BLOCK_K * stride_bk
|
|
|
|
c = kernel_utils.mul(accumulator, accumulator)
|
|
# Write back the block of the output matrix C with masks.
|
|
offs_cm = pid_m * BLOCK_M + tl.arange(0, BLOCK_M)
|
|
offs_cn = pid_n * BLOCK_N + tl.arange(0, BLOCK_N)
|
|
c_ptrs = C + stride_cm * offs_cm[:, None] + stride_cn * offs_cn[None, :]
|
|
tl.store(c_ptrs, c)
|
|
"""
|
|
|
|
test_utils_src = '''
|
|
#include <cuda.h>
|
|
#include <stdio.h>
|
|
#include <stdint.h>
|
|
#include <string.h>
|
|
#include <assert.h>
|
|
#include "kernel.h"
|
|
|
|
static void write_buffer_to_csv(char *filename, int32_t *buffer, int size) {
|
|
FILE *file = fopen(filename, "w");
|
|
if (file == NULL) {
|
|
printf("Could not open file %s\\n", filename);
|
|
return;
|
|
}
|
|
for (int i = 0; i < size; i++) {
|
|
fprintf(file, "%d", buffer[i]);
|
|
if (i < size - 1) {
|
|
fprintf(file, ",");
|
|
}
|
|
}
|
|
fclose(file);
|
|
}
|
|
|
|
static void read_csv_to_buffer(char *filename, int16_t *buffer, int size) {
|
|
FILE *file = fopen(filename, "r");
|
|
if (file == NULL) {
|
|
printf("Could not open file %s\\n", filename);
|
|
return;
|
|
}
|
|
int index = 0;
|
|
while (fscanf(file, "%hd,", &buffer[index]) != EOF && index < size) {
|
|
index++;
|
|
}
|
|
fclose(file);
|
|
}'''
|
|
|
|
|
|
def gen_kernel_library(dir, libname):
|
|
c_files = glob.glob(os.path.join(dir, "*.c"))
|
|
subprocess.run(["gcc"] + c_files + ["-I", cuda_include_dir(),
|
|
"-c", "-fPIC"],
|
|
check=True, cwd=dir)
|
|
o_files = glob.glob(os.path.join(dir, "*.o"))
|
|
subprocess.run(["gcc"] + o_files + ["-shared",
|
|
"-o", libname,
|
|
"-L", libcuda_dirs()[0]],
|
|
check=True, cwd=dir)
|
|
|
|
|
|
def gen_test_bin(dir, M, N, K, exe="test", algo_id=0):
|
|
test_src = f'''
|
|
int main(int argc, char **argv) {{
|
|
int M = {M}, N = {N}, K = {K};
|
|
|
|
// initialize CUDA handles
|
|
CUdevice dev;
|
|
CUcontext ctx;
|
|
CUstream stream;
|
|
CUdeviceptr A, B, C;
|
|
CUresult err = 0;
|
|
cuInit(0);
|
|
cuDeviceGet(&dev, 0);
|
|
cuCtxCreate(&ctx, 0, dev);
|
|
cuMemAlloc(&A, M * K * 2);
|
|
cuMemAlloc(&B, K * N * 2);
|
|
cuMemAlloc(&C, M * N * 4);
|
|
cuStreamCreate(&stream, 0);
|
|
load_matmul_fp16();
|
|
|
|
// initialize input data
|
|
int16_t hA[M*K];
|
|
int16_t hB[K*N];
|
|
memset(hA, 0, M*K*2);
|
|
memset(hB, 0, K*N*2);
|
|
read_csv_to_buffer(argv[1], hA, M*K);
|
|
read_csv_to_buffer(argv[2], hB, K*N);
|
|
cuMemcpyHtoD(A, hA, M*K*2);
|
|
cuMemcpyHtoD(B, hB, K*N*2);
|
|
|
|
// launch kernel
|
|
cuStreamSynchronize(stream);
|
|
CUresult ret;
|
|
int algo_id = {algo_id};
|
|
if (algo_id == 0) {{
|
|
ret = matmul_fp16_default(stream, C, A, B, M, N, K, N, 1, K, 1, N, 1);
|
|
}} else {{
|
|
ret = matmul_fp16(stream, C, A, B, M, N, K, N, 1, K, 1, N, 1, {algo_id});
|
|
}}
|
|
if (ret != 0) fprintf(stderr, "kernel launch failed\\n");
|
|
assert(ret == 0);
|
|
|
|
cuStreamSynchronize(stream);
|
|
|
|
// read data
|
|
int32_t hC[M*N];
|
|
memset(hC, 0, M*N*4);
|
|
cuMemcpyDtoH(hC, C, M*N*4);
|
|
write_buffer_to_csv(argv[3], hC, M*N);
|
|
|
|
// free cuda handles
|
|
unload_matmul_fp16();
|
|
cuMemFree(A);
|
|
cuMemFree(B);
|
|
cuMemFree(C);
|
|
cuCtxDestroy(ctx);
|
|
}}
|
|
'''
|
|
src = test_utils_src + test_src
|
|
with open(os.path.join(dir, "test.c"), "w") as file:
|
|
file.write(src)
|
|
subprocess.run(["gcc"] + ["test.c",
|
|
"-I", cuda_include_dir(),
|
|
"-L", libcuda_dirs()[0],
|
|
"-l", "cuda",
|
|
"-L", dir,
|
|
"-l", "kernel",
|
|
"-o", exe], check=True, cwd=dir)
|
|
|
|
|
|
def write_triton_kernels(dir, src, util_src):
|
|
kernel_path = os.path.join(dir, "kernel.py")
|
|
with open(kernel_path, "w") as file:
|
|
file.write(src)
|
|
|
|
kernel_utils_path = os.path.join(dir, "kernel_utils.py")
|
|
with open(kernel_utils_path, "w") as file:
|
|
file.write(util_src)
|
|
|
|
return kernel_path
|
|
|
|
|
|
def compile_aot_kernels(dir, kernel_path, dtype, BM, BN, BK, ha_hb_hints):
|
|
compiler_path = os.path.join(triton.tools.__path__[0], "compile.py")
|
|
|
|
# compile all desired configs
|
|
for ha in ha_hb_hints:
|
|
for hb in ha_hb_hints:
|
|
sig = f'*fp32:16, *{dtype}:16, *{dtype}:16, i32, i32, i32, i32{ha}, i32:1, i32{hb}, i32:1, i32:16, i32:1, {BM}, {BN}, {BK}'
|
|
name = f"matmul_{dtype}"
|
|
grid = f'M/{BM}, N/{BN}, 1'
|
|
subprocess.run([sys.executable, compiler_path, "-n", "kernel", "--signature", sig, "--out-name", name, "-o", name, "-w", "1", "-g", grid, kernel_path], check=True, cwd=dir)
|
|
|
|
|
|
def link_aot_kernels(dir):
|
|
linker_path = os.path.join(triton.tools.__path__[0], "link.py")
|
|
|
|
# link all desired configs
|
|
h_files = glob.glob(os.path.join(dir, "*.h"))
|
|
subprocess.run([sys.executable, linker_path] + h_files + ["-o", "kernel"], check=True, cwd=dir)
|
|
|
|
|
|
def generate_matmul_test_data(dir, M, N, K):
|
|
a = np.random.randn(M * K).astype(np.float16).reshape((M, K))
|
|
b = np.random.randn(M * K).astype(np.float16).reshape((K, N))
|
|
a_path = os.path.join(dir, "a.csv")
|
|
b_path = os.path.join(dir, "b.csv")
|
|
c_path = os.path.join(dir, "c.csv")
|
|
for x, path in [(a, a_path), (b, b_path)]:
|
|
x.view(np.int16).ravel().tofile(path, sep=",")
|
|
return a, b, a_path, b_path, c_path
|
|
|
|
|
|
def test_compile_link_matmul():
|
|
np.random.seed(3)
|
|
|
|
with tempfile.TemporaryDirectory() as tmp_dir:
|
|
|
|
dtype = "fp16"
|
|
BM, BN, BK = 16, 16, 16
|
|
|
|
kernel_path = write_triton_kernels(tmp_dir, kernel_src, kernel_utils_src)
|
|
compile_aot_kernels(tmp_dir, kernel_path, dtype, BM, BN, BK, ha_hb_hints=["", ":16"])
|
|
link_aot_kernels(tmp_dir)
|
|
|
|
# compile test case
|
|
M, N, K = 16, 16, 16
|
|
gen_kernel_library(tmp_dir, "libkernel.so")
|
|
gen_test_bin(tmp_dir, M, N, K)
|
|
|
|
# initialize test data
|
|
a, b, a_path, b_path, c_path = generate_matmul_test_data(tmp_dir, M, N, K)
|
|
|
|
# run test case
|
|
env = os.environ.copy()
|
|
env["LD_LIBRARY_PATH"] = tmp_dir
|
|
subprocess.run(["./test", a_path, b_path, c_path], env=env, check=True, cwd=tmp_dir)
|
|
|
|
# read data and compare against reference
|
|
c = np.genfromtxt(c_path, delimiter=",", dtype=np.int32)
|
|
c_tri = c.reshape((M, N)).view(np.float32)
|
|
c_ref = np.matmul(a.astype(np.float32), b.astype(np.float32))
|
|
np.testing.assert_allclose(c_tri, c_ref * c_ref, atol=1e-4, rtol=0.)
|
|
|
|
|
|
def test_launcher_has_no_available_kernel():
|
|
np.random.seed(3)
|
|
|
|
with tempfile.TemporaryDirectory() as tmp_dir:
|
|
dtype = "fp16"
|
|
BM, BN, BK = 16, 16, 16
|
|
|
|
kernel_path = write_triton_kernels(tmp_dir, kernel_src, kernel_utils_src)
|
|
compile_aot_kernels(tmp_dir, kernel_path, dtype, BM, BN, BK, ha_hb_hints=[":1"])
|
|
link_aot_kernels(tmp_dir)
|
|
|
|
# compile test case
|
|
M, N, K = 16, 16, 16
|
|
gen_kernel_library(tmp_dir, "libkernel.so")
|
|
gen_test_bin(tmp_dir, M, N, K)
|
|
|
|
# initialize test data
|
|
a, b, a_path, b_path, c_path = generate_matmul_test_data(tmp_dir, M, N, K)
|
|
|
|
# run test case
|
|
env = os.environ.copy()
|
|
env["LD_LIBRARY_PATH"] = tmp_dir
|
|
result = subprocess.run(["./test", a_path, b_path, c_path], env=env, cwd=tmp_dir, capture_output=True, text=True)
|
|
|
|
# It should fail since the launcher requires all the strides be 1 while they are not.
|
|
assert result.returncode == -6
|
|
assert "kernel launch failed" in result.stderr
|
|
|
|
|
|
def test_compile_link_autotune_matmul():
|
|
np.random.seed(3)
|
|
|
|
with tempfile.TemporaryDirectory() as tmp_dir:
|
|
|
|
dtype = "fp16"
|
|
|
|
kernel_path = write_triton_kernels(tmp_dir, kernel_src, kernel_utils_src)
|
|
|
|
tile_sizes = [
|
|
[16, 16, 16],
|
|
[32, 32, 16],
|
|
[32, 32, 32],
|
|
[64, 64, 32],
|
|
]
|
|
|
|
for ts in tile_sizes:
|
|
BM, BN, BK = ts[0], ts[1], ts[2]
|
|
compile_aot_kernels(tmp_dir, kernel_path, dtype, BM, BN, BK, ha_hb_hints=["", ":16"])
|
|
|
|
link_aot_kernels(tmp_dir)
|
|
|
|
gen_kernel_library(tmp_dir, "libkernel.so")
|
|
|
|
# compile test case
|
|
M, N, K = 64, 64, 64
|
|
# initialize test data
|
|
a, b, a_path, b_path, c_path = generate_matmul_test_data(tmp_dir, M, N, K)
|
|
c_ref = np.matmul(a.astype(np.float32), b.astype(np.float32))
|
|
|
|
for algo_id in range(len(tile_sizes)):
|
|
# generate and run test case
|
|
test_name = f"test_{algo_id}"
|
|
gen_test_bin(tmp_dir, M, N, K, exe=test_name, algo_id=algo_id)
|
|
|
|
env = os.environ.copy()
|
|
env["LD_LIBRARY_PATH"] = tmp_dir
|
|
subprocess.run([f"./{test_name}", a_path, b_path, c_path], check=True, cwd=tmp_dir, env=env)
|
|
|
|
# read data and compare against reference
|
|
c = np.genfromtxt(c_path, delimiter=",", dtype=np.int32)
|
|
c_tri = c.reshape((M, N)).view(np.float32)
|
|
np.testing.assert_allclose(c_tri, c_ref * c_ref, atol=1e-4, rtol=1e-4)
|
|
|
|
|
|
def test_ttgir_to_ptx():
|
|
src = """
|
|
module attributes {"triton_gpu.num-warps" = 4 : i32, "triton_gpu.threads-per-warp" = 32 : i32, "triton_gpu.num-ctas" = 1 : i32} {
|
|
tt.func public @sum_kernel_0d1d(%arg0: !tt.ptr<i32, 1>, %arg1: !tt.ptr<i32, 1>) {
|
|
tt.return
|
|
}
|
|
}
|
|
"""
|
|
with tempfile.TemporaryDirectory() as tmp_dir:
|
|
kernel_path = os.path.join(tmp_dir, "empty_kernel.ttgir")
|
|
with open(kernel_path, "w") as fp:
|
|
fp.write(src)
|
|
k = triton.compile(kernel_path, cc=80)
|
|
ptx = k.asm["ptx"]
|
|
assert ".target sm_80" in ptx
|
|
assert ".address_size 64" in ptx
|