mirror of
https://github.com/ROCm/ROCm.git
synced 2026-04-05 03:01:17 -04:00
[RUNTIME] Get the correct end idx for regular arguments of GPU kernels (#2262)
Previously, if there were any specializations of "1" or "constexpr" mixed with unspecialized arguments in arbitrary order, we might have encountered errors due to passing incorrect arguments. This was because the length of the signature did not indicate the maximum index of regular arguments. https://github.com/openai/triton/issues/2229 @shunting314 @amjames More specifically for cases like: ``` kernel( b: tl.tensor, a: tl.constexpr, c: tl.int = 1, d, e: tl.constexpr, ... ) ```
This commit is contained in:
@@ -63,8 +63,9 @@ def ty_to_cpp(ty):
|
||||
|
||||
|
||||
def generate_launcher(constants, signature, ids):
|
||||
start_desc = len(signature)
|
||||
signature = generate_cu_signature(constants, signature, ids)
|
||||
# Record the end of regular arguments;
|
||||
# subsequent arguments are architecture-specific descriptors, such as tensor descriptors for CUDA.
|
||||
signature, desc_start_idx = generate_cu_signature(constants, signature, ids)
|
||||
arg_decls = ', '.join(f"{ty_to_cpp(ty)} arg{i}" for i, ty in signature.items())
|
||||
|
||||
def _extracted_type(ty):
|
||||
@@ -99,7 +100,7 @@ def generate_launcher(constants, signature, ids):
|
||||
|
||||
# generate glue code
|
||||
folded_without_constexprs = [c for c in ids['ids_of_folded_args'] if c not in ids['ids_of_const_exprs']]
|
||||
params = [i for i in signature.keys() if i >= start_desc or (i not in constants and i not in folded_without_constexprs)]
|
||||
params = [i for i in signature.keys() if i >= desc_start_idx or (i not in constants and i not in folded_without_constexprs)]
|
||||
src = f"""
|
||||
#include \"cuda.h\"
|
||||
#include <stdbool.h>
|
||||
|
||||
@@ -26,12 +26,11 @@ from ..runtime import driver
|
||||
|
||||
def generate_cu_signature(constants, signature, ids):
|
||||
# CUtensorMap*s are always the last arguments
|
||||
num_regular_signatures = max(signature.keys()) + 1 if len(signature) > 0 else 0
|
||||
if ids["ids_of_tensormaps"] is not None:
|
||||
signature = signature.copy()
|
||||
num_signature = len(signature)
|
||||
for i, _ in enumerate(ids["ids_of_tensormaps"]):
|
||||
signature[num_signature + i] = '*CUtensorMap'
|
||||
return signature
|
||||
signature[num_regular_signatures + i] = '*CUtensorMap'
|
||||
return signature, num_regular_signatures
|
||||
|
||||
|
||||
def dummy_tensormaps_info(n=2):
|
||||
|
||||
Reference in New Issue
Block a user