mirror of
https://github.com/ROCm/ROCm.git
synced 2026-02-21 03:00:39 -05:00
[FIX][AOT Compiler] Fix No Specialization Edge Case (#2584)
The [hints dispatching](218492cd65/python/triton/tools/link.py (L161)) logic currently fails for the edge case where a single kernel with no specializations is to be linked in the [AOT compiler](https://github.com/openai/triton/blob/main/python/triton/tools/link.py). Since the dispatcher inserts a conditional branch for each specialization case, this results in an `if ()` being inserted into the `C` source, which clearly breaks downstream artifacts. Fix: - Added simple check for this edge case - Added unit test that mirrors the existing [`test_compile_link_matmul`](218492cd65/python/test/unit/tools/test_aot.py (L224)) test case save for the aforementioned condition.
This commit is contained in:
@@ -59,7 +59,7 @@ def kernel(C, A, B, M, N, K,
|
||||
tl.store(c_ptrs, c)
|
||||
"""
|
||||
|
||||
test_utils_src = '''
|
||||
test_utils_src = """
|
||||
#include <cuda.h>
|
||||
#include <stdio.h>
|
||||
#include <stdint.h>
|
||||
@@ -93,23 +93,26 @@ static void read_csv_to_buffer(char *filename, int16_t *buffer, int size) {
|
||||
index++;
|
||||
}
|
||||
fclose(file);
|
||||
}'''
|
||||
}"""
|
||||
|
||||
|
||||
def gen_kernel_library(dir, libname):
|
||||
c_files = glob.glob(os.path.join(dir, "*.c"))
|
||||
subprocess.run(["gcc"] + c_files + ["-I", cuda_include_dir(),
|
||||
"-c", "-fPIC"],
|
||||
check=True, cwd=dir)
|
||||
subprocess.run(
|
||||
["gcc"] + c_files + ["-I", cuda_include_dir(), "-c", "-fPIC"],
|
||||
check=True,
|
||||
cwd=dir,
|
||||
)
|
||||
o_files = glob.glob(os.path.join(dir, "*.o"))
|
||||
subprocess.run(["gcc"] + o_files + ["-shared",
|
||||
"-o", libname,
|
||||
"-L", libcuda_dirs()[0]],
|
||||
check=True, cwd=dir)
|
||||
subprocess.run(
|
||||
["gcc"] + o_files + ["-shared", "-o", libname, "-L", libcuda_dirs()[0]],
|
||||
check=True,
|
||||
cwd=dir,
|
||||
)
|
||||
|
||||
|
||||
def gen_test_bin(dir, M, N, K, exe="test", algo_id=0):
|
||||
test_src = f'''
|
||||
test_src = f"""
|
||||
int main(int argc, char **argv) {{
|
||||
int M = {M}, N = {N}, K = {K};
|
||||
|
||||
@@ -165,17 +168,30 @@ int main(int argc, char **argv) {{
|
||||
cuMemFree(C);
|
||||
cuCtxDestroy(ctx);
|
||||
}}
|
||||
'''
|
||||
"""
|
||||
src = test_utils_src + test_src
|
||||
with open(os.path.join(dir, "test.c"), "w") as file:
|
||||
file.write(src)
|
||||
subprocess.run(["gcc"] + ["test.c",
|
||||
"-I", cuda_include_dir(),
|
||||
"-L", libcuda_dirs()[0],
|
||||
"-l", "cuda",
|
||||
"-L", dir,
|
||||
"-l", "kernel",
|
||||
"-o", exe], check=True, cwd=dir)
|
||||
subprocess.run(
|
||||
["gcc"]
|
||||
+ [
|
||||
"test.c",
|
||||
"-I",
|
||||
cuda_include_dir(),
|
||||
"-L",
|
||||
libcuda_dirs()[0],
|
||||
"-l",
|
||||
"cuda",
|
||||
"-L",
|
||||
dir,
|
||||
"-l",
|
||||
"kernel",
|
||||
"-o",
|
||||
exe,
|
||||
],
|
||||
check=True,
|
||||
cwd=dir,
|
||||
)
|
||||
|
||||
|
||||
def write_triton_kernels(dir, src, util_src):
|
||||
@@ -190,16 +206,69 @@ def write_triton_kernels(dir, src, util_src):
|
||||
return kernel_path
|
||||
|
||||
|
||||
def compile_aot_kernels(dir, kernel_path, dtype, BM, BN, BK, ha_hb_hints):
|
||||
def _compile_kernel(
|
||||
dir, signature, kernel_name, out_name, out_path, num_warps, grid, kernel_path
|
||||
):
|
||||
compiler_path = os.path.join(triton.tools.__path__[0], "compile.py")
|
||||
|
||||
subprocess.run(
|
||||
[
|
||||
sys.executable,
|
||||
compiler_path,
|
||||
"-n",
|
||||
kernel_name,
|
||||
"--signature",
|
||||
signature,
|
||||
"--out-name",
|
||||
out_name,
|
||||
"-o",
|
||||
out_path,
|
||||
"-w",
|
||||
str(num_warps),
|
||||
"-g",
|
||||
grid,
|
||||
kernel_path,
|
||||
],
|
||||
check=True,
|
||||
cwd=dir,
|
||||
)
|
||||
|
||||
|
||||
# Edge case kernel with no specialization
|
||||
def compile_aot_kernel_no_specialization(dir, kernel_path, dtype, BM, BN, BK):
|
||||
# compile all desired configs
|
||||
sig = f"*fp32, *{dtype}, *{dtype}, i32, i32, i32, i32, i32, i32, i32, i32, i32, {BM}, {BN}, {BK}"
|
||||
name = f"matmul_{dtype}"
|
||||
grid = f"M/{BM}, N/{BN}, 1"
|
||||
_compile_kernel(
|
||||
dir=dir,
|
||||
signature=sig,
|
||||
kernel_name="kernel",
|
||||
out_name=name,
|
||||
out_path=name,
|
||||
num_warps=1,
|
||||
grid=grid,
|
||||
kernel_path=kernel_path,
|
||||
)
|
||||
|
||||
|
||||
def compile_aot_kernels(dir, kernel_path, dtype, BM, BN, BK, ha_hb_hints):
|
||||
# compile all desired configs
|
||||
for ha in ha_hb_hints:
|
||||
for hb in ha_hb_hints:
|
||||
sig = f'*fp32:16, *{dtype}:16, *{dtype}:16, i32, i32, i32, i32{ha}, i32:1, i32{hb}, i32:1, i32:16, i32:1, {BM}, {BN}, {BK}'
|
||||
sig = f"*fp32:16, *{dtype}:16, *{dtype}:16, i32, i32, i32, i32{ha}, i32:1, i32{hb}, i32:1, i32:16, i32:1, {BM}, {BN}, {BK}"
|
||||
name = f"matmul_{dtype}"
|
||||
grid = f'M/{BM}, N/{BN}, 1'
|
||||
subprocess.run([sys.executable, compiler_path, "-n", "kernel", "--signature", sig, "--out-name", name, "-o", name, "-w", "1", "-g", grid, kernel_path], check=True, cwd=dir)
|
||||
grid = f"M/{BM}, N/{BN}, 1"
|
||||
_compile_kernel(
|
||||
dir=dir,
|
||||
signature=sig,
|
||||
kernel_name="kernel",
|
||||
out_name=name,
|
||||
out_path=name,
|
||||
num_warps=1,
|
||||
grid=grid,
|
||||
kernel_path=kernel_path,
|
||||
)
|
||||
|
||||
|
||||
def link_aot_kernels(dir):
|
||||
@@ -207,7 +276,9 @@ def link_aot_kernels(dir):
|
||||
|
||||
# link all desired configs
|
||||
h_files = glob.glob(os.path.join(dir, "*.h"))
|
||||
subprocess.run([sys.executable, linker_path] + h_files + ["-o", "kernel"], check=True, cwd=dir)
|
||||
subprocess.run(
|
||||
[sys.executable, linker_path] + h_files + ["-o", "kernel"], check=True, cwd=dir
|
||||
)
|
||||
|
||||
|
||||
def generate_matmul_test_data(dir, M, N, K):
|
||||
@@ -221,16 +292,16 @@ def generate_matmul_test_data(dir, M, N, K):
|
||||
return a, b, a_path, b_path, c_path
|
||||
|
||||
|
||||
def test_compile_link_matmul():
|
||||
# Test edge case where the provided kernel signature has no specializations
|
||||
def test_compile_link_matmul_no_specialization():
|
||||
np.random.seed(3)
|
||||
|
||||
with tempfile.TemporaryDirectory() as tmp_dir:
|
||||
|
||||
dtype = "fp16"
|
||||
BM, BN, BK = 16, 16, 16
|
||||
|
||||
kernel_path = write_triton_kernels(tmp_dir, kernel_src, kernel_utils_src)
|
||||
compile_aot_kernels(tmp_dir, kernel_path, dtype, BM, BN, BK, ha_hb_hints=["", ":16"])
|
||||
compile_aot_kernel_no_specialization(tmp_dir, kernel_path, dtype, BM, BN, BK)
|
||||
link_aot_kernels(tmp_dir)
|
||||
|
||||
# compile test case
|
||||
@@ -244,13 +315,50 @@ def test_compile_link_matmul():
|
||||
# run test case
|
||||
env = os.environ.copy()
|
||||
env["LD_LIBRARY_PATH"] = tmp_dir
|
||||
subprocess.run(["./test", a_path, b_path, c_path], env=env, check=True, cwd=tmp_dir)
|
||||
subprocess.run(
|
||||
["./test", a_path, b_path, c_path], env=env, check=True, cwd=tmp_dir
|
||||
)
|
||||
|
||||
# read data and compare against reference
|
||||
c = np.genfromtxt(c_path, delimiter=",", dtype=np.int32)
|
||||
c_tri = c.reshape((M, N)).view(np.float32)
|
||||
c_ref = np.matmul(a.astype(np.float32), b.astype(np.float32))
|
||||
np.testing.assert_allclose(c_tri, c_ref * c_ref, atol=1e-4, rtol=0.)
|
||||
np.testing.assert_allclose(c_tri, c_ref * c_ref, atol=1e-4, rtol=0.0)
|
||||
|
||||
|
||||
def test_compile_link_matmul():
|
||||
np.random.seed(3)
|
||||
|
||||
with tempfile.TemporaryDirectory() as tmp_dir:
|
||||
dtype = "fp16"
|
||||
BM, BN, BK = 16, 16, 16
|
||||
|
||||
kernel_path = write_triton_kernels(tmp_dir, kernel_src, kernel_utils_src)
|
||||
compile_aot_kernels(
|
||||
tmp_dir, kernel_path, dtype, BM, BN, BK, ha_hb_hints=["", ":16"]
|
||||
)
|
||||
link_aot_kernels(tmp_dir)
|
||||
|
||||
# compile test case
|
||||
M, N, K = 16, 16, 16
|
||||
gen_kernel_library(tmp_dir, "libkernel.so")
|
||||
gen_test_bin(tmp_dir, M, N, K)
|
||||
|
||||
# initialize test data
|
||||
a, b, a_path, b_path, c_path = generate_matmul_test_data(tmp_dir, M, N, K)
|
||||
|
||||
# run test case
|
||||
env = os.environ.copy()
|
||||
env["LD_LIBRARY_PATH"] = tmp_dir
|
||||
subprocess.run(
|
||||
["./test", a_path, b_path, c_path], env=env, check=True, cwd=tmp_dir
|
||||
)
|
||||
|
||||
# read data and compare against reference
|
||||
c = np.genfromtxt(c_path, delimiter=",", dtype=np.int32)
|
||||
c_tri = c.reshape((M, N)).view(np.float32)
|
||||
c_ref = np.matmul(a.astype(np.float32), b.astype(np.float32))
|
||||
np.testing.assert_allclose(c_tri, c_ref * c_ref, atol=1e-4, rtol=0.0)
|
||||
|
||||
|
||||
def test_launcher_has_no_available_kernel():
|
||||
@@ -275,7 +383,13 @@ def test_launcher_has_no_available_kernel():
|
||||
# run test case
|
||||
env = os.environ.copy()
|
||||
env["LD_LIBRARY_PATH"] = tmp_dir
|
||||
result = subprocess.run(["./test", a_path, b_path, c_path], env=env, cwd=tmp_dir, capture_output=True, text=True)
|
||||
result = subprocess.run(
|
||||
["./test", a_path, b_path, c_path],
|
||||
env=env,
|
||||
cwd=tmp_dir,
|
||||
capture_output=True,
|
||||
text=True,
|
||||
)
|
||||
|
||||
# It should fail since the launcher requires all the strides be 1 while they are not.
|
||||
assert result.returncode == -6
|
||||
@@ -286,7 +400,6 @@ def test_compile_link_autotune_matmul():
|
||||
np.random.seed(3)
|
||||
|
||||
with tempfile.TemporaryDirectory() as tmp_dir:
|
||||
|
||||
dtype = "fp16"
|
||||
|
||||
kernel_path = write_triton_kernels(tmp_dir, kernel_src, kernel_utils_src)
|
||||
@@ -300,7 +413,9 @@ def test_compile_link_autotune_matmul():
|
||||
|
||||
for ts in tile_sizes:
|
||||
BM, BN, BK = ts[0], ts[1], ts[2]
|
||||
compile_aot_kernels(tmp_dir, kernel_path, dtype, BM, BN, BK, ha_hb_hints=["", ":16"])
|
||||
compile_aot_kernels(
|
||||
tmp_dir, kernel_path, dtype, BM, BN, BK, ha_hb_hints=["", ":16"]
|
||||
)
|
||||
|
||||
link_aot_kernels(tmp_dir)
|
||||
|
||||
@@ -319,7 +434,12 @@ def test_compile_link_autotune_matmul():
|
||||
|
||||
env = os.environ.copy()
|
||||
env["LD_LIBRARY_PATH"] = tmp_dir
|
||||
subprocess.run([f"./{test_name}", a_path, b_path, c_path], check=True, cwd=tmp_dir, env=env)
|
||||
subprocess.run(
|
||||
[f"./{test_name}", a_path, b_path, c_path],
|
||||
check=True,
|
||||
cwd=tmp_dir,
|
||||
env=env,
|
||||
)
|
||||
|
||||
# read data and compare against reference
|
||||
c = np.genfromtxt(c_path, delimiter=",", dtype=np.int32)
|
||||
|
||||
Reference in New Issue
Block a user