mirror of
https://github.com/ROCm/ROCm.git
synced 2026-04-05 03:01:17 -04:00
[FIX][AOT Compiler] Fix No Specialization Edge Case (#2584)
The [hints dispatching](218492cd65/python/triton/tools/link.py (L161)) logic currently fails for the edge case where a single kernel with no specializations is to be linked in the [AOT compiler](https://github.com/openai/triton/blob/main/python/triton/tools/link.py). Since the dispatcher inserts a conditional branch for each specialization case, this results in an `if ()` being inserted into the `C` source, which clearly breaks downstream artifacts. Fix: - Added simple check for this edge case - Added unit test that mirrors the existing [`test_compile_link_matmul`](218492cd65/python/test/unit/tools/test_aot.py (L224)) test case save for the aforementioned condition.
This commit is contained in:
@@ -31,7 +31,9 @@ class HeaderParser:
|
||||
import re
|
||||
|
||||
# [kernel_name, c signature]
|
||||
self.linker_directives = re.compile("//[\\s]*tt-linker:[\\s]*([\\w]+):(.+):(.+)")
|
||||
self.linker_directives = re.compile(
|
||||
"//[\\s]*tt-linker:[\\s]*([\\w]+):(.+):(.+)"
|
||||
)
|
||||
# [name, hash, suffix]
|
||||
self.kernel_name = re.compile("^([\\w]+)_([\\w]+)_([\\w]+)$")
|
||||
# [(type, name)]
|
||||
@@ -42,7 +44,6 @@ class HeaderParser:
|
||||
self.kernels = defaultdict(list)
|
||||
|
||||
def extract_linker_meta(self, header: str):
|
||||
|
||||
for ln in header.splitlines():
|
||||
if ln.startswith("//"):
|
||||
m = self.linker_directives.match(ln)
|
||||
@@ -76,7 +77,7 @@ class HeaderParser:
|
||||
m = self.c_sig.findall(c_sig)
|
||||
if len(m):
|
||||
tys, args = [], []
|
||||
for (ty, arg_name) in m:
|
||||
for ty, arg_name in m:
|
||||
tys.append(ty)
|
||||
args.append(arg_name)
|
||||
return tys, args
|
||||
@@ -84,7 +85,7 @@ class HeaderParser:
|
||||
raise LinkerError(f"{c_sig} is not a valid argument signature")
|
||||
|
||||
def _match_suffix(self, suffix: str, c_sig: str):
|
||||
args = c_sig.split(',')
|
||||
args = c_sig.split(",")
|
||||
s2i = {"c": 1, "d": 16}
|
||||
num_specs = 0
|
||||
sizes = []
|
||||
@@ -110,7 +111,7 @@ class HeaderParser:
|
||||
if name in self.kernels:
|
||||
last: KernelLinkerMeta = self.kernels[name][-1]
|
||||
|
||||
for (cur, new_) in zip(last.arg_ctypes, ker.arg_ctypes):
|
||||
for cur, new_ in zip(last.arg_ctypes, ker.arg_ctypes):
|
||||
if cur != new_:
|
||||
raise LinkerError(
|
||||
f"Mismatched signature for kernel {name}: \n\texisting sig is: {','.join(last.arg_ctypes)}\n\tcurrent is: {','.join(ker.arg_ctypes)}"
|
||||
@@ -152,7 +153,9 @@ void unload_{meta.orig_kernel_name}();
|
||||
# generate dispatcher function for kernels with different meta-parameter and constant values
|
||||
def make_default_algo_kernel(meta: KernelLinkerMeta) -> str:
|
||||
src = f"CUresult {meta.orig_kernel_name}_default(CUstream stream, {gen_signature_with_full_args(meta)}){{\n"
|
||||
src += f" return {meta.orig_kernel_name}(stream, {', '.join(meta.arg_names)}, 0);\n"
|
||||
src += (
|
||||
f" return {meta.orig_kernel_name}(stream, {', '.join(meta.arg_names)}, 0);\n"
|
||||
)
|
||||
src += "}\n"
|
||||
return src
|
||||
|
||||
@@ -164,12 +167,28 @@ def make_kernel_hints_dispatcher(name: str, metas: Sequence[KernelLinkerMeta]) -
|
||||
src += f"CUresult {meta.orig_kernel_name}_{meta.sig_hash}_{meta.suffix}(CUstream stream, {gen_signature(meta)});\n"
|
||||
src += "\n"
|
||||
|
||||
src += f"CUresult {name}(CUstream stream, {gen_signature_with_full_args(metas[-1])}){{"
|
||||
src += (
|
||||
f"CUresult {name}(CUstream stream, {gen_signature_with_full_args(metas[-1])}){{"
|
||||
)
|
||||
src += "\n"
|
||||
for meta in sorted(metas, key=lambda m: -m.num_specs):
|
||||
cond_fn = lambda val, hint: f"({val} % {hint} == 0)" if hint == 16 else f"({val} == {hint})" if hint == 1 else None
|
||||
conds = " && ".join([cond_fn(val, hint) for val, hint in zip(meta.arg_names, meta.sizes) if hint is not None])
|
||||
src += f" if ({conds})\n"
|
||||
cond_fn = (
|
||||
lambda val, hint: f"({val} % {hint} == 0)"
|
||||
if hint == 16
|
||||
else f"({val} == {hint})"
|
||||
if hint == 1
|
||||
else None
|
||||
)
|
||||
conds = " && ".join(
|
||||
[
|
||||
cond_fn(val, hint)
|
||||
for val, hint in zip(meta.arg_names, meta.sizes)
|
||||
if hint is not None
|
||||
]
|
||||
)
|
||||
src += (
|
||||
f" if ({conds})\n" if any(meta.sizes) else "if (1)\n"
|
||||
) # Edge case where no specializations hence no dispatching required
|
||||
arg_names = [arg for arg, hint in zip(meta.arg_names, meta.sizes) if hint != 1]
|
||||
src += f" return {meta.orig_kernel_name}_{meta.sig_hash}_{meta.suffix}(stream, {', '.join(arg_names)});\n"
|
||||
src += "\n"
|
||||
@@ -183,7 +202,9 @@ def make_kernel_hints_dispatcher(name: str, metas: Sequence[KernelLinkerMeta]) -
|
||||
src += f"void {mode}_{name}() {{"
|
||||
src += "\n"
|
||||
for meta in sorted(metas, key=lambda m: -m.num_specs):
|
||||
src += f" {mode}_{meta.orig_kernel_name}_{meta.sig_hash}_{meta.suffix}();\n"
|
||||
src += (
|
||||
f" {mode}_{meta.orig_kernel_name}_{meta.sig_hash}_{meta.suffix}();\n"
|
||||
)
|
||||
src += "}\n"
|
||||
return src
|
||||
|
||||
@@ -252,7 +273,12 @@ if __name__ == "__main__":
|
||||
help="Paths to header files to link. Must include linker directive annotations (autogenerated by ttc)",
|
||||
)
|
||||
parser.add_argument("--out", "-o", type=Path, help="Out filename")
|
||||
parser.add_argument("--prefix", type=str, default="", help="String to prefix kernel dispatcher names")
|
||||
parser.add_argument(
|
||||
"--prefix",
|
||||
type=str,
|
||||
default="",
|
||||
help="String to prefix kernel dispatcher names",
|
||||
)
|
||||
args = parser.parse_args()
|
||||
|
||||
# metadata
|
||||
@@ -280,7 +306,10 @@ if __name__ == "__main__":
|
||||
fp.write(out)
|
||||
|
||||
# generate source
|
||||
defs = [make_kernel_hints_dispatcher(name, meta) for name, meta in parser.kernels.items()]
|
||||
defs = [
|
||||
make_kernel_hints_dispatcher(name, meta)
|
||||
for name, meta in parser.kernels.items()
|
||||
]
|
||||
names = [name for name in parser.kernels.keys()]
|
||||
func_pointers_def = make_func_pointers(names, meta)
|
||||
meta_const_def = make_kernel_meta_const_dispatcher(meta)
|
||||
|
||||
Reference in New Issue
Block a user