mirror of
https://github.com/ROCm/ROCm.git
synced 2026-04-05 03:01:17 -04:00
[FRONTEND] Refactor jit.py. (#2556)
[FRONTEND] Refactor jit.py. The goal is to simplify the code and make it more flexible before we change the kernel launch syntax to `kernel[grid, compiler_flags(...)](...)`. The main changes here are: - Get rid of the eval'ed code in make_launcher. We can do everything using bind(). - Add KernelParam and KernelArg classes, letting us get rid of the parallel arrays/dicts indexed by parameter index. - Get rid of duplicated kernel launch code in the cache-hit/cache-miss branches.
This commit is contained in:
@@ -7,11 +7,11 @@ import inspect
|
||||
import os
|
||||
import textwrap
|
||||
from collections import defaultdict, namedtuple
|
||||
from functools import cached_property
|
||||
from typing import Callable, Generic, Iterable, List, Optional, TypeVar, Union, cast, overload
|
||||
|
||||
from .._C.libtriton.triton import TMAInfos
|
||||
from ..common.backend import get_backend, get_cuda_version_key
|
||||
from ..language.core import dtype
|
||||
from .interpreter import InterpretedFunction
|
||||
|
||||
|
||||
@@ -112,6 +112,85 @@ def _normalize_ty(ty) -> str:
|
||||
return repr(ty)
|
||||
|
||||
|
||||
class KernelParam:
|
||||
"""Represents a parameter to a @jit'ed function.
|
||||
|
||||
A parameter is just the name plus metadata; a parameter plus a value is a
|
||||
KernelArg.
|
||||
"""
|
||||
|
||||
def __init__(self, num: int, param: inspect.Parameter, do_not_specialize: bool):
|
||||
self.num = num
|
||||
self._param = param
|
||||
self.do_not_specialize = do_not_specialize
|
||||
|
||||
@cached_property
|
||||
def name(self):
|
||||
return self._param.name
|
||||
|
||||
@cached_property
|
||||
def annotation(self):
|
||||
if not self._param.annotation or self._param.annotation == inspect.Parameter.empty:
|
||||
return ""
|
||||
return _normalize_ty(self._param.annotation)
|
||||
|
||||
@cached_property
|
||||
def is_constexpr(self):
|
||||
return "constexpr" in self.annotation
|
||||
|
||||
@property
|
||||
def default(self):
|
||||
return self._param.default
|
||||
|
||||
@property
|
||||
def has_default(self):
|
||||
return self._param.default != inspect.Parameter.empty
|
||||
|
||||
|
||||
class KernelArg:
|
||||
"""Represents an argument to a @jit'ed function.
|
||||
|
||||
An argument is a parameter plus a value.
|
||||
"""
|
||||
|
||||
def __init__(self, value, param):
|
||||
self.value = value
|
||||
self.param = param
|
||||
|
||||
@property
|
||||
def name(self):
|
||||
return self.param.name
|
||||
|
||||
def signature_key(self):
|
||||
annotation = self.param.annotation
|
||||
if "Tensor" in annotation:
|
||||
return self.value.dtype
|
||||
elif annotation == "bool":
|
||||
return "i1"
|
||||
elif annotation == "float":
|
||||
return "fp32"
|
||||
else:
|
||||
return JITFunction._key_of(self.value)
|
||||
|
||||
def specialization_key(self):
|
||||
assert not self.param.do_not_specialize
|
||||
|
||||
try:
|
||||
return (self.value.data_ptr() % JITFunction.divisibility == 0,)
|
||||
except AttributeError:
|
||||
pass
|
||||
|
||||
if isinstance(self.value, int):
|
||||
# bool is a subclass of int, so we don't check explicitly above.
|
||||
return (
|
||||
self.value % JITFunction.divisibility == 0,
|
||||
self.value % JITFunction.divisibility_8 == 0,
|
||||
self.value == 1,
|
||||
)
|
||||
|
||||
return (False,)
|
||||
|
||||
|
||||
class KernelInterface(Generic[T]):
|
||||
run: T
|
||||
|
||||
@@ -156,19 +235,17 @@ class JITFunction(KernelInterface[T]):
|
||||
|
||||
@staticmethod
|
||||
def _device_of(arg):
|
||||
if hasattr(arg, "device"):
|
||||
if hasattr(arg.device, "type"):
|
||||
return arg.device.type
|
||||
|
||||
return ""
|
||||
try:
|
||||
return arg.device.type
|
||||
except AttributeError:
|
||||
return ""
|
||||
|
||||
@staticmethod
|
||||
def _pinned_memory_of(arg):
|
||||
if hasattr(arg, "is_pinned"):
|
||||
if isinstance(arg.is_pinned, Callable):
|
||||
return arg.is_pinned()
|
||||
|
||||
return False
|
||||
try:
|
||||
return arg.is_pinned()
|
||||
except (AttributeError, TypeError):
|
||||
return False
|
||||
|
||||
@staticmethod
|
||||
def _spec_of(arg):
|
||||
@@ -178,6 +255,7 @@ class JITFunction(KernelInterface[T]):
|
||||
return (arg % 16 == 0, arg == 1)
|
||||
return (arg is None,)
|
||||
|
||||
# TODO(jlebar): Fold this into the KernelArg class.
|
||||
def _get_config(self, *args):
|
||||
def is_divisible_by_16(x):
|
||||
if hasattr(x, "data_ptr"):
|
||||
@@ -196,17 +274,21 @@ class JITFunction(KernelInterface[T]):
|
||||
return False
|
||||
|
||||
divisible_by_16 = {
|
||||
i for i, arg in enumerate(args) if is_divisible_by_16(arg) and i not in self.do_not_specialize
|
||||
param.num
|
||||
for param, arg in zip(self.params, args)
|
||||
if is_divisible_by_16(arg) and not param.do_not_specialize
|
||||
}
|
||||
divisible_by_8 = {
|
||||
param.num for param, arg in zip(self.params, args) if is_divisible_by_8(arg) and not param.do_not_specialize
|
||||
}
|
||||
divisible_by_8 = {i for i, arg in enumerate(args) if is_divisible_by_8(arg) and i not in self.do_not_specialize}
|
||||
equal_to_1 = {
|
||||
i
|
||||
for i, arg in enumerate(args)
|
||||
if isinstance(arg, int) and not isinstance(arg, bool) and arg == 1 and i not in self.do_not_specialize
|
||||
param.num
|
||||
for param, arg in zip(self.params, args)
|
||||
if isinstance(arg, int) and not isinstance(arg, bool) and arg == 1 and not param.do_not_specialize
|
||||
}
|
||||
# folded equal_to_1 and None
|
||||
# TODO: method to collect all folded args
|
||||
none_args = {i for i, arg in enumerate(args) if arg is None and i not in self.do_not_specialize}
|
||||
none_args = {param.num for param, arg in zip(self.params, args) if arg is None and not param.do_not_specialize}
|
||||
ids_of_folded_args = equal_to_1 | none_args
|
||||
return namedtuple(
|
||||
"instance_descriptor", ["divisible_by_16", "equal_to_1", "ids_of_folded_args", "divisible_by_8"]
|
||||
@@ -216,7 +298,7 @@ class JITFunction(KernelInterface[T]):
|
||||
|
||||
@staticmethod
|
||||
def _type_of(key):
|
||||
# None are nullptr -- implicitly converted to *i8
|
||||
# `None` is nullptr. Implicitly convert to *i8.
|
||||
if key is None:
|
||||
return "*i8"
|
||||
dtype_str = str(key).split(".")[-1]
|
||||
@@ -266,9 +348,10 @@ class JITFunction(KernelInterface[T]):
|
||||
):
|
||||
if JITFunction.cache_hook is None:
|
||||
return False
|
||||
|
||||
name = self.fn.__name__
|
||||
module = self.fn.__module__
|
||||
arg_reprs = ", ".join([f"{name}: {ty}" for name, ty in zip(self.arg_names, key[1])])
|
||||
arg_reprs = ", ".join([f"{param.name}: {ty}" for param, ty in zip(self.params, key[1])])
|
||||
repr = f"{name}[num_warps={num_warps}, num_ctas={num_ctas}, num_stages={num_stages}, enable_warp_specialization={enable_warp_specialization}, enable_fp_fusion={enable_fp_fusion}]({arg_reprs})"
|
||||
key = str(key)
|
||||
|
||||
@@ -300,34 +383,6 @@ class JITFunction(KernelInterface[T]):
|
||||
already_compiled=False,
|
||||
)
|
||||
|
||||
def _get_arg_specialization_key(self, arg_name, arg):
|
||||
arg_annotation = self.__annotations__.get(arg_name, "")
|
||||
if arg_annotation == "":
|
||||
return (
|
||||
(arg.data_ptr() % JITFunction.divisibility == 0)
|
||||
if hasattr(arg, "data_ptr")
|
||||
else (arg % JITFunction.divisibility == 0, arg % JITFunction.divisibility_8 == 0, arg == 1)
|
||||
if isinstance(arg, int)
|
||||
else (False,)
|
||||
)
|
||||
elif "Tensor" in arg_annotation:
|
||||
return arg.data_ptr() % JITFunction.divisibility == 0
|
||||
elif "int" in arg_annotation or "bool" in arg_annotation:
|
||||
return (arg % JITFunction.divisibility == 0, arg % JITFunction.divisibility_8 == 0, arg == 1)
|
||||
else:
|
||||
return (False,)
|
||||
|
||||
def _get_arg_sig_key(self, arg_name, arg) -> str:
|
||||
arg_annotation = self.__annotations__.get(arg_name, "")
|
||||
if "Tensor" in arg_annotation:
|
||||
return arg.dtype
|
||||
elif arg_annotation == "bool":
|
||||
return "i1"
|
||||
elif arg_annotation == "float":
|
||||
return "fp32"
|
||||
else:
|
||||
return self._key_of(arg)
|
||||
|
||||
def _conclude_device_type(self, device_types: List[str], pinned_memory_flags: List[bool]) -> str:
|
||||
device_types = [device_type for device_type in device_types if device_type != ""]
|
||||
# Return cuda if one of the input tensors is cuda
|
||||
@@ -344,219 +399,185 @@ class JITFunction(KernelInterface[T]):
|
||||
|
||||
return device_types[0] if len(device_types) > 0 else "cuda"
|
||||
|
||||
def _make_launcher(self):
|
||||
regular_args = [arg for i, arg in enumerate(self.arg_names) if i not in self.constexprs]
|
||||
constexpr_args = [arg for i, arg in enumerate(self.arg_names) if i in self.constexprs]
|
||||
def run(self, *args, **kwargs):
|
||||
from ..compiler import CompiledKernel, compile, get_arch_default_num_stages, get_arch_default_num_warps
|
||||
|
||||
def regular_args_v(args_proxy):
|
||||
return [args_proxy[arg_name] for arg_name in regular_args]
|
||||
# Get a compiler-flags arg like `num_warps` and remove it from kwargs.
|
||||
def get_special_arg(name: str, default=None):
|
||||
if name not in kwargs:
|
||||
return default
|
||||
ret = kwargs[name]
|
||||
del kwargs[name]
|
||||
return ret
|
||||
|
||||
def launcher_body(
|
||||
args_proxy,
|
||||
grid,
|
||||
grid = get_special_arg("grid")
|
||||
num_warps = get_special_arg("num_warps")
|
||||
num_ctas = get_special_arg("num_ctas", 1)
|
||||
num_stages = get_special_arg("num_stages")
|
||||
enable_warp_specialization = get_special_arg("enable_warp_specialization", False)
|
||||
enable_fp_fusion = get_special_arg("enable_fp_fusion", True)
|
||||
extern_libs = get_special_arg("extern_libs")
|
||||
stream = get_special_arg("stream")
|
||||
warmup = get_special_arg("warmup", False)
|
||||
device = get_special_arg("device")
|
||||
device_type = get_special_arg("device_type")
|
||||
|
||||
# Bind the remaining arguments to `fn`.
|
||||
bound_args = self.signature.bind(*args, **kwargs)
|
||||
bound_args.apply_defaults()
|
||||
|
||||
assert len(bound_args.arguments) == len(self.params)
|
||||
args = [KernelArg(arg_value, param) for (_, arg_value), param in zip(bound_args.arguments.items(), self.params)]
|
||||
|
||||
non_constexpr_arg_values = [arg.value for arg in args if not arg.param.is_constexpr]
|
||||
|
||||
sig_key = tuple(arg.signature_key() for arg in args if not arg.param.is_constexpr)
|
||||
spec_key = tuple(arg.specialization_key() for arg in args if not arg.param.do_not_specialize)
|
||||
constexpr_key = tuple(arg.value for arg in args if arg.param.is_constexpr)
|
||||
|
||||
assert num_ctas > 0
|
||||
assert grid is not None
|
||||
if callable(grid):
|
||||
# Arguments are passed as a dict to `grid`, by contract.
|
||||
# TODO(jlebar): In the new launch API, pass the compiler flags as a
|
||||
# second parameter to `grid`.
|
||||
grid = grid(dict(bound_args.arguments))
|
||||
grid_size = len(grid)
|
||||
grid_0 = grid[0]
|
||||
grid_1 = grid[1] if grid_size > 1 else 1
|
||||
grid_2 = grid[2] if grid_size > 2 else 1
|
||||
if device_type is None:
|
||||
device_types = [self._device_of(arg) for arg in non_constexpr_arg_values]
|
||||
device_types = [_device_type for _device_type in device_types if _device_type != ""]
|
||||
device_type = self._conclude_device_type(
|
||||
device_types, [self._pinned_memory_of(arg) for arg in non_constexpr_arg_values]
|
||||
)
|
||||
|
||||
device_backend = None
|
||||
if device_type not in ["cuda"]:
|
||||
device_backend = get_backend(device_type)
|
||||
if device_backend is None:
|
||||
raise ValueError("Cannot find backend for " + device_type)
|
||||
|
||||
if device is None:
|
||||
if device_type in ["cuda"]:
|
||||
device = get_current_device()
|
||||
set_current_device(device)
|
||||
else:
|
||||
device = device_backend.get_current_device()
|
||||
device_backend.set_current_device(device)
|
||||
if stream is None and not warmup:
|
||||
if device_type in ["cuda"]:
|
||||
stream = get_cuda_stream(device)
|
||||
else:
|
||||
stream = device_backend.get_stream()
|
||||
|
||||
if num_warps is None:
|
||||
num_warps = get_arch_default_num_warps(device_type)
|
||||
if num_stages is None:
|
||||
num_stages = get_arch_default_num_stages(device_type)
|
||||
|
||||
if device_type in ["cuda"]:
|
||||
version_key = get_cuda_version_key()
|
||||
else:
|
||||
version_key = device_backend.get_version_key()
|
||||
key = (
|
||||
version_key,
|
||||
sig_key,
|
||||
constexpr_key,
|
||||
spec_key,
|
||||
num_warps,
|
||||
num_ctas,
|
||||
num_stages,
|
||||
enable_warp_specialization,
|
||||
enable_fp_fusion,
|
||||
extern_libs,
|
||||
stream,
|
||||
warmup,
|
||||
device,
|
||||
device_type,
|
||||
):
|
||||
from ..compiler import CompiledKernel, compile, get_arch_default_num_stages, get_arch_default_num_warps
|
||||
self.debug,
|
||||
)
|
||||
if extern_libs is not None:
|
||||
key = (key, tuple(extern_libs.items()))
|
||||
|
||||
sig_key = tuple([self._get_arg_sig_key(arg_name, args_proxy[arg_name]) for arg_name in regular_args])
|
||||
constexpr_key = tuple([args_proxy[arg_name] for arg_name in constexpr_args])
|
||||
specializations = []
|
||||
for i, arg_name in enumerate(regular_args):
|
||||
if i in self.do_not_specialize:
|
||||
continue
|
||||
specializations += [self._get_arg_specialization_key(arg_name, args_proxy[arg_name])]
|
||||
# Kernel is not cached; we have to compile.
|
||||
if key not in self.cache[device]:
|
||||
configs = (self._get_config(*[arg.value for arg in args]),)
|
||||
constants = {
|
||||
arg.param.num: arg.value
|
||||
for arg in args
|
||||
if arg.param.is_constexpr or arg.param.num in configs[0].equal_to_1 or arg.value is None
|
||||
}
|
||||
for i, arg in constants.items():
|
||||
if callable(arg):
|
||||
raise TypeError(f"Callable constexpr at index {i} is not supported")
|
||||
|
||||
spec_key = tuple(specializations)
|
||||
assert num_ctas > 0
|
||||
assert grid is not None
|
||||
if callable(grid):
|
||||
grid = grid(args_proxy)
|
||||
grid_size = len(grid)
|
||||
grid_0 = grid[0]
|
||||
grid_1 = grid[1] if grid_size > 1 else 1
|
||||
grid_2 = grid[2] if grid_size > 2 else 1
|
||||
if device_type is None:
|
||||
device_types = [self._device_of(arg) for arg in regular_args_v(args_proxy)]
|
||||
device_types = [_device_type for _device_type in device_types if _device_type != ""]
|
||||
device_type = self._conclude_device_type(
|
||||
device_types, [self._pinned_memory_of(arg) for arg in regular_args_v(args_proxy)]
|
||||
)
|
||||
# Build kernel signature -- doesn't include constexpr arguments.
|
||||
signature = {
|
||||
arg.param.num: self._type_of(self._key_of(arg.value)) for arg in args if not arg.param.is_constexpr
|
||||
}
|
||||
|
||||
device_backend = None
|
||||
if device_type not in ["cuda"]:
|
||||
device_backend = get_backend(device_type)
|
||||
if device_backend is None:
|
||||
raise ValueError("Cannot find backend for " + device_type)
|
||||
|
||||
if device is None:
|
||||
if device_type in ["cuda"]:
|
||||
device = get_current_device()
|
||||
set_current_device(device)
|
||||
else:
|
||||
device = device_backend.get_current_device()
|
||||
device_backend.set_current_device(device)
|
||||
if stream is None and not warmup:
|
||||
if device_type in ["cuda"]:
|
||||
stream = get_cuda_stream(device)
|
||||
else:
|
||||
stream = device_backend.get_stream()
|
||||
|
||||
if num_warps is None:
|
||||
num_warps = get_arch_default_num_warps(device_type)
|
||||
if num_stages is None:
|
||||
num_stages = get_arch_default_num_stages(device_type)
|
||||
|
||||
if device_type in ["cuda"]:
|
||||
version_key = get_cuda_version_key()
|
||||
else:
|
||||
version_key = device_backend.get_version_key()
|
||||
key = (
|
||||
version_key,
|
||||
sig_key,
|
||||
constexpr_key,
|
||||
spec_key,
|
||||
if self._call_hook(
|
||||
key,
|
||||
signature,
|
||||
device,
|
||||
constants,
|
||||
num_warps,
|
||||
num_ctas,
|
||||
num_stages,
|
||||
enable_warp_specialization,
|
||||
enable_fp_fusion,
|
||||
self.debug,
|
||||
)
|
||||
if extern_libs is not None:
|
||||
key = (key, tuple(extern_libs.items()))
|
||||
|
||||
bin = self.cache[device].get(key, None)
|
||||
if bin is not None:
|
||||
# build dict of constant values
|
||||
args = regular_args_v(args_proxy)
|
||||
# Create tensormaps and append to args
|
||||
args = bin.assemble_tensormap_to_arg(args)
|
||||
if not warmup:
|
||||
bin.c_wrapper(
|
||||
grid_0,
|
||||
grid_1,
|
||||
grid_2,
|
||||
bin.num_warps,
|
||||
bin.num_ctas,
|
||||
bin.clusterDims[0],
|
||||
bin.clusterDims[1],
|
||||
bin.clusterDims[2],
|
||||
bin.shared,
|
||||
stream,
|
||||
bin.cu_function,
|
||||
CompiledKernel.launch_enter_hook,
|
||||
CompiledKernel.launch_exit_hook,
|
||||
bin,
|
||||
*args,
|
||||
)
|
||||
return bin
|
||||
# kernel not cached -- compile
|
||||
else:
|
||||
# build dict of constant values
|
||||
args = regular_args_v(args_proxy)
|
||||
all_args = tuple([args_proxy[arg_name] for arg_name in self.arg_names])
|
||||
configs = (self._get_config(*all_args),)
|
||||
constants = self._make_constants(constexpr_key)
|
||||
constants.update({i: None for i, arg in enumerate(all_args) if arg is None})
|
||||
constants.update({i: 1 for i in configs[0].equal_to_1})
|
||||
# build kernel signature -- doesn't include specialized arguments
|
||||
signature = {
|
||||
i: self._type_of(self._key_of(arg)) for i, arg in enumerate(all_args) if i not in self.constexprs
|
||||
}
|
||||
# build stub signature -- includes arguments that are specialized
|
||||
for i, arg in constants.items():
|
||||
if callable(arg):
|
||||
raise TypeError(f"Callable constexpr at index {i} is not supported")
|
||||
if not self._call_hook(
|
||||
key,
|
||||
signature,
|
||||
device,
|
||||
constants,
|
||||
num_warps,
|
||||
num_ctas,
|
||||
num_stages,
|
||||
enable_warp_specialization,
|
||||
enable_fp_fusion,
|
||||
extern_libs,
|
||||
configs,
|
||||
):
|
||||
bin = compile(
|
||||
self,
|
||||
signature=signature,
|
||||
device=device,
|
||||
constants=constants,
|
||||
num_warps=num_warps,
|
||||
num_ctas=num_ctas,
|
||||
num_stages=num_stages,
|
||||
enable_warp_specialization=enable_warp_specialization,
|
||||
enable_fp_fusion=enable_fp_fusion,
|
||||
extern_libs=extern_libs,
|
||||
configs=configs,
|
||||
debug=self.debug,
|
||||
device_type=device_type,
|
||||
)
|
||||
# Create tensormaps and append to args
|
||||
args = bin.assemble_tensormap_to_arg(args)
|
||||
if not warmup:
|
||||
bin.c_wrapper(
|
||||
grid_0,
|
||||
grid_1,
|
||||
grid_2,
|
||||
bin.num_warps,
|
||||
bin.num_ctas,
|
||||
bin.clusterDims[0],
|
||||
bin.clusterDims[1],
|
||||
bin.clusterDims[2],
|
||||
bin.shared,
|
||||
stream,
|
||||
bin.cu_function,
|
||||
CompiledKernel.launch_enter_hook,
|
||||
CompiledKernel.launch_exit_hook,
|
||||
bin,
|
||||
*args,
|
||||
)
|
||||
self.cache[device][key] = bin
|
||||
return bin
|
||||
extern_libs,
|
||||
configs,
|
||||
):
|
||||
return None
|
||||
|
||||
# create a wrapper to call launcher_body
|
||||
args_map = ",".join([f'"{arg}": {arg}' for arg in self.arg_names])
|
||||
args_signature = ", ".join(
|
||||
name
|
||||
if dflt == inspect._empty
|
||||
else f"{name} = triton.language.dtype('{dflt}')"
|
||||
if dtype.is_dtype(f"{dflt}")
|
||||
else f"{name} = {dflt}"
|
||||
for name, dflt in zip(self.arg_names, self.arg_defaults)
|
||||
)
|
||||
args_signature = args_signature + ", " if len(args_signature) > 0 else ""
|
||||
src = f"""
|
||||
import triton
|
||||
def {self.fn.__name__}({args_signature}grid=None, num_warps=None, num_ctas=1, num_stages=None, enable_warp_specialization=False, enable_fp_fusion=True, extern_libs=None, stream=None, warmup=False, device=None, device_type=None):
|
||||
return launcher_body({{{args_map}}}, grid, num_warps, num_ctas, num_stages, enable_warp_specialization, enable_fp_fusion, extern_libs, stream, warmup, device, device_type)
|
||||
"""
|
||||
scope = {"launcher_body": launcher_body}
|
||||
exec(src, scope)
|
||||
return scope[self.fn.__name__]
|
||||
self.cache[device][key] = compile(
|
||||
self,
|
||||
signature=signature,
|
||||
device=device,
|
||||
constants=constants,
|
||||
num_warps=num_warps,
|
||||
num_ctas=num_ctas,
|
||||
num_stages=num_stages,
|
||||
enable_warp_specialization=enable_warp_specialization,
|
||||
enable_fp_fusion=enable_fp_fusion,
|
||||
extern_libs=extern_libs,
|
||||
configs=configs,
|
||||
debug=self.debug,
|
||||
device_type=device_type,
|
||||
)
|
||||
|
||||
bin = self.cache[device][key]
|
||||
if not warmup:
|
||||
bin.c_wrapper(
|
||||
grid_0,
|
||||
grid_1,
|
||||
grid_2,
|
||||
bin.num_warps,
|
||||
bin.num_ctas,
|
||||
bin.clusterDims[0],
|
||||
bin.clusterDims[1],
|
||||
bin.clusterDims[2],
|
||||
bin.shared,
|
||||
stream,
|
||||
bin.cu_function,
|
||||
CompiledKernel.launch_enter_hook,
|
||||
CompiledKernel.launch_exit_hook,
|
||||
bin,
|
||||
*bin.assemble_tensormap_to_arg(non_constexpr_arg_values),
|
||||
)
|
||||
return bin
|
||||
|
||||
def __init__(self, fn, version=None, do_not_specialize=None, debug=None, noinline=None):
|
||||
do_not_specialize = do_not_specialize if do_not_specialize else []
|
||||
|
||||
self.fn = fn
|
||||
self.module = fn.__module__
|
||||
self.version = version
|
||||
# function signature information
|
||||
signature = inspect.signature(fn)
|
||||
self.arg_names = [v.name for v in signature.parameters.values()]
|
||||
self.arg_defaults = [v.default for v in signature.parameters.values()]
|
||||
self.has_defaults = any(v != inspect._empty for v in self.arg_defaults)
|
||||
self.signature = inspect.signature(fn)
|
||||
|
||||
self.params = []
|
||||
for i, param in enumerate(self.signature.parameters.values()):
|
||||
dns = do_not_specialize and (i in do_not_specialize or param.name in do_not_specialize)
|
||||
self.params.append(KernelParam(i, param, dns))
|
||||
|
||||
# function source code (without decorators)
|
||||
self.src = textwrap.dedent(inspect.getsource(fn))
|
||||
self.src = self.src[self.src.find("def") :]
|
||||
@@ -565,24 +586,18 @@ def {self.fn.__name__}({args_signature}grid=None, num_warps=None, num_ctas=1, nu
|
||||
self.hash = None
|
||||
# JITFunction can be instantiated as kernel
|
||||
# when called with a grid using __getitem__
|
||||
self.kernel_decorators = []
|
||||
self.kernel = None
|
||||
self.debug = True if os.environ.get("TRITON_DEBUG", "0") == "1" else debug
|
||||
self.noinline = noinline
|
||||
# annotations
|
||||
self.__annotations__ = {name: _normalize_ty(ty) for name, ty in fn.__annotations__.items()}
|
||||
# index of constexprs
|
||||
self.constexprs = [self.arg_names.index(name) for name, ty in self.__annotations__.items() if "constexpr" in ty]
|
||||
# specialization hints
|
||||
regular_args = [arg for i, arg in enumerate(self.arg_names) if i not in self.constexprs]
|
||||
self.do_not_specialize = [] if do_not_specialize is None else do_not_specialize
|
||||
self.do_not_specialize = {
|
||||
regular_args.index(arg) if isinstance(arg, str) else arg for arg in self.do_not_specialize
|
||||
}
|
||||
|
||||
# tma info
|
||||
self.tensormaps_info = TMAInfos()
|
||||
# launcher
|
||||
self.run = self._make_launcher()
|
||||
|
||||
# TODO(jlebar): Remove uses of these fields outside this file, then
|
||||
# remove the fields here.
|
||||
self.arg_names = [p.name for p in self.params]
|
||||
self.constexprs = [p.num for p in self.params if p.is_constexpr]
|
||||
|
||||
# re-use docs of wrapped function
|
||||
self.__doc__ = fn.__doc__
|
||||
self.__name__ = fn.__name__
|
||||
@@ -615,10 +630,6 @@ def {self.fn.__name__}({args_signature}grid=None, num_warps=None, num_ctas=1, nu
|
||||
raise RuntimeError("Cannot call @triton.jit'd outside of the scope of a kernel")
|
||||
|
||||
def __setattr__(self, name, value):
|
||||
# - when kernel decorators change, cached kernel
|
||||
# needs to be cleared
|
||||
if name == "kernel_decorators":
|
||||
self.kernel = None
|
||||
super(JITFunction, self).__setattr__(name, value)
|
||||
# - when `.src` attribute is set, cache path needs
|
||||
# to be reinitialized
|
||||
|
||||
Reference in New Issue
Block a user