mirror of
https://github.com/ROCm/ROCm.git
synced 2026-04-05 03:01:17 -04:00
[TOOLS][AOT] some issues in equal_to_1 hint (#1998)
- Change test_aot.py to actually use equal_to_1 hint - In the client function, equal_to_1 parameters are not specialized, because AOT clients may not know the details of Triton argument specialization, they still want to use the same parameter list as they write the Triton kernel. The generated kernels has specialized argument list, the generated dispatcher code will make sure the correct arguments from the original full argument list are passed. - Fixed a bug in _match_suffix in link.py. Previously it assumes each parameter has a suffix of either ‘d’ or ‘c’, but in fact sometimes a parameter doesn’t have a suffix, like 0d1d2d34c56c78c
This commit is contained in:
@@ -107,7 +107,7 @@ int main(int argc, char **argv) {
|
||||
// launch kernel
|
||||
int gX = 1, gY = 1, gZ = 1;
|
||||
cuStreamSynchronize(stream);
|
||||
matmul_fp16xfp16_16x16x16(stream, M/BM, N/BN, 1, C, A, B, N, K, N);
|
||||
matmul_fp16xfp16_16x16x16(stream, M/BM, N/BN, 1, C, A, B, N, 1, K, 1, N, 1);
|
||||
cuStreamSynchronize(stream);
|
||||
|
||||
// read data
|
||||
@@ -150,7 +150,7 @@ def test_compile_link_matmul():
|
||||
hints = [":16", ""]
|
||||
for ha in hints:
|
||||
for hb in hints:
|
||||
sig = f'*fp32:16, *{dtype}:16, *{dtype}:16, i32{ha}, 1, i32{hb}, 1, i32:16, 1, {BM}, {BN}, {BK}'
|
||||
sig = f'*fp32:16, *{dtype}:16, *{dtype}:16, i32{ha}, i32:1, i32{hb}, i32:1, i32:16, i32:1, {BM}, {BN}, {BK}'
|
||||
name = f"matmul_{dtype}x{dtype}_{BM}x{BN}x{BK}"
|
||||
subprocess.run([sys.executable, compiler_path, "-n", "kernel", "--signature", sig, "--out-name", name, "-o", name, "-w", "1", kernel_path], check=True, cwd=tmp_dir)
|
||||
|
||||
|
||||
BIN
python/triton/third_party/cuda/bin/ptxas
vendored
Executable file
BIN
python/triton/third_party/cuda/bin/ptxas
vendored
Executable file
Binary file not shown.
@@ -10,7 +10,7 @@
|
||||
|
||||
void unload_{kernel_name}(void);
|
||||
void load_{kernel_name}(void);
|
||||
// tt-linker: {kernel_name}:{signature}
|
||||
// tt-linker: {kernel_name}:{full_signature}
|
||||
CUresult{_placeholder} {kernel_name}(CUstream stream, unsigned int gX,
|
||||
unsigned int gY, unsigned int gZ,
|
||||
{signature});
|
||||
|
||||
@@ -94,8 +94,15 @@ if __name__ == "__main__":
|
||||
divisible_by_16 = [i for i, h in hints.items() if h == 16]
|
||||
equal_to_1 = [i for i, h in hints.items() if h == 1]
|
||||
config = triton.compiler.instance_descriptor(divisible_by_16=divisible_by_16, equal_to_1=equal_to_1)
|
||||
for i in equal_to_1:
|
||||
constexprs.update({i: 1})
|
||||
ccinfo = triton.compile(kernel, signature=signature, constants=constexprs, configs=[config], num_warps=args.num_warps)
|
||||
arg_names = [kernel.arg_names[i] for i in signature.keys()]
|
||||
arg_names = []
|
||||
arg_types = []
|
||||
for i in signature.keys():
|
||||
if i not in equal_to_1:
|
||||
arg_names += [kernel.arg_names[i]]
|
||||
arg_types += [signature[i]]
|
||||
|
||||
# dump C stub code
|
||||
suffix = kernel_suffix(signature.values(), config)
|
||||
@@ -107,7 +114,8 @@ if __name__ == "__main__":
|
||||
"triton_kernel_name": triton_kernel_name,
|
||||
"bin_size": len(hex_),
|
||||
"bin_data": ", ".join([f"0x{x}{y}" for x, y in zip(hex_[::2], hex_[1::2])]),
|
||||
"signature": ", ".join([f"{ty_to_cpp(ty)} {name}" for name, ty in zip(arg_names, signature.values())]),
|
||||
"signature": ", ".join([f"{ty_to_cpp(ty)} {name}" for name, ty in zip(arg_names, arg_types)]),
|
||||
"full_signature": ", ".join([f"{ty_to_cpp(signature[i])} {kernel.arg_names[i]}" for i in signature.keys()]),
|
||||
"arg_pointers": ", ".join([f"&{arg}" for arg in arg_names]),
|
||||
"num_args": len(arg_names),
|
||||
"kernel_docstring": "",
|
||||
|
||||
@@ -32,10 +32,10 @@ class HeaderParser:
|
||||
self.linker_directives = re.compile("//[\\s]*tt-linker:[\\s]*([\\w]+):(.+)")
|
||||
# [name, suffix]
|
||||
self.kernel_name = re.compile("^([\\w]+)_([\\w]+)_([\\w]+)$")
|
||||
# [(argnum, d|c)]
|
||||
self.kernel_suffix = re.compile("([0-9]+)([c,d])")
|
||||
# [(type, name)]
|
||||
self.c_sig = re.compile("[\\s]*(\\w+)\\s(\\w+)[,]?")
|
||||
# [d|c]
|
||||
self.arg_suffix = re.compile("[c,d]")
|
||||
|
||||
self.kernels = defaultdict(list)
|
||||
|
||||
@@ -48,7 +48,7 @@ class HeaderParser:
|
||||
ker_name, c_sig = m.group(1), m.group(2)
|
||||
name, sig_hash, suffix = self._match_name(ker_name)
|
||||
c_types, arg_names = self._match_c_sig(c_sig)
|
||||
num_specs, sizes = self._match_suffix(suffix)
|
||||
num_specs, sizes = self._match_suffix(suffix, c_sig)
|
||||
self._add_kernel(
|
||||
name,
|
||||
KernelLinkerMeta(
|
||||
@@ -79,18 +79,27 @@ class HeaderParser:
|
||||
|
||||
raise LinkerError(f"{c_sig} is not a valid argument signature")
|
||||
|
||||
def _match_suffix(self, suffix: str):
|
||||
m = self.kernel_suffix.findall(suffix)
|
||||
if not len(m):
|
||||
raise LinkerError(f"{suffix} is not a valid kernel suffix")
|
||||
sizes = []
|
||||
num_specs = len(m)
|
||||
def _match_suffix(self, suffix: str, c_sig: str):
|
||||
args = c_sig.split(',')
|
||||
s2i = {"c": 1, "d": 16}
|
||||
for (argnum, arg_size_ann) in m:
|
||||
while len(sizes) < int(argnum):
|
||||
sizes.append(None)
|
||||
|
||||
sizes.append(s2i[arg_size_ann])
|
||||
num_specs = 0
|
||||
sizes = []
|
||||
# scan through suffix, first find the index,
|
||||
# then see if it is followed by d or c
|
||||
for i in range(len(args)):
|
||||
pos = suffix.find(str(i))
|
||||
if pos == -1:
|
||||
raise LinkerError(f"{suffix} is not a valid kernel suffix")
|
||||
pos += len(str(i))
|
||||
if self.arg_suffix.match(suffix, pos):
|
||||
num_specs += 1
|
||||
sizes.extend([None] * (i - len(sizes)))
|
||||
sizes.append(s2i[suffix[pos]])
|
||||
pos += 1
|
||||
if i < len(args) - 1:
|
||||
suffix = suffix[pos:]
|
||||
else:
|
||||
sizes.extend([None] * (len(args) - len(sizes)))
|
||||
return num_specs, sizes
|
||||
|
||||
def _add_kernel(self, name: str, ker: KernelLinkerMeta):
|
||||
@@ -106,13 +115,20 @@ class HeaderParser:
|
||||
self.kernels[name].append(ker)
|
||||
|
||||
|
||||
def gen_signature(m):
|
||||
def gen_signature_with_full_args(m):
|
||||
return ", ".join([f"{ty} {arg}" for ty, arg in zip(m.arg_ctypes, m.arg_names)])
|
||||
|
||||
|
||||
def gen_signature(m):
|
||||
arg_types = [ty for ty, hint in zip(m.arg_ctypes, m.sizes) if hint != 1]
|
||||
arg_names = [arg for arg, hint in zip(m.arg_names, m.sizes) if hint != 1]
|
||||
sig = ", ".join([f"{ty} {arg}" for ty, arg in zip(arg_types, arg_names)])
|
||||
return sig
|
||||
|
||||
|
||||
def make_decls(name: str, metas: Sequence[KernelLinkerMeta]) -> str:
|
||||
return f"""
|
||||
CUresult {name}(CUstream stream, unsigned int gX, unsigned int gY, unsigned int gZ, {gen_signature(metas[-1])});
|
||||
CUresult {name}(CUstream stream, unsigned int gX, unsigned int gY, unsigned int gZ, {gen_signature_with_full_args(metas[-1])});
|
||||
void load_{name}();
|
||||
void unload_{name}();
|
||||
"""
|
||||
@@ -124,13 +140,14 @@ def make_kernel_dispatcher(name: str, metas: Sequence[KernelLinkerMeta]) -> str:
|
||||
src += f"CUresult {name}_{meta.sig_hash}_{meta.suffix}(CUstream stream, unsigned int gX, unsigned int gY, unsigned int gZ, {gen_signature(meta)});\n"
|
||||
src += "\n"
|
||||
|
||||
src += f"CUresult {name}(CUstream stream, unsigned int gX, unsigned int gY, unsigned int gZ, {gen_signature(metas[-1])}){{"
|
||||
src += f"CUresult {name}(CUstream stream, unsigned int gX, unsigned int gY, unsigned int gZ, {gen_signature_with_full_args(metas[-1])}){{"
|
||||
src += "\n"
|
||||
for meta in sorted(metas, key=lambda m: -m.num_specs):
|
||||
cond_fn = lambda val, hint: f"({val} % {hint} == 0)" if hint == 16 else f"({val} == {hint})" if hint == 1 else None
|
||||
conds = " && ".join([cond_fn(val, hint) for val, hint in zip(meta.arg_names, meta.sizes) if hint is not None])
|
||||
src += f" if ({conds})\n"
|
||||
src += f" return {name}_{meta.sig_hash}_{meta.suffix}(stream, gX, gY, gZ, {', '.join(meta.arg_names)});\n"
|
||||
arg_names = [arg for arg, hint in zip(meta.arg_names, meta.sizes) if hint != 1]
|
||||
src += f" return {name}_{meta.sig_hash}_{meta.suffix}(stream, gX, gY, gZ, {', '.join(arg_names)});\n"
|
||||
src += "}\n"
|
||||
|
||||
for mode in ["load", "unload"]:
|
||||
|
||||
Reference in New Issue
Block a user