mirror of
https://github.com/ROCm/ROCm.git
synced 2026-04-05 03:01:17 -04:00
This PR is a first in a series of PRs to import the changes that we have made to enable ROCM on [our fork](https://github.com/ROCmSoftwarePlatform/triton) of triton. The PR contains the major changes to the python frontend and enough changes to the c++ backend to allow compilation and running of the empty kernel. We use the ROCM ci added a few weeks ago to verify things. --------- Co-authored-by: Ronan Keryell <ronan@keryell.fr>
This commit is contained in:
@@ -23,9 +23,11 @@
|
||||
#include "triton/Dialect/Triton/IR/Types.h"
|
||||
#include "triton/Dialect/Triton/Transforms/Passes.h"
|
||||
#include "triton/Dialect/TritonGPU/Transforms/Passes.h"
|
||||
#include "triton/Target/HSACO/HSACOTranslation.h"
|
||||
#include "triton/Target/LLVMIR/LLVMIRTranslation.h"
|
||||
#include "triton/Target/PTX/PTXTranslation.h"
|
||||
#include "triton/Tools/Sys/GetEnv.hpp"
|
||||
#include "triton/Tools/Sys/GetPlatform.hpp"
|
||||
|
||||
#include "llvm/IR/LegacyPassManager.h"
|
||||
#include "llvm/IR/Module.h"
|
||||
@@ -64,7 +66,7 @@ void init_triton_runtime(py::module &&m) {
|
||||
py::enum_<backend_t>(m, "backend")
|
||||
.value("HOST", HOST)
|
||||
.value("CUDA", CUDA)
|
||||
// .value("ROCM", ROCM)
|
||||
.value("ROCM", ROCM)
|
||||
.export_values();
|
||||
}
|
||||
|
||||
@@ -1487,6 +1489,9 @@ void init_triton_translation(py::module &m) {
|
||||
return shared.getInt();
|
||||
});
|
||||
|
||||
m.def(
|
||||
"set_rocm", []() { setROCM(); }, ret::take_ownership);
|
||||
|
||||
m.def(
|
||||
"translate_triton_gpu_to_llvmir",
|
||||
[](mlir::ModuleOp op, int computeCapability) {
|
||||
@@ -1587,6 +1592,24 @@ void init_triton_translation(py::module &m) {
|
||||
const std::vector<std::string> &paths) {
|
||||
::mlir::triton::addExternalLibs(op, names, paths);
|
||||
});
|
||||
|
||||
m.def(
|
||||
"translate_llvmir_to_hsaco",
|
||||
[](const std::string llvmIR, std::string gfx_arch, std::string gfx_triple,
|
||||
std::string gfx_features) -> std::tuple<std::string, std::string> {
|
||||
// create LLVM module from C++
|
||||
llvm::LLVMContext context;
|
||||
std::unique_ptr<llvm::MemoryBuffer> buffer =
|
||||
llvm::MemoryBuffer::getMemBuffer(llvmIR.c_str());
|
||||
llvm::SMDiagnostic error;
|
||||
std::unique_ptr<llvm::Module> module =
|
||||
llvm::parseIR(buffer->getMemBufferRef(), error, context);
|
||||
// translate module to HSACO
|
||||
auto hsacoCode = triton::translateLLVMIRToHSACO(
|
||||
*module, gfx_arch, gfx_triple, gfx_features);
|
||||
return hsacoCode;
|
||||
},
|
||||
ret::take_ownership);
|
||||
}
|
||||
|
||||
void init_triton(py::module &m) {
|
||||
|
||||
@@ -19,6 +19,7 @@ from pathlib import Path
|
||||
from typing import Any, Callable, Dict, Optional, Tuple, Type, Union
|
||||
|
||||
import setuptools
|
||||
import torch
|
||||
from filelock import FileLock
|
||||
|
||||
import triton
|
||||
@@ -27,6 +28,14 @@ from . import impl
|
||||
from .tools.disasm import extract
|
||||
|
||||
|
||||
def static_vars(**kwargs):
|
||||
def decorate(func):
|
||||
for k in kwargs:
|
||||
setattr(func, k, kwargs[k])
|
||||
return func
|
||||
return decorate
|
||||
|
||||
|
||||
def str_to_ty(name):
|
||||
if name[0] == "*":
|
||||
ty = str_to_ty(name[1:])
|
||||
@@ -1172,6 +1181,29 @@ def ptx_get_kernel_name(ptx: str) -> str:
|
||||
return line.split()[-1]
|
||||
|
||||
|
||||
def amdgcn_get_kernel_name(amdgcn: str) -> str:
|
||||
'''
|
||||
Get kernel name from AMDGCN code.
|
||||
This Kernel name is required when launching the kernel.
|
||||
'''
|
||||
assert amdgcn
|
||||
for line in amdgcn.split('\n'):
|
||||
line = line.strip()
|
||||
if line.startswith('.globl'):
|
||||
return line.split()[-1].strip()
|
||||
|
||||
|
||||
def llir_to_amdgcn_and_hsaco(mod: Any, gfx_arch: str, gfx_triple: str, gfx_features: str) -> Tuple[str, str]:
|
||||
'''
|
||||
Translate TritonGPU module to HSACO code based on full details of gpu architecture.
|
||||
:param mod: a TritonGPU dialect module
|
||||
:return:
|
||||
- AMDGCN code
|
||||
- Path to HSACO object
|
||||
'''
|
||||
return _triton.translate_llvmir_to_hsaco(mod, gfx_arch, gfx_triple, gfx_features)
|
||||
|
||||
|
||||
@functools.lru_cache
|
||||
def ptx_get_version(cuda_version) -> int:
|
||||
'''
|
||||
@@ -1215,7 +1247,7 @@ instance_descriptor = namedtuple("instance_descriptor", ["divisible_by_16", "equ
|
||||
|
||||
def ty_to_cpp(ty):
|
||||
if ty[0] == '*':
|
||||
return "CUdeviceptr"
|
||||
return "hipDeviceptr_t" if torch.version.hip is not None else "CUdeviceptr"
|
||||
return {
|
||||
"i1": "int32_t",
|
||||
"i8": "int8_t",
|
||||
@@ -1280,7 +1312,148 @@ def generate_launcher(constants, signature):
|
||||
format = "iiiiiKKOOO" + ''.join([format_of(_extracted_type(ty)) for ty in signature.values()])
|
||||
|
||||
# generate glue code
|
||||
src = f"""
|
||||
if torch.version.hip is not None:
|
||||
src = f"""
|
||||
#define __HIP_PLATFORM_AMD__
|
||||
#include <hip/hip_runtime.h>
|
||||
#include <Python.h>
|
||||
#include <stdio.h>
|
||||
|
||||
static inline void gpuAssert(hipError_t code, const char *file, int line)
|
||||
{{
|
||||
if (code != HIP_SUCCESS)
|
||||
{{
|
||||
const char* prefix = "Triton Error [HIP]: ";
|
||||
const char* str = hipGetErrorString(code);
|
||||
char err[1024] = {{0}};
|
||||
snprintf(err, 1024, "%s Code: %d, Messsage: %s", prefix, code, str );
|
||||
PyErr_SetString(PyExc_RuntimeError, err);
|
||||
}}
|
||||
}}
|
||||
|
||||
#define HIP_CHECK(ans) {{ gpuAssert((ans), __FILE__, __LINE__); }}
|
||||
|
||||
static void _launch(int gridX, int gridY, int gridZ, int num_warps, int shared_memory, hipStream_t stream, hipFunction_t function, {arg_decls}) {{
|
||||
void *params[] = {{ {', '.join(f"&arg{i}" for i in signature.keys() if i not in constants)} }};
|
||||
if (gridX*gridY*gridZ > 0) {{
|
||||
HIP_CHECK(hipModuleLaunchKernel(function, gridX, gridY, gridZ, 64*num_warps, 1, 1, shared_memory, stream, params, 0));
|
||||
}}
|
||||
}}
|
||||
|
||||
typedef struct _DevicePtrInfo {{
|
||||
hipDeviceptr_t dev_ptr;
|
||||
bool valid;
|
||||
}} DevicePtrInfo;
|
||||
|
||||
static inline DevicePtrInfo getPointer(PyObject *obj, int idx) {{
|
||||
DevicePtrInfo ptr_info;
|
||||
ptr_info.dev_ptr = 0;
|
||||
ptr_info.valid = true;
|
||||
|
||||
if (PyLong_Check(obj)) {{
|
||||
ptr_info.dev_ptr = (hipDeviceptr_t)PyLong_AsUnsignedLongLong(obj);
|
||||
return ptr_info;
|
||||
}}
|
||||
|
||||
if (obj == Py_None) {{
|
||||
// valid nullptr
|
||||
return ptr_info;
|
||||
}}
|
||||
|
||||
PyObject *ptr = PyObject_GetAttrString(obj, "data_ptr");
|
||||
|
||||
if (ptr) {{
|
||||
PyObject *empty_tuple = PyTuple_New(0);
|
||||
PyObject *ret = PyObject_Call(ptr, empty_tuple, NULL);
|
||||
Py_DECREF(empty_tuple);
|
||||
Py_DECREF(ptr);
|
||||
|
||||
if (!PyLong_Check(ret)) {{
|
||||
PyErr_SetString(PyExc_TypeError, "data_ptr method of Pointer object must return 64-bit int");
|
||||
ptr_info.valid = false;
|
||||
return ptr_info;
|
||||
}}
|
||||
|
||||
ptr_info.dev_ptr = (hipDeviceptr_t)PyLong_AsUnsignedLongLong(ret);
|
||||
|
||||
if (!ptr_info.dev_ptr)
|
||||
return ptr_info;
|
||||
|
||||
uint64_t dev_ptr;
|
||||
hipError_t status = hipPointerGetAttribute(&dev_ptr, HIP_POINTER_ATTRIBUTE_DEVICE_POINTER, ptr_info.dev_ptr);
|
||||
if (status == hipErrorInvalidValue) {{
|
||||
PyErr_Format(PyExc_ValueError,
|
||||
"Pointer argument (at %d) cannot be accessed from Triton (cpu tensor?)", idx);
|
||||
ptr_info.valid = false;
|
||||
}}
|
||||
|
||||
ptr_info.dev_ptr = (hipDeviceptr_t)dev_ptr;
|
||||
return ptr_info;
|
||||
}}
|
||||
|
||||
PyErr_SetString(PyExc_TypeError, "Pointer argument must be either uint64 or have data_ptr method");
|
||||
return ptr_info;
|
||||
}}
|
||||
|
||||
static PyObject* launch(PyObject* self, PyObject* args) {{
|
||||
|
||||
int gridX, gridY, gridZ;
|
||||
uint64_t _stream;
|
||||
uint64_t _function;
|
||||
int num_warps;
|
||||
int shared_memory;
|
||||
PyObject *launch_enter_hook = NULL;
|
||||
PyObject *launch_exit_hook = NULL;
|
||||
PyObject *compiled_kernel = NULL;
|
||||
|
||||
{' '.join([f"{_extracted_type(ty)} _arg{i}; " for i, ty in signature.items()])}
|
||||
if (!PyArg_ParseTuple(args, \"{format}\", &gridX, &gridY, &gridZ, &num_warps, &shared_memory, &_stream, &_function, &launch_enter_hook, &launch_exit_hook, &compiled_kernel, {', '.join(f"&_arg{i}" for i, ty in signature.items())})) {{
|
||||
return NULL;
|
||||
}}
|
||||
|
||||
if (launch_enter_hook != Py_None) {{
|
||||
PyObject_CallObject(launch_enter_hook, args);
|
||||
}}
|
||||
|
||||
// raise exception asap
|
||||
{"; ".join([f"DevicePtrInfo ptr_info{i} = getPointer(_arg{i}, {i}); if (!ptr_info{i}.valid) return NULL;" if ty[0] == "*" else "" for i, ty in signature.items()])};
|
||||
_launch(gridX, gridY, gridZ, num_warps, shared_memory, (hipStream_t)_stream, (hipFunction_t)_function, {', '.join(f"ptr_info{i}.dev_ptr" if ty[0]=="*" else f"_arg{i}" for i, ty in signature.items())});
|
||||
if (launch_exit_hook != Py_None) {{
|
||||
PyObject_CallObject(launch_exit_hook, args);
|
||||
}}
|
||||
if (PyErr_Occurred()) {{
|
||||
return NULL;
|
||||
}}
|
||||
|
||||
// return None
|
||||
Py_INCREF(Py_None);
|
||||
return Py_None;
|
||||
}}
|
||||
|
||||
static PyMethodDef ModuleMethods[] = {{
|
||||
{{"launch", launch, METH_VARARGS, "Entry point for all kernels with this signature"}},
|
||||
{{NULL, NULL, 0, NULL}} // sentinel
|
||||
}};
|
||||
|
||||
static struct PyModuleDef ModuleDef = {{
|
||||
PyModuleDef_HEAD_INIT,
|
||||
\"__triton_launcher\",
|
||||
NULL, //documentation
|
||||
-1, //size
|
||||
ModuleMethods
|
||||
}};
|
||||
|
||||
PyMODINIT_FUNC PyInit___triton_launcher(void) {{
|
||||
PyObject *m = PyModule_Create(&ModuleDef);
|
||||
if(m == NULL) {{
|
||||
return NULL;
|
||||
}}
|
||||
PyModule_AddFunctions(m, ModuleMethods);
|
||||
return m;
|
||||
}}
|
||||
"""
|
||||
else:
|
||||
src = f"""
|
||||
#include \"cuda.h\"
|
||||
#include <stdbool.h>
|
||||
#include <Python.h>
|
||||
@@ -1411,7 +1584,6 @@ PyMODINIT_FUNC PyInit___triton_launcher(void) {{
|
||||
return m;
|
||||
}}
|
||||
"""
|
||||
|
||||
return src
|
||||
|
||||
|
||||
@@ -1474,17 +1646,26 @@ def quiet():
|
||||
sys.stdout, sys.stderr = old_stdout, old_stderr
|
||||
|
||||
|
||||
def _build(name, src, srcdir):
|
||||
cuda_lib_dirs = libcuda_dirs()
|
||||
base_dir = os.path.dirname(__file__)
|
||||
cuda_path = os.path.join(base_dir, "third_party", "cuda")
|
||||
@functools.lru_cache()
|
||||
def rocm_path_dir():
|
||||
return os.getenv("ROCM_PATH", default="/opt/rocm")
|
||||
|
||||
cu_include_dir = os.path.join(cuda_path, "include")
|
||||
triton_include_dir = os.path.join(os.path.dirname(__file__), "include")
|
||||
cuda_header = os.path.join(cu_include_dir, "cuda.h")
|
||||
triton_cuda_header = os.path.join(triton_include_dir, "cuda.h")
|
||||
if not os.path.exists(cuda_header) and os.path.exists(triton_cuda_header):
|
||||
cu_include_dir = triton_include_dir
|
||||
|
||||
def _build(name, src, srcdir):
|
||||
if torch.version.hip is not None:
|
||||
hip_lib_dir = os.path.join(rocm_path_dir(), "lib")
|
||||
hip_include_dir = os.path.join(rocm_path_dir(), "include")
|
||||
else:
|
||||
cuda_lib_dirs = libcuda_dirs()
|
||||
base_dir = os.path.dirname(__file__)
|
||||
cuda_path = os.path.join(base_dir, "third_party", "cuda")
|
||||
|
||||
cu_include_dir = os.path.join(cuda_path, "include")
|
||||
triton_include_dir = os.path.join(os.path.dirname(__file__), "include")
|
||||
cuda_header = os.path.join(cu_include_dir, "cuda.h")
|
||||
triton_cuda_header = os.path.join(triton_include_dir, "cuda.h")
|
||||
if not os.path.exists(cuda_header) and os.path.exists(triton_cuda_header):
|
||||
cu_include_dir = triton_include_dir
|
||||
suffix = sysconfig.get_config_var('EXT_SUFFIX')
|
||||
so = os.path.join(srcdir, '{name}{suffix}'.format(name=name, suffix=suffix))
|
||||
# try to avoid setuptools if possible
|
||||
@@ -1507,9 +1688,12 @@ def _build(name, src, srcdir):
|
||||
scheme = 'posix_prefix'
|
||||
py_include_dir = sysconfig.get_paths(scheme=scheme)["include"]
|
||||
|
||||
cc_cmd = [cc, src, "-O3", f"-I{cu_include_dir}", f"-I{py_include_dir}", f"-I{srcdir}", "-shared", "-fPIC", "-lcuda", "-o", so]
|
||||
cc_cmd += [f"-L{dir}" for dir in cuda_lib_dirs]
|
||||
ret = subprocess.check_call(cc_cmd)
|
||||
if torch.version.hip is not None:
|
||||
ret = subprocess.check_call([cc, src, f"-I{hip_include_dir}", f"-I{py_include_dir}", f"-I{srcdir}", "-shared", "-fPIC", f"-L{hip_lib_dir}", "-lamdhip64", "-o", so])
|
||||
else:
|
||||
cc_cmd = [cc, src, "-O3", f"-I{cu_include_dir}", f"-I{py_include_dir}", f"-I{srcdir}", "-shared", "-fPIC", "-lcuda", "-o", so]
|
||||
cc_cmd += [f"-L{dir}" for dir in cuda_lib_dirs]
|
||||
ret = subprocess.check_call(cc_cmd)
|
||||
|
||||
if ret == 0:
|
||||
return so
|
||||
@@ -1579,7 +1763,26 @@ def read_or_execute(cache_manager, force_compile, file_name, metadata,
|
||||
cache_manager.put(data, file_name, True if isinstance(data, bytes) else data)
|
||||
return module, md5, True, False
|
||||
|
||||
#
|
||||
|
||||
def get_amdgpu_arch_fulldetails():
|
||||
"""
|
||||
get the amdgpu fulll ISA details for compiling:
|
||||
i.e., arch_triple: amdgcn-amd-amdhsa; arch_name: gfx906; arch_features: sramecc+:xnack-
|
||||
"""
|
||||
try:
|
||||
rocminfo = subprocess.check_output(rocm_path_dir() + '/bin/rocminfo').decode()
|
||||
gfx_arch_details = re.search('amd.*', rocminfo).group(0).strip().split('--')
|
||||
arch_triple = gfx_arch_details[0]
|
||||
arch_name_features = gfx_arch_details[1].split(':')
|
||||
arch_name = arch_name_features[0]
|
||||
arch_features = ""
|
||||
|
||||
if (len(arch_name_features) == 3):
|
||||
arch_features = "+" + re.search('\\w+', arch_name_features[1]).group(0) + ","\
|
||||
"-" + re.search('\\w+', arch_name_features[2]).group(0)
|
||||
return [arch_triple, arch_name, arch_features]
|
||||
except BaseException:
|
||||
return None
|
||||
|
||||
|
||||
def make_stub(name, signature, constants):
|
||||
@@ -1663,6 +1866,7 @@ def _get_jsonable_constants(constants):
|
||||
# def compile(fn, signature: str, device: int = -1, constants=dict(), num_warps: int = 4, num_stages: int = 3, extern_libs=None, configs=None):
|
||||
|
||||
|
||||
@static_vars(discovered_gfx_arch_fulldetails=get_amdgpu_arch_fulldetails())
|
||||
def compile(fn, **kwargs):
|
||||
capability = kwargs.get("cc", None)
|
||||
if capability is None:
|
||||
@@ -1680,19 +1884,49 @@ def compile(fn, **kwargs):
|
||||
extern_libs = kwargs.get("extern_libs", dict())
|
||||
debug = kwargs.get("debug", False)
|
||||
# build compilation stages
|
||||
stages = {
|
||||
"ast": (lambda path: fn, None),
|
||||
"ttir": (lambda path: parse_mlir_module(path, context),
|
||||
lambda src: ast_to_ttir(src, signature, configs[0], constants, debug)),
|
||||
"ttgir": (lambda path: parse_mlir_module(path, context),
|
||||
lambda src: optimize_ttgir(ttir_to_ttgir(src, num_warps), num_stages, capability)),
|
||||
"llir": (lambda path: Path(path).read_text(),
|
||||
lambda src: ttgir_to_llir(src, extern_libs, capability)),
|
||||
"ptx": (lambda path: Path(path).read_text(),
|
||||
lambda src: llir_to_ptx(src, capability)),
|
||||
"cubin": (lambda path: Path(path).read_bytes(),
|
||||
lambda src: ptx_to_cubin(src, capability))
|
||||
}
|
||||
if torch.version.hip is not None:
|
||||
_triton.set_rocm()
|
||||
if extern_libs is None:
|
||||
extern_libs = get_amdgcn_bitcode_paths()
|
||||
else:
|
||||
extern_libs.update(get_amdgcn_bitcode_paths())
|
||||
|
||||
for key in list(extern_libs):
|
||||
if extern_libs[key] == '' or extern_libs[key] is None:
|
||||
extern_libs.pop(key)
|
||||
|
||||
gfx_arch_full_details = compile.discovered_gfx_arch_fulldetails
|
||||
gfx_arch = os.environ.get('MI_GPU_ARCH', gfx_arch_full_details[1])
|
||||
if gfx_arch is None:
|
||||
raise RuntimeError('gfx_arch is None (not specified)')
|
||||
stages = {
|
||||
"ast": (lambda path: fn, None),
|
||||
"ttir": (lambda path: parse_mlir_module(path, context),
|
||||
lambda src: ast_to_ttir(src, signature, configs[0], constants)),
|
||||
"ttgir": (lambda path: parse_mlir_module(path, context),
|
||||
lambda src: optimize_ttgir(ttir_to_ttgir(src, num_warps), num_stages, capability)),
|
||||
"llir": (lambda path: Path(path).read_text(),
|
||||
lambda src: ttgir_to_llir(src, extern_libs, capability)),
|
||||
"amdgcn": (lambda path: Path(path).read_text(),
|
||||
lambda src: llir_to_amdgcn_and_hsaco(src, gfx_arch,
|
||||
gfx_arch_full_details[0],
|
||||
gfx_arch_full_details[2])),
|
||||
}
|
||||
else:
|
||||
stages = {
|
||||
"ast": (lambda path: fn, None),
|
||||
"ttir": (lambda path: parse_mlir_module(path, context),
|
||||
lambda src: ast_to_ttir(src, signature, configs[0], constants, debug)),
|
||||
"ttgir": (lambda path: parse_mlir_module(path, context),
|
||||
lambda src: optimize_ttgir(ttir_to_ttgir(src, num_warps), num_stages, capability)),
|
||||
"llir": (lambda path: Path(path).read_text(),
|
||||
lambda src: ttgir_to_llir(src, extern_libs, capability)),
|
||||
"ptx": (lambda path: Path(path).read_text(),
|
||||
lambda src: llir_to_ptx(src, capability)),
|
||||
"cubin": (lambda path: Path(path).read_bytes(),
|
||||
lambda src: ptx_to_cubin(src, capability))
|
||||
}
|
||||
|
||||
# find out the signature of the function
|
||||
if isinstance(fn, triton.runtime.JITFunction):
|
||||
configs = kwargs.get("configs", None)
|
||||
@@ -1746,24 +1980,39 @@ def compile(fn, **kwargs):
|
||||
asm = dict()
|
||||
module = fn
|
||||
# run compilation pipeline and populate metadata
|
||||
for ir, (parse, compile) in list(stages.items())[first_stage:]:
|
||||
for ir, (parse, compile_kernel) in list(stages.items())[first_stage:]:
|
||||
path = fn_cache_manager._make_path(f"{name}.{ir}")
|
||||
if ir == ext:
|
||||
next_module = parse(fn)
|
||||
elif os.path.exists(path) and\
|
||||
ir in metadata["ctime"] and\
|
||||
os.path.getctime(path) == metadata["ctime"][ir]:
|
||||
next_module = parse(path)
|
||||
if ir == "amdgcn":
|
||||
next_module = (parse(path), parse(fn_cache_manager._make_path(f"{name}.hsaco_path")))
|
||||
else:
|
||||
next_module = parse(path)
|
||||
else:
|
||||
next_module = compile(module)
|
||||
fn_cache_manager.put(next_module, f"{name}.{ir}")
|
||||
next_module = compile_kernel(module)
|
||||
if ir == "amdgcn":
|
||||
fn_cache_manager.put(next_module[0], f"{name}.{ir}")
|
||||
fn_cache_manager.put(next_module[1], f"{name}.hsaco_path")
|
||||
else:
|
||||
fn_cache_manager.put(next_module, f"{name}.{ir}")
|
||||
if os.path.exists(path):
|
||||
metadata["ctime"][ir] = os.path.getctime(path)
|
||||
asm[ir] = next_module if ir == "cubin" else str(next_module)
|
||||
if ir == "cubin":
|
||||
asm[ir] = next_module
|
||||
elif ir == "amdgcn":
|
||||
asm[ir] = str(next_module[0])
|
||||
else:
|
||||
asm[ir] = str(next_module)
|
||||
if ir == "llir" and "shared" not in metadata:
|
||||
metadata["shared"] = _triton.get_shared_memory_size(module)
|
||||
if ir == "ptx":
|
||||
metadata["name"] = ptx_get_kernel_name(next_module)
|
||||
if ir == "amdgcn":
|
||||
metadata["name"] = amdgcn_get_kernel_name(next_module[0])
|
||||
asm["hsaco_path"] = next_module[1]
|
||||
module = next_module
|
||||
# write-back metadata
|
||||
fn_cache_manager.put(json.dumps(metadata), f"{name}.json", binary=False)
|
||||
@@ -1771,6 +2020,45 @@ def compile(fn, **kwargs):
|
||||
return CompiledKernel(fn, so_path, metadata, asm)
|
||||
|
||||
|
||||
@static_vars(discovered_gfx_arch_fulldetails=get_amdgpu_arch_fulldetails())
|
||||
def _get_amdgcn_bitcode_paths():
|
||||
if torch.version.hip is not None:
|
||||
gpu_arch_agnostic_bitcode_libraries = ["opencl.bc",
|
||||
"ocml.bc",
|
||||
"ockl.bc",
|
||||
"oclc_finite_only_off.bc",
|
||||
"oclc_daz_opt_off.bc",
|
||||
"oclc_correctly_rounded_sqrt_on.bc",
|
||||
"oclc_unsafe_math_off.bc",
|
||||
"oclc_wavefrontsize64_on.bc"]
|
||||
|
||||
gfx_arch = _get_amdgcn_bitcode_paths.discovered_gfx_arch_fulldetails[1]
|
||||
gfx_arch_id = re.search('gfx(\\w+)', gfx_arch).group(1).strip()
|
||||
|
||||
gpu_arch_specific_bitcode_library = 'oclc_isa_version_' + gfx_arch_id + ".bc"
|
||||
bitcode_path_dir = os.path.join(Path(__file__).parent.resolve(), "third_party/rocm/lib/bitcode/")
|
||||
|
||||
amdgcn_bitcode_paths = {}
|
||||
i = 1
|
||||
for bc_lib in gpu_arch_agnostic_bitcode_libraries:
|
||||
bc_path = bitcode_path_dir + bc_lib
|
||||
if os.path.exists(bc_path):
|
||||
amdgcn_bitcode_paths['library_' + str(i)] = bc_path
|
||||
i += 1
|
||||
bc_gfx_path = bitcode_path_dir + gpu_arch_specific_bitcode_library
|
||||
if os.path.exists(bc_gfx_path):
|
||||
amdgcn_bitcode_paths['library_' + str(i)] = bc_gfx_path
|
||||
|
||||
return amdgcn_bitcode_paths
|
||||
else:
|
||||
return {}
|
||||
|
||||
|
||||
@static_vars(amdgcn_bitcode_paths=_get_amdgcn_bitcode_paths())
|
||||
def get_amdgcn_bitcode_paths():
|
||||
return get_amdgcn_bitcode_paths.amdgcn_bitcode_paths
|
||||
|
||||
|
||||
class CompiledKernel:
|
||||
|
||||
# Hooks for external tools to monitor the execution of triton kernels
|
||||
@@ -1803,12 +2091,21 @@ class CompiledKernel:
|
||||
if self.cu_module is not None:
|
||||
return
|
||||
device = triton.runtime.jit.get_current_device()
|
||||
global cuda_utils
|
||||
init_cuda_utils()
|
||||
max_shared = cuda_utils.get_device_properties(device)["max_shared_mem"]
|
||||
if self.shared > max_shared:
|
||||
raise OutOfResources(self.shared, max_shared, "shared memory")
|
||||
mod, func, n_regs, n_spills = cuda_utils.load_binary(self.metadata["name"], self.asm["cubin"], self.shared, device)
|
||||
if torch.version.hip is not None:
|
||||
global hip_utils
|
||||
init_hip_utils()
|
||||
max_shared = hip_utils.get_device_properties(device)["max_shared_mem"]
|
||||
if self.shared > max_shared:
|
||||
raise OutOfResources(self.shared, max_shared, "shared memory")
|
||||
mod, func, n_regs, n_spills = hip_utils.load_binary(self.metadata["name"], self.asm["hsaco_path"], self.shared, device)
|
||||
else:
|
||||
global cuda_utils
|
||||
init_cuda_utils()
|
||||
max_shared = cuda_utils.get_device_properties(device)["max_shared_mem"]
|
||||
if self.shared > max_shared:
|
||||
raise OutOfResources(self.shared, max_shared, "shared memory")
|
||||
mod, func, n_regs, n_spills = cuda_utils.load_binary(self.metadata["name"], self.asm["cubin"], self.shared, device)
|
||||
|
||||
self.n_spills = n_spills
|
||||
self.n_regs = n_regs
|
||||
self.cu_module = mod
|
||||
@@ -1992,3 +2289,153 @@ def init_cuda_utils():
|
||||
|
||||
|
||||
cuda_utils = None
|
||||
|
||||
|
||||
def init_hip_utils():
|
||||
global hip_utils
|
||||
if hip_utils is None:
|
||||
hip_utils = HIPUtils()
|
||||
|
||||
|
||||
hip_utils = None
|
||||
|
||||
|
||||
class HIPUtils(object):
|
||||
def __new__(cls):
|
||||
if not hasattr(cls, 'instance'):
|
||||
cls.instance = super(HIPUtils, cls).__new__(cls)
|
||||
return cls.instance
|
||||
|
||||
def _generate_src(self):
|
||||
return """
|
||||
#define __HIP_PLATFORM_AMD__
|
||||
#include <hip/hip_runtime.h>
|
||||
#define PY_SSIZE_T_CLEAN
|
||||
#include <Python.h>
|
||||
#include <stdio.h>
|
||||
#include <stdlib.h>
|
||||
static inline void gpuAssert(hipError_t code, const char *file, int line)
|
||||
{{
|
||||
if (code != HIP_SUCCESS)
|
||||
{{
|
||||
const char* prefix = "Triton Error [HIP]: ";
|
||||
const char* str = hipGetErrorString(code);
|
||||
char err[1024] = {0};
|
||||
snprintf(err, 1024, "%s Code: %d, Messsage: %s", prefix, code, str );
|
||||
PyErr_SetString(PyExc_RuntimeError, err);
|
||||
}}
|
||||
}}
|
||||
|
||||
#define HIP_CHECK(ans) { gpuAssert((ans), __FILE__, __LINE__); if (PyErr_Occurred()) return NULL; }
|
||||
|
||||
static PyObject* getDeviceProperties(PyObject* self, PyObject* args){
|
||||
int device_id;
|
||||
if (!PyArg_ParseTuple(args, "i", &device_id))
|
||||
return NULL;
|
||||
|
||||
hipDeviceProp_t props;
|
||||
HIP_CHECK(hipGetDeviceProperties(&props, device_id));
|
||||
|
||||
// create a struct to hold device properties
|
||||
return Py_BuildValue("{s:i, s:i, s:i, s:i, s:i}", "max_shared_mem", props.sharedMemPerBlock,
|
||||
"multiprocessor_count", props.multiProcessorCount,
|
||||
"sm_clock_rate", props.clockRate,
|
||||
"mem_clock_rate", props.memoryClockRate,
|
||||
"mem_bus_width", props.memoryBusWidth);
|
||||
}
|
||||
|
||||
static PyObject* loadBinary(PyObject* self, PyObject* args) {
|
||||
const char* name;
|
||||
const char* data;
|
||||
Py_ssize_t data_size;
|
||||
int shared;
|
||||
int device;
|
||||
if (!PyArg_ParseTuple(args, "ss#ii", &name, &data, &data_size, &shared, &device)) {
|
||||
return NULL;
|
||||
}
|
||||
|
||||
// Open HSACO file
|
||||
FILE* hsaco_file;
|
||||
if ((hsaco_file = fopen(data, "rb")) == NULL) {
|
||||
return NULL;
|
||||
}
|
||||
|
||||
// Read HSCAO file into Buffer
|
||||
fseek(hsaco_file, 0L, SEEK_END);
|
||||
size_t hsaco_file_size = ftell(hsaco_file);
|
||||
unsigned char* hsaco = (unsigned char*) malloc(hsaco_file_size * sizeof(unsigned char));
|
||||
rewind(hsaco_file);
|
||||
fread(hsaco, sizeof(unsigned char), hsaco_file_size, hsaco_file);
|
||||
fclose(hsaco_file);
|
||||
|
||||
// set HIP options
|
||||
hipJitOption opt[] = {hipJitOptionErrorLogBufferSizeBytes, hipJitOptionErrorLogBuffer,
|
||||
hipJitOptionInfoLogBufferSizeBytes, hipJitOptionInfoLogBuffer,
|
||||
hipJitOptionLogVerbose};
|
||||
const unsigned int errbufsize = 8192;
|
||||
const unsigned int logbufsize = 8192;
|
||||
char _err[errbufsize];
|
||||
char _log[logbufsize];
|
||||
void *optval[] = {(void *)(uintptr_t)errbufsize,
|
||||
(void *)_err, (void *)(uintptr_t)logbufsize,
|
||||
(void *)_log, (void *)1};
|
||||
|
||||
// launch HIP Binary
|
||||
hipModule_t mod;
|
||||
hipFunction_t fun;
|
||||
hipModuleLoadDataEx(&mod, hsaco, 5, opt, optval);
|
||||
hipModuleGetFunction(&fun, mod, name);
|
||||
free(hsaco);
|
||||
|
||||
// get allocated registers and spilled registers from the function
|
||||
int n_regs = 0;
|
||||
int n_spills = 0;
|
||||
if (PyErr_Occurred()) {
|
||||
return NULL;
|
||||
}
|
||||
return Py_BuildValue("(KKii)", (uint64_t)mod, (uint64_t)fun, n_regs, n_spills);
|
||||
}
|
||||
|
||||
static PyMethodDef ModuleMethods[] = {
|
||||
{"load_binary", loadBinary, METH_VARARGS, "Load provided hsaco into HIP driver"},
|
||||
{"get_device_properties", getDeviceProperties, METH_VARARGS, "Get the properties for a given device"},
|
||||
{NULL, NULL, 0, NULL} // sentinel
|
||||
};
|
||||
|
||||
static struct PyModuleDef ModuleDef = {
|
||||
PyModuleDef_HEAD_INIT,
|
||||
"hip_utils",
|
||||
NULL, //documentation
|
||||
-1, //size
|
||||
ModuleMethods
|
||||
};
|
||||
|
||||
PyMODINIT_FUNC PyInit_hip_utils(void) {
|
||||
PyObject *m = PyModule_Create(&ModuleDef);
|
||||
if (m == NULL) {
|
||||
return NULL;
|
||||
}
|
||||
PyModule_AddFunctions(m, ModuleMethods);
|
||||
return m;
|
||||
}
|
||||
"""
|
||||
|
||||
def __init__(self):
|
||||
src = self._generate_src()
|
||||
key = hashlib.md5(src.encode("utf-8")).hexdigest()
|
||||
cache = CacheManager(key)
|
||||
fname = "hip_utils.so"
|
||||
if not cache.has_file(fname):
|
||||
with tempfile.TemporaryDirectory() as tmpdir:
|
||||
src_path = os.path.join(tmpdir, "main.c")
|
||||
with open(src_path, "w") as f:
|
||||
f.write(src)
|
||||
so = _build("hip_utils", src_path, tmpdir)
|
||||
with open(so, "rb") as f:
|
||||
cache.put(f.read(), fname, binary=True)
|
||||
import importlib.util
|
||||
spec = importlib.util.spec_from_file_location("hip_utils", cache._make_path(fname))
|
||||
mod = importlib.util.module_from_spec(spec)
|
||||
spec.loader.exec_module(mod)
|
||||
self.load_binary = mod.load_binary
|
||||
self.get_device_properties = mod.get_device_properties
|
||||
|
||||
@@ -1,9 +1,14 @@
|
||||
import os
|
||||
|
||||
import torch
|
||||
|
||||
from .. import impl
|
||||
from . import core, extern
|
||||
|
||||
LOCAL_PATH = os.path.join(os.path.dirname(os.path.abspath(__file__)), "..", "third_party", "cuda", "lib", "libdevice.10.bc")
|
||||
if torch.version.hip is not None:
|
||||
LOCAL_PATH = os.path.join(os.path.dirname(os.path.abspath(__file__)), "cuda2gcn.bc")
|
||||
else:
|
||||
LOCAL_PATH = os.path.join(os.path.dirname(os.path.abspath(__file__)), "..", "third_party", "cuda", "lib", "libdevice.10.bc")
|
||||
LIBDEVICE_PATH = os.getenv("TRITON_LIBDEVICE_PATH", LOCAL_PATH)
|
||||
|
||||
|
||||
|
||||
@@ -3,11 +3,12 @@ import sys
|
||||
|
||||
import triton
|
||||
import triton._C.libtriton.triton as libtriton
|
||||
import triton.compiler as tc
|
||||
|
||||
if __name__ == '__main__':
|
||||
|
||||
# valid source and target formats
|
||||
VALID_FORMATS = ['triton-ir', 'triton-gpu-ir', 'llvm-ir', 'ptx']
|
||||
VALID_FORMATS = ['triton-ir', 'triton-gpu-ir', 'llvm-ir', 'ptx', 'amdgcn']
|
||||
|
||||
# set up the argument parser
|
||||
# TODO: conditional requirements
|
||||
@@ -17,6 +18,10 @@ if __name__ == '__main__':
|
||||
help="Target format, one of: " + ', '.join(VALID_FORMATS))
|
||||
parser.add_argument('--sm', type=int, help="Compute capability to compile for")
|
||||
parser.add_argument('--ptx-version', type=int, help="PTX version to compile for")
|
||||
parser.add_argument('--gfx', type=str, help="AMDGPU target to compile for")
|
||||
parser.add_argument('--triple', type=str, help="target triple, for example: amdgcn-amd-amdhsa")
|
||||
parser.add_argument('--features', type=str, help="target features, for example: +sramecc,-xnack")
|
||||
parser.add_argument('--num_warps', type=int, help="number of warps to compile ttgir for")
|
||||
|
||||
# parse the args
|
||||
args = parser.parse_args()
|
||||
@@ -38,11 +43,52 @@ if __name__ == '__main__':
|
||||
print(module.str())
|
||||
sys.exit(0)
|
||||
|
||||
if not args.num_warps:
|
||||
args.num_warps = 4
|
||||
|
||||
# llvm-ir -> amdgcn
|
||||
if args.target == 'amdgcn':
|
||||
# auto detect available architecture and features
|
||||
# if nothing detected, set with default values
|
||||
arch_details = tc.get_amdgpu_arch_fulldetails()
|
||||
if not arch_details:
|
||||
arch_name = ""
|
||||
arch_triple = "amdgcn-amd-amdhsa"
|
||||
arch_features = ""
|
||||
else:
|
||||
arch_triple, arch_name, arch_features = arch_details
|
||||
|
||||
# stop processing if architecture name is not automatically detected and is not set manually
|
||||
if not args.gfx and not arch_name:
|
||||
raise argparse.ArgumentError(None, "Must specify --gfx for AMDGCN compilation")
|
||||
|
||||
# rewrite default and automatically detected values with manually provided data
|
||||
if args.gfx:
|
||||
arch_name = args.gfx
|
||||
if args.triple:
|
||||
arch_triple = args.triple
|
||||
if args.features:
|
||||
arch_features = args.features
|
||||
|
||||
# triton-ir -> triton-gpu-ir
|
||||
# use compute_capability == 80
|
||||
module = triton.compiler.ttir_to_ttgir(module, num_warps=args.num_warps) # num_stages=3, compute_capability=80)
|
||||
module = triton.compiler.optimize_ttgir(module, num_stages=3, compute_capability=80)
|
||||
# triton-gpu-ir -> llvm-ir
|
||||
# use compute_capability == 80
|
||||
module = triton.compiler.ttgir_to_llir(module, extern_libs=None, compute_capability=80)
|
||||
# llvm-ir -> amdgcn asm, hsaco binary
|
||||
module, hsaco_path = triton.compiler.llir_to_amdgcn_and_hsaco(module, arch_name, arch_triple, arch_features)
|
||||
|
||||
print(hsaco_path)
|
||||
print(module)
|
||||
sys.exit(0)
|
||||
|
||||
if not args.sm:
|
||||
raise argparse.ArgumentError(None, "Must specify --sm for PTX compilation")
|
||||
|
||||
# triton-ir -> triton-gpu-ir
|
||||
module = triton.compiler.ttir_to_ttgir(module, num_warps=4)
|
||||
module = triton.compiler.ttir_to_ttgir(module, num_warps=args.num_warps)
|
||||
module = triton.compiler.optimize_ttgir(module, num_stages=3, compute_capability=args.sm)
|
||||
if args.target == 'triton-gpu-ir':
|
||||
print(module.str())
|
||||
@@ -54,10 +100,16 @@ if __name__ == '__main__':
|
||||
print(module)
|
||||
sys.exit(0)
|
||||
|
||||
if not args.ptx_version:
|
||||
raise argparse.ArgumentError(None, "Must specify --ptx-version for PTX compilation")
|
||||
|
||||
# llvm-ir -> ptx
|
||||
module = triton.compiler.llir_to_ptx(module, compute_capability=args.sm, ptx_version=args.ptx_version)
|
||||
assert args.target == 'ptx'
|
||||
if args.target == 'ptx':
|
||||
if not args.ptx_version:
|
||||
raise argparse.ArgumentError(None, "Must specify --ptx-version for PTX compilation")
|
||||
module = triton.compiler.llir_to_ptx(module, compute_capability=args.sm, ptx_version=args.ptx_version)
|
||||
|
||||
# llvm-ir -> amdgcn
|
||||
if args.target == 'amdgcn':
|
||||
if not args.gfx:
|
||||
raise argparse.ArgumentError(None, "Must specify --gfx for AMDGCN compilation")
|
||||
module, hsaco_path = triton.compiler.llir_to_amdgcn_and_hsaco(module, args.gfx)
|
||||
|
||||
print(module)
|
||||
|
||||
Reference in New Issue
Block a user