Fix segfault in assertion test. (#2520)

<git-pr-chain>

#### Commits in this PR
1. Fix segfault in assertion test.
    
The issue here is that we were not checking the return values of the
CUDA API
calls we were making. We call one function and then use the data it
returns as
input to another call. Obviously this doesn't work if the first call
returns
    an error and doesn't actually return meaningful data.
    
I don't know why this was passing in CI, but it failed consistently for
me.

#### [PR chain](https://github.com/jlebar/git-pr-chain)
1. 👉 #2520 👈 **YOU ARE HERE**


</git-pr-chain>
This commit is contained in:
Justin Lebar
2023-10-19 13:42:38 -07:00
committed by GitHub
parent bdf464e4a8
commit 30186f401e
3 changed files with 86 additions and 47 deletions

View File

@@ -1,27 +1,42 @@
#include "cuda.h"
#include <dlfcn.h>
#include <stdbool.h>
#define PY_SSIZE_T_CLEAN
#include <Python.h>
static inline void gpuAssert(CUresult code, const char *file, int line) {
if (code != CUDA_SUCCESS) {
const char *prefix = "Triton Error [CUDA]: ";
const char *str;
cuGetErrorString(code, &str);
char err[1024] = {0};
strcat(err, prefix);
strcat(err, str);
PyGILState_STATE gil_state;
gil_state = PyGILState_Ensure();
PyErr_SetString(PyExc_RuntimeError, err);
PyGILState_Release(gil_state);
}
// Raises a Python exception and returns false if code is not CUDA_SUCCESS.
static bool gpuAssert(CUresult code, const char *file, int line) {
if (code == CUDA_SUCCESS)
return true;
const char *prefix = "Triton Error [CUDA]: ";
const char *str;
cuGetErrorString(code, &str);
char err[1024] = {0};
strcat(err, prefix);
strcat(err, str);
PyGILState_STATE gil_state;
gil_state = PyGILState_Ensure();
PyErr_SetString(PyExc_RuntimeError, err);
PyGILState_Release(gil_state);
return false;
}
#define CUDA_CHECK(ans) \
{ \
{ gpuAssert((ans), __FILE__, __LINE__); } \
}
// To be used only *outside* a Py_{BEGIN,END}_ALLOW_THREADS block.
#define CUDA_CHECK_AND_RETURN_NULL(ans) \
do { \
if (!gpuAssert((ans), __FILE__, __LINE__)) \
return NULL; \
} while (0)
// To be used inside a Py_{BEGIN,END}_ALLOW_THREADS block.
#define CUDA_CHECK_AND_RETURN_NULL_ALLOW_THREADS(ans) \
do { \
if (!gpuAssert((ans), __FILE__, __LINE__)) { \
PyEval_RestoreThread(_save); \
return NULL; \
} \
} while (0)
#define ADD_ENUM_ITEM(value) \
do { \
@@ -200,16 +215,16 @@ static PyObject *getDeviceProperties(PyObject *self, PyObject *args) {
int sm_clock_rate;
int mem_clock_rate;
int mem_bus_width;
CUDA_CHECK(cuDeviceGetAttribute(
CUDA_CHECK_AND_RETURN_NULL(cuDeviceGetAttribute(
&max_shared_mem, CU_DEVICE_ATTRIBUTE_MAX_SHARED_MEMORY_PER_BLOCK_OPTIN,
device));
CUDA_CHECK(cuDeviceGetAttribute(
CUDA_CHECK_AND_RETURN_NULL(cuDeviceGetAttribute(
&multiprocessor_count, CU_DEVICE_ATTRIBUTE_MULTIPROCESSOR_COUNT, device));
CUDA_CHECK(cuDeviceGetAttribute(&sm_clock_rate,
CU_DEVICE_ATTRIBUTE_CLOCK_RATE, device));
CUDA_CHECK(cuDeviceGetAttribute(
CUDA_CHECK_AND_RETURN_NULL(cuDeviceGetAttribute(
&sm_clock_rate, CU_DEVICE_ATTRIBUTE_CLOCK_RATE, device));
CUDA_CHECK_AND_RETURN_NULL(cuDeviceGetAttribute(
&mem_clock_rate, CU_DEVICE_ATTRIBUTE_MEMORY_CLOCK_RATE, device));
CUDA_CHECK(cuDeviceGetAttribute(
CUDA_CHECK_AND_RETURN_NULL(cuDeviceGetAttribute(
&mem_bus_width, CU_DEVICE_ATTRIBUTE_GLOBAL_MEMORY_BUS_WIDTH, device));
return Py_BuildValue("{s:i, s:i, s:i, s:i, s:i}", "max_shared_mem",
@@ -237,33 +252,37 @@ static PyObject *loadBinary(PyObject *self, PyObject *args) {
CUcontext pctx = 0;
Py_BEGIN_ALLOW_THREADS;
CUDA_CHECK(cuCtxGetCurrent(&pctx));
CUDA_CHECK_AND_RETURN_NULL_ALLOW_THREADS(cuCtxGetCurrent(&pctx));
if (!pctx) {
CUDA_CHECK(cuDevicePrimaryCtxRetain(&pctx, device));
CUDA_CHECK(cuCtxSetCurrent(pctx));
CUDA_CHECK_AND_RETURN_NULL_ALLOW_THREADS(
cuDevicePrimaryCtxRetain(&pctx, device));
CUDA_CHECK_AND_RETURN_NULL_ALLOW_THREADS(cuCtxSetCurrent(pctx));
}
CUDA_CHECK(cuModuleLoadData(&mod, data));
CUDA_CHECK(cuModuleGetFunction(&fun, mod, name));
CUDA_CHECK_AND_RETURN_NULL_ALLOW_THREADS(cuModuleLoadData(&mod, data));
CUDA_CHECK_AND_RETURN_NULL_ALLOW_THREADS(
cuModuleGetFunction(&fun, mod, name));
// get allocated registers and spilled registers from the function
CUDA_CHECK(cuFuncGetAttribute(&n_regs, CU_FUNC_ATTRIBUTE_NUM_REGS, fun));
CUDA_CHECK(
CUDA_CHECK_AND_RETURN_NULL_ALLOW_THREADS(
cuFuncGetAttribute(&n_regs, CU_FUNC_ATTRIBUTE_NUM_REGS, fun));
CUDA_CHECK_AND_RETURN_NULL_ALLOW_THREADS(
cuFuncGetAttribute(&n_spills, CU_FUNC_ATTRIBUTE_LOCAL_SIZE_BYTES, fun));
n_spills /= 4;
// set dynamic shared memory if necessary
int shared_optin;
CUDA_CHECK(cuDeviceGetAttribute(
CUDA_CHECK_AND_RETURN_NULL_ALLOW_THREADS(cuDeviceGetAttribute(
&shared_optin, CU_DEVICE_ATTRIBUTE_MAX_SHARED_MEMORY_PER_BLOCK_OPTIN,
device));
if (shared > 49152 && shared_optin > 49152) {
CUDA_CHECK(cuFuncSetCacheConfig(fun, CU_FUNC_CACHE_PREFER_SHARED));
CUDA_CHECK_AND_RETURN_NULL_ALLOW_THREADS(
cuFuncSetCacheConfig(fun, CU_FUNC_CACHE_PREFER_SHARED));
int shared_total, shared_static;
CUDA_CHECK(cuDeviceGetAttribute(
CUDA_CHECK_AND_RETURN_NULL_ALLOW_THREADS(cuDeviceGetAttribute(
&shared_total, CU_DEVICE_ATTRIBUTE_MAX_SHARED_MEMORY_PER_MULTIPROCESSOR,
device));
CUDA_CHECK(cuFuncGetAttribute(&shared_static,
CU_FUNC_ATTRIBUTE_SHARED_SIZE_BYTES, fun));
CUDA_CHECK(
CUDA_CHECK_AND_RETURN_NULL_ALLOW_THREADS(cuFuncGetAttribute(
&shared_static, CU_FUNC_ATTRIBUTE_SHARED_SIZE_BYTES, fun));
CUDA_CHECK_AND_RETURN_NULL_ALLOW_THREADS(
cuFuncSetAttribute(fun, CU_FUNC_ATTRIBUTE_MAX_DYNAMIC_SHARED_SIZE_BYTES,
shared_optin - shared_static));
}
@@ -286,7 +305,7 @@ static PyObject *memAlloc(PyObject *self, PyObject *args) {
}
Py_BEGIN_ALLOW_THREADS;
CUDA_CHECK(cuMemAlloc(&dptr, bytesize));
CUDA_CHECK_AND_RETURN_NULL_ALLOW_THREADS(cuMemAlloc(&dptr, bytesize));
Py_END_ALLOW_THREADS;
return PyLong_FromUnsignedLongLong((unsigned long long)dptr);
@@ -307,7 +326,8 @@ static PyObject *memcpyHtoD(PyObject *self, PyObject *args) {
srcHost = (const void *)srcHostPtr;
Py_BEGIN_ALLOW_THREADS;
CUDA_CHECK(cuMemcpyHtoD(dstDevice, srcHost, byteCount));
CUDA_CHECK_AND_RETURN_NULL_ALLOW_THREADS(
cuMemcpyHtoD(dstDevice, srcHost, byteCount));
Py_END_ALLOW_THREADS;
Py_RETURN_NONE;
@@ -321,7 +341,7 @@ static PyObject *memFree(PyObject *self, PyObject *args) {
}
Py_BEGIN_ALLOW_THREADS;
CUDA_CHECK(cuMemFree(dptr));
CUDA_CHECK_AND_RETURN_NULL_ALLOW_THREADS(cuMemFree(dptr));
Py_END_ALLOW_THREADS;
Py_RETURN_NONE;
@@ -411,7 +431,7 @@ static PyObject *tensorMapEncodeTiled(PyObject *self, PyObject *args) {
}
// Call the function
Py_BEGIN_ALLOW_THREADS;
CUDA_CHECK(cuTensorMapEncodeTiledHandle(
CUDA_CHECK_AND_RETURN_NULL_ALLOW_THREADS(cuTensorMapEncodeTiledHandle(
tensorMap, tensorDataType, tensorRank, globalAddress, globalDim,
globalStrides, boxDim, elementStrides, interleave, swizzle, l2Promotion,
oobFill));