[ROCM] Enable ROCM Backend #1: Empty Kernel (#1312)

This PR is a first in a series of PRs to import the changes that we have
made to enable ROCM on [our
fork](https://github.com/ROCmSoftwarePlatform/triton) of triton.

The PR contains the major changes to the python frontend and enough
changes to the c++ backend to allow compilation and running of the empty
kernel. We use the ROCM ci added a few weeks ago to verify things.

---------

Co-authored-by: Ronan Keryell <ronan@keryell.fr>
This commit is contained in:
Michael Melesse
2023-03-24 20:18:27 -04:00
committed by GitHub
parent 89d8fe6502
commit a9c87245b4
33 changed files with 1602 additions and 131 deletions

View File

@@ -23,9 +23,11 @@
#include "triton/Dialect/Triton/IR/Types.h"
#include "triton/Dialect/Triton/Transforms/Passes.h"
#include "triton/Dialect/TritonGPU/Transforms/Passes.h"
#include "triton/Target/HSACO/HSACOTranslation.h"
#include "triton/Target/LLVMIR/LLVMIRTranslation.h"
#include "triton/Target/PTX/PTXTranslation.h"
#include "triton/Tools/Sys/GetEnv.hpp"
#include "triton/Tools/Sys/GetPlatform.hpp"
#include "llvm/IR/LegacyPassManager.h"
#include "llvm/IR/Module.h"
@@ -64,7 +66,7 @@ void init_triton_runtime(py::module &&m) {
py::enum_<backend_t>(m, "backend")
.value("HOST", HOST)
.value("CUDA", CUDA)
// .value("ROCM", ROCM)
.value("ROCM", ROCM)
.export_values();
}
@@ -1487,6 +1489,9 @@ void init_triton_translation(py::module &m) {
return shared.getInt();
});
m.def(
"set_rocm", []() { setROCM(); }, ret::take_ownership);
m.def(
"translate_triton_gpu_to_llvmir",
[](mlir::ModuleOp op, int computeCapability) {
@@ -1587,6 +1592,24 @@ void init_triton_translation(py::module &m) {
const std::vector<std::string> &paths) {
::mlir::triton::addExternalLibs(op, names, paths);
});
m.def(
"translate_llvmir_to_hsaco",
[](const std::string llvmIR, std::string gfx_arch, std::string gfx_triple,
std::string gfx_features) -> std::tuple<std::string, std::string> {
// create LLVM module from C++
llvm::LLVMContext context;
std::unique_ptr<llvm::MemoryBuffer> buffer =
llvm::MemoryBuffer::getMemBuffer(llvmIR.c_str());
llvm::SMDiagnostic error;
std::unique_ptr<llvm::Module> module =
llvm::parseIR(buffer->getMemBufferRef(), error, context);
// translate module to HSACO
auto hsacoCode = triton::translateLLVMIRToHSACO(
*module, gfx_arch, gfx_triple, gfx_features);
return hsacoCode;
},
ret::take_ownership);
}
void init_triton(py::module &m) {

View File

@@ -19,6 +19,7 @@ from pathlib import Path
from typing import Any, Callable, Dict, Optional, Tuple, Type, Union
import setuptools
import torch
from filelock import FileLock
import triton
@@ -27,6 +28,14 @@ from . import impl
from .tools.disasm import extract
def static_vars(**kwargs):
def decorate(func):
for k in kwargs:
setattr(func, k, kwargs[k])
return func
return decorate
def str_to_ty(name):
if name[0] == "*":
ty = str_to_ty(name[1:])
@@ -1172,6 +1181,29 @@ def ptx_get_kernel_name(ptx: str) -> str:
return line.split()[-1]
def amdgcn_get_kernel_name(amdgcn: str) -> str:
'''
Get kernel name from AMDGCN code.
This Kernel name is required when launching the kernel.
'''
assert amdgcn
for line in amdgcn.split('\n'):
line = line.strip()
if line.startswith('.globl'):
return line.split()[-1].strip()
def llir_to_amdgcn_and_hsaco(mod: Any, gfx_arch: str, gfx_triple: str, gfx_features: str) -> Tuple[str, str]:
'''
Translate TritonGPU module to HSACO code based on full details of gpu architecture.
:param mod: a TritonGPU dialect module
:return:
- AMDGCN code
- Path to HSACO object
'''
return _triton.translate_llvmir_to_hsaco(mod, gfx_arch, gfx_triple, gfx_features)
@functools.lru_cache
def ptx_get_version(cuda_version) -> int:
'''
@@ -1215,7 +1247,7 @@ instance_descriptor = namedtuple("instance_descriptor", ["divisible_by_16", "equ
def ty_to_cpp(ty):
if ty[0] == '*':
return "CUdeviceptr"
return "hipDeviceptr_t" if torch.version.hip is not None else "CUdeviceptr"
return {
"i1": "int32_t",
"i8": "int8_t",
@@ -1280,7 +1312,148 @@ def generate_launcher(constants, signature):
format = "iiiiiKKOOO" + ''.join([format_of(_extracted_type(ty)) for ty in signature.values()])
# generate glue code
src = f"""
if torch.version.hip is not None:
src = f"""
#define __HIP_PLATFORM_AMD__
#include <hip/hip_runtime.h>
#include <Python.h>
#include <stdio.h>
static inline void gpuAssert(hipError_t code, const char *file, int line)
{{
if (code != HIP_SUCCESS)
{{
const char* prefix = "Triton Error [HIP]: ";
const char* str = hipGetErrorString(code);
char err[1024] = {{0}};
snprintf(err, 1024, "%s Code: %d, Messsage: %s", prefix, code, str );
PyErr_SetString(PyExc_RuntimeError, err);
}}
}}
#define HIP_CHECK(ans) {{ gpuAssert((ans), __FILE__, __LINE__); }}
static void _launch(int gridX, int gridY, int gridZ, int num_warps, int shared_memory, hipStream_t stream, hipFunction_t function, {arg_decls}) {{
void *params[] = {{ {', '.join(f"&arg{i}" for i in signature.keys() if i not in constants)} }};
if (gridX*gridY*gridZ > 0) {{
HIP_CHECK(hipModuleLaunchKernel(function, gridX, gridY, gridZ, 64*num_warps, 1, 1, shared_memory, stream, params, 0));
}}
}}
typedef struct _DevicePtrInfo {{
hipDeviceptr_t dev_ptr;
bool valid;
}} DevicePtrInfo;
static inline DevicePtrInfo getPointer(PyObject *obj, int idx) {{
DevicePtrInfo ptr_info;
ptr_info.dev_ptr = 0;
ptr_info.valid = true;
if (PyLong_Check(obj)) {{
ptr_info.dev_ptr = (hipDeviceptr_t)PyLong_AsUnsignedLongLong(obj);
return ptr_info;
}}
if (obj == Py_None) {{
// valid nullptr
return ptr_info;
}}
PyObject *ptr = PyObject_GetAttrString(obj, "data_ptr");
if (ptr) {{
PyObject *empty_tuple = PyTuple_New(0);
PyObject *ret = PyObject_Call(ptr, empty_tuple, NULL);
Py_DECREF(empty_tuple);
Py_DECREF(ptr);
if (!PyLong_Check(ret)) {{
PyErr_SetString(PyExc_TypeError, "data_ptr method of Pointer object must return 64-bit int");
ptr_info.valid = false;
return ptr_info;
}}
ptr_info.dev_ptr = (hipDeviceptr_t)PyLong_AsUnsignedLongLong(ret);
if (!ptr_info.dev_ptr)
return ptr_info;
uint64_t dev_ptr;
hipError_t status = hipPointerGetAttribute(&dev_ptr, HIP_POINTER_ATTRIBUTE_DEVICE_POINTER, ptr_info.dev_ptr);
if (status == hipErrorInvalidValue) {{
PyErr_Format(PyExc_ValueError,
"Pointer argument (at %d) cannot be accessed from Triton (cpu tensor?)", idx);
ptr_info.valid = false;
}}
ptr_info.dev_ptr = (hipDeviceptr_t)dev_ptr;
return ptr_info;
}}
PyErr_SetString(PyExc_TypeError, "Pointer argument must be either uint64 or have data_ptr method");
return ptr_info;
}}
static PyObject* launch(PyObject* self, PyObject* args) {{
int gridX, gridY, gridZ;
uint64_t _stream;
uint64_t _function;
int num_warps;
int shared_memory;
PyObject *launch_enter_hook = NULL;
PyObject *launch_exit_hook = NULL;
PyObject *compiled_kernel = NULL;
{' '.join([f"{_extracted_type(ty)} _arg{i}; " for i, ty in signature.items()])}
if (!PyArg_ParseTuple(args, \"{format}\", &gridX, &gridY, &gridZ, &num_warps, &shared_memory, &_stream, &_function, &launch_enter_hook, &launch_exit_hook, &compiled_kernel, {', '.join(f"&_arg{i}" for i, ty in signature.items())})) {{
return NULL;
}}
if (launch_enter_hook != Py_None) {{
PyObject_CallObject(launch_enter_hook, args);
}}
// raise exception asap
{"; ".join([f"DevicePtrInfo ptr_info{i} = getPointer(_arg{i}, {i}); if (!ptr_info{i}.valid) return NULL;" if ty[0] == "*" else "" for i, ty in signature.items()])};
_launch(gridX, gridY, gridZ, num_warps, shared_memory, (hipStream_t)_stream, (hipFunction_t)_function, {', '.join(f"ptr_info{i}.dev_ptr" if ty[0]=="*" else f"_arg{i}" for i, ty in signature.items())});
if (launch_exit_hook != Py_None) {{
PyObject_CallObject(launch_exit_hook, args);
}}
if (PyErr_Occurred()) {{
return NULL;
}}
// return None
Py_INCREF(Py_None);
return Py_None;
}}
static PyMethodDef ModuleMethods[] = {{
{{"launch", launch, METH_VARARGS, "Entry point for all kernels with this signature"}},
{{NULL, NULL, 0, NULL}} // sentinel
}};
static struct PyModuleDef ModuleDef = {{
PyModuleDef_HEAD_INIT,
\"__triton_launcher\",
NULL, //documentation
-1, //size
ModuleMethods
}};
PyMODINIT_FUNC PyInit___triton_launcher(void) {{
PyObject *m = PyModule_Create(&ModuleDef);
if(m == NULL) {{
return NULL;
}}
PyModule_AddFunctions(m, ModuleMethods);
return m;
}}
"""
else:
src = f"""
#include \"cuda.h\"
#include <stdbool.h>
#include <Python.h>
@@ -1411,7 +1584,6 @@ PyMODINIT_FUNC PyInit___triton_launcher(void) {{
return m;
}}
"""
return src
@@ -1474,17 +1646,26 @@ def quiet():
sys.stdout, sys.stderr = old_stdout, old_stderr
def _build(name, src, srcdir):
cuda_lib_dirs = libcuda_dirs()
base_dir = os.path.dirname(__file__)
cuda_path = os.path.join(base_dir, "third_party", "cuda")
@functools.lru_cache()
def rocm_path_dir():
return os.getenv("ROCM_PATH", default="/opt/rocm")
cu_include_dir = os.path.join(cuda_path, "include")
triton_include_dir = os.path.join(os.path.dirname(__file__), "include")
cuda_header = os.path.join(cu_include_dir, "cuda.h")
triton_cuda_header = os.path.join(triton_include_dir, "cuda.h")
if not os.path.exists(cuda_header) and os.path.exists(triton_cuda_header):
cu_include_dir = triton_include_dir
def _build(name, src, srcdir):
if torch.version.hip is not None:
hip_lib_dir = os.path.join(rocm_path_dir(), "lib")
hip_include_dir = os.path.join(rocm_path_dir(), "include")
else:
cuda_lib_dirs = libcuda_dirs()
base_dir = os.path.dirname(__file__)
cuda_path = os.path.join(base_dir, "third_party", "cuda")
cu_include_dir = os.path.join(cuda_path, "include")
triton_include_dir = os.path.join(os.path.dirname(__file__), "include")
cuda_header = os.path.join(cu_include_dir, "cuda.h")
triton_cuda_header = os.path.join(triton_include_dir, "cuda.h")
if not os.path.exists(cuda_header) and os.path.exists(triton_cuda_header):
cu_include_dir = triton_include_dir
suffix = sysconfig.get_config_var('EXT_SUFFIX')
so = os.path.join(srcdir, '{name}{suffix}'.format(name=name, suffix=suffix))
# try to avoid setuptools if possible
@@ -1507,9 +1688,12 @@ def _build(name, src, srcdir):
scheme = 'posix_prefix'
py_include_dir = sysconfig.get_paths(scheme=scheme)["include"]
cc_cmd = [cc, src, "-O3", f"-I{cu_include_dir}", f"-I{py_include_dir}", f"-I{srcdir}", "-shared", "-fPIC", "-lcuda", "-o", so]
cc_cmd += [f"-L{dir}" for dir in cuda_lib_dirs]
ret = subprocess.check_call(cc_cmd)
if torch.version.hip is not None:
ret = subprocess.check_call([cc, src, f"-I{hip_include_dir}", f"-I{py_include_dir}", f"-I{srcdir}", "-shared", "-fPIC", f"-L{hip_lib_dir}", "-lamdhip64", "-o", so])
else:
cc_cmd = [cc, src, "-O3", f"-I{cu_include_dir}", f"-I{py_include_dir}", f"-I{srcdir}", "-shared", "-fPIC", "-lcuda", "-o", so]
cc_cmd += [f"-L{dir}" for dir in cuda_lib_dirs]
ret = subprocess.check_call(cc_cmd)
if ret == 0:
return so
@@ -1579,7 +1763,26 @@ def read_or_execute(cache_manager, force_compile, file_name, metadata,
cache_manager.put(data, file_name, True if isinstance(data, bytes) else data)
return module, md5, True, False
#
def get_amdgpu_arch_fulldetails():
"""
get the amdgpu fulll ISA details for compiling:
i.e., arch_triple: amdgcn-amd-amdhsa; arch_name: gfx906; arch_features: sramecc+:xnack-
"""
try:
rocminfo = subprocess.check_output(rocm_path_dir() + '/bin/rocminfo').decode()
gfx_arch_details = re.search('amd.*', rocminfo).group(0).strip().split('--')
arch_triple = gfx_arch_details[0]
arch_name_features = gfx_arch_details[1].split(':')
arch_name = arch_name_features[0]
arch_features = ""
if (len(arch_name_features) == 3):
arch_features = "+" + re.search('\\w+', arch_name_features[1]).group(0) + ","\
"-" + re.search('\\w+', arch_name_features[2]).group(0)
return [arch_triple, arch_name, arch_features]
except BaseException:
return None
def make_stub(name, signature, constants):
@@ -1663,6 +1866,7 @@ def _get_jsonable_constants(constants):
# def compile(fn, signature: str, device: int = -1, constants=dict(), num_warps: int = 4, num_stages: int = 3, extern_libs=None, configs=None):
@static_vars(discovered_gfx_arch_fulldetails=get_amdgpu_arch_fulldetails())
def compile(fn, **kwargs):
capability = kwargs.get("cc", None)
if capability is None:
@@ -1680,19 +1884,49 @@ def compile(fn, **kwargs):
extern_libs = kwargs.get("extern_libs", dict())
debug = kwargs.get("debug", False)
# build compilation stages
stages = {
"ast": (lambda path: fn, None),
"ttir": (lambda path: parse_mlir_module(path, context),
lambda src: ast_to_ttir(src, signature, configs[0], constants, debug)),
"ttgir": (lambda path: parse_mlir_module(path, context),
lambda src: optimize_ttgir(ttir_to_ttgir(src, num_warps), num_stages, capability)),
"llir": (lambda path: Path(path).read_text(),
lambda src: ttgir_to_llir(src, extern_libs, capability)),
"ptx": (lambda path: Path(path).read_text(),
lambda src: llir_to_ptx(src, capability)),
"cubin": (lambda path: Path(path).read_bytes(),
lambda src: ptx_to_cubin(src, capability))
}
if torch.version.hip is not None:
_triton.set_rocm()
if extern_libs is None:
extern_libs = get_amdgcn_bitcode_paths()
else:
extern_libs.update(get_amdgcn_bitcode_paths())
for key in list(extern_libs):
if extern_libs[key] == '' or extern_libs[key] is None:
extern_libs.pop(key)
gfx_arch_full_details = compile.discovered_gfx_arch_fulldetails
gfx_arch = os.environ.get('MI_GPU_ARCH', gfx_arch_full_details[1])
if gfx_arch is None:
raise RuntimeError('gfx_arch is None (not specified)')
stages = {
"ast": (lambda path: fn, None),
"ttir": (lambda path: parse_mlir_module(path, context),
lambda src: ast_to_ttir(src, signature, configs[0], constants)),
"ttgir": (lambda path: parse_mlir_module(path, context),
lambda src: optimize_ttgir(ttir_to_ttgir(src, num_warps), num_stages, capability)),
"llir": (lambda path: Path(path).read_text(),
lambda src: ttgir_to_llir(src, extern_libs, capability)),
"amdgcn": (lambda path: Path(path).read_text(),
lambda src: llir_to_amdgcn_and_hsaco(src, gfx_arch,
gfx_arch_full_details[0],
gfx_arch_full_details[2])),
}
else:
stages = {
"ast": (lambda path: fn, None),
"ttir": (lambda path: parse_mlir_module(path, context),
lambda src: ast_to_ttir(src, signature, configs[0], constants, debug)),
"ttgir": (lambda path: parse_mlir_module(path, context),
lambda src: optimize_ttgir(ttir_to_ttgir(src, num_warps), num_stages, capability)),
"llir": (lambda path: Path(path).read_text(),
lambda src: ttgir_to_llir(src, extern_libs, capability)),
"ptx": (lambda path: Path(path).read_text(),
lambda src: llir_to_ptx(src, capability)),
"cubin": (lambda path: Path(path).read_bytes(),
lambda src: ptx_to_cubin(src, capability))
}
# find out the signature of the function
if isinstance(fn, triton.runtime.JITFunction):
configs = kwargs.get("configs", None)
@@ -1746,24 +1980,39 @@ def compile(fn, **kwargs):
asm = dict()
module = fn
# run compilation pipeline and populate metadata
for ir, (parse, compile) in list(stages.items())[first_stage:]:
for ir, (parse, compile_kernel) in list(stages.items())[first_stage:]:
path = fn_cache_manager._make_path(f"{name}.{ir}")
if ir == ext:
next_module = parse(fn)
elif os.path.exists(path) and\
ir in metadata["ctime"] and\
os.path.getctime(path) == metadata["ctime"][ir]:
next_module = parse(path)
if ir == "amdgcn":
next_module = (parse(path), parse(fn_cache_manager._make_path(f"{name}.hsaco_path")))
else:
next_module = parse(path)
else:
next_module = compile(module)
fn_cache_manager.put(next_module, f"{name}.{ir}")
next_module = compile_kernel(module)
if ir == "amdgcn":
fn_cache_manager.put(next_module[0], f"{name}.{ir}")
fn_cache_manager.put(next_module[1], f"{name}.hsaco_path")
else:
fn_cache_manager.put(next_module, f"{name}.{ir}")
if os.path.exists(path):
metadata["ctime"][ir] = os.path.getctime(path)
asm[ir] = next_module if ir == "cubin" else str(next_module)
if ir == "cubin":
asm[ir] = next_module
elif ir == "amdgcn":
asm[ir] = str(next_module[0])
else:
asm[ir] = str(next_module)
if ir == "llir" and "shared" not in metadata:
metadata["shared"] = _triton.get_shared_memory_size(module)
if ir == "ptx":
metadata["name"] = ptx_get_kernel_name(next_module)
if ir == "amdgcn":
metadata["name"] = amdgcn_get_kernel_name(next_module[0])
asm["hsaco_path"] = next_module[1]
module = next_module
# write-back metadata
fn_cache_manager.put(json.dumps(metadata), f"{name}.json", binary=False)
@@ -1771,6 +2020,45 @@ def compile(fn, **kwargs):
return CompiledKernel(fn, so_path, metadata, asm)
@static_vars(discovered_gfx_arch_fulldetails=get_amdgpu_arch_fulldetails())
def _get_amdgcn_bitcode_paths():
if torch.version.hip is not None:
gpu_arch_agnostic_bitcode_libraries = ["opencl.bc",
"ocml.bc",
"ockl.bc",
"oclc_finite_only_off.bc",
"oclc_daz_opt_off.bc",
"oclc_correctly_rounded_sqrt_on.bc",
"oclc_unsafe_math_off.bc",
"oclc_wavefrontsize64_on.bc"]
gfx_arch = _get_amdgcn_bitcode_paths.discovered_gfx_arch_fulldetails[1]
gfx_arch_id = re.search('gfx(\\w+)', gfx_arch).group(1).strip()
gpu_arch_specific_bitcode_library = 'oclc_isa_version_' + gfx_arch_id + ".bc"
bitcode_path_dir = os.path.join(Path(__file__).parent.resolve(), "third_party/rocm/lib/bitcode/")
amdgcn_bitcode_paths = {}
i = 1
for bc_lib in gpu_arch_agnostic_bitcode_libraries:
bc_path = bitcode_path_dir + bc_lib
if os.path.exists(bc_path):
amdgcn_bitcode_paths['library_' + str(i)] = bc_path
i += 1
bc_gfx_path = bitcode_path_dir + gpu_arch_specific_bitcode_library
if os.path.exists(bc_gfx_path):
amdgcn_bitcode_paths['library_' + str(i)] = bc_gfx_path
return amdgcn_bitcode_paths
else:
return {}
@static_vars(amdgcn_bitcode_paths=_get_amdgcn_bitcode_paths())
def get_amdgcn_bitcode_paths():
return get_amdgcn_bitcode_paths.amdgcn_bitcode_paths
class CompiledKernel:
# Hooks for external tools to monitor the execution of triton kernels
@@ -1803,12 +2091,21 @@ class CompiledKernel:
if self.cu_module is not None:
return
device = triton.runtime.jit.get_current_device()
global cuda_utils
init_cuda_utils()
max_shared = cuda_utils.get_device_properties(device)["max_shared_mem"]
if self.shared > max_shared:
raise OutOfResources(self.shared, max_shared, "shared memory")
mod, func, n_regs, n_spills = cuda_utils.load_binary(self.metadata["name"], self.asm["cubin"], self.shared, device)
if torch.version.hip is not None:
global hip_utils
init_hip_utils()
max_shared = hip_utils.get_device_properties(device)["max_shared_mem"]
if self.shared > max_shared:
raise OutOfResources(self.shared, max_shared, "shared memory")
mod, func, n_regs, n_spills = hip_utils.load_binary(self.metadata["name"], self.asm["hsaco_path"], self.shared, device)
else:
global cuda_utils
init_cuda_utils()
max_shared = cuda_utils.get_device_properties(device)["max_shared_mem"]
if self.shared > max_shared:
raise OutOfResources(self.shared, max_shared, "shared memory")
mod, func, n_regs, n_spills = cuda_utils.load_binary(self.metadata["name"], self.asm["cubin"], self.shared, device)
self.n_spills = n_spills
self.n_regs = n_regs
self.cu_module = mod
@@ -1992,3 +2289,153 @@ def init_cuda_utils():
cuda_utils = None
def init_hip_utils():
global hip_utils
if hip_utils is None:
hip_utils = HIPUtils()
hip_utils = None
class HIPUtils(object):
def __new__(cls):
if not hasattr(cls, 'instance'):
cls.instance = super(HIPUtils, cls).__new__(cls)
return cls.instance
def _generate_src(self):
return """
#define __HIP_PLATFORM_AMD__
#include <hip/hip_runtime.h>
#define PY_SSIZE_T_CLEAN
#include <Python.h>
#include <stdio.h>
#include <stdlib.h>
static inline void gpuAssert(hipError_t code, const char *file, int line)
{{
if (code != HIP_SUCCESS)
{{
const char* prefix = "Triton Error [HIP]: ";
const char* str = hipGetErrorString(code);
char err[1024] = {0};
snprintf(err, 1024, "%s Code: %d, Messsage: %s", prefix, code, str );
PyErr_SetString(PyExc_RuntimeError, err);
}}
}}
#define HIP_CHECK(ans) { gpuAssert((ans), __FILE__, __LINE__); if (PyErr_Occurred()) return NULL; }
static PyObject* getDeviceProperties(PyObject* self, PyObject* args){
int device_id;
if (!PyArg_ParseTuple(args, "i", &device_id))
return NULL;
hipDeviceProp_t props;
HIP_CHECK(hipGetDeviceProperties(&props, device_id));
// create a struct to hold device properties
return Py_BuildValue("{s:i, s:i, s:i, s:i, s:i}", "max_shared_mem", props.sharedMemPerBlock,
"multiprocessor_count", props.multiProcessorCount,
"sm_clock_rate", props.clockRate,
"mem_clock_rate", props.memoryClockRate,
"mem_bus_width", props.memoryBusWidth);
}
static PyObject* loadBinary(PyObject* self, PyObject* args) {
const char* name;
const char* data;
Py_ssize_t data_size;
int shared;
int device;
if (!PyArg_ParseTuple(args, "ss#ii", &name, &data, &data_size, &shared, &device)) {
return NULL;
}
// Open HSACO file
FILE* hsaco_file;
if ((hsaco_file = fopen(data, "rb")) == NULL) {
return NULL;
}
// Read HSCAO file into Buffer
fseek(hsaco_file, 0L, SEEK_END);
size_t hsaco_file_size = ftell(hsaco_file);
unsigned char* hsaco = (unsigned char*) malloc(hsaco_file_size * sizeof(unsigned char));
rewind(hsaco_file);
fread(hsaco, sizeof(unsigned char), hsaco_file_size, hsaco_file);
fclose(hsaco_file);
// set HIP options
hipJitOption opt[] = {hipJitOptionErrorLogBufferSizeBytes, hipJitOptionErrorLogBuffer,
hipJitOptionInfoLogBufferSizeBytes, hipJitOptionInfoLogBuffer,
hipJitOptionLogVerbose};
const unsigned int errbufsize = 8192;
const unsigned int logbufsize = 8192;
char _err[errbufsize];
char _log[logbufsize];
void *optval[] = {(void *)(uintptr_t)errbufsize,
(void *)_err, (void *)(uintptr_t)logbufsize,
(void *)_log, (void *)1};
// launch HIP Binary
hipModule_t mod;
hipFunction_t fun;
hipModuleLoadDataEx(&mod, hsaco, 5, opt, optval);
hipModuleGetFunction(&fun, mod, name);
free(hsaco);
// get allocated registers and spilled registers from the function
int n_regs = 0;
int n_spills = 0;
if (PyErr_Occurred()) {
return NULL;
}
return Py_BuildValue("(KKii)", (uint64_t)mod, (uint64_t)fun, n_regs, n_spills);
}
static PyMethodDef ModuleMethods[] = {
{"load_binary", loadBinary, METH_VARARGS, "Load provided hsaco into HIP driver"},
{"get_device_properties", getDeviceProperties, METH_VARARGS, "Get the properties for a given device"},
{NULL, NULL, 0, NULL} // sentinel
};
static struct PyModuleDef ModuleDef = {
PyModuleDef_HEAD_INIT,
"hip_utils",
NULL, //documentation
-1, //size
ModuleMethods
};
PyMODINIT_FUNC PyInit_hip_utils(void) {
PyObject *m = PyModule_Create(&ModuleDef);
if (m == NULL) {
return NULL;
}
PyModule_AddFunctions(m, ModuleMethods);
return m;
}
"""
def __init__(self):
src = self._generate_src()
key = hashlib.md5(src.encode("utf-8")).hexdigest()
cache = CacheManager(key)
fname = "hip_utils.so"
if not cache.has_file(fname):
with tempfile.TemporaryDirectory() as tmpdir:
src_path = os.path.join(tmpdir, "main.c")
with open(src_path, "w") as f:
f.write(src)
so = _build("hip_utils", src_path, tmpdir)
with open(so, "rb") as f:
cache.put(f.read(), fname, binary=True)
import importlib.util
spec = importlib.util.spec_from_file_location("hip_utils", cache._make_path(fname))
mod = importlib.util.module_from_spec(spec)
spec.loader.exec_module(mod)
self.load_binary = mod.load_binary
self.get_device_properties = mod.get_device_properties

View File

@@ -1,9 +1,14 @@
import os
import torch
from .. import impl
from . import core, extern
LOCAL_PATH = os.path.join(os.path.dirname(os.path.abspath(__file__)), "..", "third_party", "cuda", "lib", "libdevice.10.bc")
if torch.version.hip is not None:
LOCAL_PATH = os.path.join(os.path.dirname(os.path.abspath(__file__)), "cuda2gcn.bc")
else:
LOCAL_PATH = os.path.join(os.path.dirname(os.path.abspath(__file__)), "..", "third_party", "cuda", "lib", "libdevice.10.bc")
LIBDEVICE_PATH = os.getenv("TRITON_LIBDEVICE_PATH", LOCAL_PATH)

View File

@@ -3,11 +3,12 @@ import sys
import triton
import triton._C.libtriton.triton as libtriton
import triton.compiler as tc
if __name__ == '__main__':
# valid source and target formats
VALID_FORMATS = ['triton-ir', 'triton-gpu-ir', 'llvm-ir', 'ptx']
VALID_FORMATS = ['triton-ir', 'triton-gpu-ir', 'llvm-ir', 'ptx', 'amdgcn']
# set up the argument parser
# TODO: conditional requirements
@@ -17,6 +18,10 @@ if __name__ == '__main__':
help="Target format, one of: " + ', '.join(VALID_FORMATS))
parser.add_argument('--sm', type=int, help="Compute capability to compile for")
parser.add_argument('--ptx-version', type=int, help="PTX version to compile for")
parser.add_argument('--gfx', type=str, help="AMDGPU target to compile for")
parser.add_argument('--triple', type=str, help="target triple, for example: amdgcn-amd-amdhsa")
parser.add_argument('--features', type=str, help="target features, for example: +sramecc,-xnack")
parser.add_argument('--num_warps', type=int, help="number of warps to compile ttgir for")
# parse the args
args = parser.parse_args()
@@ -38,11 +43,52 @@ if __name__ == '__main__':
print(module.str())
sys.exit(0)
if not args.num_warps:
args.num_warps = 4
# llvm-ir -> amdgcn
if args.target == 'amdgcn':
# auto detect available architecture and features
# if nothing detected, set with default values
arch_details = tc.get_amdgpu_arch_fulldetails()
if not arch_details:
arch_name = ""
arch_triple = "amdgcn-amd-amdhsa"
arch_features = ""
else:
arch_triple, arch_name, arch_features = arch_details
# stop processing if architecture name is not automatically detected and is not set manually
if not args.gfx and not arch_name:
raise argparse.ArgumentError(None, "Must specify --gfx for AMDGCN compilation")
# rewrite default and automatically detected values with manually provided data
if args.gfx:
arch_name = args.gfx
if args.triple:
arch_triple = args.triple
if args.features:
arch_features = args.features
# triton-ir -> triton-gpu-ir
# use compute_capability == 80
module = triton.compiler.ttir_to_ttgir(module, num_warps=args.num_warps) # num_stages=3, compute_capability=80)
module = triton.compiler.optimize_ttgir(module, num_stages=3, compute_capability=80)
# triton-gpu-ir -> llvm-ir
# use compute_capability == 80
module = triton.compiler.ttgir_to_llir(module, extern_libs=None, compute_capability=80)
# llvm-ir -> amdgcn asm, hsaco binary
module, hsaco_path = triton.compiler.llir_to_amdgcn_and_hsaco(module, arch_name, arch_triple, arch_features)
print(hsaco_path)
print(module)
sys.exit(0)
if not args.sm:
raise argparse.ArgumentError(None, "Must specify --sm for PTX compilation")
# triton-ir -> triton-gpu-ir
module = triton.compiler.ttir_to_ttgir(module, num_warps=4)
module = triton.compiler.ttir_to_ttgir(module, num_warps=args.num_warps)
module = triton.compiler.optimize_ttgir(module, num_stages=3, compute_capability=args.sm)
if args.target == 'triton-gpu-ir':
print(module.str())
@@ -54,10 +100,16 @@ if __name__ == '__main__':
print(module)
sys.exit(0)
if not args.ptx_version:
raise argparse.ArgumentError(None, "Must specify --ptx-version for PTX compilation")
# llvm-ir -> ptx
module = triton.compiler.llir_to_ptx(module, compute_capability=args.sm, ptx_version=args.ptx_version)
assert args.target == 'ptx'
if args.target == 'ptx':
if not args.ptx_version:
raise argparse.ArgumentError(None, "Must specify --ptx-version for PTX compilation")
module = triton.compiler.llir_to_ptx(module, compute_capability=args.sm, ptx_version=args.ptx_version)
# llvm-ir -> amdgcn
if args.target == 'amdgcn':
if not args.gfx:
raise argparse.ArgumentError(None, "Must specify --gfx for AMDGCN compilation")
module, hsaco_path = triton.compiler.llir_to_amdgcn_and_hsaco(module, args.gfx)
print(module)