[FIX][AOT Compiler] Fix No Specialization Edge Case (#2584)

The [hints
dispatching](218492cd65/python/triton/tools/link.py (L161))
logic currently fails for the edge case where a single kernel with no
specializations is to be linked in the [AOT
compiler](https://github.com/openai/triton/blob/main/python/triton/tools/link.py).

Since the dispatcher inserts a conditional branch for each
specialization case, this results in an `if ()` being inserted into the
`C` source, which clearly breaks downstream artifacts.

Fix:
- Added simple check for this edge case
- Added unit test that mirrors the existing
[`test_compile_link_matmul`](218492cd65/python/test/unit/tools/test_aot.py (L224))
test case save for the aforementioned condition.
This commit is contained in:
jeromeku
2023-11-02 10:02:41 -07:00
committed by GitHub
parent ca8f110617
commit 37cd3d5339
2 changed files with 194 additions and 45 deletions

View File

@@ -31,7 +31,9 @@ class HeaderParser:
import re
# [kernel_name, c signature]
self.linker_directives = re.compile("//[\\s]*tt-linker:[\\s]*([\\w]+):(.+):(.+)")
self.linker_directives = re.compile(
"//[\\s]*tt-linker:[\\s]*([\\w]+):(.+):(.+)"
)
# [name, hash, suffix]
self.kernel_name = re.compile("^([\\w]+)_([\\w]+)_([\\w]+)$")
# [(type, name)]
@@ -42,7 +44,6 @@ class HeaderParser:
self.kernels = defaultdict(list)
def extract_linker_meta(self, header: str):
for ln in header.splitlines():
if ln.startswith("//"):
m = self.linker_directives.match(ln)
@@ -76,7 +77,7 @@ class HeaderParser:
m = self.c_sig.findall(c_sig)
if len(m):
tys, args = [], []
for (ty, arg_name) in m:
for ty, arg_name in m:
tys.append(ty)
args.append(arg_name)
return tys, args
@@ -84,7 +85,7 @@ class HeaderParser:
raise LinkerError(f"{c_sig} is not a valid argument signature")
def _match_suffix(self, suffix: str, c_sig: str):
args = c_sig.split(',')
args = c_sig.split(",")
s2i = {"c": 1, "d": 16}
num_specs = 0
sizes = []
@@ -110,7 +111,7 @@ class HeaderParser:
if name in self.kernels:
last: KernelLinkerMeta = self.kernels[name][-1]
for (cur, new_) in zip(last.arg_ctypes, ker.arg_ctypes):
for cur, new_ in zip(last.arg_ctypes, ker.arg_ctypes):
if cur != new_:
raise LinkerError(
f"Mismatched signature for kernel {name}: \n\texisting sig is: {','.join(last.arg_ctypes)}\n\tcurrent is: {','.join(ker.arg_ctypes)}"
@@ -152,7 +153,9 @@ void unload_{meta.orig_kernel_name}();
# generate dispatcher function for kernels with different meta-parameter and constant values
def make_default_algo_kernel(meta: KernelLinkerMeta) -> str:
src = f"CUresult {meta.orig_kernel_name}_default(CUstream stream, {gen_signature_with_full_args(meta)}){{\n"
src += f" return {meta.orig_kernel_name}(stream, {', '.join(meta.arg_names)}, 0);\n"
src += (
f" return {meta.orig_kernel_name}(stream, {', '.join(meta.arg_names)}, 0);\n"
)
src += "}\n"
return src
@@ -164,12 +167,28 @@ def make_kernel_hints_dispatcher(name: str, metas: Sequence[KernelLinkerMeta]) -
src += f"CUresult {meta.orig_kernel_name}_{meta.sig_hash}_{meta.suffix}(CUstream stream, {gen_signature(meta)});\n"
src += "\n"
src += f"CUresult {name}(CUstream stream, {gen_signature_with_full_args(metas[-1])}){{"
src += (
f"CUresult {name}(CUstream stream, {gen_signature_with_full_args(metas[-1])}){{"
)
src += "\n"
for meta in sorted(metas, key=lambda m: -m.num_specs):
cond_fn = lambda val, hint: f"({val} % {hint} == 0)" if hint == 16 else f"({val} == {hint})" if hint == 1 else None
conds = " && ".join([cond_fn(val, hint) for val, hint in zip(meta.arg_names, meta.sizes) if hint is not None])
src += f" if ({conds})\n"
cond_fn = (
lambda val, hint: f"({val} % {hint} == 0)"
if hint == 16
else f"({val} == {hint})"
if hint == 1
else None
)
conds = " && ".join(
[
cond_fn(val, hint)
for val, hint in zip(meta.arg_names, meta.sizes)
if hint is not None
]
)
src += (
f" if ({conds})\n" if any(meta.sizes) else "if (1)\n"
) # Edge case where no specializations hence no dispatching required
arg_names = [arg for arg, hint in zip(meta.arg_names, meta.sizes) if hint != 1]
src += f" return {meta.orig_kernel_name}_{meta.sig_hash}_{meta.suffix}(stream, {', '.join(arg_names)});\n"
src += "\n"
@@ -183,7 +202,9 @@ def make_kernel_hints_dispatcher(name: str, metas: Sequence[KernelLinkerMeta]) -
src += f"void {mode}_{name}() {{"
src += "\n"
for meta in sorted(metas, key=lambda m: -m.num_specs):
src += f" {mode}_{meta.orig_kernel_name}_{meta.sig_hash}_{meta.suffix}();\n"
src += (
f" {mode}_{meta.orig_kernel_name}_{meta.sig_hash}_{meta.suffix}();\n"
)
src += "}\n"
return src
@@ -252,7 +273,12 @@ if __name__ == "__main__":
help="Paths to header files to link. Must include linker directive annotations (autogenerated by ttc)",
)
parser.add_argument("--out", "-o", type=Path, help="Out filename")
parser.add_argument("--prefix", type=str, default="", help="String to prefix kernel dispatcher names")
parser.add_argument(
"--prefix",
type=str,
default="",
help="String to prefix kernel dispatcher names",
)
args = parser.parse_args()
# metadata
@@ -280,7 +306,10 @@ if __name__ == "__main__":
fp.write(out)
# generate source
defs = [make_kernel_hints_dispatcher(name, meta) for name, meta in parser.kernels.items()]
defs = [
make_kernel_hints_dispatcher(name, meta)
for name, meta in parser.kernels.items()
]
names = [name for name in parser.kernels.keys()]
func_pointers_def = make_func_pointers(names, meta)
meta_const_def = make_kernel_meta_const_dispatcher(meta)