mirror of
https://github.com/ROCm/ROCm.git
synced 2026-04-05 03:01:17 -04:00
[FRONTEND] fix the silent return issue in AOT launcher (#2013)
In the current link.py, it produces the launcher code as below:
```python
CUresult matmul_fp16xfp16_16x16x16(CUstream stream, unsigned int gX, unsigned int gY, unsigned int gZ, CUdeviceptr C, CUdeviceptr A, CUdeviceptr B, int32_t stride_cm, int32_t stride_am, int32_t stride_bk){
if ((C % 16 == 0) && (A % 16 == 0) && (B % 16 == 0) && (stride_cm % 16 == 0))
return matmul_fp16xfp16_16x16x16_688cc413_0d1d2d3d45d(stream, gX, gY, gZ, C, A, B, stride_cm, stride_am, stride_bk);
// ...
if ((C % 16 == 0) && (A % 16 == 0) && (B % 16 == 0))
return matmul_fp16xfp16_16x16x16_7c0255bf_0d1d2d345(stream, gX, gY, gZ, C, A, B, stride_cm, stride_am, stride_bk);
}
```
Note that, when the input does not match any of the if branches, it will
do nothing, and the compiler should make it return 0 as a default
behavior, which equals to `CUDA_SUCCESS`, this doesn't match the
expectation.
This PR adds a `return CUDA_VALUE_ERROR;` to the tail of launchers, and
it produces code like:
```c++
CUresult matmul_fp16xfp16_16x16x16(CUstream stream, unsigned int gX, unsigned int gY, unsigned int gZ, CUdeviceptr C, CUdeviceptr A, CUdeviceptr B, int32_t stride_cm, int32_t stride_cn, int32_t stride_am, int32_t stride_ak, int32_t stride_bk, int32_t stride_bn){
if ((C % 16 == 0) && (A % 16 == 0) && (B % 16 == 0) && (stride_cm == 1) && (stride_cn == 1) && (stride_am == 1) && (stride_ak == 1) && (stride_bk % 16 == 0) && (stride_bn == 1))
return matmul_fp16xfp16_16x16x16_1f18a6da_0d1d2d3c4c5c6c7d8c(stream, gX, gY, gZ, C, A, B, stride_bk);
return CUDA_ERROR_INVALID_VALUE;
}
```
And it requires users to check the result in their application, which I
think should match the initial AOT ideas.
This commit is contained in:
@@ -40,17 +40,20 @@ def kernel(C, A, B,
|
||||
tl.store(C + ms[:, None] * stride_cm + ns[None, :] * stride_cn, c)
|
||||
"""
|
||||
|
||||
test_src = """
|
||||
|
||||
def gen_test_bin(dir, M, N, K, BM, BN, BK):
|
||||
test_src = '''
|
||||
#include <cuda.h>
|
||||
#include <stdio.h>
|
||||
#include <stdint.h>
|
||||
#include <string.h>
|
||||
#include <assert.h>
|
||||
#include "kernel.h"
|
||||
|
||||
static void write_buffer_to_csv(char *filename, int32_t *buffer, int size) {
|
||||
FILE *file = fopen(filename, "w");
|
||||
if (file == NULL) {
|
||||
printf(\"Could not open file %s\\n\", filename);
|
||||
printf("Could not open file %s\\n", filename);
|
||||
return;
|
||||
}
|
||||
for (int i = 0; i < size; i++) {
|
||||
@@ -65,7 +68,7 @@ static void write_buffer_to_csv(char *filename, int32_t *buffer, int size) {
|
||||
static void read_csv_to_buffer(char *filename, int16_t *buffer, int size) {
|
||||
FILE *file = fopen(filename, "r");
|
||||
if (file == NULL) {
|
||||
printf(\"Could not open file %s\\n\", filename);
|
||||
printf("Could not open file %s\\n", filename);
|
||||
return;
|
||||
}
|
||||
int index = 0;
|
||||
@@ -73,11 +76,12 @@ static void read_csv_to_buffer(char *filename, int16_t *buffer, int size) {
|
||||
index++;
|
||||
}
|
||||
fclose(file);
|
||||
}
|
||||
}'''
|
||||
|
||||
int main(int argc, char **argv) {
|
||||
int M = 16, N = 16, K = 16;
|
||||
int BM = 16, BN = 16, BK = 16;
|
||||
test_src += f'''
|
||||
int main(int argc, char **argv) {{
|
||||
int M = {M}, N = {N}, K = {K};
|
||||
int BM = {M}, BN = {N}, BK = {K};
|
||||
|
||||
// initialize CUDA handles
|
||||
CUdevice dev;
|
||||
@@ -105,9 +109,11 @@ int main(int argc, char **argv) {
|
||||
cuMemcpyHtoD(B, hB, K*N*2);
|
||||
|
||||
// launch kernel
|
||||
int gX = 1, gY = 1, gZ = 1;
|
||||
cuStreamSynchronize(stream);
|
||||
matmul_fp16xfp16_16x16x16(stream, M/BM, N/BN, 1, C, A, B, N, 1, K, 1, N, 1);
|
||||
CUresult ret = matmul_fp16xfp16_16x16x16(stream, M/BM, N/BN, 1, C, A, B, N, 1, K, 1, N, 1);
|
||||
if (ret != 0) fprintf(stderr, "kernel launch failed\\n");
|
||||
assert(ret == 0);
|
||||
|
||||
cuStreamSynchronize(stream);
|
||||
|
||||
// read data
|
||||
@@ -116,65 +122,75 @@ int main(int argc, char **argv) {
|
||||
cuMemcpyDtoH(hC, C, M*N*4);
|
||||
write_buffer_to_csv(argv[3], hC, M*N);
|
||||
|
||||
|
||||
// free cuda handles
|
||||
unload_matmul_fp16xfp16_16x16x16();
|
||||
cuMemFree(A);
|
||||
cuMemFree(B);
|
||||
cuMemFree(C);
|
||||
cuCtxDestroy(ctx);
|
||||
}
|
||||
"""
|
||||
}}
|
||||
'''
|
||||
|
||||
with open(os.path.join(dir, "test.c"), "w") as file:
|
||||
file.write(test_src)
|
||||
c_files = glob.glob(os.path.join(dir, "*.c"))
|
||||
subprocess.run(["gcc"] + c_files + ["-I", cuda_include_dir(),
|
||||
"-L", libcuda_dirs()[0],
|
||||
"-l", "cuda",
|
||||
"-o", "test"], check=True, cwd=dir)
|
||||
|
||||
|
||||
def generate_matmul_launcher(dir, dtype, BM, BN, BK, ha_hb_hints):
|
||||
kernel_path = os.path.join(dir, "kernel.py")
|
||||
with open(kernel_path, "w") as file:
|
||||
file.write(kernel_src)
|
||||
|
||||
kernel_utils_path = os.path.join(dir, "kernel_utils.py")
|
||||
with open(kernel_utils_path, "w") as file:
|
||||
file.write(kernel_utils_src)
|
||||
|
||||
compiler_path = os.path.join(triton.tools.__path__[0], "compile.py")
|
||||
linker_path = os.path.join(triton.tools.__path__[0], "link.py")
|
||||
|
||||
# compile all desired configs
|
||||
for ha in ha_hb_hints:
|
||||
for hb in ha_hb_hints:
|
||||
sig = f'*fp32:16, *{dtype}:16, *{dtype}:16, i32{ha}, i32:1, i32{hb}, i32:1, i32:16, i32:1, {BM}, {BN}, {BK}'
|
||||
name = f"matmul_{dtype}x{dtype}_{BM}x{BN}x{BK}"
|
||||
subprocess.run([sys.executable, compiler_path, "-n", "kernel", "--signature", sig, "--out-name", name, "-o", name, "-w", "1", kernel_path], check=True, cwd=dir)
|
||||
|
||||
# link all desired configs
|
||||
h_files = glob.glob(os.path.join(dir, "*.h"))
|
||||
subprocess.run([sys.executable, linker_path] + h_files + ["-o", "kernel"], check=True, cwd=dir)
|
||||
|
||||
|
||||
def generate_matmul_test_data(dir, M, N, K):
|
||||
a = np.random.randn(M * K).astype(np.float16).reshape((M, K))
|
||||
b = np.random.randn(M * K).astype(np.float16).reshape((K, N))
|
||||
a_path = os.path.join(dir, "a.csv")
|
||||
b_path = os.path.join(dir, "b.csv")
|
||||
c_path = os.path.join(dir, "c.csv")
|
||||
for x, path in [(a, a_path), (b, b_path)]:
|
||||
x.view(np.int16).ravel().tofile(path, sep=",")
|
||||
return a, b, a_path, b_path, c_path
|
||||
|
||||
|
||||
def test_compile_link_matmul():
|
||||
np.random.seed(3)
|
||||
|
||||
with tempfile.TemporaryDirectory() as tmp_dir:
|
||||
kernel_path = os.path.join(tmp_dir, "kernel.py")
|
||||
with open(kernel_path, "w") as file:
|
||||
file.write(kernel_src)
|
||||
|
||||
kernel_utils_path = os.path.join(tmp_dir, "kernel_utils.py")
|
||||
with open(kernel_utils_path, "w") as file:
|
||||
file.write(kernel_utils_src)
|
||||
|
||||
compiler_path = os.path.join(triton.tools.__path__[0], "compile.py")
|
||||
linker_path = os.path.join(triton.tools.__path__[0], "link.py")
|
||||
|
||||
dtype = "fp16"
|
||||
M, N, K = 16, 16, 16
|
||||
BM, BN, BK = 16, 16, 16
|
||||
|
||||
# compile all desired configs
|
||||
hints = [":16", ""]
|
||||
for ha in hints:
|
||||
for hb in hints:
|
||||
sig = f'*fp32:16, *{dtype}:16, *{dtype}:16, i32{ha}, i32:1, i32{hb}, i32:1, i32:16, i32:1, {BM}, {BN}, {BK}'
|
||||
name = f"matmul_{dtype}x{dtype}_{BM}x{BN}x{BK}"
|
||||
subprocess.run([sys.executable, compiler_path, "-n", "kernel", "--signature", sig, "--out-name", name, "-o", name, "-w", "1", kernel_path], check=True, cwd=tmp_dir)
|
||||
|
||||
# link all desired configs
|
||||
h_files = glob.glob(os.path.join(tmp_dir, "*.h"))
|
||||
subprocess.run([sys.executable, linker_path] + h_files + ["-o", "kernel"], check=True, cwd=tmp_dir)
|
||||
generate_matmul_launcher(tmp_dir, dtype, BM, BN, BK, ha_hb_hints=["", ":16"])
|
||||
|
||||
# compile test case
|
||||
with open(os.path.join(tmp_dir, "test.c"), "w") as file:
|
||||
file.write(test_src)
|
||||
c_files = glob.glob(os.path.join(tmp_dir, "*.c"))
|
||||
subprocess.run(["gcc"] + c_files + ["-I", cuda_include_dir(),
|
||||
"-L", libcuda_dirs()[0],
|
||||
"-l", "cuda",
|
||||
"-o", "test"], check=True, cwd=tmp_dir)
|
||||
M, N, K = 16, 16, 16
|
||||
gen_test_bin(tmp_dir, M, N, K, BM, BN, BK)
|
||||
|
||||
# initialize test data
|
||||
a = np.random.randn(M * K).astype(np.float16).reshape((M, K))
|
||||
b = np.random.randn(M * K).astype(np.float16).reshape((K, N))
|
||||
a_path = os.path.join(tmp_dir, "a.csv")
|
||||
b_path = os.path.join(tmp_dir, "b.csv")
|
||||
c_path = os.path.join(tmp_dir, "c.csv")
|
||||
for x, path in [(a, a_path), (b, b_path)]:
|
||||
x.view(np.int16).ravel().tofile(path, sep=",")
|
||||
a, b, a_path, b_path, c_path = generate_matmul_test_data(tmp_dir, M, N, K)
|
||||
|
||||
# run test case
|
||||
subprocess.run(["./test", a_path, b_path, c_path], check=True, cwd=tmp_dir)
|
||||
@@ -186,6 +202,30 @@ def test_compile_link_matmul():
|
||||
np.testing.assert_allclose(c_tri, c_ref * c_ref, atol=1e-4, rtol=0.)
|
||||
|
||||
|
||||
def test_launcher_has_no_available_kernel():
|
||||
np.random.seed(3)
|
||||
|
||||
with tempfile.TemporaryDirectory() as tmp_dir:
|
||||
dtype = "fp16"
|
||||
BM, BN, BK = 16, 16, 16
|
||||
|
||||
generate_matmul_launcher(tmp_dir, dtype, BM, BN, BK, ha_hb_hints=[":1"])
|
||||
|
||||
# compile test case
|
||||
M, N, K = 16, 16, 16
|
||||
gen_test_bin(tmp_dir, M, N, K, BM, BN, BK)
|
||||
|
||||
# initialize test data
|
||||
a, b, a_path, b_path, c_path = generate_matmul_test_data(tmp_dir, M, N, K)
|
||||
|
||||
# run test case
|
||||
result = subprocess.run(["./test", a_path, b_path, c_path], cwd=tmp_dir, capture_output=True, text=True)
|
||||
|
||||
# It should fail since the launcher requires all the strides be 1 while they are not.
|
||||
assert result.returncode == -6
|
||||
assert "kernel launch failed" in result.stderr
|
||||
|
||||
|
||||
def test_ttgir_to_ptx():
|
||||
src = """
|
||||
module attributes {"triton_gpu.num-warps" = 4 : i32, "triton_gpu.threads-per-warp" = 32 : i32} {
|
||||
|
||||
Reference in New Issue
Block a user