[ROCM] Core Functionality for AMD (#1983)

* this pr adds a third party backend for triton that works on AMD
* this expose a lot of the work that has been done in our
[fork](https://github.com/ROCmSoftwarePlatform/triton)
* most unit tests on `test_core.py` pass
* it skips some unit tests for various reasons
* we plan to follow up with more prs improving Functionality and
Performance in the future

---------

Co-authored-by: Philippe Tillet <phil@openai.com>
This commit is contained in:
Michael Melesse
2023-10-16 15:06:07 -05:00
parent 833c9b985f
commit 09ba348f87
17 changed files with 264 additions and 377 deletions

View File

@@ -5,6 +5,7 @@ import importlib.util
import os
import re
import subprocess
import traceback
from typing import Dict
from ..runtime.driver import DriverBase
@@ -94,7 +95,7 @@ def get_backend(device_type: str):
try:
importlib.import_module(device_backend_package_name, package=__spec__.name)
except Exception:
return None
traceback.print_exc()
else:
return None
return _backends[device_type] if device_type in _backends else None

View File

@@ -5,19 +5,18 @@ import hashlib
import json
import os
import re
import subprocess
import tempfile
from collections import namedtuple
from pathlib import Path
from typing import Any, Tuple
from typing import Any
from .._C.libtriton.triton import (ClusterInfo, TMAInfos, add_external_libs,
compile_ptx_to_cubin, get_env_vars, get_num_warps,
get_shared_memory_size, ir, runtime,
translate_llvmir_to_hsaco, translate_llvmir_to_ptx,
translate_triton_gpu_to_llvmir, get_arch_info,
get_warp_size)
translate_llvmir_to_ptx,
translate_triton_gpu_to_llvmir)
from ..common.backend import get_backend, path_to_ptxas
from ..common.build import is_hip
# from ..runtime import driver, jit, JITFunction
# TODO: runtime.errors
from ..runtime.autotuner import OutOfResources
@@ -214,71 +213,6 @@ def ptx_to_cubin(ptx: str, arch: int):
return compile_ptx_to_cubin(ptx, ptxas, arch)
# AMDGCN translation
def get_amdgcn_bitcode_paths(arch):
gpu_arch_agnostic_bitcode_libraries = ["opencl.bc",
"ocml.bc",
"ockl.bc",
"oclc_finite_only_off.bc",
"oclc_daz_opt_off.bc",
"oclc_correctly_rounded_sqrt_on.bc",
"oclc_unsafe_math_off.bc",
"oclc_wavefrontsize64_on.bc",
"oclc_abi_version_400.bc",]
gfx_arch = arch[1]
gfx_arch_id = re.search('gfx(\\w+)', gfx_arch).group(1).strip()
gpu_arch_specific_bitcode_library = 'oclc_isa_version_' + gfx_arch_id + ".bc"
bitcode_path_dir = os.path.join(Path(__file__).parent.parent.resolve(), "third_party/rocm/lib/bitcode/")
amdgcn_bitcode_paths = {}
i = 0
for bc_lib in gpu_arch_agnostic_bitcode_libraries:
bc_path = bitcode_path_dir + bc_lib
if os.path.exists(bc_path):
amdgcn_bitcode_paths['library_' + str(i)] = bc_path
i += 1
bc_gfx_path = bitcode_path_dir + gpu_arch_specific_bitcode_library
if os.path.exists(bc_gfx_path):
amdgcn_bitcode_paths['library_' + str(i)] = bc_gfx_path
return amdgcn_bitcode_paths
def get_amdgpu_arch_fulldetails():
"""
get the amdgpu full ISA details for compiling:
i.e., arch_triple: amdgcn-amd-amdhsa; arch_name: gfx906; arch_features: sramecc+:xnack-
"""
try:
# TODO: package rocm.cc with Triton
arch_info = get_arch_info()
warp_size = get_warp_size()
gfx_arch_details = re.search('amd.*', arch_info).group(0).strip().split('--')
arch_triple = gfx_arch_details[0]
arch_name_features = gfx_arch_details[1].split(':')
arch_name = arch_name_features[0]
arch_features = ""
return [arch_triple, arch_name, arch_features, warp_size]
except BaseException as e:
print("Error: Attempting to get amgpu ISA Details {}".format(e))
return None
def llir_to_amdgcn_and_hsaco(mod: Any, gfx_arch: str, gfx_triple: str, gfx_features: str) -> Tuple[str, str]:
'''
Translate TritonGPU module to HSACO code based on full details of gpu architecture.
:param mod: a TritonGPU dialect module
:return:
- AMDGCN code
- Path to HSACO object
'''
return translate_llvmir_to_hsaco(mod, gfx_arch, gfx_triple, gfx_features)
# ------------------------------------------------------------------------------
# compiler
# ------------------------------------------------------------------------------
@@ -347,8 +281,10 @@ arg_type_pattern = {
"ttgir": mlir_arg_type_pattern,
"ptx": ptx_arg_type_pattern,
}
ttgir_num_warps_pattern = r'"triton_gpu.num-warps"\s?=\s?(\d+)\s?:'
if is_hip():
ttgir_num_warps_pattern = r'"triton_gpu_rocm.num-warps"\s?=\s?(\d+)\s?:'
else:
ttgir_num_warps_pattern = r'"triton_gpu.num-warps"\s?=\s?(\d+)\s?:'
def _get_jsonable_constants(constants):
@@ -389,17 +325,10 @@ def is_hip():
from ..language.semantic import gpu_matrix_core_version
def get_architecture_descriptor(capability):
try:
import torch
except ImportError:
raise ImportError("Triton requires PyTorch to be installed")
if capability is None:
if torch.version.hip is None:
device = get_current_device()
capability = get_device_capability(device)
capability = capability[0] * 10 + capability[1]
else:
capability = get_amdgpu_arch_fulldetails()
device = get_current_device()
capability = get_device_capability(device)
capability = capability[0] * 10 + capability[1]
return capability
@@ -429,23 +358,6 @@ def get_arch_default_num_stages(device_type, capability=None):
return num_stages
def add_rocm_stages(arch, extern_libs, stages):
extern_libs.update(get_amdgcn_bitcode_paths(arch))
for key in list(extern_libs):
if extern_libs[key] == '' or extern_libs[key] is None:
extern_libs.pop(key)
gfx_arch_full_details = arch
gfx_arch = os.environ.get('MI_GPU_ARCH', gfx_arch_full_details[1])
if gfx_arch is None:
raise RuntimeError('gfx_arch is None (not specified)')
stages["amdgcn"] = (lambda path: Path(path).read_text(),
lambda src: llir_to_amdgcn_and_hsaco(src, gfx_arch,
gfx_arch_full_details[0],
gfx_arch_full_details[2]))
def add_cuda_stages(arch, extern_libs, stages):
stages["ptx"] = (lambda path: Path(path).read_text(),
@@ -457,23 +369,22 @@ def add_cuda_stages(arch, extern_libs, stages):
def compile(fn, **kwargs):
# Get device type to decide which backend should be used
device_type = kwargs.get("device_type", "cuda")
_device_backend = get_backend(device_type)
capability = kwargs.get("cc", None)
if device_type in ["cuda", "hip"]:
# hip with kwargs.get("cc", None) causes multiprocessing issues in torch.compile
if device_type == "hip":
arch = get_architecture_descriptor(None if type(capability) is int else capability)
else:
arch = get_architecture_descriptor(capability)
if is_hip():
device_type = "hip"
if device_type == "cuda":
_device_backend = get_backend(device_type)
arch = get_architecture_descriptor(capability)
else:
_device_backend = get_backend(device_type)
assert _device_backend
arch = _device_backend.get_architecture_descriptor(**kwargs)
is_cuda = device_type == "cuda" and _is_cuda(arch)
is_hip = device_type in ["cuda", "hip"] and not is_cuda
warp_size = CUDA_DEFAULT_WARP_SIZE if _is_cuda(arch) else arch[3]
if is_hip():
is_cuda = False
context = ir.context()
constants = kwargs.get("constants", dict())
num_warps = kwargs.get("num_warps", get_arch_default_num_warps(device_type))
@@ -506,14 +417,20 @@ def compile(fn, **kwargs):
stages["ast"] = (lambda path: fn, None)
stages["ttir"] = (lambda path: parse_mlir_module(path, context),
lambda src: optimize_ttir(ast_to_ttir(src, signature, configs[0], constants, debug=debug, arch=arch), arch))
stages["ttgir"] = (lambda path: parse_mlir_module(path, context),
lambda src: optimize_ttgir(ttir_to_ttgir(src, num_warps, warp_size, num_ctas, arch), num_stages, num_warps, num_ctas, arch, cluster_info, enable_warp_specialization, enable_persistent, optimize_epilogue, matrix_instr_nonkdim))
stages["llir"] = (lambda path: Path(path).read_text(),
lambda src: ttgir_to_llir(src, extern_libs, arch, tma_infos, waves_per_eu))
if is_cuda:
stages["ttgir"] = (lambda path: parse_mlir_module(path, context),
lambda src: optimize_ttgir(ttir_to_ttgir(src, num_warps, num_ctas, arch), num_stages, num_warps, num_ctas, arch, cluster_info, enable_warp_specialization, enable_persistent, optimize_epilogue))
stages["llir"] = (lambda path: Path(path).read_text(),
lambda src: ttgir_to_llir(src, extern_libs, arch, tma_infos))
add_cuda_stages(arch, extern_libs, stages)
elif is_hip:
add_rocm_stages(arch, extern_libs, stages)
elif device_type == "hip":
_device_backend.add_stages(arch, extern_libs, stages, num_warps=num_warps, num_stages=num_stages)
elif device_type == "xpu":
stages["ttgir"] = (lambda path: parse_mlir_module(path, context),
lambda src: optimize_ttgir(ttir_to_ttgir(src, num_warps, num_ctas, arch), num_stages, num_warps, num_ctas, arch, cluster_info, enable_warp_specialization, enable_persistent, optimize_epilogue))
stages["llir"] = (lambda path: Path(path).read_text(),
lambda src: ttgir_to_llir(src, extern_libs, arch, tma_infos))
_device_backend.add_stages(arch, extern_libs, stages)
else:
# pass the user's configuration to the backend device.
arch["num_warps"] = num_warps
@@ -632,17 +549,23 @@ def compile(fn, **kwargs):
else:
asm[ir_name] = str(next_module)
if ir_name == "llir" and "shared" not in metadata:
metadata["shared"] = get_shared_memory_size(module)
if is_hip():
metadata["shared"] = _device_backend.get_shared_memory_size(module)
else:
metadata["shared"] = get_shared_memory_size(module)
if ir_name == "ttgir":
metadata["enable_warp_specialization"] = ir.is_ws_supported(next_module)
if metadata["enable_warp_specialization"]:
metadata["num_warps"] = get_num_warps(next_module)
if is_hip():
metadata["num_warps"] = _device_backend.get_num_warps(next_module)
else:
metadata["enable_warp_specialization"] = ir.is_ws_supported(next_module)
if metadata["enable_warp_specialization"]:
metadata["num_warps"] = get_num_warps(next_module)
if ir_name == "ptx":
metadata["name"] = get_kernel_name(next_module, pattern='// .globl')
if ir_name == "amdgcn":
metadata["name"] = get_kernel_name(next_module[0], pattern='.globl')
asm["hsaco_path"] = next_module[1]
if not is_cuda and not is_hip:
if not is_cuda and not is_hip():
_device_backend.add_meta_info(ir_name, module, next_module, metadata, asm)
module = next_module
@@ -667,7 +590,7 @@ def compile(fn, **kwargs):
ids_of_const_exprs = tuple(fn.constexprs) if isinstance(fn, JITFunction) else ()
ids = {"ids_of_tensormaps": ids_of_tensormaps, "ids_of_folded_args": ids_of_folded_args, "ids_of_const_exprs": ids_of_const_exprs}
# cache manager
if is_cuda or is_hip:
if is_cuda:
so_path = make_stub(name, signature, constants, ids, enable_warp_specialization=enable_warp_specialization)
else:
so_path = _device_backend.make_launcher_stub(name, signature, constants, ids)
@@ -707,7 +630,7 @@ class CompiledKernel:
self.tensormaps_info = metadata["tensormaps_info"]
self.constants = metadata["constants"]
self.device_type = metadata["device_type"]
self.device_backend = get_backend(self.device_type) if self.device_type not in ["cuda", "hip"] else None
self.device_backend = get_backend(self.device_type) if self.device_type not in ["cuda"] else None
# initialize asm dict
self.asm = asm
# binaries are lazily initialized
@@ -721,7 +644,7 @@ class CompiledKernel:
if self.cu_module is not None:
return
if self.device_type in ["cuda", "hip"]:
if self.device_type in ["cuda"]:
device = get_current_device()
bin_path = {
driver.HIP: "hsaco_path",
@@ -767,7 +690,7 @@ class CompiledKernel:
def runner(*args, stream=None):
args_expand = self.assemble_tensormap_to_arg(args)
if stream is None:
if self.device_type in ["cuda", "hip"]:
if self.device_type in ["cuda"]:
stream = get_cuda_stream()
else:
stream = get_backend(self.device_type).get_stream(None)

View File

@@ -3,16 +3,11 @@ import os
import tempfile
from ..common import _build
from ..common.build import is_hip
from ..runtime.cache import get_cache_manager
from ..runtime.jit import version_key
from .utils import generate_cu_signature
def is_hip():
import torch
return torch.version.hip is not None
# ----- stub --------
@@ -103,150 +98,9 @@ def generate_launcher(constants, signature, ids):
format = "iiiiiiiiiKKOOO" + ''.join([format_of(_extracted_type(ty)) for ty in signature.values()])
# generate glue code
if is_hip():
folded_without_constexprs = [c for c in ids['ids_of_folded_args'] if c not in ids['ids_of_const_exprs']]
params = [i for i in signature.keys() if i >= start_desc or (i not in constants and i not in folded_without_constexprs)]
src = f"""
#define __HIP_PLATFORM_AMD__
#include <hip/hip_runtime.h>
#include <Python.h>
#include <stdbool.h>
#include <dlfcn.h>
static inline void gpuAssert(hipError_t code, const char *file, int line)
{{
if (code != HIP_SUCCESS)
{{
const char* prefix = "Triton Error [HIP]: ";
const char* str = hipGetErrorString(code);
char err[1024] = {{0}};
snprintf(err, 1024, "%s Code: %d, Messsage: %s", prefix, code, str );
PyErr_SetString(PyExc_RuntimeError, err);
}}
}}
#define HIP_CHECK(ans) {{ gpuAssert((ans), __FILE__, __LINE__); }}
static void _launch(int gridX, int gridY, int gridZ, int num_warps, int num_ctas, int clusterDimX, int clusterDimY, int clusterDimZ, int shared_memory, hipStream_t stream, hipFunction_t function{', ' + arg_decls if len(arg_decls) > 0 else ''}) {{
// printf("_launch hip kernel\\n");
void *params[] = {{ {', '.join(f"&arg{i}" for i in params)} }};
if (gridX*gridY*gridZ > 0) {{
HIP_CHECK(hipModuleLaunchKernel(function, gridX, gridY, gridZ, 64*num_warps, 1, 1, shared_memory, stream, params, 0));
}}
}}
typedef struct _DevicePtrInfo {{
hipDeviceptr_t dev_ptr;
bool valid;
}} DevicePtrInfo;
static inline DevicePtrInfo getPointer(PyObject *obj, int idx) {{
DevicePtrInfo ptr_info;
ptr_info.dev_ptr = 0;
ptr_info.valid = true;
if (PyLong_Check(obj)) {{
ptr_info.dev_ptr = (hipDeviceptr_t)PyLong_AsUnsignedLongLong(obj);
return ptr_info;
}}
if (obj == Py_None) {{
// valid nullptr
return ptr_info;
}}
PyObject *ptr = PyObject_GetAttrString(obj, "data_ptr");
if(ptr){{
PyObject *empty_tuple = PyTuple_New(0);
PyObject *ret = PyObject_Call(ptr, empty_tuple, NULL);
Py_DECREF(empty_tuple);
Py_DECREF(ptr);
if (!PyLong_Check(ret)) {{
PyErr_SetString(PyExc_TypeError, "data_ptr method of Pointer object must return 64-bit int");
ptr_info.valid = false;
return ptr_info;
}}
ptr_info.dev_ptr = (hipDeviceptr_t)PyLong_AsUnsignedLongLong(ret);
if(!ptr_info.dev_ptr)
return ptr_info;
uint64_t dev_ptr;
hipError_t status = hipPointerGetAttribute(&dev_ptr, HIP_POINTER_ATTRIBUTE_DEVICE_POINTER, ptr_info.dev_ptr);
if (status == hipErrorInvalidValue) {{
PyErr_Format(PyExc_ValueError,
"Pointer argument (at %d) cannot be accessed from Triton (cpu tensor?)", idx);
ptr_info.valid = false;
}}
ptr_info.dev_ptr = (hipDeviceptr_t)dev_ptr;
Py_DECREF(ret);
return ptr_info;
}}
PyErr_SetString(PyExc_TypeError, "Pointer argument must be either uint64 or have data_ptr method");
return ptr_info;
}}
static PyObject* launch(PyObject* self, PyObject* args) {{
// printf("launch\\n");
int gridX, gridY, gridZ;
uint64_t _stream;
uint64_t _function;
int num_warps;
int num_ctas;
int clusterDimX;
int clusterDimY;
int clusterDimZ;
int shared_memory;
PyObject *launch_enter_hook = NULL;
PyObject *launch_exit_hook = NULL;
PyObject *compiled_kernel = NULL;
{' '.join([f"{_extracted_type(ty)} _arg{i}; " for i, ty in signature.items()])}
if(!PyArg_ParseTuple(args, \"{format}\", &gridX, &gridY, &gridZ, &num_warps, &num_ctas, &clusterDimX, &clusterDimY, &clusterDimZ, &shared_memory, &_stream, &_function, &launch_enter_hook, &launch_exit_hook, &compiled_kernel{', ' + ', '.join(f"&_arg{i}" for i, ty in signature.items()) if len(signature) > 0 else ''})) {{
return NULL;
}}
if (launch_enter_hook != Py_None) {{
PyObject_CallObject(launch_enter_hook, args);
}}
// raise exception asap
{"; ".join([f"DevicePtrInfo ptr_info{i} = getPointer(_arg{i}, {i}); if (!ptr_info{i}.valid) return NULL;" if ty[0] == "*" else "" for i, ty in signature.items()])};
_launch(gridX, gridY, gridZ, num_warps, num_ctas, clusterDimX, clusterDimY, clusterDimZ, shared_memory, (hipStream_t)_stream, (hipFunction_t)_function{', ' + ', '.join(f"ptr_info{i}.dev_ptr" if ty[0]=="*" else f"_arg{i}"for i, ty in signature.items()) if len(signature) > 0 else ''});
if (launch_exit_hook != Py_None) {{
PyObject_CallObject(launch_exit_hook, args);
}}
if(PyErr_Occurred()) {{
return NULL;
}}
// return None
Py_INCREF(Py_None);
return Py_None;
}}
static PyMethodDef ModuleMethods[] = {{
{{"launch", launch, METH_VARARGS, "Entry point for all kernels with this signature"}},
{{NULL, NULL, 0, NULL}} // sentinel
}};
static struct PyModuleDef ModuleDef = {{
PyModuleDef_HEAD_INIT,
\"__triton_launcher\",
NULL, //documentation
-1, //size
ModuleMethods
}};
PyMODINIT_FUNC PyInit___triton_launcher(void) {{
PyObject *m = PyModule_Create(&ModuleDef);
if(m == NULL) {{
return NULL;
}}
PyModule_AddFunctions(m, ModuleMethods);
return m;
}}
"""
else:
folded_without_constexprs = [c for c in ids['ids_of_folded_args'] if c not in ids['ids_of_const_exprs']]
params = [i for i in signature.keys() if i >= start_desc or (i not in constants and i not in folded_without_constexprs)]
src = f"""
folded_without_constexprs = [c for c in ids['ids_of_folded_args'] if c not in ids['ids_of_const_exprs']]
params = [i for i in signature.keys() if i >= start_desc or (i not in constants and i not in folded_without_constexprs)]
src = f"""
#include \"cuda.h\"
#include <stdbool.h>
#include <Python.h>

View File

@@ -1,17 +1,18 @@
import functools
import os
from ..common.build import is_hip
from . import core
@functools.lru_cache()
def libdevice_path():
import torch
third_party_dir = os.path.join(os.path.dirname(os.path.abspath(__file__)), "..", "third_party")
if torch.version.hip is None:
default = os.path.join(third_party_dir, "cuda", "lib", "libdevice.10.bc")
if is_hip():
default = os.path.join(third_party_dir, "hip", "lib", "bitcode", "cuda2gcn.bc")
else:
default = os.path.join(third_party_dir, "rocm", "lib", "bitcode", "cuda2gcn.bc")
default = os.path.join(third_party_dir, "cuda", "lib", "libdevice.10.bc")
return os.getenv("TRITON_LIBDEVICE_PATH", default)

View File

@@ -4,6 +4,7 @@ from functools import wraps
from typing import List, Optional, Sequence, Tuple, TypeVar
from .._C.libtriton.triton import ir
from ..common.build import is_hip
from . import core as tl
import triton._C.libtriton.triton as _triton
@@ -1301,6 +1302,19 @@ def mfma_supported(M, N, K, allow_tf32, ret_scalar_ty) -> bool:
return False
return True
def gpu_has_mfma() -> bool:
if not is_hip():
return False
return True # mfma supported in ['gfx908', 'gfx90a']
def mfma_supported(M, N, K, allow_tf32, ret_scalar_ty) -> bool:
if not gpu_has_mfma():
return False
# TODO: Add check for configurations and types.
return True
def dot(lhs: tl.tensor,
rhs: tl.tensor,
allow_tf32: bool,

View File

@@ -383,20 +383,20 @@ def {self.fn.__name__}({args_signature}grid=None, num_warps=None, num_ctas=1, nu
device_type = self._conclude_device_type(device_types, {pinned_memory_flags})
device_backend = None
if device_type not in ['cuda', 'hip']:
if device_type not in ['cuda']:
device_backend = get_backend(device_type)
if device_backend is None:
raise ValueError('Cannot find backend for ' + device_type)
if device is None:
if device_type in ['cuda', 'hip']:
if device_type in ['cuda']:
device = get_current_device()
set_current_device(device)
else:
device = device_backend.get_current_device()
device_backend.set_current_device(device)
if stream is None and not warmup:
if device_type in ['cuda', 'hip']:
if device_type in ['cuda']:
stream = get_cuda_stream(device)
else:
stream = device_backend.get_stream()