mirror of
https://github.com/ROCm/ROCm.git
synced 2026-04-05 03:01:17 -04:00
[RUNTIME] Make apis compatible with cuda 11 drivers (#2081)
https://github.com/openai/triton/issues/2042
This commit is contained in:
@@ -1,4 +1,5 @@
|
||||
#include "cuda.h"
|
||||
#include <dlfcn.h>
|
||||
#define PY_SSIZE_T_CLEAN
|
||||
#include <Python.h>
|
||||
|
||||
@@ -338,6 +339,36 @@ static cuuint32_t *list_to_cuuint32_array(PyObject *listObj) {
|
||||
return array;
|
||||
}
|
||||
|
||||
typedef CUresult (*cuTensorMapEncodeTiled_t)(
|
||||
CUtensorMap *tensorMap, CUtensorMapDataType tensorDataType,
|
||||
cuuint32_t tensorRank, void *globalAddress, const cuuint64_t *globalDim,
|
||||
const cuuint64_t *globalStrides, const cuuint32_t *boxDim,
|
||||
const cuuint32_t *elementStrides, CUtensorMapInterleave interleave,
|
||||
CUtensorMapSwizzle swizzle, CUtensorMapL2promotion l2Promotion,
|
||||
CUtensorMapFloatOOBfill oobFill);
|
||||
|
||||
static cuTensorMapEncodeTiled_t getCuTensorMapEncodeTiledHandle() {
|
||||
// Open the shared library
|
||||
void *handle = dlopen("libcuda.so", RTLD_LAZY);
|
||||
if (!handle) {
|
||||
PyErr_SetString(PyExc_RuntimeError, "Failed to open libcuda.so");
|
||||
return NULL;
|
||||
}
|
||||
// Clear any existing error
|
||||
dlerror();
|
||||
cuTensorMapEncodeTiled_t cuTensorMapEncodeTiledHandle =
|
||||
(cuTensorMapEncodeTiled_t)dlsym(handle, "cuTensorMapEncodeTiled");
|
||||
// Check for errors
|
||||
const char *dlsym_error = dlerror();
|
||||
if (dlsym_error) {
|
||||
PyErr_SetString(
|
||||
PyExc_RuntimeError,
|
||||
"Failed to retrieve cuTensorMapEncodeTiled from libcuda.so");
|
||||
return NULL;
|
||||
}
|
||||
return cuTensorMapEncodeTiledHandle;
|
||||
}
|
||||
|
||||
static PyObject *tensorMapEncodeTiled(PyObject *self, PyObject *args) {
|
||||
CUtensorMap *tensorMap = (CUtensorMap *)malloc(sizeof(CUtensorMap));
|
||||
CUtensorMapDataType tensorDataType;
|
||||
@@ -364,18 +395,21 @@ static PyObject *tensorMapEncodeTiled(PyObject *self, PyObject *args) {
|
||||
cuuint32_t *boxDim = list_to_cuuint32_array(boxDimObj);
|
||||
cuuint32_t *elementStrides = list_to_cuuint32_array(elementStridesObj);
|
||||
|
||||
static cuTensorMapEncodeTiled_t cuTensorMapEncodeTiledHandle = NULL;
|
||||
if (cuTensorMapEncodeTiledHandle == NULL) {
|
||||
cuTensorMapEncodeTiledHandle = getCuTensorMapEncodeTiledHandle();
|
||||
}
|
||||
// Call the function
|
||||
CUDA_CHECK(cuTensorMapEncodeTiled(tensorMap, tensorDataType, tensorRank,
|
||||
globalAddress, globalDim, globalStrides,
|
||||
boxDim, elementStrides, interleave, swizzle,
|
||||
l2Promotion, oobFill));
|
||||
CUDA_CHECK(cuTensorMapEncodeTiledHandle(
|
||||
tensorMap, tensorDataType, tensorRank, globalAddress, globalDim,
|
||||
globalStrides, boxDim, elementStrides, interleave, swizzle, l2Promotion,
|
||||
oobFill));
|
||||
|
||||
// Clean up
|
||||
free(globalDim);
|
||||
free(globalStrides);
|
||||
free(boxDim);
|
||||
free(elementStrides);
|
||||
|
||||
// Return the tensor map as a normal pointer
|
||||
return PyLong_FromUnsignedLongLong((unsigned long long)tensorMap);
|
||||
}
|
||||
|
||||
Reference in New Issue
Block a user