mirror of
https://github.com/ROCm/ROCm.git
synced 2026-04-05 03:01:17 -04:00
[FRONTEND] Remove extra calls to _get_config causing runtime overhead (#2094)
This commit is contained in:
@@ -674,7 +674,7 @@ class CompiledKernel:
|
||||
return super().__getattribute__(name)
|
||||
|
||||
# capture args and expand args with cutensormap*
|
||||
def assemble_tensormap_to_arg(self, args, constants):
|
||||
def assemble_tensormap_to_arg(self, args):
|
||||
args_with_tma = list(args)
|
||||
if hasattr(self, 'tensormaps_info'):
|
||||
# tuple for hashable
|
||||
@@ -687,7 +687,7 @@ class CompiledKernel:
|
||||
self._init_handles()
|
||||
|
||||
def runner(*args, stream=None):
|
||||
args_expand = self.assemble_tensormap_to_arg(args, self.constants)
|
||||
args_expand = self.assemble_tensormap_to_arg(args)
|
||||
if stream is None:
|
||||
if self.device_type in ["cuda", "hip"]:
|
||||
stream = get_cuda_stream()
|
||||
|
||||
@@ -407,13 +407,8 @@ def {self.fn.__name__}({args_signature}grid=None, num_warps=4, num_ctas=1, num_s
|
||||
if bin is not None:
|
||||
# build dict of constant values
|
||||
args = [{args}]
|
||||
all_args = {', '.join([f'{arg}' for arg in self.arg_names]) + ', ' if len(self.arg_names) > 0 else ()}
|
||||
configs = self._get_config(*all_args),
|
||||
constants = self._make_constants(constexpr_key)
|
||||
constants.update({{i: None for i, arg in enumerate(all_args) if arg is None}})
|
||||
constants.update({{i: 1 for i in configs[0].equal_to_1}})
|
||||
# Create tensormaps and append to args
|
||||
args = bin.assemble_tensormap_to_arg(args, constants)
|
||||
args = bin.assemble_tensormap_to_arg(args)
|
||||
if not warmup:
|
||||
bin.c_wrapper(grid_0, grid_1, grid_2, bin.num_warps, bin.num_ctas, bin.clusterDims[0], bin.clusterDims[1], bin.clusterDims[2], bin.shared, stream, bin.cu_function, CompiledKernel.launch_enter_hook, CompiledKernel.launch_exit_hook, bin, *args)
|
||||
return bin
|
||||
@@ -435,7 +430,7 @@ def {self.fn.__name__}({args_signature}grid=None, num_warps=4, num_ctas=1, num_s
|
||||
if not self._call_hook(key, signature, device, constants, num_warps, num_ctas, num_stages, enable_warp_specialization, extern_libs, configs):
|
||||
bin = compile(self, signature=signature, device=device, constants=constants, num_warps=num_warps, num_ctas=num_ctas, num_stages=num_stages, enable_warp_specialization=enable_warp_specialization, extern_libs=extern_libs, configs=configs, debug=self.debug, device_type=device_type)
|
||||
# Create tensormaps and append to args
|
||||
args = bin.assemble_tensormap_to_arg(args, constants)
|
||||
args = bin.assemble_tensormap_to_arg(args)
|
||||
if not warmup:
|
||||
bin.c_wrapper(grid_0, grid_1, grid_2, bin.num_warps, bin.num_ctas, bin.clusterDims[0], bin.clusterDims[1], bin.clusterDims[2], bin.shared, stream, bin.cu_function, CompiledKernel.launch_enter_hook, CompiledKernel.launch_exit_hook, bin, *args)
|
||||
self.cache[device][key] = bin
|
||||
|
||||
Reference in New Issue
Block a user