Merge commit 'cb3d79a185e40c9d8a579bea07747a8a8d157d52' into ifu-231117

Conflicts:
	lib/Conversion/TritonGPUToLLVM/ElementwiseOpToLLVM.cpp
	lib/Conversion/TritonGPUToLLVM/TritonGPUToLLVM.cpp
	lib/Dialect/TritonGPU/IR/Dialect.cpp
	python/setup.py
	python/test/unit/language/assert_helper.py
	python/test/unit/operators/test_flash_attention.py
	python/test/unit/runtime/test_subproc.py
	python/triton/compiler/compiler.py
	python/triton/language/semantic.py
	python/triton/runtime/autotuner.py
	python/triton/runtime/jit.py
	python/tutorials/03-matrix-multiplication.py
	python/tutorials/05-layer-norm.py
	python/tutorials/06-fused-attention.py
	python/tutorials/11-grouped-gemm.py
	test/Conversion/tritongpu_to_llvm.mlir
This commit is contained in:
Jason Furmanek
2023-11-17 20:42:12 +00:00
179 changed files with 10116 additions and 6835 deletions

View File

@@ -1,8 +1,6 @@
from .autotuner import (Autotuner, Config, Heuristics, OutOfResources, autotune,
heuristics)
from .autotuner import (Autotuner, Config, Heuristics, OutOfResources, autotune, heuristics)
from .driver import driver
from .jit import (JITFunction, KernelInterface, MockTensor, TensorWrapper, reinterpret,
version_key)
from .jit import JITFunction, KernelInterface, MockTensor, TensorWrapper, reinterpret
__all__ = [
"driver",
@@ -12,7 +10,6 @@ __all__ = [
"heuristics",
"JITFunction",
"KernelInterface",
"version_key",
"reinterpret",
"TensorWrapper",
"OutOfResources",

View File

@@ -9,11 +9,10 @@ from .jit import KernelInterface
class OutOfResources(Exception):
def __init__(self, required, limit, name):
self.message = f'out of resource: {name}, '\
f'Required: {required}, '\
f'Hardware limit: {limit}'
self.message += '. Reducing block sizes or `num_stages` may help.'
self.message = (f"out of resource: {name}, Required: {required}, Hardware limit: {limit}. " +
"Reducing block sizes or `num_stages` may help.")
self.required = required
self.limit = limit
self.name = name
@@ -25,38 +24,77 @@ class OutOfResources(Exception):
class Autotuner(KernelInterface):
<<<<<<< HEAD
def __init__(self, fn, arg_names, configs, key, verbose, reset_to_zero, prune_configs_by: Dict = None, warmup=25, rep=100):
'''
=======
def __init__(
self,
fn,
arg_names,
configs,
key,
reset_to_zero,
restore_value,
prune_configs_by: Dict = None,
warmup=25,
rep=100,
):
"""
>>>>>>> cb3d79a185e40c9d8a579bea07747a8a8d157d52
:param prune_configs_by: a dict of functions that are used to prune configs, fields:
'perf_model': performance model used to predicate running time with different configs, returns running time
'top_k': number of configs to bench
'prune_num_stages_by'(optional): a function used to prune num_stages. It takes configs:List[Config] as its input, and returns pruned configs.
'''
"""
if not configs:
self.configs = [Config({}, num_warps=4, num_stages=2, num_ctas=1)]
else:
self.configs = configs
self.key_idx = [arg_names.index(k) for k in key]
self.cache = {}
# hook to reset all required tensor to zeros before relaunching a kernel
self.hook = lambda args: 0
self.arg_names = arg_names
# Reset to zero or restore values
self.reset_idx = []
if reset_to_zero is not None:
self.reset_idx = [arg_names.index(k) for k in reset_to_zero]
self.restore_idx = []
if restore_value is not None:
self.restore_idx = [arg_names.index(k) for k in restore_value]
def _hook(args):
# Hook to reset or restore for required tensors
self.pre_hook = lambda args, reset_only=False: 0
self.post_hook = lambda args: 0
if len(self.reset_idx) > 0 or len(self.restore_idx) > 0:
def _pre_hook(args, reset_only=False):
for i in self.reset_idx:
args[i].zero_()
self.hook = _hook
self.arg_names = arg_names
# prune configs
if not reset_only:
self.restore_copies = [args[i].clone() for i in self.restore_idx]
self.pre_hook = _pre_hook
if len(self.restore_idx) > 0:
def _post_hook(args):
for i, j in enumerate(self.restore_idx):
args[j].copy_(self.restore_copies[i])
self.restore_copies = []
self.post_hook = _post_hook
# Prune configs
if prune_configs_by:
perf_model, top_k = prune_configs_by['perf_model'], prune_configs_by['top_k']
if 'early_config_prune' in prune_configs_by:
early_config_prune = prune_configs_by['early_config_prune']
perf_model, top_k = prune_configs_by["perf_model"], prune_configs_by["top_k"]
if "early_config_prune" in prune_configs_by:
early_config_prune = prune_configs_by["early_config_prune"]
else:
perf_model, top_k, early_config_prune = None, None, None
self.perf_model, self.configs_top_k = perf_model, top_k
self.early_config_prune = early_config_prune
self.fn = fn
self.warmup = warmup
self.rep = rep
@@ -67,10 +105,8 @@ class Autotuner(KernelInterface):
# as kwargs and by the autotuner
conflicts = meta.keys() & config.kwargs.keys()
if conflicts:
raise ValueError(
f"Conflicting meta-parameters: {', '.join(conflicts)}."
" Make sure that you don't re-define auto-tuned symbols."
)
raise ValueError(f"Conflicting meta-parameters: {', '.join(conflicts)}."
" Make sure that you don't re-define auto-tuned symbols.")
# augment meta-parameters with tunable ones
current = dict(meta, **config.kwargs)
full_nargs = {**self.nargs, **current}
@@ -78,16 +114,22 @@ class Autotuner(KernelInterface):
def kernel_call():
if config.pre_hook:
config.pre_hook(full_nargs)
self.hook(args)
self.fn.run(*args, num_warps=config.num_warps, num_stages=config.num_stages,
num_ctas=config.num_ctas,
enable_warp_specialization=config.enable_warp_specialization,
# enable_persistent=False,
**current)
self.pre_hook(args)
self.fn.run(
*args,
num_warps=config.num_warps,
num_stages=config.num_stages,
num_ctas=config.num_ctas,
enable_warp_specialization=config.enable_warp_specialization,
# enable_persistent=False,
**current,
)
self.post_hook(args)
try:
return do_bench(kernel_call, warmup=self.warmup, rep=self.rep, quantiles=(0.5, 0.2, 0.8))
except OutOfResources:
return [float('inf'), float('inf'), float('inf')]
return [float("inf"), float("inf"), float("inf")]
def get_best_config(self):
return self.best_config
@@ -110,12 +152,11 @@ class Autotuner(KernelInterface):
# prune configs
pruned_configs = self.prune_configs(kwargs)
bench_start = time.time()
timings = {config: self._bench(*args, config=config, **kwargs)
for config in pruned_configs}
timings = {config: self._bench(*args, config=config, **kwargs) for config in pruned_configs}
bench_end = time.time()
self.bench_time = bench_end - bench_start
self.cache[key] = builtins.min(timings, key=timings.get)
self.hook(args)
self.pre_hook(args, reset_only=True)
self.configs_timings = timings
if self.verbose:
print(str(key) + ": " + str(self.cache[key]))
@@ -126,9 +167,15 @@ class Autotuner(KernelInterface):
full_nargs = {**self.nargs, **kwargs, **self.best_config.kwargs}
if config.pre_hook is not None:
config.pre_hook(full_nargs)
ret = self.fn.run(*args, num_warps=config.num_warps, num_stages=config.num_stages,
num_ctas=config.num_ctas,
enable_warp_specialization=config.enable_warp_specialization, **kwargs, **config.kwargs)
ret = self.fn.run(
*args,
num_warps=config.num_warps,
num_stages=config.num_stages,
num_ctas=config.num_ctas,
enable_warp_specialization=config.enable_warp_specialization,
**kwargs,
**config.kwargs,
)
self.nargs = None
return ret
@@ -142,17 +189,20 @@ class Autotuner(KernelInterface):
top_k = int(len(self.configs) * top_k)
if len(pruned_configs) > top_k:
est_timing = {
config: self.perf_model(**self.nargs, **kwargs, **config.kwargs, num_stages=config.num_stages,
num_warps=config.num_warps,
num_ctas=config.num_ctas,
enable_warp_specialization=config.enable_warp_specialization,
enable_persistent=config.enable_persistent)
config:
self.perf_model(
**self.nargs,
**kwargs,
**config.kwargs,
num_stages=config.num_stages,
num_warps=config.num_warps,
num_ctas=config.num_ctas,
enable_warp_specialization=config.enable_warp_specialization,
enable_persistent=config.enable_persistent,
)
for config in pruned_configs
}
pruned_configs = sorted(
est_timing.keys(),
key=lambda x: est_timing[x])[
:top_k]
pruned_configs = sorted(est_timing.keys(), key=lambda x: est_timing[x])[:top_k]
return pruned_configs
def warmup(self, *args, **kwargs):
@@ -195,13 +245,14 @@ class Config:
self.num_ctas = num_ctas
self.num_stages = num_stages
self.enable_warp_specialization = enable_warp_specialization
# TODO[shuhaoj]: May make enable_persistent configurable in future if necessay.
# TODO[shuhaoj]: May make enable_persistent configurable in future if necessary.
self.enable_persistent = False
self.pre_hook = pre_hook
def __str__(self):
res = []
for k, v in self.kwargs.items():
<<<<<<< HEAD
res.append(f'{k}: {v}')
res.append(f'num_warps: {self.num_warps}')
## Comment out Hopper specific parameters
@@ -214,6 +265,18 @@ class Config:
def autotune(configs, key, prune_configs_by=None, reset_to_zero=None, verbose=False, warmup=25, rep=100):
=======
res.append(f"{k}: {v}")
res.append(f"num_warps: {self.num_warps}")
res.append(f"num_ctas: {self.num_ctas}")
res.append(f"num_stages: {self.num_stages}")
res.append(f"enable_warp_specialization: {self.enable_warp_specialization}")
res.append(f"enable_persistent: {self.enable_persistent}")
return ", ".join(res)
def autotune(configs, key, prune_configs_by=None, reset_to_zero=None, restore_value=None, warmup=25, rep=100):
>>>>>>> cb3d79a185e40c9d8a579bea07747a8a8d157d52
"""
Decorator for auto-tuning a :code:`triton.jit`'d function.
@@ -244,6 +307,8 @@ def autotune(configs, key, prune_configs_by=None, reset_to_zero=None, verbose=Fa
'early_config_prune'(optional): a function used to do early prune (eg, num_stages). It takes configs:List[Config] as its input, and returns pruned configs.
:param reset_to_zero: a list of argument names whose value will be reset to zero before evaluating any configs.
:type reset_to_zero: list[str]
:param restore_value: a list of argument names whose value will be restored after evaluating any configs.
:type restore_value: list[str]
:param warmup: Warmup time (in ms) to pass to benchmarking, defaults to 25.
:type warmup: int
:param rep: Repetition time (in ms) to pass to benchmarking, defaults to 100.
@@ -251,8 +316,13 @@ def autotune(configs, key, prune_configs_by=None, reset_to_zero=None, verbose=Fa
:param verbose: a boolean that controls whether the best_config for each key is printed
:type verbose: bool
"""
def decorator(fn):
<<<<<<< HEAD
return Autotuner(fn, fn.arg_names, configs, key, verbose, reset_to_zero, prune_configs_by, warmup, rep)
=======
return Autotuner(fn, fn.arg_names, configs, key, reset_to_zero, restore_value, prune_configs_by, warmup, rep)
>>>>>>> cb3d79a185e40c9d8a579bea07747a8a8d157d52
return decorator
@@ -286,6 +356,7 @@ def heuristics(values):
each such function takes a list of positional arguments as input.
:type values: dict[str, Callable[[list[Any]], Any]]
"""
def decorator(fn):
return Heuristics(fn, fn.arg_names, values)

View File

@@ -1,27 +1,42 @@
#include "cuda.h"
#include <dlfcn.h>
#include <stdbool.h>
#define PY_SSIZE_T_CLEAN
#include <Python.h>
static inline void gpuAssert(CUresult code, const char *file, int line) {
if (code != CUDA_SUCCESS) {
const char *prefix = "Triton Error [CUDA]: ";
const char *str;
cuGetErrorString(code, &str);
char err[1024] = {0};
strcat(err, prefix);
strcat(err, str);
PyGILState_STATE gil_state;
gil_state = PyGILState_Ensure();
PyErr_SetString(PyExc_RuntimeError, err);
PyGILState_Release(gil_state);
}
// Raises a Python exception and returns false if code is not CUDA_SUCCESS.
static bool gpuAssert(CUresult code, const char *file, int line) {
if (code == CUDA_SUCCESS)
return true;
const char *prefix = "Triton Error [CUDA]: ";
const char *str;
cuGetErrorString(code, &str);
char err[1024] = {0};
strcat(err, prefix);
strcat(err, str);
PyGILState_STATE gil_state;
gil_state = PyGILState_Ensure();
PyErr_SetString(PyExc_RuntimeError, err);
PyGILState_Release(gil_state);
return false;
}
#define CUDA_CHECK(ans) \
{ \
{ gpuAssert((ans), __FILE__, __LINE__); } \
}
// To be used only *outside* a Py_{BEGIN,END}_ALLOW_THREADS block.
#define CUDA_CHECK_AND_RETURN_NULL(ans) \
do { \
if (!gpuAssert((ans), __FILE__, __LINE__)) \
return NULL; \
} while (0)
// To be used inside a Py_{BEGIN,END}_ALLOW_THREADS block.
#define CUDA_CHECK_AND_RETURN_NULL_ALLOW_THREADS(ans) \
do { \
if (!gpuAssert((ans), __FILE__, __LINE__)) { \
PyEval_RestoreThread(_save); \
return NULL; \
} \
} while (0)
#define ADD_ENUM_ITEM(value) \
do { \
@@ -200,16 +215,16 @@ static PyObject *getDeviceProperties(PyObject *self, PyObject *args) {
int sm_clock_rate;
int mem_clock_rate;
int mem_bus_width;
CUDA_CHECK(cuDeviceGetAttribute(
CUDA_CHECK_AND_RETURN_NULL(cuDeviceGetAttribute(
&max_shared_mem, CU_DEVICE_ATTRIBUTE_MAX_SHARED_MEMORY_PER_BLOCK_OPTIN,
device));
CUDA_CHECK(cuDeviceGetAttribute(
CUDA_CHECK_AND_RETURN_NULL(cuDeviceGetAttribute(
&multiprocessor_count, CU_DEVICE_ATTRIBUTE_MULTIPROCESSOR_COUNT, device));
CUDA_CHECK(cuDeviceGetAttribute(&sm_clock_rate,
CU_DEVICE_ATTRIBUTE_CLOCK_RATE, device));
CUDA_CHECK(cuDeviceGetAttribute(
CUDA_CHECK_AND_RETURN_NULL(cuDeviceGetAttribute(
&sm_clock_rate, CU_DEVICE_ATTRIBUTE_CLOCK_RATE, device));
CUDA_CHECK_AND_RETURN_NULL(cuDeviceGetAttribute(
&mem_clock_rate, CU_DEVICE_ATTRIBUTE_MEMORY_CLOCK_RATE, device));
CUDA_CHECK(cuDeviceGetAttribute(
CUDA_CHECK_AND_RETURN_NULL(cuDeviceGetAttribute(
&mem_bus_width, CU_DEVICE_ATTRIBUTE_GLOBAL_MEMORY_BUS_WIDTH, device));
return Py_BuildValue("{s:i, s:i, s:i, s:i, s:i}", "max_shared_mem",
@@ -237,33 +252,37 @@ static PyObject *loadBinary(PyObject *self, PyObject *args) {
CUcontext pctx = 0;
Py_BEGIN_ALLOW_THREADS;
CUDA_CHECK(cuCtxGetCurrent(&pctx));
CUDA_CHECK_AND_RETURN_NULL_ALLOW_THREADS(cuCtxGetCurrent(&pctx));
if (!pctx) {
CUDA_CHECK(cuDevicePrimaryCtxRetain(&pctx, device));
CUDA_CHECK(cuCtxSetCurrent(pctx));
CUDA_CHECK_AND_RETURN_NULL_ALLOW_THREADS(
cuDevicePrimaryCtxRetain(&pctx, device));
CUDA_CHECK_AND_RETURN_NULL_ALLOW_THREADS(cuCtxSetCurrent(pctx));
}
CUDA_CHECK(cuModuleLoadData(&mod, data));
CUDA_CHECK(cuModuleGetFunction(&fun, mod, name));
CUDA_CHECK_AND_RETURN_NULL_ALLOW_THREADS(cuModuleLoadData(&mod, data));
CUDA_CHECK_AND_RETURN_NULL_ALLOW_THREADS(
cuModuleGetFunction(&fun, mod, name));
// get allocated registers and spilled registers from the function
CUDA_CHECK(cuFuncGetAttribute(&n_regs, CU_FUNC_ATTRIBUTE_NUM_REGS, fun));
CUDA_CHECK(
CUDA_CHECK_AND_RETURN_NULL_ALLOW_THREADS(
cuFuncGetAttribute(&n_regs, CU_FUNC_ATTRIBUTE_NUM_REGS, fun));
CUDA_CHECK_AND_RETURN_NULL_ALLOW_THREADS(
cuFuncGetAttribute(&n_spills, CU_FUNC_ATTRIBUTE_LOCAL_SIZE_BYTES, fun));
n_spills /= 4;
// set dynamic shared memory if necessary
int shared_optin;
CUDA_CHECK(cuDeviceGetAttribute(
CUDA_CHECK_AND_RETURN_NULL_ALLOW_THREADS(cuDeviceGetAttribute(
&shared_optin, CU_DEVICE_ATTRIBUTE_MAX_SHARED_MEMORY_PER_BLOCK_OPTIN,
device));
if (shared > 49152 && shared_optin > 49152) {
CUDA_CHECK(cuFuncSetCacheConfig(fun, CU_FUNC_CACHE_PREFER_SHARED));
CUDA_CHECK_AND_RETURN_NULL_ALLOW_THREADS(
cuFuncSetCacheConfig(fun, CU_FUNC_CACHE_PREFER_SHARED));
int shared_total, shared_static;
CUDA_CHECK(cuDeviceGetAttribute(
CUDA_CHECK_AND_RETURN_NULL_ALLOW_THREADS(cuDeviceGetAttribute(
&shared_total, CU_DEVICE_ATTRIBUTE_MAX_SHARED_MEMORY_PER_MULTIPROCESSOR,
device));
CUDA_CHECK(cuFuncGetAttribute(&shared_static,
CU_FUNC_ATTRIBUTE_SHARED_SIZE_BYTES, fun));
CUDA_CHECK(
CUDA_CHECK_AND_RETURN_NULL_ALLOW_THREADS(cuFuncGetAttribute(
&shared_static, CU_FUNC_ATTRIBUTE_SHARED_SIZE_BYTES, fun));
CUDA_CHECK_AND_RETURN_NULL_ALLOW_THREADS(
cuFuncSetAttribute(fun, CU_FUNC_ATTRIBUTE_MAX_DYNAMIC_SHARED_SIZE_BYTES,
shared_optin - shared_static));
}
@@ -286,7 +305,7 @@ static PyObject *memAlloc(PyObject *self, PyObject *args) {
}
Py_BEGIN_ALLOW_THREADS;
CUDA_CHECK(cuMemAlloc(&dptr, bytesize));
CUDA_CHECK_AND_RETURN_NULL_ALLOW_THREADS(cuMemAlloc(&dptr, bytesize));
Py_END_ALLOW_THREADS;
return PyLong_FromUnsignedLongLong((unsigned long long)dptr);
@@ -307,7 +326,8 @@ static PyObject *memcpyHtoD(PyObject *self, PyObject *args) {
srcHost = (const void *)srcHostPtr;
Py_BEGIN_ALLOW_THREADS;
CUDA_CHECK(cuMemcpyHtoD(dstDevice, srcHost, byteCount));
CUDA_CHECK_AND_RETURN_NULL_ALLOW_THREADS(
cuMemcpyHtoD(dstDevice, srcHost, byteCount));
Py_END_ALLOW_THREADS;
Py_RETURN_NONE;
@@ -321,7 +341,7 @@ static PyObject *memFree(PyObject *self, PyObject *args) {
}
Py_BEGIN_ALLOW_THREADS;
CUDA_CHECK(cuMemFree(dptr));
CUDA_CHECK_AND_RETURN_NULL_ALLOW_THREADS(cuMemFree(dptr));
Py_END_ALLOW_THREADS;
Py_RETURN_NONE;
@@ -411,7 +431,7 @@ static PyObject *tensorMapEncodeTiled(PyObject *self, PyObject *args) {
}
// Call the function
Py_BEGIN_ALLOW_THREADS;
CUDA_CHECK(cuTensorMapEncodeTiledHandle(
CUDA_CHECK_AND_RETURN_NULL_ALLOW_THREADS(cuTensorMapEncodeTiledHandle(
tensorMap, tensorDataType, tensorRank, globalAddress, globalDim,
globalStrides, boxDim, elementStrides, interleave, swizzle, l2Promotion,
oobFill));

View File

@@ -19,6 +19,7 @@ def default_dump_dir():
class CacheManager(ABC):
def __init__(self, key):
pass
@@ -44,20 +45,21 @@ class CacheManager(ABC):
class FileCacheManager(CacheManager):
def __init__(self, key, override=False, dump=False):
self.key = key
self.lock_path = None
if (dump):
if dump:
self.cache_dir = default_dump_dir()
self.cache_dir = os.path.join(self.cache_dir, self.key)
self.lock_path = os.path.join(self.cache_dir, "lock")
os.makedirs(self.cache_dir, exist_ok=True)
elif (override):
elif override:
self.cache_dir = default_override_dir()
self.cache_dir = os.path.join(self.cache_dir, self.key)
else:
# create cache directory if it doesn't exist
self.cache_dir = os.getenv('TRITON_CACHE_DIR', "").strip() or default_cache_dir()
self.cache_dir = os.getenv("TRITON_CACHE_DIR", "").strip() or default_cache_dir()
if self.cache_dir:
self.cache_dir = os.path.join(self.cache_dir, self.key)
self.lock_path = os.path.join(self.cache_dir, "lock")
@@ -93,9 +95,8 @@ class FileCacheManager(CacheManager):
result = {}
for c in child_paths:
p = self._make_path(c)
if not os.path.exists(p):
raise Exception(f"Group file {p} does not exist from group {grp_filename} ")
result[c] = p
if os.path.exists(p):
result[c] = p
return result
# Note a group of pushed files as being part of a group
@@ -142,6 +143,7 @@ def get_cache_manager(key) -> CacheManager:
if user_cache_manager is not None and user_cache_manager != __cache_cls_nme:
import importlib
module_path, clz_nme = user_cache_manager.split(":")
module = importlib.import_module(module_path)
__cache_cls = getattr(module, clz_nme)

View File

@@ -9,7 +9,6 @@ from .cache import get_cache_manager
class DriverBase(metaclass=abc.ABCMeta):
CUDA = 0
HIP = 1
@@ -19,6 +18,8 @@ class DriverBase(metaclass=abc.ABCMeta):
def __init__(self) -> None:
pass
# -----------------------------
# CUDA
# -----------------------------
@@ -27,7 +28,7 @@ class DriverBase(metaclass=abc.ABCMeta):
class CudaUtils(object):
def __new__(cls):
if not hasattr(cls, 'instance'):
if not hasattr(cls, "instance"):
cls.instance = super(CudaUtils, cls).__new__(cls)
return cls.instance
@@ -47,6 +48,7 @@ class CudaUtils(object):
with open(so, "rb") as f:
cache_path = cache.put(f.read(), fname, binary=True)
import importlib.util
spec = importlib.util.spec_from_file_location("cuda_utils", cache_path)
mod = importlib.util.module_from_spec(spec)
spec.loader.exec_module(mod)
@@ -66,7 +68,7 @@ class CudaUtils(object):
class CudaDriver(DriverBase):
def __new__(cls):
if not hasattr(cls, 'instance'):
if not hasattr(cls, "instance"):
cls.instance = super(CudaDriver, cls).__new__(cls)
return cls.instance
@@ -74,14 +76,16 @@ class CudaDriver(DriverBase):
self.utils = CudaUtils()
self.backend = self.CUDA
# -----------------------------
# HIP
# -----------------------------
class HIPUtils(object):
def __new__(cls):
if not hasattr(cls, 'instance'):
if not hasattr(cls, "instance"):
cls.instance = super(HIPUtils, cls).__new__(cls)
return cls.instance
@@ -101,6 +105,7 @@ class HIPUtils(object):
with open(so, "rb") as f:
cache_path = cache.put(f.read(), fname, binary=True)
import importlib.util
spec = importlib.util.spec_from_file_location("hip_utils", cache_path)
mod = importlib.util.module_from_spec(spec)
spec.loader.exec_module(mod)
@@ -111,7 +116,7 @@ class HIPUtils(object):
class HIPDriver(DriverBase):
def __new__(cls):
if not hasattr(cls, 'instance'):
if not hasattr(cls, "instance"):
cls.instance = super(HIPDriver, cls).__new__(cls)
return cls.instance
@@ -123,7 +128,7 @@ class HIPDriver(DriverBase):
class UnsupportedDriver(DriverBase):
def __new__(cls):
if not hasattr(cls, 'instance'):
if not hasattr(cls, "instance"):
cls.instance = super(UnsupportedDriver, cls).__new__(cls)
return cls.instance
@@ -131,12 +136,14 @@ class UnsupportedDriver(DriverBase):
self.utils = None
self.backend = None
# -----------------------------
# Driver
# -----------------------------
class LazyProxy:
def __init__(self, init_fn):
self._init_fn = init_fn
self._obj = None
@@ -150,7 +157,7 @@ class LazyProxy:
return getattr(self._obj, name)
def __setattr__(self, name, value):
if name in ['_init_fn', '_obj']:
if name in ["_init_fn", "_obj"]:
super().__setattr__(name, value)
else:
self._initialize_obj()
@@ -172,6 +179,7 @@ class LazyProxy:
def initialize_driver():
import torch
if torch.version.hip is not None:
return HIPDriver()
elif torch.cuda.is_available():

View File

@@ -1,10 +1,8 @@
class OutOfResources(Exception):
def __init__(self, required, limit, name):
self.message = f'out of resource: {name}, '\
f'Required: {required}, '\
f'Hardware limit: {limit}'
self.message += '. Reducing block sizes or `num_stages` may help.'
self.message = f"out of resource: {name}, " f"Required: {required}, " f"Hardware limit: {limit}"
self.message += ". Reducing block sizes or `num_stages` may help."
self.required = required
self.limit = limit
self.name = name

View File

@@ -74,11 +74,15 @@ class BlockPointerHandle:
def wrap_ret(compute_ret_ty):
def wrapper(fn):
def wrapped(*args, **kwargs):
ret = fn(*args, **kwargs)
return TensorHandle(ret.data, compute_ret_ty(*args, **kwargs))
return wrapped
return wrapper
@@ -249,11 +253,13 @@ class Builder:
# ternary functions
def ternary_op(self, lhs, rhs, other, op):
return TensorHandle(op(lhs.data, rhs.data, other.data), other.dtype)
create_select = lambda self, cond, lhs, rhs: self.ternary_op(cond, lhs, rhs, np.where)
# unary functions
def unary_op(self, arg, op):
return TensorHandle(op(arg.data), arg.dtype)
create_exp = lambda self, arg: self.unary_op(arg, np.exp)
create_cos = lambda self, arg: self.unary_op(arg, np.cos)
create_sin = lambda self, arg: self.unary_op(arg, np.sin)
@@ -279,7 +285,8 @@ class Builder:
dtype_tt = ptr.dtype.element_ty
return TensorHandle(ptr.data + (dtype_tt.primitive_bitwidth // 8) * offset.data.astype(np.uint64), ptr.dtype)
def create_tensor_pointer_load(self, ptr, boundary_check, padding_option, cache_modifier, eviction_policy, is_volatile):
def create_tensor_pointer_load(self, ptr, boundary_check, padding_option, cache_modifier, eviction_policy,
is_volatile):
ptrs, masks = ptr.materialize_pointers(boundary_check)
assert padding_option is None
other = None
@@ -297,6 +304,7 @@ class Builder:
def create_int_to_ptr(self, val, dst_ty):
return TensorHandle(val.data.astype(np.uint64), dst_ty)
# def create_cat(self, lhs, rhs):
# pass
@@ -360,7 +368,10 @@ class Builder:
def patch_attr(obj, name, member, builder):
new_member = lambda *args, member=member, **kwargs: (member(*args, **{k: v for k, v in kwargs.items() if k != '_builder'}, _builder=builder))
new_member = lambda *args, member=member, **kwargs: (member(*args, **
{k: v
for k, v in kwargs.items()
if k != "_builder"}, _builder=builder))
setattr(obj, name, new_member)
@@ -384,8 +395,8 @@ def _patch_lang_core(lang, builder):
def _new_reduce(input, axis, combine_fn):
fn = combine_fn.fn.__name__
mapping = {
'maximum': np.max,
'_sum_combine': np.sum,
"maximum": np.max,
"_sum_combine": np.sum,
}
ret = mapping[fn](input.handle.data, axis=axis)
ret_type = tl.block_type(input.dtype, ret.shape)
@@ -397,15 +408,16 @@ def _patch_lang_core(lang, builder):
def _patch_lang_math(lang, builder):
math = lang.math
mapping = {
'abs': 'abs',
'acos': 'arccos',
'asin': 'arcsin',
'exp2': 'exp2',
'log2': 'log2',
'max': 'maximum',
"abs": "abs",
"acos": "arccos",
"asin": "arcsin",
"exp2": "exp2",
"log2": "log2",
"max": "maximum",
}
def make_numpy(name):
def impl(*args, **kwargs):
ret_type = args[0].type # TODO: incorrect
ret_dtype = args[0].dtype # TODO: incorrect
@@ -414,15 +426,18 @@ def _patch_lang_math(lang, builder):
ret = getattr(np, mapping[name])(*args, **kwargs)
ret = tl.core.tensor(TensorHandle(ret, ret_dtype), ret_type)
return ret
return impl
def make_fallback(name):
def fallback(*args, **kwargs):
raise NotImplementedError(f"""
{name} not supported in interpreter mode: no known numpy implementation.
If you think that {name} in fact does have a numpy implementation, please add it
to the mapping in python/triton/interpreter/new_interpreter.py:_patch_lang_math.
""")
return fallback
for name, member in inspect.getmembers(math):
@@ -438,7 +453,7 @@ def _implicit_cvt(arg):
ty = str_to_ty(triton.runtime.jit.JITFunction._type_of(triton.runtime.jit.JITFunction._key_of(arg)))
handle = TensorHandle(np.array([arg], dtype=np.int32), ty)
return tl.tensor(handle, ty)
if hasattr(arg, 'data_ptr'):
if hasattr(arg, "data_ptr"):
ty = str_to_ty(triton.runtime.jit.JITFunction._type_of(triton.runtime.jit.JITFunction._key_of(arg)))
handle = TensorHandle(np.array([arg.data_ptr()], dtype=np.uint64), ty)
return tl.tensor(handle, ty)
@@ -453,28 +468,29 @@ def _unwrap(tensor):
builder = Builder()
RESERVED_KWS = ['num_warps', 'num_stages', 'num_ctas', 'enable_warp_specialization', 'enable_fp_fusion']
RESERVED_KWS = ["num_warps", "num_stages", "num_ctas", "enable_warp_specialization", "enable_fp_fusion"]
class GridExecutor:
def __init__(self, fn, arg_names, grid):
from .jit import _normalize_ty # TODO: modularize
self.fn = fn
self.arg_names = arg_names
self.grid = grid
__annotations__ = {name: _normalize_ty(ty) for name, ty in fn.__annotations__.items()}
self.constexprs = [name for name in arg_names if __annotations__.get(name) == 'constexpr']
self.constexprs = [name for name in arg_names if __annotations__.get(name) == "constexpr"]
def _patch_lang(self, builder):
lang = [value for _, value in self.fn.__globals__.items() if value in [tl, tl.core]]
assert len(lang) == 1, "triton.language must be visible from within jit'd function"
_patch_lang_tensor(getattr(lang[0], 'tensor'), builder)
_patch_lang_tensor(getattr(lang[0], "tensor"), builder)
_patch_lang_core(lang[0], builder)
_patch_lang_math(lang[0], builder)
def __call__(self, *args_dev, **kwargs):
args_hst = [_unwrap(arg).cpu() if hasattr(arg, 'data_ptr') else arg for arg in args_dev]
args_hst = [_unwrap(arg).cpu() if hasattr(arg, "data_ptr") else arg for arg in args_dev]
# removes reserved keywords from kwargs
kwargs = {k: v for k, v in kwargs.items() if k not in RESERVED_KWS}
# remaps core language functions to interpreted ones
@@ -486,7 +502,7 @@ class GridExecutor:
# iterate through grid
grid = self.grid(args) if callable(self.grid) else self.grid
assert len(grid) <= 3
grid = grid + (1,) * (3 - len(grid))
grid = grid + (1, ) * (3 - len(grid))
builder.set_grid_dim(*grid)
for x in range(grid[0]):
for y in range(grid[1]):
@@ -495,7 +511,7 @@ class GridExecutor:
self.fn(**args)
# copy arguments back to propagate side-effects
for arg_dev, arg_hst in zip(args_dev, args_hst):
if hasattr(arg_dev, 'data_ptr'):
if hasattr(arg_dev, "data_ptr"):
_unwrap(arg_dev).copy_(arg_hst.to(arg_dev.device))
@@ -504,17 +520,18 @@ class InterpretedFunction:
def _patch_lang(self, builder):
lang = [value for _, value in self.fn.__globals__.items() if value in [tl, tl.core]]
assert len(lang) == 1, "triton.language must be visible from within jit'd function"
_patch_lang_tensor(getattr(lang[0], 'tensor'), builder)
_patch_lang_tensor(getattr(lang[0], "tensor"), builder)
_patch_lang_core(lang[0], builder)
def __init__(self, fn) -> None:
self.fn = fn
def run(*args, **kwargs):
grid = kwargs['grid']
kwargs = {k: v for k, v in kwargs.items() if k not in RESERVED_KWS + ['grid']}
grid = kwargs["grid"]
kwargs = {k: v for k, v in kwargs.items() if k not in RESERVED_KWS + ["grid"]}
return GridExecutor(self.fn, self.arg_names, grid)(*args, **kwargs)
self.run = run
signature = inspect.signature(fn)
self.arg_names = [v.name for v in signature.parameters.values()]

View File

@@ -5,48 +5,48 @@ import functools
import hashlib
import inspect
import os
import subprocess
import textwrap
from collections import defaultdict, namedtuple
from typing import (Callable, Generic, Iterable, List, Optional, TypeVar, Union, cast,
overload)
from functools import cached_property
from typing import Callable, Generic, Iterable, List, Optional, TypeVar, Union, cast, overload
from .._C.libtriton.triton import TMAInfos
from ..common.backend import get_backend, path_to_ptxas
from ..language.core import dtype
from ..common.backend import get_backend, get_cuda_version_key
from .interpreter import InterpretedFunction
TRITON_PATH = os.path.dirname(os.path.dirname(os.path.abspath(__file__)))
TRITON_VERSION = "2.1.0"
def get_cuda_stream(idx=None):
if idx is None:
idx = get_current_device()
try:
from torch._C import _cuda_getCurrentRawStream
return _cuda_getCurrentRawStream(idx)
except ImportError:
import torch
return torch.cuda.current_stream(idx).cuda_stream
def get_current_device():
import torch
return torch.cuda.current_device()
def set_current_device(idx):
import torch
torch.cuda.set_device(idx)
def get_device_capability(idx):
import torch
return torch.cuda.get_device_capability(idx)
T = TypeVar('T')
T = TypeVar("T")
# -----------------------------------------------------------------------------
# Dependencies Finder
@@ -72,7 +72,8 @@ class DependenciesFinder(ast.NodeVisitor):
lhs = self.visit(node.value)
while isinstance(lhs, ast.Attribute):
lhs = self.visit(lhs.value)
if lhs is None or (getattr(lhs, "__name__", "") == "triton" or getattr(lhs, "__name__", "").endswith(".triton")):
if lhs is None or (getattr(lhs, "__name__", "") == "triton"
or getattr(lhs, "__name__", "").endswith(".triton")):
return None
return getattr(lhs, node.attr)
@@ -82,55 +83,26 @@ class DependenciesFinder(ast.NodeVisitor):
return
if inspect.isbuiltin(func):
return
if func.__module__ and (func.__module__.startswith('triton.') or '.triton.' in func.__module__):
if func.__module__ and (func.__module__.startswith("triton.") or ".triton." in func.__module__):
return
assert isinstance(func, JITFunction), f"Function \"{func.__name__}\" is being called from a Triton function but is not a Triton function itself. Decorate it with @triton.jit to fix this"
assert isinstance(
func, JITFunction
), f'Function "{func.__name__}" is being called from a Triton function but is not a Triton function itself. Decorate it with @triton.jit to fix this'
if func.hash is None:
tree = ast.parse(func.src)
finder = DependenciesFinder(func.__globals__, func.src)
finder.visit(tree)
func.hash = finder.ret
noinline = str(getattr(func, 'noinline', False))
noinline = str(getattr(func, "noinline", False))
self.ret = (self.ret + func.hash + noinline).encode("utf-8")
self.ret = hashlib.sha1(self.ret).hexdigest()
# -----------------------------------------------------------------------------
# JITFunction
# -----------------------------------------------------------------------------
@functools.lru_cache()
def version_key():
import pkgutil
contents = []
# frontend
with open(__file__, "rb") as f:
contents += [hashlib.sha1(f.read()).hexdigest()]
# compiler
compiler_path = os.path.join(TRITON_PATH, 'compiler')
for lib in pkgutil.iter_modules([compiler_path]):
with open(lib.module_finder.find_spec(lib.name).origin, "rb") as f:
contents += [hashlib.sha1(f.read()).hexdigest()]
# backend
libtriton_hash = hashlib.sha1()
with open(os.path.join(TRITON_PATH, "_C/libtriton.so"), "rb") as f:
while True:
chunk = f.read(1024 ** 2)
if not chunk:
break
libtriton_hash.update(chunk)
contents.append(libtriton_hash.hexdigest())
# language
language_path = os.path.join(TRITON_PATH, 'language')
for lib in pkgutil.iter_modules([language_path]):
with open(lib.module_finder.find_spec(lib.name).origin, "rb") as f:
contents += [hashlib.sha1(f.read()).hexdigest()]
# ptxas version
ptxas = path_to_ptxas()[0]
ptxas_version = hashlib.sha1(subprocess.check_output([ptxas, "--version"])).hexdigest()
return '-'.join(TRITON_VERSION) + '-' + ptxas_version + '-' + '-'.join(contents)
def _normalize_ty(ty) -> str:
if isinstance(ty, type):
return ty.__name__
@@ -139,6 +111,85 @@ def _normalize_ty(ty) -> str:
return repr(ty)
class KernelParam:
"""Represents a parameter to a @jit'ed function.
A parameter is just the name plus metadata; a parameter plus a value is a
KernelArg.
"""
def __init__(self, num: int, param: inspect.Parameter, do_not_specialize: bool):
self.num = num
self._param = param
self.do_not_specialize = do_not_specialize
@cached_property
def name(self):
return self._param.name
@cached_property
def annotation(self):
if not self._param.annotation or self._param.annotation == inspect.Parameter.empty:
return ""
return _normalize_ty(self._param.annotation)
@cached_property
def is_constexpr(self):
return "constexpr" in self.annotation
@property
def default(self):
return self._param.default
@property
def has_default(self):
return self._param.default != inspect.Parameter.empty
class KernelArg:
"""Represents an argument to a @jit'ed function.
An argument is a parameter plus a value.
"""
def __init__(self, value, param):
self.value = value
self.param = param
@property
def name(self):
return self.param.name
def signature_key(self):
annotation = self.param.annotation
if "Tensor" in annotation:
return self.value.dtype
elif annotation == "bool":
return "i1"
elif annotation == "float":
return "fp32"
else:
return JITFunction._key_of(self.value)
def specialization_key(self):
assert not self.param.do_not_specialize
try:
return (self.value.data_ptr() % JITFunction.divisibility == 0, )
except AttributeError:
pass
if isinstance(self.value, int):
# bool is a subclass of int, so we don't check explicitly above.
return (
self.value % JITFunction.divisibility == 0,
self.value % JITFunction.divisibility_8 == 0,
self.value == 1,
)
return (False, )
class KernelInterface(Generic[T]):
run: T
@@ -152,7 +203,6 @@ class KernelInterface(Generic[T]):
class JITFunction(KernelInterface[T]):
# Hook for inspecting compiled functions and modules
cache_hook = None
divisibility = 16
@@ -169,44 +219,44 @@ class JITFunction(KernelInterface[T]):
elif isinstance(arg, bool):
return "i1"
elif isinstance(arg, int):
if -2**31 <= arg and arg <= 2**31 - 1:
if -(2**31) <= arg and arg <= 2**31 - 1:
return "i32"
elif 2**63 <= arg and arg <= 2**64 - 1:
return "u64"
else:
return "i64"
elif isinstance(arg, float):
return 'fp32'
return "fp32"
elif arg is None:
return None
else:
raise TypeError(f'Unsupported type {type(arg)} for {arg}')
raise TypeError(f"Unsupported type {type(arg)} for {arg}")
@staticmethod
def _device_of(arg):
if hasattr(arg, "device"):
if hasattr(arg.device, 'type'):
return arg.device.type
return ''
try:
return arg.device.type
except AttributeError:
return ""
@staticmethod
def _pinned_memory_of(arg):
if hasattr(arg, "is_pinned"):
if isinstance(arg.is_pinned, Callable):
return arg.is_pinned()
return False
try:
return arg.is_pinned()
except (AttributeError, TypeError):
return False
@staticmethod
def _spec_of(arg):
if hasattr(arg, "data_ptr"):
return (arg.data_ptr() % JITFunction.divisibility == 0)
return arg.data_ptr() % JITFunction.divisibility == 0
elif isinstance(arg, int):
return (arg % 16 == 0, arg == 1)
return (arg is None, )
# TODO(jlebar): Fold this into the KernelArg class.
def _get_config(self, *args):
def is_divisible_by_16(x):
if hasattr(x, "data_ptr"):
return x.data_ptr() % JITFunction.divisibility == 0
@@ -222,28 +272,38 @@ class JITFunction(KernelInterface[T]):
if x is None:
return True
return False
divisible_by_16 = {i for i, arg in enumerate(
args) if is_divisible_by_16(arg) and i not in self.do_not_specialize}
divisible_by_8 = {i for i, arg in enumerate(
args) if is_divisible_by_8(arg) and i not in self.do_not_specialize}
divisible_by_16 = {
param.num
for param, arg in zip(self.params, args)
if is_divisible_by_16(arg) and not param.do_not_specialize
}
divisible_by_8 = {
param.num
for param, arg in zip(self.params, args)
if is_divisible_by_8(arg) and not param.do_not_specialize
}
equal_to_1 = {
i for i, arg in enumerate(args) if isinstance(
arg, int) and not isinstance(
arg, bool) and arg == 1 and i not in self.do_not_specialize}
param.num
for param, arg in zip(self.params, args)
if isinstance(arg, int) and not isinstance(arg, bool) and arg == 1 and not param.do_not_specialize
}
# folded equal_to_1 and None
# TODO: method to collect all folded args
none_args = {i for i, arg in enumerate(args) if arg is None and i not in self.do_not_specialize}
none_args = {param.num for param, arg in zip(self.params, args) if arg is None and not param.do_not_specialize}
ids_of_folded_args = equal_to_1 | none_args
return namedtuple("instance_descriptor", ["divisible_by_16", "equal_to_1", "ids_of_folded_args", "divisible_by_8"])(
tuple(divisible_by_16), tuple(equal_to_1), tuple(ids_of_folded_args), tuple(divisible_by_8))
return namedtuple("instance_descriptor",
["divisible_by_16", "equal_to_1", "ids_of_folded_args", "divisible_by_8"])( #
tuple(divisible_by_16), tuple(equal_to_1), tuple(ids_of_folded_args),
tuple(divisible_by_8))
# return _triton.code_gen.instance_descriptor(divisible_by_16,
# equal_to_1)
@staticmethod
def _type_of(key):
# None are nullptr -- implicitly converted to *i8
# `None` is nullptr. Implicitly convert to *i8.
if key is None:
return '*i8'
return "*i8"
dtype_str = str(key).split(".")[-1]
tys = {
"bool": "i1",
@@ -281,21 +341,46 @@ class JITFunction(KernelInterface[T]):
constants = dict(zip(self.constexprs, constexpr_key))
return constants
<<<<<<< HEAD
def _call_hook(self, key, signature, device, constants, num_warps, num_ctas, num_stages, waves_per_eu, matrix_instr_nonkdim, enable_warp_specialization,enable_fp_fusion, extern_libs, configs):
=======
def _call_hook(
self,
key,
signature,
device,
constants,
num_warps,
num_ctas,
num_stages,
enable_warp_specialization,
enable_fp_fusion,
extern_libs,
configs,
):
>>>>>>> cb3d79a185e40c9d8a579bea07747a8a8d157d52
if JITFunction.cache_hook is None:
return False
name = self.fn.__name__
module = self.fn.__module__
<<<<<<< HEAD
arg_reprs = ', '.join([f'{name}: {ty}' for name, ty in zip(self.arg_names, key[1])])
repr = f"{name}[num_warps={num_warps}, num_ctas={num_ctas}, num_stages={num_stages}, waves_per_eu={waves_per_eu}, matrix_instr_nonkdim={matrix_instr_nonkdim}, enable_warp_specialization={enable_warp_specialization}]({arg_reprs}), enable_fp_fusion={enable_fp_fusion}]({arg_reprs})"
=======
arg_reprs = ", ".join([f"{param.name}: {ty}" for param, ty in zip(self.params, key[1])])
repr = f"{name}[num_warps={num_warps}, num_ctas={num_ctas}, num_stages={num_stages}, enable_warp_specialization={enable_warp_specialization}, enable_fp_fusion={enable_fp_fusion}]({arg_reprs})"
>>>>>>> cb3d79a185e40c9d8a579bea07747a8a8d157d52
key = str(key)
class LegacyCompiler:
def __init__(self, module, name):
self.module = module
self.name = name
pass
<<<<<<< HEAD
kwargs = dict(signature=signature, device=device, constants=constants,
num_warps=num_warps, num_ctas=num_ctas, num_stages=num_stages, waves_per_eu=waves_per_eu, enable_warp_specialization=enable_warp_specialization, enable_fp_fusion=enable_fp_fusion, extern_libs=extern_libs,
configs=configs)
@@ -326,18 +411,43 @@ class JITFunction(KernelInterface[T]):
return 'fp32'
else:
return self._key_of(arg)
=======
kwargs = dict(
signature=signature,
device=device,
constants=constants,
num_warps=num_warps,
num_ctas=num_ctas,
num_stages=num_stages,
enable_warp_specialization=enable_warp_specialization,
enable_fp_fusion=enable_fp_fusion,
extern_libs=extern_libs,
configs=configs,
)
return JITFunction.cache_hook(
key=key,
repr=repr,
fn=LegacyCompiler(module, name),
compile={"key": key, **kwargs},
is_manual_warmup=False,
already_compiled=False,
)
>>>>>>> cb3d79a185e40c9d8a579bea07747a8a8d157d52
def _conclude_device_type(self, device_types: List[str], pinned_memory_flags: List[bool]) -> str:
device_types = [device_type for device_type in device_types if device_type != '']
device_types = [device_type for device_type in device_types if device_type != ""]
# Return cuda if one of the input tensors is cuda
if 'cuda' in device_types:
if "cuda" in device_types:
import torch
return 'hip' if torch.version.hip else 'cuda'
is_cpu = all(device_type == 'cpu' for device_type in device_types)
return "hip" if torch.version.hip else "cuda"
is_cpu = all(device_type == "cpu" for device_type in device_types)
is_pinned_memory = any(pinned_memory_flag for pinned_memory_flag in pinned_memory_flags)
# Return cuda if all the input tensors are cpu while the memory is pinned
if is_cpu and is_pinned_memory:
<<<<<<< HEAD
return 'cuda'
return device_types[0] if len(device_types) > 0 else 'cuda'
@@ -452,16 +562,193 @@ def {self.fn.__name__}({args_signature}grid=None, num_warps=None, num_ctas=1, nu
scope = {"launcher_body": launcher_body}
exec(src, scope)
return scope[self.fn.__name__]
=======
return "cuda"
return device_types[0] if len(device_types) > 0 else "cuda"
def run(self, *args, **kwargs):
from ..compiler import CompiledKernel, compile, get_arch_default_num_stages, get_arch_default_num_warps
# Get a compiler-flags arg like `num_warps` and remove it from kwargs.
def get_special_arg(name: str, default=None):
if name not in kwargs:
return default
ret = kwargs[name]
del kwargs[name]
return ret
grid = get_special_arg("grid")
num_warps = get_special_arg("num_warps")
num_ctas = get_special_arg("num_ctas", 1)
num_stages = get_special_arg("num_stages")
enable_warp_specialization = get_special_arg("enable_warp_specialization", False)
enable_fp_fusion = get_special_arg("enable_fp_fusion", True)
extern_libs = get_special_arg("extern_libs")
stream = get_special_arg("stream")
warmup = get_special_arg("warmup", False)
device = get_special_arg("device")
device_type = get_special_arg("device_type")
# Bind the remaining arguments to `fn`.
bound_args = self.signature.bind(*args, **kwargs)
bound_args.apply_defaults()
assert len(bound_args.arguments) == len(self.params)
args = [KernelArg(arg_value, param) for (_, arg_value), param in zip(bound_args.arguments.items(), self.params)]
non_constexpr_arg_values = [arg.value for arg in args if not arg.param.is_constexpr]
sig_key = tuple(arg.signature_key() for arg in args if not arg.param.is_constexpr)
spec_key = tuple(arg.specialization_key() for arg in args if not arg.param.do_not_specialize)
constexpr_key = tuple(arg.value for arg in args if arg.param.is_constexpr)
assert num_ctas > 0
assert grid is not None
if callable(grid):
# Arguments are passed as a dict to `grid`, by contract.
# TODO(jlebar): In the new launch API, pass the compiler flags as a
# second parameter to `grid`.
grid = grid(dict(bound_args.arguments))
grid_size = len(grid)
grid_0 = grid[0]
grid_1 = grid[1] if grid_size > 1 else 1
grid_2 = grid[2] if grid_size > 2 else 1
if device_type is None:
device_types = [self._device_of(arg) for arg in non_constexpr_arg_values]
device_types = [_device_type for _device_type in device_types if _device_type != ""]
device_type = self._conclude_device_type(device_types,
[self._pinned_memory_of(arg) for arg in non_constexpr_arg_values])
device_backend = None
if device_type not in ["cuda"]:
device_backend = get_backend(device_type)
if device_backend is None:
raise ValueError("Cannot find backend for " + device_type)
if device is None:
if device_type in ["cuda"]:
device = get_current_device()
set_current_device(device)
else:
device = device_backend.get_current_device()
device_backend.set_current_device(device)
if stream is None and not warmup:
if device_type in ["cuda"]:
stream = get_cuda_stream(device)
else:
stream = device_backend.get_stream()
if num_warps is None:
num_warps = get_arch_default_num_warps(device_type)
if num_stages is None:
num_stages = get_arch_default_num_stages(device_type)
if device_type in ["cuda"]:
version_key = get_cuda_version_key()
else:
version_key = device_backend.get_version_key()
key = (
version_key,
sig_key,
constexpr_key,
spec_key,
num_warps,
num_ctas,
num_stages,
enable_warp_specialization,
enable_fp_fusion,
self.debug,
)
if extern_libs is not None:
key = (key, tuple(extern_libs.items()))
# Kernel is not cached; we have to compile.
if key not in self.cache[device]:
configs = (self._get_config(*[arg.value for arg in args]), )
constants = {
arg.param.num: arg.value
for arg in args
if arg.param.is_constexpr or arg.param.num in configs[0].equal_to_1 or arg.value is None
}
for i, arg in constants.items():
if callable(arg):
raise TypeError(f"Callable constexpr at index {i} is not supported")
# Build kernel signature -- doesn't include constexpr arguments.
signature = {
arg.param.num: self._type_of(self._key_of(arg.value))
for arg in args
if not arg.param.is_constexpr
}
if self._call_hook(
key,
signature,
device,
constants,
num_warps,
num_ctas,
num_stages,
enable_warp_specialization,
enable_fp_fusion,
extern_libs,
configs,
):
return None
self.cache[device][key] = compile(
self,
signature=signature,
device=device,
constants=constants,
num_warps=num_warps,
num_ctas=num_ctas,
num_stages=num_stages,
enable_warp_specialization=enable_warp_specialization,
enable_fp_fusion=enable_fp_fusion,
extern_libs=extern_libs,
configs=configs,
debug=self.debug,
device_type=device_type,
)
bin = self.cache[device][key]
if not warmup:
bin.c_wrapper(
grid_0,
grid_1,
grid_2,
bin.num_warps,
bin.num_ctas,
bin.clusterDims[0],
bin.clusterDims[1],
bin.clusterDims[2],
bin.shared,
stream,
bin.cu_function,
CompiledKernel.launch_enter_hook,
CompiledKernel.launch_exit_hook,
bin,
*bin.assemble_tensormap_to_arg(non_constexpr_arg_values),
)
return bin
>>>>>>> cb3d79a185e40c9d8a579bea07747a8a8d157d52
def __init__(self, fn, version=None, do_not_specialize=None, debug=None, noinline=None):
do_not_specialize = do_not_specialize if do_not_specialize else []
self.fn = fn
self.module = fn.__module__
self.version = version
# function signature information
signature = inspect.signature(fn)
self.arg_names = [v.name for v in signature.parameters.values()]
self.arg_defaults = [v.default for v in signature.parameters.values()]
self.has_defaults = any(v != inspect._empty for v in self.arg_defaults)
self.signature = inspect.signature(fn)
self.do_not_specialize = do_not_specialize
self.params = []
for i, param in enumerate(self.signature.parameters.values()):
dns = do_not_specialize and (i in do_not_specialize or param.name in do_not_specialize)
self.params.append(KernelParam(i, param, dns))
# function source code (without decorators)
self.src = textwrap.dedent(inspect.getsource(fn))
self.src = self.src[self.src.find("def"):]
@@ -470,22 +757,18 @@ def {self.fn.__name__}({args_signature}grid=None, num_warps=None, num_ctas=1, nu
self.hash = None
# JITFunction can be instantiated as kernel
# when called with a grid using __getitem__
self.kernel_decorators = []
self.kernel = None
self.debug = True if os.environ.get("TRITON_DEBUG", "0") == "1" else debug
self.noinline = noinline
# annotations
self.__annotations__ = {name: _normalize_ty(ty) for name, ty in fn.__annotations__.items()}
# index of constexprs
self.constexprs = [self.arg_names.index(name) for name, ty in self.__annotations__.items() if 'constexpr' in ty]
# specialization hints
regular_args = [arg for i, arg in enumerate(self.arg_names) if i not in self.constexprs]
self.do_not_specialize = [] if do_not_specialize is None else do_not_specialize
self.do_not_specialize = {regular_args.index(arg) if isinstance(arg, str) else arg for arg in self.do_not_specialize}
# tma info
self.tensormaps_info = TMAInfos()
# launcher
self.run = self._make_launcher()
# TODO(jlebar): Remove uses of these fields outside this file, then
# remove the fields here.
self.arg_names = [p.name for p in self.params]
self.constexprs = [p.num for p in self.params if p.is_constexpr]
# re-use docs of wrapped function
self.__doc__ = fn.__doc__
self.__name__ = fn.__name__
@@ -498,7 +781,7 @@ def {self.fn.__name__}({args_signature}grid=None, num_warps=None, num_ctas=1, nu
if self.hash is None:
dependencies_finder = DependenciesFinder(globals=self.__globals__, src=self.src)
dependencies_finder.visit(self.parse())
self.hash = dependencies_finder.ret + version_key()
self.hash = dependencies_finder.ret
return self.hash
def warmup(self, *args, **kwargs):
@@ -518,14 +801,10 @@ def {self.fn.__name__}({args_signature}grid=None, num_warps=None, num_ctas=1, nu
raise RuntimeError("Cannot call @triton.jit'd outside of the scope of a kernel")
def __setattr__(self, name, value):
# - when kernel decorators change, cached kernel
# needs to be cleared
if name == 'kernel_decorators':
self.kernel = None
super(JITFunction, self).__setattr__(name, value)
# - when `.src` attribute is set, cache path needs
# to be reinitialized
if name == 'src':
if name == "src":
self.hash = None
def __repr__(self):
@@ -591,12 +870,14 @@ def jit(
debug=debug,
noinline=noinline,
)
if fn is not None:
return decorator(fn)
else:
return decorator
# -----------------------------------------------------------------------------
# Utilities for mocking tensors
# -----------------------------------------------------------------------------
@@ -607,10 +888,10 @@ class MockTensor:
Can be used in place of real tensors when calling:
kernel.warmup(MockTensor(torch.float32), ...)
"""
@staticmethod
def wrap_dtype(arg):
if arg.__class__.__name__ == "dtype" and\
arg.__module__ == "torch":
if arg.__class__.__name__ == "dtype" and arg.__module__ == "torch":
return MockTensor(arg)
return arg
@@ -623,6 +904,7 @@ class MockTensor:
class TensorWrapper:
def __init__(self, base, dtype):
self.dtype = dtype
self.base = base
@@ -637,7 +919,7 @@ class TensorWrapper:
return self.base.stride(i)
def __str__(self) -> str:
return f'TensorWrapper[{self.dtype}]({self.base})'
return f"TensorWrapper[{self.dtype}]({self.base})"
def element_size(self):
return self.base.element_size()
@@ -655,4 +937,4 @@ def reinterpret(tensor, dtype):
# A new wrapper is needed around an unwrapped tensor.
return TensorWrapper(tensor, dtype)
else:
raise TypeError(f'Cannot reinterpret a {type(tensor)}.')
raise TypeError(f"Cannot reinterpret a {type(tensor)}.")