mirror of
https://github.com/ROCm/ROCm.git
synced 2026-04-05 03:01:17 -04:00
Apply ruff pre-commit to python/triton/runtime. (#2558)
We're in the process of incrementally converting from autopep8 + flake8 + isort to ruff, on a directory-by-directory basis. The motivation to switch away from autopep8 is that I can't get it to wrap long lines, even with -aaa. This seems to be a known problem, https://github.com/hhatto/autopep8/issues/497. See more details about alternatives tried in https://github.com/openai/triton/pull/2557.
This commit is contained in:
@@ -16,6 +16,7 @@ from .. import language as tl
|
||||
|
||||
@jit
|
||||
def _fwd_kernel(
|
||||
# fmt: off
|
||||
Q, K, V, sm_scale,
|
||||
L,
|
||||
Out,
|
||||
@@ -28,6 +29,7 @@ def _fwd_kernel(
|
||||
BLOCK_M: tl.constexpr, BLOCK_DMODEL: tl.constexpr,
|
||||
BLOCK_N: tl.constexpr,
|
||||
IS_CAUSAL: tl.constexpr,
|
||||
# fmt: on
|
||||
):
|
||||
start_m = tl.program_id(0)
|
||||
off_hz = tl.program_id(1)
|
||||
@@ -40,7 +42,7 @@ def _fwd_kernel(
|
||||
strides=(stride_kk, stride_kn),
|
||||
offsets=(0, vk_offset),
|
||||
block_shape=(BLOCK_DMODEL, BLOCK_N),
|
||||
order=(0, 1)
|
||||
order=(0, 1),
|
||||
)
|
||||
V_block_ptr = tl.make_block_ptr(
|
||||
base=V,
|
||||
@@ -48,7 +50,7 @@ def _fwd_kernel(
|
||||
strides=(stride_vn, stride_vk),
|
||||
offsets=(vk_offset, 0),
|
||||
block_shape=(BLOCK_N, BLOCK_DMODEL),
|
||||
order=(1, 0)
|
||||
order=(1, 0),
|
||||
)
|
||||
# initialize offsets
|
||||
offs_m = start_m * BLOCK_M + tl.arange(0, BLOCK_M)
|
||||
@@ -104,7 +106,7 @@ def _fwd_kernel(
|
||||
strides=(stride_om, stride_on),
|
||||
offsets=(vk_offset + start_m * BLOCK_M, 0),
|
||||
block_shape=(BLOCK_M, BLOCK_DMODEL),
|
||||
order=(1, 0)
|
||||
order=(1, 0),
|
||||
)
|
||||
# O_ptrs = Out + qvk_offset + offs_m[:, None] * stride_qm + offs_k[None, :] * stride_qk
|
||||
tl.store(O_block_ptr, acc.to(K.dtype.element_ty))
|
||||
@@ -112,9 +114,11 @@ def _fwd_kernel(
|
||||
|
||||
@jit
|
||||
def _bwd_preprocess(
|
||||
Out, DO,
|
||||
Out,
|
||||
DO,
|
||||
Delta,
|
||||
BLOCK_M: tl.constexpr, D_HEAD: tl.constexpr,
|
||||
BLOCK_M: tl.constexpr,
|
||||
D_HEAD: tl.constexpr,
|
||||
):
|
||||
off_m = tl.program_id(0) * BLOCK_M + tl.arange(0, BLOCK_M)
|
||||
off_n = tl.arange(0, D_HEAD)
|
||||
@@ -129,6 +133,7 @@ def _bwd_preprocess(
|
||||
|
||||
@jit
|
||||
def _bwd_kernel_one_col_block(
|
||||
# fmt: off
|
||||
Q, K, V, sm_scale, qk_scale,
|
||||
Out, DO,
|
||||
DQ, DK, DV,
|
||||
@@ -146,6 +151,7 @@ def _bwd_kernel_one_col_block(
|
||||
SEQUENCE_PARALLEL: tl.constexpr,
|
||||
CAUSAL: tl.constexpr,
|
||||
MMA_V3: tl.constexpr
|
||||
# fmt: on
|
||||
):
|
||||
if CAUSAL:
|
||||
lo = start_n * BLOCK_M
|
||||
@@ -153,7 +159,7 @@ def _bwd_kernel_one_col_block(
|
||||
lo = 0
|
||||
|
||||
Q_offset = (off_z * stride_qz + off_h * stride_qh) // stride_qm
|
||||
DQ_offset = (off_z * stride_qz + off_h * stride_qh)
|
||||
DQ_offset = off_z * stride_qz + off_h * stride_qh
|
||||
K_offset = (off_z * stride_kz + off_h * stride_kh) // stride_kn
|
||||
V_offset = (off_z * stride_vz + off_h * stride_vh) // stride_vn
|
||||
if SEQUENCE_PARALLEL:
|
||||
@@ -188,7 +194,7 @@ def _bwd_kernel_one_col_block(
|
||||
# recompute p = softmax(qk, dim=-1).T
|
||||
# NOTE: `do` is pre-divided by `l`; no normalization here
|
||||
if CAUSAL:
|
||||
qk = tl.where(offs_m_curr[:, None] >= (offs_n[None, :]), float(0.), float("-inf"))
|
||||
qk = tl.where(offs_m_curr[:, None] >= (offs_n[None, :]), float(0.0), float("-inf"))
|
||||
else:
|
||||
qk = tl.zeros([BLOCK_M, BLOCK_N], dtype=tl.float32)
|
||||
qk += tl.dot(q, tl.trans(k))
|
||||
@@ -260,7 +266,7 @@ def _bwd_kernel(
|
||||
strides=(stride_qm, stride_qk),
|
||||
offsets=(0, 0),
|
||||
block_shape=(BLOCK_M, BLOCK_DMODEL),
|
||||
order=(1, 0)
|
||||
order=(1, 0),
|
||||
)
|
||||
K_block_ptr = tl.make_block_ptr(
|
||||
base=K,
|
||||
@@ -268,7 +274,7 @@ def _bwd_kernel(
|
||||
strides=(stride_kn, stride_kk),
|
||||
offsets=(0, 0),
|
||||
block_shape=(BLOCK_M, BLOCK_DMODEL),
|
||||
order=(1, 0)
|
||||
order=(1, 0),
|
||||
)
|
||||
V_block_ptr = tl.make_block_ptr(
|
||||
base=V,
|
||||
@@ -276,7 +282,7 @@ def _bwd_kernel(
|
||||
strides=(stride_vn, stride_vk),
|
||||
offsets=(0, 0),
|
||||
block_shape=(BLOCK_M, BLOCK_DMODEL),
|
||||
order=(1, 0)
|
||||
order=(1, 0),
|
||||
)
|
||||
DO_block_ptr = tl.make_block_ptr(
|
||||
base=DO,
|
||||
@@ -284,7 +290,7 @@ def _bwd_kernel(
|
||||
strides=(stride_qm, stride_qk),
|
||||
offsets=(0, 0),
|
||||
block_shape=(BLOCK_M, BLOCK_DMODEL),
|
||||
order=(1, 0)
|
||||
order=(1, 0),
|
||||
)
|
||||
if SEQUENCE_PARALLEL:
|
||||
DQ_block_ptr = tl.make_block_ptr(
|
||||
@@ -293,7 +299,7 @@ def _bwd_kernel(
|
||||
strides=(stride_qm, stride_qk),
|
||||
offsets=(0, 0),
|
||||
block_shape=(BLOCK_M, BLOCK_DMODEL),
|
||||
order=(1, 0)
|
||||
order=(1, 0),
|
||||
)
|
||||
else:
|
||||
DQ_block_ptr = tl.make_block_ptr(
|
||||
@@ -302,7 +308,7 @@ def _bwd_kernel(
|
||||
strides=(stride_qm, stride_qk),
|
||||
offsets=(0, 0),
|
||||
block_shape=(BLOCK_M, BLOCK_DMODEL),
|
||||
order=(1, 0)
|
||||
order=(1, 0),
|
||||
)
|
||||
|
||||
DK_block_ptr = tl.make_block_ptr(
|
||||
@@ -311,7 +317,7 @@ def _bwd_kernel(
|
||||
strides=(stride_kn, stride_kk),
|
||||
offsets=(0, 0),
|
||||
block_shape=(BLOCK_M, BLOCK_DMODEL),
|
||||
order=(1, 0)
|
||||
order=(1, 0),
|
||||
)
|
||||
DV_block_ptr = tl.make_block_ptr(
|
||||
base=DV,
|
||||
@@ -319,13 +325,14 @@ def _bwd_kernel(
|
||||
strides=(stride_vn, stride_vk),
|
||||
offsets=(0, 0),
|
||||
block_shape=(BLOCK_M, BLOCK_DMODEL),
|
||||
order=(1, 0)
|
||||
order=(1, 0),
|
||||
)
|
||||
|
||||
num_block_n = tl.cdiv(N_CTX, BLOCK_N)
|
||||
if not SEQUENCE_PARALLEL:
|
||||
for start_n in range(0, num_block_n):
|
||||
_bwd_kernel_one_col_block(
|
||||
# fmt: off
|
||||
Q, K, V, sm_scale, qk_scale, Out, DO,
|
||||
DQ, DK, DV,
|
||||
L,
|
||||
@@ -342,10 +349,12 @@ def _bwd_kernel(
|
||||
SEQUENCE_PARALLEL=SEQUENCE_PARALLEL,
|
||||
CAUSAL=CAUSAL,
|
||||
MMA_V3=MMA_V3
|
||||
# fmt: on
|
||||
)
|
||||
else:
|
||||
start_n = tl.program_id(1)
|
||||
_bwd_kernel_one_col_block(
|
||||
# fmt: off
|
||||
Q, K, V, sm_scale, qk_scale, Out, DO,
|
||||
DQ, DK, DV,
|
||||
L,
|
||||
@@ -362,11 +371,11 @@ def _bwd_kernel(
|
||||
SEQUENCE_PARALLEL=SEQUENCE_PARALLEL,
|
||||
CAUSAL=CAUSAL,
|
||||
MMA_V3=MMA_V3
|
||||
# fmt: on
|
||||
)
|
||||
|
||||
|
||||
class _attention(torch.autograd.Function):
|
||||
|
||||
@staticmethod
|
||||
def forward(ctx, q, k, v, causal, sm_scale, sequence_parallel=False):
|
||||
# only support for Ampere now
|
||||
@@ -384,6 +393,7 @@ class _attention(torch.autograd.Function):
|
||||
L = torch.empty((q.shape[0] * q.shape[1], q.shape[2]), device=q.device, dtype=torch.float32)
|
||||
num_warps = 4 if Lk <= 64 else 8
|
||||
_fwd_kernel[grid](
|
||||
# fmt: off
|
||||
q, k, v, sm_scale,
|
||||
L,
|
||||
o,
|
||||
@@ -396,7 +406,9 @@ class _attention(torch.autograd.Function):
|
||||
BLOCK_M=BLOCK_M, BLOCK_N=BLOCK_N, BLOCK_DMODEL=Lk,
|
||||
IS_CAUSAL=causal,
|
||||
num_warps=num_warps,
|
||||
num_stages=4)
|
||||
num_stages=4,
|
||||
# fmt: on
|
||||
)
|
||||
|
||||
ctx.save_for_backward(q, k, v, o, L)
|
||||
ctx.grid = grid
|
||||
@@ -424,12 +436,15 @@ class _attention(torch.autograd.Function):
|
||||
dk = torch.empty_like(k)
|
||||
dv = torch.empty_like(v)
|
||||
delta = torch.empty_like(L)
|
||||
_bwd_preprocess[(cdiv(q.shape[2], BLOCK) * ctx.grid[1], )](
|
||||
o, do,
|
||||
_bwd_preprocess[(cdiv(q.shape[2], BLOCK) * ctx.grid[1],)](
|
||||
o,
|
||||
do,
|
||||
delta,
|
||||
BLOCK_M=BLOCK, D_HEAD=ctx.BLOCK_DMODEL,
|
||||
BLOCK_M=BLOCK,
|
||||
D_HEAD=ctx.BLOCK_DMODEL,
|
||||
)
|
||||
_bwd_kernel[(ctx.grid[1], cdiv(seq_len_kv, BLOCK) if sequence_parallel else 1)](
|
||||
# fmt: off
|
||||
q, k, v, ctx.sm_scale,
|
||||
o, do,
|
||||
dq, dk, dv,
|
||||
@@ -448,6 +463,7 @@ class _attention(torch.autograd.Function):
|
||||
MMA_V3=MMA_V3,
|
||||
num_warps=8,
|
||||
num_stages=1,
|
||||
# fmt: on
|
||||
)
|
||||
|
||||
if len(dq.shape) == 5:
|
||||
|
||||
@@ -1,5 +1,4 @@
|
||||
from .autotuner import (Autotuner, Config, Heuristics, OutOfResources, autotune,
|
||||
heuristics)
|
||||
from .autotuner import Autotuner, Config, Heuristics, OutOfResources, autotune, heuristics
|
||||
from .driver import driver
|
||||
from .jit import JITFunction, KernelInterface, MockTensor, TensorWrapper, reinterpret
|
||||
|
||||
|
||||
@@ -10,10 +10,10 @@ from .jit import KernelInterface
|
||||
|
||||
class OutOfResources(Exception):
|
||||
def __init__(self, required, limit, name):
|
||||
self.message = f'out of resource: {name}, '\
|
||||
f'Required: {required}, '\
|
||||
f'Hardware limit: {limit}'
|
||||
self.message += '. Reducing block sizes or `num_stages` may help.'
|
||||
self.message = (
|
||||
f"out of resource: {name}, Required: {required}, Hardware limit: {limit}. "
|
||||
+ "Reducing block sizes or `num_stages` may help."
|
||||
)
|
||||
self.required = required
|
||||
self.limit = limit
|
||||
self.name = name
|
||||
@@ -26,12 +26,12 @@ class OutOfResources(Exception):
|
||||
|
||||
class Autotuner(KernelInterface):
|
||||
def __init__(self, fn, arg_names, configs, key, reset_to_zero, prune_configs_by: Dict = None, warmup=25, rep=100):
|
||||
'''
|
||||
"""
|
||||
:param prune_configs_by: a dict of functions that are used to prune configs, fields:
|
||||
'perf_model': performance model used to predicate running time with different configs, returns running time
|
||||
'top_k': number of configs to bench
|
||||
'prune_num_stages_by'(optional): a function used to prune num_stages. It takes configs:List[Config] as its input, and returns pruned configs.
|
||||
'''
|
||||
"""
|
||||
if not configs:
|
||||
self.configs = [Config({}, num_warps=4, num_stages=2, num_ctas=1)]
|
||||
else:
|
||||
@@ -46,13 +46,14 @@ class Autotuner(KernelInterface):
|
||||
def _hook(args):
|
||||
for i in self.reset_idx:
|
||||
args[i].zero_()
|
||||
|
||||
self.hook = _hook
|
||||
self.arg_names = arg_names
|
||||
# prune configs
|
||||
if prune_configs_by:
|
||||
perf_model, top_k = prune_configs_by['perf_model'], prune_configs_by['top_k']
|
||||
if 'early_config_prune' in prune_configs_by:
|
||||
early_config_prune = prune_configs_by['early_config_prune']
|
||||
perf_model, top_k = prune_configs_by["perf_model"], prune_configs_by["top_k"]
|
||||
if "early_config_prune" in prune_configs_by:
|
||||
early_config_prune = prune_configs_by["early_config_prune"]
|
||||
else:
|
||||
perf_model, top_k, early_config_prune = None, None, None
|
||||
self.perf_model, self.configs_top_k = perf_model, top_k
|
||||
@@ -78,15 +79,20 @@ class Autotuner(KernelInterface):
|
||||
if config.pre_hook:
|
||||
config.pre_hook(full_nargs)
|
||||
self.hook(args)
|
||||
self.fn.run(*args, num_warps=config.num_warps, num_stages=config.num_stages,
|
||||
num_ctas=config.num_ctas,
|
||||
enable_warp_specialization=config.enable_warp_specialization,
|
||||
# enable_persistent=False,
|
||||
**current)
|
||||
self.fn.run(
|
||||
*args,
|
||||
num_warps=config.num_warps,
|
||||
num_stages=config.num_stages,
|
||||
num_ctas=config.num_ctas,
|
||||
enable_warp_specialization=config.enable_warp_specialization,
|
||||
# enable_persistent=False,
|
||||
**current,
|
||||
)
|
||||
|
||||
try:
|
||||
return do_bench(kernel_call, warmup=self.warmup, rep=self.rep, quantiles=(0.5, 0.2, 0.8))
|
||||
except OutOfResources:
|
||||
return [float('inf'), float('inf'), float('inf')]
|
||||
return [float("inf"), float("inf"), float("inf")]
|
||||
|
||||
def run(self, *args, **kwargs):
|
||||
self.nargs = dict(zip(self.arg_names, args))
|
||||
@@ -105,8 +111,7 @@ class Autotuner(KernelInterface):
|
||||
# prune configs
|
||||
pruned_configs = self.prune_configs(kwargs)
|
||||
bench_start = time.time()
|
||||
timings = {config: self._bench(*args, config=config, **kwargs)
|
||||
for config in pruned_configs}
|
||||
timings = {config: self._bench(*args, config=config, **kwargs) for config in pruned_configs}
|
||||
bench_end = time.time()
|
||||
self.bench_time = bench_end - bench_start
|
||||
self.cache[key] = builtins.min(timings, key=timings.get)
|
||||
@@ -119,9 +124,15 @@ class Autotuner(KernelInterface):
|
||||
full_nargs = {**self.nargs, **kwargs, **self.best_config.kwargs}
|
||||
if config.pre_hook is not None:
|
||||
config.pre_hook(full_nargs)
|
||||
ret = self.fn.run(*args, num_warps=config.num_warps, num_stages=config.num_stages,
|
||||
num_ctas=config.num_ctas,
|
||||
enable_warp_specialization=config.enable_warp_specialization, **kwargs, **config.kwargs)
|
||||
ret = self.fn.run(
|
||||
*args,
|
||||
num_warps=config.num_warps,
|
||||
num_stages=config.num_stages,
|
||||
num_ctas=config.num_ctas,
|
||||
enable_warp_specialization=config.enable_warp_specialization,
|
||||
**kwargs,
|
||||
**config.kwargs,
|
||||
)
|
||||
self.nargs = None
|
||||
return ret
|
||||
|
||||
@@ -135,17 +146,19 @@ class Autotuner(KernelInterface):
|
||||
top_k = int(len(self.configs) * top_k)
|
||||
if len(pruned_configs) > top_k:
|
||||
est_timing = {
|
||||
config: self.perf_model(**self.nargs, **kwargs, **config.kwargs, num_stages=config.num_stages,
|
||||
num_warps=config.num_warps,
|
||||
num_ctas=config.num_ctas,
|
||||
enable_warp_specialization=config.enable_warp_specialization,
|
||||
enable_persistent=config.enable_persistent)
|
||||
config: self.perf_model(
|
||||
**self.nargs,
|
||||
**kwargs,
|
||||
**config.kwargs,
|
||||
num_stages=config.num_stages,
|
||||
num_warps=config.num_warps,
|
||||
num_ctas=config.num_ctas,
|
||||
enable_warp_specialization=config.enable_warp_specialization,
|
||||
enable_persistent=config.enable_persistent,
|
||||
)
|
||||
for config in pruned_configs
|
||||
}
|
||||
pruned_configs = sorted(
|
||||
est_timing.keys(),
|
||||
key=lambda x: est_timing[x])[
|
||||
:top_k]
|
||||
pruned_configs = sorted(est_timing.keys(), key=lambda x: est_timing[x])[:top_k]
|
||||
return pruned_configs
|
||||
|
||||
def warmup(self, *args, **kwargs):
|
||||
@@ -195,14 +208,13 @@ class Config:
|
||||
def __str__(self):
|
||||
res = []
|
||||
for k, v in self.kwargs.items():
|
||||
res.append(f'{k}: {v}')
|
||||
res.append(f'num_warps: {self.num_warps}')
|
||||
res.append(f'num_ctas: {self.num_ctas}')
|
||||
res.append(f'num_stages: {self.num_stages}')
|
||||
res.append(
|
||||
f'enable_warp_specialization: {self.enable_warp_specialization}')
|
||||
res.append(f'enable_persistent: {self.enable_persistent}')
|
||||
return ', '.join(res)
|
||||
res.append(f"{k}: {v}")
|
||||
res.append(f"num_warps: {self.num_warps}")
|
||||
res.append(f"num_ctas: {self.num_ctas}")
|
||||
res.append(f"num_stages: {self.num_stages}")
|
||||
res.append(f"enable_warp_specialization: {self.enable_warp_specialization}")
|
||||
res.append(f"enable_persistent: {self.enable_persistent}")
|
||||
return ", ".join(res)
|
||||
|
||||
|
||||
def autotune(configs, key, prune_configs_by=None, reset_to_zero=None, warmup=25, rep=100):
|
||||
@@ -241,6 +253,7 @@ def autotune(configs, key, prune_configs_by=None, reset_to_zero=None, warmup=25,
|
||||
:param rep: Repetition time (in ms) to pass to benchmarking, defaults to 100.
|
||||
:type rep: int
|
||||
"""
|
||||
|
||||
def decorator(fn):
|
||||
return Autotuner(fn, fn.arg_names, configs, key, reset_to_zero, prune_configs_by, warmup, rep)
|
||||
|
||||
@@ -248,7 +261,6 @@ def autotune(configs, key, prune_configs_by=None, reset_to_zero=None, warmup=25,
|
||||
|
||||
|
||||
class Heuristics(KernelInterface):
|
||||
|
||||
def __init__(self, fn, arg_names, values) -> None:
|
||||
self.fn = fn
|
||||
self.values = values
|
||||
@@ -276,6 +288,7 @@ def heuristics(values):
|
||||
each such function takes a list of positional arguments as input.
|
||||
:type values: dict[str, Callable[[list[Any]], Any]]
|
||||
"""
|
||||
|
||||
def decorator(fn):
|
||||
return Heuristics(fn, fn.arg_names, values)
|
||||
|
||||
|
||||
@@ -47,17 +47,17 @@ class FileCacheManager(CacheManager):
|
||||
def __init__(self, key, override=False, dump=False):
|
||||
self.key = key
|
||||
self.lock_path = None
|
||||
if (dump):
|
||||
if dump:
|
||||
self.cache_dir = default_dump_dir()
|
||||
self.cache_dir = os.path.join(self.cache_dir, self.key)
|
||||
self.lock_path = os.path.join(self.cache_dir, "lock")
|
||||
os.makedirs(self.cache_dir, exist_ok=True)
|
||||
elif (override):
|
||||
elif override:
|
||||
self.cache_dir = default_override_dir()
|
||||
self.cache_dir = os.path.join(self.cache_dir, self.key)
|
||||
else:
|
||||
# create cache directory if it doesn't exist
|
||||
self.cache_dir = os.getenv('TRITON_CACHE_DIR', "").strip() or default_cache_dir()
|
||||
self.cache_dir = os.getenv("TRITON_CACHE_DIR", "").strip() or default_cache_dir()
|
||||
if self.cache_dir:
|
||||
self.cache_dir = os.path.join(self.cache_dir, self.key)
|
||||
self.lock_path = os.path.join(self.cache_dir, "lock")
|
||||
@@ -141,6 +141,7 @@ def get_cache_manager(key) -> CacheManager:
|
||||
|
||||
if user_cache_manager is not None and user_cache_manager != __cache_cls_nme:
|
||||
import importlib
|
||||
|
||||
module_path, clz_nme = user_cache_manager.split(":")
|
||||
module = importlib.import_module(module_path)
|
||||
__cache_cls = getattr(module, clz_nme)
|
||||
|
||||
@@ -9,7 +9,6 @@ from .cache import get_cache_manager
|
||||
|
||||
|
||||
class DriverBase(metaclass=abc.ABCMeta):
|
||||
|
||||
CUDA = 0
|
||||
HIP = 1
|
||||
|
||||
@@ -19,15 +18,16 @@ class DriverBase(metaclass=abc.ABCMeta):
|
||||
|
||||
def __init__(self) -> None:
|
||||
pass
|
||||
|
||||
|
||||
# -----------------------------
|
||||
# CUDA
|
||||
# -----------------------------
|
||||
|
||||
|
||||
class CudaUtils(object):
|
||||
|
||||
def __new__(cls):
|
||||
if not hasattr(cls, 'instance'):
|
||||
if not hasattr(cls, "instance"):
|
||||
cls.instance = super(CudaUtils, cls).__new__(cls)
|
||||
return cls.instance
|
||||
|
||||
@@ -47,6 +47,7 @@ class CudaUtils(object):
|
||||
with open(so, "rb") as f:
|
||||
cache_path = cache.put(f.read(), fname, binary=True)
|
||||
import importlib.util
|
||||
|
||||
spec = importlib.util.spec_from_file_location("cuda_utils", cache_path)
|
||||
mod = importlib.util.module_from_spec(spec)
|
||||
spec.loader.exec_module(mod)
|
||||
@@ -64,9 +65,8 @@ class CudaUtils(object):
|
||||
|
||||
|
||||
class CudaDriver(DriverBase):
|
||||
|
||||
def __new__(cls):
|
||||
if not hasattr(cls, 'instance'):
|
||||
if not hasattr(cls, "instance"):
|
||||
cls.instance = super(CudaDriver, cls).__new__(cls)
|
||||
return cls.instance
|
||||
|
||||
@@ -74,6 +74,7 @@ class CudaDriver(DriverBase):
|
||||
self.utils = CudaUtils()
|
||||
self.backend = self.CUDA
|
||||
|
||||
|
||||
# -----------------------------
|
||||
# HIP
|
||||
# -----------------------------
|
||||
@@ -81,7 +82,7 @@ class CudaDriver(DriverBase):
|
||||
|
||||
class HIPUtils(object):
|
||||
def __new__(cls):
|
||||
if not hasattr(cls, 'instance'):
|
||||
if not hasattr(cls, "instance"):
|
||||
cls.instance = super(HIPUtils, cls).__new__(cls)
|
||||
return cls.instance
|
||||
|
||||
@@ -101,6 +102,7 @@ class HIPUtils(object):
|
||||
with open(so, "rb") as f:
|
||||
cache_path = cache.put(f.read(), fname, binary=True)
|
||||
import importlib.util
|
||||
|
||||
spec = importlib.util.spec_from_file_location("hip_utils", cache_path)
|
||||
mod = importlib.util.module_from_spec(spec)
|
||||
spec.loader.exec_module(mod)
|
||||
@@ -109,9 +111,8 @@ class HIPUtils(object):
|
||||
|
||||
|
||||
class HIPDriver(DriverBase):
|
||||
|
||||
def __new__(cls):
|
||||
if not hasattr(cls, 'instance'):
|
||||
if not hasattr(cls, "instance"):
|
||||
cls.instance = super(HIPDriver, cls).__new__(cls)
|
||||
return cls.instance
|
||||
|
||||
@@ -121,9 +122,8 @@ class HIPDriver(DriverBase):
|
||||
|
||||
|
||||
class UnsupportedDriver(DriverBase):
|
||||
|
||||
def __new__(cls):
|
||||
if not hasattr(cls, 'instance'):
|
||||
if not hasattr(cls, "instance"):
|
||||
cls.instance = super(UnsupportedDriver, cls).__new__(cls)
|
||||
return cls.instance
|
||||
|
||||
@@ -131,6 +131,7 @@ class UnsupportedDriver(DriverBase):
|
||||
self.utils = None
|
||||
self.backend = None
|
||||
|
||||
|
||||
# -----------------------------
|
||||
# Driver
|
||||
# -----------------------------
|
||||
@@ -150,7 +151,7 @@ class LazyProxy:
|
||||
return getattr(self._obj, name)
|
||||
|
||||
def __setattr__(self, name, value):
|
||||
if name in ['_init_fn', '_obj']:
|
||||
if name in ["_init_fn", "_obj"]:
|
||||
super().__setattr__(name, value)
|
||||
else:
|
||||
self._initialize_obj()
|
||||
@@ -172,6 +173,7 @@ class LazyProxy:
|
||||
|
||||
def initialize_driver():
|
||||
import torch
|
||||
|
||||
if torch.version.hip is not None:
|
||||
return HIPDriver()
|
||||
elif torch.cuda.is_available():
|
||||
|
||||
@@ -1,10 +1,7 @@
|
||||
|
||||
class OutOfResources(Exception):
|
||||
def __init__(self, required, limit, name):
|
||||
self.message = f'out of resource: {name}, '\
|
||||
f'Required: {required}, '\
|
||||
f'Hardware limit: {limit}'
|
||||
self.message += '. Reducing block sizes or `num_stages` may help.'
|
||||
self.message = f"out of resource: {name}, " f"Required: {required}, " f"Hardware limit: {limit}"
|
||||
self.message += ". Reducing block sizes or `num_stages` may help."
|
||||
self.required = required
|
||||
self.limit = limit
|
||||
self.name = name
|
||||
|
||||
@@ -37,7 +37,6 @@ def str_to_ty(name):
|
||||
|
||||
|
||||
class TensorHandle:
|
||||
|
||||
def __init__(self, data, dtype):
|
||||
self.data = data
|
||||
self.dtype = dtype
|
||||
@@ -47,7 +46,6 @@ class TensorHandle:
|
||||
|
||||
|
||||
class BlockPointerHandle:
|
||||
|
||||
def __init__(self, base, shape, strides, offsets, tensor_shape, order):
|
||||
self.base = base
|
||||
self.shape = shape
|
||||
@@ -78,12 +76,13 @@ def wrap_ret(compute_ret_ty):
|
||||
def wrapped(*args, **kwargs):
|
||||
ret = fn(*args, **kwargs)
|
||||
return TensorHandle(ret.data, compute_ret_ty(*args, **kwargs))
|
||||
|
||||
return wrapped
|
||||
|
||||
return wrapper
|
||||
|
||||
|
||||
class Builder:
|
||||
|
||||
def __init__(self) -> None:
|
||||
self.arch = None
|
||||
# pass
|
||||
@@ -249,11 +248,13 @@ class Builder:
|
||||
# ternary functions
|
||||
def ternary_op(self, lhs, rhs, other, op):
|
||||
return TensorHandle(op(lhs.data, rhs.data, other.data), other.dtype)
|
||||
|
||||
create_select = lambda self, cond, lhs, rhs: self.ternary_op(cond, lhs, rhs, np.where)
|
||||
|
||||
# unary functions
|
||||
def unary_op(self, arg, op):
|
||||
return TensorHandle(op(arg.data), arg.dtype)
|
||||
|
||||
create_exp = lambda self, arg: self.unary_op(arg, np.exp)
|
||||
create_cos = lambda self, arg: self.unary_op(arg, np.cos)
|
||||
create_sin = lambda self, arg: self.unary_op(arg, np.sin)
|
||||
@@ -279,7 +280,9 @@ class Builder:
|
||||
dtype_tt = ptr.dtype.element_ty
|
||||
return TensorHandle(ptr.data + (dtype_tt.primitive_bitwidth // 8) * offset.data.astype(np.uint64), ptr.dtype)
|
||||
|
||||
def create_tensor_pointer_load(self, ptr, boundary_check, padding_option, cache_modifier, eviction_policy, is_volatile):
|
||||
def create_tensor_pointer_load(
|
||||
self, ptr, boundary_check, padding_option, cache_modifier, eviction_policy, is_volatile
|
||||
):
|
||||
ptrs, masks = ptr.materialize_pointers(boundary_check)
|
||||
assert padding_option is None
|
||||
other = None
|
||||
@@ -297,6 +300,7 @@ class Builder:
|
||||
|
||||
def create_int_to_ptr(self, val, dst_ty):
|
||||
return TensorHandle(val.data.astype(np.uint64), dst_ty)
|
||||
|
||||
# def create_cat(self, lhs, rhs):
|
||||
# pass
|
||||
|
||||
@@ -360,7 +364,9 @@ class Builder:
|
||||
|
||||
|
||||
def patch_attr(obj, name, member, builder):
|
||||
new_member = lambda *args, member=member, **kwargs: (member(*args, **{k: v for k, v in kwargs.items() if k != '_builder'}, _builder=builder))
|
||||
new_member = lambda *args, member=member, **kwargs: (
|
||||
member(*args, **{k: v for k, v in kwargs.items() if k != "_builder"}, _builder=builder)
|
||||
)
|
||||
setattr(obj, name, new_member)
|
||||
|
||||
|
||||
@@ -384,8 +390,8 @@ def _patch_lang_core(lang, builder):
|
||||
def _new_reduce(input, axis, combine_fn):
|
||||
fn = combine_fn.fn.__name__
|
||||
mapping = {
|
||||
'maximum': np.max,
|
||||
'_sum_combine': np.sum,
|
||||
"maximum": np.max,
|
||||
"_sum_combine": np.sum,
|
||||
}
|
||||
ret = mapping[fn](input.handle.data, axis=axis)
|
||||
ret_type = tl.block_type(input.dtype, ret.shape)
|
||||
@@ -397,12 +403,12 @@ def _patch_lang_core(lang, builder):
|
||||
def _patch_lang_math(lang, builder):
|
||||
math = lang.math
|
||||
mapping = {
|
||||
'abs': 'abs',
|
||||
'acos': 'arccos',
|
||||
'asin': 'arcsin',
|
||||
'exp2': 'exp2',
|
||||
'log2': 'log2',
|
||||
'max': 'maximum',
|
||||
"abs": "abs",
|
||||
"acos": "arccos",
|
||||
"asin": "arcsin",
|
||||
"exp2": "exp2",
|
||||
"log2": "log2",
|
||||
"max": "maximum",
|
||||
}
|
||||
|
||||
def make_numpy(name):
|
||||
@@ -414,15 +420,19 @@ def _patch_lang_math(lang, builder):
|
||||
ret = getattr(np, mapping[name])(*args, **kwargs)
|
||||
ret = tl.core.tensor(TensorHandle(ret, ret_dtype), ret_type)
|
||||
return ret
|
||||
|
||||
return impl
|
||||
|
||||
def make_fallback(name):
|
||||
def fallback(*args, **kwargs):
|
||||
raise NotImplementedError(f"""
|
||||
raise NotImplementedError(
|
||||
f"""
|
||||
{name} not supported in interpreter mode: no known numpy implementation.
|
||||
If you think that {name} in fact does have a numpy implementation, please add it
|
||||
to the mapping in python/triton/interpreter/new_interpreter.py:_patch_lang_math.
|
||||
""")
|
||||
"""
|
||||
)
|
||||
|
||||
return fallback
|
||||
|
||||
for name, member in inspect.getmembers(math):
|
||||
@@ -438,7 +448,7 @@ def _implicit_cvt(arg):
|
||||
ty = str_to_ty(triton.runtime.jit.JITFunction._type_of(triton.runtime.jit.JITFunction._key_of(arg)))
|
||||
handle = TensorHandle(np.array([arg], dtype=np.int32), ty)
|
||||
return tl.tensor(handle, ty)
|
||||
if hasattr(arg, 'data_ptr'):
|
||||
if hasattr(arg, "data_ptr"):
|
||||
ty = str_to_ty(triton.runtime.jit.JITFunction._type_of(triton.runtime.jit.JITFunction._key_of(arg)))
|
||||
handle = TensorHandle(np.array([arg.data_ptr()], dtype=np.uint64), ty)
|
||||
return tl.tensor(handle, ty)
|
||||
@@ -453,28 +463,28 @@ def _unwrap(tensor):
|
||||
|
||||
builder = Builder()
|
||||
|
||||
RESERVED_KWS = ['num_warps', 'num_stages', 'num_ctas', 'enable_warp_specialization', 'enable_fp_fusion']
|
||||
RESERVED_KWS = ["num_warps", "num_stages", "num_ctas", "enable_warp_specialization", "enable_fp_fusion"]
|
||||
|
||||
|
||||
class GridExecutor:
|
||||
|
||||
def __init__(self, fn, arg_names, grid):
|
||||
from .jit import _normalize_ty # TODO: modularize
|
||||
|
||||
self.fn = fn
|
||||
self.arg_names = arg_names
|
||||
self.grid = grid
|
||||
__annotations__ = {name: _normalize_ty(ty) for name, ty in fn.__annotations__.items()}
|
||||
self.constexprs = [name for name in arg_names if __annotations__.get(name) == 'constexpr']
|
||||
self.constexprs = [name for name in arg_names if __annotations__.get(name) == "constexpr"]
|
||||
|
||||
def _patch_lang(self, builder):
|
||||
lang = [value for _, value in self.fn.__globals__.items() if value in [tl, tl.core]]
|
||||
assert len(lang) == 1, "triton.language must be visible from within jit'd function"
|
||||
_patch_lang_tensor(getattr(lang[0], 'tensor'), builder)
|
||||
_patch_lang_tensor(getattr(lang[0], "tensor"), builder)
|
||||
_patch_lang_core(lang[0], builder)
|
||||
_patch_lang_math(lang[0], builder)
|
||||
|
||||
def __call__(self, *args_dev, **kwargs):
|
||||
args_hst = [_unwrap(arg).cpu() if hasattr(arg, 'data_ptr') else arg for arg in args_dev]
|
||||
args_hst = [_unwrap(arg).cpu() if hasattr(arg, "data_ptr") else arg for arg in args_dev]
|
||||
# removes reserved keywords from kwargs
|
||||
kwargs = {k: v for k, v in kwargs.items() if k not in RESERVED_KWS}
|
||||
# remaps core language functions to interpreted ones
|
||||
@@ -495,26 +505,26 @@ class GridExecutor:
|
||||
self.fn(**args)
|
||||
# copy arguments back to propagate side-effects
|
||||
for arg_dev, arg_hst in zip(args_dev, args_hst):
|
||||
if hasattr(arg_dev, 'data_ptr'):
|
||||
if hasattr(arg_dev, "data_ptr"):
|
||||
_unwrap(arg_dev).copy_(arg_hst.to(arg_dev.device))
|
||||
|
||||
|
||||
class InterpretedFunction:
|
||||
|
||||
def _patch_lang(self, builder):
|
||||
lang = [value for _, value in self.fn.__globals__.items() if value in [tl, tl.core]]
|
||||
assert len(lang) == 1, "triton.language must be visible from within jit'd function"
|
||||
_patch_lang_tensor(getattr(lang[0], 'tensor'), builder)
|
||||
_patch_lang_tensor(getattr(lang[0], "tensor"), builder)
|
||||
_patch_lang_core(lang[0], builder)
|
||||
|
||||
def __init__(self, fn) -> None:
|
||||
self.fn = fn
|
||||
|
||||
def run(*args, **kwargs):
|
||||
grid = kwargs['grid']
|
||||
kwargs = {k: v for k, v in kwargs.items() if k not in RESERVED_KWS + ['grid']}
|
||||
grid = kwargs["grid"]
|
||||
kwargs = {k: v for k, v in kwargs.items() if k not in RESERVED_KWS + ["grid"]}
|
||||
|
||||
return GridExecutor(self.fn, self.arg_names, grid)(*args, **kwargs)
|
||||
|
||||
self.run = run
|
||||
signature = inspect.signature(fn)
|
||||
self.arg_names = [v.name for v in signature.parameters.values()]
|
||||
|
||||
@@ -7,8 +7,7 @@ import inspect
|
||||
import os
|
||||
import textwrap
|
||||
from collections import defaultdict, namedtuple
|
||||
from typing import (Callable, Generic, Iterable, List, Optional, TypeVar, Union, cast,
|
||||
overload)
|
||||
from typing import Callable, Generic, Iterable, List, Optional, TypeVar, Union, cast, overload
|
||||
|
||||
from .._C.libtriton.triton import TMAInfos
|
||||
from ..common.backend import get_backend, get_cuda_version_key
|
||||
@@ -21,28 +20,33 @@ def get_cuda_stream(idx=None):
|
||||
idx = get_current_device()
|
||||
try:
|
||||
from torch._C import _cuda_getCurrentRawStream
|
||||
|
||||
return _cuda_getCurrentRawStream(idx)
|
||||
except ImportError:
|
||||
import torch
|
||||
|
||||
return torch.cuda.current_stream(idx).cuda_stream
|
||||
|
||||
|
||||
def get_current_device():
|
||||
import torch
|
||||
|
||||
return torch.cuda.current_device()
|
||||
|
||||
|
||||
def set_current_device(idx):
|
||||
import torch
|
||||
|
||||
torch.cuda.set_device(idx)
|
||||
|
||||
|
||||
def get_device_capability(idx):
|
||||
import torch
|
||||
|
||||
return torch.cuda.get_device_capability(idx)
|
||||
|
||||
|
||||
T = TypeVar('T')
|
||||
T = TypeVar("T")
|
||||
|
||||
# -----------------------------------------------------------------------------
|
||||
# Dependencies Finder
|
||||
@@ -68,7 +72,9 @@ class DependenciesFinder(ast.NodeVisitor):
|
||||
lhs = self.visit(node.value)
|
||||
while isinstance(lhs, ast.Attribute):
|
||||
lhs = self.visit(lhs.value)
|
||||
if lhs is None or (getattr(lhs, "__name__", "") == "triton" or getattr(lhs, "__name__", "").endswith(".triton")):
|
||||
if lhs is None or (
|
||||
getattr(lhs, "__name__", "") == "triton" or getattr(lhs, "__name__", "").endswith(".triton")
|
||||
):
|
||||
return None
|
||||
return getattr(lhs, node.attr)
|
||||
|
||||
@@ -78,18 +84,21 @@ class DependenciesFinder(ast.NodeVisitor):
|
||||
return
|
||||
if inspect.isbuiltin(func):
|
||||
return
|
||||
if func.__module__ and (func.__module__.startswith('triton.') or '.triton.' in func.__module__):
|
||||
if func.__module__ and (func.__module__.startswith("triton.") or ".triton." in func.__module__):
|
||||
return
|
||||
assert isinstance(func, JITFunction), f"Function \"{func.__name__}\" is being called from a Triton function but is not a Triton function itself. Decorate it with @triton.jit to fix this"
|
||||
assert isinstance(
|
||||
func, JITFunction
|
||||
), f'Function "{func.__name__}" is being called from a Triton function but is not a Triton function itself. Decorate it with @triton.jit to fix this'
|
||||
if func.hash is None:
|
||||
tree = ast.parse(func.src)
|
||||
finder = DependenciesFinder(func.__globals__, func.src)
|
||||
finder.visit(tree)
|
||||
func.hash = finder.ret
|
||||
noinline = str(getattr(func, 'noinline', False))
|
||||
noinline = str(getattr(func, "noinline", False))
|
||||
self.ret = (self.ret + func.hash + noinline).encode("utf-8")
|
||||
self.ret = hashlib.sha1(self.ret).hexdigest()
|
||||
|
||||
|
||||
# -----------------------------------------------------------------------------
|
||||
# JITFunction
|
||||
# -----------------------------------------------------------------------------
|
||||
@@ -116,7 +125,6 @@ class KernelInterface(Generic[T]):
|
||||
|
||||
|
||||
class JITFunction(KernelInterface[T]):
|
||||
|
||||
# Hook for inspecting compiled functions and modules
|
||||
cache_hook = None
|
||||
divisibility = 16
|
||||
@@ -133,26 +141,26 @@ class JITFunction(KernelInterface[T]):
|
||||
elif isinstance(arg, bool):
|
||||
return "i1"
|
||||
elif isinstance(arg, int):
|
||||
if -2**31 <= arg and arg <= 2**31 - 1:
|
||||
if -(2**31) <= arg and arg <= 2**31 - 1:
|
||||
return "i32"
|
||||
elif 2**63 <= arg and arg <= 2**64 - 1:
|
||||
return "u64"
|
||||
else:
|
||||
return "i64"
|
||||
elif isinstance(arg, float):
|
||||
return 'fp32'
|
||||
return "fp32"
|
||||
elif arg is None:
|
||||
return None
|
||||
else:
|
||||
raise TypeError(f'Unsupported type {type(arg)} for {arg}')
|
||||
raise TypeError(f"Unsupported type {type(arg)} for {arg}")
|
||||
|
||||
@staticmethod
|
||||
def _device_of(arg):
|
||||
if hasattr(arg, "device"):
|
||||
if hasattr(arg.device, 'type'):
|
||||
if hasattr(arg.device, "type"):
|
||||
return arg.device.type
|
||||
|
||||
return ''
|
||||
return ""
|
||||
|
||||
@staticmethod
|
||||
def _pinned_memory_of(arg):
|
||||
@@ -165,10 +173,10 @@ class JITFunction(KernelInterface[T]):
|
||||
@staticmethod
|
||||
def _spec_of(arg):
|
||||
if hasattr(arg, "data_ptr"):
|
||||
return (arg.data_ptr() % JITFunction.divisibility == 0)
|
||||
return arg.data_ptr() % JITFunction.divisibility == 0
|
||||
elif isinstance(arg, int):
|
||||
return (arg % 16 == 0, arg == 1)
|
||||
return (arg is None, )
|
||||
return (arg is None,)
|
||||
|
||||
def _get_config(self, *args):
|
||||
def is_divisible_by_16(x):
|
||||
@@ -186,20 +194,23 @@ class JITFunction(KernelInterface[T]):
|
||||
if x is None:
|
||||
return True
|
||||
return False
|
||||
divisible_by_16 = {i for i, arg in enumerate(
|
||||
args) if is_divisible_by_16(arg) and i not in self.do_not_specialize}
|
||||
divisible_by_8 = {i for i, arg in enumerate(
|
||||
args) if is_divisible_by_8(arg) and i not in self.do_not_specialize}
|
||||
|
||||
divisible_by_16 = {
|
||||
i for i, arg in enumerate(args) if is_divisible_by_16(arg) and i not in self.do_not_specialize
|
||||
}
|
||||
divisible_by_8 = {i for i, arg in enumerate(args) if is_divisible_by_8(arg) and i not in self.do_not_specialize}
|
||||
equal_to_1 = {
|
||||
i for i, arg in enumerate(args) if isinstance(
|
||||
arg, int) and not isinstance(
|
||||
arg, bool) and arg == 1 and i not in self.do_not_specialize}
|
||||
i
|
||||
for i, arg in enumerate(args)
|
||||
if isinstance(arg, int) and not isinstance(arg, bool) and arg == 1 and i not in self.do_not_specialize
|
||||
}
|
||||
# folded equal_to_1 and None
|
||||
# TODO: method to collect all folded args
|
||||
none_args = {i for i, arg in enumerate(args) if arg is None and i not in self.do_not_specialize}
|
||||
ids_of_folded_args = equal_to_1 | none_args
|
||||
return namedtuple("instance_descriptor", ["divisible_by_16", "equal_to_1", "ids_of_folded_args", "divisible_by_8"])(
|
||||
tuple(divisible_by_16), tuple(equal_to_1), tuple(ids_of_folded_args), tuple(divisible_by_8))
|
||||
return namedtuple(
|
||||
"instance_descriptor", ["divisible_by_16", "equal_to_1", "ids_of_folded_args", "divisible_by_8"]
|
||||
)(tuple(divisible_by_16), tuple(equal_to_1), tuple(ids_of_folded_args), tuple(divisible_by_8))
|
||||
# return _triton.code_gen.instance_descriptor(divisible_by_16,
|
||||
# equal_to_1)
|
||||
|
||||
@@ -207,7 +218,7 @@ class JITFunction(KernelInterface[T]):
|
||||
def _type_of(key):
|
||||
# None are nullptr -- implicitly converted to *i8
|
||||
if key is None:
|
||||
return '*i8'
|
||||
return "*i8"
|
||||
dtype_str = str(key).split(".")[-1]
|
||||
tys = {
|
||||
"bool": "i1",
|
||||
@@ -239,12 +250,25 @@ class JITFunction(KernelInterface[T]):
|
||||
constants = dict(zip(self.constexprs, constexpr_key))
|
||||
return constants
|
||||
|
||||
def _call_hook(self, key, signature, device, constants, num_warps, num_ctas, num_stages, enable_warp_specialization, enable_fp_fusion, extern_libs, configs):
|
||||
def _call_hook(
|
||||
self,
|
||||
key,
|
||||
signature,
|
||||
device,
|
||||
constants,
|
||||
num_warps,
|
||||
num_ctas,
|
||||
num_stages,
|
||||
enable_warp_specialization,
|
||||
enable_fp_fusion,
|
||||
extern_libs,
|
||||
configs,
|
||||
):
|
||||
if JITFunction.cache_hook is None:
|
||||
return False
|
||||
name = self.fn.__name__
|
||||
module = self.fn.__module__
|
||||
arg_reprs = ', '.join([f'{name}: {ty}' for name, ty in zip(self.arg_names, key[1])])
|
||||
arg_reprs = ", ".join([f"{name}: {ty}" for name, ty in zip(self.arg_names, key[1])])
|
||||
repr = f"{name}[num_warps={num_warps}, num_ctas={num_ctas}, num_stages={num_stages}, enable_warp_specialization={enable_warp_specialization}, enable_fp_fusion={enable_fp_fusion}]({arg_reprs})"
|
||||
key = str(key)
|
||||
|
||||
@@ -254,65 +278,95 @@ class JITFunction(KernelInterface[T]):
|
||||
self.name = name
|
||||
pass
|
||||
|
||||
kwargs = dict(signature=signature, device=device, constants=constants,
|
||||
num_warps=num_warps, num_ctas=num_ctas, num_stages=num_stages, enable_warp_specialization=enable_warp_specialization, enable_fp_fusion=enable_fp_fusion, extern_libs=extern_libs,
|
||||
configs=configs)
|
||||
kwargs = dict(
|
||||
signature=signature,
|
||||
device=device,
|
||||
constants=constants,
|
||||
num_warps=num_warps,
|
||||
num_ctas=num_ctas,
|
||||
num_stages=num_stages,
|
||||
enable_warp_specialization=enable_warp_specialization,
|
||||
enable_fp_fusion=enable_fp_fusion,
|
||||
extern_libs=extern_libs,
|
||||
configs=configs,
|
||||
)
|
||||
|
||||
return JITFunction.cache_hook(key=key, repr=repr, fn=LegacyCompiler(module, name), compile={
|
||||
"key": key, **kwargs}, is_manual_warmup=False, already_compiled=False)
|
||||
return JITFunction.cache_hook(
|
||||
key=key,
|
||||
repr=repr,
|
||||
fn=LegacyCompiler(module, name),
|
||||
compile={"key": key, **kwargs},
|
||||
is_manual_warmup=False,
|
||||
already_compiled=False,
|
||||
)
|
||||
|
||||
def _get_arg_specialization_key(self, arg_name, arg):
|
||||
arg_annotation = self.__annotations__.get(arg_name, '')
|
||||
if arg_annotation == '':
|
||||
return (arg.data_ptr() % JITFunction.divisibility == 0) if hasattr(arg, "data_ptr") \
|
||||
else (arg % JITFunction.divisibility == 0, arg % JITFunction.divisibility_8 == 0, arg == 1) if isinstance(arg, int) \
|
||||
arg_annotation = self.__annotations__.get(arg_name, "")
|
||||
if arg_annotation == "":
|
||||
return (
|
||||
(arg.data_ptr() % JITFunction.divisibility == 0)
|
||||
if hasattr(arg, "data_ptr")
|
||||
else (arg % JITFunction.divisibility == 0, arg % JITFunction.divisibility_8 == 0, arg == 1)
|
||||
if isinstance(arg, int)
|
||||
else (False,)
|
||||
elif 'Tensor' in arg_annotation:
|
||||
return (arg.data_ptr() % JITFunction.divisibility == 0)
|
||||
elif 'int' in arg_annotation or 'bool' in arg_annotation:
|
||||
)
|
||||
elif "Tensor" in arg_annotation:
|
||||
return arg.data_ptr() % JITFunction.divisibility == 0
|
||||
elif "int" in arg_annotation or "bool" in arg_annotation:
|
||||
return (arg % JITFunction.divisibility == 0, arg % JITFunction.divisibility_8 == 0, arg == 1)
|
||||
else:
|
||||
return (False,)
|
||||
|
||||
def _get_arg_sig_key(self, arg_name, arg) -> str:
|
||||
arg_annotation = self.__annotations__.get(arg_name, '')
|
||||
if 'Tensor' in arg_annotation:
|
||||
arg_annotation = self.__annotations__.get(arg_name, "")
|
||||
if "Tensor" in arg_annotation:
|
||||
return arg.dtype
|
||||
elif arg_annotation == 'bool':
|
||||
elif arg_annotation == "bool":
|
||||
return "i1"
|
||||
elif arg_annotation == 'float':
|
||||
return 'fp32'
|
||||
elif arg_annotation == "float":
|
||||
return "fp32"
|
||||
else:
|
||||
return self._key_of(arg)
|
||||
|
||||
def _conclude_device_type(self, device_types: List[str], pinned_memory_flags: List[bool]) -> str:
|
||||
device_types = [device_type for device_type in device_types if device_type != '']
|
||||
device_types = [device_type for device_type in device_types if device_type != ""]
|
||||
# Return cuda if one of the input tensors is cuda
|
||||
if 'cuda' in device_types:
|
||||
if "cuda" in device_types:
|
||||
import torch
|
||||
return 'hip' if torch.version.hip else 'cuda'
|
||||
|
||||
is_cpu = all(device_type == 'cpu' for device_type in device_types)
|
||||
return "hip" if torch.version.hip else "cuda"
|
||||
|
||||
is_cpu = all(device_type == "cpu" for device_type in device_types)
|
||||
is_pinned_memory = any(pinned_memory_flag for pinned_memory_flag in pinned_memory_flags)
|
||||
# Return cuda if all the input tensors are cpu while the memory is pinned
|
||||
if is_cpu and is_pinned_memory:
|
||||
return 'cuda'
|
||||
return "cuda"
|
||||
|
||||
return device_types[0] if len(device_types) > 0 else 'cuda'
|
||||
return device_types[0] if len(device_types) > 0 else "cuda"
|
||||
|
||||
def _make_launcher(self):
|
||||
regular_args = [arg for i, arg in enumerate(
|
||||
self.arg_names) if i not in self.constexprs]
|
||||
constexpr_args = [arg for i, arg in enumerate(
|
||||
self.arg_names) if i in self.constexprs]
|
||||
regular_args = [arg for i, arg in enumerate(self.arg_names) if i not in self.constexprs]
|
||||
constexpr_args = [arg for i, arg in enumerate(self.arg_names) if i in self.constexprs]
|
||||
|
||||
def regular_args_v(args_proxy):
|
||||
return [args_proxy[arg_name] for arg_name in regular_args]
|
||||
|
||||
def launcher_body(args_proxy, grid, num_warps, num_ctas, num_stages, enable_warp_specialization, enable_fp_fusion, extern_libs, stream, warmup, device, device_type):
|
||||
from ..compiler import (CompiledKernel, compile,
|
||||
get_arch_default_num_stages,
|
||||
get_arch_default_num_warps)
|
||||
def launcher_body(
|
||||
args_proxy,
|
||||
grid,
|
||||
num_warps,
|
||||
num_ctas,
|
||||
num_stages,
|
||||
enable_warp_specialization,
|
||||
enable_fp_fusion,
|
||||
extern_libs,
|
||||
stream,
|
||||
warmup,
|
||||
device,
|
||||
device_type,
|
||||
):
|
||||
from ..compiler import CompiledKernel, compile, get_arch_default_num_stages, get_arch_default_num_warps
|
||||
|
||||
sig_key = tuple([self._get_arg_sig_key(arg_name, args_proxy[arg_name]) for arg_name in regular_args])
|
||||
constexpr_key = tuple([args_proxy[arg_name] for arg_name in constexpr_args])
|
||||
specializations = []
|
||||
@@ -332,25 +386,26 @@ class JITFunction(KernelInterface[T]):
|
||||
grid_2 = grid[2] if grid_size > 2 else 1
|
||||
if device_type is None:
|
||||
device_types = [self._device_of(arg) for arg in regular_args_v(args_proxy)]
|
||||
device_types = [_device_type for _device_type in device_types if _device_type != '']
|
||||
device_type = self._conclude_device_type(device_types, [self._pinned_memory_of(arg) for arg in
|
||||
regular_args_v(args_proxy)])
|
||||
device_types = [_device_type for _device_type in device_types if _device_type != ""]
|
||||
device_type = self._conclude_device_type(
|
||||
device_types, [self._pinned_memory_of(arg) for arg in regular_args_v(args_proxy)]
|
||||
)
|
||||
|
||||
device_backend = None
|
||||
if device_type not in ['cuda']:
|
||||
if device_type not in ["cuda"]:
|
||||
device_backend = get_backend(device_type)
|
||||
if device_backend is None:
|
||||
raise ValueError('Cannot find backend for ' + device_type)
|
||||
raise ValueError("Cannot find backend for " + device_type)
|
||||
|
||||
if device is None:
|
||||
if device_type in ['cuda']:
|
||||
if device_type in ["cuda"]:
|
||||
device = get_current_device()
|
||||
set_current_device(device)
|
||||
else:
|
||||
device = device_backend.get_current_device()
|
||||
device_backend.set_current_device(device)
|
||||
if stream is None and not warmup:
|
||||
if device_type in ['cuda']:
|
||||
if device_type in ["cuda"]:
|
||||
stream = get_cuda_stream(device)
|
||||
else:
|
||||
stream = device_backend.get_stream()
|
||||
@@ -360,11 +415,22 @@ class JITFunction(KernelInterface[T]):
|
||||
if num_stages is None:
|
||||
num_stages = get_arch_default_num_stages(device_type)
|
||||
|
||||
if device_type in ['cuda']:
|
||||
if device_type in ["cuda"]:
|
||||
version_key = get_cuda_version_key()
|
||||
else:
|
||||
version_key = device_backend.get_version_key()
|
||||
key = (version_key, sig_key, constexpr_key, spec_key, num_warps, num_ctas, num_stages, enable_warp_specialization, enable_fp_fusion, self.debug)
|
||||
key = (
|
||||
version_key,
|
||||
sig_key,
|
||||
constexpr_key,
|
||||
spec_key,
|
||||
num_warps,
|
||||
num_ctas,
|
||||
num_stages,
|
||||
enable_warp_specialization,
|
||||
enable_fp_fusion,
|
||||
self.debug,
|
||||
)
|
||||
if extern_libs is not None:
|
||||
key = (key, tuple(extern_libs.items()))
|
||||
|
||||
@@ -375,37 +441,104 @@ class JITFunction(KernelInterface[T]):
|
||||
# Create tensormaps and append to args
|
||||
args = bin.assemble_tensormap_to_arg(args)
|
||||
if not warmup:
|
||||
bin.c_wrapper(grid_0, grid_1, grid_2, bin.num_warps, bin.num_ctas, bin.clusterDims[0], bin.clusterDims[1], bin.clusterDims[2], bin.shared, stream, bin.cu_function, CompiledKernel.launch_enter_hook, CompiledKernel.launch_exit_hook, bin, *args)
|
||||
bin.c_wrapper(
|
||||
grid_0,
|
||||
grid_1,
|
||||
grid_2,
|
||||
bin.num_warps,
|
||||
bin.num_ctas,
|
||||
bin.clusterDims[0],
|
||||
bin.clusterDims[1],
|
||||
bin.clusterDims[2],
|
||||
bin.shared,
|
||||
stream,
|
||||
bin.cu_function,
|
||||
CompiledKernel.launch_enter_hook,
|
||||
CompiledKernel.launch_exit_hook,
|
||||
bin,
|
||||
*args,
|
||||
)
|
||||
return bin
|
||||
# kernel not cached -- compile
|
||||
else:
|
||||
# build dict of constant values
|
||||
args = regular_args_v(args_proxy)
|
||||
all_args = tuple([args_proxy[arg_name] for arg_name in self.arg_names])
|
||||
configs = self._get_config(*all_args),
|
||||
configs = (self._get_config(*all_args),)
|
||||
constants = self._make_constants(constexpr_key)
|
||||
constants.update({i: None for i, arg in enumerate(all_args) if arg is None})
|
||||
constants.update({i: 1 for i in configs[0].equal_to_1})
|
||||
# build kernel signature -- doesn't include specialized arguments
|
||||
signature = {i: self._type_of(self._key_of(arg)) for i, arg in enumerate(all_args) if i not in self.constexprs}
|
||||
signature = {
|
||||
i: self._type_of(self._key_of(arg)) for i, arg in enumerate(all_args) if i not in self.constexprs
|
||||
}
|
||||
# build stub signature -- includes arguments that are specialized
|
||||
for i, arg in constants.items():
|
||||
if callable(arg):
|
||||
raise TypeError(f"Callable constexpr at index {i} is not supported")
|
||||
if not self._call_hook(key, signature, device, constants, num_warps, num_ctas, num_stages, enable_warp_specialization, enable_fp_fusion, extern_libs, configs):
|
||||
bin = compile(self, signature=signature, device=device, constants=constants, num_warps=num_warps, num_ctas=num_ctas, num_stages=num_stages, enable_warp_specialization=enable_warp_specialization, enable_fp_fusion=enable_fp_fusion, extern_libs=extern_libs, configs=configs, debug=self.debug, device_type=device_type)
|
||||
if not self._call_hook(
|
||||
key,
|
||||
signature,
|
||||
device,
|
||||
constants,
|
||||
num_warps,
|
||||
num_ctas,
|
||||
num_stages,
|
||||
enable_warp_specialization,
|
||||
enable_fp_fusion,
|
||||
extern_libs,
|
||||
configs,
|
||||
):
|
||||
bin = compile(
|
||||
self,
|
||||
signature=signature,
|
||||
device=device,
|
||||
constants=constants,
|
||||
num_warps=num_warps,
|
||||
num_ctas=num_ctas,
|
||||
num_stages=num_stages,
|
||||
enable_warp_specialization=enable_warp_specialization,
|
||||
enable_fp_fusion=enable_fp_fusion,
|
||||
extern_libs=extern_libs,
|
||||
configs=configs,
|
||||
debug=self.debug,
|
||||
device_type=device_type,
|
||||
)
|
||||
# Create tensormaps and append to args
|
||||
args = bin.assemble_tensormap_to_arg(args)
|
||||
if not warmup:
|
||||
bin.c_wrapper(grid_0, grid_1, grid_2, bin.num_warps, bin.num_ctas, bin.clusterDims[0], bin.clusterDims[1], bin.clusterDims[2], bin.shared, stream, bin.cu_function, CompiledKernel.launch_enter_hook, CompiledKernel.launch_exit_hook, bin, *args)
|
||||
bin.c_wrapper(
|
||||
grid_0,
|
||||
grid_1,
|
||||
grid_2,
|
||||
bin.num_warps,
|
||||
bin.num_ctas,
|
||||
bin.clusterDims[0],
|
||||
bin.clusterDims[1],
|
||||
bin.clusterDims[2],
|
||||
bin.shared,
|
||||
stream,
|
||||
bin.cu_function,
|
||||
CompiledKernel.launch_enter_hook,
|
||||
CompiledKernel.launch_exit_hook,
|
||||
bin,
|
||||
*args,
|
||||
)
|
||||
self.cache[device][key] = bin
|
||||
return bin
|
||||
return None
|
||||
|
||||
# create a wrapper to call launcher_body
|
||||
args_map = ','.join([f'"{arg}": {arg}' for arg in self.arg_names])
|
||||
args_signature = ', '.join(name if dflt == inspect._empty else f'{name} = triton.language.dtype(\'{dflt}\')' if dtype.is_dtype(f'{dflt}') else f'{name} = {dflt}' for name, dflt in zip(self.arg_names, self.arg_defaults))
|
||||
args_signature = args_signature + ', ' if len(args_signature) > 0 else ''
|
||||
args_map = ",".join([f'"{arg}": {arg}' for arg in self.arg_names])
|
||||
args_signature = ", ".join(
|
||||
name
|
||||
if dflt == inspect._empty
|
||||
else f"{name} = triton.language.dtype('{dflt}')"
|
||||
if dtype.is_dtype(f"{dflt}")
|
||||
else f"{name} = {dflt}"
|
||||
for name, dflt in zip(self.arg_names, self.arg_defaults)
|
||||
)
|
||||
args_signature = args_signature + ", " if len(args_signature) > 0 else ""
|
||||
src = f"""
|
||||
import triton
|
||||
def {self.fn.__name__}({args_signature}grid=None, num_warps=None, num_ctas=1, num_stages=None, enable_warp_specialization=False, enable_fp_fusion=True, extern_libs=None, stream=None, warmup=False, device=None, device_type=None):
|
||||
@@ -426,7 +559,7 @@ def {self.fn.__name__}({args_signature}grid=None, num_warps=None, num_ctas=1, nu
|
||||
self.has_defaults = any(v != inspect._empty for v in self.arg_defaults)
|
||||
# function source code (without decorators)
|
||||
self.src = textwrap.dedent(inspect.getsource(fn))
|
||||
self.src = self.src[self.src.find("def"):]
|
||||
self.src = self.src[self.src.find("def") :]
|
||||
# cache of just-in-time compiled kernels
|
||||
self.cache = defaultdict(dict)
|
||||
self.hash = None
|
||||
@@ -439,11 +572,13 @@ def {self.fn.__name__}({args_signature}grid=None, num_warps=None, num_ctas=1, nu
|
||||
# annotations
|
||||
self.__annotations__ = {name: _normalize_ty(ty) for name, ty in fn.__annotations__.items()}
|
||||
# index of constexprs
|
||||
self.constexprs = [self.arg_names.index(name) for name, ty in self.__annotations__.items() if 'constexpr' in ty]
|
||||
self.constexprs = [self.arg_names.index(name) for name, ty in self.__annotations__.items() if "constexpr" in ty]
|
||||
# specialization hints
|
||||
regular_args = [arg for i, arg in enumerate(self.arg_names) if i not in self.constexprs]
|
||||
self.do_not_specialize = [] if do_not_specialize is None else do_not_specialize
|
||||
self.do_not_specialize = {regular_args.index(arg) if isinstance(arg, str) else arg for arg in self.do_not_specialize}
|
||||
self.do_not_specialize = {
|
||||
regular_args.index(arg) if isinstance(arg, str) else arg for arg in self.do_not_specialize
|
||||
}
|
||||
# tma info
|
||||
self.tensormaps_info = TMAInfos()
|
||||
# launcher
|
||||
@@ -482,12 +617,12 @@ def {self.fn.__name__}({args_signature}grid=None, num_warps=None, num_ctas=1, nu
|
||||
def __setattr__(self, name, value):
|
||||
# - when kernel decorators change, cached kernel
|
||||
# needs to be cleared
|
||||
if name == 'kernel_decorators':
|
||||
if name == "kernel_decorators":
|
||||
self.kernel = None
|
||||
super(JITFunction, self).__setattr__(name, value)
|
||||
# - when `.src` attribute is set, cache path needs
|
||||
# to be reinitialized
|
||||
if name == 'src':
|
||||
if name == "src":
|
||||
self.hash = None
|
||||
|
||||
def __repr__(self):
|
||||
@@ -553,12 +688,14 @@ def jit(
|
||||
debug=debug,
|
||||
noinline=noinline,
|
||||
)
|
||||
|
||||
if fn is not None:
|
||||
return decorator(fn)
|
||||
|
||||
else:
|
||||
return decorator
|
||||
|
||||
|
||||
# -----------------------------------------------------------------------------
|
||||
# Utilities for mocking tensors
|
||||
# -----------------------------------------------------------------------------
|
||||
@@ -569,10 +706,10 @@ class MockTensor:
|
||||
Can be used in place of real tensors when calling:
|
||||
kernel.warmup(MockTensor(torch.float32), ...)
|
||||
"""
|
||||
|
||||
@staticmethod
|
||||
def wrap_dtype(arg):
|
||||
if arg.__class__.__name__ == "dtype" and\
|
||||
arg.__module__ == "torch":
|
||||
if arg.__class__.__name__ == "dtype" and arg.__module__ == "torch":
|
||||
return MockTensor(arg)
|
||||
return arg
|
||||
|
||||
@@ -599,7 +736,7 @@ class TensorWrapper:
|
||||
return self.base.stride(i)
|
||||
|
||||
def __str__(self) -> str:
|
||||
return f'TensorWrapper[{self.dtype}]({self.base})'
|
||||
return f"TensorWrapper[{self.dtype}]({self.base})"
|
||||
|
||||
def element_size(self):
|
||||
return self.base.element_size()
|
||||
@@ -617,4 +754,4 @@ def reinterpret(tensor, dtype):
|
||||
# A new wrapper is needed around an unwrapped tensor.
|
||||
return TensorWrapper(tensor, dtype)
|
||||
else:
|
||||
raise TypeError(f'Cannot reinterpret a {type(tensor)}.')
|
||||
raise TypeError(f"Cannot reinterpret a {type(tensor)}.")
|
||||
|
||||
Reference in New Issue
Block a user