[RUNTIME] Get the correct end idx for regular arguments of GPU kernels (#2262)

Previously, if there were any specializations of "1" or "constexpr"
mixed with unspecialized arguments in arbitrary order, we might have
encountered errors due to passing incorrect arguments. This was because
the length of the signature did not indicate the maximum index of
regular arguments.

https://github.com/openai/triton/issues/2229

@shunting314 @amjames 

More specifically for cases like:

```
kernel(
b: tl.tensor,
a: tl.constexpr,
c: tl.int = 1,
d,
e: tl.constexpr,
...
)
```
This commit is contained in:
Keren Zhou
2023-09-08 02:31:07 -04:00
committed by GitHub
parent 52aa663dcb
commit 10f59d8ce0
2 changed files with 7 additions and 7 deletions

View File

@@ -63,8 +63,9 @@ def ty_to_cpp(ty):
def generate_launcher(constants, signature, ids):
start_desc = len(signature)
signature = generate_cu_signature(constants, signature, ids)
# Record the end of regular arguments;
# subsequent arguments are architecture-specific descriptors, such as tensor descriptors for CUDA.
signature, desc_start_idx = generate_cu_signature(constants, signature, ids)
arg_decls = ', '.join(f"{ty_to_cpp(ty)} arg{i}" for i, ty in signature.items())
def _extracted_type(ty):
@@ -99,7 +100,7 @@ def generate_launcher(constants, signature, ids):
# generate glue code
folded_without_constexprs = [c for c in ids['ids_of_folded_args'] if c not in ids['ids_of_const_exprs']]
params = [i for i in signature.keys() if i >= start_desc or (i not in constants and i not in folded_without_constexprs)]
params = [i for i in signature.keys() if i >= desc_start_idx or (i not in constants and i not in folded_without_constexprs)]
src = f"""
#include \"cuda.h\"
#include <stdbool.h>

View File

@@ -26,12 +26,11 @@ from ..runtime import driver
def generate_cu_signature(constants, signature, ids):
# CUtensorMap*s are always the last arguments
num_regular_signatures = max(signature.keys()) + 1 if len(signature) > 0 else 0
if ids["ids_of_tensormaps"] is not None:
signature = signature.copy()
num_signature = len(signature)
for i, _ in enumerate(ids["ids_of_tensormaps"]):
signature[num_signature + i] = '*CUtensorMap'
return signature
signature[num_regular_signatures + i] = '*CUtensorMap'
return signature, num_regular_signatures
def dummy_tensormaps_info(n=2):