[FRONTEND] Remove extra calls to _get_config causing runtime overhead (#2094)

This commit is contained in:
Thomas
2023-08-13 06:51:26 -07:00
committed by GitHub
parent a01c116f76
commit 98372f46d3
2 changed files with 4 additions and 9 deletions

View File

@@ -674,7 +674,7 @@ class CompiledKernel:
return super().__getattribute__(name)
# capture args and expand args with cutensormap*
def assemble_tensormap_to_arg(self, args, constants):
def assemble_tensormap_to_arg(self, args):
args_with_tma = list(args)
if hasattr(self, 'tensormaps_info'):
# tuple for hashable
@@ -687,7 +687,7 @@ class CompiledKernel:
self._init_handles()
def runner(*args, stream=None):
args_expand = self.assemble_tensormap_to_arg(args, self.constants)
args_expand = self.assemble_tensormap_to_arg(args)
if stream is None:
if self.device_type in ["cuda", "hip"]:
stream = get_cuda_stream()

View File

@@ -407,13 +407,8 @@ def {self.fn.__name__}({args_signature}grid=None, num_warps=4, num_ctas=1, num_s
if bin is not None:
# build dict of constant values
args = [{args}]
all_args = {', '.join([f'{arg}' for arg in self.arg_names]) + ', ' if len(self.arg_names) > 0 else ()}
configs = self._get_config(*all_args),
constants = self._make_constants(constexpr_key)
constants.update({{i: None for i, arg in enumerate(all_args) if arg is None}})
constants.update({{i: 1 for i in configs[0].equal_to_1}})
# Create tensormaps and append to args
args = bin.assemble_tensormap_to_arg(args, constants)
args = bin.assemble_tensormap_to_arg(args)
if not warmup:
bin.c_wrapper(grid_0, grid_1, grid_2, bin.num_warps, bin.num_ctas, bin.clusterDims[0], bin.clusterDims[1], bin.clusterDims[2], bin.shared, stream, bin.cu_function, CompiledKernel.launch_enter_hook, CompiledKernel.launch_exit_hook, bin, *args)
return bin
@@ -435,7 +430,7 @@ def {self.fn.__name__}({args_signature}grid=None, num_warps=4, num_ctas=1, num_s
if not self._call_hook(key, signature, device, constants, num_warps, num_ctas, num_stages, enable_warp_specialization, extern_libs, configs):
bin = compile(self, signature=signature, device=device, constants=constants, num_warps=num_warps, num_ctas=num_ctas, num_stages=num_stages, enable_warp_specialization=enable_warp_specialization, extern_libs=extern_libs, configs=configs, debug=self.debug, device_type=device_type)
# Create tensormaps and append to args
args = bin.assemble_tensormap_to_arg(args, constants)
args = bin.assemble_tensormap_to_arg(args)
if not warmup:
bin.c_wrapper(grid_0, grid_1, grid_2, bin.num_warps, bin.num_ctas, bin.clusterDims[0], bin.clusterDims[1], bin.clusterDims[2], bin.shared, stream, bin.cu_function, CompiledKernel.launch_enter_hook, CompiledKernel.launch_exit_hook, bin, *args)
self.cache[device][key] = bin