[TOOLS][AOT] some issues in equal_to_1 hint (#1998)

- Change test_aot.py to actually use equal_to_1 hint

- In the client function, equal_to_1 parameters are not specialized,
because AOT clients may not know the details of Triton argument
specialization, they still want to use the same parameter list as they
write the Triton kernel. The generated kernels has specialized argument
list, the generated dispatcher code will make sure the correct arguments
from the original full argument list are passed.

- Fixed a bug in _match_suffix in link.py. Previously it assumes each
parameter has a suffix of either ‘d’ or ‘c’, but in fact sometimes a
parameter doesn’t have a suffix, like 0d1d2d34c56c78c
This commit is contained in:
Bin Fan
2023-07-27 16:07:49 -07:00
committed by GitHub
parent 4f1b2ea8d7
commit 2689f4a3b0
5 changed files with 48 additions and 23 deletions

View File

@@ -107,7 +107,7 @@ int main(int argc, char **argv) {
// launch kernel
int gX = 1, gY = 1, gZ = 1;
cuStreamSynchronize(stream);
matmul_fp16xfp16_16x16x16(stream, M/BM, N/BN, 1, C, A, B, N, K, N);
matmul_fp16xfp16_16x16x16(stream, M/BM, N/BN, 1, C, A, B, N, 1, K, 1, N, 1);
cuStreamSynchronize(stream);
// read data
@@ -150,7 +150,7 @@ def test_compile_link_matmul():
hints = [":16", ""]
for ha in hints:
for hb in hints:
sig = f'*fp32:16, *{dtype}:16, *{dtype}:16, i32{ha}, 1, i32{hb}, 1, i32:16, 1, {BM}, {BN}, {BK}'
sig = f'*fp32:16, *{dtype}:16, *{dtype}:16, i32{ha}, i32:1, i32{hb}, i32:1, i32:16, i32:1, {BM}, {BN}, {BK}'
name = f"matmul_{dtype}x{dtype}_{BM}x{BN}x{BK}"
subprocess.run([sys.executable, compiler_path, "-n", "kernel", "--signature", sig, "--out-name", name, "-o", name, "-w", "1", kernel_path], check=True, cwd=tmp_dir)

BIN
python/triton/third_party/cuda/bin/ptxas vendored Executable file

Binary file not shown.

View File

@@ -10,7 +10,7 @@
void unload_{kernel_name}(void);
void load_{kernel_name}(void);
// tt-linker: {kernel_name}:{signature}
// tt-linker: {kernel_name}:{full_signature}
CUresult{_placeholder} {kernel_name}(CUstream stream, unsigned int gX,
unsigned int gY, unsigned int gZ,
{signature});

View File

@@ -94,8 +94,15 @@ if __name__ == "__main__":
divisible_by_16 = [i for i, h in hints.items() if h == 16]
equal_to_1 = [i for i, h in hints.items() if h == 1]
config = triton.compiler.instance_descriptor(divisible_by_16=divisible_by_16, equal_to_1=equal_to_1)
for i in equal_to_1:
constexprs.update({i: 1})
ccinfo = triton.compile(kernel, signature=signature, constants=constexprs, configs=[config], num_warps=args.num_warps)
arg_names = [kernel.arg_names[i] for i in signature.keys()]
arg_names = []
arg_types = []
for i in signature.keys():
if i not in equal_to_1:
arg_names += [kernel.arg_names[i]]
arg_types += [signature[i]]
# dump C stub code
suffix = kernel_suffix(signature.values(), config)
@@ -107,7 +114,8 @@ if __name__ == "__main__":
"triton_kernel_name": triton_kernel_name,
"bin_size": len(hex_),
"bin_data": ", ".join([f"0x{x}{y}" for x, y in zip(hex_[::2], hex_[1::2])]),
"signature": ", ".join([f"{ty_to_cpp(ty)} {name}" for name, ty in zip(arg_names, signature.values())]),
"signature": ", ".join([f"{ty_to_cpp(ty)} {name}" for name, ty in zip(arg_names, arg_types)]),
"full_signature": ", ".join([f"{ty_to_cpp(signature[i])} {kernel.arg_names[i]}" for i in signature.keys()]),
"arg_pointers": ", ".join([f"&{arg}" for arg in arg_names]),
"num_args": len(arg_names),
"kernel_docstring": "",

View File

@@ -32,10 +32,10 @@ class HeaderParser:
self.linker_directives = re.compile("//[\\s]*tt-linker:[\\s]*([\\w]+):(.+)")
# [name, suffix]
self.kernel_name = re.compile("^([\\w]+)_([\\w]+)_([\\w]+)$")
# [(argnum, d|c)]
self.kernel_suffix = re.compile("([0-9]+)([c,d])")
# [(type, name)]
self.c_sig = re.compile("[\\s]*(\\w+)\\s(\\w+)[,]?")
# [d|c]
self.arg_suffix = re.compile("[c,d]")
self.kernels = defaultdict(list)
@@ -48,7 +48,7 @@ class HeaderParser:
ker_name, c_sig = m.group(1), m.group(2)
name, sig_hash, suffix = self._match_name(ker_name)
c_types, arg_names = self._match_c_sig(c_sig)
num_specs, sizes = self._match_suffix(suffix)
num_specs, sizes = self._match_suffix(suffix, c_sig)
self._add_kernel(
name,
KernelLinkerMeta(
@@ -79,18 +79,27 @@ class HeaderParser:
raise LinkerError(f"{c_sig} is not a valid argument signature")
def _match_suffix(self, suffix: str):
m = self.kernel_suffix.findall(suffix)
if not len(m):
raise LinkerError(f"{suffix} is not a valid kernel suffix")
sizes = []
num_specs = len(m)
def _match_suffix(self, suffix: str, c_sig: str):
args = c_sig.split(',')
s2i = {"c": 1, "d": 16}
for (argnum, arg_size_ann) in m:
while len(sizes) < int(argnum):
sizes.append(None)
sizes.append(s2i[arg_size_ann])
num_specs = 0
sizes = []
# scan through suffix, first find the index,
# then see if it is followed by d or c
for i in range(len(args)):
pos = suffix.find(str(i))
if pos == -1:
raise LinkerError(f"{suffix} is not a valid kernel suffix")
pos += len(str(i))
if self.arg_suffix.match(suffix, pos):
num_specs += 1
sizes.extend([None] * (i - len(sizes)))
sizes.append(s2i[suffix[pos]])
pos += 1
if i < len(args) - 1:
suffix = suffix[pos:]
else:
sizes.extend([None] * (len(args) - len(sizes)))
return num_specs, sizes
def _add_kernel(self, name: str, ker: KernelLinkerMeta):
@@ -106,13 +115,20 @@ class HeaderParser:
self.kernels[name].append(ker)
def gen_signature(m):
def gen_signature_with_full_args(m):
return ", ".join([f"{ty} {arg}" for ty, arg in zip(m.arg_ctypes, m.arg_names)])
def gen_signature(m):
arg_types = [ty for ty, hint in zip(m.arg_ctypes, m.sizes) if hint != 1]
arg_names = [arg for arg, hint in zip(m.arg_names, m.sizes) if hint != 1]
sig = ", ".join([f"{ty} {arg}" for ty, arg in zip(arg_types, arg_names)])
return sig
def make_decls(name: str, metas: Sequence[KernelLinkerMeta]) -> str:
return f"""
CUresult {name}(CUstream stream, unsigned int gX, unsigned int gY, unsigned int gZ, {gen_signature(metas[-1])});
CUresult {name}(CUstream stream, unsigned int gX, unsigned int gY, unsigned int gZ, {gen_signature_with_full_args(metas[-1])});
void load_{name}();
void unload_{name}();
"""
@@ -124,13 +140,14 @@ def make_kernel_dispatcher(name: str, metas: Sequence[KernelLinkerMeta]) -> str:
src += f"CUresult {name}_{meta.sig_hash}_{meta.suffix}(CUstream stream, unsigned int gX, unsigned int gY, unsigned int gZ, {gen_signature(meta)});\n"
src += "\n"
src += f"CUresult {name}(CUstream stream, unsigned int gX, unsigned int gY, unsigned int gZ, {gen_signature(metas[-1])}){{"
src += f"CUresult {name}(CUstream stream, unsigned int gX, unsigned int gY, unsigned int gZ, {gen_signature_with_full_args(metas[-1])}){{"
src += "\n"
for meta in sorted(metas, key=lambda m: -m.num_specs):
cond_fn = lambda val, hint: f"({val} % {hint} == 0)" if hint == 16 else f"({val} == {hint})" if hint == 1 else None
conds = " && ".join([cond_fn(val, hint) for val, hint in zip(meta.arg_names, meta.sizes) if hint is not None])
src += f" if ({conds})\n"
src += f" return {name}_{meta.sig_hash}_{meta.suffix}(stream, gX, gY, gZ, {', '.join(meta.arg_names)});\n"
arg_names = [arg for arg, hint in zip(meta.arg_names, meta.sizes) if hint != 1]
src += f" return {name}_{meta.sig_hash}_{meta.suffix}(stream, gX, gY, gZ, {', '.join(arg_names)});\n"
src += "}\n"
for mode in ["load", "unload"]: