mirror of
https://github.com/ROCm/ROCm.git
synced 2026-04-05 03:01:17 -04:00
I noticed that Triton is using the `ptxas` version as part of the version hash even for non-CUDA targets. This is an attempt at fixing this. Moving the version calculation to the back-end makes sense to me from an architectural standpoint, so that's my approach here. I'm not as confident in the implementation, so please if folks have any feedback let me know.
294 lines
9.9 KiB
Python
294 lines
9.9 KiB
Python
import hashlib
|
|
import os
|
|
import tempfile
|
|
|
|
from ..common import _build
|
|
from ..common.backend import get_cuda_version_key
|
|
from ..common.build import is_hip
|
|
from ..runtime.cache import get_cache_manager
|
|
from .utils import generate_cu_signature
|
|
|
|
# ----- stub --------
|
|
|
|
|
|
def make_so_cache_key(version_hash, signature, constants, ids, **kwargs):
|
|
# Get unique key for the compiled code
|
|
signature = {k: 'ptr' if v[0] == '*' else v for k, v in signature.items()}
|
|
key = f"{version_hash}-{''.join(signature.values())}-{constants}-{ids}"
|
|
for kw in kwargs:
|
|
key = f"{key}-{kwargs.get(kw)}"
|
|
key = hashlib.md5(key.encode("utf-8")).hexdigest()
|
|
return key
|
|
|
|
|
|
def make_stub(name, signature, constants, ids, **kwargs):
|
|
# name of files that are cached
|
|
so_cache_key = make_so_cache_key(get_cuda_version_key(), signature, constants, ids, **kwargs)
|
|
so_cache_manager = get_cache_manager(so_cache_key)
|
|
so_name = f"{name}.so"
|
|
# retrieve stub from cache if it exists
|
|
cache_path = so_cache_manager.get_file(so_name)
|
|
if cache_path is None:
|
|
with tempfile.TemporaryDirectory() as tmpdir:
|
|
src = generate_launcher(constants, signature, ids)
|
|
src_path = os.path.join(tmpdir, "main.c")
|
|
with open(src_path, "w") as f:
|
|
f.write(src)
|
|
so = _build(name, src_path, tmpdir)
|
|
with open(so, "rb") as f:
|
|
return so_cache_manager.put(f.read(), so_name, binary=True)
|
|
else:
|
|
return cache_path
|
|
|
|
# ----- source code generation --------
|
|
|
|
|
|
def ty_to_cpp(ty):
|
|
if ty[0] == '*':
|
|
return "hipDeviceptr_t" if is_hip() else "CUdeviceptr"
|
|
return {
|
|
"i1": "int32_t",
|
|
"i8": "int8_t",
|
|
"i16": "int16_t",
|
|
"i32": "int32_t",
|
|
"i64": "int64_t",
|
|
"u32": "uint32_t",
|
|
"u64": "uint64_t",
|
|
"fp16": "float",
|
|
"bf16": "float",
|
|
"fp32": "float",
|
|
"f32": "float",
|
|
"fp64": "double",
|
|
}[ty]
|
|
|
|
|
|
def generate_launcher(constants, signature, ids):
|
|
# Record the end of regular arguments;
|
|
# subsequent arguments are architecture-specific descriptors, such as tensor descriptors for CUDA.
|
|
signature, desc_start_idx = generate_cu_signature(constants, signature, ids)
|
|
arg_decls = ', '.join(f"{ty_to_cpp(ty)} arg{i}" for i, ty in signature.items())
|
|
|
|
def _extracted_type(ty):
|
|
if ty[0] == '*':
|
|
return "PyObject*"
|
|
return {
|
|
'i1': 'int32_t',
|
|
'i32': 'int32_t',
|
|
'i64': 'int64_t',
|
|
'u32': 'uint32_t',
|
|
'u64': 'uint64_t',
|
|
'fp16': 'float',
|
|
'bf16': 'float',
|
|
'fp32': 'float',
|
|
'f32': 'float',
|
|
'fp64': 'double',
|
|
}[ty]
|
|
|
|
def format_of(ty):
|
|
return {
|
|
"PyObject*": "O",
|
|
"float": "f",
|
|
"double": "d",
|
|
"long": "l",
|
|
"uint32_t": "I",
|
|
"int32_t": "i",
|
|
"uint64_t": "K",
|
|
"int64_t": "L",
|
|
}[ty]
|
|
|
|
format = "iiiiiiiiiKKOOO" + ''.join([format_of(_extracted_type(ty)) for ty in signature.values()])
|
|
|
|
# generate glue code
|
|
folded_without_constexprs = [c for c in ids['ids_of_folded_args'] if c not in ids['ids_of_const_exprs']]
|
|
params = [i for i in signature.keys() if i >= desc_start_idx or (i not in constants and i not in folded_without_constexprs)]
|
|
src = f"""
|
|
#include \"cuda.h\"
|
|
#include <stdbool.h>
|
|
#include <Python.h>
|
|
#include <dlfcn.h>
|
|
|
|
static inline void gpuAssert(CUresult code, const char *file, int line)
|
|
{{
|
|
if (code != CUDA_SUCCESS)
|
|
{{
|
|
const char* prefix = "Triton Error [CUDA]: ";
|
|
const char* str;
|
|
cuGetErrorString(code, &str);
|
|
char err[1024] = {{0}};
|
|
strcat(err, prefix);
|
|
strcat(err, str);
|
|
PyGILState_STATE gil_state;
|
|
gil_state = PyGILState_Ensure();
|
|
PyErr_SetString(PyExc_RuntimeError, err);
|
|
PyGILState_Release(gil_state);
|
|
}}
|
|
}}
|
|
|
|
#define CUDA_CHECK(ans) {{ gpuAssert((ans), __FILE__, __LINE__); }}
|
|
|
|
typedef CUresult (*cuLaunchKernelEx_t)(const CUlaunchConfig* config, CUfunction f, void** kernelParams, void** extra);
|
|
|
|
static cuLaunchKernelEx_t getLaunchKernelExHandle() {{
|
|
// Open the shared library
|
|
void* handle = dlopen("libcuda.so", RTLD_LAZY);
|
|
if (!handle) {{
|
|
PyErr_SetString(PyExc_RuntimeError, "Failed to open libcuda.so");
|
|
return NULL;
|
|
}}
|
|
// Clear any existing error
|
|
dlerror();
|
|
cuLaunchKernelEx_t cuLaunchKernelExHandle = (cuLaunchKernelEx_t)dlsym(handle, "cuLaunchKernelEx");
|
|
// Check for errors
|
|
const char *dlsym_error = dlerror();
|
|
if (dlsym_error) {{
|
|
PyErr_SetString(PyExc_RuntimeError, "Failed to retrieve cuLaunchKernelEx from libcuda.so");
|
|
return NULL;
|
|
}}
|
|
return cuLaunchKernelExHandle;
|
|
}}
|
|
|
|
static void _launch(int gridX, int gridY, int gridZ, int num_warps, int num_ctas, int clusterDimX, int clusterDimY, int clusterDimZ, int shared_memory, CUstream stream, CUfunction function{', ' + arg_decls if len(arg_decls) > 0 else ''}) {{
|
|
void *params[] = {{ {', '.join(f"&arg{i}" for i in params)} }};
|
|
if (gridX*gridY*gridZ > 0) {{
|
|
if (num_ctas == 1) {{
|
|
CUDA_CHECK(cuLaunchKernel(function, gridX, gridY, gridZ, 32*num_warps, 1, 1, shared_memory, stream, params, 0));
|
|
}} else {{
|
|
CUlaunchAttribute launchAttr[2];
|
|
launchAttr[0].id = CU_LAUNCH_ATTRIBUTE_CLUSTER_DIMENSION;
|
|
launchAttr[0].value.clusterDim.x = clusterDimX;
|
|
launchAttr[0].value.clusterDim.y = clusterDimY;
|
|
launchAttr[0].value.clusterDim.z = clusterDimZ;
|
|
launchAttr[1].id = CU_LAUNCH_ATTRIBUTE_CLUSTER_SCHEDULING_POLICY_PREFERENCE;
|
|
launchAttr[1].value.clusterSchedulingPolicyPreference = CU_CLUSTER_SCHEDULING_POLICY_SPREAD;
|
|
CUlaunchConfig config;
|
|
config.gridDimX = gridX * clusterDimX;
|
|
config.gridDimY = gridY * clusterDimY;
|
|
config.gridDimZ = gridZ * clusterDimZ;
|
|
config.blockDimX = 32 * num_warps;
|
|
config.blockDimY = 1;
|
|
config.blockDimZ = 1;
|
|
config.sharedMemBytes = shared_memory;
|
|
config.hStream = stream;
|
|
config.attrs = launchAttr;
|
|
config.numAttrs = 2;
|
|
static cuLaunchKernelEx_t cuLaunchKernelExHandle = NULL;
|
|
if (cuLaunchKernelExHandle == NULL) {{
|
|
cuLaunchKernelExHandle = getLaunchKernelExHandle();
|
|
}}
|
|
CUDA_CHECK(cuLaunchKernelExHandle(&config, function, params, 0));
|
|
}}
|
|
}}
|
|
}}
|
|
|
|
typedef struct _DevicePtrInfo {{
|
|
CUdeviceptr dev_ptr;
|
|
bool valid;
|
|
}} DevicePtrInfo;
|
|
|
|
static inline DevicePtrInfo getPointer(PyObject *obj, int idx) {{
|
|
DevicePtrInfo ptr_info;
|
|
ptr_info.dev_ptr = 0;
|
|
ptr_info.valid = true;
|
|
if (PyLong_Check(obj)) {{
|
|
ptr_info.dev_ptr = PyLong_AsUnsignedLongLong(obj);
|
|
return ptr_info;
|
|
}}
|
|
if (obj == Py_None) {{
|
|
// valid nullptr
|
|
return ptr_info;
|
|
}}
|
|
PyObject *ptr = PyObject_GetAttrString(obj, "data_ptr");
|
|
if(ptr){{
|
|
PyObject *empty_tuple = PyTuple_New(0);
|
|
PyObject *ret = PyObject_Call(ptr, empty_tuple, NULL);
|
|
Py_DECREF(empty_tuple);
|
|
Py_DECREF(ptr);
|
|
if (!PyLong_Check(ret)) {{
|
|
PyErr_SetString(PyExc_TypeError, "data_ptr method of Pointer object must return 64-bit int");
|
|
ptr_info.valid = false;
|
|
return ptr_info;
|
|
}}
|
|
ptr_info.dev_ptr = PyLong_AsUnsignedLongLong(ret);
|
|
if(!ptr_info.dev_ptr)
|
|
return ptr_info;
|
|
uint64_t dev_ptr;
|
|
int status = cuPointerGetAttribute(&dev_ptr, CU_POINTER_ATTRIBUTE_DEVICE_POINTER, ptr_info.dev_ptr);
|
|
if (status == CUDA_ERROR_INVALID_VALUE) {{
|
|
PyErr_Format(PyExc_ValueError,
|
|
"Pointer argument (at %d) cannot be accessed from Triton (cpu tensor?)", idx);
|
|
ptr_info.valid = false;
|
|
}}
|
|
ptr_info.dev_ptr = dev_ptr;
|
|
Py_DECREF(ret); // Thanks ChatGPT!
|
|
return ptr_info;
|
|
}}
|
|
PyErr_SetString(PyExc_TypeError, "Pointer argument must be either uint64 or have data_ptr method");
|
|
ptr_info.valid = false;
|
|
return ptr_info;
|
|
}}
|
|
|
|
static PyObject* launch(PyObject* self, PyObject* args) {{
|
|
int gridX, gridY, gridZ;
|
|
uint64_t _stream;
|
|
uint64_t _function;
|
|
int num_warps;
|
|
int num_ctas;
|
|
int clusterDimX;
|
|
int clusterDimY;
|
|
int clusterDimZ;
|
|
int shared_memory;
|
|
PyObject *launch_enter_hook = NULL;
|
|
PyObject *launch_exit_hook = NULL;
|
|
PyObject *compiled_kernel = NULL;
|
|
{' '.join([f"{_extracted_type(ty)} _arg{i}; " for i, ty in signature.items()])}
|
|
if(!PyArg_ParseTuple(args, \"{format}\", &gridX, &gridY, &gridZ, &num_warps, &num_ctas, &clusterDimX, &clusterDimY, &clusterDimZ, &shared_memory, &_stream, &_function, &launch_enter_hook, &launch_exit_hook, &compiled_kernel{', ' + ', '.join(f"&_arg{i}" for i, ty in signature.items()) if len(signature) > 0 else ''})) {{
|
|
return NULL;
|
|
}}
|
|
|
|
if (launch_enter_hook != Py_None && !PyObject_CallObject(launch_enter_hook, args)) {{
|
|
return NULL;
|
|
}}
|
|
|
|
|
|
// raise exception asap
|
|
{"; ".join([f"DevicePtrInfo ptr_info{i} = getPointer(_arg{i}, {i}); if (!ptr_info{i}.valid) return NULL;" if ty[0] == "*" else "" for i, ty in signature.items()])};
|
|
Py_BEGIN_ALLOW_THREADS;
|
|
_launch(gridX, gridY, gridZ, num_warps, num_ctas, clusterDimX, clusterDimY, clusterDimZ, shared_memory, (CUstream)_stream, (CUfunction)_function{', ' + ', '.join(f"ptr_info{i}.dev_ptr" if ty[0]=="*" else f"_arg{i}"for i, ty in signature.items()) if len(signature) > 0 else ''});
|
|
Py_END_ALLOW_THREADS;
|
|
if (PyErr_Occurred()) {{
|
|
return NULL;
|
|
}}
|
|
|
|
if (launch_exit_hook != Py_None && !PyObject_CallObject(launch_exit_hook, args)) {{
|
|
return NULL;
|
|
}}
|
|
|
|
// return None
|
|
Py_INCREF(Py_None);
|
|
return Py_None;
|
|
}}
|
|
|
|
static PyMethodDef ModuleMethods[] = {{
|
|
{{"launch", launch, METH_VARARGS, "Entry point for all kernels with this signature"}},
|
|
{{NULL, NULL, 0, NULL}} // sentinel
|
|
}};
|
|
|
|
static struct PyModuleDef ModuleDef = {{
|
|
PyModuleDef_HEAD_INIT,
|
|
\"__triton_launcher\",
|
|
NULL, //documentation
|
|
-1, //size
|
|
ModuleMethods
|
|
}};
|
|
|
|
PyMODINIT_FUNC PyInit___triton_launcher(void) {{
|
|
PyObject *m = PyModule_Create(&ModuleDef);
|
|
if(m == NULL) {{
|
|
return NULL;
|
|
}}
|
|
PyModule_AddFunctions(m, ModuleMethods);
|
|
return m;
|
|
}}
|
|
"""
|
|
return src
|