mirror of
https://github.com/ROCm/ROCm.git
synced 2026-04-05 03:01:17 -04:00
[FRONTEND] using closure to create jit launcher (#2289)
Hi, I'm adding some features to `triton.runtime.jit.JITFunction_make_launcher` and found it is hard to debug it: 1. The inlined Python code is hard to inspect in my editor. 2. My debugger fails to step into these inlined codes. In response, I've introduced some code to solve these issues. My modifications include: ~~1. Refactoring the launcher's inline Python code, ensuring it only relies on the "self" object.~~ ~~2. Add a utility method that generates a temporary file to create a launcher when debugging kernel in main module~~ Using a closure to hold the launcher's body Because this features might be good to others, I have initiated this Pull Request. ~~Tests are yet to be added; if this submission might be accepted, I will add it later.~~ Since this change is a refactor, no new test was added.
This commit is contained in:
@@ -297,29 +297,29 @@ class JITFunction(KernelInterface[T]):
|
||||
return JITFunction.cache_hook(key=key, repr=repr, fn=LegacyCompiler(module, name), compile={
|
||||
"key": key, **kwargs}, is_manual_warmup=False, already_compiled=False)
|
||||
|
||||
def _get_arg_specialization_key(self, arg) -> str:
|
||||
arg_annotation = self.__annotations__.get(arg, '')
|
||||
def _get_arg_specialization_key(self, arg_name, arg):
|
||||
arg_annotation = self.__annotations__.get(arg_name, '')
|
||||
if arg_annotation == '':
|
||||
return f'({arg}.data_ptr() % {JITFunction.divisibility} == 0) if hasattr({arg}, "data_ptr") \
|
||||
else ({arg} % {JITFunction.divisibility} == 0, {arg} % {JITFunction.divisibility_8} == 0, {arg} == 1) if isinstance({arg}, int) \
|
||||
else (False,)'
|
||||
return (arg.data_ptr() % JITFunction.divisibility == 0) if hasattr(arg, "data_ptr") \
|
||||
else (arg % JITFunction.divisibility == 0, arg % JITFunction.divisibility_8 == 0, arg == 1) if isinstance(arg, int) \
|
||||
else (False,)
|
||||
elif 'Tensor' in arg_annotation:
|
||||
return f'({arg}.data_ptr() % {JITFunction.divisibility} == 0)'
|
||||
return (arg.data_ptr() % JITFunction.divisibility == 0)
|
||||
elif 'int' in arg_annotation or 'bool' in arg_annotation:
|
||||
return f'({arg} % {JITFunction.divisibility} == 0, {arg} % {JITFunction.divisibility_8} == 0, {arg} == 1)'
|
||||
return (arg % JITFunction.divisibility == 0, arg % JITFunction.divisibility_8 == 0, arg == 1)
|
||||
else:
|
||||
return '(False,)'
|
||||
return (False,)
|
||||
|
||||
def _get_arg_sig_key(self, arg) -> str:
|
||||
arg_annotation = self.__annotations__.get(arg, '')
|
||||
def _get_arg_sig_key(self, arg_name, arg) -> str:
|
||||
arg_annotation = self.__annotations__.get(arg_name, '')
|
||||
if 'Tensor' in arg_annotation:
|
||||
return f'{arg}.dtype'
|
||||
return arg.dtype
|
||||
elif arg_annotation == 'bool':
|
||||
return "i1"
|
||||
elif arg_annotation == 'float':
|
||||
return 'fp32'
|
||||
else:
|
||||
return f'_key_of({arg})'
|
||||
return self._key_of(arg)
|
||||
|
||||
def _conclude_device_type(self, device_types: List[str], pinned_memory_flags: List[bool]) -> str:
|
||||
device_types = [device_type for device_type in device_types if device_type != '']
|
||||
@@ -337,124 +337,113 @@ class JITFunction(KernelInterface[T]):
|
||||
return device_types[0] if len(device_types) > 0 else 'cuda'
|
||||
|
||||
def _make_launcher(self):
|
||||
regular_args = [f'{arg}' for i, arg in enumerate(
|
||||
regular_args = [arg for i, arg in enumerate(
|
||||
self.arg_names) if i not in self.constexprs]
|
||||
constexpr_args = [
|
||||
f'{arg}' for i, arg in enumerate(
|
||||
self.arg_names) if i in self.constexprs]
|
||||
args = ', '.join(regular_args)
|
||||
# cache key for regular argument type
|
||||
sig_keys = ', '.join([self._get_arg_sig_key(arg) for arg in regular_args])
|
||||
device_types = '[' + ', '.join([f'_device_of({arg})' for arg in regular_args]) + ']'
|
||||
pinned_memory_flags = '[' + ', '.join([f'_pinned_memory_of({arg})' for arg in regular_args]) + ']'
|
||||
# cache key for constexpr argument values
|
||||
constexpr_keys = ', '.join(constexpr_args)
|
||||
# cache key for argument specialization
|
||||
specializations = []
|
||||
for i, arg in enumerate(regular_args):
|
||||
if i in self.do_not_specialize:
|
||||
continue
|
||||
specializations += [self._get_arg_specialization_key(arg)]
|
||||
constexpr_args = [arg for i, arg in enumerate(
|
||||
self.arg_names) if i in self.constexprs]
|
||||
|
||||
spec_keys = ', '.join(specializations)
|
||||
grid_args = ','.join([f'"{arg}": {arg}' for arg in self.arg_names])
|
||||
def regular_args_v(args_proxy):
|
||||
return [args_proxy[arg_name] for arg_name in regular_args]
|
||||
|
||||
def launcher_body(args_proxy, grid, num_warps, num_ctas, num_stages, enable_warp_specialization, extern_libs, stream, warmup, device, device_type):
|
||||
from ..compiler import (CompiledKernel, compile,
|
||||
get_arch_default_num_stages,
|
||||
get_arch_default_num_warps)
|
||||
sig_key = tuple([self._get_arg_sig_key(arg_name, args_proxy[arg_name]) for arg_name in regular_args])
|
||||
constexpr_key = tuple([args_proxy[arg_name] for arg_name in constexpr_args])
|
||||
specializations = []
|
||||
for i, arg_name in enumerate(regular_args):
|
||||
if i in self.do_not_specialize:
|
||||
continue
|
||||
specializations += [self._get_arg_specialization_key(arg_name, args_proxy[arg_name])]
|
||||
|
||||
spec_key = tuple(specializations)
|
||||
assert num_ctas > 0
|
||||
assert grid is not None
|
||||
if callable(grid):
|
||||
grid = grid(args_proxy)
|
||||
grid_size = len(grid)
|
||||
grid_0 = grid[0]
|
||||
grid_1 = grid[1] if grid_size > 1 else 1
|
||||
grid_2 = grid[2] if grid_size > 2 else 1
|
||||
if device_type is None:
|
||||
device_types = [self._device_of(arg) for arg in regular_args_v(args_proxy)]
|
||||
device_types = [_device_type for _device_type in device_types if _device_type != '']
|
||||
device_type = self._conclude_device_type(device_types, [self._pinned_memory_of(arg) for arg in
|
||||
regular_args_v(args_proxy)])
|
||||
|
||||
device_backend = None
|
||||
if device_type not in ['cuda']:
|
||||
device_backend = get_backend(device_type)
|
||||
if device_backend is None:
|
||||
raise ValueError('Cannot find backend for ' + device_type)
|
||||
|
||||
if device is None:
|
||||
if device_type in ['cuda']:
|
||||
device = get_current_device()
|
||||
set_current_device(device)
|
||||
else:
|
||||
device = device_backend.get_current_device()
|
||||
device_backend.set_current_device(device)
|
||||
if stream is None and not warmup:
|
||||
if device_type in ['cuda']:
|
||||
stream = get_cuda_stream(device)
|
||||
else:
|
||||
stream = device_backend.get_stream()
|
||||
|
||||
if num_warps is None:
|
||||
num_warps = get_arch_default_num_warps(device_type)
|
||||
if num_stages is None:
|
||||
num_stages = get_arch_default_num_stages(device_type)
|
||||
|
||||
key = (version_key(), sig_key, constexpr_key, spec_key, num_warps, num_ctas, num_stages, enable_warp_specialization, self.debug)
|
||||
if extern_libs is not None:
|
||||
key = (key, tuple(extern_libs.items()))
|
||||
|
||||
bin = self.cache[device].get(key, None)
|
||||
if bin is not None:
|
||||
# build dict of constant values
|
||||
args = regular_args_v(args_proxy)
|
||||
# Create tensormaps and append to args
|
||||
args = bin.assemble_tensormap_to_arg(args)
|
||||
if not warmup:
|
||||
bin.c_wrapper(grid_0, grid_1, grid_2, bin.num_warps, bin.num_ctas, bin.clusterDims[0], bin.clusterDims[1], bin.clusterDims[2], bin.shared, stream, bin.cu_function, CompiledKernel.launch_enter_hook, CompiledKernel.launch_exit_hook, bin, *args)
|
||||
return bin
|
||||
# kernel not cached -- compile
|
||||
else:
|
||||
# build dict of constant values
|
||||
args = regular_args_v(args_proxy)
|
||||
all_args = tuple([args_proxy[arg_name] for arg_name in self.arg_names])
|
||||
configs = self._get_config(*all_args),
|
||||
constants = self._make_constants(constexpr_key)
|
||||
constants.update({i: None for i, arg in enumerate(all_args) if arg is None})
|
||||
constants.update({i: 1 for i in configs[0].equal_to_1})
|
||||
# build kernel signature -- doesn't include specialized arguments
|
||||
signature = {i: self._type_of(self._key_of(arg)) for i, arg in enumerate(all_args) if i not in self.constexprs}
|
||||
# build stub signature -- includes arguments that are specialized
|
||||
for i, arg in constants.items():
|
||||
if callable(arg):
|
||||
raise TypeError(f"Callable constexpr at index {i} is not supported")
|
||||
if not self._call_hook(key, signature, device, constants, num_warps, num_ctas, num_stages, enable_warp_specialization, extern_libs, configs):
|
||||
bin = compile(self, signature=signature, device=device, constants=constants, num_warps=num_warps, num_ctas=num_ctas, num_stages=num_stages, enable_warp_specialization=enable_warp_specialization, extern_libs=extern_libs, configs=configs, debug=self.debug, device_type=device_type)
|
||||
# Create tensormaps and append to args
|
||||
args = bin.assemble_tensormap_to_arg(args)
|
||||
if not warmup:
|
||||
bin.c_wrapper(grid_0, grid_1, grid_2, bin.num_warps, bin.num_ctas, bin.clusterDims[0], bin.clusterDims[1], bin.clusterDims[2], bin.shared, stream, bin.cu_function, CompiledKernel.launch_enter_hook, CompiledKernel.launch_exit_hook, bin, *args)
|
||||
self.cache[device][key] = bin
|
||||
return bin
|
||||
return None
|
||||
|
||||
# create a wrapper to call launcher_body
|
||||
args_map = ','.join([f'"{arg}": {arg}' for arg in self.arg_names])
|
||||
args_signature = ', '.join(name if dflt == inspect._empty else f'{name} = triton.language.dtype(\'{dflt}\')' if dtype.is_dtype(f'{dflt}') else f'{name} = {dflt}' for name, dflt in zip(self.arg_names, self.arg_defaults))
|
||||
args_signature = args_signature + ', ' if len(args_signature) > 0 else ''
|
||||
|
||||
src = f"""
|
||||
import triton
|
||||
def {self.fn.__name__}({args_signature}grid=None, num_warps=None, num_ctas=1, num_stages=None, enable_warp_specialization=False, extern_libs=None, stream=None, warmup=False, device=None, device_type=None):
|
||||
from ..compiler import compile, CompiledKernel, get_arch_default_num_warps, get_arch_default_num_stages
|
||||
sig_key = {f'{sig_keys},' if len(sig_keys) > 0 else ()}
|
||||
constexpr_key = {f'{constexpr_keys},' if len(constexpr_keys) > 0 else ()}
|
||||
spec_key = {f'{spec_keys},' if len(spec_keys) > 0 else ()}
|
||||
assert num_ctas > 0
|
||||
assert grid is not None
|
||||
if callable(grid):
|
||||
grid = grid({{{grid_args}}})
|
||||
grid_size = len(grid)
|
||||
grid_0 = grid[0]
|
||||
grid_1 = grid[1] if grid_size > 1 else 1
|
||||
grid_2 = grid[2] if grid_size > 2 else 1
|
||||
|
||||
if device_type is None:
|
||||
device_types = [_device_type for _device_type in {device_types} if _device_type != '']
|
||||
device_type = self._conclude_device_type(device_types, {pinned_memory_flags})
|
||||
|
||||
device_backend = None
|
||||
if device_type not in ['cuda']:
|
||||
device_backend = get_backend(device_type)
|
||||
if device_backend is None:
|
||||
raise ValueError('Cannot find backend for ' + device_type)
|
||||
|
||||
if device is None:
|
||||
if device_type in ['cuda']:
|
||||
device = get_current_device()
|
||||
set_current_device(device)
|
||||
else:
|
||||
device = device_backend.get_current_device()
|
||||
device_backend.set_current_device(device)
|
||||
if stream is None and not warmup:
|
||||
if device_type in ['cuda']:
|
||||
stream = get_cuda_stream(device)
|
||||
else:
|
||||
stream = device_backend.get_stream()
|
||||
|
||||
if num_warps is None:
|
||||
num_warps = get_arch_default_num_warps(device_type)
|
||||
if num_stages is None:
|
||||
num_stages = get_arch_default_num_stages(device_type)
|
||||
|
||||
key = (version_key, sig_key, constexpr_key, spec_key, num_warps, num_ctas, num_stages, enable_warp_specialization, self.debug)
|
||||
if not extern_libs is None:
|
||||
key = (key, tuple(extern_libs.items()))
|
||||
|
||||
bin = cache[device].get(key, None)
|
||||
if bin is not None:
|
||||
# build dict of constant values
|
||||
args = [{args}]
|
||||
# Create tensormaps and append to args
|
||||
args = bin.assemble_tensormap_to_arg(args)
|
||||
if not warmup:
|
||||
bin.c_wrapper(grid_0, grid_1, grid_2, bin.num_warps, bin.num_ctas, bin.clusterDims[0], bin.clusterDims[1], bin.clusterDims[2], bin.shared, stream, bin.cu_function, CompiledKernel.launch_enter_hook, CompiledKernel.launch_exit_hook, bin, *args)
|
||||
return bin
|
||||
# kernel not cached -- compile
|
||||
else:
|
||||
# build dict of constant values
|
||||
args = [{args}]
|
||||
all_args = {', '.join([f'{arg}' for arg in self.arg_names]) + ', ' if len(self.arg_names) > 0 else ()}
|
||||
configs = self._get_config(*all_args),
|
||||
constants = self._make_constants(constexpr_key)
|
||||
constants.update({{i: None for i, arg in enumerate(all_args) if arg is None}})
|
||||
constants.update({{i: 1 for i in configs[0].equal_to_1}})
|
||||
# build kernel signature -- doesn't include specialized arguments
|
||||
signature = {{ i: self._type_of(_key_of(arg)) for i, arg in enumerate(all_args) if i not in self.constexprs }}
|
||||
# build stub signature -- includes arguments that are specialized
|
||||
for i, arg in constants.items():
|
||||
if callable(arg):
|
||||
raise TypeError(f"Callable constexpr at index {{i}} is not supported")
|
||||
if not self._call_hook(key, signature, device, constants, num_warps, num_ctas, num_stages, enable_warp_specialization, extern_libs, configs):
|
||||
bin = compile(self, signature=signature, device=device, constants=constants, num_warps=num_warps, num_ctas=num_ctas, num_stages=num_stages, enable_warp_specialization=enable_warp_specialization, extern_libs=extern_libs, configs=configs, debug=self.debug, device_type=device_type)
|
||||
# Create tensormaps and append to args
|
||||
args = bin.assemble_tensormap_to_arg(args)
|
||||
if not warmup:
|
||||
bin.c_wrapper(grid_0, grid_1, grid_2, bin.num_warps, bin.num_ctas, bin.clusterDims[0], bin.clusterDims[1], bin.clusterDims[2], bin.shared, stream, bin.cu_function, CompiledKernel.launch_enter_hook, CompiledKernel.launch_exit_hook, bin, *args)
|
||||
self.cache[device][key] = bin
|
||||
return bin
|
||||
return None
|
||||
return launcher_body({{{args_map}}}, grid, num_warps, num_ctas, num_stages, enable_warp_specialization, extern_libs, stream, warmup, device, device_type)
|
||||
"""
|
||||
scope = {"version_key": version_key(),
|
||||
"get_cuda_stream": get_cuda_stream,
|
||||
"self": self,
|
||||
"_spec_of": self._spec_of,
|
||||
"_key_of": self._key_of,
|
||||
"_device_of": self._device_of,
|
||||
"_pinned_memory_of": self._pinned_memory_of,
|
||||
"cache": self.cache,
|
||||
"__spec__": __spec__,
|
||||
"get_backend": get_backend,
|
||||
"get_current_device": get_current_device,
|
||||
"set_current_device": set_current_device}
|
||||
scope = {"launcher_body": launcher_body}
|
||||
exec(src, scope)
|
||||
return scope[self.fn.__name__]
|
||||
|
||||
|
||||
Reference in New Issue
Block a user