[FRONTEND] use enum instead of bool to select target (#2118)

Before this PR, the determination of `TritonGPUToLLVMIRPass` to generate
NVVM-compatible LLVM or ROCDL-compatible LLVM is controlled by a boolean
`isROCM`. This method is hard to scale.
This PR changes it to use an enum instead, where new target can be added
easily when needed.

---------

Signed-off-by: Tsang, Whitney <whitney.tsang@intel.com>
Co-authored-by: Philippe Tillet <phil@openai.com>
This commit is contained in:
Whitney Tsang
2023-08-17 21:37:09 -04:00
committed by GitHub
parent b33f97a682
commit 100cabd0e4
10 changed files with 67 additions and 43 deletions

View File

@@ -81,6 +81,11 @@ void init_triton_runtime(py::module &&m) {
.value("CUDA", CUDA)
.value("ROCM", ROCM)
.export_values();
py::enum_<mlir::triton::Target>(m, "TARGET")
.value("NVVM", mlir::triton::NVVM)
.value("ROCDL", mlir::triton::ROCDL)
.export_values();
}
// A custom op builder that keeps track of the last location
@@ -1804,11 +1809,12 @@ void init_triton_translation(py::module &m) {
m.def(
"translate_triton_gpu_to_llvmir",
[](mlir::ModuleOp op, int computeCapability,
mlir::triton::gpu::TMAMetadataTy &tmaInfos, bool isROCM) {
mlir::triton::gpu::TMAMetadataTy &tmaInfos,
mlir::triton::Target target) {
py::gil_scoped_release allow_threads;
llvm::LLVMContext llvmContext;
auto llvmModule = ::mlir::triton::translateTritonGPUToLLVMIR(
&llvmContext, op, computeCapability, tmaInfos, isROCM);
&llvmContext, op, computeCapability, tmaInfos, target);
if (!llvmModule)
llvm::report_fatal_error("Failed to translate TritonGPU to LLVM IR.");

View File

@@ -13,7 +13,7 @@ from typing import Any, Tuple
from .._C.libtriton.triton import (ClusterInfo, TMAInfos, add_external_libs,
compile_ptx_to_cubin, get_env_vars, get_num_warps,
get_shared_memory_size, ir,
get_shared_memory_size, ir, runtime,
translate_llvmir_to_hsaco, translate_llvmir_to_ptx,
translate_triton_gpu_to_llvmir)
from ..common.backend import get_backend, path_to_ptxas
@@ -142,9 +142,9 @@ def ttgir_to_llir(mod, extern_libs, arch, tma_infos):
_add_external_libs(mod, extern_libs)
# TODO: separate tritongpu_to_llvmir for different backends
if _is_cuda(arch):
return translate_triton_gpu_to_llvmir(mod, arch, tma_infos, False)
return translate_triton_gpu_to_llvmir(mod, arch, tma_infos, runtime.TARGET.NVVM)
else:
return translate_triton_gpu_to_llvmir(mod, 0, True)
return translate_triton_gpu_to_llvmir(mod, 0, TMAInfos(), runtime.TARGET.ROCDL)
# PTX translation