mirror of
https://github.com/ROCm/ROCm.git
synced 2026-04-05 03:01:17 -04:00
Merge commit '36fc54b6f28168d3644808bfe299f1ba06a36272' into ifu230908-2
Conflicts: .gitignore bin/triton-translate.cpp include/triton/Conversion/TritonGPUToLLVM/TritonGPUToLLVMPass.h include/triton/Dialect/TritonGPU/IR/TritonGPUAttrDefs.td include/triton/Dialect/TritonGPU/IR/TritonGPUDialect.td lib/Analysis/Utility.cpp lib/Conversion/TritonGPUToLLVM/ConvertLayoutOpToLLVM/SharedToDotOperandMMAv2.cpp lib/Conversion/TritonGPUToLLVM/DotOpToLLVM.cpp lib/Conversion/TritonGPUToLLVM/ElementwiseOpToLLVM.cpp lib/Conversion/TritonGPUToLLVM/ReduceOpToLLVM.cpp lib/Conversion/TritonGPUToLLVM/TritonGPUToLLVM.cpp lib/Conversion/TritonGPUToLLVM/TritonGPUToLLVMBase.h lib/Conversion/TritonGPUToLLVM/TritonGPUToLLVMPass.cpp lib/Conversion/TritonGPUToLLVM/Utility.h lib/Dialect/Triton/Transforms/RewriteTensorPointer.cpp lib/Dialect/TritonGPU/IR/Dialect.cpp lib/Dialect/TritonGPU/Transforms/AccelerateMatmul.cpp lib/Dialect/TritonGPU/Transforms/RemoveLayoutConversions.cpp lib/Target/LLVMIR/LLVMIRTranslation.cpp python/src/triton.cc python/test/unit/runtime/test_subproc.py python/triton/compiler/compiler.py python/triton/compiler/make_launcher.py python/triton/language/semantic.py python/triton/runtime/jit.py python/tutorials/06-fused-attention.py test/Conversion/triton_to_tritongpu.mlir test/Conversion/tritongpu_to_llvm.mlir test/TritonGPU/coalesce.mlir unittest/Conversion/TritonGPUToLLVM/CMakeLists.txt
This commit is contained in:
@@ -33,7 +33,7 @@ class Autotuner(KernelInterface):
|
||||
'prune_num_stages_by'(optional): a function used to prune num_stages. It takes configs:List[Config] as its input, and returns pruned configs.
|
||||
'''
|
||||
if not configs:
|
||||
self.configs = [Config({}, num_warps=4, num_stages=2)]
|
||||
self.configs = [Config({}, num_warps=4, num_stages=2, num_ctas=1)]
|
||||
else:
|
||||
self.configs = configs
|
||||
self.key_idx = [arg_names.index(k) for k in key]
|
||||
@@ -79,7 +79,11 @@ class Autotuner(KernelInterface):
|
||||
if config.pre_hook:
|
||||
config.pre_hook(full_nargs)
|
||||
self.hook(args)
|
||||
self.fn.run(*args, num_warps=config.num_warps, num_stages=config.num_stages, **current)
|
||||
self.fn.run(*args, num_warps=config.num_warps, num_stages=config.num_stages,
|
||||
num_ctas=config.num_ctas,
|
||||
enable_warp_specialization=config.enable_warp_specialization,
|
||||
# enable_persistent=False,
|
||||
**current)
|
||||
try:
|
||||
return do_bench(kernel_call, warmup=self.warmup, rep=self.rep, quantiles=(0.5, 0.2, 0.8))
|
||||
except OutOfResources:
|
||||
@@ -125,12 +129,12 @@ class Autotuner(KernelInterface):
|
||||
else:
|
||||
config = self.configs[0]
|
||||
self.best_config = config
|
||||
full_nargs = {**self.nargs, **kwargs, **self.best_config.kwargs}
|
||||
if config.pre_hook is not None:
|
||||
full_nargs = {**self.nargs, **kwargs, **self.best_config.kwargs}
|
||||
config.pre_hook(full_nargs)
|
||||
ret = self.fn.run(*args, num_warps=config.num_warps, num_stages=config.num_stages, **kwargs, **config.kwargs)
|
||||
self.nargs = None
|
||||
return ret
|
||||
return self.fn.run(*args, num_warps=config.num_warps, num_stages=config.num_stages,
|
||||
num_ctas=config.num_ctas,
|
||||
enable_warp_specialization=config.enable_warp_specialization, **kwargs, **config.kwargs)
|
||||
|
||||
def prune_configs(self, kwargs):
|
||||
pruned_configs = self.configs
|
||||
@@ -143,10 +147,16 @@ class Autotuner(KernelInterface):
|
||||
if len(pruned_configs) > top_k:
|
||||
est_timing = {
|
||||
config: self.perf_model(**self.nargs, **kwargs, **config.kwargs, num_stages=config.num_stages,
|
||||
num_warps=config.num_warps)
|
||||
num_warps=config.num_warps,
|
||||
num_ctas=config.num_ctas,
|
||||
enable_warp_specialization=config.enable_warp_specialization,
|
||||
enable_persistent=config.enable_persistent)
|
||||
for config in pruned_configs
|
||||
}
|
||||
pruned_configs = sorted(est_timing.keys(), key=lambda x: est_timing[x])[:top_k]
|
||||
pruned_configs = sorted(
|
||||
est_timing.keys(),
|
||||
key=lambda x: est_timing[x])[
|
||||
:top_k]
|
||||
return pruned_configs
|
||||
|
||||
def warmup(self, *args, **kwargs):
|
||||
@@ -155,7 +165,10 @@ class Autotuner(KernelInterface):
|
||||
self.fn.warmup(
|
||||
*args,
|
||||
num_warps=config.num_warps,
|
||||
num_ctas=config.num_ctas,
|
||||
num_stages=config.num_stages,
|
||||
enable_warp_specialization=config.enable_warp_specialization,
|
||||
enable_persistent=config.enable_persistent,
|
||||
**kwargs,
|
||||
**config.kwargs,
|
||||
)
|
||||
@@ -174,15 +187,20 @@ class Config:
|
||||
:type num_warps: int
|
||||
:ivar num_stages: the number of stages that the compiler should use when software-pipelining loops.
|
||||
Mostly useful for matrix multiplication workloads on SM80+ GPUs.
|
||||
:type num_stages: int
|
||||
:type enable_warp_specialization: bool
|
||||
:ivar enable_warp_specialization: enable specialization (spatial partitioning) or not. See https://docs.nvidia.com/cuda/cuda-c-programming-guide/index.html#spatial-partitioning-also-known-as-warp-specialization
|
||||
:ivar pre_hook: a function that will be called before the kernel is called. Parameters of this
|
||||
function are args.
|
||||
"""
|
||||
|
||||
def __init__(self, kwargs, num_warps=4, num_stages=2, pre_hook=None):
|
||||
def __init__(self, kwargs, num_warps=4, num_stages=2, num_ctas=1, enable_warp_specialization=False, pre_hook=None):
|
||||
self.kwargs = kwargs
|
||||
self.num_warps = num_warps
|
||||
self.num_ctas = num_ctas
|
||||
self.num_stages = num_stages
|
||||
self.enable_warp_specialization = enable_warp_specialization
|
||||
# TODO[shuhaoj]: May make enable_persistent configurable in future if necessay.
|
||||
self.enable_persistent = False
|
||||
self.pre_hook = pre_hook
|
||||
|
||||
def __str__(self):
|
||||
@@ -190,7 +208,11 @@ class Config:
|
||||
for k, v in self.kwargs.items():
|
||||
res.append(f'{k}: {v}')
|
||||
res.append(f'num_warps: {self.num_warps}')
|
||||
res.append(f'num_ctas: {self.num_ctas}')
|
||||
res.append(f'num_stages: {self.num_stages}')
|
||||
res.append(
|
||||
f'enable_warp_specialization: {self.enable_warp_specialization}')
|
||||
res.append(f'enable_persistent: {self.enable_persistent}')
|
||||
return ', '.join(res)
|
||||
|
||||
|
||||
|
||||
@@ -1,4 +1,5 @@
|
||||
#include "cuda.h"
|
||||
#include <dlfcn.h>
|
||||
#define PY_SSIZE_T_CLEAN
|
||||
#include <Python.h>
|
||||
|
||||
@@ -16,11 +17,172 @@ static inline void gpuAssert(CUresult code, const char *file, int line) {
|
||||
|
||||
#define CUDA_CHECK(ans) \
|
||||
{ \
|
||||
gpuAssert((ans), __FILE__, __LINE__); \
|
||||
if (PyErr_Occurred()) \
|
||||
return NULL; \
|
||||
{ gpuAssert((ans), __FILE__, __LINE__); } \
|
||||
}
|
||||
|
||||
#define ADD_ENUM_ITEM(value) \
|
||||
do { \
|
||||
PyObject *py_value = PyLong_FromLong(value); \
|
||||
PyDict_SetItemString(enum_dict, #value, py_value); \
|
||||
} while (0)
|
||||
|
||||
#define ADD_ENUM_ITEM_0()
|
||||
#define ADD_ENUM_ITEM_1(v1) ADD_ENUM_ITEM(v1)
|
||||
#define ADD_ENUM_ITEM_2(v1, v2) \
|
||||
ADD_ENUM_ITEM(v1); \
|
||||
ADD_ENUM_ITEM(v2);
|
||||
#define ADD_ENUM_ITEM_3(v1, v2, v3) \
|
||||
ADD_ENUM_ITEM(v1); \
|
||||
ADD_ENUM_ITEM(v2); \
|
||||
ADD_ENUM_ITEM(v3);
|
||||
#define ADD_ENUM_ITEM_4(v1, v2, v3, v4) \
|
||||
ADD_ENUM_ITEM(v1); \
|
||||
ADD_ENUM_ITEM(v2); \
|
||||
ADD_ENUM_ITEM(v3); \
|
||||
ADD_ENUM_ITEM(v4);
|
||||
#define ADD_ENUM_ITEM_5(v1, v2, v3, v4, v5) \
|
||||
ADD_ENUM_ITEM_2(v1, v2); \
|
||||
ADD_ENUM_ITEM_3(v3, v4, v5);
|
||||
#define ADD_ENUM_ITEM_6(v1, v2, v3, v4, v5, v6) \
|
||||
ADD_ENUM_ITEM_2(v1, v2); \
|
||||
ADD_ENUM_ITEM_4(v3, v4, v5, v6);
|
||||
#define ADD_ENUM_ITEM_7(v1, v2, v3, v4, v5, v6, v7) \
|
||||
ADD_ENUM_ITEM_3(v1, v2, v3); \
|
||||
ADD_ENUM_ITEM_4(v4, v5, v6, v7);
|
||||
#define ADD_ENUM_ITEM_8(v1, v2, v3, v4, v5, v6, v7, v8) \
|
||||
ADD_ENUM_ITEM_4(v1, v2, v3, v4); \
|
||||
ADD_ENUM_ITEM_4(v5, v6, v7, v8);
|
||||
#define ADD_ENUM_ITEM_9(v1, v2, v3, v4, v5, v6, v7, v8, v9) \
|
||||
ADD_ENUM_ITEM_5(v1, v2, v3, v4, v5); \
|
||||
ADD_ENUM_ITEM_4(v6, v7, v8, v9);
|
||||
#define ADD_ENUM_ITEM_10(v1, v2, v3, v4, v5, v6, v7, v8, v9, v10) \
|
||||
ADD_ENUM_ITEM_5(v1, v2, v3, v4, v5); \
|
||||
ADD_ENUM_ITEM_5(v6, v7, v8, v9, v10);
|
||||
#define ADD_ENUM_ITEM_11(v1, v2, v3, v4, v5, v6, v7, v8, v9, v10, v11) \
|
||||
ADD_ENUM_ITEM_6(v1, v2, v3, v4, v5, v6); \
|
||||
ADD_ENUM_ITEM_5(v7, v8, v9, v10, v11);
|
||||
#define ADD_ENUM_ITEM_12(v1, v2, v3, v4, v5, v6, v7, v8, v9, v10, v11, v12) \
|
||||
ADD_ENUM_ITEM_6(v1, v2, v3, v4, v5, v6); \
|
||||
ADD_ENUM_ITEM_6(v7, v8, v9, v10, v11, v12);
|
||||
#define ADD_ENUM_ITEM_13(v1, v2, v3, v4, v5, v6, v7, v8, v9, v10, v11, v12, \
|
||||
v13) \
|
||||
ADD_ENUM_ITEM_7(v1, v2, v3, v4, v5, v6, v7); \
|
||||
ADD_ENUM_ITEM_6(v8, v9, v10, v11, v12, v13);
|
||||
#define ADD_ENUM_ITEM_14(v1, v2, v3, v4, v5, v6, v7, v8, v9, v10, v11, v12, \
|
||||
v13, v14) \
|
||||
ADD_ENUM_ITEM_7(v1, v2, v3, v4, v5, v6, v7); \
|
||||
ADD_ENUM_ITEM_7(v8, v9, v10, v11, v12, v13, v14);
|
||||
|
||||
#define DISPATCH_ARGS_N(_14, _13, _12, _11, _10, _9, _8, _7, _6, _5, _4, _3, \
|
||||
_2, _1, N, ...) \
|
||||
ADD_ENUM_ITEM_##N
|
||||
#define DISPATCH_ARGS(...) \
|
||||
DISPATCH_ARGS_N(__VA_ARGS__, 14, 13, 12, 11, 10, 9, 8, 7, 6, 5, 4, 3, 2, 1, \
|
||||
0) \
|
||||
(__VA_ARGS__)
|
||||
|
||||
#define ADD_ENUM_TO_MODULE(module, enum_name, ...) \
|
||||
do { \
|
||||
PyObject *enum_dict = PyDict_New(); \
|
||||
DISPATCH_ARGS(__VA_ARGS__) \
|
||||
if (enum_dict != NULL) { \
|
||||
PyObject_SetAttrString(module, #enum_name, enum_dict); \
|
||||
} \
|
||||
} while (0)
|
||||
|
||||
static void defineEnums(PyObject *self) {
|
||||
ADD_ENUM_TO_MODULE(
|
||||
self, CUtensorMapDataType, CU_TENSOR_MAP_DATA_TYPE_UINT8,
|
||||
CU_TENSOR_MAP_DATA_TYPE_UINT16, CU_TENSOR_MAP_DATA_TYPE_UINT32,
|
||||
CU_TENSOR_MAP_DATA_TYPE_INT32, CU_TENSOR_MAP_DATA_TYPE_UINT64,
|
||||
CU_TENSOR_MAP_DATA_TYPE_INT64, CU_TENSOR_MAP_DATA_TYPE_FLOAT16,
|
||||
CU_TENSOR_MAP_DATA_TYPE_FLOAT32, CU_TENSOR_MAP_DATA_TYPE_FLOAT64,
|
||||
CU_TENSOR_MAP_DATA_TYPE_BFLOAT16, CU_TENSOR_MAP_DATA_TYPE_FLOAT32_FTZ,
|
||||
CU_TENSOR_MAP_DATA_TYPE_TFLOAT32, CU_TENSOR_MAP_DATA_TYPE_TFLOAT32_FTZ);
|
||||
|
||||
ADD_ENUM_TO_MODULE(self, CUtensorMapInterleave, CU_TENSOR_MAP_INTERLEAVE_NONE,
|
||||
CU_TENSOR_MAP_INTERLEAVE_16B,
|
||||
CU_TENSOR_MAP_INTERLEAVE_32B);
|
||||
|
||||
ADD_ENUM_TO_MODULE(self, CUtensorMapSwizzle, CU_TENSOR_MAP_SWIZZLE_NONE,
|
||||
CU_TENSOR_MAP_SWIZZLE_32B, CU_TENSOR_MAP_SWIZZLE_64B,
|
||||
CU_TENSOR_MAP_SWIZZLE_128B);
|
||||
|
||||
ADD_ENUM_TO_MODULE(
|
||||
self, CUtensorMapL2promotion, CU_TENSOR_MAP_L2_PROMOTION_NONE,
|
||||
CU_TENSOR_MAP_L2_PROMOTION_L2_64B, CU_TENSOR_MAP_L2_PROMOTION_L2_128B,
|
||||
CU_TENSOR_MAP_L2_PROMOTION_L2_256B);
|
||||
|
||||
ADD_ENUM_TO_MODULE(self, CUtensorMapFloatOOBfill,
|
||||
CU_TENSOR_MAP_FLOAT_OOB_FILL_NONE,
|
||||
CU_TENSOR_MAP_FLOAT_OOB_FILL_NAN_REQUEST_ZERO_FMA);
|
||||
}
|
||||
|
||||
typedef struct {
|
||||
PyObject_HEAD cuuint32_t value;
|
||||
} PyCUuint32;
|
||||
|
||||
typedef struct {
|
||||
PyObject_HEAD cuuint64_t value;
|
||||
} PyCUuint64;
|
||||
|
||||
#define DEFINE_CUUINT_CONSTRUCTOR(NAME, TYPE, FORMAT, VALUE_TYPE) \
|
||||
static PyObject *Py##NAME##_New(PyTypeObject *type, PyObject *args, \
|
||||
PyObject *kwds) { \
|
||||
Py##NAME *self; \
|
||||
VALUE_TYPE value; \
|
||||
if (!PyArg_ParseTuple(args, FORMAT, &value)) \
|
||||
return NULL; \
|
||||
self = (Py##NAME *)type->tp_alloc(type, 0); \
|
||||
if (self != NULL) { \
|
||||
self->value = (TYPE)value; \
|
||||
} \
|
||||
return (PyObject *)self; \
|
||||
}
|
||||
|
||||
DEFINE_CUUINT_CONSTRUCTOR(CUuint32, cuuint32_t, "l", long)
|
||||
DEFINE_CUUINT_CONSTRUCTOR(CUuint64, cuuint64_t, "L", long long)
|
||||
|
||||
static PyTypeObject PyCUuint32_Type = {
|
||||
PyVarObject_HEAD_INIT(NULL, 0).tp_name = "cuda_utils.cuuint32_t",
|
||||
.tp_basicsize = sizeof(PyCUuint32),
|
||||
.tp_flags = Py_TPFLAGS_DEFAULT,
|
||||
.tp_new = PyCUuint32_New,
|
||||
};
|
||||
|
||||
static PyTypeObject PyCUuint64_Type = {
|
||||
PyVarObject_HEAD_INIT(NULL, 0).tp_name = "cuda_utils.cuuint64_t",
|
||||
.tp_basicsize = sizeof(PyCUuint64),
|
||||
.tp_flags = Py_TPFLAGS_DEFAULT,
|
||||
.tp_new = PyCUuint64_New,
|
||||
};
|
||||
|
||||
static void defineTypes(PyObject *self) {
|
||||
if (PyType_Ready(&PyCUuint32_Type) < 0) {
|
||||
PyErr_SetString(PyExc_TypeError, "Failed to ready cuuint32_t type");
|
||||
return;
|
||||
}
|
||||
Py_INCREF(&PyCUuint32_Type);
|
||||
if (PyModule_AddObject(self, "cuuint32_t", (PyObject *)&PyCUuint32_Type) <
|
||||
0) {
|
||||
PyErr_SetString(PyExc_RuntimeError,
|
||||
"Failed to add cuuint32_t type to module");
|
||||
return;
|
||||
}
|
||||
|
||||
if (PyType_Ready(&PyCUuint64_Type) < 0) {
|
||||
PyErr_SetString(PyExc_TypeError, "Failed to ready cuuint64_t type");
|
||||
return;
|
||||
}
|
||||
Py_INCREF(&PyCUuint64_Type);
|
||||
if (PyModule_AddObject(self, "cuuint64_t", (PyObject *)&PyCUuint64_Type) <
|
||||
0) {
|
||||
PyErr_SetString(PyExc_RuntimeError,
|
||||
"Failed to add cuuint64_t type to module");
|
||||
return;
|
||||
}
|
||||
}
|
||||
|
||||
static PyObject *getDeviceProperties(PyObject *self, PyObject *args) {
|
||||
int device_id;
|
||||
if (!PyArg_ParseTuple(args, "i", &device_id))
|
||||
@@ -70,6 +232,8 @@ static PyObject *loadBinary(PyObject *self, PyObject *args) {
|
||||
int32_t n_spills = 0;
|
||||
// create driver handles
|
||||
CUcontext pctx = 0;
|
||||
|
||||
Py_BEGIN_ALLOW_THREADS;
|
||||
CUDA_CHECK(cuCtxGetCurrent(&pctx));
|
||||
if (!pctx) {
|
||||
CUDA_CHECK(cuDevicePrimaryCtxRetain(&pctx, device));
|
||||
@@ -100,6 +264,7 @@ static PyObject *loadBinary(PyObject *self, PyObject *args) {
|
||||
cuFuncSetAttribute(fun, CU_FUNC_ATTRIBUTE_MAX_DYNAMIC_SHARED_SIZE_BYTES,
|
||||
shared_optin - shared_static));
|
||||
}
|
||||
Py_END_ALLOW_THREADS;
|
||||
|
||||
if (PyErr_Occurred()) {
|
||||
return NULL;
|
||||
@@ -108,11 +273,165 @@ static PyObject *loadBinary(PyObject *self, PyObject *args) {
|
||||
n_spills);
|
||||
}
|
||||
|
||||
static PyObject *memAlloc(PyObject *self, PyObject *args) {
|
||||
size_t bytesize;
|
||||
CUdeviceptr dptr;
|
||||
CUresult result;
|
||||
|
||||
if (!PyArg_ParseTuple(args, "K", &bytesize)) {
|
||||
return NULL; // Error parsing arguments
|
||||
}
|
||||
|
||||
Py_BEGIN_ALLOW_THREADS;
|
||||
CUDA_CHECK(cuMemAlloc(&dptr, bytesize));
|
||||
Py_END_ALLOW_THREADS;
|
||||
|
||||
return PyLong_FromUnsignedLongLong((unsigned long long)dptr);
|
||||
}
|
||||
|
||||
static PyObject *memcpyHtoD(PyObject *self, PyObject *args) {
|
||||
unsigned long long dstDevicePtr, srcHostPtr;
|
||||
size_t byteCount;
|
||||
CUdeviceptr dstDevice;
|
||||
const void *srcHost;
|
||||
CUresult result;
|
||||
|
||||
if (!PyArg_ParseTuple(args, "KKK", &dstDevicePtr, &srcHostPtr, &byteCount)) {
|
||||
return NULL; // Error parsing arguments
|
||||
}
|
||||
|
||||
dstDevice = (CUdeviceptr)dstDevicePtr;
|
||||
srcHost = (const void *)srcHostPtr;
|
||||
|
||||
Py_BEGIN_ALLOW_THREADS;
|
||||
CUDA_CHECK(cuMemcpyHtoD(dstDevice, srcHost, byteCount));
|
||||
Py_END_ALLOW_THREADS;
|
||||
|
||||
Py_RETURN_NONE;
|
||||
}
|
||||
|
||||
static PyObject *memFree(PyObject *self, PyObject *args) {
|
||||
CUdeviceptr dptr;
|
||||
|
||||
if (!PyArg_ParseTuple(args, "K", &dptr)) {
|
||||
return NULL; // Error parsing arguments
|
||||
}
|
||||
|
||||
Py_BEGIN_ALLOW_THREADS;
|
||||
CUDA_CHECK(cuMemFree(dptr));
|
||||
Py_END_ALLOW_THREADS;
|
||||
|
||||
Py_RETURN_NONE;
|
||||
}
|
||||
|
||||
// Helper function to convert a Python list to a cuuint64_t array
|
||||
static cuuint64_t *list_to_cuuint64_array(PyObject *listObj) {
|
||||
Py_ssize_t len = PyList_Size(listObj);
|
||||
cuuint64_t *array = malloc(len * sizeof(cuuint64_t));
|
||||
for (Py_ssize_t i = 0; i < len; i++) {
|
||||
PyObject *item = PyList_GetItem(listObj, i);
|
||||
array[i] = (cuuint64_t)PyLong_AsUnsignedLongLong(item);
|
||||
}
|
||||
return array;
|
||||
}
|
||||
|
||||
// Helper function to convert a Python list to a cuuint32_t array
|
||||
static cuuint32_t *list_to_cuuint32_array(PyObject *listObj) {
|
||||
Py_ssize_t len = PyList_Size(listObj);
|
||||
cuuint32_t *array = malloc(len * sizeof(cuuint32_t));
|
||||
for (Py_ssize_t i = 0; i < len; i++) {
|
||||
PyObject *item = PyList_GetItem(listObj, i);
|
||||
array[i] = (cuuint32_t)PyLong_AsUnsignedLong(item);
|
||||
}
|
||||
return array;
|
||||
}
|
||||
|
||||
typedef CUresult (*cuTensorMapEncodeTiled_t)(
|
||||
CUtensorMap *tensorMap, CUtensorMapDataType tensorDataType,
|
||||
cuuint32_t tensorRank, void *globalAddress, const cuuint64_t *globalDim,
|
||||
const cuuint64_t *globalStrides, const cuuint32_t *boxDim,
|
||||
const cuuint32_t *elementStrides, CUtensorMapInterleave interleave,
|
||||
CUtensorMapSwizzle swizzle, CUtensorMapL2promotion l2Promotion,
|
||||
CUtensorMapFloatOOBfill oobFill);
|
||||
|
||||
static cuTensorMapEncodeTiled_t getCuTensorMapEncodeTiledHandle() {
|
||||
// Open the shared library
|
||||
void *handle = dlopen("libcuda.so", RTLD_LAZY);
|
||||
if (!handle) {
|
||||
PyErr_SetString(PyExc_RuntimeError, "Failed to open libcuda.so");
|
||||
return NULL;
|
||||
}
|
||||
// Clear any existing error
|
||||
dlerror();
|
||||
cuTensorMapEncodeTiled_t cuTensorMapEncodeTiledHandle =
|
||||
(cuTensorMapEncodeTiled_t)dlsym(handle, "cuTensorMapEncodeTiled");
|
||||
// Check for errors
|
||||
const char *dlsym_error = dlerror();
|
||||
if (dlsym_error) {
|
||||
PyErr_SetString(
|
||||
PyExc_RuntimeError,
|
||||
"Failed to retrieve cuTensorMapEncodeTiled from libcuda.so");
|
||||
return NULL;
|
||||
}
|
||||
return cuTensorMapEncodeTiledHandle;
|
||||
}
|
||||
|
||||
static PyObject *tensorMapEncodeTiled(PyObject *self, PyObject *args) {
|
||||
CUtensorMap *tensorMap = (CUtensorMap *)malloc(sizeof(CUtensorMap));
|
||||
CUtensorMapDataType tensorDataType;
|
||||
cuuint32_t tensorRank;
|
||||
void *globalAddress;
|
||||
PyObject *globalDimObj, *globalStridesObj, *boxDimObj, *elementStridesObj;
|
||||
CUtensorMapInterleave interleave;
|
||||
CUtensorMapSwizzle swizzle;
|
||||
CUtensorMapL2promotion l2Promotion;
|
||||
CUtensorMapFloatOOBfill oobFill;
|
||||
|
||||
// Parse arguments
|
||||
if (!PyArg_ParseTuple(args, "iiKO!O!O!O!iiii", &tensorDataType, &tensorRank,
|
||||
&globalAddress, &PyList_Type, &globalDimObj,
|
||||
&PyList_Type, &globalStridesObj, &PyList_Type,
|
||||
&boxDimObj, &PyList_Type, &elementStridesObj,
|
||||
&interleave, &swizzle, &l2Promotion, &oobFill)) {
|
||||
return NULL; // Error parsing arguments
|
||||
}
|
||||
|
||||
// Convert Python lists to C arrays
|
||||
cuuint64_t *globalDim = list_to_cuuint64_array(globalDimObj);
|
||||
cuuint64_t *globalStrides = list_to_cuuint64_array(globalStridesObj);
|
||||
cuuint32_t *boxDim = list_to_cuuint32_array(boxDimObj);
|
||||
cuuint32_t *elementStrides = list_to_cuuint32_array(elementStridesObj);
|
||||
|
||||
static cuTensorMapEncodeTiled_t cuTensorMapEncodeTiledHandle = NULL;
|
||||
if (cuTensorMapEncodeTiledHandle == NULL) {
|
||||
cuTensorMapEncodeTiledHandle = getCuTensorMapEncodeTiledHandle();
|
||||
}
|
||||
// Call the function
|
||||
Py_BEGIN_ALLOW_THREADS;
|
||||
CUDA_CHECK(cuTensorMapEncodeTiledHandle(
|
||||
tensorMap, tensorDataType, tensorRank, globalAddress, globalDim,
|
||||
globalStrides, boxDim, elementStrides, interleave, swizzle, l2Promotion,
|
||||
oobFill));
|
||||
Py_END_ALLOW_THREADS;
|
||||
|
||||
// Clean up
|
||||
free(globalDim);
|
||||
free(globalStrides);
|
||||
free(boxDim);
|
||||
free(elementStrides);
|
||||
// Return the tensor map as a normal pointer
|
||||
return PyLong_FromUnsignedLongLong((unsigned long long)tensorMap);
|
||||
}
|
||||
|
||||
static PyMethodDef ModuleMethods[] = {
|
||||
{"load_binary", loadBinary, METH_VARARGS,
|
||||
"Load provided cubin into CUDA driver"},
|
||||
{"get_device_properties", getDeviceProperties, METH_VARARGS,
|
||||
"Get the properties for a given device"},
|
||||
{"cuMemAlloc", memAlloc, METH_VARARGS},
|
||||
{"cuMemcpyHtoD", memcpyHtoD, METH_VARARGS},
|
||||
{"cuMemFree", memFree, METH_VARARGS},
|
||||
{"cuTensorMapEncodeTiled", tensorMapEncodeTiled, METH_VARARGS},
|
||||
{NULL, NULL, 0, NULL} // sentinel
|
||||
};
|
||||
|
||||
@@ -126,6 +445,10 @@ PyMODINIT_FUNC PyInit_cuda_utils(void) {
|
||||
if (m == NULL) {
|
||||
return NULL;
|
||||
}
|
||||
|
||||
defineEnums(m);
|
||||
defineTypes(m);
|
||||
PyModule_AddFunctions(m, ModuleMethods);
|
||||
|
||||
return m;
|
||||
}
|
||||
|
||||
@@ -40,18 +40,20 @@ class FileCacheManager(CacheManager):
|
||||
self.key = key
|
||||
self.lock_path = None
|
||||
# create cache directory if it doesn't exist
|
||||
self.cache_dir = os.environ.get('TRITON_CACHE_DIR', default_cache_dir())
|
||||
self.cache_dir = os.getenv('TRITON_CACHE_DIR', "").strip() or default_cache_dir()
|
||||
if self.cache_dir:
|
||||
self.cache_dir = os.path.join(self.cache_dir, self.key)
|
||||
self.lock_path = os.path.join(self.cache_dir, "lock")
|
||||
os.makedirs(self.cache_dir, exist_ok=True)
|
||||
else:
|
||||
raise RuntimeError("Could not create or locate cache dir")
|
||||
|
||||
def _make_path(self, filename) -> str:
|
||||
return os.path.join(self.cache_dir, filename)
|
||||
|
||||
def has_file(self, filename):
|
||||
def has_file(self, filename) -> bool:
|
||||
if not self.cache_dir:
|
||||
return False
|
||||
raise RuntimeError("Could not create or locate cache dir")
|
||||
return os.path.exists(self._make_path(filename))
|
||||
|
||||
def get_file(self, filename) -> Optional[str]:
|
||||
@@ -80,16 +82,16 @@ class FileCacheManager(CacheManager):
|
||||
return result
|
||||
|
||||
# Note a group of pushed files as being part of a group
|
||||
def put_group(self, filename: str, group: Dict[str, str]):
|
||||
def put_group(self, filename: str, group: Dict[str, str]) -> str:
|
||||
if not self.cache_dir:
|
||||
return
|
||||
raise RuntimeError("Could not create or locate cache dir")
|
||||
grp_contents = json.dumps({"child_paths": sorted(list(group.keys()))})
|
||||
grp_filename = f"__grp__{filename}"
|
||||
return self.put(grp_contents, grp_filename, binary=False)
|
||||
|
||||
def put(self, data, filename, binary=True) -> str:
|
||||
if not self.cache_dir:
|
||||
return
|
||||
raise RuntimeError("Could not create or locate cache dir")
|
||||
binary = isinstance(data, bytes)
|
||||
if not binary:
|
||||
data = str(data)
|
||||
|
||||
@@ -52,6 +52,15 @@ class CudaUtils(object):
|
||||
spec.loader.exec_module(mod)
|
||||
self.load_binary = mod.load_binary
|
||||
self.get_device_properties = mod.get_device_properties
|
||||
self.CUtensorMapDataType = mod.CUtensorMapDataType
|
||||
self.CUtensorMapInterleave = mod.CUtensorMapInterleave
|
||||
self.CUtensorMapSwizzle = mod.CUtensorMapSwizzle
|
||||
self.CUtensorMapL2promotion = mod.CUtensorMapL2promotion
|
||||
self.CUtensorMapFloatOOBfill = mod.CUtensorMapFloatOOBfill
|
||||
self.cuTensorMapEncodeTiled = mod.cuTensorMapEncodeTiled
|
||||
self.cuMemAlloc = mod.cuMemAlloc
|
||||
self.cuMemcpyHtoD = mod.cuMemcpyHtoD
|
||||
self.cuMemFree = mod.cuMemFree
|
||||
|
||||
|
||||
class CudaDriver(DriverBase):
|
||||
|
||||
@@ -11,7 +11,9 @@ from collections import defaultdict, namedtuple
|
||||
from typing import (Callable, Generic, Iterable, List, Optional, TypeVar, Union, cast,
|
||||
overload)
|
||||
|
||||
from .._C.libtriton.triton import TMAInfos
|
||||
from ..common.backend import get_backend, path_to_ptxas
|
||||
from ..language.core import dtype
|
||||
|
||||
TRITON_PATH = os.path.dirname(os.path.dirname(os.path.abspath(__file__)))
|
||||
TRITON_VERSION = "2.1.0"
|
||||
@@ -59,7 +61,7 @@ class DependenciesFinder(ast.NodeVisitor):
|
||||
|
||||
def __init__(self, globals, src) -> None:
|
||||
super().__init__()
|
||||
self.ret = hashlib.md5(src.encode("utf-8")).hexdigest()
|
||||
self.ret = hashlib.sha1(src.encode("utf-8")).hexdigest()
|
||||
self.globals = globals
|
||||
|
||||
def visit_Name(self, node):
|
||||
@@ -89,7 +91,7 @@ class DependenciesFinder(ast.NodeVisitor):
|
||||
func.hash = finder.ret
|
||||
noinline = str(getattr(func, 'noinline', False))
|
||||
self.ret = (self.ret + func.hash + noinline).encode("utf-8")
|
||||
self.ret = hashlib.md5(self.ret).hexdigest()
|
||||
self.ret = hashlib.sha1(self.ret).hexdigest()
|
||||
|
||||
# -----------------------------------------------------------------------------
|
||||
# JITFunction
|
||||
@@ -102,23 +104,29 @@ def version_key():
|
||||
contents = []
|
||||
# frontend
|
||||
with open(__file__, "rb") as f:
|
||||
contents += [hashlib.md5(f.read()).hexdigest()]
|
||||
contents += [hashlib.sha1(f.read()).hexdigest()]
|
||||
# compiler
|
||||
compiler_path = os.path.join(TRITON_PATH, 'compiler')
|
||||
for lib in pkgutil.iter_modules([compiler_path]):
|
||||
with open(lib.module_finder.find_spec(lib.name).origin, "rb") as f:
|
||||
contents += [hashlib.md5(f.read()).hexdigest()]
|
||||
contents += [hashlib.sha1(f.read()).hexdigest()]
|
||||
# backend
|
||||
libtriton_hash = hashlib.sha1()
|
||||
with open(os.path.join(TRITON_PATH, "_C/libtriton.so"), "rb") as f:
|
||||
contents += [hashlib.md5(f.read()).hexdigest()]
|
||||
while True:
|
||||
chunk = f.read(1024 ** 2)
|
||||
if not chunk:
|
||||
break
|
||||
libtriton_hash.update(chunk)
|
||||
contents.append(libtriton_hash.hexdigest())
|
||||
# language
|
||||
language_path = os.path.join(TRITON_PATH, 'language')
|
||||
for lib in pkgutil.iter_modules([language_path]):
|
||||
with open(lib.module_finder.find_spec(lib.name).origin, "rb") as f:
|
||||
contents += [hashlib.md5(f.read()).hexdigest()]
|
||||
contents += [hashlib.sha1(f.read()).hexdigest()]
|
||||
# ptxas version
|
||||
ptxas = path_to_ptxas()[0]
|
||||
ptxas_version = hashlib.md5(subprocess.check_output([ptxas, "--version"])).hexdigest()
|
||||
ptxas_version = hashlib.sha1(subprocess.check_output([ptxas, "--version"])).hexdigest()
|
||||
return '-'.join(TRITON_VERSION) + '-' + ptxas_version + '-' + '-'.join(contents)
|
||||
|
||||
|
||||
@@ -147,6 +155,11 @@ class JITFunction(KernelInterface[T]):
|
||||
# Hook for inspecting compiled functions and modules
|
||||
cache_hook = None
|
||||
divisibility = 16
|
||||
# As Hopper TMA load and store primitive requires the tensor stride to be 16-byte aligned.
|
||||
# And we only support WGMMA with float16 dtype on Hopper for now.
|
||||
# So whether the LoadOp and StoreOp will lowering into TMA copy depend on whether the tensor stride is divisible by 8.
|
||||
# TODO: Make it more reasonable to handle multiple dtypes.
|
||||
divisibility_8 = 8
|
||||
|
||||
@staticmethod
|
||||
def _key_of(arg):
|
||||
@@ -201,10 +214,29 @@ class JITFunction(KernelInterface[T]):
|
||||
if x is None:
|
||||
return True
|
||||
return False
|
||||
divisible_by_16 = {i for i, arg in enumerate(args) if is_divisible_by_16(arg) and i not in self.do_not_specialize}
|
||||
equal_to_1 = {i for i, arg in enumerate(args) if not isinstance(arg, bool) and isinstance(arg, int) and arg == 1 and i not in self.do_not_specialize}
|
||||
return namedtuple("instance_descriptor", ["divisible_by_16", "equal_to_1"])(tuple(divisible_by_16), tuple(equal_to_1))
|
||||
# return _triton.code_gen.instance_descriptor(divisible_by_16, equal_to_1)
|
||||
|
||||
def is_divisible_by_8(x):
|
||||
if isinstance(x, int):
|
||||
return x % JITFunction.divisibility_8 == 0
|
||||
if x is None:
|
||||
return True
|
||||
return False
|
||||
divisible_by_16 = {i for i, arg in enumerate(
|
||||
args) if is_divisible_by_16(arg) and i not in self.do_not_specialize}
|
||||
divisible_by_8 = {i for i, arg in enumerate(
|
||||
args) if is_divisible_by_8(arg) and i not in self.do_not_specialize}
|
||||
equal_to_1 = {
|
||||
i for i, arg in enumerate(args) if isinstance(
|
||||
arg, int) and not isinstance(
|
||||
arg, bool) and arg == 1 and i not in self.do_not_specialize}
|
||||
# folded equal_to_1 and None
|
||||
# TODO: method to collect all folded args
|
||||
none_args = {i for i, arg in enumerate(args) if arg is None and i not in self.do_not_specialize}
|
||||
ids_of_folded_args = equal_to_1 | none_args
|
||||
return namedtuple("instance_descriptor", ["divisible_by_16", "equal_to_1", "ids_of_folded_args", "divisible_by_8"])(
|
||||
tuple(divisible_by_16), tuple(equal_to_1), tuple(ids_of_folded_args), tuple(divisible_by_8))
|
||||
# return _triton.code_gen.instance_descriptor(divisible_by_16,
|
||||
# equal_to_1)
|
||||
|
||||
@staticmethod
|
||||
def _type_of(key):
|
||||
@@ -214,9 +246,10 @@ class JITFunction(KernelInterface[T]):
|
||||
dtype_str = str(key).split(".")[-1]
|
||||
tys = {
|
||||
"bool": "i1",
|
||||
"float8e4": "fp8e4",
|
||||
"float8e4nv": "fp8e4nv",
|
||||
"float8e5": "fp8e5",
|
||||
"float8e4b15": "fp8e4b15",
|
||||
"float8e4b15x4": "fp8e4b15x4",
|
||||
"float16": "fp16",
|
||||
"bfloat16": "bf16",
|
||||
"float32": "fp32",
|
||||
@@ -243,13 +276,13 @@ class JITFunction(KernelInterface[T]):
|
||||
constants = dict(zip(self.constexprs, constexpr_key))
|
||||
return constants
|
||||
|
||||
def _call_hook(self, key, signature, device, constants, num_warps, num_stages, extern_libs, configs):
|
||||
def _call_hook(self, key, signature, device, constants, num_warps, num_ctas, num_stages, enable_warp_specialization, extern_libs, configs):
|
||||
if JITFunction.cache_hook is None:
|
||||
return False
|
||||
name = self.fn.__name__
|
||||
module = self.fn.__module__
|
||||
arg_reprs = ', '.join([f'{name}: {ty}' for name, ty in zip(self.arg_names, key[1])])
|
||||
repr = f"{name}[num_warps={num_warps}, num_stages={num_stages}]({arg_reprs})"
|
||||
repr = f"{name}[num_warps={num_warps}, num_ctas={num_ctas}, num_stages={num_stages}, enable_warp_specialization={enable_warp_specialization}]({arg_reprs})"
|
||||
key = str(key)
|
||||
|
||||
class LegacyCompiler:
|
||||
@@ -259,21 +292,22 @@ class JITFunction(KernelInterface[T]):
|
||||
pass
|
||||
|
||||
kwargs = dict(signature=signature, device=device, constants=constants,
|
||||
num_warps=num_warps, num_stages=num_stages, extern_libs=extern_libs,
|
||||
num_warps=num_warps, num_ctas=num_ctas, num_stages=num_stages, enable_warp_specialization=enable_warp_specialization, extern_libs=extern_libs,
|
||||
configs=configs)
|
||||
|
||||
return JITFunction.cache_hook(key=key, repr=repr, fn=LegacyCompiler(module, name), compile={"key": key, **kwargs}, is_manual_warmup=False, already_compiled=False)
|
||||
return JITFunction.cache_hook(key=key, repr=repr, fn=LegacyCompiler(module, name), compile={
|
||||
"key": key, **kwargs}, is_manual_warmup=False, already_compiled=False)
|
||||
|
||||
def _get_arg_specialization_key(self, arg) -> str:
|
||||
arg_annotation = self.__annotations__.get(arg, '')
|
||||
if arg_annotation == '':
|
||||
return f'({arg}.data_ptr() % {JITFunction.divisibility} == 0) if hasattr({arg}, "data_ptr") \
|
||||
else ({arg} % {JITFunction.divisibility} == 0, {arg} == 1) if isinstance({arg}, int) \
|
||||
else ({arg} % {JITFunction.divisibility} == 0, {arg} % {JITFunction.divisibility_8} == 0, {arg} == 1) if isinstance({arg}, int) \
|
||||
else (False,)'
|
||||
elif 'Tensor' in arg_annotation:
|
||||
return f'({arg}.data_ptr() % {JITFunction.divisibility} == 0)'
|
||||
elif arg_annotation == 'int':
|
||||
return f'({arg} % {JITFunction.divisibility} == 0, {arg} == 1)'
|
||||
return f'({arg} % {JITFunction.divisibility} == 0, {arg} % {JITFunction.divisibility_8} == 0, {arg} == 1)'
|
||||
else:
|
||||
return '(False,)'
|
||||
|
||||
@@ -304,8 +338,11 @@ class JITFunction(KernelInterface[T]):
|
||||
return device_types[0] if len(device_types) > 0 else 'cuda'
|
||||
|
||||
def _make_launcher(self):
|
||||
regular_args = [f'{arg}' for i, arg in enumerate(self.arg_names) if i not in self.constexprs]
|
||||
constexpr_args = [f'{arg}' for i, arg in enumerate(self.arg_names) if i in self.constexprs]
|
||||
regular_args = [f'{arg}' for i, arg in enumerate(
|
||||
self.arg_names) if i not in self.constexprs]
|
||||
constexpr_args = [
|
||||
f'{arg}' for i, arg in enumerate(
|
||||
self.arg_names) if i in self.constexprs]
|
||||
args = ', '.join(regular_args)
|
||||
# cache key for regular argument type
|
||||
sig_keys = ', '.join([self._get_arg_sig_key(arg) for arg in regular_args])
|
||||
@@ -322,19 +359,24 @@ class JITFunction(KernelInterface[T]):
|
||||
|
||||
spec_keys = ', '.join(specializations)
|
||||
grid_args = ','.join([f'"{arg}": {arg}' for arg in self.arg_names])
|
||||
args_signature = ', '.join(name if dflt == inspect._empty else f'{name} = {dflt}' for name, dflt in zip(self.arg_names, self.arg_defaults))
|
||||
args_signature = ', '.join(name if dflt == inspect._empty else f'{name} = triton.language.dtype(\'{dflt}\')' if dtype.is_dtype(f'{dflt}') else f'{name} = {dflt}' for name, dflt in zip(self.arg_names, self.arg_defaults))
|
||||
args_signature = args_signature + ', ' if len(args_signature) > 0 else ''
|
||||
|
||||
src = f"""
|
||||
<<<<<<< HEAD
|
||||
|
||||
def {self.fn.__name__}({args_signature}, grid=None, num_warps=4, num_stages=3, extern_libs=None, stream=None, warmup=False, device=None, device_type=None):
|
||||
from ..compiler import compile, CompiledKernel
|
||||
sig_key = {sig_keys},
|
||||
=======
|
||||
import triton
|
||||
def {self.fn.__name__}({args_signature}grid=None, num_warps=None, num_ctas=1, num_stages=None, enable_warp_specialization=False, extern_libs=None, stream=None, warmup=False, device=None, device_type=None):
|
||||
from ..compiler import compile, CompiledKernel, get_arch_default_num_warps, get_arch_default_num_stages
|
||||
sig_key = {f'{sig_keys},' if len(sig_keys) > 0 else ()}
|
||||
>>>>>>> 36fc54b6f28168d3644808bfe299f1ba06a36272
|
||||
constexpr_key = {f'{constexpr_keys},' if len(constexpr_keys) > 0 else ()}
|
||||
spec_key = {f'{spec_keys},' if len(spec_keys) > 0 else ()}
|
||||
key = (version_key, sig_key, constexpr_key, spec_key, num_warps, num_stages, self.debug)
|
||||
if not extern_libs is None:
|
||||
key = (key, tuple(extern_libs.items()))
|
||||
assert num_warps > 0 and (num_warps & (num_warps - 1)) == 0, "num_warps must be a power of 2"
|
||||
assert num_ctas > 0
|
||||
assert grid is not None
|
||||
if callable(grid):
|
||||
grid = grid({{{grid_args}}})
|
||||
@@ -366,16 +408,29 @@ def {self.fn.__name__}({args_signature}, grid=None, num_warps=4, num_stages=3, e
|
||||
else:
|
||||
stream = device_backend.get_stream()
|
||||
|
||||
if num_warps is None:
|
||||
num_warps = get_arch_default_num_warps(device_type)
|
||||
if num_stages is None:
|
||||
num_stages = get_arch_default_num_stages(device_type)
|
||||
|
||||
key = (version_key, sig_key, constexpr_key, spec_key, num_warps, num_ctas, num_stages, enable_warp_specialization, self.debug)
|
||||
if not extern_libs is None:
|
||||
key = (key, tuple(extern_libs.items()))
|
||||
|
||||
bin = cache[device].get(key, None)
|
||||
if bin is not None:
|
||||
# build dict of constant values
|
||||
args = [{args}]
|
||||
# Create tensormaps and append to args
|
||||
args = bin.assemble_tensormap_to_arg(args)
|
||||
if not warmup:
|
||||
bin.c_wrapper(grid_0, grid_1, grid_2, bin.num_warps, bin.shared, stream, bin.cu_function, CompiledKernel.launch_enter_hook, CompiledKernel.launch_exit_hook, bin, {args})
|
||||
bin.c_wrapper(grid_0, grid_1, grid_2, bin.num_warps, bin.num_ctas, bin.clusterDims[0], bin.clusterDims[1], bin.clusterDims[2], bin.shared, stream, bin.cu_function, CompiledKernel.launch_enter_hook, CompiledKernel.launch_exit_hook, bin, *args)
|
||||
return bin
|
||||
# kernel not cached -- compile
|
||||
else:
|
||||
# build dict of constant values
|
||||
args = [{args}]
|
||||
all_args = {', '.join([f'{arg}' for arg in self.arg_names])},
|
||||
all_args = {', '.join([f'{arg}' for arg in self.arg_names]) + ', ' if len(self.arg_names) > 0 else ()}
|
||||
configs = self._get_config(*all_args),
|
||||
constants = self._make_constants(constexpr_key)
|
||||
constants.update({{i: None for i, arg in enumerate(all_args) if arg is None}})
|
||||
@@ -386,10 +441,12 @@ def {self.fn.__name__}({args_signature}, grid=None, num_warps=4, num_stages=3, e
|
||||
for i, arg in constants.items():
|
||||
if callable(arg):
|
||||
raise TypeError(f"Callable constexpr at index {{i}} is not supported")
|
||||
if not self._call_hook(key, signature, device, constants, num_warps, num_stages, extern_libs, configs):
|
||||
bin = compile(self, signature=signature, device=device, constants=constants, num_warps=num_warps, num_stages=num_stages, extern_libs=extern_libs, configs=configs, debug=self.debug, device_type=device_type)
|
||||
if not self._call_hook(key, signature, device, constants, num_warps, num_ctas, num_stages, enable_warp_specialization, extern_libs, configs):
|
||||
bin = compile(self, signature=signature, device=device, constants=constants, num_warps=num_warps, num_ctas=num_ctas, num_stages=num_stages, enable_warp_specialization=enable_warp_specialization, extern_libs=extern_libs, configs=configs, debug=self.debug, device_type=device_type)
|
||||
# Create tensormaps and append to args
|
||||
args = bin.assemble_tensormap_to_arg(args)
|
||||
if not warmup:
|
||||
bin.c_wrapper(grid_0, grid_1, grid_2, bin.num_warps, bin.shared, stream, bin.cu_function, CompiledKernel.launch_enter_hook, CompiledKernel.launch_exit_hook, bin, *args)
|
||||
bin.c_wrapper(grid_0, grid_1, grid_2, bin.num_warps, bin.num_ctas, bin.clusterDims[0], bin.clusterDims[1], bin.clusterDims[2], bin.shared, stream, bin.cu_function, CompiledKernel.launch_enter_hook, CompiledKernel.launch_exit_hook, bin, *args)
|
||||
self.cache[device][key] = bin
|
||||
return bin
|
||||
return None
|
||||
@@ -418,9 +475,6 @@ def {self.fn.__name__}({args_signature}, grid=None, num_warps=4, num_stages=3, e
|
||||
self.arg_names = [v.name for v in signature.parameters.values()]
|
||||
self.arg_defaults = [v.default for v in signature.parameters.values()]
|
||||
self.has_defaults = any(v != inspect._empty for v in self.arg_defaults)
|
||||
# specialization hints
|
||||
self.do_not_specialize = [] if do_not_specialize is None else do_not_specialize
|
||||
self.do_not_specialize = {self.arg_names.index(arg) if isinstance(arg, str) else arg for arg in self.do_not_specialize}
|
||||
# function source code (without decorators)
|
||||
self.src = textwrap.dedent(inspect.getsource(fn))
|
||||
self.src = self.src[self.src.find("def"):]
|
||||
@@ -437,6 +491,12 @@ def {self.fn.__name__}({args_signature}, grid=None, num_warps=4, num_stages=3, e
|
||||
self.__annotations__ = {name: _normalize_ty(ty) for name, ty in fn.__annotations__.items()}
|
||||
# index of constexprs
|
||||
self.constexprs = [self.arg_names.index(name) for name, ty in self.__annotations__.items() if 'constexpr' in ty]
|
||||
# specialization hints
|
||||
regular_args = [arg for i, arg in enumerate(self.arg_names) if i not in self.constexprs]
|
||||
self.do_not_specialize = [] if do_not_specialize is None else do_not_specialize
|
||||
self.do_not_specialize = {regular_args.index(arg) if isinstance(arg, str) else arg for arg in self.do_not_specialize}
|
||||
# tma info
|
||||
self.tensormaps_info = TMAInfos()
|
||||
# launcher
|
||||
self.run = self._make_launcher()
|
||||
# re-use docs of wrapped function
|
||||
@@ -594,6 +654,9 @@ class TensorWrapper:
|
||||
def __str__(self) -> str:
|
||||
return f'TensorWrapper[{self.dtype}]({self.base})'
|
||||
|
||||
def element_size(self):
|
||||
return self.base.element_size()
|
||||
|
||||
|
||||
def reinterpret(tensor, dtype):
|
||||
if isinstance(tensor, TensorWrapper):
|
||||
|
||||
Reference in New Issue
Block a user