[FRONTEND] drop the GIL around more CUDA ops (#2173)

This commit is contained in:
Shantanu
2023-08-24 20:31:38 -07:00
committed by GitHub
parent 22a2fe3e55
commit 7083dae4f2

View File

@@ -17,9 +17,7 @@ static inline void gpuAssert(CUresult code, const char *file, int line) {
#define CUDA_CHECK(ans) \
{ \
gpuAssert((ans), __FILE__, __LINE__); \
if (PyErr_Occurred()) \
return NULL; \
{ gpuAssert((ans), __FILE__, __LINE__); } \
}
#define ADD_ENUM_ITEM(value) \
@@ -234,6 +232,8 @@ static PyObject *loadBinary(PyObject *self, PyObject *args) {
int32_t n_spills = 0;
// create driver handles
CUcontext pctx = 0;
Py_BEGIN_ALLOW_THREADS;
CUDA_CHECK(cuCtxGetCurrent(&pctx));
if (!pctx) {
CUDA_CHECK(cuDevicePrimaryCtxRetain(&pctx, device));
@@ -264,6 +264,7 @@ static PyObject *loadBinary(PyObject *self, PyObject *args) {
cuFuncSetAttribute(fun, CU_FUNC_ATTRIBUTE_MAX_DYNAMIC_SHARED_SIZE_BYTES,
shared_optin - shared_static));
}
Py_END_ALLOW_THREADS;
if (PyErr_Occurred()) {
return NULL;
@@ -281,7 +282,9 @@ static PyObject *memAlloc(PyObject *self, PyObject *args) {
return NULL; // Error parsing arguments
}
Py_BEGIN_ALLOW_THREADS;
CUDA_CHECK(cuMemAlloc(&dptr, bytesize));
Py_END_ALLOW_THREADS;
return PyLong_FromUnsignedLongLong((unsigned long long)dptr);
}
@@ -300,7 +303,9 @@ static PyObject *memcpyHtoD(PyObject *self, PyObject *args) {
dstDevice = (CUdeviceptr)dstDevicePtr;
srcHost = (const void *)srcHostPtr;
Py_BEGIN_ALLOW_THREADS;
CUDA_CHECK(cuMemcpyHtoD(dstDevice, srcHost, byteCount));
Py_END_ALLOW_THREADS;
Py_RETURN_NONE;
}
@@ -312,7 +317,9 @@ static PyObject *memFree(PyObject *self, PyObject *args) {
return NULL; // Error parsing arguments
}
Py_BEGIN_ALLOW_THREADS;
CUDA_CHECK(cuMemFree(dptr));
Py_END_ALLOW_THREADS;
Py_RETURN_NONE;
}
@@ -400,10 +407,12 @@ static PyObject *tensorMapEncodeTiled(PyObject *self, PyObject *args) {
cuTensorMapEncodeTiledHandle = getCuTensorMapEncodeTiledHandle();
}
// Call the function
Py_BEGIN_ALLOW_THREADS;
CUDA_CHECK(cuTensorMapEncodeTiledHandle(
tensorMap, tensorDataType, tensorRank, globalAddress, globalDim,
globalStrides, boxDim, elementStrides, interleave, swizzle, l2Promotion,
oobFill));
Py_END_ALLOW_THREADS;
// Clean up
free(globalDim);