[FIX][AOT Compiler] Fix No Specialization Edge Case (#2584)

The [hints
dispatching](218492cd65/python/triton/tools/link.py (L161))
logic currently fails for the edge case where a single kernel with no
specializations is to be linked in the [AOT
compiler](https://github.com/openai/triton/blob/main/python/triton/tools/link.py).

Since the dispatcher inserts a conditional branch for each
specialization case, this results in an `if ()` being inserted into the
`C` source, which clearly breaks downstream artifacts.

Fix:
- Added simple check for this edge case
- Added unit test that mirrors the existing
[`test_compile_link_matmul`](218492cd65/python/test/unit/tools/test_aot.py (L224))
test case save for the aforementioned condition.
This commit is contained in:
jeromeku
2023-11-02 10:02:41 -07:00
committed by GitHub
parent ca8f110617
commit 37cd3d5339
2 changed files with 194 additions and 45 deletions

View File

@@ -59,7 +59,7 @@ def kernel(C, A, B, M, N, K,
tl.store(c_ptrs, c)
"""
test_utils_src = '''
test_utils_src = """
#include <cuda.h>
#include <stdio.h>
#include <stdint.h>
@@ -93,23 +93,26 @@ static void read_csv_to_buffer(char *filename, int16_t *buffer, int size) {
index++;
}
fclose(file);
}'''
}"""
def gen_kernel_library(dir, libname):
c_files = glob.glob(os.path.join(dir, "*.c"))
subprocess.run(["gcc"] + c_files + ["-I", cuda_include_dir(),
"-c", "-fPIC"],
check=True, cwd=dir)
subprocess.run(
["gcc"] + c_files + ["-I", cuda_include_dir(), "-c", "-fPIC"],
check=True,
cwd=dir,
)
o_files = glob.glob(os.path.join(dir, "*.o"))
subprocess.run(["gcc"] + o_files + ["-shared",
"-o", libname,
"-L", libcuda_dirs()[0]],
check=True, cwd=dir)
subprocess.run(
["gcc"] + o_files + ["-shared", "-o", libname, "-L", libcuda_dirs()[0]],
check=True,
cwd=dir,
)
def gen_test_bin(dir, M, N, K, exe="test", algo_id=0):
test_src = f'''
test_src = f"""
int main(int argc, char **argv) {{
int M = {M}, N = {N}, K = {K};
@@ -165,17 +168,30 @@ int main(int argc, char **argv) {{
cuMemFree(C);
cuCtxDestroy(ctx);
}}
'''
"""
src = test_utils_src + test_src
with open(os.path.join(dir, "test.c"), "w") as file:
file.write(src)
subprocess.run(["gcc"] + ["test.c",
"-I", cuda_include_dir(),
"-L", libcuda_dirs()[0],
"-l", "cuda",
"-L", dir,
"-l", "kernel",
"-o", exe], check=True, cwd=dir)
subprocess.run(
["gcc"]
+ [
"test.c",
"-I",
cuda_include_dir(),
"-L",
libcuda_dirs()[0],
"-l",
"cuda",
"-L",
dir,
"-l",
"kernel",
"-o",
exe,
],
check=True,
cwd=dir,
)
def write_triton_kernels(dir, src, util_src):
@@ -190,16 +206,69 @@ def write_triton_kernels(dir, src, util_src):
return kernel_path
def compile_aot_kernels(dir, kernel_path, dtype, BM, BN, BK, ha_hb_hints):
def _compile_kernel(
dir, signature, kernel_name, out_name, out_path, num_warps, grid, kernel_path
):
compiler_path = os.path.join(triton.tools.__path__[0], "compile.py")
subprocess.run(
[
sys.executable,
compiler_path,
"-n",
kernel_name,
"--signature",
signature,
"--out-name",
out_name,
"-o",
out_path,
"-w",
str(num_warps),
"-g",
grid,
kernel_path,
],
check=True,
cwd=dir,
)
# Edge case kernel with no specialization
def compile_aot_kernel_no_specialization(dir, kernel_path, dtype, BM, BN, BK):
# compile all desired configs
sig = f"*fp32, *{dtype}, *{dtype}, i32, i32, i32, i32, i32, i32, i32, i32, i32, {BM}, {BN}, {BK}"
name = f"matmul_{dtype}"
grid = f"M/{BM}, N/{BN}, 1"
_compile_kernel(
dir=dir,
signature=sig,
kernel_name="kernel",
out_name=name,
out_path=name,
num_warps=1,
grid=grid,
kernel_path=kernel_path,
)
def compile_aot_kernels(dir, kernel_path, dtype, BM, BN, BK, ha_hb_hints):
# compile all desired configs
for ha in ha_hb_hints:
for hb in ha_hb_hints:
sig = f'*fp32:16, *{dtype}:16, *{dtype}:16, i32, i32, i32, i32{ha}, i32:1, i32{hb}, i32:1, i32:16, i32:1, {BM}, {BN}, {BK}'
sig = f"*fp32:16, *{dtype}:16, *{dtype}:16, i32, i32, i32, i32{ha}, i32:1, i32{hb}, i32:1, i32:16, i32:1, {BM}, {BN}, {BK}"
name = f"matmul_{dtype}"
grid = f'M/{BM}, N/{BN}, 1'
subprocess.run([sys.executable, compiler_path, "-n", "kernel", "--signature", sig, "--out-name", name, "-o", name, "-w", "1", "-g", grid, kernel_path], check=True, cwd=dir)
grid = f"M/{BM}, N/{BN}, 1"
_compile_kernel(
dir=dir,
signature=sig,
kernel_name="kernel",
out_name=name,
out_path=name,
num_warps=1,
grid=grid,
kernel_path=kernel_path,
)
def link_aot_kernels(dir):
@@ -207,7 +276,9 @@ def link_aot_kernels(dir):
# link all desired configs
h_files = glob.glob(os.path.join(dir, "*.h"))
subprocess.run([sys.executable, linker_path] + h_files + ["-o", "kernel"], check=True, cwd=dir)
subprocess.run(
[sys.executable, linker_path] + h_files + ["-o", "kernel"], check=True, cwd=dir
)
def generate_matmul_test_data(dir, M, N, K):
@@ -221,16 +292,16 @@ def generate_matmul_test_data(dir, M, N, K):
return a, b, a_path, b_path, c_path
def test_compile_link_matmul():
# Test edge case where the provided kernel signature has no specializations
def test_compile_link_matmul_no_specialization():
np.random.seed(3)
with tempfile.TemporaryDirectory() as tmp_dir:
dtype = "fp16"
BM, BN, BK = 16, 16, 16
kernel_path = write_triton_kernels(tmp_dir, kernel_src, kernel_utils_src)
compile_aot_kernels(tmp_dir, kernel_path, dtype, BM, BN, BK, ha_hb_hints=["", ":16"])
compile_aot_kernel_no_specialization(tmp_dir, kernel_path, dtype, BM, BN, BK)
link_aot_kernels(tmp_dir)
# compile test case
@@ -244,13 +315,50 @@ def test_compile_link_matmul():
# run test case
env = os.environ.copy()
env["LD_LIBRARY_PATH"] = tmp_dir
subprocess.run(["./test", a_path, b_path, c_path], env=env, check=True, cwd=tmp_dir)
subprocess.run(
["./test", a_path, b_path, c_path], env=env, check=True, cwd=tmp_dir
)
# read data and compare against reference
c = np.genfromtxt(c_path, delimiter=",", dtype=np.int32)
c_tri = c.reshape((M, N)).view(np.float32)
c_ref = np.matmul(a.astype(np.float32), b.astype(np.float32))
np.testing.assert_allclose(c_tri, c_ref * c_ref, atol=1e-4, rtol=0.)
np.testing.assert_allclose(c_tri, c_ref * c_ref, atol=1e-4, rtol=0.0)
def test_compile_link_matmul():
np.random.seed(3)
with tempfile.TemporaryDirectory() as tmp_dir:
dtype = "fp16"
BM, BN, BK = 16, 16, 16
kernel_path = write_triton_kernels(tmp_dir, kernel_src, kernel_utils_src)
compile_aot_kernels(
tmp_dir, kernel_path, dtype, BM, BN, BK, ha_hb_hints=["", ":16"]
)
link_aot_kernels(tmp_dir)
# compile test case
M, N, K = 16, 16, 16
gen_kernel_library(tmp_dir, "libkernel.so")
gen_test_bin(tmp_dir, M, N, K)
# initialize test data
a, b, a_path, b_path, c_path = generate_matmul_test_data(tmp_dir, M, N, K)
# run test case
env = os.environ.copy()
env["LD_LIBRARY_PATH"] = tmp_dir
subprocess.run(
["./test", a_path, b_path, c_path], env=env, check=True, cwd=tmp_dir
)
# read data and compare against reference
c = np.genfromtxt(c_path, delimiter=",", dtype=np.int32)
c_tri = c.reshape((M, N)).view(np.float32)
c_ref = np.matmul(a.astype(np.float32), b.astype(np.float32))
np.testing.assert_allclose(c_tri, c_ref * c_ref, atol=1e-4, rtol=0.0)
def test_launcher_has_no_available_kernel():
@@ -275,7 +383,13 @@ def test_launcher_has_no_available_kernel():
# run test case
env = os.environ.copy()
env["LD_LIBRARY_PATH"] = tmp_dir
result = subprocess.run(["./test", a_path, b_path, c_path], env=env, cwd=tmp_dir, capture_output=True, text=True)
result = subprocess.run(
["./test", a_path, b_path, c_path],
env=env,
cwd=tmp_dir,
capture_output=True,
text=True,
)
# It should fail since the launcher requires all the strides be 1 while they are not.
assert result.returncode == -6
@@ -286,7 +400,6 @@ def test_compile_link_autotune_matmul():
np.random.seed(3)
with tempfile.TemporaryDirectory() as tmp_dir:
dtype = "fp16"
kernel_path = write_triton_kernels(tmp_dir, kernel_src, kernel_utils_src)
@@ -300,7 +413,9 @@ def test_compile_link_autotune_matmul():
for ts in tile_sizes:
BM, BN, BK = ts[0], ts[1], ts[2]
compile_aot_kernels(tmp_dir, kernel_path, dtype, BM, BN, BK, ha_hb_hints=["", ":16"])
compile_aot_kernels(
tmp_dir, kernel_path, dtype, BM, BN, BK, ha_hb_hints=["", ":16"]
)
link_aot_kernels(tmp_dir)
@@ -319,7 +434,12 @@ def test_compile_link_autotune_matmul():
env = os.environ.copy()
env["LD_LIBRARY_PATH"] = tmp_dir
subprocess.run([f"./{test_name}", a_path, b_path, c_path], check=True, cwd=tmp_dir, env=env)
subprocess.run(
[f"./{test_name}", a_path, b_path, c_path],
check=True,
cwd=tmp_dir,
env=env,
)
# read data and compare against reference
c = np.genfromtxt(c_path, delimiter=",", dtype=np.int32)