Apply ruff pre-commit to python/triton/runtime. (#2558)

We're in the process of incrementally converting from autopep8 + flake8
+ isort to ruff, on a directory-by-directory basis.

The motivation to switch away from autopep8 is that I can't get it to
wrap long lines, even with -aaa.  This seems to be a known problem,
https://github.com/hhatto/autopep8/issues/497.

See more details about alternatives tried in
https://github.com/openai/triton/pull/2557.
This commit is contained in:
Justin Lebar
2023-10-30 11:06:44 -07:00
committed by GitHub
parent f7be5f8fa5
commit f88b01f558
10 changed files with 393 additions and 190 deletions

View File

@@ -1,5 +1,4 @@
from .autotuner import (Autotuner, Config, Heuristics, OutOfResources, autotune,
heuristics)
from .autotuner import Autotuner, Config, Heuristics, OutOfResources, autotune, heuristics
from .driver import driver
from .jit import JITFunction, KernelInterface, MockTensor, TensorWrapper, reinterpret

View File

@@ -10,10 +10,10 @@ from .jit import KernelInterface
class OutOfResources(Exception):
def __init__(self, required, limit, name):
self.message = f'out of resource: {name}, '\
f'Required: {required}, '\
f'Hardware limit: {limit}'
self.message += '. Reducing block sizes or `num_stages` may help.'
self.message = (
f"out of resource: {name}, Required: {required}, Hardware limit: {limit}. "
+ "Reducing block sizes or `num_stages` may help."
)
self.required = required
self.limit = limit
self.name = name
@@ -26,12 +26,12 @@ class OutOfResources(Exception):
class Autotuner(KernelInterface):
def __init__(self, fn, arg_names, configs, key, reset_to_zero, prune_configs_by: Dict = None, warmup=25, rep=100):
'''
"""
:param prune_configs_by: a dict of functions that are used to prune configs, fields:
'perf_model': performance model used to predicate running time with different configs, returns running time
'top_k': number of configs to bench
'prune_num_stages_by'(optional): a function used to prune num_stages. It takes configs:List[Config] as its input, and returns pruned configs.
'''
"""
if not configs:
self.configs = [Config({}, num_warps=4, num_stages=2, num_ctas=1)]
else:
@@ -46,13 +46,14 @@ class Autotuner(KernelInterface):
def _hook(args):
for i in self.reset_idx:
args[i].zero_()
self.hook = _hook
self.arg_names = arg_names
# prune configs
if prune_configs_by:
perf_model, top_k = prune_configs_by['perf_model'], prune_configs_by['top_k']
if 'early_config_prune' in prune_configs_by:
early_config_prune = prune_configs_by['early_config_prune']
perf_model, top_k = prune_configs_by["perf_model"], prune_configs_by["top_k"]
if "early_config_prune" in prune_configs_by:
early_config_prune = prune_configs_by["early_config_prune"]
else:
perf_model, top_k, early_config_prune = None, None, None
self.perf_model, self.configs_top_k = perf_model, top_k
@@ -78,15 +79,20 @@ class Autotuner(KernelInterface):
if config.pre_hook:
config.pre_hook(full_nargs)
self.hook(args)
self.fn.run(*args, num_warps=config.num_warps, num_stages=config.num_stages,
num_ctas=config.num_ctas,
enable_warp_specialization=config.enable_warp_specialization,
# enable_persistent=False,
**current)
self.fn.run(
*args,
num_warps=config.num_warps,
num_stages=config.num_stages,
num_ctas=config.num_ctas,
enable_warp_specialization=config.enable_warp_specialization,
# enable_persistent=False,
**current,
)
try:
return do_bench(kernel_call, warmup=self.warmup, rep=self.rep, quantiles=(0.5, 0.2, 0.8))
except OutOfResources:
return [float('inf'), float('inf'), float('inf')]
return [float("inf"), float("inf"), float("inf")]
def run(self, *args, **kwargs):
self.nargs = dict(zip(self.arg_names, args))
@@ -105,8 +111,7 @@ class Autotuner(KernelInterface):
# prune configs
pruned_configs = self.prune_configs(kwargs)
bench_start = time.time()
timings = {config: self._bench(*args, config=config, **kwargs)
for config in pruned_configs}
timings = {config: self._bench(*args, config=config, **kwargs) for config in pruned_configs}
bench_end = time.time()
self.bench_time = bench_end - bench_start
self.cache[key] = builtins.min(timings, key=timings.get)
@@ -119,9 +124,15 @@ class Autotuner(KernelInterface):
full_nargs = {**self.nargs, **kwargs, **self.best_config.kwargs}
if config.pre_hook is not None:
config.pre_hook(full_nargs)
ret = self.fn.run(*args, num_warps=config.num_warps, num_stages=config.num_stages,
num_ctas=config.num_ctas,
enable_warp_specialization=config.enable_warp_specialization, **kwargs, **config.kwargs)
ret = self.fn.run(
*args,
num_warps=config.num_warps,
num_stages=config.num_stages,
num_ctas=config.num_ctas,
enable_warp_specialization=config.enable_warp_specialization,
**kwargs,
**config.kwargs,
)
self.nargs = None
return ret
@@ -135,17 +146,19 @@ class Autotuner(KernelInterface):
top_k = int(len(self.configs) * top_k)
if len(pruned_configs) > top_k:
est_timing = {
config: self.perf_model(**self.nargs, **kwargs, **config.kwargs, num_stages=config.num_stages,
num_warps=config.num_warps,
num_ctas=config.num_ctas,
enable_warp_specialization=config.enable_warp_specialization,
enable_persistent=config.enable_persistent)
config: self.perf_model(
**self.nargs,
**kwargs,
**config.kwargs,
num_stages=config.num_stages,
num_warps=config.num_warps,
num_ctas=config.num_ctas,
enable_warp_specialization=config.enable_warp_specialization,
enable_persistent=config.enable_persistent,
)
for config in pruned_configs
}
pruned_configs = sorted(
est_timing.keys(),
key=lambda x: est_timing[x])[
:top_k]
pruned_configs = sorted(est_timing.keys(), key=lambda x: est_timing[x])[:top_k]
return pruned_configs
def warmup(self, *args, **kwargs):
@@ -195,14 +208,13 @@ class Config:
def __str__(self):
res = []
for k, v in self.kwargs.items():
res.append(f'{k}: {v}')
res.append(f'num_warps: {self.num_warps}')
res.append(f'num_ctas: {self.num_ctas}')
res.append(f'num_stages: {self.num_stages}')
res.append(
f'enable_warp_specialization: {self.enable_warp_specialization}')
res.append(f'enable_persistent: {self.enable_persistent}')
return ', '.join(res)
res.append(f"{k}: {v}")
res.append(f"num_warps: {self.num_warps}")
res.append(f"num_ctas: {self.num_ctas}")
res.append(f"num_stages: {self.num_stages}")
res.append(f"enable_warp_specialization: {self.enable_warp_specialization}")
res.append(f"enable_persistent: {self.enable_persistent}")
return ", ".join(res)
def autotune(configs, key, prune_configs_by=None, reset_to_zero=None, warmup=25, rep=100):
@@ -241,6 +253,7 @@ def autotune(configs, key, prune_configs_by=None, reset_to_zero=None, warmup=25,
:param rep: Repetition time (in ms) to pass to benchmarking, defaults to 100.
:type rep: int
"""
def decorator(fn):
return Autotuner(fn, fn.arg_names, configs, key, reset_to_zero, prune_configs_by, warmup, rep)
@@ -248,7 +261,6 @@ def autotune(configs, key, prune_configs_by=None, reset_to_zero=None, warmup=25,
class Heuristics(KernelInterface):
def __init__(self, fn, arg_names, values) -> None:
self.fn = fn
self.values = values
@@ -276,6 +288,7 @@ def heuristics(values):
each such function takes a list of positional arguments as input.
:type values: dict[str, Callable[[list[Any]], Any]]
"""
def decorator(fn):
return Heuristics(fn, fn.arg_names, values)

View File

@@ -47,17 +47,17 @@ class FileCacheManager(CacheManager):
def __init__(self, key, override=False, dump=False):
self.key = key
self.lock_path = None
if (dump):
if dump:
self.cache_dir = default_dump_dir()
self.cache_dir = os.path.join(self.cache_dir, self.key)
self.lock_path = os.path.join(self.cache_dir, "lock")
os.makedirs(self.cache_dir, exist_ok=True)
elif (override):
elif override:
self.cache_dir = default_override_dir()
self.cache_dir = os.path.join(self.cache_dir, self.key)
else:
# create cache directory if it doesn't exist
self.cache_dir = os.getenv('TRITON_CACHE_DIR', "").strip() or default_cache_dir()
self.cache_dir = os.getenv("TRITON_CACHE_DIR", "").strip() or default_cache_dir()
if self.cache_dir:
self.cache_dir = os.path.join(self.cache_dir, self.key)
self.lock_path = os.path.join(self.cache_dir, "lock")
@@ -141,6 +141,7 @@ def get_cache_manager(key) -> CacheManager:
if user_cache_manager is not None and user_cache_manager != __cache_cls_nme:
import importlib
module_path, clz_nme = user_cache_manager.split(":")
module = importlib.import_module(module_path)
__cache_cls = getattr(module, clz_nme)

View File

@@ -9,7 +9,6 @@ from .cache import get_cache_manager
class DriverBase(metaclass=abc.ABCMeta):
CUDA = 0
HIP = 1
@@ -19,15 +18,16 @@ class DriverBase(metaclass=abc.ABCMeta):
def __init__(self) -> None:
pass
# -----------------------------
# CUDA
# -----------------------------
class CudaUtils(object):
def __new__(cls):
if not hasattr(cls, 'instance'):
if not hasattr(cls, "instance"):
cls.instance = super(CudaUtils, cls).__new__(cls)
return cls.instance
@@ -47,6 +47,7 @@ class CudaUtils(object):
with open(so, "rb") as f:
cache_path = cache.put(f.read(), fname, binary=True)
import importlib.util
spec = importlib.util.spec_from_file_location("cuda_utils", cache_path)
mod = importlib.util.module_from_spec(spec)
spec.loader.exec_module(mod)
@@ -64,9 +65,8 @@ class CudaUtils(object):
class CudaDriver(DriverBase):
def __new__(cls):
if not hasattr(cls, 'instance'):
if not hasattr(cls, "instance"):
cls.instance = super(CudaDriver, cls).__new__(cls)
return cls.instance
@@ -74,6 +74,7 @@ class CudaDriver(DriverBase):
self.utils = CudaUtils()
self.backend = self.CUDA
# -----------------------------
# HIP
# -----------------------------
@@ -81,7 +82,7 @@ class CudaDriver(DriverBase):
class HIPUtils(object):
def __new__(cls):
if not hasattr(cls, 'instance'):
if not hasattr(cls, "instance"):
cls.instance = super(HIPUtils, cls).__new__(cls)
return cls.instance
@@ -101,6 +102,7 @@ class HIPUtils(object):
with open(so, "rb") as f:
cache_path = cache.put(f.read(), fname, binary=True)
import importlib.util
spec = importlib.util.spec_from_file_location("hip_utils", cache_path)
mod = importlib.util.module_from_spec(spec)
spec.loader.exec_module(mod)
@@ -109,9 +111,8 @@ class HIPUtils(object):
class HIPDriver(DriverBase):
def __new__(cls):
if not hasattr(cls, 'instance'):
if not hasattr(cls, "instance"):
cls.instance = super(HIPDriver, cls).__new__(cls)
return cls.instance
@@ -121,9 +122,8 @@ class HIPDriver(DriverBase):
class UnsupportedDriver(DriverBase):
def __new__(cls):
if not hasattr(cls, 'instance'):
if not hasattr(cls, "instance"):
cls.instance = super(UnsupportedDriver, cls).__new__(cls)
return cls.instance
@@ -131,6 +131,7 @@ class UnsupportedDriver(DriverBase):
self.utils = None
self.backend = None
# -----------------------------
# Driver
# -----------------------------
@@ -150,7 +151,7 @@ class LazyProxy:
return getattr(self._obj, name)
def __setattr__(self, name, value):
if name in ['_init_fn', '_obj']:
if name in ["_init_fn", "_obj"]:
super().__setattr__(name, value)
else:
self._initialize_obj()
@@ -172,6 +173,7 @@ class LazyProxy:
def initialize_driver():
import torch
if torch.version.hip is not None:
return HIPDriver()
elif torch.cuda.is_available():

View File

@@ -1,10 +1,7 @@
class OutOfResources(Exception):
def __init__(self, required, limit, name):
self.message = f'out of resource: {name}, '\
f'Required: {required}, '\
f'Hardware limit: {limit}'
self.message += '. Reducing block sizes or `num_stages` may help.'
self.message = f"out of resource: {name}, " f"Required: {required}, " f"Hardware limit: {limit}"
self.message += ". Reducing block sizes or `num_stages` may help."
self.required = required
self.limit = limit
self.name = name

View File

@@ -37,7 +37,6 @@ def str_to_ty(name):
class TensorHandle:
def __init__(self, data, dtype):
self.data = data
self.dtype = dtype
@@ -47,7 +46,6 @@ class TensorHandle:
class BlockPointerHandle:
def __init__(self, base, shape, strides, offsets, tensor_shape, order):
self.base = base
self.shape = shape
@@ -78,12 +76,13 @@ def wrap_ret(compute_ret_ty):
def wrapped(*args, **kwargs):
ret = fn(*args, **kwargs)
return TensorHandle(ret.data, compute_ret_ty(*args, **kwargs))
return wrapped
return wrapper
class Builder:
def __init__(self) -> None:
self.arch = None
# pass
@@ -249,11 +248,13 @@ class Builder:
# ternary functions
def ternary_op(self, lhs, rhs, other, op):
return TensorHandle(op(lhs.data, rhs.data, other.data), other.dtype)
create_select = lambda self, cond, lhs, rhs: self.ternary_op(cond, lhs, rhs, np.where)
# unary functions
def unary_op(self, arg, op):
return TensorHandle(op(arg.data), arg.dtype)
create_exp = lambda self, arg: self.unary_op(arg, np.exp)
create_cos = lambda self, arg: self.unary_op(arg, np.cos)
create_sin = lambda self, arg: self.unary_op(arg, np.sin)
@@ -279,7 +280,9 @@ class Builder:
dtype_tt = ptr.dtype.element_ty
return TensorHandle(ptr.data + (dtype_tt.primitive_bitwidth // 8) * offset.data.astype(np.uint64), ptr.dtype)
def create_tensor_pointer_load(self, ptr, boundary_check, padding_option, cache_modifier, eviction_policy, is_volatile):
def create_tensor_pointer_load(
self, ptr, boundary_check, padding_option, cache_modifier, eviction_policy, is_volatile
):
ptrs, masks = ptr.materialize_pointers(boundary_check)
assert padding_option is None
other = None
@@ -297,6 +300,7 @@ class Builder:
def create_int_to_ptr(self, val, dst_ty):
return TensorHandle(val.data.astype(np.uint64), dst_ty)
# def create_cat(self, lhs, rhs):
# pass
@@ -360,7 +364,9 @@ class Builder:
def patch_attr(obj, name, member, builder):
new_member = lambda *args, member=member, **kwargs: (member(*args, **{k: v for k, v in kwargs.items() if k != '_builder'}, _builder=builder))
new_member = lambda *args, member=member, **kwargs: (
member(*args, **{k: v for k, v in kwargs.items() if k != "_builder"}, _builder=builder)
)
setattr(obj, name, new_member)
@@ -384,8 +390,8 @@ def _patch_lang_core(lang, builder):
def _new_reduce(input, axis, combine_fn):
fn = combine_fn.fn.__name__
mapping = {
'maximum': np.max,
'_sum_combine': np.sum,
"maximum": np.max,
"_sum_combine": np.sum,
}
ret = mapping[fn](input.handle.data, axis=axis)
ret_type = tl.block_type(input.dtype, ret.shape)
@@ -397,12 +403,12 @@ def _patch_lang_core(lang, builder):
def _patch_lang_math(lang, builder):
math = lang.math
mapping = {
'abs': 'abs',
'acos': 'arccos',
'asin': 'arcsin',
'exp2': 'exp2',
'log2': 'log2',
'max': 'maximum',
"abs": "abs",
"acos": "arccos",
"asin": "arcsin",
"exp2": "exp2",
"log2": "log2",
"max": "maximum",
}
def make_numpy(name):
@@ -414,15 +420,19 @@ def _patch_lang_math(lang, builder):
ret = getattr(np, mapping[name])(*args, **kwargs)
ret = tl.core.tensor(TensorHandle(ret, ret_dtype), ret_type)
return ret
return impl
def make_fallback(name):
def fallback(*args, **kwargs):
raise NotImplementedError(f"""
raise NotImplementedError(
f"""
{name} not supported in interpreter mode: no known numpy implementation.
If you think that {name} in fact does have a numpy implementation, please add it
to the mapping in python/triton/interpreter/new_interpreter.py:_patch_lang_math.
""")
"""
)
return fallback
for name, member in inspect.getmembers(math):
@@ -438,7 +448,7 @@ def _implicit_cvt(arg):
ty = str_to_ty(triton.runtime.jit.JITFunction._type_of(triton.runtime.jit.JITFunction._key_of(arg)))
handle = TensorHandle(np.array([arg], dtype=np.int32), ty)
return tl.tensor(handle, ty)
if hasattr(arg, 'data_ptr'):
if hasattr(arg, "data_ptr"):
ty = str_to_ty(triton.runtime.jit.JITFunction._type_of(triton.runtime.jit.JITFunction._key_of(arg)))
handle = TensorHandle(np.array([arg.data_ptr()], dtype=np.uint64), ty)
return tl.tensor(handle, ty)
@@ -453,28 +463,28 @@ def _unwrap(tensor):
builder = Builder()
RESERVED_KWS = ['num_warps', 'num_stages', 'num_ctas', 'enable_warp_specialization', 'enable_fp_fusion']
RESERVED_KWS = ["num_warps", "num_stages", "num_ctas", "enable_warp_specialization", "enable_fp_fusion"]
class GridExecutor:
def __init__(self, fn, arg_names, grid):
from .jit import _normalize_ty # TODO: modularize
self.fn = fn
self.arg_names = arg_names
self.grid = grid
__annotations__ = {name: _normalize_ty(ty) for name, ty in fn.__annotations__.items()}
self.constexprs = [name for name in arg_names if __annotations__.get(name) == 'constexpr']
self.constexprs = [name for name in arg_names if __annotations__.get(name) == "constexpr"]
def _patch_lang(self, builder):
lang = [value for _, value in self.fn.__globals__.items() if value in [tl, tl.core]]
assert len(lang) == 1, "triton.language must be visible from within jit'd function"
_patch_lang_tensor(getattr(lang[0], 'tensor'), builder)
_patch_lang_tensor(getattr(lang[0], "tensor"), builder)
_patch_lang_core(lang[0], builder)
_patch_lang_math(lang[0], builder)
def __call__(self, *args_dev, **kwargs):
args_hst = [_unwrap(arg).cpu() if hasattr(arg, 'data_ptr') else arg for arg in args_dev]
args_hst = [_unwrap(arg).cpu() if hasattr(arg, "data_ptr") else arg for arg in args_dev]
# removes reserved keywords from kwargs
kwargs = {k: v for k, v in kwargs.items() if k not in RESERVED_KWS}
# remaps core language functions to interpreted ones
@@ -495,26 +505,26 @@ class GridExecutor:
self.fn(**args)
# copy arguments back to propagate side-effects
for arg_dev, arg_hst in zip(args_dev, args_hst):
if hasattr(arg_dev, 'data_ptr'):
if hasattr(arg_dev, "data_ptr"):
_unwrap(arg_dev).copy_(arg_hst.to(arg_dev.device))
class InterpretedFunction:
def _patch_lang(self, builder):
lang = [value for _, value in self.fn.__globals__.items() if value in [tl, tl.core]]
assert len(lang) == 1, "triton.language must be visible from within jit'd function"
_patch_lang_tensor(getattr(lang[0], 'tensor'), builder)
_patch_lang_tensor(getattr(lang[0], "tensor"), builder)
_patch_lang_core(lang[0], builder)
def __init__(self, fn) -> None:
self.fn = fn
def run(*args, **kwargs):
grid = kwargs['grid']
kwargs = {k: v for k, v in kwargs.items() if k not in RESERVED_KWS + ['grid']}
grid = kwargs["grid"]
kwargs = {k: v for k, v in kwargs.items() if k not in RESERVED_KWS + ["grid"]}
return GridExecutor(self.fn, self.arg_names, grid)(*args, **kwargs)
self.run = run
signature = inspect.signature(fn)
self.arg_names = [v.name for v in signature.parameters.values()]

View File

@@ -7,8 +7,7 @@ import inspect
import os
import textwrap
from collections import defaultdict, namedtuple
from typing import (Callable, Generic, Iterable, List, Optional, TypeVar, Union, cast,
overload)
from typing import Callable, Generic, Iterable, List, Optional, TypeVar, Union, cast, overload
from .._C.libtriton.triton import TMAInfos
from ..common.backend import get_backend, get_cuda_version_key
@@ -21,28 +20,33 @@ def get_cuda_stream(idx=None):
idx = get_current_device()
try:
from torch._C import _cuda_getCurrentRawStream
return _cuda_getCurrentRawStream(idx)
except ImportError:
import torch
return torch.cuda.current_stream(idx).cuda_stream
def get_current_device():
import torch
return torch.cuda.current_device()
def set_current_device(idx):
import torch
torch.cuda.set_device(idx)
def get_device_capability(idx):
import torch
return torch.cuda.get_device_capability(idx)
T = TypeVar('T')
T = TypeVar("T")
# -----------------------------------------------------------------------------
# Dependencies Finder
@@ -68,7 +72,9 @@ class DependenciesFinder(ast.NodeVisitor):
lhs = self.visit(node.value)
while isinstance(lhs, ast.Attribute):
lhs = self.visit(lhs.value)
if lhs is None or (getattr(lhs, "__name__", "") == "triton" or getattr(lhs, "__name__", "").endswith(".triton")):
if lhs is None or (
getattr(lhs, "__name__", "") == "triton" or getattr(lhs, "__name__", "").endswith(".triton")
):
return None
return getattr(lhs, node.attr)
@@ -78,18 +84,21 @@ class DependenciesFinder(ast.NodeVisitor):
return
if inspect.isbuiltin(func):
return
if func.__module__ and (func.__module__.startswith('triton.') or '.triton.' in func.__module__):
if func.__module__ and (func.__module__.startswith("triton.") or ".triton." in func.__module__):
return
assert isinstance(func, JITFunction), f"Function \"{func.__name__}\" is being called from a Triton function but is not a Triton function itself. Decorate it with @triton.jit to fix this"
assert isinstance(
func, JITFunction
), f'Function "{func.__name__}" is being called from a Triton function but is not a Triton function itself. Decorate it with @triton.jit to fix this'
if func.hash is None:
tree = ast.parse(func.src)
finder = DependenciesFinder(func.__globals__, func.src)
finder.visit(tree)
func.hash = finder.ret
noinline = str(getattr(func, 'noinline', False))
noinline = str(getattr(func, "noinline", False))
self.ret = (self.ret + func.hash + noinline).encode("utf-8")
self.ret = hashlib.sha1(self.ret).hexdigest()
# -----------------------------------------------------------------------------
# JITFunction
# -----------------------------------------------------------------------------
@@ -116,7 +125,6 @@ class KernelInterface(Generic[T]):
class JITFunction(KernelInterface[T]):
# Hook for inspecting compiled functions and modules
cache_hook = None
divisibility = 16
@@ -133,26 +141,26 @@ class JITFunction(KernelInterface[T]):
elif isinstance(arg, bool):
return "i1"
elif isinstance(arg, int):
if -2**31 <= arg and arg <= 2**31 - 1:
if -(2**31) <= arg and arg <= 2**31 - 1:
return "i32"
elif 2**63 <= arg and arg <= 2**64 - 1:
return "u64"
else:
return "i64"
elif isinstance(arg, float):
return 'fp32'
return "fp32"
elif arg is None:
return None
else:
raise TypeError(f'Unsupported type {type(arg)} for {arg}')
raise TypeError(f"Unsupported type {type(arg)} for {arg}")
@staticmethod
def _device_of(arg):
if hasattr(arg, "device"):
if hasattr(arg.device, 'type'):
if hasattr(arg.device, "type"):
return arg.device.type
return ''
return ""
@staticmethod
def _pinned_memory_of(arg):
@@ -165,10 +173,10 @@ class JITFunction(KernelInterface[T]):
@staticmethod
def _spec_of(arg):
if hasattr(arg, "data_ptr"):
return (arg.data_ptr() % JITFunction.divisibility == 0)
return arg.data_ptr() % JITFunction.divisibility == 0
elif isinstance(arg, int):
return (arg % 16 == 0, arg == 1)
return (arg is None, )
return (arg is None,)
def _get_config(self, *args):
def is_divisible_by_16(x):
@@ -186,20 +194,23 @@ class JITFunction(KernelInterface[T]):
if x is None:
return True
return False
divisible_by_16 = {i for i, arg in enumerate(
args) if is_divisible_by_16(arg) and i not in self.do_not_specialize}
divisible_by_8 = {i for i, arg in enumerate(
args) if is_divisible_by_8(arg) and i not in self.do_not_specialize}
divisible_by_16 = {
i for i, arg in enumerate(args) if is_divisible_by_16(arg) and i not in self.do_not_specialize
}
divisible_by_8 = {i for i, arg in enumerate(args) if is_divisible_by_8(arg) and i not in self.do_not_specialize}
equal_to_1 = {
i for i, arg in enumerate(args) if isinstance(
arg, int) and not isinstance(
arg, bool) and arg == 1 and i not in self.do_not_specialize}
i
for i, arg in enumerate(args)
if isinstance(arg, int) and not isinstance(arg, bool) and arg == 1 and i not in self.do_not_specialize
}
# folded equal_to_1 and None
# TODO: method to collect all folded args
none_args = {i for i, arg in enumerate(args) if arg is None and i not in self.do_not_specialize}
ids_of_folded_args = equal_to_1 | none_args
return namedtuple("instance_descriptor", ["divisible_by_16", "equal_to_1", "ids_of_folded_args", "divisible_by_8"])(
tuple(divisible_by_16), tuple(equal_to_1), tuple(ids_of_folded_args), tuple(divisible_by_8))
return namedtuple(
"instance_descriptor", ["divisible_by_16", "equal_to_1", "ids_of_folded_args", "divisible_by_8"]
)(tuple(divisible_by_16), tuple(equal_to_1), tuple(ids_of_folded_args), tuple(divisible_by_8))
# return _triton.code_gen.instance_descriptor(divisible_by_16,
# equal_to_1)
@@ -207,7 +218,7 @@ class JITFunction(KernelInterface[T]):
def _type_of(key):
# None are nullptr -- implicitly converted to *i8
if key is None:
return '*i8'
return "*i8"
dtype_str = str(key).split(".")[-1]
tys = {
"bool": "i1",
@@ -239,12 +250,25 @@ class JITFunction(KernelInterface[T]):
constants = dict(zip(self.constexprs, constexpr_key))
return constants
def _call_hook(self, key, signature, device, constants, num_warps, num_ctas, num_stages, enable_warp_specialization, enable_fp_fusion, extern_libs, configs):
def _call_hook(
self,
key,
signature,
device,
constants,
num_warps,
num_ctas,
num_stages,
enable_warp_specialization,
enable_fp_fusion,
extern_libs,
configs,
):
if JITFunction.cache_hook is None:
return False
name = self.fn.__name__
module = self.fn.__module__
arg_reprs = ', '.join([f'{name}: {ty}' for name, ty in zip(self.arg_names, key[1])])
arg_reprs = ", ".join([f"{name}: {ty}" for name, ty in zip(self.arg_names, key[1])])
repr = f"{name}[num_warps={num_warps}, num_ctas={num_ctas}, num_stages={num_stages}, enable_warp_specialization={enable_warp_specialization}, enable_fp_fusion={enable_fp_fusion}]({arg_reprs})"
key = str(key)
@@ -254,65 +278,95 @@ class JITFunction(KernelInterface[T]):
self.name = name
pass
kwargs = dict(signature=signature, device=device, constants=constants,
num_warps=num_warps, num_ctas=num_ctas, num_stages=num_stages, enable_warp_specialization=enable_warp_specialization, enable_fp_fusion=enable_fp_fusion, extern_libs=extern_libs,
configs=configs)
kwargs = dict(
signature=signature,
device=device,
constants=constants,
num_warps=num_warps,
num_ctas=num_ctas,
num_stages=num_stages,
enable_warp_specialization=enable_warp_specialization,
enable_fp_fusion=enable_fp_fusion,
extern_libs=extern_libs,
configs=configs,
)
return JITFunction.cache_hook(key=key, repr=repr, fn=LegacyCompiler(module, name), compile={
"key": key, **kwargs}, is_manual_warmup=False, already_compiled=False)
return JITFunction.cache_hook(
key=key,
repr=repr,
fn=LegacyCompiler(module, name),
compile={"key": key, **kwargs},
is_manual_warmup=False,
already_compiled=False,
)
def _get_arg_specialization_key(self, arg_name, arg):
arg_annotation = self.__annotations__.get(arg_name, '')
if arg_annotation == '':
return (arg.data_ptr() % JITFunction.divisibility == 0) if hasattr(arg, "data_ptr") \
else (arg % JITFunction.divisibility == 0, arg % JITFunction.divisibility_8 == 0, arg == 1) if isinstance(arg, int) \
arg_annotation = self.__annotations__.get(arg_name, "")
if arg_annotation == "":
return (
(arg.data_ptr() % JITFunction.divisibility == 0)
if hasattr(arg, "data_ptr")
else (arg % JITFunction.divisibility == 0, arg % JITFunction.divisibility_8 == 0, arg == 1)
if isinstance(arg, int)
else (False,)
elif 'Tensor' in arg_annotation:
return (arg.data_ptr() % JITFunction.divisibility == 0)
elif 'int' in arg_annotation or 'bool' in arg_annotation:
)
elif "Tensor" in arg_annotation:
return arg.data_ptr() % JITFunction.divisibility == 0
elif "int" in arg_annotation or "bool" in arg_annotation:
return (arg % JITFunction.divisibility == 0, arg % JITFunction.divisibility_8 == 0, arg == 1)
else:
return (False,)
def _get_arg_sig_key(self, arg_name, arg) -> str:
arg_annotation = self.__annotations__.get(arg_name, '')
if 'Tensor' in arg_annotation:
arg_annotation = self.__annotations__.get(arg_name, "")
if "Tensor" in arg_annotation:
return arg.dtype
elif arg_annotation == 'bool':
elif arg_annotation == "bool":
return "i1"
elif arg_annotation == 'float':
return 'fp32'
elif arg_annotation == "float":
return "fp32"
else:
return self._key_of(arg)
def _conclude_device_type(self, device_types: List[str], pinned_memory_flags: List[bool]) -> str:
device_types = [device_type for device_type in device_types if device_type != '']
device_types = [device_type for device_type in device_types if device_type != ""]
# Return cuda if one of the input tensors is cuda
if 'cuda' in device_types:
if "cuda" in device_types:
import torch
return 'hip' if torch.version.hip else 'cuda'
is_cpu = all(device_type == 'cpu' for device_type in device_types)
return "hip" if torch.version.hip else "cuda"
is_cpu = all(device_type == "cpu" for device_type in device_types)
is_pinned_memory = any(pinned_memory_flag for pinned_memory_flag in pinned_memory_flags)
# Return cuda if all the input tensors are cpu while the memory is pinned
if is_cpu and is_pinned_memory:
return 'cuda'
return "cuda"
return device_types[0] if len(device_types) > 0 else 'cuda'
return device_types[0] if len(device_types) > 0 else "cuda"
def _make_launcher(self):
regular_args = [arg for i, arg in enumerate(
self.arg_names) if i not in self.constexprs]
constexpr_args = [arg for i, arg in enumerate(
self.arg_names) if i in self.constexprs]
regular_args = [arg for i, arg in enumerate(self.arg_names) if i not in self.constexprs]
constexpr_args = [arg for i, arg in enumerate(self.arg_names) if i in self.constexprs]
def regular_args_v(args_proxy):
return [args_proxy[arg_name] for arg_name in regular_args]
def launcher_body(args_proxy, grid, num_warps, num_ctas, num_stages, enable_warp_specialization, enable_fp_fusion, extern_libs, stream, warmup, device, device_type):
from ..compiler import (CompiledKernel, compile,
get_arch_default_num_stages,
get_arch_default_num_warps)
def launcher_body(
args_proxy,
grid,
num_warps,
num_ctas,
num_stages,
enable_warp_specialization,
enable_fp_fusion,
extern_libs,
stream,
warmup,
device,
device_type,
):
from ..compiler import CompiledKernel, compile, get_arch_default_num_stages, get_arch_default_num_warps
sig_key = tuple([self._get_arg_sig_key(arg_name, args_proxy[arg_name]) for arg_name in regular_args])
constexpr_key = tuple([args_proxy[arg_name] for arg_name in constexpr_args])
specializations = []
@@ -332,25 +386,26 @@ class JITFunction(KernelInterface[T]):
grid_2 = grid[2] if grid_size > 2 else 1
if device_type is None:
device_types = [self._device_of(arg) for arg in regular_args_v(args_proxy)]
device_types = [_device_type for _device_type in device_types if _device_type != '']
device_type = self._conclude_device_type(device_types, [self._pinned_memory_of(arg) for arg in
regular_args_v(args_proxy)])
device_types = [_device_type for _device_type in device_types if _device_type != ""]
device_type = self._conclude_device_type(
device_types, [self._pinned_memory_of(arg) for arg in regular_args_v(args_proxy)]
)
device_backend = None
if device_type not in ['cuda']:
if device_type not in ["cuda"]:
device_backend = get_backend(device_type)
if device_backend is None:
raise ValueError('Cannot find backend for ' + device_type)
raise ValueError("Cannot find backend for " + device_type)
if device is None:
if device_type in ['cuda']:
if device_type in ["cuda"]:
device = get_current_device()
set_current_device(device)
else:
device = device_backend.get_current_device()
device_backend.set_current_device(device)
if stream is None and not warmup:
if device_type in ['cuda']:
if device_type in ["cuda"]:
stream = get_cuda_stream(device)
else:
stream = device_backend.get_stream()
@@ -360,11 +415,22 @@ class JITFunction(KernelInterface[T]):
if num_stages is None:
num_stages = get_arch_default_num_stages(device_type)
if device_type in ['cuda']:
if device_type in ["cuda"]:
version_key = get_cuda_version_key()
else:
version_key = device_backend.get_version_key()
key = (version_key, sig_key, constexpr_key, spec_key, num_warps, num_ctas, num_stages, enable_warp_specialization, enable_fp_fusion, self.debug)
key = (
version_key,
sig_key,
constexpr_key,
spec_key,
num_warps,
num_ctas,
num_stages,
enable_warp_specialization,
enable_fp_fusion,
self.debug,
)
if extern_libs is not None:
key = (key, tuple(extern_libs.items()))
@@ -375,37 +441,104 @@ class JITFunction(KernelInterface[T]):
# Create tensormaps and append to args
args = bin.assemble_tensormap_to_arg(args)
if not warmup:
bin.c_wrapper(grid_0, grid_1, grid_2, bin.num_warps, bin.num_ctas, bin.clusterDims[0], bin.clusterDims[1], bin.clusterDims[2], bin.shared, stream, bin.cu_function, CompiledKernel.launch_enter_hook, CompiledKernel.launch_exit_hook, bin, *args)
bin.c_wrapper(
grid_0,
grid_1,
grid_2,
bin.num_warps,
bin.num_ctas,
bin.clusterDims[0],
bin.clusterDims[1],
bin.clusterDims[2],
bin.shared,
stream,
bin.cu_function,
CompiledKernel.launch_enter_hook,
CompiledKernel.launch_exit_hook,
bin,
*args,
)
return bin
# kernel not cached -- compile
else:
# build dict of constant values
args = regular_args_v(args_proxy)
all_args = tuple([args_proxy[arg_name] for arg_name in self.arg_names])
configs = self._get_config(*all_args),
configs = (self._get_config(*all_args),)
constants = self._make_constants(constexpr_key)
constants.update({i: None for i, arg in enumerate(all_args) if arg is None})
constants.update({i: 1 for i in configs[0].equal_to_1})
# build kernel signature -- doesn't include specialized arguments
signature = {i: self._type_of(self._key_of(arg)) for i, arg in enumerate(all_args) if i not in self.constexprs}
signature = {
i: self._type_of(self._key_of(arg)) for i, arg in enumerate(all_args) if i not in self.constexprs
}
# build stub signature -- includes arguments that are specialized
for i, arg in constants.items():
if callable(arg):
raise TypeError(f"Callable constexpr at index {i} is not supported")
if not self._call_hook(key, signature, device, constants, num_warps, num_ctas, num_stages, enable_warp_specialization, enable_fp_fusion, extern_libs, configs):
bin = compile(self, signature=signature, device=device, constants=constants, num_warps=num_warps, num_ctas=num_ctas, num_stages=num_stages, enable_warp_specialization=enable_warp_specialization, enable_fp_fusion=enable_fp_fusion, extern_libs=extern_libs, configs=configs, debug=self.debug, device_type=device_type)
if not self._call_hook(
key,
signature,
device,
constants,
num_warps,
num_ctas,
num_stages,
enable_warp_specialization,
enable_fp_fusion,
extern_libs,
configs,
):
bin = compile(
self,
signature=signature,
device=device,
constants=constants,
num_warps=num_warps,
num_ctas=num_ctas,
num_stages=num_stages,
enable_warp_specialization=enable_warp_specialization,
enable_fp_fusion=enable_fp_fusion,
extern_libs=extern_libs,
configs=configs,
debug=self.debug,
device_type=device_type,
)
# Create tensormaps and append to args
args = bin.assemble_tensormap_to_arg(args)
if not warmup:
bin.c_wrapper(grid_0, grid_1, grid_2, bin.num_warps, bin.num_ctas, bin.clusterDims[0], bin.clusterDims[1], bin.clusterDims[2], bin.shared, stream, bin.cu_function, CompiledKernel.launch_enter_hook, CompiledKernel.launch_exit_hook, bin, *args)
bin.c_wrapper(
grid_0,
grid_1,
grid_2,
bin.num_warps,
bin.num_ctas,
bin.clusterDims[0],
bin.clusterDims[1],
bin.clusterDims[2],
bin.shared,
stream,
bin.cu_function,
CompiledKernel.launch_enter_hook,
CompiledKernel.launch_exit_hook,
bin,
*args,
)
self.cache[device][key] = bin
return bin
return None
# create a wrapper to call launcher_body
args_map = ','.join([f'"{arg}": {arg}' for arg in self.arg_names])
args_signature = ', '.join(name if dflt == inspect._empty else f'{name} = triton.language.dtype(\'{dflt}\')' if dtype.is_dtype(f'{dflt}') else f'{name} = {dflt}' for name, dflt in zip(self.arg_names, self.arg_defaults))
args_signature = args_signature + ', ' if len(args_signature) > 0 else ''
args_map = ",".join([f'"{arg}": {arg}' for arg in self.arg_names])
args_signature = ", ".join(
name
if dflt == inspect._empty
else f"{name} = triton.language.dtype('{dflt}')"
if dtype.is_dtype(f"{dflt}")
else f"{name} = {dflt}"
for name, dflt in zip(self.arg_names, self.arg_defaults)
)
args_signature = args_signature + ", " if len(args_signature) > 0 else ""
src = f"""
import triton
def {self.fn.__name__}({args_signature}grid=None, num_warps=None, num_ctas=1, num_stages=None, enable_warp_specialization=False, enable_fp_fusion=True, extern_libs=None, stream=None, warmup=False, device=None, device_type=None):
@@ -426,7 +559,7 @@ def {self.fn.__name__}({args_signature}grid=None, num_warps=None, num_ctas=1, nu
self.has_defaults = any(v != inspect._empty for v in self.arg_defaults)
# function source code (without decorators)
self.src = textwrap.dedent(inspect.getsource(fn))
self.src = self.src[self.src.find("def"):]
self.src = self.src[self.src.find("def") :]
# cache of just-in-time compiled kernels
self.cache = defaultdict(dict)
self.hash = None
@@ -439,11 +572,13 @@ def {self.fn.__name__}({args_signature}grid=None, num_warps=None, num_ctas=1, nu
# annotations
self.__annotations__ = {name: _normalize_ty(ty) for name, ty in fn.__annotations__.items()}
# index of constexprs
self.constexprs = [self.arg_names.index(name) for name, ty in self.__annotations__.items() if 'constexpr' in ty]
self.constexprs = [self.arg_names.index(name) for name, ty in self.__annotations__.items() if "constexpr" in ty]
# specialization hints
regular_args = [arg for i, arg in enumerate(self.arg_names) if i not in self.constexprs]
self.do_not_specialize = [] if do_not_specialize is None else do_not_specialize
self.do_not_specialize = {regular_args.index(arg) if isinstance(arg, str) else arg for arg in self.do_not_specialize}
self.do_not_specialize = {
regular_args.index(arg) if isinstance(arg, str) else arg for arg in self.do_not_specialize
}
# tma info
self.tensormaps_info = TMAInfos()
# launcher
@@ -482,12 +617,12 @@ def {self.fn.__name__}({args_signature}grid=None, num_warps=None, num_ctas=1, nu
def __setattr__(self, name, value):
# - when kernel decorators change, cached kernel
# needs to be cleared
if name == 'kernel_decorators':
if name == "kernel_decorators":
self.kernel = None
super(JITFunction, self).__setattr__(name, value)
# - when `.src` attribute is set, cache path needs
# to be reinitialized
if name == 'src':
if name == "src":
self.hash = None
def __repr__(self):
@@ -553,12 +688,14 @@ def jit(
debug=debug,
noinline=noinline,
)
if fn is not None:
return decorator(fn)
else:
return decorator
# -----------------------------------------------------------------------------
# Utilities for mocking tensors
# -----------------------------------------------------------------------------
@@ -569,10 +706,10 @@ class MockTensor:
Can be used in place of real tensors when calling:
kernel.warmup(MockTensor(torch.float32), ...)
"""
@staticmethod
def wrap_dtype(arg):
if arg.__class__.__name__ == "dtype" and\
arg.__module__ == "torch":
if arg.__class__.__name__ == "dtype" and arg.__module__ == "torch":
return MockTensor(arg)
return arg
@@ -599,7 +736,7 @@ class TensorWrapper:
return self.base.stride(i)
def __str__(self) -> str:
return f'TensorWrapper[{self.dtype}]({self.base})'
return f"TensorWrapper[{self.dtype}]({self.base})"
def element_size(self):
return self.base.element_size()
@@ -617,4 +754,4 @@ def reinterpret(tensor, dtype):
# A new wrapper is needed around an unwrapped tensor.
return TensorWrapper(tensor, dtype)
else:
raise TypeError(f'Cannot reinterpret a {type(tensor)}.')
raise TypeError(f"Cannot reinterpret a {type(tensor)}.")