[ROCM] Core Functionality for AMD (#1983)

* this pr adds a third party backend for triton that works on AMD
* this expose a lot of the work that has been done in our
[fork](https://github.com/ROCmSoftwarePlatform/triton)
* most unit tests on `test_core.py` pass
* it skips some unit tests for various reasons
* we plan to follow up with more prs improving Functionality and
Performance in the future

---------

Co-authored-by: Philippe Tillet <phil@openai.com>
This commit is contained in:
Michael Melesse
2023-10-16 15:06:07 -05:00
parent 833c9b985f
commit 09ba348f87
17 changed files with 264 additions and 377 deletions

View File

@@ -27,7 +27,7 @@ jobs:
run: |
if [ x"${{ github.repository }}" == x"openai/triton" ]; then
echo '::set-output name=matrix-required::[["self-hosted", "A100"], ["self-hosted", "H100"]]'
echo '::set-output name=matrix-optional::[]'
echo '::set-output name=matrix-optional::[["self-hosted", "gfx908"], ["self-hosted", "arc770"]]'
else
echo '::set-output name=matrix-required::["ubuntu-latest"]'
echo '::set-output name=matrix-optional::["ubuntu-latest"]'
@@ -210,10 +210,12 @@ jobs:
- name: Install Triton on ROCM
if: ${{ env.BACKEND == 'ROCM'}}
run: |
git submodule update --init --recursive
cd python
python3 -m pip install --upgrade pip
python3 -m pip install cmake==3.24
python3 -m pip install torch==1.13.1 --index-url https://download.pytorch.org/whl/rocm5.2
export TRITON_CODEGEN_AMD_HIP_BACKEND=1
python3 -m pip install --no-build-isolation -vvv '.[tests]'
- name: Install Triton on XPU
@@ -235,7 +237,7 @@ jobs:
if: ${{ env.BACKEND == 'ROCM'}}
run: |
cd python/test/unit/language
python3 -m pytest --capture=tee-sys -rfs --verbose "test_core.py::test_empty_kernel"
python3 -m pytest --capture=tee-sys -rfs --verbose "test_core.py"
- name: Run python tests on XPU
if: ${{ env.BACKEND == 'XPU'}}

4
.gitmodules vendored
View File

@@ -1,3 +1,7 @@
[submodule "third_party/intel_xpu_backend"]
path = third_party/intel_xpu_backend
url = http://github.com/intel/intel-xpu-backend-for-triton
[submodule "third_party/amd_hip_backend"]
path = third_party/amd_hip_backend
url = https://github.com/ROCmSoftwarePlatform/triton
branch = third_party_backend_2

View File

@@ -249,7 +249,6 @@ if(TRITON_BUILD_PYTHON_MODULE)
TritonNvidiaGPUTransforms
TritonLLVMIR
TritonPTX
TritonHSACO
${dialect_libs}
${conversion_libs}

View File

@@ -53,7 +53,6 @@ llvm_update_compile_flags(triton-translate)
TritonNvidiaGPUTransforms
TritonLLVMIR
TritonPTX
TritonHSACO
${dialect_libs}
${conversion_libs}
# tests

View File

@@ -15,7 +15,6 @@
#include "triton/Dialect/Triton/IR/Dialect.h"
#include "triton/Dialect/TritonGPU/IR/Dialect.h"
#include "triton/Dialect/TritonNvidiaGPU/IR/Dialect.h"
#include "triton/Target/HSACO/HSACOTranslation.h"
#include "triton/Target/LLVMIR/LLVMIRTranslation.h"
#include "triton/Target/PTX/PTXTranslation.h"
#include "llvm/IR/LLVMContext.h"
@@ -138,16 +137,11 @@ LogicalResult tritonTranslateMain(int argc, char **argv,
llvm::errs() << "Translate to LLVM IR failed";
}
if (targetKind == "llvmir")
if (targetKind == "llvmir") {
llvm::outs() << *llvmir << '\n';
else if (targetKind == "ptx")
} else if (targetKind == "ptx") {
llvm::outs() << ::triton::translateLLVMIRToPTX(*llvmir, SMArch.getValue(),
ptxVersion.getValue());
else if (targetKind == "hsaco") {
auto [module, hsaco] = ::triton::translateLLVMIRToHSACO(
*llvmir, GCNArch.getValue(), GCNTriple.getValue(),
GCNFeatures.getValue());
llvm::outs() << hsaco;
} else {
llvm::errs() << "Error: Unknown target specified: " << targetKind << "\n";
return failure();

View File

@@ -46,7 +46,7 @@ inline std::string getenv(const char *name) {
inline bool getBoolEnv(const std::string &env) {
std::string msg = "Environment variable " + env + " is not recognized";
assert(triton::ENV_VARS.find(env.c_str()) != triton::ENV_VARS.end() &&
assert(::triton::ENV_VARS.find(env.c_str()) != ::triton::ENV_VARS.end() &&
msg.c_str());
const char *s = std::getenv(env.c_str());
std::string str(s ? s : "");

View File

@@ -1,3 +1,2 @@
add_subdirectory(LLVMIR)
add_subdirectory(PTX)
add_subdirectory(HSACO)

View File

@@ -1,9 +0,0 @@
add_mlir_translation_library(TritonHSACO
HSACOTranslation.cpp
LINK_COMPONENTS
Core
LINK_LIBS PUBLIC
TritonLLVMIR
)

View File

@@ -33,7 +33,6 @@
#include "triton/Dialect/TritonGPU/Transforms/Passes.h"
#include "triton/Dialect/TritonNvidiaGPU/IR/Dialect.h"
#include "triton/Dialect/TritonNvidiaGPU/Transforms/Passes.h"
#include "triton/Target/HSACO/HSACOTranslation.h"
#include "triton/Target/LLVMIR/LLVMIRTranslation.h"
#include "triton/Target/PTX/PTXTranslation.h"
#include "triton/Target/HSACO/HSACOTranslation.h"
@@ -2083,24 +2082,6 @@ void init_triton_translation(py::module &m) {
const std::vector<std::string> &paths) {
::mlir::triton::addExternalLibs(op, names, paths);
});
m.def(
"translate_llvmir_to_hsaco",
[](const std::string llvmIR, std::string gfx_arch, std::string gfx_triple,
std::string gfx_features) -> std::tuple<std::string, std::string> {
// create LLVM module from C++
llvm::LLVMContext context;
std::unique_ptr<llvm::MemoryBuffer> buffer =
llvm::MemoryBuffer::getMemBuffer(llvmIR.c_str());
llvm::SMDiagnostic error;
std::unique_ptr<llvm::Module> module =
llvm::parseIR(buffer->getMemBufferRef(), error, context);
// translate module to HSACO
auto hsacoCode = triton::translateLLVMIRToHSACO(
*module, gfx_arch, gfx_triple, gfx_features);
return hsacoCode;
},
ret::take_ownership);
}
void init_triton(py::module &m) {

View File

@@ -12,6 +12,7 @@ from numpy.random import RandomState
import triton
import triton._C.libtriton.triton as _triton
import triton.language as tl
from triton.common.build import is_hip
from triton.runtime.jit import JITFunction, TensorWrapper, reinterpret
int_dtypes = ['int8', 'int16', 'int32', 'int64']
@@ -25,6 +26,13 @@ torch_dtypes = ['bool'] + int_dtypes + ['uint8'] + float_dtypes + ['bfloat16']
# num_ctas_list = [1, 4] if torch.cuda.get_device_capability()[0] == 9 else [1]
num_ctas_list = [1]
if is_hip():
GPU_DIALECT = "triton_gpu_rocm"
THREADS_PER_WARP = 64
else:
GPU_DIALECT = "triton_gpu"
THREADS_PER_WARP = 32
def _bitwidth(dtype: str) -> int:
# ex.: "int64" -> 64
@@ -137,7 +145,7 @@ class MmaLayout:
self.instr_shape = str(instr_shape)
def __str__(self):
return f"#triton_gpu.mma<{{versionMajor={self.version[0]}, versionMinor={self.version[1]}, warpsPerCTA={self.warps_per_cta}, CTAsPerCGA={self.ctas_per_cga}, CTASplitNum={self.cta_split_num}, CTAOrder={self.cta_order}, instrShape={self.instr_shape}}}>"
return f"#{GPU_DIALECT}.mma<{{versionMajor={self.version[0]}, versionMinor={self.version[1]}, warpsPerCTA={self.warps_per_cta}, CTAsPerCGA={self.ctas_per_cga}, CTASplitNum={self.cta_split_num}, CTAOrder={self.cta_order}, instrShape={self.instr_shape}}}>"
class BlockedLayout:
@@ -151,7 +159,7 @@ class BlockedLayout:
self.cta_order = str(cta_order)
def __str__(self):
return f"#triton_gpu.blocked<{{sizePerThread={self.sz_per_thread}, threadsPerWarp={self.threads_per_warp}, warpsPerCTA={self.warps_per_cta}, order={self.order}, CTAsPerCGA={self.ctas_per_cga}, CTASplitNum={self.cta_split_num}, CTAOrder={self.cta_order}}}>"
return f"#{GPU_DIALECT}.blocked<{{sizePerThread={self.sz_per_thread}, threadsPerWarp={self.threads_per_warp}, warpsPerCTA={self.warps_per_cta}, order={self.order}, CTAsPerCGA={self.ctas_per_cga}, CTASplitNum={self.cta_split_num}, CTAOrder={self.cta_order}}}>"
class SharedLayout:
@@ -165,7 +173,7 @@ class SharedLayout:
self.cta_order = str(cta_order)
def __str__(self):
return f"#triton_gpu.shared<{{vec={self.vec}, perPhase={self.per_phase}, maxPhase={self.max_phase}, order={self.order}, CTAsPerCGA={self.ctas_per_cga}, CTASplitNum={self.cta_split_num}, CTAOrder={self.cta_order}}}>"
return f"#{GPU_DIALECT}.shared<{{vec={self.vec}, perPhase={self.per_phase}, maxPhase={self.max_phase}, order={self.order}, CTAsPerCGA={self.ctas_per_cga}, CTASplitNum={self.cta_split_num}, CTAOrder={self.cta_order}}}>"
@pytest.mark.parametrize("dtype_x", list(dtypes) + ["bfloat16"])
@@ -851,6 +859,8 @@ def test_abs(dtype_x, device):
@pytest.mark.parametrize("in_dtype", [tl.float8e4b15, tl.float8e4nv, tl.float8e5])
def test_abs_fp8(in_dtype, device):
if is_hip():
pytest.skip('test_abs_fp8 not supported on HIP.')
@triton.jit
def abs_kernel(X, Z, SIZE: tl.constexpr):
@@ -1056,6 +1066,9 @@ def noinline_multi_values_fn(x, y, Z):
@pytest.mark.parametrize("mode", ["simple", "call_graph", "shared", "dynamic", "multi_values"])
def test_noinline(mode, device):
if is_hip() and mode == "shared":
pytest.skip('test_noinline["shared"] not supported on HIP.')
@triton.jit
def kernel(X, Y, Z):
x = tl.load(X)
@@ -1141,6 +1154,9 @@ def test_atomic_rmw(op, dtype_x_str, mode, sem, device):
else:
np.testing.assert_allclose(z_ref, to_numpy(z_tri), rtol=0.01)
sem_str = "acq_rel" if sem is None else sem
if is_hip():
return
assert f"atom.global.gpu.{sem_str}" in h.asm["ptx"]
@@ -1232,6 +1248,8 @@ def test_atomic_cas(sem, num_ctas, device):
h = serialized_add[(64,)](data, Lock, SEM=sem, num_ctas=num_ctas)
sem_str = "acq_rel" if sem is None else sem
np.testing.assert_allclose(to_numpy(data), to_numpy(ref))
if is_hip():
return
assert f"atom.global.{sem_str}" in h.asm["ptx"]
@@ -1261,6 +1279,9 @@ def test_cast(dtype_x, dtype_z, bitcast, num_ctas, device):
check_type_supported(dtype_x, device)
check_type_supported(dtype_z, device)
if is_hip() and (dtype_z == "bfloat16"):
pytest.skip(f'test_cast{(dtype_x, dtype_z)} cast to bfloat16 not supported on HIP.')
size = 1024
# This is tricky because numpy doesn't have bfloat, and torch doesn't have uints.
if dtype_x.startswith('bfloat'):
@@ -1358,7 +1379,10 @@ def test_load_store_same_ptr(device):
for _ in range(1000):
x = torch.ones((65536,), device=device, dtype=torch.float32)
kernel[(65536,)](x, num_warps=32)
if is_hip():
kernel[(65536,)](x, num_warps=16) # threads per Warp for ROCM is 64
else:
kernel[(65536,)](x, num_warps=32)
assert torch.all(x == 2)
@@ -1452,6 +1476,8 @@ def test_fp8_fpN_roundtrip(in_dtype, out_dtype, device):
"""
check_type_supported(in_dtype, device)
check_type_supported(out_dtype, device)
if is_hip():
pytest.skip('test_abs_fp8 not supported on HIP.')
@triton.jit
def copy_kernel(input_ptr, output_ptr, n_elements, BLOCK_SIZE: tl.constexpr):
@@ -1507,6 +1533,9 @@ def get_reduced_dtype(dtype_str, op):
def test_reduce1d(op, dtype_str, shape, num_ctas, device):
check_type_supported(dtype_str, device) # bfloat16 on cc < 80 will not be tested
if is_hip():
pytest.skip(f"test_reduce1d not supported on HIP")
# triton kernel
@triton.jit
def kernel(X, Z, BLOCK: tl.constexpr):
@@ -1597,7 +1626,10 @@ reduce_configs2 = [
def test_reduce2d(op, dtype_str, shape, axis, num_ctas, device):
check_type_supported(dtype_str, device) # bfloat16 on cc < 80 will not be tested
if is_hip():
pytest.skip(f"test_reduce2d not supported on HIP")
# triton kernel
@triton.jit
def kernel(X, Z, BLOCK_M: tl.constexpr, BLOCK_N: tl.constexpr, AXIS: tl.constexpr):
range_m = tl.arange(0, BLOCK_M)
@@ -1667,6 +1699,8 @@ scan_configs = [
@pytest.mark.parametrize("op, dtype_str, shape, axis, num_warps", scan_configs)
def test_scan2d(op, dtype_str, shape, axis, num_warps, device):
if is_hip():
pytest.skip("test_scan2d is not supported in HIP")
check_type_supported(dtype_str, device)
# triton kernel
@@ -1720,6 +1754,9 @@ scan_layouts = [
@pytest.mark.parametrize("src_layout", scan_layouts)
@pytest.mark.parametrize("axis", [0, 1])
def test_scan_layouts(M, N, src_layout, axis, device):
if is_hip():
pytest.skip("test_scan_layouts is not supported in HIP")
ir = f"""
#blocked = {src_layout}
module attributes {{"triton_gpu.num-warps" = 4 : i32, "triton_gpu.num-ctas" = 1 : i32, "triton_gpu.threads-per-warp" = 32 : i32}} {{
@@ -1783,6 +1820,9 @@ layouts = [
@pytest.mark.parametrize("src_layout", layouts)
@pytest.mark.parametrize("axis", [0, 1])
def test_reduce_layouts(M, N, src_layout, axis, device):
if is_hip():
pytest.skip("test_reduce_layouts is not supported in HIP")
rdims_2d = f"1x{N}" if axis == 0 else f"{M}x1"
rdims_1d = f"{N}" if axis == 0 else f"{M}"
store_range = "%7" if axis == 0 else "%1"
@@ -1792,28 +1832,28 @@ def test_reduce_layouts(M, N, src_layout, axis, device):
#src = {src_layout}
module attributes {{"triton_gpu.num-warps" = 4 : i32, "triton_gpu.num-ctas" = 1 : i32, "triton_gpu.threads-per-warp" = 32 : i32}} {{
tt.func public @kernel_0d1d2c3d4c(%arg0: !tt.ptr<i32> {{tt.divisibility = 16 : i32}}, %arg1: i32 {{tt.divisibility = 16 : i32}}, %arg2: !tt.ptr<i32> {{tt.divisibility = 16 : i32}}) {{
%0 = tt.make_range {{end = {M} : i32, start = 0 : i32}} : tensor<{M}xi32, #triton_gpu.slice<{{dim = 1, parent = #blocked}}>>
%1 = tt.expand_dims %0 {{axis = 1 : i32}} : (tensor<{M}xi32, #triton_gpu.slice<{{dim = 1, parent = #blocked}}>>) -> tensor<{M}x1xi32, #blocked>
%0 = tt.make_range {{end = {M} : i32, start = 0 : i32}} : tensor<{M}xi32, #{GPU_DIALECT}.slice<{{dim = 1, parent = #blocked}}>>
%1 = tt.expand_dims %0 {{axis = 1 : i32}} : (tensor<{M}xi32, #{GPU_DIALECT}.slice<{{dim = 1, parent = #blocked}}>>) -> tensor<{M}x1xi32, #blocked>
%2 = tt.splat %arg1 : (i32) -> tensor<{M}x1xi32, #blocked>
%3 = arith.muli %1, %2 : tensor<{M}x1xi32, #blocked>
%4 = tt.splat %arg0 : (!tt.ptr<i32>) -> tensor<{M}x1x!tt.ptr<i32>, #blocked>
%5 = tt.addptr %4, %3 : tensor<{M}x1x!tt.ptr<i32>, #blocked>, tensor<{M}x1xi32, #blocked>
%6 = tt.make_range {{end = {N} : i32, start = 0 : i32}} : tensor<{N}xi32, #triton_gpu.slice<{{dim = 0, parent = #blocked}}>>
%7 = tt.expand_dims %6 {{axis = 0 : i32}} : (tensor<{N}xi32, #triton_gpu.slice<{{dim = 0, parent = #blocked}}>>) -> tensor<1x{N}xi32, #blocked>
%6 = tt.make_range {{end = {N} : i32, start = 0 : i32}} : tensor<{N}xi32, #{GPU_DIALECT}.slice<{{dim = 0, parent = #blocked}}>>
%7 = tt.expand_dims %6 {{axis = 0 : i32}} : (tensor<{N}xi32, #{GPU_DIALECT}.slice<{{dim = 0, parent = #blocked}}>>) -> tensor<1x{N}xi32, #blocked>
%8 = tt.broadcast %5 : (tensor<{M}x1x!tt.ptr<i32>, #blocked>) -> tensor<{M}x{N}x!tt.ptr<i32>, #blocked>
%9 = tt.broadcast %7 : (tensor<1x{N}xi32, #blocked>) -> tensor<{M}x{N}xi32, #blocked>
%10 = tt.addptr %8, %9 : tensor<{M}x{N}x!tt.ptr<i32>, #blocked>, tensor<{M}x{N}xi32, #blocked>
%11 = tt.splat %arg2 : (!tt.ptr<i32>) -> tensor<{rdims_2d}x!tt.ptr<i32>, #blocked>
%12 = tt.addptr %11, {store_range} : tensor<{rdims_2d}x!tt.ptr<i32>, #blocked>, tensor<{rdims_2d}xi32, #blocked>
%13 = tt.load %10 {{cache = 1 : i32, evict = 1 : i32, isVolatile = false}} : tensor<{M}x{N}xi32, #blocked>
%14 = triton_gpu.convert_layout %13 : (tensor<{M}x{N}xi32, #blocked>) -> tensor<{M}x{N}xi32, #src>
%14 = {GPU_DIALECT}.convert_layout %13 : (tensor<{M}x{N}xi32, #blocked>) -> tensor<{M}x{N}xi32, #src>
%15 = "tt.reduce"(%14) ({{
^bb0(%arg3: i32, %arg4: i32):
%17 = arith.addi %arg3, %arg4 : i32
tt.reduce.return %17 : i32
}}) {{axis = {axis} : i32}} : (tensor<{M}x{N}xi32, #src>) -> tensor<{rdims_1d}xi32, #triton_gpu.slice<{{dim = {axis}, parent = #src}}>>
%18 = triton_gpu.convert_layout %15 : (tensor<{rdims_1d}xi32, #triton_gpu.slice<{{dim = {axis}, parent = #src}}>>) -> tensor<{rdims_1d}xi32, #triton_gpu.slice<{{dim = {axis}, parent = #blocked}}>>
%19 = tt.expand_dims %18 {{axis = {axis} : i32}} : (tensor<{rdims_1d}xi32, #triton_gpu.slice<{{dim = {axis}, parent = #blocked}}>>) -> tensor<{rdims_2d}xi32, #blocked>
}}) {{axis = {axis} : i32}} : (tensor<{M}x{N}xi32, #src>) -> tensor<{rdims_1d}xi32, #{GPU_DIALECT}.slice<{{dim = {axis}, parent = #src}}>>
%18 = {GPU_DIALECT}.convert_layout %15 : (tensor<{rdims_1d}xi32, #{GPU_DIALECT}.slice<{{dim = {axis}, parent = #src}}>>) -> tensor<{rdims_1d}xi32, #{GPU_DIALECT}.slice<{{dim = {axis}, parent = #blocked}}>>
%19 = tt.expand_dims %18 {{axis = {axis} : i32}} : (tensor<{rdims_1d}xi32, #{GPU_DIALECT}.slice<{{dim = {axis}, parent = #blocked}}>>) -> tensor<{rdims_2d}xi32, #blocked>
tt.store %12, %19 {{cache = 1 : i32, evict = 1 : i32}} : tensor<{rdims_2d}xi32, #blocked>
tt.return
}}
@@ -1854,17 +1894,20 @@ layouts = [
@pytest.mark.parametrize("M", [32, 64, 128, 256])
@pytest.mark.parametrize("src_layout", layouts)
def test_store_op(M, src_layout, device):
if is_hip():
pytest.skip("test_convert1d is not supported yet in HIP")
ir = f"""
#src = {src_layout}
module attributes {{"triton_gpu.num-warps" = 4 : i32, "triton_gpu.num-ctas" = 1 : i32, "triton_gpu.threads-per-warp" = 32 : i32}} {{
module attributes {{"{GPU_DIALECT}.num-warps" = 4 : i32, "{GPU_DIALECT}.num-ctas" = 1 : i32, "{GPU_DIALECT}.threads-per-warp" = {THREADS_PER_WARP} : i32}} {{
tt.func public @kernel(%arg0: !tt.ptr<f32> {{tt.divisibility = 16 : i32}}, %arg1: !tt.ptr<f32> {{tt.divisibility = 16 : i32}}) {{
%0 = tt.make_range {{end = {M} : i32, start = 0 : i32}} : tensor<{M}xi32, #triton_gpu.slice<{{dim = 1, parent = #src}}>>
%1 = tt.splat %arg0 : (!tt.ptr<f32>) -> tensor<{M}x!tt.ptr<f32>, #triton_gpu.slice<{{dim = 1, parent = #src}}>>
%2 = tt.addptr %1, %0 : tensor<{M}x!tt.ptr<f32>, #triton_gpu.slice<{{dim = 1, parent = #src}}>>, tensor<{M}xi32, #triton_gpu.slice<{{dim = 1, parent = #src}}>>
%3 = tt.load %2 {{cache = 1 : i32, evict = 1 : i32, isVolatile = false}} : tensor<{M}xf32, #triton_gpu.slice<{{dim = 1, parent = #src}}>>
%4 = tt.expand_dims %3 {{axis = 1 : i32}} : (tensor<{M}xf32, #triton_gpu.slice<{{dim = 1, parent = #src}}>>) -> tensor<{M}x1xf32, #src>
%5 = tt.make_range {{end = {M} : i32, start = 0 : i32}} : tensor<{M}xi32, #triton_gpu.slice<{{dim = 1, parent = #src}}>>
%6 = tt.expand_dims %5 {{axis = 1 : i32}} : (tensor<{M}xi32, #triton_gpu.slice<{{dim = 1, parent = #src}}>>) -> tensor<{M}x1xi32, #src>
%0 = tt.make_range {{end = {M} : i32, start = 0 : i32}} : tensor<{M}xi32, #{GPU_DIALECT}.slice<{{dim = 1, parent = #src}}>>
%1 = tt.splat %arg0 : (!tt.ptr<f32>) -> tensor<{M}x!tt.ptr<f32>, #{GPU_DIALECT}.slice<{{dim = 1, parent = #src}}>>
%2 = tt.addptr %1, %0 : tensor<{M}x!tt.ptr<f32>, #{GPU_DIALECT}.slice<{{dim = 1, parent = #src}}>>, tensor<{M}xi32, #{GPU_DIALECT}.slice<{{dim = 1, parent = #src}}>>
%3 = tt.load %2 {{cache = 1 : i32, evict = 1 : i32, isVolatile = false}} : tensor<{M}xf32, #{GPU_DIALECT}.slice<{{dim = 1, parent = #src}}>>
%4 = tt.expand_dims %3 {{axis = 1 : i32}} : (tensor<{M}xf32, #{GPU_DIALECT}.slice<{{dim = 1, parent = #src}}>>) -> tensor<{M}x1xf32, #src>
%5 = tt.make_range {{end = {M} : i32, start = 0 : i32}} : tensor<{M}xi32, #{GPU_DIALECT}.slice<{{dim = 1, parent = #src}}>>
%6 = tt.expand_dims %5 {{axis = 1 : i32}} : (tensor<{M}xi32, #{GPU_DIALECT}.slice<{{dim = 1, parent = #src}}>>) -> tensor<{M}x1xi32, #src>
%7 = tt.splat %arg1 : (!tt.ptr<f32>) -> tensor<{M}x1x!tt.ptr<f32>, #src>
%8 = tt.addptr %7, %6 : tensor<{M}x1x!tt.ptr<f32>, #src>, tensor<{M}x1xi32, #src>
tt.store %8, %4 : tensor<{M}x1xf32, #src>
@@ -1903,20 +1946,23 @@ layouts = [
@pytest.mark.parametrize("src_dim", [0, 1])
@pytest.mark.parametrize("dst_dim", [0, 1])
def test_convert1d(M, src_layout, dst_layout, src_dim, dst_dim, device):
if is_hip():
pytest.skip("test_convert1d is not supported in HIP")
ir = f"""
#dst = {dst_layout}
#src = {src_layout}
module attributes {{"triton_gpu.num-warps" = 4 : i32, "triton_gpu.num-ctas" = 1 : i32, "triton_gpu.threads-per-warp" = 32 : i32}} {{
module attributes {{"{GPU_DIALECT}.num-warps" = 4 : i32, "triton_gpu.num-ctas" = 1 : i32, "triton_gpu.threads-per-warp" = {THREADS_PER_WARP} : i32}} {{
tt.func public @kernel(%arg0: !tt.ptr<i32> {{tt.divisibility = 16 : i32}}, %arg1: !tt.ptr<i32> {{tt.divisibility = 16 : i32}}) {{
%0 = tt.splat %arg0 : (!tt.ptr<i32>) -> tensor<{M}x!tt.ptr<i32>, #triton_gpu.slice<{{dim = {src_dim}, parent = #src}}>>
%1 = tt.make_range {{end = {M} : i32, start = 0 : i32}} : tensor<{M}xi32, #triton_gpu.slice<{{dim = {src_dim}, parent = #src}}>>
%2 = tt.addptr %0, %1 : tensor<{M}x!tt.ptr<i32>, #triton_gpu.slice<{{dim = {src_dim}, parent = #src}}>>, tensor<{M}xi32, #triton_gpu.slice<{{dim = {src_dim}, parent = #src}}>>
%3 = tt.load %2 {{cache = 1 : i32, evict = 1 : i32, isVolatile = false}} : tensor<{M}xi32, #triton_gpu.slice<{{dim = {src_dim}, parent = #src}}>>
%4 = tt.splat %arg1 : (!tt.ptr<i32>) -> tensor<{M}x!tt.ptr<i32>, #triton_gpu.slice<{{dim = {dst_dim}, parent = #dst}}>>
%5 = tt.make_range {{end = {M} : i32, start = 0 : i32}} : tensor<{M}xi32, #triton_gpu.slice<{{dim = {dst_dim}, parent = #dst}}>>
%6 = tt.addptr %4, %5 : tensor<{M}x!tt.ptr<i32>, #triton_gpu.slice<{{dim = {dst_dim}, parent = #dst}}>>, tensor<{M}xi32, #triton_gpu.slice<{{dim = {dst_dim}, parent = #dst}}>>
%7 = triton_gpu.convert_layout %3 : (tensor<{M}xi32, #triton_gpu.slice<{{dim = {src_dim}, parent = #src}}>>) -> tensor<{M}xi32, #triton_gpu.slice<{{dim = {dst_dim}, parent = #dst}}>>
tt.store %6, %7 : tensor<{M}xi32, #triton_gpu.slice<{{dim = {dst_dim}, parent = #dst}}>>
%0 = tt.splat %arg0 : (!tt.ptr<i32>) -> tensor<{M}x!tt.ptr<i32>, #{GPU_DIALECT}.slice<{{dim = {src_dim}, parent = #src}}>>
%1 = tt.make_range {{end = {M} : i32, start = 0 : i32}} : tensor<{M}xi32, #{GPU_DIALECT}.slice<{{dim = {src_dim}, parent = #src}}>>
%2 = tt.addptr %0, %1 : tensor<{M}x!tt.ptr<i32>, #{GPU_DIALECT}.slice<{{dim = {src_dim}, parent = #src}}>>, tensor<{M}xi32, #{GPU_DIALECT}.slice<{{dim = {src_dim}, parent = #src}}>>
%3 = tt.load %2 {{cache = 1 : i32, evict = 1 : i32, isVolatile = false}} : tensor<{M}xi32, #{GPU_DIALECT}.slice<{{dim = {src_dim}, parent = #src}}>>
%4 = tt.splat %arg1 : (!tt.ptr<i32>) -> tensor<{M}x!tt.ptr<i32>, #{GPU_DIALECT}.slice<{{dim = {dst_dim}, parent = #dst}}>>
%5 = tt.make_range {{end = {M} : i32, start = 0 : i32}} : tensor<{M}xi32, #{GPU_DIALECT}.slice<{{dim = {dst_dim}, parent = #dst}}>>
%6 = tt.addptr %4, %5 : tensor<{M}x!tt.ptr<i32>, #{GPU_DIALECT}.slice<{{dim = {dst_dim}, parent = #dst}}>>, tensor<{M}xi32, #{GPU_DIALECT}.slice<{{dim = {dst_dim}, parent = #dst}}>>
%7 = {GPU_DIALECT}.convert_layout %3 : (tensor<{M}xi32, #{GPU_DIALECT}.slice<{{dim = {src_dim}, parent = #src}}>>) -> tensor<{M}xi32, #{GPU_DIALECT}.slice<{{dim = {dst_dim}, parent = #dst}}>>
tt.store %6, %7 : tensor<{M}xi32, #{GPU_DIALECT}.slice<{{dim = {dst_dim}, parent = #dst}}>>
tt.return
}}
}}
@@ -1962,6 +2008,9 @@ layouts = [
@pytest.mark.parametrize("op", ["sum", "max"])
@pytest.mark.parametrize("first_axis", [0, 1])
def test_chain_reduce(M, N, src_layout, op, device, first_axis):
if is_hip():
pytest.skip("test_chain_reduce is not supported in HIP")
op_str = ""
if op == "sum":
op_str = f"""
@@ -1969,19 +2018,19 @@ def test_chain_reduce(M, N, src_layout, op, device, first_axis):
tt.reduce.return %13 : i32"""
elif op == "max":
op_str = f"""
%13 = "triton_gpu.cmpi"(%arg2, %arg3) <{{predicate = 4 : i64}}> : (i32, i32) -> i1
%13 = "{GPU_DIALECT}.cmpi"(%arg2, %arg3) <{{predicate = 4 : i64}}> : (i32, i32) -> i1
%14 = arith.select %13, %arg2, %arg3 : i32
tt.reduce.return %14 : i32"""
ir = f"""
#src = {src_layout}
module attributes {{"triton_gpu.num-warps" = 4 : i32, "triton_gpu.num-ctas" = 1 : i32, "triton_gpu.threads-per-warp" = 32 : i32}} {{
module attributes {{"{GPU_DIALECT}.num-warps" = 4 : i32, "triton_gpu.num-ctas" = 1 : i32, "triton_gpu.threads-per-warp" = {THREADS_PER_WARP} : i32}} {{
tt.func public @sum_kernel_0d1d(%arg0: !tt.ptr<i32> {{tt.divisibility = 16 : i32}}, %arg1: !tt.ptr<i32> {{tt.divisibility = 16 : i32}}) {{
%cst = arith.constant dense<{N}> : tensor<{M}x1xi32, #src>
%0 = tt.make_range {{end = {M} : i32, start = 0 : i32}} : tensor<{M}xi32, #triton_gpu.slice<{{dim = 1, parent = #src}}>>
%1 = tt.expand_dims %0 {{axis = 1 : i32}} : (tensor<{M}xi32, #triton_gpu.slice<{{dim = 1, parent = #src}}>>) -> tensor<{M}x1xi32, #src>
%0 = tt.make_range {{end = {M} : i32, start = 0 : i32}} : tensor<{M}xi32, #{GPU_DIALECT}.slice<{{dim = 1, parent = #src}}>>
%1 = tt.expand_dims %0 {{axis = 1 : i32}} : (tensor<{M}xi32, #{GPU_DIALECT}.slice<{{dim = 1, parent = #src}}>>) -> tensor<{M}x1xi32, #src>
%2 = arith.muli %1, %cst : tensor<{M}x1xi32, #src>
%3 = tt.make_range {{end = {N} : i32, start = 0 : i32}} : tensor<{N}xi32, #triton_gpu.slice<{{dim = 0, parent = #src}}>>
%4 = tt.expand_dims %3 {{axis = 0 : i32}} : (tensor<{N}xi32, #triton_gpu.slice<{{dim = 0, parent = #src}}>>) -> tensor<1x{N}xi32, #src>
%3 = tt.make_range {{end = {N} : i32, start = 0 : i32}} : tensor<{N}xi32, #{GPU_DIALECT}.slice<{{dim = 0, parent = #src}}>>
%4 = tt.expand_dims %3 {{axis = 0 : i32}} : (tensor<{N}xi32, #{GPU_DIALECT}.slice<{{dim = 0, parent = #src}}>>) -> tensor<1x{N}xi32, #src>
%5 = tt.broadcast %2 : (tensor<{M}x1xi32, #src>) -> tensor<{M}x{N}xi32, #src>
%6 = tt.broadcast %4 : (tensor<1x{N}xi32, #src>) -> tensor<{M}x{N}xi32, #src>
%7 = arith.addi %5, %6 : tensor<{M}x{N}xi32, #src>
@@ -1991,11 +2040,11 @@ def test_chain_reduce(M, N, src_layout, op, device, first_axis):
%11 = "tt.reduce"(%10) ({{
^bb0(%arg2: i32, %arg3: i32):
{op_str}
}}) {{axis = {first_axis} : i32}} : (tensor<{M}x{N}xi32, #src>) -> tensor<{M if first_axis == 1 else N}xi32, #triton_gpu.slice<{{dim = {first_axis}, parent = #src}}>>
}}) {{axis = {first_axis} : i32}} : (tensor<{M}x{N}xi32, #src>) -> tensor<{M if first_axis == 1 else N}xi32, #{GPU_DIALECT}.slice<{{dim = {first_axis}, parent = #src}}>>
%12 = "tt.reduce"(%11) ({{
^bb0(%arg2: i32, %arg3: i32):
{op_str}
}}) {{axis = 0 : i32}} : (tensor<{M if first_axis == 1 else N}xi32, #triton_gpu.slice<{{dim = {first_axis}, parent = #src}}>>) -> i32
}}) {{axis = 0 : i32}} : (tensor<{M if first_axis == 1 else N}xi32, #{GPU_DIALECT}.slice<{{dim = {first_axis}, parent = #src}}>>) -> i32
tt.store %arg1, %12 {{cache = 1 : i32, evict = 1 : i32}} : i32
tt.return
}}
@@ -2063,6 +2112,8 @@ def test_generic_reduction(device):
@pytest.mark.parametrize("num_ctas", num_ctas_list)
def test_permute(dtype_str, shape, perm, num_ctas, device):
check_type_supported(dtype_str, device) # bfloat16 on cc < 80 will not be tested
if is_hip():
pytest.skip(f"test_permute is not supported in HIP")
# triton kernel
@triton.jit
@@ -2099,6 +2150,10 @@ def test_permute(dtype_str, shape, perm, num_ctas, device):
# compare
np.testing.assert_allclose(to_numpy(z_tri), z_ref)
np.testing.assert_allclose(to_numpy(z_tri_contiguous), z_ref)
if is_hip():
return
# parse ptx to make sure ld/st are vectorized
ptx = pgm.asm['ptx']
assert 'ld.global.v4' in ptx
@@ -2115,7 +2170,7 @@ def test_permute(dtype_str, shape, perm, num_ctas, device):
@pytest.mark.parametrize("M, N, K, num_warps, col_a, col_b, epilogue, allow_tf32, in_dtype, out_dtype",
[(*shape, 4, False, False, epilogue, allow_tf32, in_dtype, out_dtype)
for shape in [(64, 64, 64), (16, 16, 16)]
for shape in [(64, 64, 64), (32, 32, 32), (16, 16, 16)]
for epilogue in ['none', 'trans', 'add-matrix', 'add-rows', 'add-cols', 'softmax', 'chain-dot']
for allow_tf32 in [True, False]
for in_dtype, out_dtype in [('float16', 'float16'),
@@ -2146,6 +2201,17 @@ def test_dot(M, N, K, num_warps, col_a, col_b, epilogue, allow_tf32, in_dtype, o
check_cuda_only(device)
capability = torch.cuda.get_device_capability()
if is_hip():
# set capability to large number to jump over check below
# check are not relevant to amd gpu, left them for smaller diff between test_core.py and test_core_amd.py tests
capability = (100, 100)
if out_dtype is None:
if in_dtype in float_dtypes:
out_dtype = "float32"
else:
out_dtype = "int32"
if capability[0] < 7:
pytest.skip("Only test tl.dot() on devices with sm >= 70")
if capability[0] < 8:
@@ -2160,6 +2226,16 @@ def test_dot(M, N, K, num_warps, col_a, col_b, epilogue, allow_tf32, in_dtype, o
# TODO: support out_dtype=float16 for tl.dot on V100
pytest.skip("Only test out_dtype=float16 on devices with sm >=80")
if is_hip():
if (M, N, K) in [(64, 128, 128)]:
pytest.skip(f"test_dot{(M, N, K)} not supported on HIP: memory out of resource.")
if (M, N, K, num_warps) in [(128, 256, 32, 8), (128, 128, 64, 4)]:
pytest.skip(f"test_dot{(M, N, K)} not supported on HIP. Reduce Warp to work")
if M == 16 or N == 16 or K == 16:
pytest.skip(f"test_dot{(M, N, K)} segfaults on HIP")
if epilogue == "softmax":
pytest.skip(f"test_dot{epilogue} segfaults on HIP")
torch.backends.cuda.matmul.allow_tf32 = allow_tf32
if num_ctas > 1 and in_dtype == 'int8':
@@ -2247,6 +2323,7 @@ def test_dot(M, N, K, num_warps, col_a, col_b, epilogue, allow_tf32, in_dtype, o
out_dtype = tl.float16
else:
out_dtype = tl.float32
pgm = kernel[(1, 1)](x_tri, x_tri.stride(0), x_tri.stride(1),
y_tri, y_tri.stride(0), y_tri.stride(1),
w_tri, w_tri.stride(0), w_tri.stride(1),
@@ -2261,20 +2338,24 @@ def test_dot(M, N, K, num_warps, col_a, col_b, epilogue, allow_tf32, in_dtype, o
ALLOW_TF32=allow_tf32,
num_warps=num_warps, num_ctas=num_ctas,
out_dtype=out_dtype)
if epilogue == 'softmax' and (in_dtype != 'float32' or allow_tf32):
ptx = pgm.asm["ptx"]
start = ptx.find("shfl.sync")
end = ptx.find("cvt.rn.f16.f32")
red_code = ptx[start:end]
assert len(red_code) > 0
import os
enable_mmav3 = os.environ.get('ENABLE_MMA_V3', 'not found').lower()
enable_tma = os.environ.get('ENABLE_TMA', 'not found').lower()
# skip this check on hopper because there are some functions whose name contain "shared" in ptx.
# TODO: we should eliminate these unused functions in ptx code.
if not (enable_mmav3 in ["on", "true", "1"] and enable_tma in ["on", "true", "1"]):
assert "shared" not in red_code
assert "bar.sync" not in red_code
if is_hip():
pass
else:
ptx = pgm.asm["ptx"]
start = ptx.find("shfl.sync")
end = ptx.find("cvt.rn.f16.f32")
red_code = ptx[start:end]
assert len(red_code) > 0
import os
enable_mmav3 = os.environ.get('ENABLE_MMA_V3', 'not found').lower()
enable_tma = os.environ.get('ENABLE_TMA', 'not found').lower()
# skip this check on hopper because there are some functions whose name contain "shared" in ptx.
# TODO: we should eliminate these unused functions in ptx code.
if not (enable_mmav3 in ["on", "true", "1"] and enable_tma in ["on", "true", "1"]):
assert "shared" not in red_code
assert "bar.sync" not in red_code
# torch result
if in_dtype == 'int8':
z_ref = np.matmul(x.astype(np.float32),
@@ -2300,9 +2381,12 @@ def test_dot(M, N, K, num_warps, col_a, col_b, epilogue, allow_tf32, in_dtype, o
# XXX: Somehow there's a larger difference when we use float32
np.testing.assert_allclose(z_ref, to_numpy(z_tri), rtol=0.01, atol=1e-3)
elif out_dtype == tl.float16:
np.testing.assert_allclose(z_ref, to_numpy(z_tri), rtol=0.01, atol=1e-3)
np.testing.assert_allclose(z_ref, to_numpy(z_tri), rtol=0.01, atol=1e-2)
else:
np.testing.assert_allclose(z_ref, to_numpy(z_tri), rtol=0.01)
# added atol, to loose precision for float16xfloat16->float32 case
np.testing.assert_allclose(z_ref, to_numpy(z_tri), rtol=0.01, atol=1e-3)
if is_hip():
return
# make sure ld/st are vectorized
ptx = pgm.asm['ptx']
if (K > 16 or N > 16 or M > 16) and (M * N // (num_warps * 32) >= 4):
@@ -2366,6 +2450,9 @@ def test_dot_mulbroadcastred(in_dtype, device):
h = kernel[grid](z_tri, x_tri, y_tri, M, N, K, BM, BN, BK)
z_ref = np.matmul(x, y)
np.testing.assert_allclose(z_ref, to_numpy(z_tri), atol=0.01)
if is_hip():
return
assert "tt.dot" in h.asm['ttir']
# with option ENABLE_MMA_V3 on, we will not pipeline the load op for Y
# as the loaded value is in rowmajor. But MMAv3 requires it's second
@@ -2432,6 +2519,9 @@ def test_dot_without_load(dtype_str, device):
capability = torch.cuda.get_device_capability()
allow_tf32 = capability[0] > 7
if is_hip() and dtype_str == "float16":
pytest.skip("test_dot_without_load[float16] not supported in HIP")
@triton.jit
def _kernel(out, ALLOW_TF32: tl.constexpr):
a = GENERATE_TEST_HERE
@@ -2512,6 +2602,9 @@ def test_masked_load(dtype_str, size, size_diff, num_ctas, device):
# FIXME: Shape too small for ldmatrix when num_ctas=4
@pytest.mark.parametrize("dtype", [torch.bfloat16, torch.float16, torch.float32])
def test_masked_load_shared_memory(dtype, device):
if is_hip():
pytest.skip("test_masked_load_shared_memory is not supported in HIP")
check_type_supported(dtype, device) # bfloat16 on cc < 80 will not be tested
M = 32
@@ -2571,6 +2664,9 @@ def test_load_cache_modifier(cache, device):
tl.store(dst + offsets, x)
pgm = _kernel[(1,)](dst, src, CACHE=cache)
if is_hip():
return
ptx = pgm.asm['ptx']
if cache == '':
assert 'ld.global.ca' not in ptx
@@ -2597,6 +2693,10 @@ def test_vectorization(N, num_ctas, device):
tl.store(dst + offsets, x, mask=offsets < N)
pgm = _kernel[(1,)](
dst, src, N=N, BLOCK_SIZE=block_size)
if is_hip():
return
ptx = pgm.asm["ptx"]
if N % 16 == 0:
assert "ld.global.v4.b32" in ptx
@@ -2620,6 +2720,9 @@ def test_vectorization_hints(has_hints, device):
x = tl.load(src + offsets, mask=offsets < N)
tl.store(dst + offsets, x, mask=offsets < N)
pgm = _kernel[(1,)](dst, src, off, N=1024, BLOCK_SIZE=src.shape[0], HINT=has_hints)
if is_hip():
return
ptx = pgm.asm["ptx"]
if has_hints:
assert "ld.global.v4.b32" in ptx
@@ -2642,6 +2745,8 @@ def test_store_cache_modifier(cache):
x = tl.load(src + offsets)
tl.store(dst + offsets, x, cache_modifier=CACHE)
if is_hip():
return
pgm = _kernel[(1,)](dst, src, CACHE=cache)
ptx = pgm.asm['ptx']
if cache == '':
@@ -2793,6 +2898,9 @@ def test_value_specialization_overflow(value: int, overflow: bool, device) -> No
@pytest.mark.parametrize("is_lhs_constexpr", [False, True])
@pytest.mark.parametrize("is_rhs_constexpr", [True, False])
def test_bin_op_constexpr(op, is_lhs_constexpr, is_rhs_constexpr, device):
if is_hip():
if (is_rhs_constexpr, is_lhs_constexpr, op) in [(False, False, "<<"), (False, False, ">>"), (False, True, "<<")]:
pytest.skip(f"test_bin_op_constexpr[{is_lhs_constexpr}-{is_rhs_constexpr}-{op}] is not supported in HIP")
@triton.jit
def kernel(Z, X, Y):
@@ -2968,6 +3076,9 @@ def test_num_warps_pow2(device):
@pytest.mark.parametrize("num_ctas", num_ctas_list)
def test_math_tensor(dtype_str, expr, lib_path, num_ctas, device):
if is_hip() and expr == "math.scalbn":
pytest.skip("test_math_tensor[math.scalbn] is not supported in HIP")
@triton.jit
def kernel(X, Y, BLOCK: tl.constexpr):
x = tl.load(X + tl.arange(0, BLOCK))
@@ -3063,6 +3174,9 @@ def test_math_scalar(dtype_str, expr, lib_path, num_ctas, device):
def test_inline_asm(num_ctas, device):
check_cuda_only(device)
if is_hip():
pytest.skip("test_inline_asm is not supported in HIP")
@triton.jit
def kernel(X, Y, Z, n: tl.constexpr, BLOCK: tl.constexpr):
x = tl.load(X + tl.arange(0, BLOCK))
@@ -3089,6 +3203,9 @@ def test_inline_asm(num_ctas, device):
def test_inline_asm_packed(num_ctas, device):
check_cuda_only(device)
if is_hip():
pytest.skip("test_inline_asm is not supported in HIP")
@triton.jit
def kernel(X, Y, BLOCK: tl.constexpr):
x = tl.load(X + tl.arange(0, BLOCK))
@@ -3392,6 +3509,8 @@ def test_while(device):
def test_globaltimer(device):
if is_hip():
pytest.skip("test_globaltimer is not supported in HIP")
check_cuda_only(device)
@triton.jit
@@ -3411,6 +3530,8 @@ def test_globaltimer(device):
def test_smid(device):
if is_hip():
pytest.skip("test_smid is not supported in HIP")
check_cuda_only(device)
@triton.jit
@@ -3456,6 +3577,9 @@ intermediate_layouts = [
@pytest.mark.parametrize("interm_layout", intermediate_layouts)
@pytest.mark.parametrize("dst_layout", layouts)
def test_convert2d(dtype, shape, src_layout, interm_layout, dst_layout, device):
if is_hip():
pytest.skip("test_convert2d is not supported in HIP")
if str(src_layout) == str(dst_layout):
pytest.skip()
if 'mma' in str(src_layout) and 'mma' in str(dst_layout):

View File

@@ -5,6 +5,7 @@ import importlib.util
import os
import re
import subprocess
import traceback
from typing import Dict
from ..runtime.driver import DriverBase
@@ -94,7 +95,7 @@ def get_backend(device_type: str):
try:
importlib.import_module(device_backend_package_name, package=__spec__.name)
except Exception:
return None
traceback.print_exc()
else:
return None
return _backends[device_type] if device_type in _backends else None

View File

@@ -5,19 +5,18 @@ import hashlib
import json
import os
import re
import subprocess
import tempfile
from collections import namedtuple
from pathlib import Path
from typing import Any, Tuple
from typing import Any
from .._C.libtriton.triton import (ClusterInfo, TMAInfos, add_external_libs,
compile_ptx_to_cubin, get_env_vars, get_num_warps,
get_shared_memory_size, ir, runtime,
translate_llvmir_to_hsaco, translate_llvmir_to_ptx,
translate_triton_gpu_to_llvmir, get_arch_info,
get_warp_size)
translate_llvmir_to_ptx,
translate_triton_gpu_to_llvmir)
from ..common.backend import get_backend, path_to_ptxas
from ..common.build import is_hip
# from ..runtime import driver, jit, JITFunction
# TODO: runtime.errors
from ..runtime.autotuner import OutOfResources
@@ -214,71 +213,6 @@ def ptx_to_cubin(ptx: str, arch: int):
return compile_ptx_to_cubin(ptx, ptxas, arch)
# AMDGCN translation
def get_amdgcn_bitcode_paths(arch):
gpu_arch_agnostic_bitcode_libraries = ["opencl.bc",
"ocml.bc",
"ockl.bc",
"oclc_finite_only_off.bc",
"oclc_daz_opt_off.bc",
"oclc_correctly_rounded_sqrt_on.bc",
"oclc_unsafe_math_off.bc",
"oclc_wavefrontsize64_on.bc",
"oclc_abi_version_400.bc",]
gfx_arch = arch[1]
gfx_arch_id = re.search('gfx(\\w+)', gfx_arch).group(1).strip()
gpu_arch_specific_bitcode_library = 'oclc_isa_version_' + gfx_arch_id + ".bc"
bitcode_path_dir = os.path.join(Path(__file__).parent.parent.resolve(), "third_party/rocm/lib/bitcode/")
amdgcn_bitcode_paths = {}
i = 0
for bc_lib in gpu_arch_agnostic_bitcode_libraries:
bc_path = bitcode_path_dir + bc_lib
if os.path.exists(bc_path):
amdgcn_bitcode_paths['library_' + str(i)] = bc_path
i += 1
bc_gfx_path = bitcode_path_dir + gpu_arch_specific_bitcode_library
if os.path.exists(bc_gfx_path):
amdgcn_bitcode_paths['library_' + str(i)] = bc_gfx_path
return amdgcn_bitcode_paths
def get_amdgpu_arch_fulldetails():
"""
get the amdgpu full ISA details for compiling:
i.e., arch_triple: amdgcn-amd-amdhsa; arch_name: gfx906; arch_features: sramecc+:xnack-
"""
try:
# TODO: package rocm.cc with Triton
arch_info = get_arch_info()
warp_size = get_warp_size()
gfx_arch_details = re.search('amd.*', arch_info).group(0).strip().split('--')
arch_triple = gfx_arch_details[0]
arch_name_features = gfx_arch_details[1].split(':')
arch_name = arch_name_features[0]
arch_features = ""
return [arch_triple, arch_name, arch_features, warp_size]
except BaseException as e:
print("Error: Attempting to get amgpu ISA Details {}".format(e))
return None
def llir_to_amdgcn_and_hsaco(mod: Any, gfx_arch: str, gfx_triple: str, gfx_features: str) -> Tuple[str, str]:
'''
Translate TritonGPU module to HSACO code based on full details of gpu architecture.
:param mod: a TritonGPU dialect module
:return:
- AMDGCN code
- Path to HSACO object
'''
return translate_llvmir_to_hsaco(mod, gfx_arch, gfx_triple, gfx_features)
# ------------------------------------------------------------------------------
# compiler
# ------------------------------------------------------------------------------
@@ -347,8 +281,10 @@ arg_type_pattern = {
"ttgir": mlir_arg_type_pattern,
"ptx": ptx_arg_type_pattern,
}
ttgir_num_warps_pattern = r'"triton_gpu.num-warps"\s?=\s?(\d+)\s?:'
if is_hip():
ttgir_num_warps_pattern = r'"triton_gpu_rocm.num-warps"\s?=\s?(\d+)\s?:'
else:
ttgir_num_warps_pattern = r'"triton_gpu.num-warps"\s?=\s?(\d+)\s?:'
def _get_jsonable_constants(constants):
@@ -389,17 +325,10 @@ def is_hip():
from ..language.semantic import gpu_matrix_core_version
def get_architecture_descriptor(capability):
try:
import torch
except ImportError:
raise ImportError("Triton requires PyTorch to be installed")
if capability is None:
if torch.version.hip is None:
device = get_current_device()
capability = get_device_capability(device)
capability = capability[0] * 10 + capability[1]
else:
capability = get_amdgpu_arch_fulldetails()
device = get_current_device()
capability = get_device_capability(device)
capability = capability[0] * 10 + capability[1]
return capability
@@ -429,23 +358,6 @@ def get_arch_default_num_stages(device_type, capability=None):
return num_stages
def add_rocm_stages(arch, extern_libs, stages):
extern_libs.update(get_amdgcn_bitcode_paths(arch))
for key in list(extern_libs):
if extern_libs[key] == '' or extern_libs[key] is None:
extern_libs.pop(key)
gfx_arch_full_details = arch
gfx_arch = os.environ.get('MI_GPU_ARCH', gfx_arch_full_details[1])
if gfx_arch is None:
raise RuntimeError('gfx_arch is None (not specified)')
stages["amdgcn"] = (lambda path: Path(path).read_text(),
lambda src: llir_to_amdgcn_and_hsaco(src, gfx_arch,
gfx_arch_full_details[0],
gfx_arch_full_details[2]))
def add_cuda_stages(arch, extern_libs, stages):
stages["ptx"] = (lambda path: Path(path).read_text(),
@@ -457,23 +369,22 @@ def add_cuda_stages(arch, extern_libs, stages):
def compile(fn, **kwargs):
# Get device type to decide which backend should be used
device_type = kwargs.get("device_type", "cuda")
_device_backend = get_backend(device_type)
capability = kwargs.get("cc", None)
if device_type in ["cuda", "hip"]:
# hip with kwargs.get("cc", None) causes multiprocessing issues in torch.compile
if device_type == "hip":
arch = get_architecture_descriptor(None if type(capability) is int else capability)
else:
arch = get_architecture_descriptor(capability)
if is_hip():
device_type = "hip"
if device_type == "cuda":
_device_backend = get_backend(device_type)
arch = get_architecture_descriptor(capability)
else:
_device_backend = get_backend(device_type)
assert _device_backend
arch = _device_backend.get_architecture_descriptor(**kwargs)
is_cuda = device_type == "cuda" and _is_cuda(arch)
is_hip = device_type in ["cuda", "hip"] and not is_cuda
warp_size = CUDA_DEFAULT_WARP_SIZE if _is_cuda(arch) else arch[3]
if is_hip():
is_cuda = False
context = ir.context()
constants = kwargs.get("constants", dict())
num_warps = kwargs.get("num_warps", get_arch_default_num_warps(device_type))
@@ -506,14 +417,20 @@ def compile(fn, **kwargs):
stages["ast"] = (lambda path: fn, None)
stages["ttir"] = (lambda path: parse_mlir_module(path, context),
lambda src: optimize_ttir(ast_to_ttir(src, signature, configs[0], constants, debug=debug, arch=arch), arch))
stages["ttgir"] = (lambda path: parse_mlir_module(path, context),
lambda src: optimize_ttgir(ttir_to_ttgir(src, num_warps, warp_size, num_ctas, arch), num_stages, num_warps, num_ctas, arch, cluster_info, enable_warp_specialization, enable_persistent, optimize_epilogue, matrix_instr_nonkdim))
stages["llir"] = (lambda path: Path(path).read_text(),
lambda src: ttgir_to_llir(src, extern_libs, arch, tma_infos, waves_per_eu))
if is_cuda:
stages["ttgir"] = (lambda path: parse_mlir_module(path, context),
lambda src: optimize_ttgir(ttir_to_ttgir(src, num_warps, num_ctas, arch), num_stages, num_warps, num_ctas, arch, cluster_info, enable_warp_specialization, enable_persistent, optimize_epilogue))
stages["llir"] = (lambda path: Path(path).read_text(),
lambda src: ttgir_to_llir(src, extern_libs, arch, tma_infos))
add_cuda_stages(arch, extern_libs, stages)
elif is_hip:
add_rocm_stages(arch, extern_libs, stages)
elif device_type == "hip":
_device_backend.add_stages(arch, extern_libs, stages, num_warps=num_warps, num_stages=num_stages)
elif device_type == "xpu":
stages["ttgir"] = (lambda path: parse_mlir_module(path, context),
lambda src: optimize_ttgir(ttir_to_ttgir(src, num_warps, num_ctas, arch), num_stages, num_warps, num_ctas, arch, cluster_info, enable_warp_specialization, enable_persistent, optimize_epilogue))
stages["llir"] = (lambda path: Path(path).read_text(),
lambda src: ttgir_to_llir(src, extern_libs, arch, tma_infos))
_device_backend.add_stages(arch, extern_libs, stages)
else:
# pass the user's configuration to the backend device.
arch["num_warps"] = num_warps
@@ -632,17 +549,23 @@ def compile(fn, **kwargs):
else:
asm[ir_name] = str(next_module)
if ir_name == "llir" and "shared" not in metadata:
metadata["shared"] = get_shared_memory_size(module)
if is_hip():
metadata["shared"] = _device_backend.get_shared_memory_size(module)
else:
metadata["shared"] = get_shared_memory_size(module)
if ir_name == "ttgir":
metadata["enable_warp_specialization"] = ir.is_ws_supported(next_module)
if metadata["enable_warp_specialization"]:
metadata["num_warps"] = get_num_warps(next_module)
if is_hip():
metadata["num_warps"] = _device_backend.get_num_warps(next_module)
else:
metadata["enable_warp_specialization"] = ir.is_ws_supported(next_module)
if metadata["enable_warp_specialization"]:
metadata["num_warps"] = get_num_warps(next_module)
if ir_name == "ptx":
metadata["name"] = get_kernel_name(next_module, pattern='// .globl')
if ir_name == "amdgcn":
metadata["name"] = get_kernel_name(next_module[0], pattern='.globl')
asm["hsaco_path"] = next_module[1]
if not is_cuda and not is_hip:
if not is_cuda and not is_hip():
_device_backend.add_meta_info(ir_name, module, next_module, metadata, asm)
module = next_module
@@ -667,7 +590,7 @@ def compile(fn, **kwargs):
ids_of_const_exprs = tuple(fn.constexprs) if isinstance(fn, JITFunction) else ()
ids = {"ids_of_tensormaps": ids_of_tensormaps, "ids_of_folded_args": ids_of_folded_args, "ids_of_const_exprs": ids_of_const_exprs}
# cache manager
if is_cuda or is_hip:
if is_cuda:
so_path = make_stub(name, signature, constants, ids, enable_warp_specialization=enable_warp_specialization)
else:
so_path = _device_backend.make_launcher_stub(name, signature, constants, ids)
@@ -707,7 +630,7 @@ class CompiledKernel:
self.tensormaps_info = metadata["tensormaps_info"]
self.constants = metadata["constants"]
self.device_type = metadata["device_type"]
self.device_backend = get_backend(self.device_type) if self.device_type not in ["cuda", "hip"] else None
self.device_backend = get_backend(self.device_type) if self.device_type not in ["cuda"] else None
# initialize asm dict
self.asm = asm
# binaries are lazily initialized
@@ -721,7 +644,7 @@ class CompiledKernel:
if self.cu_module is not None:
return
if self.device_type in ["cuda", "hip"]:
if self.device_type in ["cuda"]:
device = get_current_device()
bin_path = {
driver.HIP: "hsaco_path",
@@ -767,7 +690,7 @@ class CompiledKernel:
def runner(*args, stream=None):
args_expand = self.assemble_tensormap_to_arg(args)
if stream is None:
if self.device_type in ["cuda", "hip"]:
if self.device_type in ["cuda"]:
stream = get_cuda_stream()
else:
stream = get_backend(self.device_type).get_stream(None)

View File

@@ -3,16 +3,11 @@ import os
import tempfile
from ..common import _build
from ..common.build import is_hip
from ..runtime.cache import get_cache_manager
from ..runtime.jit import version_key
from .utils import generate_cu_signature
def is_hip():
import torch
return torch.version.hip is not None
# ----- stub --------
@@ -103,150 +98,9 @@ def generate_launcher(constants, signature, ids):
format = "iiiiiiiiiKKOOO" + ''.join([format_of(_extracted_type(ty)) for ty in signature.values()])
# generate glue code
if is_hip():
folded_without_constexprs = [c for c in ids['ids_of_folded_args'] if c not in ids['ids_of_const_exprs']]
params = [i for i in signature.keys() if i >= start_desc or (i not in constants and i not in folded_without_constexprs)]
src = f"""
#define __HIP_PLATFORM_AMD__
#include <hip/hip_runtime.h>
#include <Python.h>
#include <stdbool.h>
#include <dlfcn.h>
static inline void gpuAssert(hipError_t code, const char *file, int line)
{{
if (code != HIP_SUCCESS)
{{
const char* prefix = "Triton Error [HIP]: ";
const char* str = hipGetErrorString(code);
char err[1024] = {{0}};
snprintf(err, 1024, "%s Code: %d, Messsage: %s", prefix, code, str );
PyErr_SetString(PyExc_RuntimeError, err);
}}
}}
#define HIP_CHECK(ans) {{ gpuAssert((ans), __FILE__, __LINE__); }}
static void _launch(int gridX, int gridY, int gridZ, int num_warps, int num_ctas, int clusterDimX, int clusterDimY, int clusterDimZ, int shared_memory, hipStream_t stream, hipFunction_t function{', ' + arg_decls if len(arg_decls) > 0 else ''}) {{
// printf("_launch hip kernel\\n");
void *params[] = {{ {', '.join(f"&arg{i}" for i in params)} }};
if (gridX*gridY*gridZ > 0) {{
HIP_CHECK(hipModuleLaunchKernel(function, gridX, gridY, gridZ, 64*num_warps, 1, 1, shared_memory, stream, params, 0));
}}
}}
typedef struct _DevicePtrInfo {{
hipDeviceptr_t dev_ptr;
bool valid;
}} DevicePtrInfo;
static inline DevicePtrInfo getPointer(PyObject *obj, int idx) {{
DevicePtrInfo ptr_info;
ptr_info.dev_ptr = 0;
ptr_info.valid = true;
if (PyLong_Check(obj)) {{
ptr_info.dev_ptr = (hipDeviceptr_t)PyLong_AsUnsignedLongLong(obj);
return ptr_info;
}}
if (obj == Py_None) {{
// valid nullptr
return ptr_info;
}}
PyObject *ptr = PyObject_GetAttrString(obj, "data_ptr");
if(ptr){{
PyObject *empty_tuple = PyTuple_New(0);
PyObject *ret = PyObject_Call(ptr, empty_tuple, NULL);
Py_DECREF(empty_tuple);
Py_DECREF(ptr);
if (!PyLong_Check(ret)) {{
PyErr_SetString(PyExc_TypeError, "data_ptr method of Pointer object must return 64-bit int");
ptr_info.valid = false;
return ptr_info;
}}
ptr_info.dev_ptr = (hipDeviceptr_t)PyLong_AsUnsignedLongLong(ret);
if(!ptr_info.dev_ptr)
return ptr_info;
uint64_t dev_ptr;
hipError_t status = hipPointerGetAttribute(&dev_ptr, HIP_POINTER_ATTRIBUTE_DEVICE_POINTER, ptr_info.dev_ptr);
if (status == hipErrorInvalidValue) {{
PyErr_Format(PyExc_ValueError,
"Pointer argument (at %d) cannot be accessed from Triton (cpu tensor?)", idx);
ptr_info.valid = false;
}}
ptr_info.dev_ptr = (hipDeviceptr_t)dev_ptr;
Py_DECREF(ret);
return ptr_info;
}}
PyErr_SetString(PyExc_TypeError, "Pointer argument must be either uint64 or have data_ptr method");
return ptr_info;
}}
static PyObject* launch(PyObject* self, PyObject* args) {{
// printf("launch\\n");
int gridX, gridY, gridZ;
uint64_t _stream;
uint64_t _function;
int num_warps;
int num_ctas;
int clusterDimX;
int clusterDimY;
int clusterDimZ;
int shared_memory;
PyObject *launch_enter_hook = NULL;
PyObject *launch_exit_hook = NULL;
PyObject *compiled_kernel = NULL;
{' '.join([f"{_extracted_type(ty)} _arg{i}; " for i, ty in signature.items()])}
if(!PyArg_ParseTuple(args, \"{format}\", &gridX, &gridY, &gridZ, &num_warps, &num_ctas, &clusterDimX, &clusterDimY, &clusterDimZ, &shared_memory, &_stream, &_function, &launch_enter_hook, &launch_exit_hook, &compiled_kernel{', ' + ', '.join(f"&_arg{i}" for i, ty in signature.items()) if len(signature) > 0 else ''})) {{
return NULL;
}}
if (launch_enter_hook != Py_None) {{
PyObject_CallObject(launch_enter_hook, args);
}}
// raise exception asap
{"; ".join([f"DevicePtrInfo ptr_info{i} = getPointer(_arg{i}, {i}); if (!ptr_info{i}.valid) return NULL;" if ty[0] == "*" else "" for i, ty in signature.items()])};
_launch(gridX, gridY, gridZ, num_warps, num_ctas, clusterDimX, clusterDimY, clusterDimZ, shared_memory, (hipStream_t)_stream, (hipFunction_t)_function{', ' + ', '.join(f"ptr_info{i}.dev_ptr" if ty[0]=="*" else f"_arg{i}"for i, ty in signature.items()) if len(signature) > 0 else ''});
if (launch_exit_hook != Py_None) {{
PyObject_CallObject(launch_exit_hook, args);
}}
if(PyErr_Occurred()) {{
return NULL;
}}
// return None
Py_INCREF(Py_None);
return Py_None;
}}
static PyMethodDef ModuleMethods[] = {{
{{"launch", launch, METH_VARARGS, "Entry point for all kernels with this signature"}},
{{NULL, NULL, 0, NULL}} // sentinel
}};
static struct PyModuleDef ModuleDef = {{
PyModuleDef_HEAD_INIT,
\"__triton_launcher\",
NULL, //documentation
-1, //size
ModuleMethods
}};
PyMODINIT_FUNC PyInit___triton_launcher(void) {{
PyObject *m = PyModule_Create(&ModuleDef);
if(m == NULL) {{
return NULL;
}}
PyModule_AddFunctions(m, ModuleMethods);
return m;
}}
"""
else:
folded_without_constexprs = [c for c in ids['ids_of_folded_args'] if c not in ids['ids_of_const_exprs']]
params = [i for i in signature.keys() if i >= start_desc or (i not in constants and i not in folded_without_constexprs)]
src = f"""
folded_without_constexprs = [c for c in ids['ids_of_folded_args'] if c not in ids['ids_of_const_exprs']]
params = [i for i in signature.keys() if i >= start_desc or (i not in constants and i not in folded_without_constexprs)]
src = f"""
#include \"cuda.h\"
#include <stdbool.h>
#include <Python.h>

View File

@@ -1,17 +1,18 @@
import functools
import os
from ..common.build import is_hip
from . import core
@functools.lru_cache()
def libdevice_path():
import torch
third_party_dir = os.path.join(os.path.dirname(os.path.abspath(__file__)), "..", "third_party")
if torch.version.hip is None:
default = os.path.join(third_party_dir, "cuda", "lib", "libdevice.10.bc")
if is_hip():
default = os.path.join(third_party_dir, "hip", "lib", "bitcode", "cuda2gcn.bc")
else:
default = os.path.join(third_party_dir, "rocm", "lib", "bitcode", "cuda2gcn.bc")
default = os.path.join(third_party_dir, "cuda", "lib", "libdevice.10.bc")
return os.getenv("TRITON_LIBDEVICE_PATH", default)

View File

@@ -4,6 +4,7 @@ from functools import wraps
from typing import List, Optional, Sequence, Tuple, TypeVar
from .._C.libtriton.triton import ir
from ..common.build import is_hip
from . import core as tl
import triton._C.libtriton.triton as _triton
@@ -1301,6 +1302,19 @@ def mfma_supported(M, N, K, allow_tf32, ret_scalar_ty) -> bool:
return False
return True
def gpu_has_mfma() -> bool:
if not is_hip():
return False
return True # mfma supported in ['gfx908', 'gfx90a']
def mfma_supported(M, N, K, allow_tf32, ret_scalar_ty) -> bool:
if not gpu_has_mfma():
return False
# TODO: Add check for configurations and types.
return True
def dot(lhs: tl.tensor,
rhs: tl.tensor,
allow_tf32: bool,

View File

@@ -383,20 +383,20 @@ def {self.fn.__name__}({args_signature}grid=None, num_warps=None, num_ctas=1, nu
device_type = self._conclude_device_type(device_types, {pinned_memory_flags})
device_backend = None
if device_type not in ['cuda', 'hip']:
if device_type not in ['cuda']:
device_backend = get_backend(device_type)
if device_backend is None:
raise ValueError('Cannot find backend for ' + device_type)
if device is None:
if device_type in ['cuda', 'hip']:
if device_type in ['cuda']:
device = get_current_device()
set_current_device(device)
else:
device = device_backend.get_current_device()
device_backend.set_current_device(device)
if stream is None and not warmup:
if device_type in ['cuda', 'hip']:
if device_type in ['cuda']:
stream = get_cuda_stream(device)
else:
stream = device_backend.get_stream()

1
third_party/amd_hip_backend vendored Submodule