mirror of
https://github.com/ROCm/ROCm.git
synced 2026-04-05 03:01:17 -04:00
[FRONTEND] use enum instead of bool to select target (#2118)
Before this PR, the determination of `TritonGPUToLLVMIRPass` to generate NVVM-compatible LLVM or ROCDL-compatible LLVM is controlled by a boolean `isROCM`. This method is hard to scale. This PR changes it to use an enum instead, where new target can be added easily when needed. --------- Signed-off-by: Tsang, Whitney <whitney.tsang@intel.com> Co-authored-by: Philippe Tillet <phil@openai.com>
This commit is contained in:
@@ -81,6 +81,11 @@ void init_triton_runtime(py::module &&m) {
|
||||
.value("CUDA", CUDA)
|
||||
.value("ROCM", ROCM)
|
||||
.export_values();
|
||||
|
||||
py::enum_<mlir::triton::Target>(m, "TARGET")
|
||||
.value("NVVM", mlir::triton::NVVM)
|
||||
.value("ROCDL", mlir::triton::ROCDL)
|
||||
.export_values();
|
||||
}
|
||||
|
||||
// A custom op builder that keeps track of the last location
|
||||
@@ -1804,11 +1809,12 @@ void init_triton_translation(py::module &m) {
|
||||
m.def(
|
||||
"translate_triton_gpu_to_llvmir",
|
||||
[](mlir::ModuleOp op, int computeCapability,
|
||||
mlir::triton::gpu::TMAMetadataTy &tmaInfos, bool isROCM) {
|
||||
mlir::triton::gpu::TMAMetadataTy &tmaInfos,
|
||||
mlir::triton::Target target) {
|
||||
py::gil_scoped_release allow_threads;
|
||||
llvm::LLVMContext llvmContext;
|
||||
auto llvmModule = ::mlir::triton::translateTritonGPUToLLVMIR(
|
||||
&llvmContext, op, computeCapability, tmaInfos, isROCM);
|
||||
&llvmContext, op, computeCapability, tmaInfos, target);
|
||||
if (!llvmModule)
|
||||
llvm::report_fatal_error("Failed to translate TritonGPU to LLVM IR.");
|
||||
|
||||
|
||||
@@ -13,7 +13,7 @@ from typing import Any, Tuple
|
||||
|
||||
from .._C.libtriton.triton import (ClusterInfo, TMAInfos, add_external_libs,
|
||||
compile_ptx_to_cubin, get_env_vars, get_num_warps,
|
||||
get_shared_memory_size, ir,
|
||||
get_shared_memory_size, ir, runtime,
|
||||
translate_llvmir_to_hsaco, translate_llvmir_to_ptx,
|
||||
translate_triton_gpu_to_llvmir)
|
||||
from ..common.backend import get_backend, path_to_ptxas
|
||||
@@ -142,9 +142,9 @@ def ttgir_to_llir(mod, extern_libs, arch, tma_infos):
|
||||
_add_external_libs(mod, extern_libs)
|
||||
# TODO: separate tritongpu_to_llvmir for different backends
|
||||
if _is_cuda(arch):
|
||||
return translate_triton_gpu_to_llvmir(mod, arch, tma_infos, False)
|
||||
return translate_triton_gpu_to_llvmir(mod, arch, tma_infos, runtime.TARGET.NVVM)
|
||||
else:
|
||||
return translate_triton_gpu_to_llvmir(mod, 0, True)
|
||||
return translate_triton_gpu_to_llvmir(mod, 0, TMAInfos(), runtime.TARGET.ROCDL)
|
||||
|
||||
|
||||
# PTX translation
|
||||
|
||||
Reference in New Issue
Block a user