[FRONTEND] Refactor jit.py. (#2556)

[FRONTEND] Refactor jit.py.

The goal is to simplify the code and make it more flexible before we
change the kernel launch syntax to
`kernel[grid, compiler_flags(...)](...)`.

The main changes here are:

 - Get rid of the eval'ed code in make_launcher.  We can do everything
   using bind().
 - Add KernelParam and KernelArg classes, letting us get rid of the
   parallel arrays/dicts indexed by parameter index.
 - Get rid of duplicated kernel launch code in the cache-hit/cache-miss
   branches.
This commit is contained in:
Justin Lebar
2023-10-30 13:14:51 -07:00
committed by GitHub
parent f88b01f558
commit 12f906287f

View File

@@ -7,11 +7,11 @@ import inspect
import os
import textwrap
from collections import defaultdict, namedtuple
from functools import cached_property
from typing import Callable, Generic, Iterable, List, Optional, TypeVar, Union, cast, overload
from .._C.libtriton.triton import TMAInfos
from ..common.backend import get_backend, get_cuda_version_key
from ..language.core import dtype
from .interpreter import InterpretedFunction
@@ -112,6 +112,85 @@ def _normalize_ty(ty) -> str:
return repr(ty)
class KernelParam:
"""Represents a parameter to a @jit'ed function.
A parameter is just the name plus metadata; a parameter plus a value is a
KernelArg.
"""
def __init__(self, num: int, param: inspect.Parameter, do_not_specialize: bool):
self.num = num
self._param = param
self.do_not_specialize = do_not_specialize
@cached_property
def name(self):
return self._param.name
@cached_property
def annotation(self):
if not self._param.annotation or self._param.annotation == inspect.Parameter.empty:
return ""
return _normalize_ty(self._param.annotation)
@cached_property
def is_constexpr(self):
return "constexpr" in self.annotation
@property
def default(self):
return self._param.default
@property
def has_default(self):
return self._param.default != inspect.Parameter.empty
class KernelArg:
"""Represents an argument to a @jit'ed function.
An argument is a parameter plus a value.
"""
def __init__(self, value, param):
self.value = value
self.param = param
@property
def name(self):
return self.param.name
def signature_key(self):
annotation = self.param.annotation
if "Tensor" in annotation:
return self.value.dtype
elif annotation == "bool":
return "i1"
elif annotation == "float":
return "fp32"
else:
return JITFunction._key_of(self.value)
def specialization_key(self):
assert not self.param.do_not_specialize
try:
return (self.value.data_ptr() % JITFunction.divisibility == 0,)
except AttributeError:
pass
if isinstance(self.value, int):
# bool is a subclass of int, so we don't check explicitly above.
return (
self.value % JITFunction.divisibility == 0,
self.value % JITFunction.divisibility_8 == 0,
self.value == 1,
)
return (False,)
class KernelInterface(Generic[T]):
run: T
@@ -156,19 +235,17 @@ class JITFunction(KernelInterface[T]):
@staticmethod
def _device_of(arg):
if hasattr(arg, "device"):
if hasattr(arg.device, "type"):
return arg.device.type
return ""
try:
return arg.device.type
except AttributeError:
return ""
@staticmethod
def _pinned_memory_of(arg):
if hasattr(arg, "is_pinned"):
if isinstance(arg.is_pinned, Callable):
return arg.is_pinned()
return False
try:
return arg.is_pinned()
except (AttributeError, TypeError):
return False
@staticmethod
def _spec_of(arg):
@@ -178,6 +255,7 @@ class JITFunction(KernelInterface[T]):
return (arg % 16 == 0, arg == 1)
return (arg is None,)
# TODO(jlebar): Fold this into the KernelArg class.
def _get_config(self, *args):
def is_divisible_by_16(x):
if hasattr(x, "data_ptr"):
@@ -196,17 +274,21 @@ class JITFunction(KernelInterface[T]):
return False
divisible_by_16 = {
i for i, arg in enumerate(args) if is_divisible_by_16(arg) and i not in self.do_not_specialize
param.num
for param, arg in zip(self.params, args)
if is_divisible_by_16(arg) and not param.do_not_specialize
}
divisible_by_8 = {
param.num for param, arg in zip(self.params, args) if is_divisible_by_8(arg) and not param.do_not_specialize
}
divisible_by_8 = {i for i, arg in enumerate(args) if is_divisible_by_8(arg) and i not in self.do_not_specialize}
equal_to_1 = {
i
for i, arg in enumerate(args)
if isinstance(arg, int) and not isinstance(arg, bool) and arg == 1 and i not in self.do_not_specialize
param.num
for param, arg in zip(self.params, args)
if isinstance(arg, int) and not isinstance(arg, bool) and arg == 1 and not param.do_not_specialize
}
# folded equal_to_1 and None
# TODO: method to collect all folded args
none_args = {i for i, arg in enumerate(args) if arg is None and i not in self.do_not_specialize}
none_args = {param.num for param, arg in zip(self.params, args) if arg is None and not param.do_not_specialize}
ids_of_folded_args = equal_to_1 | none_args
return namedtuple(
"instance_descriptor", ["divisible_by_16", "equal_to_1", "ids_of_folded_args", "divisible_by_8"]
@@ -216,7 +298,7 @@ class JITFunction(KernelInterface[T]):
@staticmethod
def _type_of(key):
# None are nullptr -- implicitly converted to *i8
# `None` is nullptr. Implicitly convert to *i8.
if key is None:
return "*i8"
dtype_str = str(key).split(".")[-1]
@@ -266,9 +348,10 @@ class JITFunction(KernelInterface[T]):
):
if JITFunction.cache_hook is None:
return False
name = self.fn.__name__
module = self.fn.__module__
arg_reprs = ", ".join([f"{name}: {ty}" for name, ty in zip(self.arg_names, key[1])])
arg_reprs = ", ".join([f"{param.name}: {ty}" for param, ty in zip(self.params, key[1])])
repr = f"{name}[num_warps={num_warps}, num_ctas={num_ctas}, num_stages={num_stages}, enable_warp_specialization={enable_warp_specialization}, enable_fp_fusion={enable_fp_fusion}]({arg_reprs})"
key = str(key)
@@ -300,34 +383,6 @@ class JITFunction(KernelInterface[T]):
already_compiled=False,
)
def _get_arg_specialization_key(self, arg_name, arg):
arg_annotation = self.__annotations__.get(arg_name, "")
if arg_annotation == "":
return (
(arg.data_ptr() % JITFunction.divisibility == 0)
if hasattr(arg, "data_ptr")
else (arg % JITFunction.divisibility == 0, arg % JITFunction.divisibility_8 == 0, arg == 1)
if isinstance(arg, int)
else (False,)
)
elif "Tensor" in arg_annotation:
return arg.data_ptr() % JITFunction.divisibility == 0
elif "int" in arg_annotation or "bool" in arg_annotation:
return (arg % JITFunction.divisibility == 0, arg % JITFunction.divisibility_8 == 0, arg == 1)
else:
return (False,)
def _get_arg_sig_key(self, arg_name, arg) -> str:
arg_annotation = self.__annotations__.get(arg_name, "")
if "Tensor" in arg_annotation:
return arg.dtype
elif arg_annotation == "bool":
return "i1"
elif arg_annotation == "float":
return "fp32"
else:
return self._key_of(arg)
def _conclude_device_type(self, device_types: List[str], pinned_memory_flags: List[bool]) -> str:
device_types = [device_type for device_type in device_types if device_type != ""]
# Return cuda if one of the input tensors is cuda
@@ -344,219 +399,185 @@ class JITFunction(KernelInterface[T]):
return device_types[0] if len(device_types) > 0 else "cuda"
def _make_launcher(self):
regular_args = [arg for i, arg in enumerate(self.arg_names) if i not in self.constexprs]
constexpr_args = [arg for i, arg in enumerate(self.arg_names) if i in self.constexprs]
def run(self, *args, **kwargs):
from ..compiler import CompiledKernel, compile, get_arch_default_num_stages, get_arch_default_num_warps
def regular_args_v(args_proxy):
return [args_proxy[arg_name] for arg_name in regular_args]
# Get a compiler-flags arg like `num_warps` and remove it from kwargs.
def get_special_arg(name: str, default=None):
if name not in kwargs:
return default
ret = kwargs[name]
del kwargs[name]
return ret
def launcher_body(
args_proxy,
grid,
grid = get_special_arg("grid")
num_warps = get_special_arg("num_warps")
num_ctas = get_special_arg("num_ctas", 1)
num_stages = get_special_arg("num_stages")
enable_warp_specialization = get_special_arg("enable_warp_specialization", False)
enable_fp_fusion = get_special_arg("enable_fp_fusion", True)
extern_libs = get_special_arg("extern_libs")
stream = get_special_arg("stream")
warmup = get_special_arg("warmup", False)
device = get_special_arg("device")
device_type = get_special_arg("device_type")
# Bind the remaining arguments to `fn`.
bound_args = self.signature.bind(*args, **kwargs)
bound_args.apply_defaults()
assert len(bound_args.arguments) == len(self.params)
args = [KernelArg(arg_value, param) for (_, arg_value), param in zip(bound_args.arguments.items(), self.params)]
non_constexpr_arg_values = [arg.value for arg in args if not arg.param.is_constexpr]
sig_key = tuple(arg.signature_key() for arg in args if not arg.param.is_constexpr)
spec_key = tuple(arg.specialization_key() for arg in args if not arg.param.do_not_specialize)
constexpr_key = tuple(arg.value for arg in args if arg.param.is_constexpr)
assert num_ctas > 0
assert grid is not None
if callable(grid):
# Arguments are passed as a dict to `grid`, by contract.
# TODO(jlebar): In the new launch API, pass the compiler flags as a
# second parameter to `grid`.
grid = grid(dict(bound_args.arguments))
grid_size = len(grid)
grid_0 = grid[0]
grid_1 = grid[1] if grid_size > 1 else 1
grid_2 = grid[2] if grid_size > 2 else 1
if device_type is None:
device_types = [self._device_of(arg) for arg in non_constexpr_arg_values]
device_types = [_device_type for _device_type in device_types if _device_type != ""]
device_type = self._conclude_device_type(
device_types, [self._pinned_memory_of(arg) for arg in non_constexpr_arg_values]
)
device_backend = None
if device_type not in ["cuda"]:
device_backend = get_backend(device_type)
if device_backend is None:
raise ValueError("Cannot find backend for " + device_type)
if device is None:
if device_type in ["cuda"]:
device = get_current_device()
set_current_device(device)
else:
device = device_backend.get_current_device()
device_backend.set_current_device(device)
if stream is None and not warmup:
if device_type in ["cuda"]:
stream = get_cuda_stream(device)
else:
stream = device_backend.get_stream()
if num_warps is None:
num_warps = get_arch_default_num_warps(device_type)
if num_stages is None:
num_stages = get_arch_default_num_stages(device_type)
if device_type in ["cuda"]:
version_key = get_cuda_version_key()
else:
version_key = device_backend.get_version_key()
key = (
version_key,
sig_key,
constexpr_key,
spec_key,
num_warps,
num_ctas,
num_stages,
enable_warp_specialization,
enable_fp_fusion,
extern_libs,
stream,
warmup,
device,
device_type,
):
from ..compiler import CompiledKernel, compile, get_arch_default_num_stages, get_arch_default_num_warps
self.debug,
)
if extern_libs is not None:
key = (key, tuple(extern_libs.items()))
sig_key = tuple([self._get_arg_sig_key(arg_name, args_proxy[arg_name]) for arg_name in regular_args])
constexpr_key = tuple([args_proxy[arg_name] for arg_name in constexpr_args])
specializations = []
for i, arg_name in enumerate(regular_args):
if i in self.do_not_specialize:
continue
specializations += [self._get_arg_specialization_key(arg_name, args_proxy[arg_name])]
# Kernel is not cached; we have to compile.
if key not in self.cache[device]:
configs = (self._get_config(*[arg.value for arg in args]),)
constants = {
arg.param.num: arg.value
for arg in args
if arg.param.is_constexpr or arg.param.num in configs[0].equal_to_1 or arg.value is None
}
for i, arg in constants.items():
if callable(arg):
raise TypeError(f"Callable constexpr at index {i} is not supported")
spec_key = tuple(specializations)
assert num_ctas > 0
assert grid is not None
if callable(grid):
grid = grid(args_proxy)
grid_size = len(grid)
grid_0 = grid[0]
grid_1 = grid[1] if grid_size > 1 else 1
grid_2 = grid[2] if grid_size > 2 else 1
if device_type is None:
device_types = [self._device_of(arg) for arg in regular_args_v(args_proxy)]
device_types = [_device_type for _device_type in device_types if _device_type != ""]
device_type = self._conclude_device_type(
device_types, [self._pinned_memory_of(arg) for arg in regular_args_v(args_proxy)]
)
# Build kernel signature -- doesn't include constexpr arguments.
signature = {
arg.param.num: self._type_of(self._key_of(arg.value)) for arg in args if not arg.param.is_constexpr
}
device_backend = None
if device_type not in ["cuda"]:
device_backend = get_backend(device_type)
if device_backend is None:
raise ValueError("Cannot find backend for " + device_type)
if device is None:
if device_type in ["cuda"]:
device = get_current_device()
set_current_device(device)
else:
device = device_backend.get_current_device()
device_backend.set_current_device(device)
if stream is None and not warmup:
if device_type in ["cuda"]:
stream = get_cuda_stream(device)
else:
stream = device_backend.get_stream()
if num_warps is None:
num_warps = get_arch_default_num_warps(device_type)
if num_stages is None:
num_stages = get_arch_default_num_stages(device_type)
if device_type in ["cuda"]:
version_key = get_cuda_version_key()
else:
version_key = device_backend.get_version_key()
key = (
version_key,
sig_key,
constexpr_key,
spec_key,
if self._call_hook(
key,
signature,
device,
constants,
num_warps,
num_ctas,
num_stages,
enable_warp_specialization,
enable_fp_fusion,
self.debug,
)
if extern_libs is not None:
key = (key, tuple(extern_libs.items()))
bin = self.cache[device].get(key, None)
if bin is not None:
# build dict of constant values
args = regular_args_v(args_proxy)
# Create tensormaps and append to args
args = bin.assemble_tensormap_to_arg(args)
if not warmup:
bin.c_wrapper(
grid_0,
grid_1,
grid_2,
bin.num_warps,
bin.num_ctas,
bin.clusterDims[0],
bin.clusterDims[1],
bin.clusterDims[2],
bin.shared,
stream,
bin.cu_function,
CompiledKernel.launch_enter_hook,
CompiledKernel.launch_exit_hook,
bin,
*args,
)
return bin
# kernel not cached -- compile
else:
# build dict of constant values
args = regular_args_v(args_proxy)
all_args = tuple([args_proxy[arg_name] for arg_name in self.arg_names])
configs = (self._get_config(*all_args),)
constants = self._make_constants(constexpr_key)
constants.update({i: None for i, arg in enumerate(all_args) if arg is None})
constants.update({i: 1 for i in configs[0].equal_to_1})
# build kernel signature -- doesn't include specialized arguments
signature = {
i: self._type_of(self._key_of(arg)) for i, arg in enumerate(all_args) if i not in self.constexprs
}
# build stub signature -- includes arguments that are specialized
for i, arg in constants.items():
if callable(arg):
raise TypeError(f"Callable constexpr at index {i} is not supported")
if not self._call_hook(
key,
signature,
device,
constants,
num_warps,
num_ctas,
num_stages,
enable_warp_specialization,
enable_fp_fusion,
extern_libs,
configs,
):
bin = compile(
self,
signature=signature,
device=device,
constants=constants,
num_warps=num_warps,
num_ctas=num_ctas,
num_stages=num_stages,
enable_warp_specialization=enable_warp_specialization,
enable_fp_fusion=enable_fp_fusion,
extern_libs=extern_libs,
configs=configs,
debug=self.debug,
device_type=device_type,
)
# Create tensormaps and append to args
args = bin.assemble_tensormap_to_arg(args)
if not warmup:
bin.c_wrapper(
grid_0,
grid_1,
grid_2,
bin.num_warps,
bin.num_ctas,
bin.clusterDims[0],
bin.clusterDims[1],
bin.clusterDims[2],
bin.shared,
stream,
bin.cu_function,
CompiledKernel.launch_enter_hook,
CompiledKernel.launch_exit_hook,
bin,
*args,
)
self.cache[device][key] = bin
return bin
extern_libs,
configs,
):
return None
# create a wrapper to call launcher_body
args_map = ",".join([f'"{arg}": {arg}' for arg in self.arg_names])
args_signature = ", ".join(
name
if dflt == inspect._empty
else f"{name} = triton.language.dtype('{dflt}')"
if dtype.is_dtype(f"{dflt}")
else f"{name} = {dflt}"
for name, dflt in zip(self.arg_names, self.arg_defaults)
)
args_signature = args_signature + ", " if len(args_signature) > 0 else ""
src = f"""
import triton
def {self.fn.__name__}({args_signature}grid=None, num_warps=None, num_ctas=1, num_stages=None, enable_warp_specialization=False, enable_fp_fusion=True, extern_libs=None, stream=None, warmup=False, device=None, device_type=None):
return launcher_body({{{args_map}}}, grid, num_warps, num_ctas, num_stages, enable_warp_specialization, enable_fp_fusion, extern_libs, stream, warmup, device, device_type)
"""
scope = {"launcher_body": launcher_body}
exec(src, scope)
return scope[self.fn.__name__]
self.cache[device][key] = compile(
self,
signature=signature,
device=device,
constants=constants,
num_warps=num_warps,
num_ctas=num_ctas,
num_stages=num_stages,
enable_warp_specialization=enable_warp_specialization,
enable_fp_fusion=enable_fp_fusion,
extern_libs=extern_libs,
configs=configs,
debug=self.debug,
device_type=device_type,
)
bin = self.cache[device][key]
if not warmup:
bin.c_wrapper(
grid_0,
grid_1,
grid_2,
bin.num_warps,
bin.num_ctas,
bin.clusterDims[0],
bin.clusterDims[1],
bin.clusterDims[2],
bin.shared,
stream,
bin.cu_function,
CompiledKernel.launch_enter_hook,
CompiledKernel.launch_exit_hook,
bin,
*bin.assemble_tensormap_to_arg(non_constexpr_arg_values),
)
return bin
def __init__(self, fn, version=None, do_not_specialize=None, debug=None, noinline=None):
do_not_specialize = do_not_specialize if do_not_specialize else []
self.fn = fn
self.module = fn.__module__
self.version = version
# function signature information
signature = inspect.signature(fn)
self.arg_names = [v.name for v in signature.parameters.values()]
self.arg_defaults = [v.default for v in signature.parameters.values()]
self.has_defaults = any(v != inspect._empty for v in self.arg_defaults)
self.signature = inspect.signature(fn)
self.params = []
for i, param in enumerate(self.signature.parameters.values()):
dns = do_not_specialize and (i in do_not_specialize or param.name in do_not_specialize)
self.params.append(KernelParam(i, param, dns))
# function source code (without decorators)
self.src = textwrap.dedent(inspect.getsource(fn))
self.src = self.src[self.src.find("def") :]
@@ -565,24 +586,18 @@ def {self.fn.__name__}({args_signature}grid=None, num_warps=None, num_ctas=1, nu
self.hash = None
# JITFunction can be instantiated as kernel
# when called with a grid using __getitem__
self.kernel_decorators = []
self.kernel = None
self.debug = True if os.environ.get("TRITON_DEBUG", "0") == "1" else debug
self.noinline = noinline
# annotations
self.__annotations__ = {name: _normalize_ty(ty) for name, ty in fn.__annotations__.items()}
# index of constexprs
self.constexprs = [self.arg_names.index(name) for name, ty in self.__annotations__.items() if "constexpr" in ty]
# specialization hints
regular_args = [arg for i, arg in enumerate(self.arg_names) if i not in self.constexprs]
self.do_not_specialize = [] if do_not_specialize is None else do_not_specialize
self.do_not_specialize = {
regular_args.index(arg) if isinstance(arg, str) else arg for arg in self.do_not_specialize
}
# tma info
self.tensormaps_info = TMAInfos()
# launcher
self.run = self._make_launcher()
# TODO(jlebar): Remove uses of these fields outside this file, then
# remove the fields here.
self.arg_names = [p.name for p in self.params]
self.constexprs = [p.num for p in self.params if p.is_constexpr]
# re-use docs of wrapped function
self.__doc__ = fn.__doc__
self.__name__ = fn.__name__
@@ -615,10 +630,6 @@ def {self.fn.__name__}({args_signature}grid=None, num_warps=None, num_ctas=1, nu
raise RuntimeError("Cannot call @triton.jit'd outside of the scope of a kernel")
def __setattr__(self, name, value):
# - when kernel decorators change, cached kernel
# needs to be cleared
if name == "kernel_decorators":
self.kernel = None
super(JITFunction, self).__setattr__(name, value)
# - when `.src` attribute is set, cache path needs
# to be reinitialized