[FRONTEND] using closure to create jit launcher (#2289)

Hi,

I'm adding some features to
`triton.runtime.jit.JITFunction_make_launcher` and found it is hard to
debug it:
1. The inlined Python code is hard to inspect in my editor.
2. My debugger fails to step into these inlined codes.

In response, I've introduced some code to solve these issues. My
modifications include:
~~1. Refactoring the launcher's inline Python code, ensuring it only
relies on the "self" object.~~
~~2. Add a utility method that generates a temporary file to create a
launcher when debugging kernel in main module~~
Using a closure to hold the launcher's body

Because this features might be good to others, I have initiated this
Pull Request.

~~Tests are yet to be added; if this submission might be accepted, I
will add it later.~~
Since this change is a refactor, no new test was added.
This commit is contained in:
edimetia3d
2023-09-23 08:01:54 +08:00
committed by GitHub
parent 215b2e77a1
commit cb83b42ed6

View File

@@ -297,29 +297,29 @@ class JITFunction(KernelInterface[T]):
return JITFunction.cache_hook(key=key, repr=repr, fn=LegacyCompiler(module, name), compile={
"key": key, **kwargs}, is_manual_warmup=False, already_compiled=False)
def _get_arg_specialization_key(self, arg) -> str:
arg_annotation = self.__annotations__.get(arg, '')
def _get_arg_specialization_key(self, arg_name, arg):
arg_annotation = self.__annotations__.get(arg_name, '')
if arg_annotation == '':
return f'({arg}.data_ptr() % {JITFunction.divisibility} == 0) if hasattr({arg}, "data_ptr") \
else ({arg} % {JITFunction.divisibility} == 0, {arg} % {JITFunction.divisibility_8} == 0, {arg} == 1) if isinstance({arg}, int) \
else (False,)'
return (arg.data_ptr() % JITFunction.divisibility == 0) if hasattr(arg, "data_ptr") \
else (arg % JITFunction.divisibility == 0, arg % JITFunction.divisibility_8 == 0, arg == 1) if isinstance(arg, int) \
else (False,)
elif 'Tensor' in arg_annotation:
return f'({arg}.data_ptr() % {JITFunction.divisibility} == 0)'
return (arg.data_ptr() % JITFunction.divisibility == 0)
elif 'int' in arg_annotation or 'bool' in arg_annotation:
return f'({arg} % {JITFunction.divisibility} == 0, {arg} % {JITFunction.divisibility_8} == 0, {arg} == 1)'
return (arg % JITFunction.divisibility == 0, arg % JITFunction.divisibility_8 == 0, arg == 1)
else:
return '(False,)'
return (False,)
def _get_arg_sig_key(self, arg) -> str:
arg_annotation = self.__annotations__.get(arg, '')
def _get_arg_sig_key(self, arg_name, arg) -> str:
arg_annotation = self.__annotations__.get(arg_name, '')
if 'Tensor' in arg_annotation:
return f'{arg}.dtype'
return arg.dtype
elif arg_annotation == 'bool':
return "i1"
elif arg_annotation == 'float':
return 'fp32'
else:
return f'_key_of({arg})'
return self._key_of(arg)
def _conclude_device_type(self, device_types: List[str], pinned_memory_flags: List[bool]) -> str:
device_types = [device_type for device_type in device_types if device_type != '']
@@ -337,124 +337,113 @@ class JITFunction(KernelInterface[T]):
return device_types[0] if len(device_types) > 0 else 'cuda'
def _make_launcher(self):
regular_args = [f'{arg}' for i, arg in enumerate(
regular_args = [arg for i, arg in enumerate(
self.arg_names) if i not in self.constexprs]
constexpr_args = [
f'{arg}' for i, arg in enumerate(
self.arg_names) if i in self.constexprs]
args = ', '.join(regular_args)
# cache key for regular argument type
sig_keys = ', '.join([self._get_arg_sig_key(arg) for arg in regular_args])
device_types = '[' + ', '.join([f'_device_of({arg})' for arg in regular_args]) + ']'
pinned_memory_flags = '[' + ', '.join([f'_pinned_memory_of({arg})' for arg in regular_args]) + ']'
# cache key for constexpr argument values
constexpr_keys = ', '.join(constexpr_args)
# cache key for argument specialization
specializations = []
for i, arg in enumerate(regular_args):
if i in self.do_not_specialize:
continue
specializations += [self._get_arg_specialization_key(arg)]
constexpr_args = [arg for i, arg in enumerate(
self.arg_names) if i in self.constexprs]
spec_keys = ', '.join(specializations)
grid_args = ','.join([f'"{arg}": {arg}' for arg in self.arg_names])
def regular_args_v(args_proxy):
return [args_proxy[arg_name] for arg_name in regular_args]
def launcher_body(args_proxy, grid, num_warps, num_ctas, num_stages, enable_warp_specialization, extern_libs, stream, warmup, device, device_type):
from ..compiler import (CompiledKernel, compile,
get_arch_default_num_stages,
get_arch_default_num_warps)
sig_key = tuple([self._get_arg_sig_key(arg_name, args_proxy[arg_name]) for arg_name in regular_args])
constexpr_key = tuple([args_proxy[arg_name] for arg_name in constexpr_args])
specializations = []
for i, arg_name in enumerate(regular_args):
if i in self.do_not_specialize:
continue
specializations += [self._get_arg_specialization_key(arg_name, args_proxy[arg_name])]
spec_key = tuple(specializations)
assert num_ctas > 0
assert grid is not None
if callable(grid):
grid = grid(args_proxy)
grid_size = len(grid)
grid_0 = grid[0]
grid_1 = grid[1] if grid_size > 1 else 1
grid_2 = grid[2] if grid_size > 2 else 1
if device_type is None:
device_types = [self._device_of(arg) for arg in regular_args_v(args_proxy)]
device_types = [_device_type for _device_type in device_types if _device_type != '']
device_type = self._conclude_device_type(device_types, [self._pinned_memory_of(arg) for arg in
regular_args_v(args_proxy)])
device_backend = None
if device_type not in ['cuda']:
device_backend = get_backend(device_type)
if device_backend is None:
raise ValueError('Cannot find backend for ' + device_type)
if device is None:
if device_type in ['cuda']:
device = get_current_device()
set_current_device(device)
else:
device = device_backend.get_current_device()
device_backend.set_current_device(device)
if stream is None and not warmup:
if device_type in ['cuda']:
stream = get_cuda_stream(device)
else:
stream = device_backend.get_stream()
if num_warps is None:
num_warps = get_arch_default_num_warps(device_type)
if num_stages is None:
num_stages = get_arch_default_num_stages(device_type)
key = (version_key(), sig_key, constexpr_key, spec_key, num_warps, num_ctas, num_stages, enable_warp_specialization, self.debug)
if extern_libs is not None:
key = (key, tuple(extern_libs.items()))
bin = self.cache[device].get(key, None)
if bin is not None:
# build dict of constant values
args = regular_args_v(args_proxy)
# Create tensormaps and append to args
args = bin.assemble_tensormap_to_arg(args)
if not warmup:
bin.c_wrapper(grid_0, grid_1, grid_2, bin.num_warps, bin.num_ctas, bin.clusterDims[0], bin.clusterDims[1], bin.clusterDims[2], bin.shared, stream, bin.cu_function, CompiledKernel.launch_enter_hook, CompiledKernel.launch_exit_hook, bin, *args)
return bin
# kernel not cached -- compile
else:
# build dict of constant values
args = regular_args_v(args_proxy)
all_args = tuple([args_proxy[arg_name] for arg_name in self.arg_names])
configs = self._get_config(*all_args),
constants = self._make_constants(constexpr_key)
constants.update({i: None for i, arg in enumerate(all_args) if arg is None})
constants.update({i: 1 for i in configs[0].equal_to_1})
# build kernel signature -- doesn't include specialized arguments
signature = {i: self._type_of(self._key_of(arg)) for i, arg in enumerate(all_args) if i not in self.constexprs}
# build stub signature -- includes arguments that are specialized
for i, arg in constants.items():
if callable(arg):
raise TypeError(f"Callable constexpr at index {i} is not supported")
if not self._call_hook(key, signature, device, constants, num_warps, num_ctas, num_stages, enable_warp_specialization, extern_libs, configs):
bin = compile(self, signature=signature, device=device, constants=constants, num_warps=num_warps, num_ctas=num_ctas, num_stages=num_stages, enable_warp_specialization=enable_warp_specialization, extern_libs=extern_libs, configs=configs, debug=self.debug, device_type=device_type)
# Create tensormaps and append to args
args = bin.assemble_tensormap_to_arg(args)
if not warmup:
bin.c_wrapper(grid_0, grid_1, grid_2, bin.num_warps, bin.num_ctas, bin.clusterDims[0], bin.clusterDims[1], bin.clusterDims[2], bin.shared, stream, bin.cu_function, CompiledKernel.launch_enter_hook, CompiledKernel.launch_exit_hook, bin, *args)
self.cache[device][key] = bin
return bin
return None
# create a wrapper to call launcher_body
args_map = ','.join([f'"{arg}": {arg}' for arg in self.arg_names])
args_signature = ', '.join(name if dflt == inspect._empty else f'{name} = triton.language.dtype(\'{dflt}\')' if dtype.is_dtype(f'{dflt}') else f'{name} = {dflt}' for name, dflt in zip(self.arg_names, self.arg_defaults))
args_signature = args_signature + ', ' if len(args_signature) > 0 else ''
src = f"""
import triton
def {self.fn.__name__}({args_signature}grid=None, num_warps=None, num_ctas=1, num_stages=None, enable_warp_specialization=False, extern_libs=None, stream=None, warmup=False, device=None, device_type=None):
from ..compiler import compile, CompiledKernel, get_arch_default_num_warps, get_arch_default_num_stages
sig_key = {f'{sig_keys},' if len(sig_keys) > 0 else ()}
constexpr_key = {f'{constexpr_keys},' if len(constexpr_keys) > 0 else ()}
spec_key = {f'{spec_keys},' if len(spec_keys) > 0 else ()}
assert num_ctas > 0
assert grid is not None
if callable(grid):
grid = grid({{{grid_args}}})
grid_size = len(grid)
grid_0 = grid[0]
grid_1 = grid[1] if grid_size > 1 else 1
grid_2 = grid[2] if grid_size > 2 else 1
if device_type is None:
device_types = [_device_type for _device_type in {device_types} if _device_type != '']
device_type = self._conclude_device_type(device_types, {pinned_memory_flags})
device_backend = None
if device_type not in ['cuda']:
device_backend = get_backend(device_type)
if device_backend is None:
raise ValueError('Cannot find backend for ' + device_type)
if device is None:
if device_type in ['cuda']:
device = get_current_device()
set_current_device(device)
else:
device = device_backend.get_current_device()
device_backend.set_current_device(device)
if stream is None and not warmup:
if device_type in ['cuda']:
stream = get_cuda_stream(device)
else:
stream = device_backend.get_stream()
if num_warps is None:
num_warps = get_arch_default_num_warps(device_type)
if num_stages is None:
num_stages = get_arch_default_num_stages(device_type)
key = (version_key, sig_key, constexpr_key, spec_key, num_warps, num_ctas, num_stages, enable_warp_specialization, self.debug)
if not extern_libs is None:
key = (key, tuple(extern_libs.items()))
bin = cache[device].get(key, None)
if bin is not None:
# build dict of constant values
args = [{args}]
# Create tensormaps and append to args
args = bin.assemble_tensormap_to_arg(args)
if not warmup:
bin.c_wrapper(grid_0, grid_1, grid_2, bin.num_warps, bin.num_ctas, bin.clusterDims[0], bin.clusterDims[1], bin.clusterDims[2], bin.shared, stream, bin.cu_function, CompiledKernel.launch_enter_hook, CompiledKernel.launch_exit_hook, bin, *args)
return bin
# kernel not cached -- compile
else:
# build dict of constant values
args = [{args}]
all_args = {', '.join([f'{arg}' for arg in self.arg_names]) + ', ' if len(self.arg_names) > 0 else ()}
configs = self._get_config(*all_args),
constants = self._make_constants(constexpr_key)
constants.update({{i: None for i, arg in enumerate(all_args) if arg is None}})
constants.update({{i: 1 for i in configs[0].equal_to_1}})
# build kernel signature -- doesn't include specialized arguments
signature = {{ i: self._type_of(_key_of(arg)) for i, arg in enumerate(all_args) if i not in self.constexprs }}
# build stub signature -- includes arguments that are specialized
for i, arg in constants.items():
if callable(arg):
raise TypeError(f"Callable constexpr at index {{i}} is not supported")
if not self._call_hook(key, signature, device, constants, num_warps, num_ctas, num_stages, enable_warp_specialization, extern_libs, configs):
bin = compile(self, signature=signature, device=device, constants=constants, num_warps=num_warps, num_ctas=num_ctas, num_stages=num_stages, enable_warp_specialization=enable_warp_specialization, extern_libs=extern_libs, configs=configs, debug=self.debug, device_type=device_type)
# Create tensormaps and append to args
args = bin.assemble_tensormap_to_arg(args)
if not warmup:
bin.c_wrapper(grid_0, grid_1, grid_2, bin.num_warps, bin.num_ctas, bin.clusterDims[0], bin.clusterDims[1], bin.clusterDims[2], bin.shared, stream, bin.cu_function, CompiledKernel.launch_enter_hook, CompiledKernel.launch_exit_hook, bin, *args)
self.cache[device][key] = bin
return bin
return None
return launcher_body({{{args_map}}}, grid, num_warps, num_ctas, num_stages, enable_warp_specialization, extern_libs, stream, warmup, device, device_type)
"""
scope = {"version_key": version_key(),
"get_cuda_stream": get_cuda_stream,
"self": self,
"_spec_of": self._spec_of,
"_key_of": self._key_of,
"_device_of": self._device_of,
"_pinned_memory_of": self._pinned_memory_of,
"cache": self.cache,
"__spec__": __spec__,
"get_backend": get_backend,
"get_current_device": get_current_device,
"set_current_device": set_current_device}
scope = {"launcher_body": launcher_body}
exec(src, scope)
return scope[self.fn.__name__]