mirror of
https://github.com/ROCm/ROCm.git
synced 2026-04-05 03:01:17 -04:00
Works as StandAlone and Backend and also Perf is Good
This is a combination of 4 commits. Works as StandAlone and Backend Works as StandAlone and Backend This is a combination of 13 commits. Works StandAlone and as Backend This is a combination of 7 commits. backend set default dir with flag move bitcode to backend dir copy backend save empty test work in backendmode enable backend mode when copying to upstream clean up fix failure minimize diff add skip function fix bug with corrupted dwarf exp match num_wraps fix multi threaded test issue move bitcode file out of lib move backend to python/triton/third_party/hip move libhsa backend works again restart ci clean upstream location first before copy match scripts fix new error memoize backend stuff fix bug
This commit is contained in:
@@ -82,12 +82,12 @@ class BaseBackend:
|
||||
|
||||
_backends: Dict[str, BaseBackend] = {}
|
||||
|
||||
|
||||
@functools.lru_cache
|
||||
def register_backend(device_type: str, backend_cls: type):
|
||||
if device_type not in _backends:
|
||||
_backends[device_type] = backend_cls.create_backend(device_type)
|
||||
|
||||
|
||||
@functools.lru_cache
|
||||
def get_backend(device_type: str):
|
||||
if device_type not in _backends:
|
||||
device_backend_package_name = f"...third_party.{device_type}"
|
||||
|
||||
@@ -282,7 +282,7 @@ arg_type_pattern = {
|
||||
"ptx": ptx_arg_type_pattern,
|
||||
}
|
||||
if is_hip():
|
||||
ttgir_num_warps_pattern = r'"triton_gpu_rocm.num-warps"\s?=\s?(\d+)\s?:'
|
||||
ttgir_num_warps_pattern = r'"triton_gpu.num-warps"\s?=\s?(\d+)\s?:'
|
||||
else:
|
||||
ttgir_num_warps_pattern = r'"triton_gpu.num-warps"\s?=\s?(\d+)\s?:'
|
||||
|
||||
@@ -324,16 +324,23 @@ def is_hip():
|
||||
|
||||
from ..language.semantic import gpu_matrix_core_version
|
||||
|
||||
@functools.lru_cache
|
||||
def get_architecture_descriptor(capability):
|
||||
if capability is None:
|
||||
device = get_current_device()
|
||||
capability = get_device_capability(device)
|
||||
capability = capability[0] * 10 + capability[1]
|
||||
return capability
|
||||
|
||||
if is_hip():
|
||||
_device_backend = get_backend("hip")
|
||||
assert _device_backend
|
||||
arch = _device_backend.get_architecture_descriptor()
|
||||
return arch
|
||||
else:
|
||||
if capability is None:
|
||||
device = get_current_device()
|
||||
capability = get_device_capability(device)
|
||||
capability = capability[0] * 10 + capability[1]
|
||||
return capability
|
||||
|
||||
@functools.lru_cache
|
||||
def get_arch_default_num_warps(device_type):
|
||||
if device_type in ["cuda", "hip"]:
|
||||
if device_type in ["cuda"]:
|
||||
num_warps = 4
|
||||
else:
|
||||
_device_backend = get_backend(device_type)
|
||||
@@ -343,9 +350,9 @@ def get_arch_default_num_warps(device_type):
|
||||
|
||||
return num_warps
|
||||
|
||||
|
||||
@functools.lru_cache
|
||||
def get_arch_default_num_stages(device_type, capability=None):
|
||||
if device_type in ["cuda", "hip"]:
|
||||
if device_type in ["cuda"]:
|
||||
arch = get_architecture_descriptor(capability)
|
||||
is_cuda = device_type == "cuda" and _is_cuda(arch)
|
||||
num_stages = 3 if is_cuda and arch >= 75 else 2
|
||||
@@ -359,7 +366,6 @@ def get_arch_default_num_stages(device_type, capability=None):
|
||||
|
||||
|
||||
def add_cuda_stages(arch, extern_libs, stages):
|
||||
|
||||
stages["ptx"] = (lambda path: Path(path).read_text(),
|
||||
lambda src: llir_to_ptx(src, arch))
|
||||
stages["cubin"] = (lambda path: Path(path).read_bytes(),
|
||||
@@ -373,6 +379,7 @@ def compile(fn, **kwargs):
|
||||
|
||||
if is_hip():
|
||||
device_type = "hip"
|
||||
capability = None
|
||||
|
||||
if device_type == "cuda":
|
||||
_device_backend = get_backend(device_type)
|
||||
@@ -385,6 +392,7 @@ def compile(fn, **kwargs):
|
||||
is_cuda = device_type == "cuda" and _is_cuda(arch)
|
||||
if is_hip():
|
||||
is_cuda = False
|
||||
warp_size = CUDA_DEFAULT_WARP_SIZE if _is_cuda(arch) else arch["warp_size"]
|
||||
context = ir.context()
|
||||
constants = kwargs.get("constants", dict())
|
||||
num_warps = kwargs.get("num_warps", get_arch_default_num_warps(device_type))
|
||||
@@ -412,6 +420,7 @@ def compile(fn, **kwargs):
|
||||
cluster_info.clusterDimY = kwargs["clusterDims"][1]
|
||||
cluster_info.clusterDimZ = kwargs["clusterDims"][2]
|
||||
tma_infos = TMAInfos()
|
||||
|
||||
# build compilation stages
|
||||
stages = dict()
|
||||
stages["ast"] = (lambda path: fn, None)
|
||||
@@ -424,7 +433,23 @@ def compile(fn, **kwargs):
|
||||
lambda src: ttgir_to_llir(src, extern_libs, arch, tma_infos))
|
||||
add_cuda_stages(arch, extern_libs, stages)
|
||||
elif device_type == "hip":
|
||||
_device_backend.add_stages(arch, extern_libs, stages, num_warps=num_warps, num_stages=num_stages)
|
||||
# pass the user's configuration to the backend device.
|
||||
arch["num_warps"] = num_warps
|
||||
arch["num_stages"] = num_stages
|
||||
arch["num_ctas"] = num_ctas
|
||||
|
||||
other = {}
|
||||
other["context"] = context
|
||||
other["warp_size"] = warp_size
|
||||
other["cluster_info"] = cluster_info
|
||||
other["enable_warp_specialization"] = enable_warp_specialization
|
||||
other["enable_persistent"] = enable_persistent
|
||||
other["optimize_epilogue"] = optimize_epilogue
|
||||
other["tma_infos"] = tma_infos
|
||||
other["waves_per_eu"] = waves_per_eu
|
||||
other["matrix_instr_nonkdim"] = matrix_instr_nonkdim
|
||||
|
||||
_device_backend.add_stages(arch, extern_libs, stages, other)
|
||||
elif device_type == "xpu":
|
||||
stages["ttgir"] = (lambda path: parse_mlir_module(path, context),
|
||||
lambda src: optimize_ttgir(ttir_to_ttgir(src, num_warps, num_ctas, arch), num_stages, num_warps, num_ctas, arch, cluster_info, enable_warp_specialization, enable_persistent, optimize_epilogue))
|
||||
@@ -554,11 +579,11 @@ def compile(fn, **kwargs):
|
||||
else:
|
||||
metadata["shared"] = get_shared_memory_size(module)
|
||||
if ir_name == "ttgir":
|
||||
if is_hip():
|
||||
metadata["num_warps"] = _device_backend.get_num_warps(next_module)
|
||||
else:
|
||||
metadata["enable_warp_specialization"] = ir.is_ws_supported(next_module)
|
||||
if metadata["enable_warp_specialization"]:
|
||||
metadata["enable_warp_specialization"] = ir.is_ws_supported(next_module)
|
||||
if metadata["enable_warp_specialization"]:
|
||||
if is_hip():
|
||||
metadata["num_warps"] = _device_backend.get_num_warps(next_module)
|
||||
else:
|
||||
metadata["num_warps"] = get_num_warps(next_module)
|
||||
if ir_name == "ptx":
|
||||
metadata["name"] = get_kernel_name(next_module, pattern='// .globl')
|
||||
|
||||
@@ -1302,19 +1302,6 @@ def mfma_supported(M, N, K, allow_tf32, ret_scalar_ty) -> bool:
|
||||
return False
|
||||
return True
|
||||
|
||||
def gpu_has_mfma() -> bool:
|
||||
if not is_hip():
|
||||
return False
|
||||
return True # mfma supported in ['gfx908', 'gfx90a']
|
||||
|
||||
|
||||
def mfma_supported(M, N, K, allow_tf32, ret_scalar_ty) -> bool:
|
||||
if not gpu_has_mfma():
|
||||
return False
|
||||
# TODO: Add check for configurations and types.
|
||||
return True
|
||||
|
||||
|
||||
def dot(lhs: tl.tensor,
|
||||
rhs: tl.tensor,
|
||||
allow_tf32: bool,
|
||||
|
||||
109
python/triton/third_party/hip/CMakeLists.txt
vendored
Normal file
109
python/triton/third_party/hip/CMakeLists.txt
vendored
Normal file
@@ -0,0 +1,109 @@
|
||||
list(APPEND CMAKE_MODULE_PATH "${CMAKE_CURRENT_SOURCE_DIR}/cmake")
|
||||
|
||||
# FLAGS
|
||||
message(STATUS "HIP_BACKEND_MODE = ${HIP_BACKEND_MODE}")
|
||||
set(ROCM_DEFAULT_DIR "/opt/rocm")
|
||||
add_definitions( -DROCM_DEFAULT_DIR="${ROCM_DEFAULT_DIR}")
|
||||
set(ROCM_LIBRARIES
|
||||
${CMAKE_CURRENT_SOURCE_DIR}/lib/hsa/libhsa-runtime64.so
|
||||
)
|
||||
|
||||
include_directories(${CMAKE_CURRENT_SOURCE_DIR}/include)
|
||||
include_directories(${CMAKE_CURRENT_BINARY_DIR}/include) # Tablegen'd files
|
||||
|
||||
add_subdirectory(include)
|
||||
add_subdirectory(lib)
|
||||
|
||||
|
||||
# shows dependecy of targets
|
||||
set_property(GLOBAL PROPERTY GLOBAL_DEPENDS_DEBUG_MODE 1)
|
||||
|
||||
# Python module
|
||||
if(TRITON_BUILD_PYTHON_MODULE)
|
||||
message(STATUS "Adding HIP Backend Python module")
|
||||
set(PYTHON_SRC_PATH ${CMAKE_CURRENT_SOURCE_DIR})
|
||||
set(PYTHON_SRC ${PYTHON_SRC_PATH}/rocm_backend_for_triton.cc)
|
||||
include_directories("." ${PYTHON_SRC_PATH})
|
||||
include_directories(../include)
|
||||
|
||||
if(PYTHON_INCLUDE_DIRS)
|
||||
include_directories(${PYTHON_INCLUDE_DIRS})
|
||||
else()
|
||||
find_package(Python3 REQUIRED COMPONENTS Development Interpreter)
|
||||
include_directories(${Python3_INCLUDE_DIRS})
|
||||
link_directories(${Python3_LIBRARY_DIRS})
|
||||
link_libraries(${Python3_LIBRARIES})
|
||||
add_link_options(${Python3_LINK_OPTIONS})
|
||||
endif()
|
||||
|
||||
add_library(rocm_backend_for_triton SHARED ${PYTHON_SRC})
|
||||
|
||||
set(ROCM_EXTENSION_LIBRARIES
|
||||
TritonAnalysisROCM
|
||||
TritonTransforms
|
||||
TritonHSACO
|
||||
|
||||
# ${dialect_libs}
|
||||
TritonGPUROCMIR
|
||||
TritonGPUROCMTransforms
|
||||
|
||||
# ${conversion_libs}
|
||||
TritonToTritonGPUROCM
|
||||
TritonGPUROCMToLLVM
|
||||
|
||||
# tests
|
||||
# TritonTestAnalysis
|
||||
|
||||
# llvm
|
||||
LLVMCore
|
||||
LLVMSupport
|
||||
LLVMOption
|
||||
LLVMCodeGen
|
||||
LLVMAsmParser
|
||||
|
||||
# MLIR core
|
||||
MLIROptLib
|
||||
MLIRIR
|
||||
MLIRLLVMDialect
|
||||
MLIRPass
|
||||
MLIRSupport
|
||||
MLIRTransforms
|
||||
MLIRExecutionEngine
|
||||
MLIRMathToLLVM
|
||||
MLIRTransformUtils
|
||||
MLIRLLVMToLLVMIRTranslation
|
||||
MLIRROCDLToLLVMIRTranslation
|
||||
# MLIRNVVMToLLVMIRTranslation
|
||||
)
|
||||
target_link_libraries(rocm_backend_for_triton PRIVATE ${ROCM_EXTENSION_LIBRARIES})
|
||||
target_link_libraries(rocm_backend_for_triton PRIVATE ${LLVM_LIBRARIES})
|
||||
link_libraries(stdc++fs)
|
||||
target_link_options(rocm_backend_for_triton PRIVATE ${LLVM_LDFLAGS})
|
||||
|
||||
# copy to upstream third_party dir
|
||||
file(REMOVE_RECURSE ${PYTHON_THIRD_PARTY_PATH}/hip)
|
||||
file(INSTALL
|
||||
${CMAKE_CURRENT_SOURCE_DIR}/
|
||||
DESTINATION ${PYTHON_THIRD_PARTY_PATH}/hip)
|
||||
|
||||
# set HIP_BACKEND_MODE to true
|
||||
set(HIP_BACKEND_PY "${PYTHON_THIRD_PARTY_PATH}/hip/hip_backend.py")
|
||||
set(HIP_BACKEND_PY_STAMP "${PYTHON_THIRD_PARTY_PATH}/hip/hip_backend.py.stamp")
|
||||
|
||||
add_custom_command(
|
||||
OUTPUT ${HIP_BACKEND_PY_STAMP}
|
||||
COMMAND
|
||||
sed -i'' -e 's/HIP_BACKEND_MODE[[:space:]]*=[[:space:]]*False/HIP_BACKEND_MODE = True/' ${HIP_BACKEND_PY}
|
||||
COMMAND
|
||||
touch ${HIP_BACKEND_PY_STAMP}
|
||||
DEPENDS ${HIP_BACKEND_PY}
|
||||
COMMENT "Modifying hip_backend.py to enable HIP_BACKEND_MODE."
|
||||
)
|
||||
|
||||
add_custom_target(
|
||||
modify_file
|
||||
ALL
|
||||
DEPENDS ${HIP_BACKEND_PY_STAMP}
|
||||
COMMENT "Checking and applying modifications to hip_backend.py"
|
||||
)
|
||||
endif()
|
||||
5
python/triton/third_party/hip/__init__.py
vendored
Executable file
5
python/triton/third_party/hip/__init__.py
vendored
Executable file
@@ -0,0 +1,5 @@
|
||||
from triton.common.backend import register_backend
|
||||
from .hip_backend import HIPBackend
|
||||
|
||||
# register backend
|
||||
register_backend("hip", HIPBackend)
|
||||
493
python/triton/third_party/hip/hip_backend.py
vendored
Normal file
493
python/triton/third_party/hip/hip_backend.py
vendored
Normal file
@@ -0,0 +1,493 @@
|
||||
import hashlib
|
||||
import os
|
||||
import re
|
||||
import subprocess
|
||||
import tempfile
|
||||
from pathlib import Path
|
||||
from typing import Any, Tuple
|
||||
|
||||
|
||||
from triton.common import _build
|
||||
from triton.common.backend import BaseBackend, register_backend
|
||||
from triton.compiler.make_launcher import get_cache_manager, version_key, make_so_cache_key
|
||||
from triton.compiler.utils import generate_cu_signature
|
||||
from triton.runtime import jit
|
||||
from triton.runtime.driver import HIPDriver
|
||||
from triton.compiler.compiler import optimize_ttgir, parse_mlir_module, ttgir_to_llir, ttir_to_ttgir
|
||||
|
||||
HIP_BACKEND_MODE = False
|
||||
|
||||
if HIP_BACKEND_MODE:
|
||||
from ..._C.librocm_backend_for_triton import triton as _triton
|
||||
else:
|
||||
from ..._C.libtriton import triton as _triton
|
||||
|
||||
|
||||
def make_stub(name, signature, constants, ids, **kwargs):
|
||||
# name of files that are cached
|
||||
so_cache_key = make_so_cache_key(version_key(), signature, constants, ids, **kwargs)
|
||||
so_cache_manager = get_cache_manager(so_cache_key)
|
||||
so_name = f"{name}.so"
|
||||
# retrieve stub from cache if it exists
|
||||
cache_path = so_cache_manager.get_file(so_name)
|
||||
if cache_path is None:
|
||||
with tempfile.TemporaryDirectory() as tmpdir:
|
||||
src = generate_launcher_hip(constants, signature, ids)
|
||||
src_path = os.path.join(tmpdir, "main.c")
|
||||
with open(src_path, "w") as f:
|
||||
f.write(src)
|
||||
so = _build(name, src_path, tmpdir)
|
||||
with open(so, "rb") as f:
|
||||
return so_cache_manager.put(f.read(), so_name, binary=True)
|
||||
else:
|
||||
return cache_path
|
||||
|
||||
|
||||
def ty_to_cpp(ty):
|
||||
if ty[0] == '*':
|
||||
return "hipDeviceptr_t"
|
||||
return {
|
||||
"i1": "int32_t",
|
||||
"i8": "int8_t",
|
||||
"i16": "int16_t",
|
||||
"i32": "int32_t",
|
||||
"i64": "int64_t",
|
||||
"u32": "uint32_t",
|
||||
"u64": "uint64_t",
|
||||
"fp16": "float",
|
||||
"bf16": "float",
|
||||
"fp32": "float",
|
||||
"f32": "float",
|
||||
"fp64": "double",
|
||||
}[ty]
|
||||
|
||||
|
||||
def generate_launcher_hip(constants, signature, ids):
|
||||
start_desc = len(signature)
|
||||
signature = generate_cu_signature(constants, signature, ids)
|
||||
arg_decls = ', '.join(f"{ty_to_cpp(ty)} arg{i}" for i, ty in signature.items())
|
||||
|
||||
def _extracted_type(ty):
|
||||
if ty[0] == '*':
|
||||
return "PyObject*"
|
||||
return {
|
||||
'i1': 'int32_t',
|
||||
'i32': 'int32_t',
|
||||
'i64': 'int64_t',
|
||||
'u32': 'uint32_t',
|
||||
'u64': 'uint64_t',
|
||||
'fp16': 'float',
|
||||
'bf16': 'float',
|
||||
'fp32': 'float',
|
||||
'f32': 'float',
|
||||
'fp64': 'double',
|
||||
}[ty]
|
||||
|
||||
def format_of(ty):
|
||||
return {
|
||||
"PyObject*": "O",
|
||||
"float": "f",
|
||||
"double": "d",
|
||||
"long": "l",
|
||||
"uint32_t": "I",
|
||||
"int32_t": "i",
|
||||
"uint64_t": "K",
|
||||
"int64_t": "L",
|
||||
}[ty]
|
||||
|
||||
format = "iiiiiiiiiKKOOO" + ''.join([format_of(_extracted_type(ty)) for ty in signature.values()])
|
||||
|
||||
# generate glue code
|
||||
folded_without_constexprs = [c for c in ids['ids_of_folded_args'] if c not in ids['ids_of_const_exprs']]
|
||||
params = [i for i in signature.keys() if i >= start_desc or (i not in constants and i not in folded_without_constexprs)]
|
||||
src = f"""
|
||||
#define __HIP_PLATFORM_AMD__
|
||||
#include <hip/hip_runtime.h>
|
||||
#include <Python.h>
|
||||
#include <stdbool.h>
|
||||
#include <dlfcn.h>
|
||||
|
||||
static inline void gpuAssert(hipError_t code, const char *file, int line)
|
||||
{{
|
||||
if (code != HIP_SUCCESS)
|
||||
{{
|
||||
const char* prefix = "Triton Error [HIP]: ";
|
||||
const char* str = hipGetErrorString(code);
|
||||
char err[1024] = {{0}};
|
||||
snprintf(err, 1024, "%s Code: %d, Messsage: %s", prefix, code, str );
|
||||
PyErr_SetString(PyExc_RuntimeError, err);
|
||||
}}
|
||||
}}
|
||||
|
||||
#define HIP_CHECK(ans) {{ gpuAssert((ans), __FILE__, __LINE__); }}
|
||||
|
||||
static void _launch(int gridX, int gridY, int gridZ, int num_warps, int num_ctas, int clusterDimX, int clusterDimY, int clusterDimZ, int shared_memory, hipStream_t stream, hipFunction_t function{', ' + arg_decls if len(arg_decls) > 0 else ''}) {{
|
||||
// printf("_launch hip kernel\\n");
|
||||
void *params[] = {{ {', '.join(f"&arg{i}" for i in params)} }};
|
||||
if (gridX*gridY*gridZ > 0) {{
|
||||
HIP_CHECK(hipModuleLaunchKernel(function, gridX, gridY, gridZ, 64*num_warps, 1, 1, shared_memory, stream, params, 0));
|
||||
}}
|
||||
}}
|
||||
|
||||
typedef struct _DevicePtrInfo {{
|
||||
hipDeviceptr_t dev_ptr;
|
||||
bool valid;
|
||||
}} DevicePtrInfo;
|
||||
|
||||
static inline DevicePtrInfo getPointer(PyObject *obj, int idx) {{
|
||||
DevicePtrInfo ptr_info;
|
||||
ptr_info.dev_ptr = 0;
|
||||
ptr_info.valid = true;
|
||||
if (PyLong_Check(obj)) {{
|
||||
ptr_info.dev_ptr = (hipDeviceptr_t)PyLong_AsUnsignedLongLong(obj);
|
||||
return ptr_info;
|
||||
}}
|
||||
if (obj == Py_None) {{
|
||||
// valid nullptr
|
||||
return ptr_info;
|
||||
}}
|
||||
PyObject *ptr = PyObject_GetAttrString(obj, "data_ptr");
|
||||
if(ptr){{
|
||||
PyObject *empty_tuple = PyTuple_New(0);
|
||||
PyObject *ret = PyObject_Call(ptr, empty_tuple, NULL);
|
||||
Py_DECREF(empty_tuple);
|
||||
Py_DECREF(ptr);
|
||||
if (!PyLong_Check(ret)) {{
|
||||
PyErr_SetString(PyExc_TypeError, "data_ptr method of Pointer object must return 64-bit int");
|
||||
ptr_info.valid = false;
|
||||
return ptr_info;
|
||||
}}
|
||||
ptr_info.dev_ptr = (hipDeviceptr_t)PyLong_AsUnsignedLongLong(ret);
|
||||
if(!ptr_info.dev_ptr)
|
||||
return ptr_info;
|
||||
uint64_t dev_ptr;
|
||||
hipError_t status = hipPointerGetAttribute(&dev_ptr, HIP_POINTER_ATTRIBUTE_DEVICE_POINTER, ptr_info.dev_ptr);
|
||||
if (status == hipErrorInvalidValue) {{
|
||||
PyErr_Format(PyExc_ValueError,
|
||||
"Pointer argument (at %d) cannot be accessed from Triton (cpu tensor?)", idx);
|
||||
ptr_info.valid = false;
|
||||
}}
|
||||
ptr_info.dev_ptr = (hipDeviceptr_t)dev_ptr;
|
||||
Py_DECREF(ret);
|
||||
return ptr_info;
|
||||
}}
|
||||
PyErr_SetString(PyExc_TypeError, "Pointer argument must be either uint64 or have data_ptr method");
|
||||
return ptr_info;
|
||||
}}
|
||||
|
||||
static PyObject* launch(PyObject* self, PyObject* args) {{
|
||||
// printf("launch\\n");
|
||||
int gridX, gridY, gridZ;
|
||||
uint64_t _stream;
|
||||
uint64_t _function;
|
||||
int num_warps;
|
||||
int num_ctas;
|
||||
int clusterDimX;
|
||||
int clusterDimY;
|
||||
int clusterDimZ;
|
||||
int shared_memory;
|
||||
PyObject *launch_enter_hook = NULL;
|
||||
PyObject *launch_exit_hook = NULL;
|
||||
PyObject *compiled_kernel = NULL;
|
||||
{' '.join([f"{_extracted_type(ty)} _arg{i}; " for i, ty in signature.items()])}
|
||||
if(!PyArg_ParseTuple(args, \"{format}\", &gridX, &gridY, &gridZ, &num_warps, &num_ctas, &clusterDimX, &clusterDimY, &clusterDimZ, &shared_memory, &_stream, &_function, &launch_enter_hook, &launch_exit_hook, &compiled_kernel{', ' + ', '.join(f"&_arg{i}" for i, ty in signature.items()) if len(signature) > 0 else ''})) {{
|
||||
return NULL;
|
||||
}}
|
||||
|
||||
if (launch_enter_hook != Py_None) {{
|
||||
PyObject_CallObject(launch_enter_hook, args);
|
||||
}}
|
||||
|
||||
|
||||
// raise exception asap
|
||||
{"; ".join([f"DevicePtrInfo ptr_info{i} = getPointer(_arg{i}, {i}); if (!ptr_info{i}.valid) return NULL;" if ty[0] == "*" else "" for i, ty in signature.items()])};
|
||||
_launch(gridX, gridY, gridZ, num_warps, num_ctas, clusterDimX, clusterDimY, clusterDimZ, shared_memory, (hipStream_t)_stream, (hipFunction_t)_function{', ' + ', '.join(f"ptr_info{i}.dev_ptr" if ty[0]=="*" else f"_arg{i}"for i, ty in signature.items()) if len(signature) > 0 else ''});
|
||||
|
||||
if (launch_exit_hook != Py_None) {{
|
||||
PyObject_CallObject(launch_exit_hook, args);
|
||||
}}
|
||||
|
||||
if(PyErr_Occurred()) {{
|
||||
return NULL;
|
||||
}}
|
||||
// return None
|
||||
Py_INCREF(Py_None);
|
||||
return Py_None;
|
||||
}}
|
||||
|
||||
static PyMethodDef ModuleMethods[] = {{
|
||||
{{"launch", launch, METH_VARARGS, "Entry point for all kernels with this signature"}},
|
||||
{{NULL, NULL, 0, NULL}} // sentinel
|
||||
}};
|
||||
|
||||
static struct PyModuleDef ModuleDef = {{
|
||||
PyModuleDef_HEAD_INIT,
|
||||
\"__triton_launcher\",
|
||||
NULL, //documentation
|
||||
-1, //size
|
||||
ModuleMethods
|
||||
}};
|
||||
|
||||
PyMODINIT_FUNC PyInit___triton_launcher(void) {{
|
||||
PyObject *m = PyModule_Create(&ModuleDef);
|
||||
if(m == NULL) {{
|
||||
return NULL;
|
||||
}}
|
||||
PyModule_AddFunctions(m, ModuleMethods);
|
||||
return m;
|
||||
}}
|
||||
"""
|
||||
return src
|
||||
|
||||
|
||||
def get_amdgcn_bitcode_paths(gfx_arch: str):
|
||||
# print("get_amdgcn_bitcode_paths")
|
||||
gpu_arch_agnostic_bitcode_libraries = ["opencl.bc",
|
||||
"ocml.bc",
|
||||
"ockl.bc",
|
||||
"oclc_finite_only_off.bc",
|
||||
"oclc_daz_opt_off.bc",
|
||||
"oclc_correctly_rounded_sqrt_on.bc",
|
||||
"oclc_unsafe_math_off.bc",
|
||||
"oclc_wavefrontsize64_on.bc",
|
||||
"oclc_abi_version_400.bc", ]
|
||||
|
||||
gfx_arch_id = re.search('gfx(\\w+)', gfx_arch).group(1).strip()
|
||||
|
||||
gpu_arch_specific_bitcode_library = 'oclc_isa_version_' + gfx_arch_id + ".bc"
|
||||
current_dir = Path(__file__)
|
||||
bitcode_path_dir = os.path.join(current_dir.parent.resolve(), "lib/bitcode/")
|
||||
|
||||
amdgcn_bitcode_paths = {}
|
||||
i = 0
|
||||
for bc_lib in gpu_arch_agnostic_bitcode_libraries:
|
||||
bc_path = bitcode_path_dir + bc_lib
|
||||
if os.path.exists(bc_path):
|
||||
amdgcn_bitcode_paths['library_' + str(i)] = bc_path
|
||||
i += 1
|
||||
bc_gfx_path = bitcode_path_dir + gpu_arch_specific_bitcode_library
|
||||
if os.path.exists(bc_gfx_path):
|
||||
amdgcn_bitcode_paths['library_' + str(i)] = bc_gfx_path
|
||||
|
||||
# print(f"amdgcn_bitcode_paths: {amdgcn_bitcode_paths}")
|
||||
return amdgcn_bitcode_paths
|
||||
|
||||
|
||||
def get_amdgpu_arch_fulldetails():
|
||||
# print("get_amdgpu_arch_fulldetails")
|
||||
"""
|
||||
get the amdgpu fulll ISA details for compiling:
|
||||
i.e., arch_triple: amdgcn-amd-amdhsa; arch_name: gfx906; arch_features: sramecc+:xnack-
|
||||
"""
|
||||
try:
|
||||
# TODO: package rocm.cc with Triton
|
||||
rocm_path_dir = os.getenv("ROCM_PATH", default="/opt/rocm")
|
||||
rocminfo = subprocess.check_output(rocm_path_dir + '/bin/rocminfo').decode()
|
||||
gfx_arch_details = re.search('amd.*', rocminfo).group(0).strip().split('--')
|
||||
arch_triple = gfx_arch_details[0]
|
||||
arch_name_features = gfx_arch_details[1].split(':')
|
||||
arch_name = arch_name_features[0]
|
||||
arch_features = ""
|
||||
|
||||
if (len(arch_name_features) == 3):
|
||||
arch_features = "+" + re.search('\\w+', arch_name_features[1]).group(0) + ","\
|
||||
"-" + re.search('\\w+', arch_name_features[2]).group(0)
|
||||
|
||||
# overwrite if provided by user
|
||||
gfx_arch = os.environ.get('MI_GPU_ARCH', arch_name)
|
||||
if gfx_arch is None:
|
||||
raise RuntimeError('gfx_arch is None (not specified)')
|
||||
|
||||
return {"gfx_triple": arch_triple, "gfx_arch": gfx_arch, "gfx_features": arch_features}
|
||||
except BaseException:
|
||||
return None
|
||||
|
||||
|
||||
def get_kernel_name(src: str, pattern: str) -> str:
|
||||
# print("get_kernel_name")
|
||||
'''
|
||||
Get kernel name from PTX code.
|
||||
This Kernel name is required when launching the kernel.
|
||||
'''
|
||||
# There is a name mangling in PTX codegen, so the original kernel names in Triton IR are not available in PTX/cubin.
|
||||
assert src
|
||||
for line in src.split('\n'):
|
||||
line = line.strip()
|
||||
if line.startswith(pattern):
|
||||
return line.split()[-1]
|
||||
|
||||
|
||||
def get_arch_details(arch: dict):
|
||||
# get arch info
|
||||
gfx_arch = os.environ.get('MI_GPU_ARCH', arch["gfx_arch"])
|
||||
gfx_triple = arch["gfx_triple"]
|
||||
gfx_features = arch["gfx_features"]
|
||||
if gfx_arch is None:
|
||||
raise RuntimeError('gfx_arch is None (not specified)')
|
||||
|
||||
return gfx_arch, gfx_triple, gfx_features
|
||||
|
||||
|
||||
def update_extern_libs(extern_libs: dict, gfx_arch: str):
|
||||
# append extern_libs
|
||||
extern_libs.update(get_amdgcn_bitcode_paths(gfx_arch))
|
||||
for key in list(extern_libs):
|
||||
if extern_libs[key] == '' or extern_libs[key] is None:
|
||||
extern_libs.pop(key)
|
||||
|
||||
# check extern libs
|
||||
if extern_libs:
|
||||
for name, path in extern_libs.items():
|
||||
if len(name) == 0 or len(path) == 0:
|
||||
raise RuntimeWarning(f"extern_lib has empty value, {name}: {path}")
|
||||
|
||||
names = list(extern_libs.keys())
|
||||
paths = list(extern_libs.values())
|
||||
return names, paths
|
||||
|
||||
|
||||
# passes
|
||||
def ttir_to_ttgir_rocm(module: str, compute_capability: int, num_warps: int, num_stages: int):
|
||||
return _triton.translate_ttir_to_ttgir_rocm(module, compute_capability, num_warps, num_stages)
|
||||
|
||||
|
||||
def optimize_ttgir_rocm():
|
||||
pass
|
||||
|
||||
|
||||
def ttgir_to_llir_rocm(module: str, extern_libs: dict, arch: dict):
|
||||
names, paths = update_extern_libs(extern_libs, arch["gfx_arch"])
|
||||
llvmIR = _triton.translate_ttgir_to_llvmir(module, names, paths)
|
||||
return llvmIR
|
||||
|
||||
|
||||
def llir_to_amdgcn_and_hsaco_rocm(module: str, arch: dict):
|
||||
'''
|
||||
Translate TritonGPU module to HSACO code based on full details of gpu architecture.
|
||||
:param mod: a TritonGPU dialect module
|
||||
:return:
|
||||
- AMDGCN code
|
||||
- Path to HSACO object
|
||||
'''
|
||||
return _triton.translate_llvmir_to_hsaco(module, arch["gfx_arch"], arch["gfx_triple"], arch["gfx_features"])
|
||||
|
||||
|
||||
def ttir_to_amdgcn_and_hsaco(module, context, arch, num_warps, num_stages, extern_libs) -> Tuple[str, str]:
|
||||
gfx_arch, gfx_triple, gfx_features = get_arch_details(arch)
|
||||
names, paths = update_extern_libs(extern_libs, gfx_arch)
|
||||
return _triton.translate_triton_ir_to_amdgcn_and_hsaco(str(module), gfx_arch, gfx_triple, gfx_features, num_warps, num_stages, names, paths)
|
||||
|
||||
|
||||
def llir_to_amdgcn_and_hsaco(mod: Any, gfx_arch: str, gfx_triple: str, gfx_features: str) -> Tuple[str, str]:
|
||||
'''
|
||||
Translate TritonGPU module to HSACO code based on full details of gpu architecture.
|
||||
:param mod: a TritonGPU dialect module
|
||||
:return:
|
||||
- AMDGCN code
|
||||
- Path to HSACO object
|
||||
'''
|
||||
return _triton.translate_llvmir_to_hsaco(mod, gfx_arch, gfx_triple, gfx_features)
|
||||
|
||||
|
||||
class HIPBackend(BaseBackend):
|
||||
def __init__(self, device_type: str) -> None:
|
||||
super(HIPBackend, self).__init__(device_type)
|
||||
self.driver = HIPDriver()
|
||||
self.stub_so_path = ""
|
||||
|
||||
def is_standalone(self):
|
||||
return not HIP_BACKEND_MODE
|
||||
|
||||
def add_stages(self, arch: dict, extern_libs: dict, stages: dict, other: dict = {}):
|
||||
if self.is_standalone():
|
||||
num_warps = arch["num_warps"]
|
||||
num_ctas = arch["num_ctas"]
|
||||
num_stages = arch["num_stages"]
|
||||
gfx_arch = arch["gfx_arch"]
|
||||
gfx_triple = arch["gfx_triple"]
|
||||
gfx_features = arch["gfx_features"]
|
||||
|
||||
context = other["context"]
|
||||
warp_size = other["warp_size"]
|
||||
cluster_info = other["cluster_info"]
|
||||
enable_warp_specialization = other["enable_warp_specialization"]
|
||||
enable_persistent = other["enable_persistent"]
|
||||
optimize_epilogue = other["optimize_epilogue"]
|
||||
tma_infos = other["tma_infos"]
|
||||
waves_per_eu = other["waves_per_eu"]
|
||||
matrix_instr_nonkdim = other["matrix_instr_nonkdim"]
|
||||
|
||||
stages["ttgir"] = (lambda path: parse_mlir_module(path, context),
|
||||
lambda src: optimize_ttgir(ttir_to_ttgir(src, num_warps, warp_size, num_ctas, arch), num_stages, num_warps, num_ctas, arch, cluster_info, enable_warp_specialization, enable_persistent, optimize_epilogue, matrix_instr_nonkdim))
|
||||
stages["llir"] = (lambda path: Path(path).read_text(),
|
||||
lambda src: ttgir_to_llir(src, extern_libs, arch, tma_infos, waves_per_eu))
|
||||
|
||||
extern_libs.update(get_amdgcn_bitcode_paths(gfx_arch))
|
||||
for key in list(extern_libs):
|
||||
if extern_libs[key] == '' or extern_libs[key] is None:
|
||||
extern_libs.pop(key)
|
||||
|
||||
stages["amdgcn"] = (lambda path: Path(path).read_text(),
|
||||
lambda src: llir_to_amdgcn_and_hsaco(src, gfx_arch,
|
||||
gfx_triple,
|
||||
gfx_features))
|
||||
else:
|
||||
# add stages
|
||||
stages["ttgir"] = (lambda path: Path(path).read_text(),
|
||||
lambda src: ttir_to_ttgir_rocm(str(src), 0, arch["num_warps"], arch["num_stages"]))
|
||||
stages["llir"] = (lambda path: Path(path).read_text(),
|
||||
lambda src: ttgir_to_llir_rocm(src, extern_libs, arch))
|
||||
stages["amdgcn"] = (lambda path: Path(path).read_text(),
|
||||
lambda src: llir_to_amdgcn_and_hsaco_rocm(src, arch))
|
||||
|
||||
def add_meta_info(self, ir, module, next_module, metadata, asm):
|
||||
pass
|
||||
|
||||
def get_driver(self):
|
||||
return self.driver
|
||||
|
||||
def get_stream(self, idx=None):
|
||||
return jit.get_cuda_stream()
|
||||
|
||||
def get_device_properties(self, device):
|
||||
return self.driver.utils.get_device_properties(device)
|
||||
|
||||
def get_current_device(self):
|
||||
return jit.get_current_device()
|
||||
|
||||
def set_current_device(self, device):
|
||||
return jit.set_current_device(device)
|
||||
|
||||
def get_load_binary_fn(self):
|
||||
return self.driver.utils.load_binary
|
||||
|
||||
def get_kernel_bin(self):
|
||||
return "hsaco_path"
|
||||
|
||||
def get_architecture_descriptor(self, **kwargs):
|
||||
# get arch
|
||||
arch = get_amdgpu_arch_fulldetails()
|
||||
|
||||
# set default values
|
||||
arch["num_warps"] = 4
|
||||
arch["num_stages"] = 2
|
||||
arch["num_ctas"] = 1
|
||||
arch["warp_size"] = 64
|
||||
return arch
|
||||
|
||||
def make_launcher_stub(self, name, signature, constants, ids):
|
||||
# print("HIPBackend.make_launcher_stub")
|
||||
self.stub_so_path = make_stub(name, signature, constants, ids)
|
||||
return self.stub_so_path
|
||||
|
||||
def get_shared_memory_size(self, module):
|
||||
if self.is_standalone():
|
||||
return _triton.get_shared_memory_size(module)
|
||||
else:
|
||||
return _triton.get_shared_memory_size(module)
|
||||
|
||||
def get_num_warps(self, module):
|
||||
if self.is_standalone():
|
||||
return _triton.get_num_warps(module)
|
||||
else:
|
||||
return _triton.get_num_warps(module)
|
||||
1
python/triton/third_party/hip/include/CMakeLists.txt
vendored
Normal file
1
python/triton/third_party/hip/include/CMakeLists.txt
vendored
Normal file
@@ -0,0 +1 @@
|
||||
add_subdirectory(triton)
|
||||
95
python/triton/third_party/hip/include/triton/AnalysisROCM/Alias.h
vendored
Normal file
95
python/triton/third_party/hip/include/triton/AnalysisROCM/Alias.h
vendored
Normal file
@@ -0,0 +1,95 @@
|
||||
#ifndef TRITON_ANALYSISROCM_ALIAS_H
|
||||
#define TRITON_ANALYSISROCM_ALIAS_H
|
||||
|
||||
#include "mlir/Analysis/AliasAnalysis.h"
|
||||
#include "mlir/Analysis/DataFlow/SparseAnalysis.h"
|
||||
#include "llvm/ADT/DenseSet.h"
|
||||
|
||||
namespace mlir {
|
||||
|
||||
class AliasInfo {
|
||||
public:
|
||||
AliasInfo() = default;
|
||||
AliasInfo(Value value) { insert(value); }
|
||||
|
||||
void insert(Value value) { allocs.insert(value); }
|
||||
|
||||
const DenseSet<Value> &getAllocs() const { return allocs; }
|
||||
|
||||
bool operator==(const AliasInfo &other) const {
|
||||
return allocs == other.allocs;
|
||||
}
|
||||
|
||||
/// The pessimistic value state of a value without alias
|
||||
static AliasInfo getPessimisticValueState(MLIRContext *context = nullptr) {
|
||||
return AliasInfo();
|
||||
}
|
||||
static AliasInfo getPessimisticValueState(Value value) { return AliasInfo(); }
|
||||
|
||||
/// The union of both arguments
|
||||
static AliasInfo join(const AliasInfo &lhs, const AliasInfo &rhs);
|
||||
|
||||
void print(raw_ostream &os) const {
|
||||
llvm::interleaveComma(allocs, os, [&](Value alloc) { alloc.print(os); });
|
||||
}
|
||||
|
||||
private:
|
||||
/// The set of allocated values that are aliased by this lattice.
|
||||
/// For now, we only consider aliased value produced by the following
|
||||
/// situations:
|
||||
/// 1. values returned by scf.yield
|
||||
/// 2. block arguments in scf.for
|
||||
/// Example:
|
||||
/// alloc v1 alloc v2
|
||||
/// | |
|
||||
/// |--------------| |------------|
|
||||
/// scf.for v3 scf.for v4 scf.for v5
|
||||
/// |
|
||||
/// scf.yield v6
|
||||
///
|
||||
/// v1's alloc [v1]
|
||||
/// v2's alloc [v2]
|
||||
/// v3's alloc [v1]
|
||||
/// v4's alloc [v1, v2]
|
||||
/// v5's alloc [v2]
|
||||
/// v6's alloc [v1]
|
||||
///
|
||||
/// Therefore, v1's liveness range is the union of v3, v4, and v6
|
||||
/// v2's liveness range is the union of v4 and v5.
|
||||
DenseSet<Value> allocs;
|
||||
};
|
||||
|
||||
//===----------------------------------------------------------------------===//
|
||||
// Shared Memory Alias Analysis
|
||||
//===----------------------------------------------------------------------===//
|
||||
class SharedMemoryAliasAnalysis
|
||||
: public dataflow::SparseDataFlowAnalysis<dataflow::Lattice<AliasInfo>> {
|
||||
public:
|
||||
using dataflow::SparseDataFlowAnalysis<
|
||||
dataflow::Lattice<AliasInfo>>::SparseDataFlowAnalysis;
|
||||
using dataflow::SparseDataFlowAnalysis<
|
||||
dataflow::Lattice<AliasInfo>>::getLatticeElement;
|
||||
|
||||
/// XXX(Keren): Compatible interface with MLIR AliasAnalysis for future use.
|
||||
/// Given two values, returns their aliasing behavior.
|
||||
AliasResult alias(Value lhs, Value rhs);
|
||||
|
||||
/// Returns the modify-reference behavior of `op` on `location`.
|
||||
ModRefResult getModRef(Operation *op, Value location);
|
||||
|
||||
void setToEntryState(dataflow::Lattice<AliasInfo> *lattice) override {
|
||||
propagateIfChanged(
|
||||
lattice, lattice->join(
|
||||
AliasInfo::getPessimisticValueState(lattice->getPoint())));
|
||||
}
|
||||
|
||||
/// Computes if the alloc set of the results are changed.
|
||||
void
|
||||
visitOperation(Operation *op,
|
||||
ArrayRef<const dataflow::Lattice<AliasInfo> *> operands,
|
||||
ArrayRef<dataflow::Lattice<AliasInfo> *> results) override;
|
||||
};
|
||||
|
||||
} // namespace mlir
|
||||
|
||||
#endif // TRITON_ANALYSISROCM_ALIAS_H
|
||||
264
python/triton/third_party/hip/include/triton/AnalysisROCM/Allocation.h
vendored
Normal file
264
python/triton/third_party/hip/include/triton/AnalysisROCM/Allocation.h
vendored
Normal file
@@ -0,0 +1,264 @@
|
||||
#ifndef TRITON_ANALYSISROCM_ALLOCATION_H
|
||||
#define TRITON_ANALYSISROCM_ALLOCATION_H
|
||||
|
||||
#include "triton/AnalysisROCM/Utility.h"
|
||||
#include "llvm/ADT/DenseMap.h"
|
||||
#include "llvm/ADT/MapVector.h"
|
||||
#include "llvm/ADT/SetVector.h"
|
||||
#include "llvm/Support/raw_ostream.h"
|
||||
|
||||
#include "triton/Dialect/Triton/IR/Dialect.h"
|
||||
#include "triton/Dialect/TritonGPUROCM/IR/Dialect.h"
|
||||
#include "triton/Dialect/TritonNvidiaGPU/IR/Dialect.h"
|
||||
#include <atomic>
|
||||
#include <limits>
|
||||
|
||||
namespace mlir {
|
||||
|
||||
namespace triton {
|
||||
class AllocationAnalysis;
|
||||
|
||||
SmallVector<unsigned>
|
||||
getScratchConfigForCvtLayout(triton::gpu_rocm::ConvertLayoutOp op, unsigned &inVec,
|
||||
unsigned &outVec);
|
||||
|
||||
} // namespace triton
|
||||
|
||||
/// Modified from llvm-15.0: llvm/ADT/AddressRanges.h
|
||||
/// A class that represents an interval, specified using a start and an end
|
||||
/// values: [Start, End).
|
||||
template <typename T> class Interval {
|
||||
public:
|
||||
Interval() {}
|
||||
Interval(T S) : Start(S), End(S+1) {}
|
||||
Interval(T S, T E) : Start(S), End(E) { assert(Start <= End); }
|
||||
T start() const { return Start; }
|
||||
T end() const { return End; }
|
||||
T size() const { return End - Start; }
|
||||
bool contains(T Addr) const { return Start <= Addr && Addr < End; }
|
||||
bool intersects(const Interval &R) const {
|
||||
return Start < R.End && R.Start < End;
|
||||
}
|
||||
bool operator==(const Interval &R) const {
|
||||
return Start == R.Start && End == R.End;
|
||||
}
|
||||
bool operator!=(const Interval &R) const { return !(*this == R); }
|
||||
bool operator<(const Interval &R) const {
|
||||
return std::make_pair(Start, End) < std::make_pair(R.Start, R.End);
|
||||
}
|
||||
bool adjacent(T Addr) const {
|
||||
return Addr+1 == Start || Addr == End;
|
||||
}
|
||||
bool adjacent(const Interval &R) const {
|
||||
return adjacent(R.Start) || adjacent(R.End-1);
|
||||
}
|
||||
|
||||
Interval merge(const Interval &R) const {
|
||||
return Interval(std::min(Start, R.Start), std::max(End, R.End));
|
||||
}
|
||||
|
||||
private:
|
||||
T Start = std::numeric_limits<T>::min();
|
||||
T End = std::numeric_limits<T>::max();
|
||||
};
|
||||
|
||||
template <class T> Interval(T, T) -> Interval<T>;
|
||||
|
||||
class Allocation {
|
||||
public:
|
||||
/// A unique identifier for shared memory buffers
|
||||
using BufferId = size_t;
|
||||
using BufferIdSetT = DenseSet<BufferId>;
|
||||
using FuncAllocMapT = CallGraph<Allocation>::FuncDataMapT;
|
||||
|
||||
static constexpr BufferId InvalidBufferId =
|
||||
std::numeric_limits<BufferId>::max();
|
||||
|
||||
Allocation() = default;
|
||||
/// Creates a new Allocation analysis that computes the shared memory
|
||||
/// information for all associated shared memory values.
|
||||
explicit Allocation(Operation *operation) : operation(operation) {}
|
||||
|
||||
/// Runs allocation analysis on the given top-level operation.
|
||||
void run(FuncAllocMapT &funcAllocMap);
|
||||
|
||||
/// Returns the operation this analysis was constructed from.
|
||||
Operation *getOperation() const { return operation; }
|
||||
|
||||
/// Returns the offset of the given buffer in the shared memory.
|
||||
size_t getOffset(BufferId bufferId) const {
|
||||
return bufferSet.at(bufferId).offset;
|
||||
}
|
||||
|
||||
/// Returns the size of the given buffer in the shared memory.
|
||||
size_t getAllocatedSize(BufferId bufferId) const {
|
||||
return bufferSet.at(bufferId).size;
|
||||
}
|
||||
|
||||
/// Returns the allocated interval of the given buffer.
|
||||
Interval<size_t> getAllocatedInterval(BufferId bufferId) const {
|
||||
auto &buffer = bufferSet.at(bufferId);
|
||||
return Interval<size_t>(buffer.offset, buffer.offset + buffer.size);
|
||||
}
|
||||
|
||||
/// Returns the buffer id of the given value.
|
||||
/// This interface only returns the allocated buffer id.
|
||||
/// If you want to get all the buffer ids that are associated with the given
|
||||
/// value, including alias buffers, use getBufferIds.
|
||||
BufferId getBufferId(Value value) const {
|
||||
if (valueBuffer.count(value)) {
|
||||
return valueBuffer.lookup(value)->id;
|
||||
} else {
|
||||
return InvalidBufferId;
|
||||
}
|
||||
}
|
||||
|
||||
/// Returns all the buffer ids of the given value, including alias buffers.
|
||||
BufferIdSetT getBufferIds(Value value) const {
|
||||
BufferIdSetT bufferIds;
|
||||
auto allocBufferId = getBufferId(value);
|
||||
if (allocBufferId != InvalidBufferId)
|
||||
bufferIds.insert(allocBufferId);
|
||||
for (auto *buffer : aliasBuffer.lookup(value)) {
|
||||
if (buffer->id != InvalidBufferId)
|
||||
bufferIds.insert(buffer->id);
|
||||
}
|
||||
return bufferIds;
|
||||
}
|
||||
|
||||
/// Returns the scratch buffer id of the given value.
|
||||
BufferId getBufferId(Operation *operation) const {
|
||||
if (opScratch.count(operation)) {
|
||||
return opScratch.lookup(operation)->id;
|
||||
} else if (opVirtual.count(operation)) {
|
||||
return opVirtual.lookup(operation)->id;
|
||||
} else {
|
||||
return InvalidBufferId;
|
||||
}
|
||||
}
|
||||
|
||||
/// Returns if the given buffer is a virtual buffer.
|
||||
bool isVirtualBuffer(BufferId bufferId) const {
|
||||
return bufferSet.at(bufferId).kind == BufferT::BufferKind::Virtual;
|
||||
}
|
||||
|
||||
/// Returns the size of total shared memory allocated
|
||||
size_t getSharedMemorySize() const { return sharedMemorySize; }
|
||||
|
||||
private:
|
||||
/// A class that represents a shared memory buffer
|
||||
struct BufferT {
|
||||
/// Explicit: triton_gpu_rocm.alloc_tensor
|
||||
/// Scratch: triton_gpu_rocm.convert_layout
|
||||
/// Virtual: triton.call
|
||||
enum class BufferKind { Explicit, Scratch, Virtual };
|
||||
|
||||
/// MT: thread-safe
|
||||
inline static std::atomic<BufferId> nextId = 0;
|
||||
|
||||
BufferKind kind;
|
||||
BufferId id;
|
||||
size_t size;
|
||||
size_t alignment;
|
||||
size_t offset;
|
||||
|
||||
bool operator==(const BufferT &other) const { return id == other.id; }
|
||||
bool operator<(const BufferT &other) const { return id < other.id; }
|
||||
|
||||
BufferT() : BufferT(BufferKind::Explicit, 0) {}
|
||||
BufferT(BufferKind kind, size_t size, size_t alignment = 4,
|
||||
size_t offset = 0)
|
||||
: kind(kind), id(nextId++), size(size), alignment(alignment),
|
||||
offset(offset) {}
|
||||
};
|
||||
|
||||
/// Op -> Scratch Buffer
|
||||
using OpScratchMapT = DenseMap<Operation *, BufferT *>;
|
||||
/// Value -> Explicit Buffer
|
||||
using ValueBufferMapT = llvm::MapVector<Value, BufferT *>;
|
||||
/// Value -> Alias Buffer
|
||||
using AliasBufferMapT = llvm::MapVector<Value, llvm::SetVector<BufferT *>>;
|
||||
/// BufferId -> Buffer
|
||||
using BufferSetT = std::map<BufferId, BufferT>;
|
||||
|
||||
private:
|
||||
template <BufferT::BufferKind Kind, typename KeyType, typename... Args>
|
||||
void addBuffer(KeyType &key, Args &&...args) {
|
||||
auto buffer = BufferT(Kind, std::forward<Args>(args)...);
|
||||
bufferSet[buffer.id] = std::move(buffer);
|
||||
if constexpr (Kind == BufferT::BufferKind::Explicit) {
|
||||
valueBuffer[key] = &bufferSet[buffer.id];
|
||||
} else if constexpr (Kind == BufferT::BufferKind::Virtual) {
|
||||
opVirtual[key] = &bufferSet[buffer.id];
|
||||
} else {
|
||||
opScratch[key] = &bufferSet[buffer.id];
|
||||
}
|
||||
}
|
||||
|
||||
void addAlias(Value value, Value alloc) {
|
||||
aliasBuffer[value].insert(valueBuffer[alloc]);
|
||||
}
|
||||
|
||||
private:
|
||||
Operation *operation = nullptr;
|
||||
OpScratchMapT opScratch;
|
||||
OpScratchMapT opVirtual;
|
||||
ValueBufferMapT valueBuffer;
|
||||
AliasBufferMapT aliasBuffer;
|
||||
BufferSetT bufferSet;
|
||||
size_t sharedMemorySize = 0;
|
||||
|
||||
friend class triton::AllocationAnalysis;
|
||||
};
|
||||
|
||||
/// Static analysis that computes the allocation of shared memory buffers
|
||||
/// of the entire call graph.
|
||||
/// The allocation is performed in a post-order walk of the call graph.
|
||||
/// Each call op is treated like convert_layout that allocates a scratch buffer.
|
||||
/// At each call, we compute the start offset of the scratch buffer and pass it
|
||||
/// as an argument to the callee.
|
||||
class ModuleAllocation : public CallGraph<Allocation> {
|
||||
public:
|
||||
using FuncOffsetMapT = DenseMap<FunctionOpInterface, Value>;
|
||||
|
||||
explicit ModuleAllocation(ModuleOp moduleOp)
|
||||
: CallGraph<Allocation>(moduleOp) {
|
||||
walk<WalkOrder::PreOrder, WalkOrder::PostOrder>(
|
||||
// Pre-order edge walk callback
|
||||
[](CallOpInterface callOp, FunctionOpInterface funcOp) {},
|
||||
// Post-order node walk callback
|
||||
[&](FunctionOpInterface funcOp) {
|
||||
auto [iter, inserted] = funcMap.try_emplace(funcOp, funcOp);
|
||||
if (inserted)
|
||||
iter->second.run(funcMap);
|
||||
});
|
||||
}
|
||||
|
||||
size_t getSharedMemorySize() {
|
||||
size_t size = 0;
|
||||
for (auto funcOp : getRoots()) {
|
||||
auto *alloc = getFuncData(funcOp);
|
||||
size = std::max(size, alloc->getSharedMemorySize());
|
||||
}
|
||||
return size;
|
||||
}
|
||||
|
||||
size_t getSharedMemorySize(FunctionOpInterface funcOp) {
|
||||
return getFuncData(funcOp)->getSharedMemorySize();
|
||||
}
|
||||
|
||||
void setFunctionSharedMemoryValue(FunctionOpInterface funcOp, Value value) {
|
||||
sharedMemoryValue[funcOp] = value;
|
||||
}
|
||||
|
||||
Value getFunctionSharedMemoryBase(FunctionOpInterface funcOp) {
|
||||
return sharedMemoryValue[funcOp];
|
||||
}
|
||||
|
||||
private:
|
||||
FuncOffsetMapT sharedMemoryValue;
|
||||
};
|
||||
|
||||
} // namespace mlir
|
||||
|
||||
#endif // TRITON_ANALYSISROCM_ALLOCATION_H
|
||||
358
python/triton/third_party/hip/include/triton/AnalysisROCM/AxisInfo.h
vendored
Normal file
358
python/triton/third_party/hip/include/triton/AnalysisROCM/AxisInfo.h
vendored
Normal file
@@ -0,0 +1,358 @@
|
||||
#ifndef TRITON_ANALYSISROCM_AXISINFO_H
|
||||
#define TRITON_ANALYSISROCM_AXISINFO_H
|
||||
|
||||
#include "mlir/Analysis/DataFlow/SparseAnalysis.h"
|
||||
#include "llvm/Support/raw_ostream.h"
|
||||
|
||||
#include "mlir/Support/LLVM.h"
|
||||
#include "triton/AnalysisROCM/Utility.h"
|
||||
#include "triton/Dialect/Triton/IR/Dialect.h"
|
||||
#include "triton/Dialect/TritonGPUROCM/IR/Dialect.h"
|
||||
|
||||
#include <optional>
|
||||
#include <type_traits>
|
||||
|
||||
namespace mlir {
|
||||
|
||||
//===----------------------------------------------------------------------===//
|
||||
// AxisInfo
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
||||
/// This lattice value represents known information on the axes of a lattice.
|
||||
class AxisInfo {
|
||||
public:
|
||||
typedef SmallVector<int64_t> DimVectorT;
|
||||
|
||||
public:
|
||||
/// Default constructor
|
||||
AxisInfo() : AxisInfo({}, {}, {}) {}
|
||||
/// Construct contiguity info with known contiguity
|
||||
AxisInfo(DimVectorT knownContiguity, DimVectorT knownDivisibility,
|
||||
DimVectorT knownConstancy)
|
||||
: AxisInfo(knownContiguity, knownDivisibility, knownConstancy, {}) {}
|
||||
AxisInfo(DimVectorT knownContiguity, DimVectorT knownDivisibility,
|
||||
DimVectorT knownConstancy, std::optional<int64_t> knownConstantValue)
|
||||
: contiguity(knownContiguity), divisibility(knownDivisibility),
|
||||
constancy(knownConstancy), constantValue(knownConstantValue),
|
||||
rank(contiguity.size()) {
|
||||
assert(knownContiguity.size() == static_cast<size_t>(rank));
|
||||
assert(knownDivisibility.size() == static_cast<size_t>(rank));
|
||||
assert(knownConstancy.size() == static_cast<size_t>(rank));
|
||||
}
|
||||
|
||||
/// Accessors
|
||||
int64_t getContiguity(size_t dim) const { return contiguity[dim]; }
|
||||
const DimVectorT &getContiguity() const { return contiguity; }
|
||||
|
||||
int64_t getDivisibility(size_t dim) const { return divisibility[dim]; }
|
||||
const DimVectorT &getDivisibility() const { return divisibility; }
|
||||
|
||||
int64_t getConstancy(size_t dim) const { return constancy[dim]; }
|
||||
const DimVectorT &getConstancy() const { return constancy; }
|
||||
|
||||
int getRank() const { return rank; }
|
||||
|
||||
std::optional<int64_t> getConstantValue() const { return constantValue; }
|
||||
|
||||
template <class T>
|
||||
static void
|
||||
initPessimisticStateFromFunc(int argNumber, T funcOp, DimVectorT *contiguity,
|
||||
DimVectorT *divisibility, DimVectorT *constancy);
|
||||
/// Comparison
|
||||
bool operator==(const AxisInfo &other) const {
|
||||
return (contiguity == other.contiguity) &&
|
||||
(divisibility == other.divisibility) &&
|
||||
(constancy == other.constancy) &&
|
||||
(constantValue == other.constantValue) && (rank == other.rank);
|
||||
}
|
||||
|
||||
/// The pessimistic value state of the contiguity is unknown.
|
||||
static AxisInfo getPessimisticValueState(MLIRContext *context = nullptr) {
|
||||
return AxisInfo();
|
||||
}
|
||||
static AxisInfo getPessimisticValueState(Value value);
|
||||
|
||||
/// The gcd of both arguments for each dimension
|
||||
static AxisInfo join(const AxisInfo &lhs, const AxisInfo &rhs);
|
||||
|
||||
void print(raw_ostream &os) const {
|
||||
auto print = [&](StringRef name, DimVectorT vec) {
|
||||
os << name << " = [";
|
||||
llvm::interleaveComma(vec, os);
|
||||
os << "]";
|
||||
};
|
||||
print("contiguity", contiguity);
|
||||
print(", divisibility", divisibility);
|
||||
print(", constancy", constancy);
|
||||
os << ", constant_value = ";
|
||||
if (constantValue)
|
||||
os << *constantValue;
|
||||
else
|
||||
os << "<none>";
|
||||
}
|
||||
|
||||
private:
|
||||
/// The _contiguity_ information maps the `d`-th
|
||||
/// dimension to the length of the shortest
|
||||
/// sequence of contiguous integers along it.
|
||||
/// Suppose we have an array of N elements,
|
||||
/// with a contiguity value C,
|
||||
/// the array can be divided into a list of
|
||||
/// N/C sequences of C contiguous elements.
|
||||
/// Since we have N = 2^k, C must be a power of two.
|
||||
/// For example:
|
||||
/// [10, 11, 12, 13, 18, 19, 20, 21]
|
||||
/// [20, 21, 22, 23, 28, 29, 30, 31]
|
||||
/// Would have contiguity [1, 4].
|
||||
/// and
|
||||
/// [12, 16, 20, 24]
|
||||
/// [13, 17, 21, 25]
|
||||
/// [14, 18, 22, 26]
|
||||
/// [15, 19, 23, 27]
|
||||
/// [18, 22, 26, 30]
|
||||
/// [19, 23, 27, 31]
|
||||
/// Would have contiguity [2, 1].
|
||||
DimVectorT contiguity;
|
||||
|
||||
/// The _divisibility_ information maps the `d`-th
|
||||
/// dimension to the largest power-of-two that
|
||||
/// divides the first element of all groups of
|
||||
// _contiguity_ values along it
|
||||
/// For example:
|
||||
/// [10, 11, 12, 13, 18, 19, 20, 21]
|
||||
/// [20, 21, 22, 23, 28, 29, 30, 31]
|
||||
// would have divisibility [1, 2]
|
||||
// and
|
||||
/// [12, 16, 20, 24]
|
||||
/// [13, 17, 21, 25]
|
||||
/// [14, 18, 22, 26]
|
||||
/// [15, 19, 23, 27]
|
||||
// would have divisibility [4, 1]
|
||||
// On the other hand:
|
||||
// [0, 1, 2, 0, 4, 5, 6, 7]
|
||||
// would have divisibility 1 because
|
||||
// _contiguity_=1
|
||||
DimVectorT divisibility;
|
||||
|
||||
/// The _constancy_ information maps the `d`-th
|
||||
/// dimension to the length of the shortest
|
||||
/// sequence of constant integer along it. This is
|
||||
/// particularly useful to infer the contiguity
|
||||
/// of operations (e.g., add) involving a constant.
|
||||
/// Suppose we have an array of N elements,
|
||||
/// with a constancy value C,
|
||||
/// the array can be divided into a list of
|
||||
/// N/C sequences of C elements with the same value.
|
||||
/// Since we have N = 2^k, C must be a power of two.
|
||||
/// For example
|
||||
/// [8, 8, 8, 8, 12, 12, 12, 12]
|
||||
/// [16, 16, 16, 16, 20, 20, 20, 20]
|
||||
/// would have constancy [1, 4]
|
||||
DimVectorT constancy;
|
||||
|
||||
/// The constant value of the lattice if we can infer it.
|
||||
std::optional<int64_t> constantValue;
|
||||
|
||||
// number of dimensions of the lattice
|
||||
int rank{};
|
||||
};
|
||||
|
||||
class AxisInfoVisitor {
|
||||
public:
|
||||
AxisInfoVisitor() = default;
|
||||
virtual ~AxisInfoVisitor() = default;
|
||||
|
||||
static bool isContiguousDim(const AxisInfo &info, ArrayRef<int64_t> shape,
|
||||
int dim) {
|
||||
return info.getContiguity(dim) == shape[dim];
|
||||
}
|
||||
|
||||
static bool isConstantDim(const AxisInfo &info, ArrayRef<int64_t> shape,
|
||||
int dim) {
|
||||
return info.getConstancy(dim) == shape[dim];
|
||||
}
|
||||
|
||||
virtual AxisInfo
|
||||
getAxisInfo(Operation *op,
|
||||
ArrayRef<const dataflow::Lattice<AxisInfo> *> operands) = 0;
|
||||
|
||||
virtual bool match(Operation *op) = 0;
|
||||
};
|
||||
|
||||
/// Base class for all operations
|
||||
template <typename OpTy> class AxisInfoVisitorImpl : public AxisInfoVisitor {
|
||||
public:
|
||||
using AxisInfoVisitor::AxisInfoVisitor;
|
||||
|
||||
AxisInfo
|
||||
getAxisInfo(Operation *op,
|
||||
ArrayRef<const dataflow::Lattice<AxisInfo> *> operands) final {
|
||||
return getAxisInfo(cast<OpTy>(op), operands);
|
||||
}
|
||||
|
||||
bool match(Operation *op) final { return isa<OpTy>(op); }
|
||||
|
||||
virtual AxisInfo
|
||||
getAxisInfo(OpTy op, ArrayRef<const dataflow::Lattice<AxisInfo> *> operands) {
|
||||
llvm_unreachable("Unimplemented getAxisInfo");
|
||||
}
|
||||
};
|
||||
|
||||
/// Binary operations
|
||||
template <typename OpTy>
|
||||
class BinaryOpVisitorImpl : public AxisInfoVisitorImpl<OpTy> {
|
||||
public:
|
||||
using AxisInfoVisitorImpl<OpTy>::AxisInfoVisitorImpl;
|
||||
|
||||
AxisInfo
|
||||
getAxisInfo(OpTy op,
|
||||
ArrayRef<const dataflow::Lattice<AxisInfo> *> operands) override {
|
||||
auto lhsInfo = operands[0]->getValue();
|
||||
auto rhsInfo = operands[1]->getValue();
|
||||
auto rank = lhsInfo.getRank();
|
||||
assert(operands.size() == 2 && "Expected two operands");
|
||||
AxisInfo::DimVectorT contiguity;
|
||||
AxisInfo::DimVectorT divisibility;
|
||||
AxisInfo::DimVectorT constancy;
|
||||
auto constantValue = getConstantValue(op, lhsInfo, rhsInfo);
|
||||
for (auto d = 0; d < rank; ++d) {
|
||||
if (constantValue.has_value()) {
|
||||
contiguity.push_back(1);
|
||||
constancy.push_back(
|
||||
std::max(lhsInfo.getConstancy(d), rhsInfo.getConstancy(d)));
|
||||
divisibility.push_back(highestPowOf2Divisor(constantValue.value()));
|
||||
} else {
|
||||
contiguity.push_back(getContiguity(op, lhsInfo, rhsInfo, d));
|
||||
constancy.push_back(getConstancy(op, lhsInfo, rhsInfo, d));
|
||||
divisibility.push_back(getDivisibility(op, lhsInfo, rhsInfo, d));
|
||||
}
|
||||
}
|
||||
return AxisInfo(contiguity, divisibility, constancy, constantValue);
|
||||
}
|
||||
|
||||
protected:
|
||||
virtual int64_t getContiguity(OpTy op, const AxisInfo &lhs,
|
||||
const AxisInfo &rhs, int dim) {
|
||||
return 1;
|
||||
}
|
||||
|
||||
virtual int64_t getDivisibility(OpTy op, const AxisInfo &lhs,
|
||||
const AxisInfo &rhs, int dim) {
|
||||
return 1;
|
||||
}
|
||||
|
||||
virtual int64_t getConstancy(OpTy op, const AxisInfo &lhs,
|
||||
const AxisInfo &rhs, int dim) {
|
||||
return 1;
|
||||
}
|
||||
|
||||
virtual std::optional<int64_t> getConstantValue(OpTy op, const AxisInfo &lhs,
|
||||
const AxisInfo &rhs) {
|
||||
return {};
|
||||
}
|
||||
};
|
||||
|
||||
class AxisInfoVisitorList {
|
||||
public:
|
||||
template <typename... Ts, typename = std::enable_if_t<sizeof...(Ts) != 0>>
|
||||
void append() {
|
||||
(visitors.emplace_back(std::make_unique<Ts>()), ...);
|
||||
}
|
||||
|
||||
AxisInfo apply(Operation *op,
|
||||
ArrayRef<const dataflow::Lattice<AxisInfo> *> operands) {
|
||||
for (auto &visitor : visitors)
|
||||
if (visitor->match(op))
|
||||
return visitor->getAxisInfo(op, operands);
|
||||
return AxisInfo();
|
||||
}
|
||||
|
||||
private:
|
||||
std::vector<std::unique_ptr<AxisInfoVisitor>> visitors;
|
||||
};
|
||||
|
||||
class AxisInfoAnalysis
|
||||
: public dataflow::SparseDataFlowAnalysis<dataflow::Lattice<AxisInfo>> {
|
||||
private:
|
||||
AxisInfoVisitorList visitors;
|
||||
|
||||
void setToEntryState(dataflow::Lattice<AxisInfo> *lattice) override {
|
||||
propagateIfChanged(
|
||||
lattice,
|
||||
lattice->join(AxisInfo::getPessimisticValueState(lattice->getPoint())));
|
||||
}
|
||||
|
||||
public:
|
||||
AxisInfoAnalysis(DataFlowSolver &solver);
|
||||
using dataflow::SparseDataFlowAnalysis<
|
||||
dataflow::Lattice<AxisInfo>>::getLatticeElement;
|
||||
using FuncAxisInfoMapT = DenseMap<FunctionOpInterface, AxisInfo>;
|
||||
|
||||
void visitOperation(Operation *op,
|
||||
ArrayRef<const dataflow::Lattice<AxisInfo> *> operands,
|
||||
ArrayRef<dataflow::Lattice<AxisInfo> *> results) override;
|
||||
};
|
||||
|
||||
/// Module level axis info analysis based on the call graph, assuming that we
|
||||
/// do not have recursive functions.
|
||||
/// Since each function will be called multiple times, we need to
|
||||
/// calculate the axis info based on the axis info of all the callers.
|
||||
/// In the future, we can perform optimization using function cloning so that
|
||||
/// each call site will have unique axis info.
|
||||
using AxisInfoMapT = DenseMap<Value, AxisInfo>;
|
||||
class ModuleAxisInfoAnalysis : public CallGraph<AxisInfoMapT> {
|
||||
public:
|
||||
explicit ModuleAxisInfoAnalysis(ModuleOp moduleOp)
|
||||
: CallGraph<AxisInfoMapT>(moduleOp) {
|
||||
SmallVector<FunctionOpInterface> funcs;
|
||||
for (auto root : getRoots()) {
|
||||
walk<WalkOrder::PreOrder, WalkOrder::PostOrder>(
|
||||
// Pre-order edge walk callback
|
||||
[](CallOpInterface callOp, FunctionOpInterface funcOp) {},
|
||||
// Post-order node walk callback
|
||||
[&](FunctionOpInterface funcOp) {
|
||||
funcs.push_back(funcOp);
|
||||
funcMap.try_emplace(funcOp, AxisInfoMapT{});
|
||||
});
|
||||
}
|
||||
SetVector<FunctionOpInterface> sortedFuncs(funcs.begin(), funcs.end());
|
||||
SymbolTableCollection symbolTable;
|
||||
for (auto funcOp : llvm::reverse(sortedFuncs)) {
|
||||
initialize(funcOp);
|
||||
funcOp.walk([&](CallOpInterface callOp) {
|
||||
auto callee =
|
||||
dyn_cast<FunctionOpInterface>(callOp.resolveCallable(&symbolTable));
|
||||
update(callOp, callee);
|
||||
});
|
||||
}
|
||||
}
|
||||
|
||||
AxisInfo *getAxisInfo(Value value) {
|
||||
auto funcOp =
|
||||
value.getParentRegion()->getParentOfType<FunctionOpInterface>();
|
||||
auto *axisInfoMap = getFuncData(funcOp);
|
||||
if (!axisInfoMap) {
|
||||
return nullptr;
|
||||
}
|
||||
auto it = axisInfoMap->find(value);
|
||||
if (it == axisInfoMap->end()) {
|
||||
return nullptr;
|
||||
}
|
||||
return &(it->second);
|
||||
}
|
||||
|
||||
unsigned getPtrContiguity(Value ptr);
|
||||
|
||||
unsigned getPtrAlignment(Value ptr);
|
||||
|
||||
unsigned getMaskAlignment(Value mask);
|
||||
|
||||
private:
|
||||
void initialize(FunctionOpInterface funcOp);
|
||||
|
||||
void update(CallOpInterface callOp, FunctionOpInterface funcOp);
|
||||
};
|
||||
|
||||
} // namespace mlir
|
||||
|
||||
#endif
|
||||
152
python/triton/third_party/hip/include/triton/AnalysisROCM/Membar.h
vendored
Normal file
152
python/triton/third_party/hip/include/triton/AnalysisROCM/Membar.h
vendored
Normal file
@@ -0,0 +1,152 @@
|
||||
#ifndef TRITON_ANALYSISROCM_MEMBAR_H
|
||||
#define TRITON_ANALYSISROCM_MEMBAR_H
|
||||
|
||||
#include "Allocation.h"
|
||||
#include "llvm/ADT/SmallPtrSet.h"
|
||||
|
||||
#include <set>
|
||||
|
||||
namespace mlir {
|
||||
|
||||
class OpBuilder;
|
||||
|
||||
struct BlockInfo {
|
||||
using BufferIdSetT = Allocation::BufferIdSetT;
|
||||
using IntervalSetT = std::set<Interval<size_t>>;
|
||||
|
||||
IntervalSetT syncReadIntervals;
|
||||
IntervalSetT syncWriteIntervals;
|
||||
|
||||
BlockInfo() = default;
|
||||
|
||||
/// Unions two BlockInfo objects.
|
||||
BlockInfo &join(const BlockInfo &other) {
|
||||
syncReadIntervals.insert(other.syncReadIntervals.begin(),
|
||||
other.syncReadIntervals.end());
|
||||
syncWriteIntervals.insert(other.syncWriteIntervals.begin(),
|
||||
other.syncWriteIntervals.end());
|
||||
return *this;
|
||||
}
|
||||
|
||||
/// Returns true if intervals in two BlockInfo objects are intersected.
|
||||
bool isIntersected(const BlockInfo &other) const {
|
||||
return /*RAW*/ isIntersected(syncWriteIntervals, other.syncReadIntervals) ||
|
||||
/*WAR*/
|
||||
isIntersected(syncReadIntervals, other.syncWriteIntervals) ||
|
||||
/*WAW*/
|
||||
isIntersected(syncWriteIntervals, other.syncWriteIntervals);
|
||||
}
|
||||
|
||||
/// Clears the intervals because a barrier is inserted.
|
||||
void sync() {
|
||||
syncReadIntervals.clear();
|
||||
syncWriteIntervals.clear();
|
||||
}
|
||||
|
||||
/// Compares two BlockInfo objects.
|
||||
bool operator==(const BlockInfo &other) const {
|
||||
return syncReadIntervals == other.syncReadIntervals &&
|
||||
syncWriteIntervals == other.syncWriteIntervals;
|
||||
}
|
||||
|
||||
bool operator!=(const BlockInfo &other) const { return !(*this == other); }
|
||||
|
||||
private:
|
||||
bool isIntersected(const IntervalSetT &lhsIntervalSet,
|
||||
const IntervalSetT &rhsIntervalSet) const {
|
||||
for (auto &lhs : lhsIntervalSet)
|
||||
for (auto &rhs : rhsIntervalSet)
|
||||
if (lhs.intersects(rhs))
|
||||
return true;
|
||||
return false;
|
||||
}
|
||||
};
|
||||
|
||||
//===----------------------------------------------------------------------===//
|
||||
// Shared Memory Barrier Analysis
|
||||
//===----------------------------------------------------------------------===//
|
||||
class MembarAnalysis {
|
||||
public:
|
||||
using FuncBlockInfoMapT = CallGraph<BlockInfo>::FuncDataMapT;
|
||||
/// Creates a new Membar analysis that generates the shared memory barrier
|
||||
/// in the following circumstances:
|
||||
/// - RAW: If a shared memory write is followed by a shared memory read, and
|
||||
/// their addresses are intersected, a barrier is inserted.
|
||||
/// - WAR: If a shared memory read is followed by a shared memory write, and
|
||||
/// their addresses are intersected, a barrier is inserted.
|
||||
/// The following circumstances do not require a barrier:
|
||||
/// - WAW: not possible because overlapped memory allocation is not allowed.
|
||||
/// - RAR: no write is performed.
|
||||
/// Temporary storage of operations such as Reduce are considered as both
|
||||
/// a shared memory read. If the temporary storage is written but not read,
|
||||
/// it is considered as the problem of the operation itself but not the membar
|
||||
/// analysis.
|
||||
MembarAnalysis() = default;
|
||||
explicit MembarAnalysis(Allocation *allocation) : allocation(allocation) {}
|
||||
|
||||
/// Runs the membar analysis to the given operation, inserts a barrier if
|
||||
/// necessary.
|
||||
void run(FuncBlockInfoMapT &funcBlockInfoMap);
|
||||
|
||||
private:
|
||||
/// Applies the barrier analysis based on the SCF dialect, in which each
|
||||
/// region has a single basic block only.
|
||||
/// Example:
|
||||
/// region1
|
||||
/// op1
|
||||
/// op2 (scf.if)
|
||||
/// region2
|
||||
/// op3
|
||||
/// op4
|
||||
/// region3
|
||||
/// op5
|
||||
/// op6
|
||||
/// op7
|
||||
/// TODO: Explain why we don't use ForwardAnalysis:
|
||||
void resolve(FunctionOpInterface funcOp, FuncBlockInfoMapT *funcBlockInfoMap,
|
||||
OpBuilder *builder);
|
||||
|
||||
/// Updates the BlockInfo operation based on the operation.
|
||||
void update(Operation *operation, BlockInfo *blockInfo,
|
||||
FuncBlockInfoMapT *funcBlockInfoMap, OpBuilder *builder);
|
||||
|
||||
/// Collects the successors of the terminator
|
||||
void visitTerminator(Operation *operation, SmallVector<Block *> &successors);
|
||||
|
||||
private:
|
||||
Allocation *allocation = nullptr;
|
||||
};
|
||||
|
||||
/// Postorder traversal on the callgraph to insert membar instructions
|
||||
/// of each function.
|
||||
/// Each function maintains a BlockInfo map that includes all potential buffers
|
||||
/// after returning. This way users do not have to explicitly insert membars
|
||||
/// before and after function calls, but might be a bit conservative.
|
||||
class ModuleMembarAnalysis : public CallGraph<BlockInfo> {
|
||||
public:
|
||||
ModuleMembarAnalysis(ModuleAllocation *moduleAllocation)
|
||||
: CallGraph<BlockInfo>(moduleAllocation->getModuleOp()),
|
||||
moduleAllocation(moduleAllocation) {}
|
||||
|
||||
void run() {
|
||||
walk<WalkOrder::PreOrder, WalkOrder::PostOrder>(
|
||||
// Pre-order walk callback
|
||||
[](CallOpInterface callOp, FunctionOpInterface funcOp) {},
|
||||
// Post-order walk callback
|
||||
[&](FunctionOpInterface funcOp) {
|
||||
auto *allocation = moduleAllocation->getFuncData(funcOp);
|
||||
auto [it, inserted] = funcMap.try_emplace(funcOp, BlockInfo());
|
||||
if (inserted) {
|
||||
MembarAnalysis analysis(allocation);
|
||||
analysis.run(funcMap);
|
||||
}
|
||||
});
|
||||
}
|
||||
|
||||
private:
|
||||
ModuleAllocation *moduleAllocation;
|
||||
};
|
||||
|
||||
} // namespace mlir
|
||||
|
||||
#endif // TRITON_ANALYSISROCM_MEMBAR_H
|
||||
347
python/triton/third_party/hip/include/triton/AnalysisROCM/Utility.h
vendored
Normal file
347
python/triton/third_party/hip/include/triton/AnalysisROCM/Utility.h
vendored
Normal file
@@ -0,0 +1,347 @@
|
||||
#ifndef TRITON_ANALYSISROCM_UTILITY_H
|
||||
#define TRITON_ANALYSISROCM_UTILITY_H
|
||||
|
||||
#include "mlir/Analysis/DataFlowFramework.h"
|
||||
#include "mlir/Analysis/SliceAnalysis.h"
|
||||
#include "triton/Dialect/Triton/IR/Dialect.h"
|
||||
#include "triton/Dialect/TritonGPUROCM/IR/Dialect.h"
|
||||
#include <algorithm>
|
||||
#include <numeric>
|
||||
#include <string>
|
||||
|
||||
namespace mlir {
|
||||
|
||||
class ReduceOpHelper {
|
||||
public:
|
||||
explicit ReduceOpHelper(triton::ReduceOp op)
|
||||
: op(op.getOperation()), axis(op.getAxis()) {
|
||||
auto firstTy = op.getOperands()[0].getType().cast<RankedTensorType>();
|
||||
srcShape = firstTy.getShape();
|
||||
srcEncoding = firstTy.getEncoding();
|
||||
srcElementTypes = op.getElementTypes();
|
||||
|
||||
for (const auto &t : op.getInputTypes()) {
|
||||
if (t.getShape() != srcShape) {
|
||||
op.emitError() << "shape mismatch";
|
||||
}
|
||||
if (t.getEncoding() != srcEncoding) {
|
||||
op.emitError() << "encoding mismatch";
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
ArrayRef<int64_t> getSrcShape() { return srcShape; }
|
||||
|
||||
Attribute getSrcLayout() { return srcEncoding; }
|
||||
|
||||
triton::ReduceOp getOperation() { return op; }
|
||||
|
||||
bool isFastReduction();
|
||||
|
||||
bool isWarpSynchronous();
|
||||
|
||||
unsigned getInterWarpSize();
|
||||
|
||||
unsigned getIntraWarpSize();
|
||||
|
||||
unsigned getInterWarpSizeWithUniqueData();
|
||||
|
||||
unsigned getIntraWarpSizeWithUniqueData();
|
||||
|
||||
unsigned getThreadsReductionAxis();
|
||||
|
||||
SmallVector<unsigned> getScratchConfigBasic();
|
||||
|
||||
SmallVector<SmallVector<unsigned>> getScratchConfigsFast();
|
||||
|
||||
unsigned getScratchSizeInBytes();
|
||||
|
||||
bool isSupportedLayout();
|
||||
|
||||
private:
|
||||
triton::ReduceOp op;
|
||||
ArrayRef<int64_t> srcShape;
|
||||
Attribute srcEncoding;
|
||||
SmallVector<Type> srcElementTypes;
|
||||
int axis;
|
||||
};
|
||||
|
||||
class ScanLoweringHelper {
|
||||
public:
|
||||
explicit ScanLoweringHelper(triton::ScanOp op) : scanOp(op) {
|
||||
auto type = scanOp.getOperand(0).getType().cast<RankedTensorType>();
|
||||
srcEncoding = type.getEncoding();
|
||||
}
|
||||
// Return true if the lowering of the scan op is supported.
|
||||
bool isSupported();
|
||||
// Return the number of elements per thread along axis dim.
|
||||
unsigned getAxisNumElementsPerThread();
|
||||
// Return the number of elements per thread along non-axis dims.
|
||||
unsigned getNonAxisNumElementsPerThread();
|
||||
// Return the number of threads per warp along non-axis dims.
|
||||
unsigned getNonAxisNumThreadsPerWarp();
|
||||
// Return the flat numbers of threads computing independent scan results.
|
||||
unsigned getNonAxisNumThreadsPerCTA();
|
||||
// Return the number of warps per CTA along axis dim.
|
||||
unsigned getAxisNumWarps();
|
||||
// Return the number of threads per warp along axis dim.
|
||||
unsigned getAxisNumThreadsPerWarp();
|
||||
// Return the number of blocks along axis dim.
|
||||
unsigned getAxisNumBlocks();
|
||||
// Return the number of blocks along non axis dim.
|
||||
unsigned getNonAxisNumBlocks();
|
||||
// Return the size of the scratch space needed for scan lowering.
|
||||
unsigned getScratchSizeInBytes();
|
||||
|
||||
// Stride between contiguous element along axis dim.
|
||||
unsigned getAxisElementStride();
|
||||
// Stride between contiguous threads along axis dim.
|
||||
unsigned getAxisThreadStride();
|
||||
// Stride between contiguous blocks along axis dim.
|
||||
unsigned getAxisBlockStride();
|
||||
|
||||
Location getLoc() { return scanOp.getLoc(); }
|
||||
unsigned getAxis() { return scanOp.getAxis(); }
|
||||
triton::gpu_rocm::BlockedEncodingAttr getEncoding();
|
||||
Region &getCombineOp();
|
||||
|
||||
private:
|
||||
triton::ScanOp scanOp;
|
||||
Attribute srcEncoding;
|
||||
};
|
||||
|
||||
bool maybeSharedAllocationOp(Operation *op);
|
||||
|
||||
bool maybeAliasOp(Operation *op);
|
||||
|
||||
#if 1
|
||||
bool supportMFMA(triton::DotOp op, int64_t nonKDim);
|
||||
#endif
|
||||
|
||||
bool supportMMA(triton::DotOp op, int version);
|
||||
|
||||
bool supportMMA(Value value, int version);
|
||||
|
||||
bool isSingleValue(Value value);
|
||||
|
||||
bool isMmaToDotShortcut(RankedTensorType &srcTy, RankedTensorType &dstTy);
|
||||
|
||||
bool isMmaToMmaShortcut(RankedTensorType &srcTy, RankedTensorType &dstTy);
|
||||
|
||||
// TODO: Move utility functions that belong to ConvertLayoutOp to class
|
||||
// ConvertLayoutOpHelper in the future
|
||||
bool shouldUseDistSmem(Attribute srcLayout, Attribute dstLayout);
|
||||
|
||||
template <typename T_OUT, typename T_IN>
|
||||
inline SmallVector<T_OUT> convertType(ArrayRef<T_IN> in) {
|
||||
SmallVector<T_OUT> out;
|
||||
for (const T_IN &i : in)
|
||||
out.push_back(T_OUT(i));
|
||||
return out;
|
||||
}
|
||||
|
||||
template <typename Int> Int product(llvm::ArrayRef<Int> arr) {
|
||||
return std::accumulate(arr.begin(), arr.end(), 1, std::multiplies{});
|
||||
}
|
||||
|
||||
template <typename Int> Int ceil(Int m, Int n) { return (m + n - 1) / n; }
|
||||
|
||||
/// output[i] = input[order[i]]
|
||||
template <typename T, typename RES_T = T>
|
||||
SmallVector<RES_T> reorder(ArrayRef<T> input, ArrayRef<unsigned> order) {
|
||||
size_t rank = order.size();
|
||||
assert(input.size() == rank);
|
||||
SmallVector<RES_T> result(rank);
|
||||
for (auto it : llvm::enumerate(order)) {
|
||||
result[it.index()] = input[it.value()];
|
||||
}
|
||||
return result;
|
||||
}
|
||||
|
||||
/// Get the highest power of 2 divisor of an integer.
|
||||
template <typename T> T highestPowOf2Divisor(T n) {
|
||||
if (n == 0) {
|
||||
return (static_cast<T>(1) << (sizeof(T) * 8 - 2));
|
||||
}
|
||||
return (n & (~(n - 1)));
|
||||
}
|
||||
|
||||
/// Get the next power of 2 for an integer (or the integer itself if it is a
|
||||
/// power of 2).
|
||||
template <typename T> T nextPowOf2(T n) {
|
||||
if (n == 0) {
|
||||
return 1;
|
||||
}
|
||||
n--;
|
||||
for (unsigned i = 1; i < sizeof(T) * 8; i <<= 1) {
|
||||
n |= n >> i;
|
||||
}
|
||||
return n + 1;
|
||||
}
|
||||
|
||||
#if 1
|
||||
bool isMfmaToDotShortcut(RankedTensorType &srcTy, RankedTensorType &dstTy);
|
||||
#endif
|
||||
|
||||
/// Multi-root DAG topological sort.
|
||||
/// Performs a topological sort of the Operation in the `toSort` SetVector.
|
||||
/// Returns a topologically sorted SetVector.
|
||||
/// It is faster than mlir::topologicalSort because it prunes nodes that have
|
||||
/// been visited before.
|
||||
SetVector<Operation *>
|
||||
multiRootTopologicalSort(const SetVector<Operation *> &toSort);
|
||||
|
||||
/// This uses the toplogicalSort above
|
||||
SetVector<Operation *>
|
||||
multiRootGetSlice(Operation *op, TransitiveFilter backwardFilter = nullptr,
|
||||
TransitiveFilter forwardFilter = nullptr);
|
||||
|
||||
/// Create a basic DataFlowSolver with constant and dead code analysis included.
|
||||
std::unique_ptr<DataFlowSolver> createDataFlowSolver();
|
||||
|
||||
/// This class represents a call graph for a given ModuleOp and holds
|
||||
/// data of type T associated with each FunctionOpInterface.
|
||||
template <typename T> class CallGraph {
|
||||
public:
|
||||
using FuncDataMapT = DenseMap<FunctionOpInterface, T>;
|
||||
|
||||
/// Constructor that builds the call graph for the given moduleOp.
|
||||
explicit CallGraph(ModuleOp moduleOp) : moduleOp(moduleOp) { build(); }
|
||||
|
||||
/// Walks the call graph and applies the provided update functions
|
||||
/// to the edges and nodes.
|
||||
template <WalkOrder UpdateEdgeOrder = WalkOrder::PreOrder,
|
||||
WalkOrder UpdateNodeOrder = WalkOrder::PreOrder,
|
||||
typename UpdateEdgeFn, typename UpdateNodeFn>
|
||||
void walk(UpdateEdgeFn updateEdgeFn, UpdateNodeFn updateNodeFn) {
|
||||
DenseSet<FunctionOpInterface> visited;
|
||||
for (auto root : roots) {
|
||||
doWalk<UpdateEdgeOrder, UpdateNodeOrder>(root, visited, updateEdgeFn,
|
||||
updateNodeFn);
|
||||
}
|
||||
}
|
||||
|
||||
/// Retrieves the data associated with a function
|
||||
T *getFuncData(FunctionOpInterface funcOp) {
|
||||
if (funcMap.count(funcOp)) {
|
||||
return &funcMap[funcOp];
|
||||
}
|
||||
return nullptr;
|
||||
}
|
||||
|
||||
/// Getters
|
||||
ModuleOp getModuleOp() const { return moduleOp; }
|
||||
SmallVector<FunctionOpInterface> getRoots() const { return roots; }
|
||||
size_t getNumFunctions() const { return funcMap.size(); }
|
||||
|
||||
/// Returns true if the given function is a root.
|
||||
bool isRoot(FunctionOpInterface funcOp) const {
|
||||
return llvm::is_contained(roots, funcOp);
|
||||
}
|
||||
|
||||
/// Maps the data and the graph nodes associated with a funcOp to a
|
||||
/// targetFuncOp.
|
||||
template <typename FROM, typename TO>
|
||||
void mapFuncOp(FROM funcOp, TO targetFuncOp) {
|
||||
// Iterate over graph and replace
|
||||
for (auto &kv : graph) {
|
||||
for (auto &edge : kv.second) {
|
||||
if (edge.second == funcOp) {
|
||||
edge.second = targetFuncOp;
|
||||
}
|
||||
}
|
||||
}
|
||||
graph[targetFuncOp] = graph[funcOp];
|
||||
// Replace in roots
|
||||
for (auto it = roots.begin(); it != roots.end(); ++it) {
|
||||
if (*it == funcOp) {
|
||||
*it = targetFuncOp;
|
||||
break;
|
||||
}
|
||||
}
|
||||
// Replace in funcMap
|
||||
funcMap[targetFuncOp] = funcMap[funcOp];
|
||||
}
|
||||
|
||||
/// Maps the graph edges associated with a callOp to a targetCallOp.
|
||||
template <typename FROM, typename TO>
|
||||
void mapCallOp(FROM callOp, TO targetCallOp) {
|
||||
// Iterate over graph and replace
|
||||
for (auto &kv : graph) {
|
||||
for (auto &edge : kv.second) {
|
||||
if (edge.first == callOp) {
|
||||
edge.first = targetCallOp;
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
private:
|
||||
void build() {
|
||||
SymbolTableCollection symbolTable;
|
||||
DenseSet<FunctionOpInterface> visited;
|
||||
// Build graph
|
||||
moduleOp.walk([&](Operation *op) {
|
||||
auto caller = op->getParentOfType<FunctionOpInterface>();
|
||||
if (auto callOp = dyn_cast<CallOpInterface>(op)) {
|
||||
auto *callee = callOp.resolveCallable(&symbolTable);
|
||||
auto funcOp = dyn_cast_or_null<FunctionOpInterface>(callee);
|
||||
if (funcOp) {
|
||||
graph[caller].emplace_back(
|
||||
std::pair<CallOpInterface, FunctionOpInterface>(callOp, funcOp));
|
||||
visited.insert(funcOp);
|
||||
}
|
||||
}
|
||||
});
|
||||
// Find roots
|
||||
moduleOp.walk([&](FunctionOpInterface funcOp) {
|
||||
if (!visited.count(funcOp)) {
|
||||
roots.push_back(funcOp);
|
||||
}
|
||||
});
|
||||
}
|
||||
|
||||
template <WalkOrder UpdateEdgeOrder = WalkOrder::PreOrder,
|
||||
WalkOrder UpdateNodeOrder = WalkOrder::PreOrder,
|
||||
typename UpdateEdgeFn, typename UpdateNodeFn>
|
||||
void doWalk(FunctionOpInterface funcOp,
|
||||
DenseSet<FunctionOpInterface> &visited, UpdateEdgeFn updateEdgeFn,
|
||||
UpdateNodeFn updateNodeFn) {
|
||||
if (visited.count(funcOp)) {
|
||||
llvm::report_fatal_error("Cycle detected in call graph");
|
||||
}
|
||||
if constexpr (UpdateNodeOrder == WalkOrder::PreOrder) {
|
||||
updateNodeFn(funcOp);
|
||||
}
|
||||
for (auto [callOp, callee] : graph[funcOp]) {
|
||||
if constexpr (UpdateEdgeOrder == WalkOrder::PreOrder) {
|
||||
updateEdgeFn(callOp, callee);
|
||||
}
|
||||
doWalk<UpdateEdgeOrder, UpdateNodeOrder>(callee, visited, updateEdgeFn,
|
||||
updateNodeFn);
|
||||
if constexpr (UpdateEdgeOrder == WalkOrder::PostOrder) {
|
||||
updateEdgeFn(callOp, callee);
|
||||
}
|
||||
}
|
||||
if constexpr (UpdateNodeOrder == WalkOrder::PostOrder) {
|
||||
updateNodeFn(funcOp);
|
||||
}
|
||||
visited.erase(funcOp);
|
||||
}
|
||||
|
||||
protected:
|
||||
ModuleOp moduleOp;
|
||||
DenseMap<FunctionOpInterface,
|
||||
SmallVector<std::pair<CallOpInterface, FunctionOpInterface>>>
|
||||
graph;
|
||||
FuncDataMapT funcMap;
|
||||
SmallVector<FunctionOpInterface> roots;
|
||||
};
|
||||
// Create a basic DataFlowSolver with constant and dead code analysis included.
|
||||
std::unique_ptr<DataFlowSolver> createDataFlowSolver();
|
||||
|
||||
triton::MakeTensorPtrOp getMakeTensorPtrOp(Value v);
|
||||
|
||||
} // namespace mlir
|
||||
|
||||
#endif // TRITON_ANALYSISROCM_UTILITY_H
|
||||
3
python/triton/third_party/hip/include/triton/CMakeLists.txt
vendored
Normal file
3
python/triton/third_party/hip/include/triton/CMakeLists.txt
vendored
Normal file
@@ -0,0 +1,3 @@
|
||||
add_subdirectory(Conversion)
|
||||
add_subdirectory(Dialect)
|
||||
# add_subdirectory(Target)
|
||||
3
python/triton/third_party/hip/include/triton/Conversion/CMakeLists.txt
vendored
Normal file
3
python/triton/third_party/hip/include/triton/Conversion/CMakeLists.txt
vendored
Normal file
@@ -0,0 +1,3 @@
|
||||
add_subdirectory(TritonToTritonGPUROCM)
|
||||
add_subdirectory(TritonGPUROCMToLLVM)
|
||||
# add_subdirectory(NVGPUToLLVM)
|
||||
27
python/triton/third_party/hip/include/triton/Conversion/TritonGPUROCMToLLVM/AsmFormat.h
vendored
Normal file
27
python/triton/third_party/hip/include/triton/Conversion/TritonGPUROCMToLLVM/AsmFormat.h
vendored
Normal file
@@ -0,0 +1,27 @@
|
||||
#ifndef TRITON_CONVERSION_ROCM_TRITON_GPU_TO_LLVM_ASM_FORMAT_H_
|
||||
#define TRITON_CONVERSION_ROCM_TRITON_GPU_TO_LLVM_ASM_FORMAT_H_
|
||||
|
||||
#include "mlir/IR/Value.h"
|
||||
#include "triton/Dialect/Triton/IR/Dialect.h"
|
||||
#include "llvm/ADT/SmallVector.h"
|
||||
#include "llvm/ADT/StringExtras.h"
|
||||
#include "llvm/ADT/StringRef.h"
|
||||
#include <memory>
|
||||
#include <string>
|
||||
|
||||
namespace mlir {
|
||||
class ConversionPatternRewriter;
|
||||
class Location;
|
||||
|
||||
namespace triton {
|
||||
using llvm::StringRef;
|
||||
|
||||
inline std::string strJoin(llvm::ArrayRef<std::string> strs,
|
||||
llvm::StringRef delimiter) {
|
||||
return llvm::join(strs.begin(), strs.end(), delimiter);
|
||||
}
|
||||
|
||||
} // namespace triton
|
||||
} // namespace mlir
|
||||
|
||||
#endif // TRITON_CONVERSION_ROCM_TRITON_GPU_TO_LLVM_ASM_FORMAT_H_
|
||||
3
python/triton/third_party/hip/include/triton/Conversion/TritonGPUROCMToLLVM/CMakeLists.txt
vendored
Normal file
3
python/triton/third_party/hip/include/triton/Conversion/TritonGPUROCMToLLVM/CMakeLists.txt
vendored
Normal file
@@ -0,0 +1,3 @@
|
||||
set(LLVM_TARGET_DEFINITIONS Passes.td)
|
||||
mlir_tablegen(Passes.h.inc -gen-pass-decls --name TritonGPUROCMToLLVM)
|
||||
add_public_tablegen_target(TritonGPUROCMConversionPassIncGen)
|
||||
381
python/triton/third_party/hip/include/triton/Conversion/TritonGPUROCMToLLVM/GCNAsmFormat.h
vendored
Normal file
381
python/triton/third_party/hip/include/triton/Conversion/TritonGPUROCMToLLVM/GCNAsmFormat.h
vendored
Normal file
@@ -0,0 +1,381 @@
|
||||
#ifndef TRITON_CONVERSION_ROCM_TRITON_GPU_TO_LLVM_GCN_FORMAT_H_
|
||||
#define TRITON_CONVERSION_ROCM_TRITON_GPU_TO_LLVM_GCN_FORMAT_H_
|
||||
|
||||
#include "mlir/IR/Value.h"
|
||||
#include "triton/Dialect/Triton/IR/Dialect.h"
|
||||
#include "llvm/ADT/SmallVector.h"
|
||||
#include "llvm/ADT/StringRef.h"
|
||||
#include <memory>
|
||||
#include <string>
|
||||
|
||||
namespace mlir {
|
||||
class ConversionPatternRewriter;
|
||||
class Location;
|
||||
|
||||
namespace triton {
|
||||
using llvm::StringRef;
|
||||
|
||||
class GCNInstr;
|
||||
class GCNInstrCommon;
|
||||
class GCNInstrExecution;
|
||||
|
||||
// GCNBuilder helps to manage a GCN asm program consists of one or multiple
|
||||
// instructions.
|
||||
//
|
||||
// A helper for building an ASM program, the objective of GCNBuilder is to give
|
||||
// a thin encapsulation and make the ASM code for MLIR LLVM Dialect more clear.
|
||||
// Currently, several factors are introduced to reduce the need for mixing
|
||||
// string and C++ if-else code.
|
||||
//
|
||||
// Usage:
|
||||
// To create a multiplcation operation
|
||||
//
|
||||
//
|
||||
// GCNBuilder gcnBuilder;
|
||||
// unsigned bitwidth = elemTy.getIntOrFloatBitWidth();
|
||||
//
|
||||
// const std::string readConstraint = "v";
|
||||
// const std::string writeConstraint = "=v";
|
||||
// auto res = gcnBuilder.newOperand(writeConstraint);
|
||||
// auto lhs = gcnBuilder.newOperand(operands[0], readConstraint);
|
||||
// auto rhs = gcnBuilder.newOperand(operands[1], readConstraint);
|
||||
//
|
||||
// create inst
|
||||
// auto &mul_inst =
|
||||
// gcnBuilder.create<GCNInstr>("v_mul")->float_op_type(bitwidth);
|
||||
//
|
||||
// launch insts
|
||||
// mul_inst(res, lhs, rhs);
|
||||
//
|
||||
// return result
|
||||
// Value ret = gcnBuilder.launch(rewriter, loc, elemTy, false);
|
||||
// return ret;
|
||||
// To get the asm code:
|
||||
// builder.dump()
|
||||
//
|
||||
// To get all the mlir::Value used in the GCN code,
|
||||
//
|
||||
// builder.getAllMlirArgs() // get {pVal, iVal, jVal, kVal}
|
||||
//
|
||||
// To get the string containing all the constraints with "," separated,
|
||||
// builder.getConstraints() // get "=v,v,v"
|
||||
//
|
||||
// GCNBuilder can build a GCN asm with multiple instructions, sample code:
|
||||
//
|
||||
// GCNBuilder builder;
|
||||
// auto &rcp = gcnBuilder.create<GCNInstr>("v_rcp")->float_op_type(bitwidth);
|
||||
// auto &mul_inst =
|
||||
// gcnBuilder.create<GCNInstr>("v_mul")->float_op_type(bitwidth);
|
||||
//
|
||||
// rcp(...);
|
||||
// mul_inst(...);
|
||||
// This will get a GCN code with two instructions.
|
||||
//
|
||||
// Similar to a C function, a declared GCNInstr instance can be launched
|
||||
// multiple times with different operands, e.g.
|
||||
//
|
||||
// auto &mul_inst =
|
||||
// gcnBuilder.create<GCNInstr>("v_mul")->float_op_type(bitwidth); mul_inst(...
|
||||
// some operands ...); mul_inst(... some different operands ...);
|
||||
//
|
||||
// Finally, we will get a GCN code with two mov instructions.
|
||||
//
|
||||
// There are several derived instruction type for typical instructions, for
|
||||
// example, the GCNIOInstr for ld and st instructions.
|
||||
struct GCNBuilder {
|
||||
struct Operand {
|
||||
std::string constraint;
|
||||
Value value;
|
||||
int idx{-1};
|
||||
llvm::SmallVector<Operand *> list;
|
||||
std::function<std::string(int idx)> repr;
|
||||
|
||||
// for list
|
||||
Operand() = default;
|
||||
Operand(const Operation &) = delete;
|
||||
Operand(Value value, StringRef constraint)
|
||||
: value(value), constraint(constraint) {}
|
||||
|
||||
bool isList() const { return !value && constraint.empty(); }
|
||||
|
||||
Operand *listAppend(Operand *arg) {
|
||||
list.push_back(arg);
|
||||
return this;
|
||||
}
|
||||
|
||||
Operand *listGet(size_t nth) const {
|
||||
assert(nth < list.size());
|
||||
return list[nth];
|
||||
}
|
||||
|
||||
std::string dump() const;
|
||||
};
|
||||
|
||||
struct Modifier {
|
||||
Value value;
|
||||
std::string modifier;
|
||||
std::string arg;
|
||||
llvm::SmallVector<Modifier *> list;
|
||||
|
||||
Modifier() = default;
|
||||
Modifier(const Operation &) = delete;
|
||||
Modifier(Value value, StringRef arg) : value(value), arg(arg) {}
|
||||
|
||||
bool isList() const { return !value && modifier.empty(); }
|
||||
|
||||
Modifier *listAppend(Modifier *arg) {
|
||||
list.push_back(arg);
|
||||
return this;
|
||||
}
|
||||
|
||||
Modifier *listGet(size_t index) const {
|
||||
assert(index < list.size());
|
||||
return list[index];
|
||||
}
|
||||
|
||||
std::string to_str() const {
|
||||
std::string str = modifier;
|
||||
if (!arg.empty()) {
|
||||
str += ":" + arg;
|
||||
}
|
||||
return str;
|
||||
}
|
||||
|
||||
std::string dump() const;
|
||||
};
|
||||
|
||||
template <typename INSTR = GCNInstr, typename... Args>
|
||||
INSTR *create(Args &&...args) {
|
||||
instrs.emplace_back(std::make_unique<INSTR>(this, args...));
|
||||
return static_cast<INSTR *>(instrs.back().get());
|
||||
}
|
||||
|
||||
// Create a list of operands.
|
||||
Operand *newListOperand() { return newOperand(); }
|
||||
|
||||
Operand *newListOperand(ArrayRef<std::pair<mlir::Value, std::string>> items) {
|
||||
auto *list = newOperand();
|
||||
for (auto &item : items) {
|
||||
list->listAppend(newOperand(item.first, item.second));
|
||||
}
|
||||
return list;
|
||||
}
|
||||
|
||||
Operand *newListOperand(unsigned count, mlir::Value val,
|
||||
const std::string &constraint) {
|
||||
auto *list = newOperand();
|
||||
for (int i = 0; i < count; ++i) {
|
||||
list->listAppend(newOperand(val, constraint));
|
||||
}
|
||||
return list;
|
||||
}
|
||||
|
||||
Operand *newListOperand(unsigned count, const std::string &constraint) {
|
||||
auto *list = newOperand();
|
||||
for (int i = 0; i < count; ++i) {
|
||||
list->listAppend(newOperand(constraint));
|
||||
}
|
||||
return list;
|
||||
}
|
||||
|
||||
// Create a new operand. It will not add to operand list.
|
||||
// @value: the MLIR value bind to this operand.
|
||||
// @constraint: ASM operand constraint, .e.g. "=r"
|
||||
// @formatter: extra format to represent this operand in ASM code, default is
|
||||
// "%{0}".format(operand.idx).
|
||||
Operand *newOperand(mlir::Value value, StringRef constraint,
|
||||
std::function<std::string(int idx)> formatter = nullptr);
|
||||
|
||||
// Create a new operand which is written to, that is, the constraint starts
|
||||
// with "=", e.g. "=r".
|
||||
Operand *newOperand(StringRef constraint);
|
||||
|
||||
// Create a constant integer operand.
|
||||
Operand *newConstantOperand(int v);
|
||||
// Create a constant operand with explicit code specified.
|
||||
Operand *newConstantOperand(const std::string &v);
|
||||
|
||||
Operand *newAddrOperand(mlir::Value addr, StringRef constraint);
|
||||
|
||||
Modifier *newModifier(StringRef modifier, StringRef arg);
|
||||
|
||||
llvm::SmallVector<Operand *, 4> getAllArgs() const;
|
||||
|
||||
llvm::SmallVector<Value, 4> getAllMLIRArgs() const;
|
||||
|
||||
std::string getConstraints() const;
|
||||
|
||||
std::string dump() const;
|
||||
|
||||
mlir::Value launch(ConversionPatternRewriter &rewriter, Location loc,
|
||||
Type resTy, bool hasSideEffect = true,
|
||||
bool isAlignStack = false,
|
||||
ArrayRef<Attribute> attrs = {}) const;
|
||||
|
||||
private:
|
||||
Operand *newOperand() {
|
||||
argArchive.emplace_back(std::make_unique<Operand>());
|
||||
return argArchive.back().get();
|
||||
}
|
||||
|
||||
Modifier *newModifier() {
|
||||
modArchive.emplace_back(std::make_unique<Modifier>());
|
||||
return modArchive.back().get();
|
||||
}
|
||||
|
||||
friend class GCNInstr;
|
||||
friend class GCNInstrCommon;
|
||||
|
||||
protected:
|
||||
llvm::SmallVector<std::unique_ptr<Operand>, 6> argArchive;
|
||||
llvm::SmallVector<std::unique_ptr<Modifier>, 2> modArchive;
|
||||
llvm::SmallVector<std::unique_ptr<GCNInstrCommon>, 2> instrs;
|
||||
llvm::SmallVector<std::unique_ptr<GCNInstrExecution>, 4> executions;
|
||||
int oprCounter{};
|
||||
};
|
||||
|
||||
// GCN instruction common interface.
|
||||
// Put the generic logic for all the instructions here.
|
||||
struct GCNInstrCommon {
|
||||
explicit GCNInstrCommon(GCNBuilder *builder) : builder(builder) {}
|
||||
|
||||
using Operand = GCNBuilder::Operand;
|
||||
using Modifier = GCNBuilder::Modifier;
|
||||
|
||||
// clang-format off
|
||||
GCNInstrExecution& operator()() { return call({}, {}); }
|
||||
GCNInstrExecution& operator()(Operand* a) { return call({a}, {}); }
|
||||
GCNInstrExecution& operator()(Operand* a, Operand* b) { return call({a, b}, {}); }
|
||||
GCNInstrExecution& operator()(Operand* a, Operand* b, Operand* c) { return call({a, b, c}, {}); }
|
||||
GCNInstrExecution& operator()(Operand* a, Operand* b, Operand* c, Operand* d) { return call({a, b, c, d}, {}); }
|
||||
GCNInstrExecution& operator()(Operand* a, Operand* b, Operand* c, Operand* d, Operand * e) { return call({a, b, c, d, e}, {}); }
|
||||
GCNInstrExecution& operator()(Operand* a, Operand* b, Operand* c, Operand* d, Operand * e, Operand* f) { return call({a, b, c, d, e, f}, {}); }
|
||||
GCNInstrExecution& operator()(Operand* a, Operand* b, Operand* c, Operand* d, Operand * e, Operand* f, Operand* g) { return call({a, b, c, d, e, f, g}, {}); }
|
||||
// clang-format on
|
||||
|
||||
// Set operands of this instruction.
|
||||
GCNInstrExecution &operator()(llvm::ArrayRef<Operand *> oprs,
|
||||
llvm::ArrayRef<Modifier *> mods);
|
||||
|
||||
protected:
|
||||
GCNInstrExecution &call(llvm::ArrayRef<Operand *> oprs,
|
||||
ArrayRef<Modifier *> mods);
|
||||
|
||||
GCNBuilder *builder{};
|
||||
llvm::SmallVector<std::string, 4> instrParts;
|
||||
|
||||
friend class GCNInstrExecution;
|
||||
};
|
||||
|
||||
template <class ConcreteT> struct GCNInstrBase : public GCNInstrCommon {
|
||||
using Operand = GCNBuilder::Operand;
|
||||
using Modifier = GCNBuilder::Modifier;
|
||||
|
||||
explicit GCNInstrBase(GCNBuilder *builder, const std::string &name)
|
||||
: GCNInstrCommon(builder) {
|
||||
o(name);
|
||||
}
|
||||
|
||||
ConcreteT &o(const std::string &suffix, bool predicate = true) {
|
||||
if (predicate)
|
||||
instrParts.push_back(suffix);
|
||||
return *static_cast<ConcreteT *>(this);
|
||||
}
|
||||
};
|
||||
|
||||
enum VectorWidth { Byte = 8, Short = 16, Dword = 32, Qword = 64 };
|
||||
|
||||
struct GCNInstr : public GCNInstrBase<GCNInstr> {
|
||||
using GCNInstrBase<GCNInstr>::GCNInstrBase;
|
||||
|
||||
GCNInstr &float_op_type(int width) {
|
||||
switch (width) {
|
||||
case Byte:
|
||||
assert(Byte != width);
|
||||
break;
|
||||
case Short:
|
||||
o("f16");
|
||||
break;
|
||||
case Dword:
|
||||
o("f32");
|
||||
break;
|
||||
case Qword:
|
||||
o("f64");
|
||||
break;
|
||||
default:
|
||||
break;
|
||||
}
|
||||
return *this;
|
||||
}
|
||||
};
|
||||
|
||||
struct GCNInstrExecution {
|
||||
using Operand = GCNBuilder::Operand;
|
||||
using Modifier = GCNBuilder::Modifier;
|
||||
|
||||
llvm::SmallVector<Operand *> argsInOrder;
|
||||
llvm::SmallVector<Modifier *> mods;
|
||||
|
||||
GCNInstrExecution() = default;
|
||||
explicit GCNInstrExecution(GCNInstrCommon *instr,
|
||||
llvm::ArrayRef<Operand *> oprs,
|
||||
llvm::ArrayRef<Modifier *> modifiers)
|
||||
: instr(instr), argsInOrder(oprs.begin(), oprs.end()),
|
||||
mods(modifiers.begin(), modifiers.end()) {}
|
||||
|
||||
std::string dump() const;
|
||||
|
||||
SmallVector<Operand *> getArgList() const;
|
||||
|
||||
GCNInstrCommon *instr{};
|
||||
};
|
||||
|
||||
struct GCNMemInstr : public GCNInstrBase<GCNMemInstr> {
|
||||
using GCNInstrBase<GCNMemInstr>::GCNInstrBase;
|
||||
// Add specific type suffix to instruction
|
||||
|
||||
GCNMemInstr &load_type(int width) {
|
||||
switch (width) {
|
||||
case Byte:
|
||||
o("ubyte");
|
||||
break;
|
||||
case Short:
|
||||
o("ushort");
|
||||
break;
|
||||
case Dword:
|
||||
o("dword");
|
||||
break;
|
||||
case Qword:
|
||||
o("dwordx2");
|
||||
break;
|
||||
default:
|
||||
break;
|
||||
}
|
||||
return *this;
|
||||
}
|
||||
|
||||
GCNMemInstr &store_type(int width) {
|
||||
switch (width) {
|
||||
case Byte:
|
||||
o("byte");
|
||||
break;
|
||||
case Short:
|
||||
o("short");
|
||||
break;
|
||||
case Dword:
|
||||
o("dword");
|
||||
break;
|
||||
case Qword:
|
||||
o("dwordx2");
|
||||
break;
|
||||
default:
|
||||
break;
|
||||
}
|
||||
return *this;
|
||||
}
|
||||
};
|
||||
|
||||
} // namespace triton
|
||||
} // namespace mlir
|
||||
|
||||
#endif // TRITON_CONVERSION_ROCM_TRITON_GPU_TO_LLVM_ASM_FORMAT_H_
|
||||
338
python/triton/third_party/hip/include/triton/Conversion/TritonGPUROCMToLLVM/PTXAsmFormat.h
vendored
Normal file
338
python/triton/third_party/hip/include/triton/Conversion/TritonGPUROCMToLLVM/PTXAsmFormat.h
vendored
Normal file
@@ -0,0 +1,338 @@
|
||||
#ifndef TRITON_CONVERSION_ROCM_TRITON_GPU_TO_LLVM_PTX_ASM_FORMAT_H_
|
||||
#define TRITON_CONVERSION_ROCM_TRITON_GPU_TO_LLVM_PTX_ASM_FORMAT_H_
|
||||
|
||||
#include "mlir/IR/Value.h"
|
||||
#include "triton/Dialect/Triton/IR/Dialect.h"
|
||||
#include "llvm/ADT/SmallVector.h"
|
||||
#include "llvm/ADT/StringRef.h"
|
||||
#include <memory>
|
||||
#include <string>
|
||||
|
||||
namespace mlir {
|
||||
class ConversionPatternRewriter;
|
||||
class Location;
|
||||
|
||||
namespace triton {
|
||||
using llvm::StringRef;
|
||||
|
||||
struct PTXInstr;
|
||||
struct PTXInstrCommon;
|
||||
struct PTXInstrExecution;
|
||||
|
||||
// PTXBuilder helps to manage a PTX asm program consists of one or multiple
|
||||
// instructions.
|
||||
//
|
||||
// A helper for building an ASM program, the objective of PTXBuilder is to give
|
||||
// a thin encapsulation and make the ASM code for MLIR LLVM Dialect more clear.
|
||||
// Currently, several factors are introduced to reduce the need for mixing
|
||||
// string and C++ if-else code.
|
||||
//
|
||||
// Usage:
|
||||
// To build: @$3 asm("@%3 add.s32 %0, %1, %2;" : "=r"(i) : "r"(j), "r"(k),
|
||||
// "b"(p));
|
||||
//
|
||||
// PTXBuilder builder;
|
||||
// auto& add = builder.create<>();
|
||||
// add.predicate(pVal).o("lo").o("u32"); // add any suffix
|
||||
// // predicate here binds %0 to pVal, pVal is a mlir::Value
|
||||
//
|
||||
// auto* iOpr = builder.newOperand(iVal, "r"); // %1 bind to iVal
|
||||
// auto* jOpr = builder.newOperand(jVal, "r"); // %2 bind to jVal
|
||||
// auto* kOpr = builder.newOperand(kVal, "r"); // %3 bind to kVal
|
||||
// add(iOpr, jOpr, kOpr).predicate(predVal); // set operands and predicate
|
||||
//
|
||||
// To get the asm code:
|
||||
// builder.dump()
|
||||
//
|
||||
// To get all the mlir::Value used in the PTX code,
|
||||
//
|
||||
// builder.getAllMlirArgs() // get {pVal, iVal, jVal, kVal}
|
||||
//
|
||||
// To get the string containing all the constraints with "," separated,
|
||||
// builder.getConstraints() // get "=r,r,k"
|
||||
//
|
||||
// PTXBuilder can build a PTX asm with multiple instructions, sample code:
|
||||
//
|
||||
// PTXBuilder builder;
|
||||
// auto& mov = builder.create("mov");
|
||||
// auto& cp = builder.create("cp");
|
||||
// mov(...);
|
||||
// cp(...);
|
||||
// This will get a PTX code with two instructions.
|
||||
//
|
||||
// Similar to a C function, a declared PTXInstr instance can be launched
|
||||
// multiple times with different operands, e.g.
|
||||
//
|
||||
// auto& mov = builder.create("mov");
|
||||
// mov(... some operands ...);
|
||||
// mov(... some different operands ...);
|
||||
//
|
||||
// Finally, we will get a PTX code with two mov instructions.
|
||||
//
|
||||
// There are several derived instruction type for typical instructions, for
|
||||
// example, the PtxIOInstr for ld and st instructions.
|
||||
struct PTXBuilder {
|
||||
struct Operand {
|
||||
std::string constraint;
|
||||
Value value;
|
||||
int idx{-1};
|
||||
llvm::SmallVector<Operand *> list;
|
||||
std::function<std::string(int idx)> repr;
|
||||
|
||||
// for list
|
||||
Operand() = default;
|
||||
Operand(const Operation &) = delete;
|
||||
Operand(Value value, StringRef constraint)
|
||||
: constraint(constraint), value(value) {}
|
||||
|
||||
bool isList() const { return !value && constraint.empty(); }
|
||||
|
||||
Operand *listAppend(Operand *arg) {
|
||||
list.push_back(arg);
|
||||
return this;
|
||||
}
|
||||
|
||||
Operand *listGet(size_t nth) const {
|
||||
assert(nth < list.size());
|
||||
return list[nth];
|
||||
}
|
||||
|
||||
std::string dump() const;
|
||||
};
|
||||
|
||||
template <typename INSTR = PTXInstr, typename... Args>
|
||||
INSTR *create(Args &&...args) {
|
||||
instrs.emplace_back(std::make_unique<INSTR>(this, args...));
|
||||
return static_cast<INSTR *>(instrs.back().get());
|
||||
}
|
||||
|
||||
// Create a list of operands.
|
||||
Operand *newListOperand() { return newOperand(); }
|
||||
|
||||
Operand *newListOperand(ArrayRef<std::pair<mlir::Value, std::string>> items) {
|
||||
auto *list = newOperand();
|
||||
for (auto &item : items) {
|
||||
list->listAppend(newOperand(item.first, item.second));
|
||||
}
|
||||
return list;
|
||||
}
|
||||
|
||||
Operand *newListOperand(unsigned count, mlir::Value val,
|
||||
const std::string &constraint) {
|
||||
auto *list = newOperand();
|
||||
for (unsigned i = 0; i < count; ++i) {
|
||||
list->listAppend(newOperand(val, constraint));
|
||||
}
|
||||
return list;
|
||||
}
|
||||
|
||||
Operand *newListOperand(unsigned count, const std::string &constraint) {
|
||||
auto *list = newOperand();
|
||||
for (unsigned i = 0; i < count; ++i) {
|
||||
list->listAppend(newOperand(constraint));
|
||||
}
|
||||
return list;
|
||||
}
|
||||
|
||||
// Create a new operand. It will not add to operand list.
|
||||
// @value: the MLIR value bind to this operand.
|
||||
// @constraint: ASM operand constraint, .e.g. "=r"
|
||||
// @formatter: extra format to represent this operand in ASM code, default is
|
||||
// "%{0}".format(operand.idx).
|
||||
Operand *newOperand(mlir::Value value, StringRef constraint,
|
||||
std::function<std::string(int idx)> formatter = nullptr);
|
||||
|
||||
// Create a new operand which is written to, that is, the constraint starts
|
||||
// with "=", e.g. "=r".
|
||||
// If the operand will be used in predicated execution,
|
||||
// users may want to initialize it before use.
|
||||
// Otherwise if the register is only used in the true branch or the false
|
||||
// branch but not both, the register is undefined and ptxas can perform
|
||||
// aggressive optimizations that may lead to incorrect results.
|
||||
Operand *newOperand(StringRef constraint, bool init = false);
|
||||
|
||||
// Create a new operand that is tied to a previous operand. In this case the
|
||||
// asm would be permitted to write to an input register. Instead of providing
|
||||
// constraint code for this operand, the constraint code of the tied operand
|
||||
// is used.
|
||||
Operand *newOperand(unsigned operandIndex);
|
||||
|
||||
// Create a constant integer operand.
|
||||
Operand *newConstantOperand(int64_t v);
|
||||
// Create a constant operand with explicit code specified.
|
||||
Operand *newConstantOperand(const std::string &v);
|
||||
|
||||
Operand *newAddrOperand(mlir::Value addr, StringRef constraint, int off = 0);
|
||||
|
||||
llvm::SmallVector<Operand *, 4> getAllArgs() const;
|
||||
|
||||
llvm::SmallVector<Value, 4> getAllMLIRArgs() const;
|
||||
|
||||
std::string getConstraints() const;
|
||||
|
||||
std::string dump() const;
|
||||
|
||||
mlir::Value launch(OpBuilder &rewriter, Location loc, Type resTy,
|
||||
bool hasSideEffect = true, bool isAlignStack = false,
|
||||
ArrayRef<Attribute> attrs = {}) const;
|
||||
|
||||
private:
|
||||
Operand *newOperand() {
|
||||
argArchive.emplace_back(std::make_unique<Operand>());
|
||||
return argArchive.back().get();
|
||||
}
|
||||
|
||||
void initOperand(Operand *opr);
|
||||
|
||||
// Make the operands in argArchive follow the provided \param order.
|
||||
void reorderArgArchive(ArrayRef<Operand *> order) {
|
||||
assert(order.size() == argArchive.size());
|
||||
// The order in argArchive is unnecessary when onlyAttachMLIRArgs=false, but
|
||||
// it does necessary when onlyAttachMLIRArgs is true for the $0, $1... are
|
||||
// determined by PTX code snippet passed from external.
|
||||
sort(argArchive.begin(), argArchive.end(),
|
||||
[&](std::unique_ptr<Operand> &a, std::unique_ptr<Operand> &b) {
|
||||
auto ida = std::find(order.begin(), order.end(), a.get());
|
||||
auto idb = std::find(order.begin(), order.end(), b.get());
|
||||
assert(ida != order.end());
|
||||
assert(idb != order.end());
|
||||
return ida < idb;
|
||||
});
|
||||
}
|
||||
|
||||
friend struct PTXInstr;
|
||||
friend struct PTXInstrCommon;
|
||||
|
||||
protected:
|
||||
llvm::SmallVector<std::unique_ptr<Operand>, 6> argArchive;
|
||||
llvm::SmallVector<std::unique_ptr<PTXInstrCommon>, 2> instrs;
|
||||
llvm::SmallVector<std::unique_ptr<PTXInstrExecution>, 4> executions;
|
||||
int oprCounter{};
|
||||
};
|
||||
|
||||
// PTX instruction common interface.
|
||||
// Put the generic logic for all the instructions here.
|
||||
struct PTXInstrCommon {
|
||||
explicit PTXInstrCommon(PTXBuilder *builder) : builder(builder) {}
|
||||
|
||||
using Operand = PTXBuilder::Operand;
|
||||
|
||||
// clang-format off
|
||||
PTXInstrExecution& operator()() { return call({}); }
|
||||
PTXInstrExecution& operator()(Operand* a) { return call({a}); }
|
||||
PTXInstrExecution& operator()(Operand* a, Operand* b) { return call({a, b}); }
|
||||
PTXInstrExecution& operator()(Operand* a, Operand* b, Operand* c) { return call({a, b, c}); }
|
||||
PTXInstrExecution& operator()(Operand* a, Operand* b, Operand* c, Operand* d) { return call({a, b, c, d}); }
|
||||
PTXInstrExecution& operator()(Operand* a, Operand* b, Operand* c, Operand* d, Operand * e) { return call({a, b, c, d, e}); }
|
||||
PTXInstrExecution& operator()(Operand* a, Operand* b, Operand* c, Operand* d, Operand * e, Operand* f) { return call({a, b, c, d, e, f}); }
|
||||
PTXInstrExecution& operator()(Operand* a, Operand* b, Operand* c, Operand* d, Operand * e, Operand* f, Operand* g) { return call({a, b, c, d, e, f, g}); }
|
||||
// clang-format on
|
||||
|
||||
// Set operands of this instruction.
|
||||
PTXInstrExecution &operator()(llvm::ArrayRef<Operand *> oprs,
|
||||
bool onlyAttachMLIRArgs = false);
|
||||
|
||||
protected:
|
||||
// "Call" the instruction with operands.
|
||||
// \param oprs The operands of this instruction.
|
||||
// \param onlyAttachMLIRArgs Indicate that it simply attach the MLIR Arguments
|
||||
// to the inline Asm without generating the operand ids(such as $0, $1) in PTX
|
||||
// code.
|
||||
PTXInstrExecution &call(llvm::ArrayRef<Operand *> oprs,
|
||||
bool onlyAttachMLIRArgs = false);
|
||||
|
||||
PTXBuilder *builder{};
|
||||
llvm::SmallVector<std::string, 4> instrParts;
|
||||
|
||||
friend struct PTXInstrExecution;
|
||||
};
|
||||
|
||||
template <class ConcreteT> struct PTXInstrBase : public PTXInstrCommon {
|
||||
using Operand = PTXBuilder::Operand;
|
||||
|
||||
explicit PTXInstrBase(PTXBuilder *builder, const std::string &name)
|
||||
: PTXInstrCommon(builder) {
|
||||
o(name);
|
||||
}
|
||||
|
||||
// Append a suffix to the instruction.
|
||||
// e.g. PTXInstr("add").o("s32") get a add.s32.
|
||||
// A predicate is used to tell whether to apply the suffix, so that no if-else
|
||||
// code needed. e.g. `PTXInstr("add").o("s32", isS32).o("u32", !isS32);` will
|
||||
// get a `add.s32` if isS32 is true.
|
||||
ConcreteT &o(const std::string &suffix, bool predicate = true) {
|
||||
if (predicate)
|
||||
instrParts.push_back(suffix);
|
||||
return *static_cast<ConcreteT *>(this);
|
||||
}
|
||||
};
|
||||
|
||||
struct PTXInstr : public PTXInstrBase<PTXInstr> {
|
||||
using PTXInstrBase<PTXInstr>::PTXInstrBase;
|
||||
|
||||
// Append a ".global" to the instruction.
|
||||
PTXInstr &global();
|
||||
|
||||
// Append a ".shared" to the instruction.
|
||||
PTXInstr &shared();
|
||||
|
||||
// Append a ".v[0-9]+" to the instruction
|
||||
PTXInstr &v(int vecWidth, bool predicate = true);
|
||||
|
||||
// Append a".b[0-9]+" to the instruction
|
||||
PTXInstr &b(int width);
|
||||
};
|
||||
|
||||
// Record the operands and context for "launching" a PtxInstr.
|
||||
struct PTXInstrExecution {
|
||||
using Operand = PTXBuilder::Operand;
|
||||
|
||||
llvm::SmallVector<Operand *> argsInOrder;
|
||||
|
||||
PTXInstrExecution() = default;
|
||||
explicit PTXInstrExecution(PTXInstrCommon *instr,
|
||||
llvm::ArrayRef<Operand *> oprs,
|
||||
bool onlyAttachMLIRArgs)
|
||||
: argsInOrder(oprs.begin(), oprs.end()), instr(instr),
|
||||
onlyAttachMLIRArgs(onlyAttachMLIRArgs) {}
|
||||
|
||||
// Prefix a predicate to the instruction.
|
||||
PTXInstrExecution &predicate(mlir::Value value, StringRef constraint = "b") {
|
||||
pred = instr->builder->newOperand(value, constraint);
|
||||
return *this;
|
||||
}
|
||||
|
||||
// Prefix a !predicate to the instruction.
|
||||
PTXInstrExecution &predicateNot(mlir::Value value, StringRef constraint) {
|
||||
pred = instr->builder->newOperand(value, constraint);
|
||||
pred->repr = [](int idx) { return "@!$" + std::to_string(idx); };
|
||||
return *this;
|
||||
}
|
||||
|
||||
std::string dump() const;
|
||||
|
||||
SmallVector<Operand *> getArgList() const;
|
||||
|
||||
PTXInstrCommon *instr{};
|
||||
Operand *pred{};
|
||||
bool onlyAttachMLIRArgs{};
|
||||
};
|
||||
|
||||
/// ====== Some instruction wrappers ======
|
||||
// We add the wrappers to make the usage more intuitive by avoiding mixing the
|
||||
// PTX code with some trivial C++ code.
|
||||
|
||||
struct PTXCpAsyncLoadInstr : PTXInstrBase<PTXCpAsyncLoadInstr> {
|
||||
explicit PTXCpAsyncLoadInstr(PTXBuilder *builder,
|
||||
triton::CacheModifier modifier)
|
||||
: PTXInstrBase(builder, "cp.async") {
|
||||
o(triton::stringifyCacheModifier(modifier).str());
|
||||
o("shared");
|
||||
o("global");
|
||||
}
|
||||
};
|
||||
|
||||
} // namespace triton
|
||||
} // namespace mlir
|
||||
|
||||
#endif
|
||||
16
python/triton/third_party/hip/include/triton/Conversion/TritonGPUROCMToLLVM/Passes.h
vendored
Normal file
16
python/triton/third_party/hip/include/triton/Conversion/TritonGPUROCMToLLVM/Passes.h
vendored
Normal file
@@ -0,0 +1,16 @@
|
||||
#ifndef CONVERSION_TRITONGPUROCM_TO_LLVM_PASSES_H
|
||||
#define CONVERSION_TRITONGPUROCM_TO_LLVM_PASSES_H
|
||||
|
||||
#include "mlir/Dialect/LLVMIR/LLVMDialect.h"
|
||||
#include "triton/Conversion/TritonGPUROCMToLLVM/TritonGPUToLLVMPass.h"
|
||||
|
||||
namespace mlir {
|
||||
namespace triton {
|
||||
|
||||
#define GEN_PASS_REGISTRATION
|
||||
#include "triton/Conversion/TritonGPUROCMToLLVM/Passes.h.inc"
|
||||
|
||||
} // namespace triton
|
||||
} // namespace mlir
|
||||
|
||||
#endif
|
||||
43
python/triton/third_party/hip/include/triton/Conversion/TritonGPUROCMToLLVM/Passes.td
vendored
Normal file
43
python/triton/third_party/hip/include/triton/Conversion/TritonGPUROCMToLLVM/Passes.td
vendored
Normal file
@@ -0,0 +1,43 @@
|
||||
#ifndef CONVERSION_TRITONGPUROCM_TO_LLVM_PASSES
|
||||
#define CONVERSION_TRITONGPUROCM_TO_LLVM_PASSES
|
||||
|
||||
include "mlir/Pass/PassBase.td"
|
||||
|
||||
|
||||
def ConvertTritonGPUROCMToLLVM : Pass<"convert-triton-gpurocm-to-llvm", "mlir::ModuleOp"> {
|
||||
let summary = "Convert TritonGPUROCM to LLVM";
|
||||
let description = [{
|
||||
|
||||
}];
|
||||
let constructor = "mlir::triton::createConvertTritonGPUROCMToLLVMPass()";
|
||||
|
||||
let dependentDialects = ["mlir::arith::ArithDialect",
|
||||
"mlir::math::MathDialect",
|
||||
"mlir::gpu::GPUDialect",
|
||||
"mlir::scf::SCFDialect",
|
||||
"mlir::LLVM::LLVMDialect",
|
||||
"mlir::tensor::TensorDialect",
|
||||
"mlir::triton::TritonDialect",
|
||||
"mlir::triton::gpu_rocm::TritonGPUROCMDialect",
|
||||
"mlir::triton::nvidia_gpu::TritonNvidiaGPUDialect",
|
||||
"mlir::ROCDL::ROCDLDialect",
|
||||
"mlir::NVVM::NVVMDialect"];
|
||||
|
||||
let options = [
|
||||
Option<"computeCapability", "compute-capability",
|
||||
"int32_t", /*default*/"80",
|
||||
"device compute capability">,
|
||||
Option<"tmaMetadata", "tma-metadata",
|
||||
"mlir::triton::gpu::TMAMetadataTy*", /*default*/"nullptr",
|
||||
"tma metadata to the runtime">,
|
||||
Option<"target", "target", "enum Target", "mlir::triton::Target::Default",
|
||||
"compile for target compatible LLVM",
|
||||
"llvm::cl::values("
|
||||
"clEnumValN(mlir::triton::Target::NVVM, \"nvvm\", \"compile for "
|
||||
"NVVM-compatible LLVM\"), "
|
||||
"clEnumValN(mlir::triton::Target::ROCDL, \"rocdl\", \"compile for "
|
||||
"ROCDL-compatible LLVM\"))">,
|
||||
];
|
||||
}
|
||||
|
||||
#endif
|
||||
@@ -0,0 +1,30 @@
|
||||
#ifndef TRITON_CONVERSION_TRITONGPUROCM_TO_LLVM_PASS_H
|
||||
#define TRITON_CONVERSION_TRITONGPUROCM_TO_LLVM_PASS_H
|
||||
|
||||
#include "mlir/Conversion/LLVMCommon/TypeConverter.h"
|
||||
#include "mlir/Transforms/DialectConversion.h"
|
||||
#include "triton/Target/PTX/TmaMetadata.h"
|
||||
|
||||
#include <memory>
|
||||
|
||||
namespace mlir {
|
||||
|
||||
class ModuleOp;
|
||||
template <typename T> class OperationPass;
|
||||
|
||||
namespace triton {
|
||||
|
||||
enum Target { NVVM, ROCDL, Default = NVVM };
|
||||
|
||||
#define GEN_PASS_DECL
|
||||
#include "triton/Conversion/TritonGPUROCMToLLVM/Passes.h.inc"
|
||||
|
||||
std::unique_ptr<OperationPass<ModuleOp>> createConvertTritonGPUROCMToLLVMPass();
|
||||
std::unique_ptr<OperationPass<ModuleOp>>
|
||||
createConvertTritonGPUROCMToLLVMPass(const ConvertTritonGPUROCMToLLVMOptions &options);
|
||||
|
||||
} // namespace triton
|
||||
|
||||
} // namespace mlir
|
||||
|
||||
#endif
|
||||
3
python/triton/third_party/hip/include/triton/Conversion/TritonToTritonGPUROCM/CMakeLists.txt
vendored
Normal file
3
python/triton/third_party/hip/include/triton/Conversion/TritonToTritonGPUROCM/CMakeLists.txt
vendored
Normal file
@@ -0,0 +1,3 @@
|
||||
set(LLVM_TARGET_DEFINITIONS Passes.td)
|
||||
mlir_tablegen(Passes.h.inc -gen-pass-decls --name TritonToTritonGPUROCM)
|
||||
add_public_tablegen_target(TritonToTritonGPUROCMConversionPassIncGen)
|
||||
15
python/triton/third_party/hip/include/triton/Conversion/TritonToTritonGPUROCM/Passes.h
vendored
Normal file
15
python/triton/third_party/hip/include/triton/Conversion/TritonToTritonGPUROCM/Passes.h
vendored
Normal file
@@ -0,0 +1,15 @@
|
||||
#ifndef CONVERSION_TRITON_TO_TRITONGPUROCM_PASSES_H
|
||||
#define CONVERSION_TRITON_TO_TRITONGPUROCM_PASSES_H
|
||||
|
||||
#include "triton/Conversion/TritonToTritonGPUROCM/TritonToTritonGPUPass.h"
|
||||
|
||||
namespace mlir {
|
||||
namespace triton {
|
||||
|
||||
#define GEN_PASS_REGISTRATION
|
||||
#include "triton/Conversion/TritonToTritonGPUROCM/Passes.h.inc"
|
||||
|
||||
} // namespace triton
|
||||
} // namespace mlir
|
||||
|
||||
#endif
|
||||
37
python/triton/third_party/hip/include/triton/Conversion/TritonToTritonGPUROCM/Passes.td
vendored
Normal file
37
python/triton/third_party/hip/include/triton/Conversion/TritonToTritonGPUROCM/Passes.td
vendored
Normal file
@@ -0,0 +1,37 @@
|
||||
#ifndef CONVERSION_TRITON_TO_TRITONGPUROCM_PASSES
|
||||
#define CONVERSION_TRITON_TO_TRITONGPUROCM_PASSES
|
||||
|
||||
include "mlir/Pass/PassBase.td"
|
||||
|
||||
def ConvertTritonToTritonGPUROCM: Pass<"convert-triton-to-tritongpu", "mlir::ModuleOp"> {
|
||||
let summary = "Convert Triton to TritonGPUROCM";
|
||||
let description = [{
|
||||
|
||||
}];
|
||||
let constructor = "mlir::triton::createConvertTritonToTritonGPUROCMPass()";
|
||||
|
||||
let dependentDialects = ["mlir::arith::ArithDialect",
|
||||
"mlir::math::MathDialect",
|
||||
// TODO: Does this pass depend on SCF?
|
||||
"mlir::scf::SCFDialect",
|
||||
"mlir::triton::TritonDialect",
|
||||
"mlir::triton::gpu_rocm::TritonGPUROCMDialect"];
|
||||
|
||||
let options = [
|
||||
Option<"numWarps", "num-warps",
|
||||
"int32_t", /*default*/"4",
|
||||
"number of warps">,
|
||||
|
||||
Option<"threadsPerWarp", "threads-per-warp",
|
||||
"int32_t", /*default*/"64",
|
||||
"number of threads per warp">,
|
||||
Option<"numCTAs", "num-ctas",
|
||||
"int32_t", /*default*/"1",
|
||||
"number of ctas in a cga">,
|
||||
Option<"computeCapability", "compute-capability",
|
||||
"int32_t", /*default*/"80",
|
||||
"compute capability">
|
||||
];
|
||||
}
|
||||
|
||||
#endif
|
||||
@@ -0,0 +1,31 @@
|
||||
#ifndef TRITON_CONVERSION_ROCM_TRITONTOTRITONGPU_TRITONTOTRITONGPUPASS_H
|
||||
#define TRITON_CONVERSION_ROCM_TRITONTOTRITONGPU_TRITONTOTRITONGPUPASS_H
|
||||
|
||||
#include <memory>
|
||||
|
||||
namespace mlir {
|
||||
|
||||
class ModuleOp;
|
||||
template <typename T> class OperationPass;
|
||||
|
||||
namespace triton {
|
||||
|
||||
constexpr static char AttrNumWarpsName[] = "triton_gpu_rocm.num-warps";
|
||||
constexpr static char AttrNumCTAsName[] = "triton_gpu_rocm.num-ctas";
|
||||
constexpr static char AttrComputeCapabilityName[] =
|
||||
"triton_gpu_rocm.compute-capability";
|
||||
|
||||
constexpr static char AttrNumThreadsPerWarp[] = "triton_gpu_rocm.threads-per-warp";
|
||||
|
||||
// Create the pass with numWarps passed from cl::opt.
|
||||
std::unique_ptr<OperationPass<ModuleOp>> createConvertTritonToTritonGPUROCMPass();
|
||||
|
||||
// Create the pass with numWarps set explicitly.
|
||||
std::unique_ptr<OperationPass<ModuleOp>>
|
||||
createConvertTritonToTritonGPUROCMPass(int numWarps, int threadsPerWarp = 64,
|
||||
int numCTAs = 1, int computeCapability = 80);
|
||||
|
||||
} // namespace triton
|
||||
} // namespace mlir
|
||||
|
||||
#endif
|
||||
4
python/triton/third_party/hip/include/triton/Dialect/CMakeLists.txt
vendored
Normal file
4
python/triton/third_party/hip/include/triton/Dialect/CMakeLists.txt
vendored
Normal file
@@ -0,0 +1,4 @@
|
||||
# add_subdirectory(Triton)
|
||||
add_subdirectory(TritonGPUROCM)
|
||||
# add_subdirectory(TritonNvidiaGPU)
|
||||
# add_subdirectory(NVGPU)
|
||||
2
python/triton/third_party/hip/include/triton/Dialect/TritonGPUROCM/CMakeLists.txt
vendored
Normal file
2
python/triton/third_party/hip/include/triton/Dialect/TritonGPUROCM/CMakeLists.txt
vendored
Normal file
@@ -0,0 +1,2 @@
|
||||
add_subdirectory(IR)
|
||||
add_subdirectory(Transforms)
|
||||
7
python/triton/third_party/hip/include/triton/Dialect/TritonGPUROCM/IR/Attributes.h
vendored
Normal file
7
python/triton/third_party/hip/include/triton/Dialect/TritonGPUROCM/IR/Attributes.h
vendored
Normal file
@@ -0,0 +1,7 @@
|
||||
#ifndef TRITON_DIALECT_TRITONGPUROCM_IR_ATTRIBUTES_H_
|
||||
#define TRITON_DIALECT_TRITONGPUROCM_IR_ATTRIBUTES_H_
|
||||
|
||||
#define GET_ATTRDEF_CLASSES
|
||||
#include "triton/Dialect/TritonGPUROCM/IR/TritonGPUAttrDefs.h.inc"
|
||||
|
||||
#endif // TRITON_DIALECT_TRITONGPUROCM_IR_ATTRIBUTES_H_
|
||||
15
python/triton/third_party/hip/include/triton/Dialect/TritonGPUROCM/IR/CMakeLists.txt
vendored
Normal file
15
python/triton/third_party/hip/include/triton/Dialect/TritonGPUROCM/IR/CMakeLists.txt
vendored
Normal file
@@ -0,0 +1,15 @@
|
||||
set(LLVM_TARGET_DEFINITIONS TritonGPUOps.td)
|
||||
mlir_tablegen(Dialect.h.inc -gen-dialect-decls -dialect=triton_gpu_rocm)
|
||||
mlir_tablegen(Dialect.cpp.inc -gen-dialect-defs -dialect=triton_gpu_rocm)
|
||||
mlir_tablegen(Ops.h.inc -gen-op-decls)
|
||||
mlir_tablegen(Ops.cpp.inc -gen-op-defs)
|
||||
mlir_tablegen(Types.h.inc -gen-typedef-decls -typedefs-dialect=triton_gpu_rocm)
|
||||
mlir_tablegen(Types.cpp.inc -gen-typedef-defs -typedefs-dialect=triton_gpu_rocm)
|
||||
add_public_tablegen_target(TritonGPUROCMTableGen)
|
||||
|
||||
set(LLVM_TARGET_DEFINITIONS TritonGPUAttrDefs.td)
|
||||
mlir_tablegen(TritonGPUAttrDefs.h.inc -gen-attrdef-decls)
|
||||
mlir_tablegen(TritonGPUAttrDefs.cpp.inc -gen-attrdef-defs)
|
||||
mlir_tablegen(OpsEnums.h.inc -gen-enum-decls)
|
||||
mlir_tablegen(OpsEnums.cpp.inc -gen-enum-defs)
|
||||
add_public_tablegen_target(TritonGPUROCMAttrDefsIncGen)
|
||||
117
python/triton/third_party/hip/include/triton/Dialect/TritonGPUROCM/IR/Dialect.h
vendored
Normal file
117
python/triton/third_party/hip/include/triton/Dialect/TritonGPUROCM/IR/Dialect.h
vendored
Normal file
@@ -0,0 +1,117 @@
|
||||
#ifndef TRITON_DIALECT_TRITONGPUROCM_IR_DIALECT_H_
|
||||
#define TRITON_DIALECT_TRITONGPUROCM_IR_DIALECT_H_
|
||||
|
||||
#include "mlir/Dialect/GPU/IR/GPUDialect.h"
|
||||
#include "mlir/Dialect/Tensor/IR/Tensor.h"
|
||||
#include "mlir/IR/BuiltinOps.h"
|
||||
#include "mlir/IR/Dialect.h"
|
||||
|
||||
// TritonGPU depends on Triton
|
||||
#include "triton/Dialect/NVGPU/IR/Dialect.h"
|
||||
#include "triton/Dialect/Triton/IR/Dialect.h"
|
||||
#include "triton/Dialect/TritonGPUROCM/IR/Attributes.h"
|
||||
#include "triton/Dialect/TritonGPUROCM/IR/Dialect.h.inc"
|
||||
#include "triton/Dialect/TritonGPUROCM/IR/Traits.h"
|
||||
|
||||
#define GET_OP_CLASSES
|
||||
#include "triton/Dialect/TritonGPUROCM/IR/Ops.h.inc"
|
||||
|
||||
namespace mlir {
|
||||
namespace triton {
|
||||
namespace gpu_rocm {
|
||||
|
||||
unsigned getTotalElemsPerThread(Type type);
|
||||
|
||||
unsigned getTotalElemsPerThread(Attribute layout, ArrayRef<int64_t> shape,
|
||||
Type eltTy);
|
||||
|
||||
SmallVector<unsigned> getElemsPerThread(Type type);
|
||||
|
||||
// Returns the number of threads per warp that may have access to replicated
|
||||
// elements. If you want non-replicated threads, use
|
||||
// getThreadsPerWarpWithUniqueData.
|
||||
SmallVector<unsigned> getThreadsPerWarp(Attribute layout);
|
||||
|
||||
unsigned getWarpSize(Attribute layout);
|
||||
|
||||
// Returns the number of warps per CTA that may have access to replicated
|
||||
// elements. If you want non-replicated warps, use getWarpsPerCTAWithUniqueData.
|
||||
SmallVector<unsigned> getWarpsPerCTA(Attribute layout);
|
||||
|
||||
SmallVector<unsigned> getSizePerThread(Attribute layout);
|
||||
|
||||
// Returns the number of contiguous elements that each thread
|
||||
// has access to, on each dimension of the tensor. E.g.
|
||||
// for a blocked layout with sizePerThread = [1, 4], returns [1, 4],
|
||||
// regardless of the shape of the tensor.
|
||||
SmallVector<unsigned> getContigPerThread(Attribute layout);
|
||||
|
||||
// Returns the number of non-replicated contiguous elements that each thread
|
||||
// has access to, on each dimension of the tensor. For a blocked layout
|
||||
// with sizePerThread = [1, 4] and tensor shape = [128, 1], the elements
|
||||
// for thread 0 would be [A_{0, 0}, A_{0, 0}, A_{0, 0}, A_{0, 0}], returns [1,
|
||||
// 1]. Whereas for a tensor shape [128, 128], the elements for thread 0 would be
|
||||
// [A_{0, 0}, A_{0, 1}, A_{0, 2}, A_{0, 3}], returns [1, 4].
|
||||
SmallVector<unsigned> getUniqueContigPerThread(Attribute layout,
|
||||
ArrayRef<int64_t> tensorShape);
|
||||
|
||||
// Returns the number of threads per warp that have access to non-replicated
|
||||
// elements of the tensor. E.g. for a blocked layout with sizePerThread = [1,
|
||||
// 1], threadsPerWarp = [2, 16] and tensor shape = [2, 2], threads 0, 1, 16, 17
|
||||
// have access to the full tensor, whereas the other threads have access to
|
||||
// replicated elements, so this function returns [2, 2].
|
||||
SmallVector<unsigned>
|
||||
getThreadsPerWarpWithUniqueData(Attribute layout,
|
||||
ArrayRef<int64_t> tensorShape);
|
||||
|
||||
// Returns the number of warps per CTA that have access to non-replicated
|
||||
// elements of the tensor. E.g. for a blocked layout with sizePerThread = [1,
|
||||
// 1], threadsPerWarp = [2, 16], warpsPerCTA = [1, 4] and tensor shape = [2, 2],
|
||||
// returns [1, 1], since the first warp has access to the full tensor, whereas
|
||||
// the other warps have access to replicated elements.
|
||||
SmallVector<unsigned>
|
||||
getWarpsPerCTAWithUniqueData(Attribute layout, ArrayRef<int64_t> tensorShape);
|
||||
|
||||
SmallVector<unsigned> getThreadsPerCTA(Attribute layout);
|
||||
|
||||
SmallVector<unsigned> getOrder(Attribute layout);
|
||||
|
||||
CTALayoutAttr getCTALayout(Attribute layout);
|
||||
|
||||
SmallVector<unsigned> getCTAsPerCGA(Attribute layout);
|
||||
|
||||
SmallVector<unsigned> getCTASplitNum(Attribute layout);
|
||||
|
||||
SmallVector<unsigned> getCTAOrder(Attribute layout);
|
||||
|
||||
/* The difference between ShapePerCTATile and ShapePerCTA:
|
||||
* (1) ShapePerCTATile is defined by SizePerThread * ThreadsPerWarp *
|
||||
* WarpsPerCTA in each dimension and is independent from the tensor shape.
|
||||
* (2) ShapePerCTA is defined by shape / CTASplitNum in each dimension.
|
||||
* (3) In the implementation of emitIndices, ShapePerCTATile will
|
||||
* be replicated or wraped to fit ShapePerCTA.
|
||||
*/
|
||||
SmallVector<unsigned>
|
||||
getShapePerCTATile(Attribute layout,
|
||||
ArrayRef<int64_t> tensorShape = ArrayRef<int64_t>());
|
||||
|
||||
SmallVector<int64_t> getShapePerCTA(ArrayRef<unsigned> CTASplitNum,
|
||||
ArrayRef<int64_t> shape);
|
||||
SmallVector<int64_t> getShapePerCTA(Attribute layout, ArrayRef<int64_t> shape);
|
||||
SmallVector<int64_t> getShapePerCTA(Type type);
|
||||
|
||||
unsigned getNumWarpsPerCTA(Attribute layout);
|
||||
|
||||
unsigned getNumCTAs(Attribute layout);
|
||||
|
||||
bool isaDistributedLayout(Attribute layout);
|
||||
|
||||
bool isSharedEncoding(Value value);
|
||||
|
||||
bool isExpensiveCat(CatOp cat, Attribute targetEncoding);
|
||||
|
||||
} // namespace gpu
|
||||
} // namespace triton
|
||||
} // namespace mlir
|
||||
|
||||
#endif // TRITON_DIALECT_TRITONGPUROCM_IR_DIALECT_H_
|
||||
31
python/triton/third_party/hip/include/triton/Dialect/TritonGPUROCM/IR/Traits.h
vendored
Normal file
31
python/triton/third_party/hip/include/triton/Dialect/TritonGPUROCM/IR/Traits.h
vendored
Normal file
@@ -0,0 +1,31 @@
|
||||
#ifndef TRITON_DIALECT_TRITONGPUROCM_IR_TRAITS_H_
|
||||
#define TRITON_DIALECT_TRITONGPUROCM_IR_TRAITS_H_
|
||||
|
||||
#include "mlir/IR/OpDefinition.h"
|
||||
|
||||
#include "mlir/IR/BuiltinTypes.h"
|
||||
#include "mlir/Support/LogicalResult.h"
|
||||
|
||||
namespace mlir {
|
||||
namespace OpTrait {
|
||||
|
||||
// These functions are out-of-line implementations of the methods in the
|
||||
// corresponding trait classes. This avoids them being template
|
||||
// instantiated/duplicated.
|
||||
namespace impl {
|
||||
LogicalResult verifyResultsAreSharedEncodingROCM(Operation *op);
|
||||
} // namespace impl
|
||||
|
||||
template <typename ConcreteType>
|
||||
class ResultsAreSharedEncodingROCM
|
||||
: public TraitBase<ConcreteType, ResultsAreSharedEncodingROCM> {
|
||||
public:
|
||||
static LogicalResult verifyTrait(Operation *op) {
|
||||
return impl::verifyResultsAreSharedEncodingROCM(op);
|
||||
}
|
||||
};
|
||||
|
||||
} // namespace OpTrait
|
||||
} // namespace mlir
|
||||
|
||||
#endif
|
||||
818
python/triton/third_party/hip/include/triton/Dialect/TritonGPUROCM/IR/TritonGPUAttrDefs.td
vendored
Normal file
818
python/triton/third_party/hip/include/triton/Dialect/TritonGPUROCM/IR/TritonGPUAttrDefs.td
vendored
Normal file
@@ -0,0 +1,818 @@
|
||||
#ifndef TRITON_DIALECT_TRITONGPUROCM_ATTRDEFS
|
||||
#define TRITON_DIALECT_TRITONGPUROCM_ATTRDEFS
|
||||
|
||||
include "mlir/IR/AttrTypeBase.td"
|
||||
include "triton/Dialect/TritonGPUROCM/IR/TritonGPUDialect.td"
|
||||
include "triton/Dialect/Triton/IR/TritonInterfaces.td"
|
||||
|
||||
//===----------------------------------------------------------------------===//
|
||||
// TritonGPU Attribute Definitions
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
||||
class TritonGPUROCM_Attr<string name, list<Trait> traits = [],
|
||||
string baseCppClass = "::mlir::Attribute">
|
||||
: AttrDef<TritonGPUROCM_Dialect, name, traits, baseCppClass> {
|
||||
|
||||
let description = [{
|
||||
TritonGPU Tensors differ from usual tensors in that they contain a _layout_ attribute which determines
|
||||
how the data should be partitioned across CUDA threads. Formally speaking, we define a layout as a function
|
||||
\mathcal{L} that maps a multi-dimensional tensor index $i \in \mathbb{Z}^d$ to a set of integers T corresponding
|
||||
to the indices of the CUDA threads allowed to access some data at index $i$.
|
||||
|
||||
For example, let us consider the layout function:
|
||||
\mathcal{L}(0, 0) = {0, 4}
|
||||
\mathcal{L}(0, 1) = {1, 5}
|
||||
\mathcal{L}(1, 0) = {2, 6}
|
||||
\mathcal{L}(1, 1) = {3, 7}
|
||||
|
||||
Then, attaching $\mathcal{L} to a tensor $T$ would mean that:
|
||||
- T[0,0] is owned by both cuda thread 0 and 4
|
||||
- T[0,1] is owned by both cuda thread 1 and 5
|
||||
- T[1,0] is owned by both cuda thread 2 and 6
|
||||
- T[1,1] is owned by both cuda thread 3 and 7
|
||||
|
||||
Right now, Triton implements two classes of layouts: shared, and distributed.
|
||||
}];
|
||||
|
||||
code extraBaseClassDeclaration = [{
|
||||
unsigned getTotalElemsPerThread(ArrayRef<int64_t> shape, Type eltTy) const;
|
||||
SmallVector<unsigned> getElemsPerThread(ArrayRef<int64_t> shape, Type eltTy) const;
|
||||
::mlir::LogicalResult verifyLayoutForArg(::mlir::Operation* op, unsigned argNo) const;
|
||||
}];
|
||||
}
|
||||
|
||||
//===----------------------------------------------------------------------===//
|
||||
// CTA Layout
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
||||
def CTALayoutAttr : TritonGPUROCM_Attr<"CTALayout"> {
|
||||
let parameters = (
|
||||
ins
|
||||
ArrayRefParameter<"unsigned">:$CTAsPerCGA,
|
||||
ArrayRefParameter<"unsigned">:$CTASplitNum,
|
||||
ArrayRefParameter<"unsigned">:$CTAOrder
|
||||
);
|
||||
}
|
||||
|
||||
//===----------------------------------------------------------------------===//
|
||||
// Shared Layout Encoding
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
||||
def SharedEncodingAttr : TritonGPUROCM_Attr<"SharedEncoding"> {
|
||||
let mnemonic = "shared";
|
||||
|
||||
let description = [{
|
||||
An encoding for tensors whose elements may be simultaneously accessed by
|
||||
different cuda threads in the programs, via shared memory. In other words,
|
||||
for all indices i \in R^d, \mathcal{L}(i) = {0, 1, ..., 32*num_warps - 1}.
|
||||
|
||||
In order to avoid shared memory bank conflicts, elements may be swizzled
|
||||
in memory. For example, a swizzled row-major layout could store its data
|
||||
as follows:
|
||||
|
||||
A_{0, 0} A_{0, 1} A_{0, 2} A_{0, 3} ... [phase 0] \ per_phase = 2
|
||||
A_{1, 0} A_{1, 1} A_{1, 2} A_{1, 3} ... [phase 0] /
|
||||
groups of vec=2 elements
|
||||
are stored contiguously
|
||||
_ _ _ _ /\_ _ _ _
|
||||
A_{2, 2} A_{2, 3} A_{2, 0} A_{2, 1} ... [phase 1] \ per phase = 2
|
||||
A_{3, 2} A_{3, 3} A_{3, 0} A_{3, 1} ... [phase 1] /
|
||||
|
||||
For MMAv3 eg Hopper GMMA, hasLeadingOffset should be true. In this case,
|
||||
when the matrix is stored in shared memory, there will be an offset not
|
||||
only in the stride dimension, but also in the leading dimension. For example,
|
||||
a matrix of size 16x128 and data type I8 is stored in the shared memory with
|
||||
64B-swizzle mode. The offset of the element with index (0, 64) will be 16*64,
|
||||
compared to 1*64 when the hasLeadingOffset is false.
|
||||
}];
|
||||
|
||||
// swizzle info: vec, perPhase, maxPhase
|
||||
// order: the fastest-changing axis first
|
||||
let parameters = (
|
||||
ins
|
||||
"unsigned":$vec,
|
||||
"unsigned":$perPhase,
|
||||
"unsigned":$maxPhase,
|
||||
ArrayRefParameter<"unsigned">:$order,
|
||||
"CTALayoutAttr":$CTALayout,
|
||||
"bool":$hasLeadingOffset
|
||||
);
|
||||
|
||||
let builders = [
|
||||
AttrBuilder<(ins "unsigned":$vec,
|
||||
"unsigned":$perPhase,
|
||||
"unsigned":$maxPhase,
|
||||
"ArrayRef<unsigned>":$order,
|
||||
"CTALayoutAttr":$CTALayout), [{
|
||||
bool hasLeadingOffset = false; // default value
|
||||
return $_get(context, vec, perPhase, maxPhase, order, CTALayout, hasLeadingOffset);
|
||||
}]>,
|
||||
|
||||
AttrBuilder<(ins "DotOperandEncodingAttr":$dotOpEnc,
|
||||
"ArrayRef<int64_t>":$shape,
|
||||
"ArrayRef<unsigned>":$order,
|
||||
"CTALayoutAttr":$CTALayout,
|
||||
"unsigned":$typeWidthInBit), [{
|
||||
|
||||
#if 1
|
||||
// ---- begin GFX908/GFX90A ----
|
||||
auto mfmaEnc = dotOpEnc.getParent().dyn_cast<MfmaEncodingAttr>();
|
||||
|
||||
if (mfmaEnc) {
|
||||
int kDimNum = dotOpEnc.getOpIdx() == 0 ? 1 : 0;
|
||||
bool isKDimInner = (order[0] == kDimNum);
|
||||
if (isKDimInner) {
|
||||
const int numBanks = 32;
|
||||
const int bankBitWidth = 32;
|
||||
const int SIMDWidth = 16;
|
||||
|
||||
// number of inner dimension rows per one pattern repeat
|
||||
int innerDimLength = shape[order[0]];
|
||||
int elemsPerOneBanksRow = (numBanks * bankBitWidth) / typeWidthInBit;
|
||||
|
||||
int perPhase = std::max(1, elemsPerOneBanksRow / innerDimLength);
|
||||
// Note: the following settings is customized to avoid
|
||||
// **load** bank conflicts
|
||||
//
|
||||
// vecSize is set to k_base, which is the number of elements each
|
||||
// workitem loads for one mfma instruction.
|
||||
// For now, the k_base rules are as follows
|
||||
// 1. All selected mfma instructions produce a single block
|
||||
// 2. For f16 data type, 2 VGPRs are used for operand A --> k_base = 4
|
||||
// 3. For non-f16 data types, 1 VGPR are used for operand A
|
||||
// k_base = 32 / elemTypeInBits
|
||||
// 4. TODO: what about f64?
|
||||
//
|
||||
// maxPhase is set to SIMDWidth / perPhase
|
||||
int vecSize = ((typeWidthInBit == 16) ? 64 : 32 ) / typeWidthInBit;
|
||||
int maxPhase = SIMDWidth / perPhase;
|
||||
|
||||
return get(context, vecSize, perPhase, maxPhase, order, CTALayout);
|
||||
} else {
|
||||
// Do not swizzle in case k dimension is not innermost.
|
||||
// In this case accesses will go in different banks even without swizzling.
|
||||
return get(context, 1, 1, 1, order, CTALayout);
|
||||
}
|
||||
}
|
||||
#endif
|
||||
auto mmaEnc = dotOpEnc.getParent().dyn_cast<MmaEncodingAttr>();
|
||||
|
||||
if(!mmaEnc)
|
||||
return get(context, 1, 1, 1, order, CTALayout);
|
||||
|
||||
int opIdx = dotOpEnc.getOpIdx();
|
||||
auto shapePerCTA = getShapePerCTA(CTALayout.getCTASplitNum(), shape);
|
||||
|
||||
// number of rows per phase
|
||||
|
||||
// index of the inner dimension in `order`
|
||||
unsigned inner = (opIdx == 0) ? 0 : 1;
|
||||
|
||||
// ---- begin Volta ----
|
||||
if (mmaEnc.isVolta()) {
|
||||
int perPhase = 128 / (shapePerCTA[order[0]] * (typeWidthInBit / 8));
|
||||
perPhase = std::max<int>(perPhase, 1);
|
||||
bool is_row = order[0] != 0;
|
||||
bool is_vec4 = opIdx == 0 ? !is_row && (shapePerCTA[order[0]] <= 16) :
|
||||
is_row && (shapePerCTA[order[0]] <= 16);
|
||||
int pack_size = opIdx == 0 ? ((is_row || is_vec4) ? 1 : 2) :
|
||||
((is_row && !is_vec4) ? 2 : 1);
|
||||
int rep = 2 * pack_size;
|
||||
int maxPhase = (order[inner] == 1 ? 8 : 4) / perPhase;
|
||||
int vec = 2 * rep;
|
||||
return get(context, vec, perPhase, maxPhase, order, CTALayout);
|
||||
}
|
||||
|
||||
// ---- begin Ampere ----
|
||||
if (mmaEnc.isAmpere()) {
|
||||
int perPhase = 128 / (shapePerCTA[order[0]] * 4 / dotOpEnc.getKWidth());
|
||||
perPhase = std::max<int>(perPhase, 1);
|
||||
std::vector<size_t> matShape = {8, 8, 4 * dotOpEnc.getKWidth()};
|
||||
// for now, disable swizzle when using transposed int8 tensor cores
|
||||
if ((32 / typeWidthInBit != dotOpEnc.getKWidth()) && order[0] == inner)
|
||||
return get(context, 1, 1, 1, order, CTALayout);
|
||||
|
||||
// --- handle A operand ---
|
||||
if (opIdx == 0) { // compute swizzling for A operand
|
||||
int vec = (order[0] == 1) ? matShape[2] : matShape[0]; // k : m
|
||||
int mmaStride = (order[0] == 1) ? matShape[0] : matShape[2];
|
||||
int maxPhase = mmaStride / perPhase;
|
||||
return get(context, vec, perPhase, maxPhase, order, CTALayout);
|
||||
}
|
||||
|
||||
// --- handle B operand ---
|
||||
if (opIdx == 1) {
|
||||
int vec = (order[0] == 1) ? matShape[1] : matShape[2]; // n : k
|
||||
int mmaStride = (order[0] == 1) ? matShape[2] : matShape[1];
|
||||
int maxPhase = mmaStride / perPhase;
|
||||
return get(context, vec, perPhase, maxPhase, order, CTALayout);
|
||||
}
|
||||
|
||||
llvm_unreachable("invalid operand index");
|
||||
}
|
||||
|
||||
// ---- begin version 3 ----
|
||||
if (mmaEnc.isHopper()) {
|
||||
llvm_unreachable("SharedEncodingAttr builder when the MMAEncodingAttr"
|
||||
" is Hopper has not been implemented yet");
|
||||
return $_get(context, 1, 1, 1, order, CTALayout, true);
|
||||
}
|
||||
|
||||
// ---- not implemented ----
|
||||
llvm_unreachable("unsupported swizzling for provided MMA version");
|
||||
}]>,
|
||||
|
||||
AttrBuilder<(ins "DotOperandEncodingAttr":$dotOpEnc,
|
||||
"ArrayRef<int64_t>":$shape,
|
||||
"ArrayRef<unsigned>":$order,
|
||||
"CTALayoutAttr":$CTALayout,
|
||||
"Type":$eltTy), [{
|
||||
unsigned bitwidth = eltTy.getIntOrFloatBitWidth();
|
||||
return get(context, dotOpEnc, shape, order, CTALayout, bitwidth);
|
||||
}]>,
|
||||
|
||||
AttrBuilder<(ins "ArrayRef<int64_t>":$shape,
|
||||
"ArrayRef<unsigned>":$order,
|
||||
"CTALayoutAttr":$CTALayout,
|
||||
"Type":$eltTy), [{
|
||||
auto shapePerCTA = getShapePerCTA(CTALayout.getCTASplitNum(), shape);
|
||||
|
||||
int32_t eleBitWidth = eltTy.getIntOrFloatBitWidth();
|
||||
int32_t vec = 128 / eleBitWidth, perPhase = 1, maxPhase = 1;
|
||||
|
||||
// get proper shared memory swizzling mode from the contiguous dimension
|
||||
// size of the origin blocked layout.
|
||||
auto contigDimSizeInByte = shapePerCTA[order[0]] * eleBitWidth / 8;
|
||||
if (contigDimSizeInByte >= 128 && contigDimSizeInByte % 128 == 0) {
|
||||
perPhase = 1;
|
||||
maxPhase = 8;
|
||||
} else if (contigDimSizeInByte >= 64 && contigDimSizeInByte % 64 == 0) {
|
||||
perPhase = 2;
|
||||
maxPhase = 4;
|
||||
} else if (contigDimSizeInByte >= 32 && contigDimSizeInByte % 32 == 0) {
|
||||
perPhase = 4;
|
||||
maxPhase = 2;
|
||||
} else {
|
||||
llvm_unreachable("unsupported shared memory layout for MMAv3");
|
||||
}
|
||||
|
||||
return $_get(context, vec, perPhase, maxPhase, order, CTALayout, true);
|
||||
}]>
|
||||
];
|
||||
|
||||
let extraClassDeclaration = extraBaseClassDeclaration;
|
||||
let hasCustomAssemblyFormat = 1;
|
||||
}
|
||||
|
||||
//===----------------------------------------------------------------------===//
|
||||
// Distributed Layout Encoding
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
||||
class DistributedEncoding<string name> : TritonGPUROCM_Attr<name> {
|
||||
let description = [{
|
||||
Distributed encodings have a layout function that is entirely characterized
|
||||
by a d-dimensional tensor L. Note that L doesn't need to have the same shape
|
||||
(or even the same rank) as the tensor it is encoding.
|
||||
|
||||
The layout function \mathcal{L} of this layout is then defined, for an
|
||||
index `i` \in R^D, as follows:
|
||||
|
||||
\mathcal{L}(A)[i_d] = L[(i_d + k_d*A.shape[d]) % L.shape[d]] \forall k_d such as i_d + k_d*A.shape[d] < L.shape[d]
|
||||
|
||||
For example, for a tensor/layout pair
|
||||
A = [x x x x x x x x]
|
||||
[x x x x x x x x]
|
||||
L = [0 1 2 3 ]
|
||||
[4 5 6 7 ]
|
||||
[8 9 10 11]
|
||||
[12 13 14 15]
|
||||
|
||||
Then the data of A would be distributed as follow between the 16 CUDA threads:
|
||||
L(A) = [ {0,8} , {1,9} , {2,10}, {3,11}, {0,8} , {1, 9} , {2, 10}, {3, 11},
|
||||
{4,12}, {5,13}, {6,14}, {7,15}, {4,12}, {5, 13}, {6, 14}, {7, 15} ]
|
||||
}];
|
||||
|
||||
let extraClassDeclaration = extraBaseClassDeclaration;
|
||||
}
|
||||
|
||||
//===----------------------------------------------------------------------===//
|
||||
// Blocked Layout Encoding
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
||||
def BlockedEncodingAttr : DistributedEncoding<"BlockedEncoding"> {
|
||||
let mnemonic = "blocked";
|
||||
|
||||
let description = [{
|
||||
An encoding where each warp owns a contiguous portion of the target tensor. This is typically the kind of data layout
|
||||
used to promote memory coalescing in LoadInst and StoreInst.
|
||||
It is characterized by three tuples -- thread tile size, warp tile size, and block tile size -- which
|
||||
specify the amount of elements owned by each CUDA thread, warp and CTA respectively.
|
||||
|
||||
Example 1, a row-major coalesced layout may partition a 16x16 tensor over 2 warps (i.e. 64 threads) as follows:
|
||||
|
||||
[ 0 0 1 1 2 2 3 3 ; 32 32 33 33 34 34 35 35 ]
|
||||
[ 0 0 1 1 2 2 3 3 ; 32 32 33 33 34 34 35 35 ]
|
||||
[ 4 4 5 5 6 6 7 7 ; 36 36 37 37 38 38 39 39 ]
|
||||
[ 4 4 5 5 6 6 7 7 ; 36 36 37 37 38 38 39 39 ]
|
||||
...
|
||||
[ 28 28 29 29 30 30 31 31 ; 60 60 61 61 62 62 63 63 ]
|
||||
[ 28 28 29 29 30 30 31 31 ; 60 60 61 61 62 62 63 63 ]
|
||||
|
||||
for
|
||||
|
||||
#triton_gpu_rocm.blocked_layout<{
|
||||
sizePerThread = {2, 2}
|
||||
threadsPerWarp = {8, 4}
|
||||
warpsPerCTA = {1, 2}
|
||||
CTAsPerCGA = {1, 1}
|
||||
}>
|
||||
|
||||
Example 2, a row-major coalesced layout may partition a 32x32 tensor over 2 warps (i.e. 64 threads) as follows:
|
||||
|
||||
[ 0 0 1 1 2 2 3 3 ; 32 32 33 33 34 34 35 35 0 0 1 1 2 2 3 3 ; 32 32 33 33 34 34 35 35 ]
|
||||
[ 0 0 1 1 2 2 3 3 ; 32 32 33 33 34 34 35 35 0 0 1 1 2 2 3 3 ; 32 32 33 33 34 34 35 35 ]
|
||||
[ 4 4 5 5 6 6 7 7 ; 36 36 37 37 38 38 39 39 4 4 5 5 6 6 7 7 ; 36 36 37 37 38 38 39 39 ]
|
||||
[ 4 4 5 5 6 6 7 7 ; 36 36 37 37 38 38 39 39 4 4 5 5 6 6 7 7 ; 36 36 37 37 38 38 39 39 ]
|
||||
... ...
|
||||
[ 28 28 29 29 30 30 31 31 ; 60 60 61 61 62 62 63 63 28 28 29 29 30 30 31 31 ; 60 60 61 61 62 62 63 63 ]
|
||||
[ 28 28 29 29 30 30 31 31 ; 60 60 61 61 62 62 63 63 28 28 29 29 30 30 31 31 ; 60 60 61 61 62 62 63 63 ]
|
||||
[ 0 0 1 1 2 2 3 3 ; 32 32 33 33 34 34 35 35 0 0 1 1 2 2 3 3 ; 32 32 33 33 34 34 35 35 ]
|
||||
[ 0 0 1 1 2 2 3 3 ; 32 32 33 33 34 34 35 35 0 0 1 1 2 2 3 3 ; 32 32 33 33 34 34 35 35 ]
|
||||
[ 4 4 5 5 6 6 7 7 ; 36 36 37 37 38 38 39 39 4 4 5 5 6 6 7 7 ; 36 36 37 37 38 38 39 39 ]
|
||||
[ 4 4 5 5 6 6 7 7 ; 36 36 37 37 38 38 39 39 4 4 5 5 6 6 7 7 ; 36 36 37 37 38 38 39 39 ]
|
||||
... ...
|
||||
[ 28 28 29 29 30 30 31 31 ; 60 60 61 61 62 62 63 63 28 28 29 29 30 30 31 31 ; 60 60 61 61 62 62 63 63 ]
|
||||
[ 28 28 29 29 30 30 31 31 ; 60 60 61 61 62 62 63 63 28 28 29 29 30 30 31 31 ; 60 60 61 61 62 62 63 63 ]
|
||||
for
|
||||
|
||||
#triton_gpu_rocm.blocked_layout<{
|
||||
sizePerThread = {2, 2}
|
||||
threadsPerWarp = {8, 4}
|
||||
warpsPerCTA = {1, 2}
|
||||
CTAsPerCGA = {1, 1}
|
||||
}>
|
||||
|
||||
Example 3, A row-major coalesced layout may partition a 32x32 tensor over 2 warps (i.e. 64 threads) and
|
||||
4 CTAs (taking 2x2 for example) as follows:
|
||||
|
||||
CTA [0,0] CTA [0,1]
|
||||
[ 0 0 1 1 2 2 3 3 ; 32 32 33 33 34 34 35 35 ] [ 0 0 1 1 2 2 3 3 ; 32 32 33 33 34 34 35 35 ]
|
||||
[ 0 0 1 1 2 2 3 3 ; 32 32 33 33 34 34 35 35 ] [ 0 0 1 1 2 2 3 3 ; 32 32 33 33 34 34 35 35 ]
|
||||
[ 4 4 5 5 6 6 7 7 ; 36 36 37 37 38 38 39 39 ] [ 4 4 5 5 6 6 7 7 ; 36 36 37 37 38 38 39 39 ]
|
||||
[ 4 4 5 5 6 6 7 7 ; 36 36 37 37 38 38 39 39 ] [ 4 4 5 5 6 6 7 7 ; 36 36 37 37 38 38 39 39 ]
|
||||
... ...
|
||||
[ 28 28 29 29 30 30 31 31 ; 60 60 61 61 62 62 63 63 ] [ 28 28 29 29 30 30 31 31 ; 60 60 61 61 62 62 63 63 ]
|
||||
[ 28 28 29 29 30 30 31 31 ; 60 60 61 61 62 62 63 63 ] [ 28 28 29 29 30 30 31 31 ; 60 60 61 61 62 62 63 63 ]
|
||||
|
||||
CTA [1,0] CTA [1,1]
|
||||
[ 0 0 1 1 2 2 3 3 ; 32 32 33 33 34 34 35 35 ] [ 0 0 1 1 2 2 3 3 ; 32 32 33 33 34 34 35 35 ]
|
||||
[ 0 0 1 1 2 2 3 3 ; 32 32 33 33 34 34 35 35 ] [ 0 0 1 1 2 2 3 3 ; 32 32 33 33 34 34 35 35 ]
|
||||
[ 4 4 5 5 6 6 7 7 ; 36 36 37 37 38 38 39 39 ] [ 4 4 5 5 6 6 7 7 ; 36 36 37 37 38 38 39 39 ]
|
||||
[ 4 4 5 5 6 6 7 7 ; 36 36 37 37 38 38 39 39 ] [ 4 4 5 5 6 6 7 7 ; 36 36 37 37 38 38 39 39 ]
|
||||
... ...
|
||||
[ 28 28 29 29 30 30 31 31 ; 60 60 61 61 62 62 63 63 ] [ 28 28 29 29 30 30 31 31 ; 60 60 61 61 62 62 63 63 ]
|
||||
[ 28 28 29 29 30 30 31 31 ; 60 60 61 61 62 62 63 63 ] [ 28 28 29 29 30 30 31 31 ; 60 60 61 61 62 62 63 63 ]
|
||||
for
|
||||
|
||||
#triton_gpu_rocm.blocked_layout<{
|
||||
sizePerThread = {2, 2}
|
||||
threadsPerWarp = {8, 4}
|
||||
warpsPerCTA = {1, 2}
|
||||
CTAsPerCGA = {2, 2}
|
||||
}>
|
||||
}];
|
||||
|
||||
let parameters = (
|
||||
ins
|
||||
ArrayRefParameter<"unsigned">:$sizePerThread,
|
||||
ArrayRefParameter<"unsigned">:$threadsPerWarp,
|
||||
ArrayRefParameter<"unsigned">:$warpsPerCTA,
|
||||
ArrayRefParameter<"unsigned">:$order, // the fastest-changing axis first
|
||||
"CTALayoutAttr":$CTALayout
|
||||
);
|
||||
|
||||
let builders = [
|
||||
AttrBuilder<(ins "ArrayRef<int64_t>":$shape,
|
||||
"ArrayRef<unsigned>":$sizePerThread,
|
||||
"ArrayRef<unsigned>":$order,
|
||||
"unsigned":$numWarps,
|
||||
"unsigned":$numThreadsPerWarp,
|
||||
"CTALayoutAttr":$CTALayout), [{
|
||||
unsigned rank = sizePerThread.size();
|
||||
SmallVector<unsigned, 4> threadsPerWarp(rank);
|
||||
SmallVector<unsigned, 4> warpsPerCTA(rank);
|
||||
SmallVector<int64_t> shapePerCTA = getShapePerCTA(CTALayout.getCTASplitNum(), shape);
|
||||
|
||||
unsigned remainingLanes = numThreadsPerWarp;
|
||||
unsigned remainingThreads = numWarps * numThreadsPerWarp;
|
||||
unsigned remainingWarps = numWarps;
|
||||
unsigned prevLanes = 1;
|
||||
unsigned prevWarps = 1;
|
||||
|
||||
// starting from the contiguous dimension
|
||||
for (unsigned d = 0; d < rank - 1; ++d) {
|
||||
unsigned i = order[d];
|
||||
unsigned threadsPerCTA = std::clamp<unsigned>(remainingThreads, 1, shapePerCTA[i] / sizePerThread[i]);
|
||||
threadsPerWarp[i] = std::clamp<unsigned>(threadsPerCTA, 1, remainingLanes);
|
||||
warpsPerCTA[i] = std::clamp<unsigned>(threadsPerCTA / threadsPerWarp[i], 1, remainingWarps);
|
||||
remainingWarps /= warpsPerCTA[i];
|
||||
remainingLanes /= threadsPerWarp[i];
|
||||
remainingThreads /= threadsPerCTA;
|
||||
prevLanes *= threadsPerWarp[i];
|
||||
prevWarps *= warpsPerCTA[i];
|
||||
}
|
||||
|
||||
// Expand the last dimension to fill the remaining lanes and warps
|
||||
threadsPerWarp[order[rank - 1]] = numThreadsPerWarp / prevLanes;
|
||||
warpsPerCTA[order[rank - 1]] = numWarps / prevWarps;
|
||||
|
||||
return $_get(context, sizePerThread, threadsPerWarp, warpsPerCTA, order, CTALayout);
|
||||
}]>,
|
||||
|
||||
AttrBuilder<(ins "ArrayRef<int64_t>":$shape,
|
||||
"ArrayRef<unsigned>":$sizePerThread,
|
||||
"ArrayRef<unsigned>":$order,
|
||||
"unsigned":$numWarps,
|
||||
"unsigned":$numThreadsPerWarp,
|
||||
"unsigned":$numCTAs), [{
|
||||
unsigned rank = sizePerThread.size();
|
||||
SmallVector<unsigned, 4> CTAsPerCGA(rank);
|
||||
SmallVector<unsigned, 4> CTASplitNum(rank);
|
||||
ArrayRef<unsigned> CTAOrder = order;
|
||||
|
||||
unsigned remainingCTAs = numCTAs;
|
||||
|
||||
// starting from the most strided dimension
|
||||
for (int d = rank - 1; d >= 0; --d) {
|
||||
unsigned i = order[d];
|
||||
CTAsPerCGA[i] = std::clamp<unsigned>(remainingCTAs, 1, shape[i] / sizePerThread[i]);
|
||||
CTASplitNum[i] = CTAsPerCGA[i];
|
||||
remainingCTAs /= CTAsPerCGA[i];
|
||||
}
|
||||
|
||||
CTAsPerCGA[rank - 1] *= remainingCTAs; // wrap at CTA level
|
||||
|
||||
CTALayoutAttr CTALayout = CTALayoutAttr::get(context, CTAsPerCGA, CTASplitNum, CTAOrder);
|
||||
return get(context, shape, sizePerThread, order, numWarps, numThreadsPerWarp, CTALayout);
|
||||
}]>
|
||||
];
|
||||
|
||||
let extraClassDeclaration = extraBaseClassDeclaration # [{
|
||||
SliceEncodingAttr squeeze(int axis);
|
||||
}];
|
||||
|
||||
let hasCustomAssemblyFormat = 1;
|
||||
}
|
||||
|
||||
//===----------------------------------------------------------------------===//
|
||||
// MMA Layout Encoding
|
||||
//===----------------------------------------------------------------------===//
|
||||
// TODO: MMAv1 and MMAv2 should be two instances of the same class
|
||||
|
||||
def MmaEncodingAttr : DistributedEncoding<"MmaEncoding"> {
|
||||
let mnemonic = "mma";
|
||||
|
||||
let description = [{
|
||||
An encoding for tensors that have been produced by tensor cores.
|
||||
It is characterized by two parameters:
|
||||
- A 'versionMajor' which specifies the generation the tensor cores
|
||||
whose output is being partitioned: 1 for first-gen tensor cores (Volta),
|
||||
and 2 for second-gen tensor cores (Turing/Ampere).
|
||||
- A 'versionMinor' which indicates the specific layout of a tensor core
|
||||
generation, e.g. for Volta, there might be multiple kinds of layouts annotated
|
||||
by 0,1,2 and so on.
|
||||
- A `blockTileSize` to indicate how data should be
|
||||
partitioned between warps.
|
||||
|
||||
// -------------------------------- version = 1 --------------------------- //
|
||||
|
||||
For first-gen tensor cores, the implicit warpTileSize is [16, 16].
|
||||
Note: the layout is different from the recommended in PTX ISA
|
||||
https://docs.nvidia.com/cuda/parallel-thread-execution/index.html
|
||||
(mma.884 section, FP32 accumulator).
|
||||
|
||||
For example, when versionMinor=1, the matrix L corresponding to
|
||||
blockTileSize=[32,16] is:
|
||||
|
||||
warp 0
|
||||
--------------------------------/\-------------------------------
|
||||
[ 0 0 2 2 8 8 10 10 0 0 2 2 8 8 10 10 ]
|
||||
[ 1 1 3 3 9 9 11 11 1 1 3 3 9 9 11 11 ]
|
||||
[ 0 0 2 2 8 8 10 10 0 0 2 2 8 8 10 10 ]
|
||||
[ 1 1 3 3 9 9 11 11 1 1 3 3 9 9 11 11 ]
|
||||
[ 4 4 6 6 12 12 14 14 4 4 6 6 12 12 14 14 ]
|
||||
[ 5 5 7 7 13 13 15 15 5 5 7 7 13 13 15 15 ]
|
||||
[ 4 4 6 6 12 12 14 14 4 4 6 6 12 12 14 14 ]
|
||||
[ 5 5 7 7 13 13 15 15 5 5 7 7 13 13 15 15 ]
|
||||
[ 16 16 18 18 20 20 22 22 16 16 18 18 20 20 22 22 ]
|
||||
[ 17 17 19 19 21 21 23 23 17 17 19 19 21 21 23 23 ]
|
||||
[ 16 16 18 18 20 20 22 22 16 16 18 18 20 20 22 22 ]
|
||||
[ 17 17 19 19 21 21 23 23 17 17 19 19 21 21 23 23 ]
|
||||
[ 24 24 26 26 28 28 30 30 24 24 26 26 28 28 30 30 ]
|
||||
[ 25 25 27 27 29 29 31 31 25 25 27 27 29 29 31 31 ]
|
||||
[ 24 24 26 26 28 28 30 30 24 24 26 26 28 28 30 30 ]
|
||||
[ 25 25 27 27 29 29 31 31 25 25 27 27 29 29 31 31 ]
|
||||
|
||||
warp 1 = warp0 + 32
|
||||
--------------------------------/\-------------------------------
|
||||
[ 32 32 34 34 40 40 42 42 32 32 34 34 40 40 42 42 ]
|
||||
[ 33 33 35 35 41 41 43 43 33 33 35 35 41 41 43 43 ]
|
||||
[ ............................................................... ]
|
||||
|
||||
|
||||
// -------------------------------- version = 2 --------------------------- //
|
||||
|
||||
For second-gen tensor cores, the implicit warpTileSize is [16, 8].
|
||||
Information about this layout can be found in the official PTX documentation
|
||||
https://docs.nvidia.com/cuda/parallel-thread-execution/index.html
|
||||
(mma.16816 section, FP32 accumulator).
|
||||
|
||||
For example, the matrix L corresponding to blockTileSize=[32,16] is:
|
||||
warp 0 warp 2
|
||||
-----------------/\------------- ----------------/\-------------
|
||||
[ 0 0 1 1 2 2 3 3 32 32 33 33 34 34 35 35
|
||||
[ 4 4 5 5 6 6 7 7 36 36 37 37 38 38 39 39
|
||||
[ .............................. ..............................
|
||||
[ 28 28 29 29 30 30 31 31 60 60 61 61 62 62 63 63
|
||||
[ 0 0 1 1 2 2 3 3 32 32 33 33 34 34 35 35
|
||||
[ 4 4 5 5 6 6 7 7 36 36 37 37 38 38 39 39
|
||||
[ .............................. ..............................
|
||||
[ 28 28 29 29 30 30 31 31 60 60 61 61 62 62 63 63
|
||||
|
||||
warp 1 warp 3
|
||||
----------------/\------------- ----------------/\-------------
|
||||
[ 64 64 65 65 66 66 67 67 96 96 97 97 98 98 99 99
|
||||
[ 68 68 69 69 70 70 71 71 100 100 101 101 102 102 103 103
|
||||
[ .............................. ...............................
|
||||
[ 92 92 93 93 94 94 95 95 124 124 125 125 126 126 127 127
|
||||
[ 64 64 65 65 66 66 67 67 96 96 97 97 98 98 99 99
|
||||
[ 68 68 69 69 70 70 71 71 100 100 101 101 102 102 103 103
|
||||
[ .............................. ...............................
|
||||
[ 92 92 93 93 94 94 95 95 124 124 125 125 126 126 127 127
|
||||
|
||||
}];
|
||||
|
||||
let parameters = (
|
||||
ins
|
||||
"unsigned":$versionMajor,
|
||||
"unsigned":$versionMinor,
|
||||
ArrayRefParameter<"unsigned">:$warpsPerCTA,
|
||||
"CTALayoutAttr":$CTALayout,
|
||||
ArrayRefParameter<"unsigned">:$instrShape
|
||||
);
|
||||
|
||||
let builders = [
|
||||
// Specially for MMAV1(Volta)
|
||||
AttrBuilder<(ins "int":$versionMajor,
|
||||
"int":$numWarps,
|
||||
"CTALayoutAttr":$CTALayout,
|
||||
"ArrayRef<unsigned>":$instrShape,
|
||||
"ArrayRef<int64_t>":$shapeC,
|
||||
"bool":$isARow,
|
||||
"bool":$isBRow,
|
||||
"bool":$isAVec4,
|
||||
"bool":$isBVec4,
|
||||
"int":$id), [{
|
||||
assert(versionMajor == 1 && "This builder is specially for versionMajor==1");
|
||||
// 4-bits to encode 4 booleans: [isARow, isBRow, isAVec4, isBVec4]
|
||||
int versionMinor = (isARow * (1<<0)) |\
|
||||
(isBRow * (1<<1)) |\
|
||||
(isAVec4 * (1<<2)) |\
|
||||
(isBVec4 * (1<<3));
|
||||
|
||||
// TODO: Share code with
|
||||
// DotOpMmaV1ConversionHelper::AParam/BParam, since same code to compute the
|
||||
// rep,spw and fpw.
|
||||
SmallVector<unsigned> wpt({1, 1});
|
||||
SmallVector<unsigned> wpt_nm1;
|
||||
|
||||
SmallVector<int, 2> rep(2), spw(2);
|
||||
std::array<int, 3> fpw{{2, 2, 1}};
|
||||
int packSize0 = (isARow || isAVec4) ? 1 : 2;
|
||||
rep[0] = 2 * packSize0;
|
||||
spw[0] = fpw[0] * 4 * rep[0];
|
||||
|
||||
int packSize1 = (isBRow && !isBVec4) ? 2 : 1;
|
||||
rep[1] = 2 * packSize1;
|
||||
spw[1] = fpw[1] * 4 * rep[1];
|
||||
|
||||
do {
|
||||
wpt_nm1 = wpt;
|
||||
if (wpt[0] * wpt[1] < numWarps)
|
||||
wpt[0] = std::clamp<int>(wpt[0] * 2, 1, shapeC[0] / spw[0]);
|
||||
if (wpt[0] * wpt[1] < numWarps)
|
||||
wpt[1] = std::clamp<int>(wpt[1] * 2, 1, shapeC[1] / spw[1]);
|
||||
} while (wpt_nm1 != wpt);
|
||||
|
||||
return $_get(context, versionMajor, versionMinor, wpt, CTALayout, instrShape);
|
||||
}]>,
|
||||
|
||||
AttrBuilder<(ins "int":$versionMajor,
|
||||
"int":$numWarps,
|
||||
"CTALayoutAttr":$CTALayout,
|
||||
"ArrayRef<unsigned>":$instrShape,
|
||||
"ArrayRef<int64_t>":$shapeA,
|
||||
"ArrayRef<int64_t>":$shapeB,
|
||||
"ArrayRef<int64_t>":$shapeC,
|
||||
"bool":$isARow,
|
||||
"bool":$isBRow,
|
||||
"int":$id), [{
|
||||
assert(versionMajor == 1 && "This builder is specially for versionMajor==1");
|
||||
bool isAVec4 = !isARow && (shapeA[isARow] <= 16);
|
||||
bool isBVec4 = isBRow && (shapeB[isBRow] <= 16);
|
||||
return get(context, versionMajor, numWarps, CTALayout, instrShape, shapeC, isARow, isBRow, isAVec4, isBVec4, id);
|
||||
}]>
|
||||
];
|
||||
|
||||
let extraClassDeclaration = extraBaseClassDeclaration # [{
|
||||
bool isVolta() const;
|
||||
bool isTuring() const;
|
||||
bool isAmpere() const;
|
||||
bool isHopper() const;
|
||||
|
||||
unsigned getElemsPerThreadOfOperand(int opIdx, ArrayRef<int64_t> shape) const;
|
||||
|
||||
// Get [isARow, isBRow, isAVec4, isBVec4, id] from versionMinor
|
||||
std::tuple<bool, bool, bool, bool, int> decodeVoltaLayoutStates() const;
|
||||
|
||||
// Number of bits in versionMinor to hold the ID of the MMA encoding instance.
|
||||
// Here 5 bits can hold 32 IDs in a single module.
|
||||
static constexpr int numBitsToHoldMmaV1ID{5};
|
||||
}];
|
||||
|
||||
let hasCustomAssemblyFormat = 1;
|
||||
}
|
||||
|
||||
def MfmaEncodingAttr : DistributedEncoding<"MfmaEncoding"> {
|
||||
let mnemonic = "mfma";
|
||||
|
||||
let description = [{
|
||||
An encoding for tensors that have been produced by MI100 && MI200 tensor cores.
|
||||
It is characterized by parameters `warpsPerCTA` and `nonKDim` that indicates how data should be partitioned
|
||||
between waves (analogous to the term 'warp' used in NVIDIA's CUDA programming model).
|
||||
|
||||
Example 1:
|
||||
Suppose we have a tensor with a shape of [32, 64], warpsPerCTA set to [1, 2] and nonKDim set to 32.
|
||||
The data will be distributed between threads as follows:
|
||||
|
||||
wave 0 wave 1
|
||||
-----------------/\-------------- -----------------/\--------------
|
||||
[ 0 1 2 3 ...... 30 31 ] [ 64 65 66 67 ...... 94 95 ]
|
||||
[ 0 1 2 3 ...... 30 31 ] [ 64 65 66 67 ...... 94 95 ]
|
||||
[ 0 1 2 3 ...... 30 31 ] [ 64 65 66 67 ...... 94 95 ]
|
||||
[ 0 1 2 3 ...... 30 31 ] [ 64 65 66 67 ...... 94 95 ]
|
||||
[ 32 33 34 35 ...... 62 63 ] [ 96 97 98 99 ...... 126 127 ]
|
||||
[ 32 33 34 35 ...... 62 63 ] [ 96 97 98 99 ...... 126 127 ]
|
||||
[ 32 33 34 35 ...... 62 63 ] [ 96 97 98 99 ...... 126 127 ]
|
||||
[ 32 33 34 35 ...... 62 63 ] [ 96 97 98 99 ...... 126 127 ]
|
||||
[ 0 1 2 3 ...... 30 31 ] [ 64 65 66 67 ...... 94 95 ]
|
||||
[ 0 1 2 3 ...... 30 31 ] [ 64 65 66 67 ...... 94 95 ]
|
||||
[ 0 1 2 3 ...... 30 31 ] [ 64 65 66 67 ...... 94 95 ]
|
||||
[ 0 1 2 3 ...... 30 31 ] [ 64 65 66 67 ...... 94 95 ]
|
||||
[ 32 33 34 35 ...... 62 63 ] [ 96 97 98 99 ...... 126 127 ]
|
||||
[ 32 33 34 35 ...... 62 63 ] [ 96 97 98 99 ...... 126 127 ]
|
||||
[ 32 33 34 35 ...... 62 63 ] [ 96 97 98 99 ...... 126 127 ]
|
||||
[ 32 33 34 35 ...... 62 63 ] [ 96 97 98 99 ...... 126 127 ]
|
||||
[ 0 1 2 3 ...... 30 31 ] [ 64 65 66 67 ...... 94 95 ]
|
||||
[ 0 1 2 3 ...... 30 31 ] [ 64 65 66 67 ...... 94 95 ]
|
||||
[ 0 1 2 3 ...... 30 31 ] [ 64 65 66 67 ...... 94 95 ]
|
||||
[ 0 1 2 3 ...... 30 31 ] [ 64 65 66 67 ...... 94 95 ]
|
||||
[ 32 33 34 35 ...... 62 63 ] [ 96 97 98 99 ...... 126 127 ]
|
||||
[ 32 33 34 35 ...... 62 63 ] [ 96 97 98 99 ...... 126 127 ]
|
||||
[ 32 33 34 35 ...... 62 63 ] [ 96 97 98 99 ...... 126 127 ]
|
||||
[ 32 33 34 35 ...... 62 63 ] [ 96 97 98 99 ...... 126 127 ]
|
||||
[ 0 1 2 3 ...... 30 31 ] [ 64 65 66 67 ...... 94 95 ]
|
||||
[ 0 1 2 3 ...... 30 31 ] [ 64 65 66 67 ...... 94 95 ]
|
||||
[ 0 1 2 3 ...... 30 31 ] [ 64 65 66 67 ...... 94 95 ]
|
||||
[ 0 1 2 3 ...... 30 31 ] [ 64 65 66 67 ...... 94 95 ]
|
||||
[ 32 33 34 35 ...... 62 63 ] [ 96 97 98 99 ...... 126 127 ]
|
||||
[ 32 33 34 35 ...... 62 63 ] [ 96 97 98 99 ...... 126 127 ]
|
||||
[ 32 33 34 35 ...... 62 63 ] [ 96 97 98 99 ...... 126 127 ]
|
||||
[ 32 33 34 35 ...... 62 63 ] [ 96 97 98 99 ...... 126 127 ]
|
||||
|
||||
Example 2:
|
||||
Suppose we have a tensor with a shape of [16, 32], warpsPerCTA set to [1, 2] and nonKDim set to 16.
|
||||
The data will be distributed between threads as follows:
|
||||
|
||||
wave 0 wave 1
|
||||
-----------------/\------------- ------------------/\---------------
|
||||
[ 0 1 2 3 ...... 14 15 ] [ 64 65 66 67 ...... 78 79 ]
|
||||
[ 0 1 2 3 ...... 14 15 ] [ 64 65 66 67 ...... 78 79 ]
|
||||
[ 0 1 2 3 ...... 14 15 ] [ 64 65 66 67 ...... 78 79 ]
|
||||
[ 0 1 2 3 ...... 14 15 ] [ 64 65 66 67 ...... 78 79 ]
|
||||
[ 16 17 18 19 ...... 30 31 ] [ 80 81 82 83 ...... 94 95 ]
|
||||
[ 16 17 18 19 ...... 30 31 ] [ 80 81 82 83 ...... 94 95 ]
|
||||
[ 16 17 18 19 ...... 30 31 ] [ 80 81 82 83 ...... 94 95 ]
|
||||
[ 16 17 18 19 ...... 30 31 ] [ 80 81 82 83 ...... 94 95 ]
|
||||
[ 32 33 34 35 ...... 46 47 ] [ 96 97 98 99 ...... 110 111 ]
|
||||
[ 32 33 34 35 ...... 46 47 ] [ 96 97 98 99 ...... 110 111 ]
|
||||
[ 32 33 34 35 ...... 46 47 ] [ 96 97 98 99 ...... 110 111 ]
|
||||
[ 32 33 34 35 ...... 46 47 ] [ 96 97 98 99 ...... 110 111 ]
|
||||
[ 48 49 50 51 ...... 62 63 ] [ 112 113 114 115 ...... 126 127 ]
|
||||
[ 48 49 50 51 ...... 62 63 ] [ 112 113 114 115 ...... 126 127 ]
|
||||
[ 48 49 50 51 ...... 62 63 ] [ 112 113 114 115 ...... 126 127 ]
|
||||
[ 48 49 50 51 ...... 62 63 ] [ 112 113 114 115 ...... 126 127 ]
|
||||
}];
|
||||
|
||||
let parameters = (
|
||||
ins
|
||||
"unsigned":$nonKDim,
|
||||
ArrayRefParameter<"unsigned">:$warpsPerCTA,
|
||||
"bool":$isTransposed,
|
||||
"CTALayoutAttr":$CTALayout
|
||||
);
|
||||
|
||||
let hasCustomAssemblyFormat = 1;
|
||||
}
|
||||
|
||||
def SliceEncodingAttr : DistributedEncoding<"SliceEncoding"> {
|
||||
let mnemonic = "slice";
|
||||
|
||||
let description = [{
|
||||
TODO: improve docs
|
||||
|
||||
A = [x x x x x x x x]
|
||||
|
||||
parent = [0 1 2 3 ]
|
||||
[4 5 6 7 ]
|
||||
[8 9 10 11]
|
||||
[12 13 14 15]
|
||||
dim = 0
|
||||
|
||||
Then the data of A would be distributed as follow between the 16 CUDA threads:
|
||||
L(A) = [ {0,4,8,12} , {1,5,9,13} , ... {3,7,11,15}, {0,4,8,12} , ..., {3,7,11,15} ]
|
||||
|
||||
This is useful for constructing the inverse layout of an expand_dims operation during some optimization passes.
|
||||
|
||||
}];
|
||||
|
||||
let parameters = (
|
||||
ins
|
||||
"unsigned":$dim,
|
||||
// TODO: constraint here to only take distributed encodings
|
||||
"Attribute":$parent
|
||||
);
|
||||
|
||||
let extraClassDeclaration = extraBaseClassDeclaration # [{
|
||||
template<class T>
|
||||
SmallVector<T> paddedShape(ArrayRef<T> shape) const;
|
||||
}];
|
||||
|
||||
let hasCustomAssemblyFormat = 1;
|
||||
}
|
||||
|
||||
def DotOperandEncodingAttr : DistributedEncoding<"DotOperandEncoding"> {
|
||||
let mnemonic = "dot_op";
|
||||
|
||||
let description = [{
|
||||
In TritonGPU dialect, considering `d = tt.dot a, b, c`
|
||||
tt.dot's operands a and b must be of DotOperandEncodingAttr layout.
|
||||
a's opIdx is 0, b's opIdx is 1.
|
||||
The parend field in DotOperandEncodingAttr is the layout of d.
|
||||
|
||||
For MMA v1, an additional attribute `isMMAv1Row` determines whether e.g. the a operand is used
|
||||
in the context of an mma.884.row.col or an mma.884.col.col operation. See the PTX ISA documentation
|
||||
section 9.7.13.4.1 for more details.
|
||||
}];
|
||||
|
||||
let parameters = (
|
||||
ins
|
||||
"unsigned":$opIdx,
|
||||
"Attribute":$parent,
|
||||
"unsigned":$kWidth
|
||||
);
|
||||
|
||||
let builders = [
|
||||
// Specially for MMAV1(Volta)
|
||||
AttrBuilder<(ins "unsigned":$opIdx,
|
||||
"Attribute":$parent,
|
||||
"Type":$eltTy), [{
|
||||
MmaEncodingAttr parentAttr = parent.dyn_cast<MmaEncodingAttr>();
|
||||
if (!parentAttr || !parentAttr.isAmpere())
|
||||
return $_get(context, opIdx, parent, 0);
|
||||
unsigned bitwidth = eltTy.getIntOrFloatBitWidth();
|
||||
unsigned kWidth = 32 / bitwidth;
|
||||
return $_get(context, opIdx, parent, kWidth);
|
||||
}]>
|
||||
];
|
||||
|
||||
let hasCustomAssemblyFormat = 1;
|
||||
let extraClassDeclaration = extraBaseClassDeclaration # [{
|
||||
bool getMMAv1IsRow() const;
|
||||
bool getMMAv1IsVec4() const;
|
||||
SmallVector<int> getMMAv1Rep() const;
|
||||
SmallVector<int> getMMAv1ShapePerWarp() const;
|
||||
int getMMAv1Vec() const;
|
||||
int getMMAv1NumOuter(ArrayRef<int64_t> shape) const;
|
||||
//
|
||||
SmallVector<int64_t> getMMAv2Rep(ArrayRef<int64_t> shape,
|
||||
int bitwidth) const;
|
||||
#if 1
|
||||
SmallVector<int64_t> getMFMAElemsPerInstr() const;
|
||||
SmallVector<int64_t> getMFMARep(ArrayRef<int64_t> operandShape,
|
||||
Type elemType) const;
|
||||
#endif
|
||||
}];
|
||||
}
|
||||
|
||||
#endif
|
||||
69
python/triton/third_party/hip/include/triton/Dialect/TritonGPUROCM/IR/TritonGPUDialect.td
vendored
Normal file
69
python/triton/third_party/hip/include/triton/Dialect/TritonGPUROCM/IR/TritonGPUDialect.td
vendored
Normal file
@@ -0,0 +1,69 @@
|
||||
#ifndef TRITON_DIALECT_TRITONGPUROCM_DIALECT
|
||||
#define TRITON_DIALECT_TRITONGPUROCM_DIALECT
|
||||
|
||||
include "mlir/IR/OpBase.td"
|
||||
|
||||
def TritonGPUROCM_Dialect : Dialect {
|
||||
let name = "triton_gpu_rocm";
|
||||
|
||||
let cppNamespace = "::mlir::triton::gpu_rocm";
|
||||
|
||||
let hasOperationAttrVerify = 1;
|
||||
|
||||
let description = [{
|
||||
Triton GPU ROCM Dialect.
|
||||
}];
|
||||
|
||||
let dependentDialects = [
|
||||
"triton::TritonDialect",
|
||||
"mlir::triton::nvgpu::NVGPUDialect",
|
||||
"mlir::gpu::GPUDialect",
|
||||
"tensor::TensorDialect",
|
||||
];
|
||||
|
||||
let extraClassDeclaration = [{
|
||||
static std::string getNumWarpsAttrName() { return "triton_gpu_rocm.num-warps"; }
|
||||
static int getNumWarps(ModuleOp mod) {
|
||||
if(!mod->hasAttr("triton_gpu_rocm.num-warps"))
|
||||
llvm::report_fatal_error(
|
||||
"TritonGPU module should contain a triton_gpu_rocm.num-warps attribute");
|
||||
return mod->getAttr("triton_gpu_rocm.num-warps").cast<IntegerAttr>().getInt();
|
||||
}
|
||||
static int getNumCTAs(ModuleOp mod) {
|
||||
if(!mod->hasAttr("triton_gpu_rocm.num-ctas"))
|
||||
llvm::report_fatal_error(
|
||||
"TritonGPU module should contain a triton_gpu_rocm.num-ctas attribute");
|
||||
return mod->getAttr("triton_gpu_rocm.num-ctas").cast<IntegerAttr>().getInt();
|
||||
}
|
||||
static int getComputeCapability(ModuleOp mod) {
|
||||
if(!mod->hasAttr("triton_gpu_rocm.compute-capability"))
|
||||
llvm::report_fatal_error(
|
||||
"TritonGPU module should contain a triton_gpu_rocm.compute-capability attribute");
|
||||
return mod->getAttrOfType<IntegerAttr>("triton_gpu_rocm.compute-capability").getInt();
|
||||
}
|
||||
void registerTypes();
|
||||
|
||||
static std::string getThreadsPerWarpAttrName() { return "triton_gpu_rocm.threads-per-warp"; }
|
||||
|
||||
static int getThreadsPerWarp(ModuleOp mod) {
|
||||
Attribute threadsPerWarp = mod->getDiscardableAttr("triton_gpu_rocm.threads-per-warp");
|
||||
if(!threadsPerWarp) {
|
||||
return 64;
|
||||
}
|
||||
return threadsPerWarp.cast<IntegerAttr>().getInt();
|
||||
}
|
||||
static int getSharedSize(ModuleOp mod) {
|
||||
Attribute sharedAttr = mod->getDiscardableAttr("triton_gpu_rocm.shared");
|
||||
if(!sharedAttr) {
|
||||
return 0;
|
||||
}
|
||||
return sharedAttr.cast<IntegerAttr>().getInt();
|
||||
}
|
||||
|
||||
}];
|
||||
|
||||
let useDefaultAttributePrinterParser = 1;
|
||||
let usePropertiesForAttributes = 1;
|
||||
}
|
||||
|
||||
#endif
|
||||
398
python/triton/third_party/hip/include/triton/Dialect/TritonGPUROCM/IR/TritonGPUOps.td
vendored
Normal file
398
python/triton/third_party/hip/include/triton/Dialect/TritonGPUROCM/IR/TritonGPUOps.td
vendored
Normal file
@@ -0,0 +1,398 @@
|
||||
#ifndef TRITON_DIALECT_TRITONGPUROCM_OPS
|
||||
#define TRITON_DIALECT_TRITONGPUROCM_OPS
|
||||
|
||||
|
||||
include "triton/Dialect/TritonGPUROCM/IR/TritonGPUDialect.td"
|
||||
include "triton/Dialect/TritonGPUROCM/IR/TritonGPUTypes.td"
|
||||
include "triton/Dialect/TritonGPUROCM/IR/TritonGPUAttrDefs.td"
|
||||
include "mlir/Dialect/Arith/IR/ArithBase.td"
|
||||
include "triton/Dialect/Triton/IR/TritonTypes.td"
|
||||
include "triton/Dialect/Triton/IR/TritonAttrDefs.td"
|
||||
include "mlir/IR/OpBase.td"
|
||||
include "mlir/Interfaces/SideEffectInterfaces.td" // Pure
|
||||
include "mlir/Interfaces/InferTypeOpInterface.td" // SameOperandsAndResultType
|
||||
include "mlir/Interfaces/DestinationStyleOpInterface.td"
|
||||
include "mlir/Interfaces/ViewLikeInterface.td"
|
||||
|
||||
def ResultsAreSharedEncodingROCM: NativeOpTrait<"ResultsAreSharedEncodingROCM">;
|
||||
|
||||
class TTG_Op<string mnemonic, list<Trait> traits = []> :
|
||||
Op<TritonGPUROCM_Dialect, mnemonic, traits>;
|
||||
|
||||
def TTG_ConvertLayoutOp : TTG_Op<"convert_layout",
|
||||
[SameOperandsAndResultShape,
|
||||
SameOperandsAndResultElementType,
|
||||
Pure]> {
|
||||
let summary = "convert layout";
|
||||
|
||||
let arguments = (ins TT_Tensor:$src);
|
||||
|
||||
let results = (outs TT_Tensor:$result);
|
||||
|
||||
let hasCanonicalizeMethod = 1;
|
||||
|
||||
let assemblyFormat = "$src attr-dict `:` functional-type(operands, results)";
|
||||
}
|
||||
|
||||
def TTG_AsyncWaitOp : TTG_Op<"async_wait"> {
|
||||
let summary = "async wait";
|
||||
|
||||
let arguments = (ins I32Attr:$num);
|
||||
|
||||
let assemblyFormat = "attr-dict";
|
||||
|
||||
let extraClassDeclaration = [{
|
||||
static bool isSupported(int computeCapability) {
|
||||
return computeCapability >= 80;
|
||||
}
|
||||
}];
|
||||
}
|
||||
|
||||
def TTG_AsyncBulkWaitOp : TTG_Op<"async_bulk_wait"> {
|
||||
let summary = "async bulk wait";
|
||||
|
||||
let arguments = (ins I32Attr:$num);
|
||||
|
||||
let assemblyFormat = "attr-dict";
|
||||
|
||||
let extraClassDeclaration = [{
|
||||
static bool isSupported(int computeCapability) {
|
||||
return computeCapability >= 90;
|
||||
}
|
||||
}];
|
||||
}
|
||||
|
||||
def TTG_AsyncCommitGroupOp : TTG_Op<"async_commit_group"> {
|
||||
let summary = "async commit group";
|
||||
|
||||
let assemblyFormat = "attr-dict";
|
||||
|
||||
let extraClassDeclaration = [{
|
||||
static bool isSupported(int computeCapability) {
|
||||
return computeCapability >= 80;
|
||||
}
|
||||
}];
|
||||
}
|
||||
|
||||
def TTG_AsyncBulkCommitGroupOp : TTG_Op<"async_bulk_commit_group"> {
|
||||
let summary = "async bulk commit group";
|
||||
|
||||
let assemblyFormat = "attr-dict";
|
||||
|
||||
let extraClassDeclaration = [{
|
||||
static bool isSupported(int computeCapability) {
|
||||
return computeCapability >= 90;
|
||||
}
|
||||
}];
|
||||
}
|
||||
|
||||
|
||||
// Port Arith_CmpIOp & Arith_CmpFOp & Std_SelectOp to TritonGPU.
|
||||
// This is needed because these ops don't
|
||||
// handle encodings
|
||||
// e.g., https://github.com/llvm/llvm-project/blob/main/mlir/include/mlir/Dialect/Arith/IR/ArithOps.td#L111
|
||||
def TTG_CmpIOp : TTG_Op<"cmpi", [Pure, Elementwise,
|
||||
SameOperandsAndResultShape,
|
||||
SameOperandsAndResultEncoding]> {
|
||||
let summary = "integer comparison operation";
|
||||
|
||||
let description = [{}];
|
||||
|
||||
let arguments = (ins Arith_CmpIPredicateAttr:$predicate,
|
||||
TT_IntLike:$lhs,
|
||||
TT_IntLike:$rhs);
|
||||
|
||||
let results = (outs TT_BoolLike:$result);
|
||||
}
|
||||
|
||||
def TTG_CmpFOp : TTG_Op<"cmpf", [Pure, Elementwise,
|
||||
SameOperandsAndResultShape,
|
||||
SameOperandsAndResultEncoding]> {
|
||||
let summary = "floating-point comparison operation";
|
||||
|
||||
let description = [{}];
|
||||
|
||||
let arguments = (ins Arith_CmpFPredicateAttr:$predicate,
|
||||
TT_FloatLike:$lhs,
|
||||
TT_FloatLike:$rhs);
|
||||
|
||||
let results = (outs TT_BoolLike:$result);
|
||||
}
|
||||
|
||||
// TODO: migrate to arith::SelectOp on LLVM16
|
||||
def TTG_SelectOp : TTG_Op<"select", [Pure, Elementwise,
|
||||
SameOperandsAndResultShape,
|
||||
SameOperandsAndResultEncoding]> {
|
||||
let summary = "select operation";
|
||||
|
||||
let description = [{}];
|
||||
|
||||
let arguments = (ins TT_BoolLike:$condition,
|
||||
TT_Tensor:$true_value,
|
||||
TT_Tensor:$false_value);
|
||||
|
||||
let results = (outs TT_Type:$result);
|
||||
}
|
||||
|
||||
// TODO[goostavz]: extract a base class for InsertSlice & InsertSliceAsync once the op definition is verified
|
||||
def TTG_InsertSliceOp : TTG_Op<"insert_slice",
|
||||
[AttrSizedOperandSegments,
|
||||
ResultsAreSharedEncodingROCM,
|
||||
MemoryEffects<[MemRead, MemWrite]>,
|
||||
TypesMatchWith<"infer mask type from src type",
|
||||
"src", "mask", "getI1SameShape($_self)",
|
||||
"($_op.getOperands().size() <= 3) || std::equal_to<>()">,
|
||||
TypesMatchWith<"infer other type from src type",
|
||||
"src", "other", "getPointeeType($_self)",
|
||||
"($_op.getOperands().size() <= 4) || std::equal_to<>()">]> {
|
||||
let summary = "insert slice";
|
||||
|
||||
let description = [{
|
||||
This operation inserts a tensor `$src` into another tensor `$dst` as specified by the operation’s
|
||||
`$index` argument and `$axis` attribute.
|
||||
|
||||
It returns a copy of `$dst` with the proper slice updated with the value of `$src`.
|
||||
|
||||
When converting from `tt.load` to `triton_gpu_rocm.insert_slice`, the `$evict`, `$cache`, and `$isVolatile` fields
|
||||
might be ignored on certain hardware. For example, on NVIDIA GPUs, the cache policy is determined by the backend,
|
||||
and `$evict` and `$isVolatile` are ignored because they apply to L1 cache only.
|
||||
|
||||
The insert_slice operation supports the following arguments:
|
||||
|
||||
* src: the tensor that is inserted.
|
||||
* dst: the tensor into which the `$src` tensor is inserted.
|
||||
* index: the index of the `$src` tensor at the given `$axis` from which the `$dst` tensor is inserted into
|
||||
* mask: optional tensor-rank number of boolean masks which specify which
|
||||
elements of the `$src` tensor are inserted into the `$dst` tensor.
|
||||
* other: optional tensor-rank number of other tensors which specify what
|
||||
values are inserted into the `$dst` tensor if the corresponding
|
||||
element of the `$mask` tensor is false.
|
||||
|
||||
ttgpu.load_tile_async depracate
|
||||
triton_gpu_rocm.insert_slice might be further lowered into triton_gpu_async for different hardware implementations
|
||||
|
||||
like tt.load, ttgpu.insert_slice/insert_slice_async has two modes up to the type of src
|
||||
mode 1: ptr/src is a tensor of pointers
|
||||
mode 2: ptr/src is a tensor pointer
|
||||
|
||||
Some typical lowering paths are:
|
||||
in case the load is pipelined by the pipeline pass( load is inside kBlock loop, which means "pipeline pass):
|
||||
Load from global + store to shared : tt.load(mode 1) -(tt->ttgpu+Coalesce)-> tt.load(mode 1) -(Pipeline)-> ttgpu.insert_slice(mode 1)
|
||||
Non-bulk cp.async : tt.load(mode 1) -(tt->ttgpu+Coalesce)-> tt.load(mode 1) -(Pipeline)-> ttgpu.insert_slice(mode 1) -(MaterializeLoad)> ttgpu.insert_slice_async(mode 1) + ttgpu.await-> llvm
|
||||
TMA load : tt.load(mode 2) -(tt->ttgpu+Coalesce)-> tt.load(mode 2) -(Pipeline)-> ttgpu.insert_slice(mode 2) -(MaterializeLoad)> ttgpu.insert_slice_async_v2(mode 2) + ttgpu.await-> llvm
|
||||
|
||||
otherwise:
|
||||
Load from global + store to shared : tt.load(mode 1) -(tt->ttgpu+Coalesce)-> tt.load(mode 1)
|
||||
Non-bulk cp.async : tt.load(mode 1) -(tt->ttgpu+Coalesce)-> tt.load(mode 1) -> ... -(MaterializeLoad)-> ttgpu.insert_slice_async(mode 1) + ttgpu.await -> llvm
|
||||
TMA load : tt.load(mode 2) -(tt->ttgpu+Coalesce)-> tt.load(mode 2) -> ... -(MaterializeLoad)-> ttgpu.insert_slice_async(mode 2) + ttgpu.await -> llvm
|
||||
|
||||
Example:
|
||||
|
||||
```
|
||||
%1 = triton_gpu_rocm.alloc_tensor : tensor<2x32xf32>
|
||||
%2 = triton_gpu_rocm.insert_slice %0, %1, %index { axis = 0 } : tensor<32x!tt.ptr<f32>, #AL> -> tensor<2x32xf32, #A>
|
||||
```
|
||||
}];
|
||||
|
||||
let arguments = (ins TT_PtrLike:$src, TT_Tensor:$dst, I32:$index,
|
||||
Optional<I1Tensor>:$mask, Optional<TT_Type>:$other,
|
||||
TT_CacheModifierAttr:$cache, TT_EvictionPolicyAttr:$evict,
|
||||
BoolAttr:$isVolatile, I32Attr:$axis);
|
||||
|
||||
let builders = [
|
||||
OpBuilder<(ins "Value":$src, "Value":$dst, "Value":$index,
|
||||
"triton::CacheModifier":$cache,
|
||||
"triton::EvictionPolicy":$evict, "bool":$isVolatile, "int":$axis)>,
|
||||
OpBuilder<(ins "Value":$src, "Value":$dst, "Value":$index, "Value":$mask,
|
||||
"triton::CacheModifier":$cache,
|
||||
"triton::EvictionPolicy":$evict, "bool":$isVolatile, "int":$axis)>,
|
||||
OpBuilder<(ins "Value":$src, "Value":$dst, "Value":$index,
|
||||
"Value":$mask, "Value":$other,
|
||||
"triton::CacheModifier":$cache,
|
||||
"triton::EvictionPolicy":$evict, "bool":$isVolatile, "int":$axis)>,
|
||||
];
|
||||
|
||||
let results = (outs TT_Tensor:$result);
|
||||
|
||||
let extraClassDeclaration = [{
|
||||
static DenseSet<unsigned> getEligibleLoadByteWidth(int computeCapability) {
|
||||
DenseSet<unsigned> validLoadBytes;
|
||||
if (computeCapability >= 80) {
|
||||
validLoadBytes = {4, 8, 16};
|
||||
}
|
||||
return validLoadBytes;
|
||||
}
|
||||
}];
|
||||
|
||||
let hasCustomAssemblyFormat = 1;
|
||||
}
|
||||
|
||||
|
||||
def TTG_ExtractSliceOp : TTG_Op<"extract_slice",
|
||||
[AttrSizedOperandSegments,
|
||||
ResultsAreSharedEncodingROCM,
|
||||
Pure,
|
||||
OffsetSizeAndStrideOpInterface
|
||||
]> {
|
||||
let summary = "extract slice operation";
|
||||
let description = [{
|
||||
same as tensor.extract_slice, but with int32 index. The motivations for re-implementing it are:
|
||||
We reimplement ExtractSliceOp with int32 index, because:
|
||||
- we want to enforce int32 indexing on GPUs since Triton tensors fit in SRAM
|
||||
- we still want to use indexWidth = 64 when lowering to LLVM because our loops can have
|
||||
64-bit induction variables and scf.for uses indexType for bounds/ivs
|
||||
}];
|
||||
|
||||
let arguments = (ins
|
||||
AnyRankedTensor:$source,
|
||||
Variadic<I32>:$offsets,
|
||||
Variadic<I32>:$sizes,
|
||||
Variadic<I32>:$strides,
|
||||
DenseI64ArrayAttr:$static_offsets,
|
||||
DenseI64ArrayAttr:$static_sizes,
|
||||
DenseI64ArrayAttr:$static_strides
|
||||
);
|
||||
let results = (outs AnyRankedTensor:$result);
|
||||
|
||||
let builders = [
|
||||
// Build an ExtractSliceOp with mixed static and dynamic entries and custom
|
||||
// result type. If the type passed is nullptr, it is inferred.
|
||||
OpBuilder<(ins "RankedTensorType":$resultType, "Value":$source,
|
||||
"ArrayRef<OpFoldResult>":$offsets, "ArrayRef<OpFoldResult>":$sizes,
|
||||
"ArrayRef<OpFoldResult>":$strides,
|
||||
CArg<"ArrayRef<NamedAttribute>", "{}">:$attrs)>,
|
||||
];
|
||||
|
||||
let extraClassDeclaration = [{
|
||||
/// Return the number of leading operands before the `offsets`, `sizes` and
|
||||
/// and `strides` operands.
|
||||
static unsigned getOffsetSizeAndStrideStartOperandIndex() { return 1; }
|
||||
|
||||
/// Returns the type of the base tensor operand.
|
||||
RankedTensorType getSourceType() {
|
||||
return getSource().getType().cast<RankedTensorType>();
|
||||
}
|
||||
|
||||
std::array<unsigned, 3> getArrayAttrMaxRanks() {
|
||||
unsigned rank = getSourceType().getRank();
|
||||
return {rank, rank, rank};
|
||||
}
|
||||
}];
|
||||
|
||||
let assemblyFormat = [{
|
||||
$source ``
|
||||
custom<DynamicIndexList>($offsets, $static_offsets)
|
||||
custom<DynamicIndexList>($sizes, $static_sizes)
|
||||
custom<DynamicIndexList>($strides, $static_strides)
|
||||
attr-dict `:` type($source) `to` type($result)
|
||||
}];
|
||||
}
|
||||
|
||||
//
|
||||
|
||||
def TTG_InsertSliceAsyncOp : TTG_Op<"insert_slice_async",
|
||||
[AttrSizedOperandSegments,
|
||||
ResultsAreSharedEncodingROCM,
|
||||
// TODO: Check if MemWrite will degrade performance of non-warp-specialized kernel
|
||||
MemoryEffects<[MemRead, MemWrite]>,
|
||||
TypesMatchWith<"infer mask type from src type",
|
||||
"src", "mask", "getI1SameShape($_self)",
|
||||
"($_op.getOperands().size() <= 3) || std::equal_to<>()">,
|
||||
TypesMatchWith<"infer other type from src type",
|
||||
"src", "other", "getPointeeType($_self)",
|
||||
"($_op.getOperands().size() <= 4) || std::equal_to<>()">]> {
|
||||
let summary = "insert slice async";
|
||||
|
||||
let description = [{
|
||||
This operation inserts a tensor `$src` into another tensor `$dst` as specified by the operation’s
|
||||
`$index` argument and `$axis` attribute.
|
||||
|
||||
It returns a copy of `$dst` with the proper slice updated asynchronously with the value of `$src`.
|
||||
This operation is non-blocking, and `$results` will have the updated value after the corresponding async_wait.
|
||||
|
||||
When converting from `tt.load` to `triton_gpu_rocm.insert_slice_async`, the `$evict`, `$cache`, and `$isVolatile` fields
|
||||
might be ignored on certain hardware. For example, on NVIDIA GPUs, the cache policy is determined by the backend,
|
||||
and `$evict` and `$isVolatile` are ignored because they apply to L1 cache only.
|
||||
|
||||
The insert_slice_async operation supports the following arguments:
|
||||
|
||||
* src: the tensor that is inserted.
|
||||
* dst: the tensor into which the `$src` tensor is inserted.
|
||||
* index: the index of the `$src` tensor at the given `$axis` from which the `$dst` tensor is inserted into
|
||||
* mask: optional tensor-rank number of boolean masks which specify which
|
||||
elements of the `$src` tensor are inserted into the `$dst` tensor.
|
||||
* other: optional tensor-rank number of other tensors which specify what
|
||||
values are inserted into the `$dst` tensor if the corresponding
|
||||
element of the `$mask` tensor is false.
|
||||
|
||||
In the future, we may decompose this operation into a sequence of:
|
||||
|
||||
* `async` operation to specify a sequence of asynchronous operations
|
||||
* `load` operation to load a tensor from global memory
|
||||
* `insert_slice` operations to insert the `$src` tensor into the `$dst` tensor
|
||||
|
||||
Example:
|
||||
|
||||
```
|
||||
%1 = triton_gpu_rocm.alloc_tensor : tensor<2x32xf32>
|
||||
%2 = triton_gpu_rocm.insert_slice_async %0, %1, %index { axis = 0 } : tensor<32x!tt.ptr<f32>, #AL> -> tensor<2x32xf32, #A>
|
||||
triton_gpu_rocm.async_wait { num = 0 : i32 }
|
||||
```
|
||||
}];
|
||||
|
||||
let arguments = (ins TT_PtrLike:$src, TT_Tensor:$dst, I32:$index,
|
||||
Optional<I1Tensor>:$mask, Optional<TT_Type>:$other,
|
||||
TT_CacheModifierAttr:$cache, TT_EvictionPolicyAttr:$evict,
|
||||
BoolAttr:$isVolatile, I32Attr:$axis);
|
||||
|
||||
let builders = [
|
||||
OpBuilder<(ins "Value":$src, "Value":$dst, "Value":$index,
|
||||
"triton::CacheModifier":$cache,
|
||||
"triton::EvictionPolicy":$evict, "bool":$isVolatile, "int":$axis)>,
|
||||
OpBuilder<(ins "Value":$src, "Value":$dst, "Value":$index, "Value":$mask,
|
||||
"triton::CacheModifier":$cache,
|
||||
"triton::EvictionPolicy":$evict, "bool":$isVolatile, "int":$axis)>,
|
||||
OpBuilder<(ins "Value":$src, "Value":$dst, "Value":$index,
|
||||
"Value":$mask, "Value":$other,
|
||||
"triton::CacheModifier":$cache,
|
||||
"triton::EvictionPolicy":$evict, "bool":$isVolatile, "int":$axis)>,
|
||||
];
|
||||
|
||||
let results = (outs TT_Tensor:$result);
|
||||
|
||||
//let assemblyFormat = [{
|
||||
// $src `,` $dst ``
|
||||
// $index, $mask, $other
|
||||
// attr-dict `:` type($src) `->` type($dst)
|
||||
//}];
|
||||
|
||||
let extraClassDeclaration = [{
|
||||
static DenseSet<unsigned> getEligibleLoadByteWidth(int computeCapability) {
|
||||
DenseSet<unsigned> validLoadBytes;
|
||||
if (computeCapability >= 80) {
|
||||
validLoadBytes = {4, 8, 16};
|
||||
}
|
||||
return validLoadBytes;
|
||||
}
|
||||
}];
|
||||
|
||||
let hasCustomAssemblyFormat = 1;
|
||||
}
|
||||
|
||||
def TTG_AllocTensorOp : TTG_Op<"alloc_tensor", [MemoryEffects<[MemAlloc]>, // Allocate shared memory
|
||||
ResultsAreSharedEncodingROCM]> {
|
||||
let summary = "allocate tensor";
|
||||
|
||||
let description = [{
|
||||
This operation defines a tensor of a particular shape.
|
||||
The contents of the tensor are supposed to be in shared memory.
|
||||
|
||||
Note: This op can be repalced to a `bufferization.alloc_tensor` in LLVM 16.
|
||||
}];
|
||||
|
||||
let assemblyFormat = [{attr-dict `:` type($result)}];
|
||||
|
||||
let results = (outs TT_Tensor:$result);
|
||||
}
|
||||
|
||||
#endif
|
||||
26
python/triton/third_party/hip/include/triton/Dialect/TritonGPUROCM/IR/TritonGPUTypes.td
vendored
Normal file
26
python/triton/third_party/hip/include/triton/Dialect/TritonGPUROCM/IR/TritonGPUTypes.td
vendored
Normal file
@@ -0,0 +1,26 @@
|
||||
#ifndef TRITON_DIALECT_TRITONGPUROCM_TYPES
|
||||
#define TRITON_DIALECT_TRITONGPUROCM_TYPES
|
||||
|
||||
include "triton/Dialect/TritonGPUROCM/IR/TritonGPUDialect.td"
|
||||
include "mlir/IR/AttrTypeBase.td"
|
||||
|
||||
class TTG_TypeDef<string name, string _mnemonic, list<Trait> traits = []>
|
||||
: TypeDef<TritonGPUROCM_Dialect, name, traits> {
|
||||
let mnemonic = _mnemonic;
|
||||
}
|
||||
|
||||
def TTG_TokenType : TTG_TypeDef<"Token", "token"> {
|
||||
let parameters = (ins "int32_t":$type);
|
||||
|
||||
let builders = [
|
||||
TypeBuilder<(ins "unsigned":$type), [{
|
||||
return $_get($_ctxt, type);
|
||||
}]>
|
||||
];
|
||||
|
||||
let hasCustomAssemblyFormat = 1;
|
||||
|
||||
let skipDefaultBuilders = 1;
|
||||
}
|
||||
|
||||
#endif
|
||||
10
python/triton/third_party/hip/include/triton/Dialect/TritonGPUROCM/IR/Types.h
vendored
Normal file
10
python/triton/third_party/hip/include/triton/Dialect/TritonGPUROCM/IR/Types.h
vendored
Normal file
@@ -0,0 +1,10 @@
|
||||
#ifndef TRITONGPUROCM_IR_TYPES_H_
|
||||
#define TRITONGPUROCM_IR_TYPES_H_
|
||||
|
||||
#include "mlir/IR/TypeSupport.h"
|
||||
#include "mlir/IR/Types.h"
|
||||
|
||||
#define GET_TYPEDEF_CLASSES
|
||||
#include "triton/Dialect/TritonGPUROCM/IR/Types.h.inc"
|
||||
|
||||
#endif // TRITON_IR_TYPES_H_
|
||||
3
python/triton/third_party/hip/include/triton/Dialect/TritonGPUROCM/Transforms/CMakeLists.txt
vendored
Normal file
3
python/triton/third_party/hip/include/triton/Dialect/TritonGPUROCM/Transforms/CMakeLists.txt
vendored
Normal file
@@ -0,0 +1,3 @@
|
||||
set(LLVM_TARGET_DEFINITIONS Passes.td)
|
||||
mlir_tablegen(Passes.h.inc -gen-pass-decls -name TritonGPUROCM)
|
||||
add_public_tablegen_target(TritonGPUROCMTransformsIncGen)
|
||||
42
python/triton/third_party/hip/include/triton/Dialect/TritonGPUROCM/Transforms/Passes.h
vendored
Normal file
42
python/triton/third_party/hip/include/triton/Dialect/TritonGPUROCM/Transforms/Passes.h
vendored
Normal file
@@ -0,0 +1,42 @@
|
||||
#ifndef TRITON_DIALECT_TRITONGPUROCM_TRANSFORMS_PASSES_H_
|
||||
#define TRITON_DIALECT_TRITONGPUROCM_TRANSFORMS_PASSES_H_
|
||||
|
||||
#include "mlir/Pass/Pass.h"
|
||||
#include "triton/Dialect/TritonNvidiaGPU/IR/Dialect.h"
|
||||
|
||||
namespace mlir {
|
||||
|
||||
std::unique_ptr<Pass> createTritonGPUROCMPipelinePass(int numStages = 3,
|
||||
int numWarps = 4,
|
||||
int numCTAs = 1,
|
||||
int computeCapability = 80);
|
||||
|
||||
std::unique_ptr<Pass> createTritonGPUROCMStreamPipelinePass();
|
||||
|
||||
std::unique_ptr<Pass>
|
||||
createTritonGPUROCMAccelerateMatmulPass(int computeCapability = 80);
|
||||
|
||||
std::unique_ptr<Pass> createTritonGPUROCMPrefetchPass();
|
||||
|
||||
std::unique_ptr<Pass> createTritonGPUROCMCanonicalizeLoopsPass();
|
||||
|
||||
std::unique_ptr<Pass> createTritonGPUROCMCoalescePass();
|
||||
|
||||
std::unique_ptr<Pass> createTritonGPUROCMReorderInstructionsPass();
|
||||
|
||||
std::unique_ptr<Pass> createTritonGPUROCMDecomposeConversionsPass();
|
||||
|
||||
std::unique_ptr<Pass> createTritonGPUROCMRemoveLayoutConversionsPass();
|
||||
|
||||
std::unique_ptr<Pass> createTritonGPUROCMVerifier();
|
||||
|
||||
std::unique_ptr<Pass> createTritonGPUROCMOptimizeDotOperandsPass();
|
||||
|
||||
std::unique_ptr<Pass> createTritonGPUROCMOptimizeEpiloguePass();
|
||||
|
||||
/// Generate the code for registering passes.
|
||||
#define GEN_PASS_REGISTRATION
|
||||
#include "triton/Dialect/TritonGPUROCM/Transforms/Passes.h.inc"
|
||||
|
||||
} // namespace mlir
|
||||
#endif
|
||||
166
python/triton/third_party/hip/include/triton/Dialect/TritonGPUROCM/Transforms/Passes.td
vendored
Normal file
166
python/triton/third_party/hip/include/triton/Dialect/TritonGPUROCM/Transforms/Passes.td
vendored
Normal file
@@ -0,0 +1,166 @@
|
||||
#ifndef TRITON_DIALECT_TRITONGPUROCM_TRANSFORMS_PASSES
|
||||
#define TRITON_DIALECT_TRITONGPUROCM_TRANSFORMS_PASSES
|
||||
|
||||
include "mlir/Pass/PassBase.td"
|
||||
|
||||
def TritonGPUPipeline : Pass<"tritongpu-pipeline", "mlir::ModuleOp"> {
|
||||
let summary = "pipeline";
|
||||
|
||||
let description = [{
|
||||
Replace `LoadOp` in loops by `InsertSliceAsyncOp` instructions that asynchronously construct the data
|
||||
needed at the next iteration
|
||||
}];
|
||||
|
||||
let constructor = "mlir::createTritonGPUROCMPipelinePass()";
|
||||
|
||||
let dependentDialects = ["mlir::triton::gpu_rocm::TritonGPUROCMDialect",
|
||||
"mlir::triton::nvidia_gpu::TritonNvidiaGPUDialect",
|
||||
"mlir::scf::SCFDialect",
|
||||
"mlir::arith::ArithDialect"];
|
||||
|
||||
let options = [
|
||||
Option<"numStages", "num-stages",
|
||||
"int32_t", /*default*/"3",
|
||||
"number of pipeline stages">,
|
||||
Option<"numWarps", "num-warps",
|
||||
"int32_t", /*default*/"4",
|
||||
"number of warps per block">,
|
||||
Option<"numCTAs", "num-ctas",
|
||||
"int32_t", /*default*/"1",
|
||||
"number of CTAs per CGA">,
|
||||
Option<"computeCapability", "compute-capability",
|
||||
"int32_t", /*default*/"80",
|
||||
"device compute capability">
|
||||
];
|
||||
}
|
||||
|
||||
def TritonGPUStreamPipeline : Pass<"tritongpu-stream-pipeline", "mlir::ModuleOp"> {
|
||||
let summary = "pipeline";
|
||||
|
||||
let description = [{
|
||||
Pipeline global loads through registers to shared memory while computing on previous
|
||||
tile
|
||||
}];
|
||||
|
||||
let constructor = "mlir::createTritonGPUROCMStreamPipelinePass()";
|
||||
|
||||
let dependentDialects = ["mlir::triton::gpu_rocm::TritonGPUROCMDialect",
|
||||
"mlir::scf::SCFDialect",
|
||||
"mlir::arith::ArithDialect"];
|
||||
}
|
||||
|
||||
def TritonGPUPrefetch : Pass<"tritongpu-prefetch", "mlir::ModuleOp"> {
|
||||
let summary = "prefetch";
|
||||
|
||||
let description = [{
|
||||
Decompose `DotOp` instructions in loops into several finer-grained `DotOp`
|
||||
that may have their operands constructed at the end of the previous iteration
|
||||
}];
|
||||
|
||||
let constructor = "mlir::createTritonGPUROCMPrefetchPass()";
|
||||
|
||||
let dependentDialects = ["mlir::triton::gpu_rocm::TritonGPUROCMDialect",
|
||||
"mlir::scf::SCFDialect",
|
||||
"mlir::arith::ArithDialect"];
|
||||
}
|
||||
|
||||
def TritonGPUAccelerateMatmul : Pass<"tritongpu-accelerate-matmul", "mlir::ModuleOp"> {
|
||||
let summary = "accelerate matmul";
|
||||
|
||||
let description = [{
|
||||
Optimize the input/output layout of `dot` instruction to make them compatible hardware accelerators
|
||||
(e.g., Nvidia tensor cores)
|
||||
}];
|
||||
|
||||
let constructor = "mlir::createTritonGPUROCMAccelerateMatmulPass()";
|
||||
|
||||
let dependentDialects = ["mlir::triton::gpu_rocm::TritonGPUROCMDialect",
|
||||
"mlir::triton::nvidia_gpu::TritonNvidiaGPUDialect",
|
||||
"mlir::triton::TritonDialect"];
|
||||
|
||||
let options = [
|
||||
Option<"computeCapability", "compute-capability",
|
||||
"int32_t", /*default*/"80",
|
||||
"device compute capability">
|
||||
];
|
||||
}
|
||||
|
||||
def TritonGPUOptimizeDotOperands : Pass<"tritongpu-optimize-dot-operands", "mlir::ModuleOp"> {
|
||||
let summary = "fuse transpositions";
|
||||
|
||||
let description = [{
|
||||
Re-arranged layouts of tensors used as matrix multiplication operands so as to promote the use of
|
||||
hardware-accelerated transpositions.
|
||||
}];
|
||||
|
||||
let constructor = "mlir::createTritonGPUROCMOptimizeDotOperandsPass()";
|
||||
|
||||
let dependentDialects = ["mlir::triton::gpu_rocm::TritonGPUROCMDialect",
|
||||
"mlir::triton::nvidia_gpu::TritonNvidiaGPUDialect",
|
||||
"mlir::triton::TritonDialect"];
|
||||
}
|
||||
|
||||
def TritonGPUCoalesce: Pass<"tritongpu-coalesce", "mlir::ModuleOp"> {
|
||||
let summary = "coalesce";
|
||||
|
||||
let description = [{
|
||||
TODO
|
||||
}];
|
||||
|
||||
let constructor = "mlir::createTritonGPUROCMCoalescePass()";
|
||||
|
||||
let dependentDialects = ["mlir::triton::gpu_rocm::TritonGPUROCMDialect"];
|
||||
}
|
||||
|
||||
|
||||
def TritonGPURemoveLayoutConversions : Pass<"tritongpu-remove-layout-conversions", "mlir::ModuleOp"> {
|
||||
let summary = "remove superfluous layout conversions";
|
||||
|
||||
let description = [{
|
||||
}];
|
||||
|
||||
let constructor = "mlir::createTritonGPUROCMRemoveLayoutConversionsPass()";
|
||||
|
||||
let dependentDialects = ["mlir::triton::gpu_rocm::TritonGPUROCMDialect",
|
||||
"mlir::triton::TritonDialect"];
|
||||
|
||||
}
|
||||
|
||||
def TritonGPUOptimizeEpilogue : Pass<"tritongpu-optimize-epilogue", "mlir::ModuleOp"> {
|
||||
let summary = "Optimize epilogue: (1) Store accumulators directly without going thorough SMEM in epilogue.";
|
||||
|
||||
let description = [{
|
||||
}];
|
||||
|
||||
let constructor = "mlir::createTritonGPUROCMOptimizeEpiloguePass()";
|
||||
|
||||
let dependentDialects = ["mlir::triton::gpu_rocm::TritonGPUROCMDialect",
|
||||
"mlir::triton::TritonDialect"];
|
||||
|
||||
}
|
||||
|
||||
def TritonGPUReorderInstructions: Pass<"tritongpu-reorder-instructions", "mlir::ModuleOp"> {
|
||||
let summary = "Reorder instructions";
|
||||
|
||||
let description = "This pass reorder instructions so as to (1) decrease register pressure (e.g., by moving "
|
||||
"conversions from shared memory before their first use) and (2) promote LLVM instruction "
|
||||
"order more friendly to `ptxas`.";
|
||||
|
||||
let constructor = "mlir::createTritonGPUROCMReorderInstructionsPass()";
|
||||
|
||||
let dependentDialects = ["mlir::triton::gpu_rocm::TritonGPUROCMDialect",
|
||||
"mlir::triton::TritonDialect"];
|
||||
}
|
||||
|
||||
def TritonGPUDecomposeConversions: Pass<"tritongpu-decompose-conversions", "mlir::ModuleOp"> {
|
||||
let summary = "Decompose convert[distributed -> dotOperand] into convert[distributed -> shared -> dotOperand]";
|
||||
|
||||
let description = "Decomposing conversions this way makes it possible to use CSE and re-use #shared tensors";
|
||||
|
||||
let constructor = "mlir::createTritonGPUROCMDecomposeConversionsPass()";
|
||||
|
||||
let dependentDialects = ["mlir::triton::gpu_rocm::TritonGPUROCMDialect",
|
||||
"mlir::triton::TritonDialect"];
|
||||
}
|
||||
|
||||
#endif
|
||||
@@ -0,0 +1,38 @@
|
||||
//===----------------------------------------------------------------------===//
|
||||
//
|
||||
// Defines utilities to use while converting to the TritonGPU dialect.
|
||||
//
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
||||
#ifndef TRITON_DIALECT_TRITONGPUROCM_TRANSFORMS_TRITONGPUCONVERSION_H_
|
||||
#define TRITON_DIALECT_TRITONGPUROCM_TRANSFORMS_TRITONGPUCONVERSION_H_
|
||||
|
||||
#include "mlir/Transforms/DialectConversion.h"
|
||||
|
||||
namespace mlir {
|
||||
|
||||
class TritonGPUTypeConverter : public TypeConverter {
|
||||
public:
|
||||
TritonGPUTypeConverter(MLIRContext *context, int numWarps, int threadsPerWarp,
|
||||
int numCTAs);
|
||||
int getNumWarps() const { return numWarps; }
|
||||
int getThreadsPerWarp() const { return threadsPerWarp; }
|
||||
int getNumCTAs() const { return numCTAs; }
|
||||
|
||||
private:
|
||||
MLIRContext *context;
|
||||
int numWarps;
|
||||
int threadsPerWarp;
|
||||
int numCTAs;
|
||||
};
|
||||
|
||||
class TritonGPUConversionTarget : public ConversionTarget {
|
||||
|
||||
public:
|
||||
explicit TritonGPUConversionTarget(MLIRContext &ctx,
|
||||
TritonGPUTypeConverter &typeConverter);
|
||||
};
|
||||
|
||||
} // namespace mlir
|
||||
|
||||
#endif // TRITON_DIALECT_TRITONGPUROCM_TRANSFORMS_TRITONGPUCONVERSION_H_
|
||||
153
python/triton/third_party/hip/include/triton/Dialect/TritonGPUROCM/Transforms/Utility.h
vendored
Normal file
153
python/triton/third_party/hip/include/triton/Dialect/TritonGPUROCM/Transforms/Utility.h
vendored
Normal file
@@ -0,0 +1,153 @@
|
||||
#ifndef TRITON_DIALECT_TRITONGPUROCM_TRANSFORMS_UTILITY_H_
|
||||
#define TRITON_DIALECT_TRITONGPUROCM_TRANSFORMS_UTILITY_H_
|
||||
|
||||
#include "mlir/IR/Matchers.h"
|
||||
#include "mlir/IR/PatternMatch.h"
|
||||
#include "llvm/ADT/MapVector.h"
|
||||
|
||||
#include "triton/Dialect/Triton/IR/Dialect.h"
|
||||
#include "triton/Dialect/TritonGPUROCM/IR/Dialect.h"
|
||||
|
||||
namespace mlir {
|
||||
|
||||
namespace triton {
|
||||
class LoadOp;
|
||||
class StoreOp;
|
||||
class FuncOp;
|
||||
namespace gpu_rocm {
|
||||
class SharedEncodingAttr;
|
||||
}
|
||||
} // namespace triton
|
||||
|
||||
SmallVector<unsigned, 3> mmaVersionToInstrShape(int version,
|
||||
const ArrayRef<int64_t> &shape,
|
||||
RankedTensorType type);
|
||||
|
||||
/// Returns true if the Load is for TMA
|
||||
bool isLoadFromTensorPtr(triton::LoadOp op);
|
||||
|
||||
/// Returns true if the store is for TMA
|
||||
bool isStoreToTensorPtr(triton::StoreOp op);
|
||||
|
||||
/// Return the first consumer of v
|
||||
Operation *getFirstUser(Value v);
|
||||
|
||||
/// Return the proper SharedEncodingAttr according to shape/order
|
||||
triton::gpu_rocm::SharedEncodingAttr getSharedEncoding(RankedTensorType tensorTy);
|
||||
|
||||
/* Dump Triton IR in graphviz dot format.
|
||||
*
|
||||
* You can override `onValue` and `onOperation` in a subclass to mark
|
||||
* specific Values and Operations. The below subclass
|
||||
* GraphLayoutMarker is an example.
|
||||
*
|
||||
* Default NodeInfo for Value nodes:
|
||||
* {{"shape": "box"},
|
||||
* {"style", "filled"},
|
||||
* {"fillcolor", "white"},
|
||||
* {"label", shapeStr}}
|
||||
*
|
||||
* Default NodeInfo for Operation nodes:
|
||||
* {{"shape": "ellipse"},
|
||||
* {"style", "filled"},
|
||||
* {"fillcolor", "white"},
|
||||
* {"label", operationName}}
|
||||
*
|
||||
* If the key "label" is not set by `onValue` or `onOperation`, default labels
|
||||
* will be generated. For Value node, the default label is the shape string and
|
||||
* for Operation node, it is the operation name.
|
||||
*
|
||||
* Reference:
|
||||
* https://graphviz.org/doc/info/shapes.html
|
||||
* https://graphviz.org/doc/info/colors.html
|
||||
*
|
||||
* Usage:
|
||||
* C++: GraphDumper().dumpToFile(func, "func.dot");
|
||||
* Shell: dot -Tjpg func.dot -o func.jpg
|
||||
*/
|
||||
class GraphDumper {
|
||||
public:
|
||||
using NodeInfo = std::map<std::string, std::string>;
|
||||
|
||||
// Override this function to mark specific Values
|
||||
virtual NodeInfo onValue(Value value) const;
|
||||
// Override this function to mark specific Operations
|
||||
virtual NodeInfo onOperation(Operation *op) const;
|
||||
|
||||
std::string dump(triton::FuncOp func) const;
|
||||
void dumpToFile(triton::FuncOp func, const std::string &filename) const;
|
||||
|
||||
protected:
|
||||
std::string getShapeStr(const Type &type) const;
|
||||
|
||||
std::string getUniqueId(Value value) const;
|
||||
std::string getUniqueId(Operation *op) const;
|
||||
|
||||
std::string emitNode(const std::string &id, const NodeInfo style) const;
|
||||
std::string emitEdge(const std::string &srcId,
|
||||
const std::string &destId) const;
|
||||
|
||||
std::string emitValueNode(Value value) const;
|
||||
std::string emitOperationNode(Operation *op) const;
|
||||
};
|
||||
|
||||
/* A subclass of GraphDumper that marks different layout kinds in different
|
||||
* colors.*/
|
||||
class GraphLayoutMarker : public GraphDumper {
|
||||
public:
|
||||
NodeInfo onValue(Value value) const override;
|
||||
|
||||
protected:
|
||||
std::string getColor(const Type &type) const;
|
||||
};
|
||||
|
||||
// Infers the encoding of the result of op given the source encoding.
|
||||
std::optional<Attribute> inferDstEncoding(Operation *op, Attribute encoding);
|
||||
|
||||
// Infers the encoding of the source of op given the result encoding.
|
||||
std::optional<Attribute> inferSrcEncoding(Operation *op, Attribute encoding);
|
||||
|
||||
bool isExpensiveLoadOrStore(Operation *op);
|
||||
|
||||
bool canFoldIntoConversion(Operation *op, Attribute targetEncoding);
|
||||
|
||||
Operation *cloneWithInferType(mlir::OpBuilder &rewriter, Operation *op,
|
||||
IRMapping &mapping);
|
||||
|
||||
// Get backward slice of tensor values starting from the root node along with
|
||||
// encoding propagation.
|
||||
LogicalResult getConvertBackwardSlice(
|
||||
Value root, SetVector<Value> &slice, Attribute rootEncoding,
|
||||
DenseMap<Value, Attribute> &layout,
|
||||
std::function<bool(Operation *)> stopPropagation = nullptr);
|
||||
|
||||
// Populate pattern to remove dead cycles in ForOp.
|
||||
void populateForOpDeadArgumentElimination(RewritePatternSet &patterns);
|
||||
|
||||
// Convert an \param index to a multi-dim coordinate given \param shape and
|
||||
// \param order.
|
||||
SmallVector<Value> delinearize(OpBuilder &b, Location loc, Value linear,
|
||||
ArrayRef<unsigned> shape,
|
||||
ArrayRef<unsigned> order);
|
||||
|
||||
SmallVector<Value> delinearize(OpBuilder &b, Location loc, unsigned linear,
|
||||
ArrayRef<unsigned> shape);
|
||||
|
||||
SmallVector<Value> delinearize(OpBuilder &b, Location loc, Value linear,
|
||||
ArrayRef<unsigned> shape);
|
||||
Value linearize(OpBuilder &b, Location loc, ArrayRef<Value> multiDim,
|
||||
ArrayRef<unsigned> shape, ArrayRef<unsigned> order);
|
||||
|
||||
Value linearize(OpBuilder &b, Location loc, ArrayRef<Value> multiDim,
|
||||
ArrayRef<unsigned> shape);
|
||||
|
||||
// Returns null if the op is not inside a agent region (warp specialization
|
||||
// mode). Note that there should be at most one agent id attached to the
|
||||
// operation.
|
||||
std::optional<int> getWSAgentId(Operation *op);
|
||||
std::optional<int> getWSRoleId(Operation *op);
|
||||
void setRoleId(Operation *op, int roleId);
|
||||
|
||||
} // namespace mlir
|
||||
|
||||
#endif // TRITON_DIALECT_TRITONGPUROCM_TRANSFORMS_UTILITY_H_
|
||||
54
python/triton/third_party/hip/include/triton/Target/HSACO/HSACOTranslation.h
vendored
Normal file
54
python/triton/third_party/hip/include/triton/Target/HSACO/HSACOTranslation.h
vendored
Normal file
@@ -0,0 +1,54 @@
|
||||
#ifndef TRITON_TARGET_HSACOTRANSLATION_H
|
||||
#define TRITON_TARGET_HSACOTRANSLATION_H
|
||||
|
||||
#include <memory>
|
||||
#include <string>
|
||||
#include <tuple>
|
||||
#include <vector>
|
||||
|
||||
namespace mlir {
|
||||
class ModuleOp;
|
||||
}
|
||||
|
||||
namespace llvm {
|
||||
class Module;
|
||||
class LLVMContext;
|
||||
} // namespace llvm
|
||||
|
||||
namespace mlir {
|
||||
namespace triton {
|
||||
|
||||
// add external libs to modules
|
||||
void addExternalLibs(mlir::ModuleOp &module,
|
||||
const std::vector<std::string> &names,
|
||||
const std::vector<std::string> &paths);
|
||||
|
||||
// Translate Triton dialect to TritonGPU, return null if failed.
|
||||
void translateTritonToTritonGPUROCM(mlir::ModuleOp &module, int computeCapability,
|
||||
int numWarps, int numStages);
|
||||
|
||||
// Translate Triton GPU to mlir LLVM dialect, return null if failed.
|
||||
void translateTritonGPUROCMToLLVMDialect(mlir::ModuleOp &module,
|
||||
int computeCapability, bool isROCM);
|
||||
|
||||
// Translate mlir LLVM dialect to LLVMIR, return null if failed.
|
||||
std::unique_ptr<llvm::Module>
|
||||
translateLLVMDialectToLLVMIR(llvm::LLVMContext *llvmContext,
|
||||
mlir::ModuleOp module, bool isROCM);
|
||||
|
||||
// Translate LLVM IR to HSACO code.
|
||||
std::tuple<std::string, std::string>
|
||||
translateLLVMIRToHSACO(llvm::Module &module, std::string gfx_arch,
|
||||
std::string gfx_triple, std::string gfx_features);
|
||||
|
||||
std::tuple<std::string, std::string>
|
||||
translateTritonIRToHSACO(mlir::ModuleOp module, std::string gfx_arch,
|
||||
std::string gfx_triple, std::string gfx_features,
|
||||
int numWarps, int numStages,
|
||||
const std::vector<std::string> &names,
|
||||
const std::vector<std::string> &paths);
|
||||
|
||||
} // namespace triton
|
||||
} // namespace mlir
|
||||
|
||||
#endif
|
||||
71
python/triton/third_party/hip/lib/AnalysisROCM/Alias.cpp
vendored
Normal file
71
python/triton/third_party/hip/lib/AnalysisROCM/Alias.cpp
vendored
Normal file
@@ -0,0 +1,71 @@
|
||||
#include "triton/AnalysisROCM/Alias.h"
|
||||
#include "mlir/Dialect/Tensor/IR/Tensor.h"
|
||||
#include "triton/AnalysisROCM/Utility.h"
|
||||
#include "triton/Dialect/TritonGPUROCM/IR/Dialect.h"
|
||||
#include "triton/Dialect/TritonNvidiaGPU/IR/Dialect.h"
|
||||
|
||||
namespace mlir {
|
||||
|
||||
AliasInfo AliasInfo::join(const AliasInfo &lhs, const AliasInfo &rhs) {
|
||||
if (lhs == rhs)
|
||||
return lhs;
|
||||
AliasInfo ret;
|
||||
for (auto value : lhs.allocs) {
|
||||
ret.insert(value);
|
||||
}
|
||||
for (auto value : rhs.allocs) {
|
||||
ret.insert(value);
|
||||
}
|
||||
return ret;
|
||||
}
|
||||
|
||||
void SharedMemoryAliasAnalysis::visitOperation(
|
||||
Operation *op, ArrayRef<const dataflow::Lattice<AliasInfo> *> operands,
|
||||
ArrayRef<dataflow::Lattice<AliasInfo> *> results) {
|
||||
AliasInfo aliasInfo;
|
||||
bool pessimistic = true;
|
||||
if (maybeSharedAllocationOp(op)) {
|
||||
// These ops may allocate a new shared memory buffer.
|
||||
auto result = op->getResult(0);
|
||||
// XXX(Keren): the following ops are always aliasing for now
|
||||
if (isa<triton::gpu_rocm::ExtractSliceOp, triton::TransOp,
|
||||
triton::nvidia_gpu::ExtractMBarrierOp>(op)) {
|
||||
// extract_slice %src
|
||||
// trans %src
|
||||
aliasInfo = AliasInfo(operands[0]->getValue());
|
||||
pessimistic = false;
|
||||
} else if (isa<tensor::InsertSliceOp, triton::gpu_rocm::InsertSliceAsyncOp,
|
||||
triton::nvidia_gpu::InsertSliceAsyncV2Op>(op)) {
|
||||
// insert_slice_async %src, %dst, %index
|
||||
// insert_slice %src into %dst[%offsets]
|
||||
aliasInfo = AliasInfo(operands[1]->getValue());
|
||||
pessimistic = false;
|
||||
} else if (isa<triton::nvidia_gpu::StoreAsyncOp>(op)) {
|
||||
aliasInfo = AliasInfo(operands[0]->getValue());
|
||||
pessimistic = false;
|
||||
} else if (triton::gpu_rocm::isSharedEncoding(result)) {
|
||||
aliasInfo.insert(result);
|
||||
pessimistic = false;
|
||||
}
|
||||
}
|
||||
|
||||
if (pessimistic) {
|
||||
return setAllToEntryStates(results);
|
||||
}
|
||||
// Join all lattice elements
|
||||
for (auto *result : results)
|
||||
propagateIfChanged(result, result->join(aliasInfo));
|
||||
}
|
||||
|
||||
AliasResult SharedMemoryAliasAnalysis::alias(Value lhs, Value rhs) {
|
||||
// TODO: implement
|
||||
return AliasResult::MayAlias;
|
||||
}
|
||||
|
||||
ModRefResult SharedMemoryAliasAnalysis::getModRef(Operation *op,
|
||||
Value location) {
|
||||
// TODO: implement
|
||||
return ModRefResult::getModAndRef();
|
||||
}
|
||||
|
||||
} // namespace mlir
|
||||
721
python/triton/third_party/hip/lib/AnalysisROCM/Allocation.cpp
vendored
Normal file
721
python/triton/third_party/hip/lib/AnalysisROCM/Allocation.cpp
vendored
Normal file
@@ -0,0 +1,721 @@
|
||||
#include "triton/AnalysisROCM/Allocation.h"
|
||||
#include "mlir/Analysis/DataFlowFramework.h"
|
||||
#include "mlir/Analysis/Liveness.h"
|
||||
#include "mlir/Analysis/SliceAnalysis.h"
|
||||
#include "mlir/Dialect/Tensor/IR/Tensor.h"
|
||||
#include "triton/AnalysisROCM/Alias.h"
|
||||
#include "triton/Dialect/TritonGPUROCM/IR/Dialect.h"
|
||||
#include "llvm/ADT/SmallVector.h"
|
||||
|
||||
#include <algorithm>
|
||||
#include <limits>
|
||||
#include <numeric>
|
||||
|
||||
using ::mlir::triton::gpu_rocm::BlockedEncodingAttr;
|
||||
using ::mlir::triton::gpu_rocm::DotOperandEncodingAttr;
|
||||
using ::mlir::triton::gpu_rocm::getContigPerThread;
|
||||
using ::mlir::triton::gpu_rocm::getOrder;
|
||||
using ::mlir::triton::gpu_rocm::getShapePerCTA;
|
||||
using ::mlir::triton::gpu_rocm::getShapePerCTATile;
|
||||
using ::mlir::triton::gpu_rocm::getSizePerThread;
|
||||
using ::mlir::triton::gpu_rocm::MfmaEncodingAttr;
|
||||
using ::mlir::triton::gpu_rocm::MmaEncodingAttr;
|
||||
using ::mlir::triton::gpu_rocm::SharedEncodingAttr;
|
||||
using ::mlir::triton::gpu_rocm::SliceEncodingAttr;
|
||||
|
||||
namespace mlir {
|
||||
|
||||
//===----------------------------------------------------------------------===//
|
||||
// Shared Memory Allocation Analysis
|
||||
//===----------------------------------------------------------------------===//
|
||||
namespace triton {
|
||||
|
||||
// Bitwidth of pointers
|
||||
constexpr int kPtrBitWidth = 64;
|
||||
|
||||
static std::pair<SmallVector<unsigned>, SmallVector<unsigned>>
|
||||
getCvtOrder(Attribute srcLayout, Attribute dstLayout) {
|
||||
auto srcMmaLayout = srcLayout.dyn_cast<MmaEncodingAttr>();
|
||||
auto srcDotLayout = srcLayout.dyn_cast<DotOperandEncodingAttr>();
|
||||
auto dstMmaLayout = dstLayout.dyn_cast<MmaEncodingAttr>();
|
||||
auto dstDotLayout = dstLayout.dyn_cast<DotOperandEncodingAttr>();
|
||||
assert(!(srcMmaLayout && dstMmaLayout) &&
|
||||
"Unexpected mma -> mma layout conversion");
|
||||
// mma or dot layout does not have an order, so the order depends on the
|
||||
// layout of the other operand.
|
||||
auto inOrd = (srcMmaLayout || srcDotLayout) ? getOrder(dstLayout)
|
||||
: getOrder(srcLayout);
|
||||
auto outOrd = (dstMmaLayout || dstDotLayout) ? getOrder(srcLayout)
|
||||
: getOrder(dstLayout);
|
||||
|
||||
return {inOrd, outOrd};
|
||||
}
|
||||
|
||||
SmallVector<unsigned>
|
||||
getScratchConfigForCvtLayout(triton::gpu_rocm::ConvertLayoutOp op, unsigned &inVec,
|
||||
unsigned &outVec) {
|
||||
auto srcTy = op.getSrc().getType().cast<RankedTensorType>();
|
||||
auto dstTy = op.getResult().getType().cast<RankedTensorType>();
|
||||
Attribute srcLayout = srcTy.getEncoding();
|
||||
Attribute dstLayout = dstTy.getEncoding();
|
||||
|
||||
if (shouldUseDistSmem(srcLayout, dstLayout)) {
|
||||
// TODO: padding to avoid bank conflicts
|
||||
return convertType<unsigned, int64_t>(getShapePerCTA(srcTy));
|
||||
}
|
||||
|
||||
// MmaToDotShortcut and MmaToMmaShortcut doesn't use shared mem
|
||||
if (auto srcMmaLayout = srcLayout.dyn_cast<MmaEncodingAttr>()) {
|
||||
if (dstLayout.isa<DotOperandEncodingAttr>()) {
|
||||
if (isMmaToDotShortcut(srcTy, dstTy)) {
|
||||
return {};
|
||||
}
|
||||
} else if (auto dstMmaLayout = dstLayout.dyn_cast<MmaEncodingAttr>()) {
|
||||
if (isMmaToMmaShortcut(srcTy, dstTy)) {
|
||||
return {};
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
#if 1
|
||||
if (srcLayout.isa<MfmaEncodingAttr>() &&
|
||||
srcLayout.dyn_cast<MfmaEncodingAttr>().getIsTransposed() &&
|
||||
dstLayout.isa<DotOperandEncodingAttr>())
|
||||
if (isMfmaToDotShortcut(srcTy, dstTy))
|
||||
return {};
|
||||
#endif
|
||||
|
||||
assert(srcLayout && dstLayout &&
|
||||
"Unexpected layout in getScratchConfigForCvtLayout()");
|
||||
auto [inOrd, outOrd] = getCvtOrder(srcLayout, dstLayout);
|
||||
unsigned srcContigPerThread = getContigPerThread(srcLayout)[inOrd[0]];
|
||||
unsigned dstContigPerThread = getContigPerThread(dstLayout)[outOrd[0]];
|
||||
// TODO: Fix the legacy issue that ourOrd[0] == 0 always means
|
||||
// that we cannot do vectorization.
|
||||
inVec = outOrd[0] == 0 ? 1 : inOrd[0] == 0 ? 1 : srcContigPerThread;
|
||||
outVec = outOrd[0] == 0 ? 1 : dstContigPerThread;
|
||||
|
||||
auto srcShapePerCTA = getShapePerCTA(srcTy);
|
||||
auto dstShapePerCTA = getShapePerCTA(dstTy);
|
||||
auto srcShapePerCTATile = getShapePerCTATile(srcLayout, srcTy.getShape());
|
||||
auto dstShapePerCTATile = getShapePerCTATile(dstLayout, dstTy.getShape());
|
||||
|
||||
unsigned rank = dstTy.getRank();
|
||||
SmallVector<unsigned> paddedRepShape(rank);
|
||||
unsigned pad = std::max(inVec, outVec);
|
||||
for (unsigned d = 0; d < rank; ++d) {
|
||||
paddedRepShape[d] =
|
||||
std::max(std::min<unsigned>(srcShapePerCTA[d], srcShapePerCTATile[d]),
|
||||
std::min<unsigned>(dstShapePerCTA[d], dstShapePerCTATile[d]));
|
||||
}
|
||||
if (rank == 1)
|
||||
return paddedRepShape;
|
||||
unsigned paddedDim = 1;
|
||||
if (auto dstBlockedLayout = dstLayout.dyn_cast<BlockedEncodingAttr>()) {
|
||||
paddedDim = dstBlockedLayout.getOrder()[0];
|
||||
}
|
||||
paddedRepShape[paddedDim] += pad;
|
||||
return paddedRepShape;
|
||||
}
|
||||
|
||||
SmallVector<unsigned>
|
||||
getScratchConfigForStoreAsync(triton::nvidia_gpu::StoreAsyncOp op) {
|
||||
auto srcTy = op.getSrc().getType().cast<RankedTensorType>();
|
||||
return convertType<unsigned, int64_t>(getShapePerCTA(srcTy));
|
||||
}
|
||||
|
||||
// TODO: extend beyond scalars
|
||||
SmallVector<unsigned> getScratchConfigForAtomicRMW(triton::AtomicRMWOp op) {
|
||||
SmallVector<unsigned> smemShape;
|
||||
if (op.getPtr().getType().isa<RankedTensorType>()) {
|
||||
// do nothing or just assert because shared memory is not used in tensor up
|
||||
// to now
|
||||
} else {
|
||||
// need only bytes for scalar
|
||||
// always vec = 1 and elemsPerThread = 1 for scalar?
|
||||
smemShape.push_back(1);
|
||||
}
|
||||
return smemShape;
|
||||
}
|
||||
|
||||
SmallVector<unsigned> getScratchConfigForAtomicCAS(triton::AtomicCASOp op) {
|
||||
return SmallVector<unsigned>{1};
|
||||
}
|
||||
|
||||
class AllocationAnalysis {
|
||||
public:
|
||||
AllocationAnalysis(Operation *operation,
|
||||
Allocation::FuncAllocMapT *funcAllocMap,
|
||||
Allocation *allocation)
|
||||
: operation(operation), funcAllocMap(funcAllocMap),
|
||||
allocation(allocation) {
|
||||
run();
|
||||
}
|
||||
|
||||
private:
|
||||
using BufferT = Allocation::BufferT;
|
||||
|
||||
/// Value -> Liveness Range
|
||||
using IntervalT = Interval<size_t>;
|
||||
/// Use MapVector to ensure determinism.
|
||||
using BufferRangeMapT = llvm::MapVector<BufferT *, IntervalT>;
|
||||
/// Nodes -> Nodes
|
||||
using GraphT = DenseMap<BufferT *, DenseSet<BufferT *>>;
|
||||
|
||||
/// Set of Liveness Intervals
|
||||
class LivenessR : public SmallVector<IntervalT, 4> {
|
||||
public:
|
||||
LivenessR() = default;
|
||||
LivenessR(const LivenessR &) = default;
|
||||
|
||||
/// Disjointness
|
||||
bool isDisjoint() const {
|
||||
if (size() < 2)
|
||||
return false;
|
||||
// sorted so the first OOB proves disjoint
|
||||
auto maxId = (*this)[0].end();
|
||||
for (auto rng : *this) {
|
||||
if (rng.start() <= maxId) {
|
||||
// adjoining
|
||||
maxId = std::max(maxId, rng.end());
|
||||
} else
|
||||
return true;
|
||||
}
|
||||
return false;
|
||||
}
|
||||
|
||||
void sort() {
|
||||
llvm::sort(*this, [](const auto &lhs, const auto &rhs) {
|
||||
return lhs.start() <= rhs.start();
|
||||
});
|
||||
}
|
||||
|
||||
bool addAdjacent(size_t id) {
|
||||
bool isAdjacent = false;
|
||||
for (auto &interval : *this) {
|
||||
if (interval.adjacent(id)) {
|
||||
isAdjacent = true;
|
||||
interval = interval.merge(IntervalT(id));
|
||||
}
|
||||
}
|
||||
return isAdjacent;
|
||||
}
|
||||
|
||||
void add(size_t id) {
|
||||
if (!addAdjacent(id))
|
||||
push_back(IntervalT(id));
|
||||
}
|
||||
IntervalT unionize() const {
|
||||
IntervalT res;
|
||||
if (size()) {
|
||||
res = front();
|
||||
for (auto &I : *this)
|
||||
res = res.merge(I);
|
||||
}
|
||||
return res;
|
||||
}
|
||||
};
|
||||
|
||||
typedef function_ref<LivenessR(Value value)> LivenessF;
|
||||
|
||||
void run() {
|
||||
getValuesAndSizes();
|
||||
resolveLiveness();
|
||||
computeOffsets();
|
||||
}
|
||||
|
||||
/// Initializes explicitly defined shared memory values for a given operation.
|
||||
void getExplicitValueSize(Operation *op) {
|
||||
// Values returned from scf.yield will not be allocated even though they
|
||||
// have the shared encoding.
|
||||
// For example: %a = scf.if -> yield
|
||||
// %a must be allocated elsewhere by other operations.
|
||||
// FIXME(Keren): extract and insert are always alias for now
|
||||
if (!maybeSharedAllocationOp(op) || maybeAliasOp(op))
|
||||
return;
|
||||
|
||||
// XXX(Keren): Why this hard-coded alignment?
|
||||
size_t kAlignment = 8;
|
||||
for (Value result : op->getResults()) {
|
||||
if (triton::gpu_rocm::isSharedEncoding(result)) {
|
||||
// Bytes could be a different value once we support padding or other
|
||||
// allocation policies.
|
||||
auto tensorType = result.getType().dyn_cast<RankedTensorType>();
|
||||
auto shapePerCTA = triton::gpu_rocm::getShapePerCTA(tensorType);
|
||||
auto bytes = product<int64_t>(shapePerCTA) *
|
||||
tensorType.getElementTypeBitWidth() / 8;
|
||||
|
||||
// XXX(Keren): magic numbers 256 and 1024
|
||||
// benzh@maybe alignment should be passed in.
|
||||
// Software swizzling calculates phase based on offset, while hardware
|
||||
// swizzling do that based on physical address. Thus only by setting the
|
||||
// alignment to 1024 can ensure the correctness.
|
||||
if (bytes > 256)
|
||||
kAlignment = 1024;
|
||||
allocation->addBuffer<BufferT::BufferKind::Explicit>(result, bytes,
|
||||
kAlignment);
|
||||
}
|
||||
}
|
||||
if (isa<triton::nvidia_gpu::AllocMBarrierOp>(op)) {
|
||||
Value result = op->getResult(0);
|
||||
if (!result.getType().isa<RankedTensorType>())
|
||||
// In case AllocMBarrierOp is allocating scalar mbarriers
|
||||
allocation->addBuffer<BufferT::BufferKind::Explicit>(result, 8,
|
||||
kAlignment);
|
||||
}
|
||||
}
|
||||
|
||||
template <BufferT::BufferKind T>
|
||||
void maybeAddScratchBuffer(Operation *op, unsigned bytes,
|
||||
unsigned alignment) {
|
||||
if (bytes > 0)
|
||||
allocation->addBuffer<T>(op, bytes, alignment);
|
||||
}
|
||||
|
||||
template <BufferT::BufferKind T>
|
||||
void maybeAddScratchBuffer(Operation *op, unsigned bytes) {
|
||||
if (bytes > 0)
|
||||
allocation->addBuffer<T>(op, bytes);
|
||||
}
|
||||
|
||||
/// Initializes temporary shared memory for a given operation.
|
||||
void getScratchValueSize(Operation *op) {
|
||||
const size_t scratchAlignment = 128;
|
||||
if (auto reduceOp = dyn_cast<triton::ReduceOp>(op)) {
|
||||
ReduceOpHelper helper(reduceOp);
|
||||
unsigned bytes = helper.getScratchSizeInBytes();
|
||||
maybeAddScratchBuffer<BufferT::BufferKind::Scratch>(op, bytes,
|
||||
scratchAlignment);
|
||||
} else if (auto scanOp = dyn_cast<triton::ScanOp>(op)) {
|
||||
ScanLoweringHelper helper(scanOp);
|
||||
unsigned bytes = helper.getScratchSizeInBytes();
|
||||
maybeAddScratchBuffer<BufferT::BufferKind::Scratch>(op, bytes,
|
||||
scratchAlignment);
|
||||
} else if (auto cvtLayout = dyn_cast<triton::gpu_rocm::ConvertLayoutOp>(op)) {
|
||||
auto srcTy = cvtLayout.getSrc().getType().cast<RankedTensorType>();
|
||||
auto dstTy = cvtLayout.getResult().getType().cast<RankedTensorType>();
|
||||
auto srcEncoding = srcTy.getEncoding();
|
||||
auto dstEncoding = dstTy.getEncoding();
|
||||
if (srcEncoding.isa<SharedEncodingAttr>() ||
|
||||
dstEncoding.isa<SharedEncodingAttr>()) {
|
||||
// Conversions from/to shared memory do not need scratch memory.
|
||||
return;
|
||||
}
|
||||
// ConvertLayoutOp with both input/output non-shared_layout
|
||||
// TODO: Besides of implementing ConvertLayoutOp via shared memory, it's
|
||||
// also possible to realize it with other approaches in restricted
|
||||
// conditions, such as warp-shuffle
|
||||
unsigned inVec = 0;
|
||||
unsigned outVec = 0;
|
||||
auto smemShape = getScratchConfigForCvtLayout(cvtLayout, inVec, outVec);
|
||||
unsigned elems = std::accumulate(smemShape.begin(), smemShape.end(), 1,
|
||||
std::multiplies{});
|
||||
auto bytes =
|
||||
srcTy.getElementType().isa<triton::PointerType>()
|
||||
? elems * kPtrBitWidth / 8
|
||||
: elems * std::max<int>(8, srcTy.getElementTypeBitWidth()) / 8;
|
||||
maybeAddScratchBuffer<BufferT::BufferKind::Scratch>(op, bytes,
|
||||
scratchAlignment);
|
||||
} else if (auto storeAsyncOp =
|
||||
dyn_cast<triton::nvidia_gpu::StoreAsyncOp>(op)) {
|
||||
auto srcTy = storeAsyncOp.getSrc().getType().cast<RankedTensorType>();
|
||||
auto srcEncoding = srcTy.getEncoding();
|
||||
if (!srcEncoding.isa<MmaEncodingAttr>()) {
|
||||
return;
|
||||
}
|
||||
auto smemShape = getScratchConfigForStoreAsync(storeAsyncOp);
|
||||
unsigned elems = std::accumulate(smemShape.begin(), smemShape.end(), 1,
|
||||
std::multiplies{});
|
||||
auto bytes = elems * std::max<int>(8, srcTy.getElementTypeBitWidth()) / 8;
|
||||
maybeAddScratchBuffer<BufferT::BufferKind::Scratch>(op, bytes, 1024);
|
||||
} else if (auto atomicRMWOp = dyn_cast<triton::AtomicRMWOp>(op)) {
|
||||
auto value = op->getOperand(0);
|
||||
// only scalar requires scratch memory
|
||||
// make it explicit for readability
|
||||
if (value.getType().dyn_cast<RankedTensorType>()) {
|
||||
// nothing to do
|
||||
} else {
|
||||
auto smemShape = getScratchConfigForAtomicRMW(atomicRMWOp);
|
||||
unsigned elems = std::accumulate(smemShape.begin(), smemShape.end(), 1,
|
||||
std::multiplies{});
|
||||
auto elemTy =
|
||||
value.getType().cast<triton::PointerType>().getPointeeType();
|
||||
auto bytes =
|
||||
elemTy.isa<triton::PointerType>()
|
||||
? elems * kPtrBitWidth / 8
|
||||
: elems * std::max<int>(8, elemTy.getIntOrFloatBitWidth()) / 8;
|
||||
maybeAddScratchBuffer<BufferT::BufferKind::Scratch>(op, bytes,
|
||||
scratchAlignment);
|
||||
}
|
||||
} else if (auto atomicCASOp = dyn_cast<triton::AtomicCASOp>(op)) {
|
||||
auto value = op->getOperand(0);
|
||||
auto smemShape = getScratchConfigForAtomicCAS(atomicCASOp);
|
||||
unsigned elems = std::accumulate(smemShape.begin(), smemShape.end(), 1,
|
||||
std::multiplies{});
|
||||
auto elemTy =
|
||||
value.getType().cast<triton::PointerType>().getPointeeType();
|
||||
auto bytes = elemTy.isa<triton::PointerType>()
|
||||
? elems * kPtrBitWidth / 8
|
||||
: elems * elemTy.getIntOrFloatBitWidth() / 8;
|
||||
maybeAddScratchBuffer<BufferT::BufferKind::Scratch>(op, bytes,
|
||||
scratchAlignment);
|
||||
} else if (auto callOp = dyn_cast<CallOpInterface>(op)) {
|
||||
auto callable = callOp.resolveCallable();
|
||||
auto funcOp = dyn_cast<FunctionOpInterface>(callable);
|
||||
auto *funcAlloc = &(*funcAllocMap)[funcOp];
|
||||
auto bytes = funcAlloc->getSharedMemorySize();
|
||||
maybeAddScratchBuffer<BufferT::BufferKind::Virtual>(op, bytes,
|
||||
scratchAlignment);
|
||||
}
|
||||
}
|
||||
|
||||
void getValueAlias(Value value, SharedMemoryAliasAnalysis &analysis) {
|
||||
dataflow::Lattice<AliasInfo> *latticeElement =
|
||||
analysis.getLatticeElement(value);
|
||||
if (latticeElement) {
|
||||
AliasInfo &info = latticeElement->getValue();
|
||||
if (!info.getAllocs().empty()) {
|
||||
for (auto alloc : info.getAllocs()) {
|
||||
allocation->addAlias(value, alloc);
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
/// Extract all shared memory values and their sizes
|
||||
void getValuesAndSizes() {
|
||||
// Get the alloc values
|
||||
operation->walk<WalkOrder::PreOrder>([&](Operation *op) {
|
||||
getExplicitValueSize(op);
|
||||
getScratchValueSize(op);
|
||||
});
|
||||
// Get the alias values
|
||||
std::unique_ptr<DataFlowSolver> solver = createDataFlowSolver();
|
||||
SharedMemoryAliasAnalysis *aliasAnalysis =
|
||||
solver->load<SharedMemoryAliasAnalysis>();
|
||||
if (failed(solver->initializeAndRun(operation))) {
|
||||
// TODO: return error instead of bailing out..
|
||||
llvm_unreachable("failed to run SharedMemoryAliasAnalysis");
|
||||
}
|
||||
operation->walk<WalkOrder::PreOrder>([&](Operation *op) {
|
||||
for (auto operand : op->getOperands()) {
|
||||
getValueAlias(operand, *aliasAnalysis);
|
||||
}
|
||||
for (auto value : op->getResults()) {
|
||||
getValueAlias(value, *aliasAnalysis);
|
||||
}
|
||||
});
|
||||
}
|
||||
|
||||
/// Computes the liveness range of the allocated value.
|
||||
/// Each buffer is allocated only once.
|
||||
void resolveExplicitBufferLiveness(LivenessF getLiveness) {
|
||||
for (auto valueBufferIter : allocation->valueBuffer) {
|
||||
auto value = valueBufferIter.first;
|
||||
auto *buffer = valueBufferIter.second;
|
||||
auto ranges = getLiveness(value);
|
||||
bufferRange[buffer] = ranges.unionize();
|
||||
}
|
||||
}
|
||||
|
||||
/// Extends the liveness range by unionizing the liveness range of the aliased
|
||||
/// values because each allocated buffer could be an alias of others, if block
|
||||
/// arguments are involved.
|
||||
/// Only unionize adjacent live ranges to account for loop-carried buffers that
|
||||
/// are mutually exclusive.
|
||||
/// Example from stream pipeliner:
|
||||
/// 3 %b0 = convert_layout %g0 -+
|
||||
/// 4 %fr = for (.., %arg0 = %b0) { |
|
||||
/// 5 %gn = load %pc |
|
||||
/// 6 %bc = convert_layout %arg0 -+
|
||||
/// 7 %v = add %bc, ...
|
||||
/// 8 %bn = convert_layout %gn -+
|
||||
/// 9 %pn = addptr %pc, %cst |
|
||||
/// 10 } |
|
||||
/// 11 %be = convert_layout %fr#1 -+
|
||||
/// 12 %ve = add %be
|
||||
void resolveAliasBufferLiveness(LivenessF getLiveness) {
|
||||
for (auto aliasBufferIter : allocation->aliasBuffer) {
|
||||
auto value = aliasBufferIter.first;
|
||||
auto buffers = aliasBufferIter.second;
|
||||
auto aranges = getLiveness(value);
|
||||
bool disjoint = aranges.isDisjoint();
|
||||
for (auto *buffer : buffers) {
|
||||
auto range = aranges[0];
|
||||
if (bufferRange.count(buffer)) {
|
||||
auto brange = bufferRange[buffer];
|
||||
if (disjoint) {
|
||||
// find adjacent/intersecting
|
||||
for (auto arange : aranges) {
|
||||
if (arange.adjacent(brange) ||
|
||||
arange.intersects(brange))
|
||||
brange = arange.merge(brange);
|
||||
}
|
||||
range = brange;
|
||||
} else {
|
||||
// Extend the allocated buffer's range
|
||||
range = range.merge(brange);
|
||||
}
|
||||
}
|
||||
bufferRange[buffer] = range;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
/// Computes the liveness range of scratched buffers.
|
||||
/// Some operations may have a temporary buffer that is not explicitly
|
||||
/// allocated, but is used to store intermediate results.
|
||||
void resolveScratchBufferLiveness(
|
||||
const DenseMap<Operation *, size_t> &operationId) {
|
||||
// Analyze liveness of scratch buffers and vritual buffers.
|
||||
auto processScratchMemory = [&](const auto &container) {
|
||||
for (auto opScratchIter : container) {
|
||||
// Any scratch memory's live range is the current operation's live
|
||||
// range.
|
||||
auto *op = opScratchIter.first;
|
||||
auto *buffer = opScratchIter.second;
|
||||
bufferRange.insert({buffer, Interval(operationId.lookup(op),
|
||||
operationId.lookup(op) + 1)});
|
||||
}
|
||||
};
|
||||
processScratchMemory(allocation->opScratch);
|
||||
processScratchMemory(allocation->opVirtual);
|
||||
}
|
||||
|
||||
/// Resolves liveness of all values involved under the root operation.
|
||||
void resolveLiveness() {
|
||||
// Assign an ID to each operation using post-order traversal.
|
||||
// To achieve the correct liveness range, the parent operation's ID
|
||||
// should be greater than each of its child operation's ID .
|
||||
// Example:
|
||||
// ...
|
||||
// %5 = triton.convert_layout %4
|
||||
// %6 = scf.for ... iter_args(%arg0 = %0) -> (i32) {
|
||||
// %2 = triton.convert_layout %5
|
||||
// ...
|
||||
// scf.yield %arg0
|
||||
// }
|
||||
// For example, %5 is defined in the parent region and used in
|
||||
// the child region, and is not passed as a block argument.
|
||||
// %6 should should have an ID greater than its child operations,
|
||||
// otherwise %5 liveness range ends before the child operation's liveness
|
||||
// range ends.
|
||||
DenseMap<Operation *, size_t> operationId;
|
||||
operation->walk<WalkOrder::PostOrder>(
|
||||
[&](Operation *op) { operationId[op] = operationId.size(); });
|
||||
|
||||
// Analyze liveness of explicit buffers
|
||||
Liveness liveness(operation);
|
||||
auto getValueLivenessRange = [&](Value value) {
|
||||
LivenessR ranges;
|
||||
// Shared memory allocated by mbarrier cannot be reused
|
||||
if (value.getDefiningOp() &&
|
||||
isa<triton::nvidia_gpu::AllocMBarrierOp>(value.getDefiningOp())) {
|
||||
ranges.push_back(Interval(std::numeric_limits<size_t>::min(),
|
||||
std::numeric_limits<size_t>::max()));
|
||||
return ranges;
|
||||
}
|
||||
|
||||
auto liveOperations = liveness.resolveLiveness(value);
|
||||
std::for_each(liveOperations.begin(), liveOperations.end(),
|
||||
[&](Operation *liveOp) {
|
||||
ranges.add(operationId[liveOp]);
|
||||
});
|
||||
ranges.sort();
|
||||
return ranges;
|
||||
};
|
||||
|
||||
resolveExplicitBufferLiveness(getValueLivenessRange);
|
||||
resolveAliasBufferLiveness(getValueLivenessRange);
|
||||
resolveScratchBufferLiveness(operationId);
|
||||
}
|
||||
|
||||
/// Computes the shared memory offsets for all related values.
|
||||
/// Paper: Algorithms for Compile-Time Memory Optimization
|
||||
/// (https://www.cs.utexas.edu/users/harrison/papers/compile-time.pdf)
|
||||
void computeOffsets() {
|
||||
SmallVector<BufferT *> buffers;
|
||||
for (auto bufferIter : bufferRange) {
|
||||
buffers.emplace_back(bufferIter.first);
|
||||
}
|
||||
|
||||
DenseMap<BufferT *, size_t> bufferStart;
|
||||
calculateStarts(buffers, bufferStart);
|
||||
|
||||
// NOTE: The original paper doesn't consider interference between
|
||||
// the bumped ranges. Buffers that previously do not interfere with
|
||||
// could interfere after offset bumping if their liveness ranges overlap.
|
||||
// Therefore, we rerun the interference graph algorithm after bumping so
|
||||
// that we regroup the buffers and color them again. Since we always
|
||||
// increase the buffer offset and keep reducing conflicts, we will
|
||||
// eventually reach a fixed point.
|
||||
GraphT interference;
|
||||
buildInterferenceGraph(buffers, bufferStart, interference);
|
||||
do {
|
||||
allocate(buffers, interference, bufferStart);
|
||||
buildInterferenceGraph(buffers, bufferStart, interference);
|
||||
} while (!interference.empty());
|
||||
}
|
||||
|
||||
/// Computes the initial shared memory offsets.
|
||||
void calculateStarts(const SmallVector<BufferT *> &buffers,
|
||||
DenseMap<BufferT *, size_t> &bufferStart) {
|
||||
// v = values in shared memory
|
||||
// t = triplet of (size, start, end)
|
||||
// shared memory space
|
||||
// -
|
||||
// | *******t4
|
||||
// | /|\ v2 inserts t4, t5, and t6
|
||||
// | |
|
||||
// | ******t5 ************t6
|
||||
// | ^^^^^v2^^^^^^
|
||||
// | | *********************t2
|
||||
// | \|/ v2 erases t1
|
||||
// | ******t1 ^^^^^^^^^v1^^^^^^^^^ ************t3
|
||||
// |---------------------------------------------| liveness range
|
||||
// 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 ...
|
||||
// If the available triple's range is less than a given buffer range,
|
||||
// we won't know if there has been an overlap without using graph coloring.
|
||||
// Start -> Liveness Range
|
||||
using TripleMapT = std::multimap<size_t, IntervalT>;
|
||||
TripleMapT tripleMap;
|
||||
tripleMap.insert(std::make_pair(0, IntervalT()));
|
||||
SmallVector<BufferT *> xBuffers = buffers;
|
||||
while (!xBuffers.empty()) {
|
||||
auto tripleIt = tripleMap.begin();
|
||||
auto size = tripleIt->first;
|
||||
auto range = tripleIt->second;
|
||||
tripleMap.erase(tripleIt);
|
||||
auto bufferIt =
|
||||
std::find_if(xBuffers.begin(), xBuffers.end(), [&](auto *buffer) {
|
||||
auto xRange = bufferRange[buffer];
|
||||
bool res = xRange.intersects(range);
|
||||
for (auto val : tripleMap)
|
||||
res = res &&
|
||||
!val.second.intersects(xRange); // only one buffer intersect
|
||||
return res;
|
||||
});
|
||||
if (bufferIt != xBuffers.end()) {
|
||||
auto buffer = *bufferIt;
|
||||
auto xSize = buffer->size;
|
||||
auto xRange = bufferRange.lookup(buffer);
|
||||
// TODO(Keren): A buffer's size shouldn't be determined here, have to
|
||||
// clean it up
|
||||
size_t alignment = buffer->alignment;
|
||||
size_t alignSize = ((size + alignment - 1) / alignment) * alignment;
|
||||
bufferStart[buffer] = alignSize;
|
||||
tripleMap.insert({alignSize + xSize,
|
||||
Interval{std::max(range.start(), xRange.start()),
|
||||
std::min(range.end(), xRange.end())}});
|
||||
// We could either insert (range.start, xRange.start) or (range.start,
|
||||
// xRange.end), both are correct and determine the potential buffer
|
||||
// offset, and the graph coloring algorithm will solve the interference,
|
||||
// if any
|
||||
if (range.start() < xRange.start())
|
||||
tripleMap.insert({size, Interval{range.start(), xRange.end()}});
|
||||
if (xRange.end() < range.end())
|
||||
tripleMap.insert({size, Interval{xRange.start(), range.end()}});
|
||||
xBuffers.erase(bufferIt);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
/// Builds a graph of all shared memory values. Edges are created between
|
||||
/// shared memory values that are overlapping.
|
||||
void buildInterferenceGraph(const SmallVector<BufferT *> &buffers,
|
||||
const DenseMap<BufferT *, size_t> &bufferStart,
|
||||
GraphT &interference) {
|
||||
// Reset interference graph
|
||||
interference.clear();
|
||||
for (auto x : buffers) {
|
||||
for (auto y : buffers) {
|
||||
if (x == y)
|
||||
continue;
|
||||
auto xStart = bufferStart.lookup(x);
|
||||
auto yStart = bufferStart.lookup(y);
|
||||
auto xSize = x->size;
|
||||
auto ySize = y->size;
|
||||
Interval xSizeRange = {xStart, xStart + xSize};
|
||||
Interval ySizeRange = {yStart, yStart + ySize};
|
||||
auto xOpRange = bufferRange.lookup(x);
|
||||
auto yOpRange = bufferRange.lookup(y);
|
||||
if (xOpRange.intersects(yOpRange) &&
|
||||
xSizeRange.intersects(ySizeRange)) {
|
||||
interference[x].insert(y);
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
/// Finalizes shared memory offsets considering interference.
|
||||
void allocate(const SmallVector<BufferT *> &buffers,
|
||||
const GraphT &interference,
|
||||
DenseMap<BufferT *, size_t> &bufferStart) {
|
||||
// Reset shared memory size
|
||||
allocation->sharedMemorySize = 0;
|
||||
// First-fit graph coloring
|
||||
// Neighbors are nodes that interfere with each other.
|
||||
// We color a node by finding the index of the first available
|
||||
// non-neighboring node or the first neighboring node without any color.
|
||||
// Nodes with the same color do not interfere with each other.
|
||||
DenseMap<BufferT *, int> colors;
|
||||
for (auto value : buffers) {
|
||||
colors[value] = (value == buffers[0]) ? 0 : -1;
|
||||
}
|
||||
SmallVector<bool> available(buffers.size());
|
||||
for (auto x : buffers) {
|
||||
std::fill(available.begin(), available.end(), true);
|
||||
for (auto y : interference.lookup(x)) {
|
||||
int color = colors[y];
|
||||
if (color >= 0) {
|
||||
available[color] = false;
|
||||
}
|
||||
}
|
||||
auto it = std::find(available.begin(), available.end(), true);
|
||||
colors[x] = std::distance(available.begin(), it);
|
||||
}
|
||||
// Finalize allocation
|
||||
// color0: [0, 7), [0, 8), [0, 15) -> [0, 7), [0, 8), [0, 15)
|
||||
// color1: [7, 9) -> [0 + 1 * 15, 9 + 1 * 15) -> [15, 24)
|
||||
// color2: [8, 12) -> [8 + 2 * 15, 12 + 2 * 15) -> [38, 42)
|
||||
// TODO(Keren): We are wasting memory here.
|
||||
// Nodes with color2 can actually start with 24.
|
||||
for (auto x : buffers) {
|
||||
size_t adj = 0;
|
||||
for (auto y : interference.lookup(x)) {
|
||||
adj = std::max(adj, bufferStart.lookup(y) + y->size);
|
||||
}
|
||||
x->offset = bufferStart.lookup(x) + colors.lookup(x) * adj;
|
||||
bufferStart[x] = x->offset;
|
||||
allocation->sharedMemorySize =
|
||||
std::max(allocation->sharedMemorySize, x->offset + x->size);
|
||||
}
|
||||
}
|
||||
|
||||
void dump() const {
|
||||
llvm::outs() << "DUMP: " << "\n";
|
||||
for (auto bufferIter : bufferRange) {
|
||||
|
||||
llvm::outs() << "ID= " << bufferIter.first->id << "\n";
|
||||
// llvm::outs() << " Kind= " << kind << "\n";
|
||||
llvm::outs() << " Size= " << bufferIter.first->size << "\n";
|
||||
llvm::outs() << " Offs= " << bufferIter.first->offset << "\n";
|
||||
llvm::outs() << " -> " << bufferIter.second.start() << "\n";
|
||||
llvm::outs() << " -> " << bufferIter.second.end() << "\n";
|
||||
}
|
||||
}
|
||||
|
||||
private:
|
||||
Operation *operation;
|
||||
Allocation::FuncAllocMapT *funcAllocMap;
|
||||
Allocation *allocation;
|
||||
BufferRangeMapT bufferRange;
|
||||
};
|
||||
|
||||
} // namespace triton
|
||||
|
||||
void Allocation::run(FuncAllocMapT &funcAllocMap) {
|
||||
triton::AllocationAnalysis(getOperation(), &funcAllocMap, this);
|
||||
}
|
||||
|
||||
} // namespace mlir
|
||||
1041
python/triton/third_party/hip/lib/AnalysisROCM/AxisInfo.cpp
vendored
Normal file
1041
python/triton/third_party/hip/lib/AnalysisROCM/AxisInfo.cpp
vendored
Normal file
File diff suppressed because it is too large
Load Diff
18
python/triton/third_party/hip/lib/AnalysisROCM/CMakeLists.txt
vendored
Normal file
18
python/triton/third_party/hip/lib/AnalysisROCM/CMakeLists.txt
vendored
Normal file
@@ -0,0 +1,18 @@
|
||||
add_mlir_library(TritonAnalysisROCM
|
||||
AxisInfo.cpp
|
||||
Allocation.cpp
|
||||
Membar.cpp
|
||||
Alias.cpp
|
||||
Utility.cpp
|
||||
|
||||
DEPENDS
|
||||
TritonTableGen
|
||||
TritonGPUROCMAttrDefsIncGen
|
||||
|
||||
LINK_LIBS PUBLIC
|
||||
MLIRAnalysis
|
||||
MLIRLLVMDialect
|
||||
TritonIR
|
||||
TritonGPUROCMIR
|
||||
TritonNvidiaGPUIR
|
||||
)
|
||||
213
python/triton/third_party/hip/lib/AnalysisROCM/Membar.cpp
vendored
Normal file
213
python/triton/third_party/hip/lib/AnalysisROCM/Membar.cpp
vendored
Normal file
@@ -0,0 +1,213 @@
|
||||
#include "triton/AnalysisROCM/Membar.h"
|
||||
#include "triton/AnalysisROCM/Alias.h"
|
||||
#include "triton/Dialect/TritonGPUROCM/IR/Dialect.h"
|
||||
|
||||
#include "../lib/Conversion/TritonGPUROCMToLLVM/Utility.h"
|
||||
#include "mlir/Dialect/Func/IR/FuncOps.h"
|
||||
#include "triton/Conversion/TritonGPUROCMToLLVM/PTXAsmFormat.h"
|
||||
#include "triton/Dialect/TritonGPUROCM/Transforms/Utility.h"
|
||||
#include "triton/Dialect/TritonNvidiaGPU/Transforms/Utility.h"
|
||||
|
||||
#include "mlir/Dialect/GPU/IR/GPUDialect.h"
|
||||
#include "mlir/Dialect/Tensor/IR/Tensor.h"
|
||||
#include <deque>
|
||||
|
||||
namespace mlir {
|
||||
|
||||
void MembarAnalysis::run(FuncBlockInfoMapT &funcBlockInfoMap) {
|
||||
FunctionOpInterface funcOp =
|
||||
dyn_cast<FunctionOpInterface>(allocation->getOperation());
|
||||
OpBuilder builder(funcOp.getContext());
|
||||
resolve(funcOp, &funcBlockInfoMap, &builder);
|
||||
}
|
||||
|
||||
void MembarAnalysis::resolve(FunctionOpInterface funcOp,
|
||||
FuncBlockInfoMapT *funcBlockInfoMap,
|
||||
OpBuilder *builder) {
|
||||
// Initialize the blockList
|
||||
DenseMap<Block *, BlockInfo> inputBlockInfoMap;
|
||||
DenseMap<Block *, BlockInfo> outputBlockInfoMap;
|
||||
std::deque<Block *> blockList;
|
||||
funcOp.walk<WalkOrder::PreOrder>([&](Block *block) {
|
||||
for (auto &op : block->getOperations()) {
|
||||
// Check if the operation belongs to scf dialect, if so, we need to
|
||||
// throw an error
|
||||
if (op.getDialect()->getNamespace() == "scf") {
|
||||
llvm::report_fatal_error(
|
||||
"scf dialect is not supported in membar. Please lower it "
|
||||
"to cf dialect first.");
|
||||
return;
|
||||
}
|
||||
}
|
||||
if (block->isEntryBlock())
|
||||
blockList.emplace_back(block);
|
||||
});
|
||||
|
||||
// A fixed point algorithm
|
||||
while (!blockList.empty()) {
|
||||
auto *block = blockList.front();
|
||||
blockList.pop_front();
|
||||
// Make a copy of the inputblockInfo but not update
|
||||
auto inputBlockInfo = inputBlockInfoMap[block];
|
||||
SmallVector<Block *> successors;
|
||||
for (auto &op : block->getOperations()) {
|
||||
if (op.hasTrait<OpTrait::IsTerminator>()) {
|
||||
visitTerminator(&op, successors);
|
||||
} else {
|
||||
update(&op, &inputBlockInfo, funcBlockInfoMap, builder);
|
||||
}
|
||||
}
|
||||
// Get the reference because we want to update if it changed
|
||||
if (outputBlockInfoMap.count(block) &&
|
||||
inputBlockInfo == outputBlockInfoMap[block]) {
|
||||
// If we have seen the block before and the inputBlockInfo is the same as
|
||||
// the outputBlockInfo, we skip the successors
|
||||
continue;
|
||||
}
|
||||
// Update the current block
|
||||
outputBlockInfoMap[block].join(inputBlockInfo);
|
||||
// Update the successors
|
||||
for (auto *successor : successors) {
|
||||
inputBlockInfoMap[successor].join(outputBlockInfoMap[block]);
|
||||
blockList.emplace_back(successor);
|
||||
}
|
||||
}
|
||||
|
||||
// Update the final dangling buffers that haven't been synced
|
||||
auto &funcBlockInfo = (*funcBlockInfoMap)[funcOp];
|
||||
funcOp.walk<WalkOrder::PreOrder>([&](Block *block) {
|
||||
block->walk([&](triton::ReturnOp returnOp) {
|
||||
funcBlockInfo.join(outputBlockInfoMap[block]);
|
||||
});
|
||||
});
|
||||
}
|
||||
|
||||
void MembarAnalysis::visitTerminator(Operation *op,
|
||||
SmallVector<Block *> &successors) {
|
||||
if (auto branchInterface = dyn_cast<BranchOpInterface>(op)) {
|
||||
Block *parentBlock = branchInterface->getBlock();
|
||||
successors.append(std::begin(parentBlock->getSuccessors()),
|
||||
std::end(parentBlock->getSuccessors()));
|
||||
return;
|
||||
}
|
||||
// Otherwise, it could be a return op
|
||||
if (isa<triton::ReduceReturnOp>(op) || isa<triton::ScanReturnOp>(op) ||
|
||||
isa<triton::ReturnOp>(op)) {
|
||||
return;
|
||||
}
|
||||
llvm_unreachable("Unknown terminator encountered in membar analysis");
|
||||
}
|
||||
|
||||
void MembarAnalysis::update(Operation *op, BlockInfo *blockInfo,
|
||||
FuncBlockInfoMapT *funcBlockInfoMap,
|
||||
OpBuilder *builder) {
|
||||
if (isa<triton::gpu_rocm::ExtractSliceOp>(op) ||
|
||||
isa<triton::gpu_rocm::AllocTensorOp>(op) || isa<triton::TransOp>(op)) {
|
||||
// alloc is an allocation op without memory write.
|
||||
// FIXME(Keren): extract_slice is always alias for now
|
||||
return;
|
||||
}
|
||||
|
||||
// TODO(Keren): Don't expose LLVM Dialect ops here
|
||||
if (isa<gpu::BarrierOp>(op) ||
|
||||
(isa<LLVM::InlineAsmOp>(op) &&
|
||||
(dyn_cast<LLVM::InlineAsmOp>(op).getAsmString().find("bar.sync") !=
|
||||
std::string::npos))) {
|
||||
// If the current op is a barrier, we sync previous reads and writes
|
||||
blockInfo->sync();
|
||||
return;
|
||||
}
|
||||
|
||||
if (isa<triton::gpu_rocm::AsyncWaitOp, triton::gpu_rocm::AsyncBulkWaitOp>(op) &&
|
||||
!isa<gpu::BarrierOp>(op->getNextNode()) &&
|
||||
!(isa<LLVM::InlineAsmOp>(op->getNextNode()) &&
|
||||
(dyn_cast<LLVM::InlineAsmOp>(op->getNextNode())
|
||||
.getAsmString()
|
||||
.find("bar.sync") != std::string::npos))) {
|
||||
// If the current op is an async wait and the next op is not a barrier we
|
||||
// insert a barrier op and sync
|
||||
blockInfo->sync();
|
||||
OpBuilder::InsertionGuard g(*builder);
|
||||
builder->setInsertionPointAfter(op);
|
||||
if (auto optionalAgentId = getWSAgentId(op)) {
|
||||
int agentId = *optionalAgentId, roleId = 0;
|
||||
if (auto optionalRoleId = getWSRoleId(op))
|
||||
roleId = *optionalRoleId;
|
||||
int barId = agentId + roleId + nameBarrierIdBegin;
|
||||
assert(barId < nameBarrierIdEnd);
|
||||
barSync(*builder, op, barId, 128);
|
||||
} else {
|
||||
builder->create<gpu::BarrierOp>(op->getLoc());
|
||||
}
|
||||
blockInfo->sync();
|
||||
return;
|
||||
}
|
||||
|
||||
BlockInfo curBlockInfo;
|
||||
if (isa<triton::CallOp>(op)) {
|
||||
// Inter-function dependencies
|
||||
auto callOpInterface = dyn_cast<CallOpInterface>(op);
|
||||
if (auto callee =
|
||||
dyn_cast<FunctionOpInterface>(callOpInterface.resolveCallable())) {
|
||||
curBlockInfo = funcBlockInfoMap->lookup(callee);
|
||||
}
|
||||
} else {
|
||||
// Intra-function dependencies
|
||||
for (Value value : op->getOperands()) {
|
||||
for (auto bufferId : allocation->getBufferIds(value)) {
|
||||
if (bufferId != Allocation::InvalidBufferId) {
|
||||
if (isa<triton::gpu_rocm::InsertSliceAsyncOp>(op) ||
|
||||
isa<tensor::InsertSliceOp>(op)) {
|
||||
// FIXME(Keren): insert_slice and insert_slice_async are always
|
||||
// alias for now
|
||||
curBlockInfo.syncWriteIntervals.insert(
|
||||
allocation->getAllocatedInterval(bufferId));
|
||||
} else {
|
||||
// ConvertLayoutOp: shared memory -> registers
|
||||
curBlockInfo.syncReadIntervals.insert(
|
||||
allocation->getAllocatedInterval(bufferId));
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
for (Value value : op->getResults()) {
|
||||
// ConvertLayoutOp: registers -> shared memory
|
||||
auto bufferId = allocation->getBufferId(value);
|
||||
if (bufferId != Allocation::InvalidBufferId) {
|
||||
curBlockInfo.syncWriteIntervals.insert(
|
||||
allocation->getAllocatedInterval(bufferId));
|
||||
}
|
||||
}
|
||||
// Scratch buffer is considered as both shared memory write & read
|
||||
auto bufferId = allocation->getBufferId(op);
|
||||
if (bufferId != Allocation::InvalidBufferId) {
|
||||
curBlockInfo.syncWriteIntervals.insert(
|
||||
allocation->getAllocatedInterval(bufferId));
|
||||
curBlockInfo.syncReadIntervals.insert(
|
||||
allocation->getAllocatedInterval(bufferId));
|
||||
}
|
||||
}
|
||||
|
||||
if (blockInfo->isIntersected(curBlockInfo)) {
|
||||
OpBuilder::InsertionGuard g(*builder);
|
||||
builder->setInsertionPoint(op);
|
||||
// TODO(Keren): Don't expose LLVM Dialect ops here
|
||||
// TODO[shuhaoj]: Change hard code style of numThreads. Hide async_agent
|
||||
// attr. Better way to determine barId (number of agents are limited).
|
||||
if (auto optionalAgentId = getWSAgentId(op)) {
|
||||
int agentId = *optionalAgentId, roleId = 0;
|
||||
if (auto optionalRoleId = getWSRoleId(op))
|
||||
roleId = *optionalRoleId;
|
||||
int barId = agentId + roleId + nameBarrierIdBegin;
|
||||
assert(barId < nameBarrierIdEnd);
|
||||
barSync(*builder, op, barId, 128);
|
||||
} else {
|
||||
builder->create<gpu::BarrierOp>(op->getLoc());
|
||||
}
|
||||
blockInfo->sync();
|
||||
}
|
||||
// Update the region info, even if barrier is inserted, we have to maintain
|
||||
// the current op's read/write buffers.
|
||||
blockInfo->join(curBlockInfo);
|
||||
}
|
||||
} // namespace mlir
|
||||
780
python/triton/third_party/hip/lib/AnalysisROCM/Utility.cpp
vendored
Normal file
780
python/triton/third_party/hip/lib/AnalysisROCM/Utility.cpp
vendored
Normal file
@@ -0,0 +1,780 @@
|
||||
#include "triton/AnalysisROCM/Utility.h"
|
||||
#include "mlir/Analysis/DataFlow/ConstantPropagationAnalysis.h"
|
||||
#include "mlir/Analysis/DataFlow/DeadCodeAnalysis.h"
|
||||
#include "mlir/Conversion/LLVMCommon/Pattern.h"
|
||||
#include "mlir/Dialect/ControlFlow/IR/ControlFlowOps.h"
|
||||
#include "mlir/Dialect/LLVMIR/LLVMDialect.h"
|
||||
#include "mlir/IR/Dialect.h"
|
||||
#include "mlir/IR/Matchers.h"
|
||||
#include "triton/Dialect/Triton/IR/Dialect.h"
|
||||
#include "triton/Dialect/TritonGPUROCM/IR/Dialect.h"
|
||||
#include "triton/Dialect/TritonNvidiaGPU/IR/Dialect.h"
|
||||
#include "triton/Tools/Sys/GetEnv.hpp"
|
||||
#include <deque>
|
||||
|
||||
namespace mlir {
|
||||
|
||||
namespace {
|
||||
|
||||
int getParentAxis(Attribute layout, int axis) {
|
||||
if (auto sliceEncoding = layout.dyn_cast<triton::gpu_rocm::SliceEncodingAttr>()) {
|
||||
axis = axis < sliceEncoding.getDim() ? axis : axis + 1;
|
||||
return getParentAxis(sliceEncoding.getParent(), axis);
|
||||
}
|
||||
return axis;
|
||||
}
|
||||
|
||||
SmallVector<unsigned> getParentOrder(Attribute layout) {
|
||||
if (auto sliceEncoding = layout.dyn_cast<triton::gpu_rocm::SliceEncodingAttr>()) {
|
||||
return getParentOrder(sliceEncoding.getParent());
|
||||
}
|
||||
return triton::gpu_rocm::getOrder(layout);
|
||||
}
|
||||
|
||||
} // namespace
|
||||
|
||||
bool ReduceOpHelper::isFastReduction() {
|
||||
// Disable fast reduction only for debugging purpose
|
||||
if (::triton::tools::getBoolEnv("DISABLE_FAST_REDUCTION"))
|
||||
return false;
|
||||
return getParentAxis(getSrcLayout(), axis) ==
|
||||
getParentOrder(getSrcLayout())[0];
|
||||
}
|
||||
|
||||
// Cases where distributed shared memory is not required in ConvertLayout:
|
||||
// (1) numCTAs == 1
|
||||
// (2) numCTAs > 1 but srcCTALayout == dstCTALayout
|
||||
// TODO: Case with SliceLayout as srcLayout and numCTAs > 1 is to be implemented
|
||||
// in the future
|
||||
bool shouldUseDistSmem(Attribute srcLayout, Attribute dstLayout) {
|
||||
unsigned numCTAs = triton::gpu_rocm::getNumCTAs(srcLayout);
|
||||
assert(numCTAs == triton::gpu_rocm::getNumCTAs(dstLayout) &&
|
||||
"Invalid layout conversion: the numbers of CTAs of src and dst "
|
||||
"layouts are different");
|
||||
|
||||
// Case (1): Never use dsmem when numCTAs == 1
|
||||
if (numCTAs == 1)
|
||||
return false;
|
||||
|
||||
// Case where CTAsPerCGA of srcLayout in the sliced dim is not 1 is not
|
||||
// implemented yet
|
||||
if (auto sliceLayout = srcLayout.dyn_cast<triton::gpu_rocm::SliceEncodingAttr>()) {
|
||||
auto dim = sliceLayout.getDim();
|
||||
auto CTAsPerCGA = triton::gpu_rocm::getCTAsPerCGA(sliceLayout.getParent());
|
||||
if (CTAsPerCGA[dim] != 1)
|
||||
assert(0 && "Layout conversion to be implemented");
|
||||
}
|
||||
|
||||
// Case where CTAsPerCGA of dstLayout in the sliced dim is not 1 is supported
|
||||
if (auto sliceLayout = dstLayout.dyn_cast<triton::gpu_rocm::SliceEncodingAttr>()) {
|
||||
auto dim = sliceLayout.getDim();
|
||||
auto CTAsPerCGA = triton::gpu_rocm::getCTAsPerCGA(sliceLayout.getParent());
|
||||
if (CTAsPerCGA[dim] != 1)
|
||||
return true;
|
||||
}
|
||||
|
||||
// The above two branches make sure that it is legal to call getCTALayout of
|
||||
// srcLayout and dstLayout
|
||||
|
||||
// Case (2): Do not use dsmem when srcCTALayout == dstCTALayout
|
||||
auto srcCTALayout = triton::gpu_rocm::getCTALayout(srcLayout);
|
||||
auto dstCTALayout = triton::gpu_rocm::getCTALayout(dstLayout);
|
||||
if (srcCTALayout == dstCTALayout)
|
||||
return false;
|
||||
|
||||
// Dsmem access is required when srcCTALayout != dstCTALayout
|
||||
return true;
|
||||
}
|
||||
|
||||
unsigned ReduceOpHelper::getInterWarpSize() {
|
||||
auto srcReduceDimSize = static_cast<unsigned>(srcShape[axis]);
|
||||
unsigned sizeIntraWarps = getIntraWarpSize();
|
||||
return std::min(srcReduceDimSize / sizeIntraWarps,
|
||||
triton::gpu_rocm::getWarpsPerCTA(getSrcLayout())[axis]);
|
||||
}
|
||||
|
||||
unsigned ReduceOpHelper::getIntraWarpSize() {
|
||||
auto srcReduceDimSize = static_cast<unsigned>(srcShape[axis]);
|
||||
return std::min(srcReduceDimSize,
|
||||
triton::gpu_rocm::getThreadsPerWarp(getSrcLayout())[axis]);
|
||||
}
|
||||
|
||||
unsigned ReduceOpHelper::getInterWarpSizeWithUniqueData() {
|
||||
auto srcReduceDimSize = static_cast<unsigned>(srcShape[axis]);
|
||||
unsigned sizeIntraWarps = getIntraWarpSizeWithUniqueData();
|
||||
return std::min(srcReduceDimSize / sizeIntraWarps,
|
||||
triton::gpu_rocm::getWarpsPerCTAWithUniqueData(
|
||||
getSrcLayout(), getSrcShape())[axis]);
|
||||
}
|
||||
|
||||
unsigned ReduceOpHelper::getIntraWarpSizeWithUniqueData() {
|
||||
auto srcReduceDimSize = static_cast<unsigned>(srcShape[axis]);
|
||||
unsigned elementPerThreads = triton::gpu_rocm::getUniqueContigPerThread(
|
||||
getSrcLayout(), getSrcShape())[axis];
|
||||
return std::min(srcReduceDimSize / elementPerThreads,
|
||||
triton::gpu_rocm::getThreadsPerWarpWithUniqueData(
|
||||
getSrcLayout(), getSrcShape())[axis]);
|
||||
}
|
||||
|
||||
unsigned ReduceOpHelper::getThreadsReductionAxis() {
|
||||
auto srcLayout = getSrcLayout();
|
||||
auto srcShape = getSrcShape();
|
||||
return triton::gpu_rocm::getThreadsPerWarpWithUniqueData(srcLayout,
|
||||
srcShape)[axis] *
|
||||
triton::gpu_rocm::getWarpsPerCTAWithUniqueData(srcLayout, srcShape)[axis];
|
||||
}
|
||||
|
||||
SmallVector<unsigned> ReduceOpHelper::getScratchConfigBasic() {
|
||||
auto smemShape = convertType<unsigned>(getSrcShape());
|
||||
smemShape[axis] = std::min(smemShape[axis], getThreadsReductionAxis());
|
||||
return smemShape;
|
||||
}
|
||||
|
||||
bool ReduceOpHelper::isWarpSynchronous() {
|
||||
auto argsLayout = getSrcLayout();
|
||||
return isFastReduction() &&
|
||||
(triton::gpu_rocm::getWarpsPerCTA(argsLayout)[axis] == 1);
|
||||
}
|
||||
|
||||
SmallVector<SmallVector<unsigned>> ReduceOpHelper::getScratchConfigsFast() {
|
||||
SmallVector<SmallVector<unsigned>> smemShapes(3);
|
||||
|
||||
auto argLayout = getSrcLayout();
|
||||
auto argLayoutMma = argLayout.dyn_cast<triton::gpu_rocm::MmaEncodingAttr>();
|
||||
|
||||
// that case doesn't need inter-warp communication
|
||||
if (isWarpSynchronous())
|
||||
return {{0, 0}, {0, 0}};
|
||||
|
||||
/// shared memory block0
|
||||
smemShapes[0] = convertType<unsigned>(getSrcShape());
|
||||
smemShapes[0][axis] = getInterWarpSize();
|
||||
|
||||
/// FIXME(Qingyi): This size is actually larger than required.
|
||||
/// shared memory block1:
|
||||
auto mod = op->getParentOfType<ModuleOp>();
|
||||
unsigned numWarps = triton::gpu_rocm::TritonGPUROCMDialect::getNumWarps(mod);
|
||||
unsigned threadsPerWarp =
|
||||
triton::gpu_rocm::TritonGPUROCMDialect::getThreadsPerWarp(mod);
|
||||
smemShapes[1].push_back(numWarps * threadsPerWarp);
|
||||
|
||||
return smemShapes;
|
||||
}
|
||||
|
||||
unsigned ReduceOpHelper::getScratchSizeInBytes() {
|
||||
unsigned elems = 0;
|
||||
if (isFastReduction()) {
|
||||
auto smemShapes = getScratchConfigsFast();
|
||||
for (const auto &smemShape : smemShapes)
|
||||
elems = std::max(elems, product<unsigned>(smemShape));
|
||||
} else {
|
||||
auto smemShape = getScratchConfigBasic();
|
||||
elems = product<unsigned>(smemShape);
|
||||
}
|
||||
|
||||
unsigned bytesPerElem = 0;
|
||||
for (const auto &ty : srcElementTypes) {
|
||||
bytesPerElem += ceil<unsigned>(ty.getIntOrFloatBitWidth(), 8);
|
||||
}
|
||||
return bytesPerElem * elems;
|
||||
}
|
||||
|
||||
bool ReduceOpHelper::isSupportedLayout() {
|
||||
auto srcLayout = getSrcLayout();
|
||||
if (srcLayout.isa<triton::gpu_rocm::BlockedEncodingAttr>()) {
|
||||
return true;
|
||||
}
|
||||
if (auto mmaLayout = srcLayout.dyn_cast<triton::gpu_rocm::MmaEncodingAttr>()) {
|
||||
if (mmaLayout.isAmpere()) {
|
||||
return true;
|
||||
}
|
||||
}
|
||||
if (auto mfmaLayout = srcLayout.dyn_cast<triton::gpu_rocm::MfmaEncodingAttr>()) {
|
||||
return true;
|
||||
}
|
||||
if (auto sliceLayout = srcLayout.dyn_cast<triton::gpu_rocm::SliceEncodingAttr>()) {
|
||||
return true;
|
||||
}
|
||||
return false;
|
||||
}
|
||||
|
||||
unsigned ScanLoweringHelper::getAxisNumElementsPerThread() {
|
||||
return getEncoding().getSizePerThread()[getAxis()];
|
||||
}
|
||||
|
||||
unsigned ScanLoweringHelper::getNonAxisNumElementsPerThread() {
|
||||
SmallVector<unsigned> sizePerThreads = getContigPerThread(getEncoding());
|
||||
sizePerThreads[getAxis()] = 1;
|
||||
return product<unsigned>(sizePerThreads);
|
||||
}
|
||||
|
||||
Region &ScanLoweringHelper::getCombineOp() { return scanOp.getCombineOp(); }
|
||||
|
||||
unsigned ScanLoweringHelper::getAxisNumThreadsPerWarp() {
|
||||
return triton::gpu_rocm::getThreadsPerWarp(getEncoding())[getAxis()];
|
||||
}
|
||||
|
||||
unsigned ScanLoweringHelper::getNonAxisNumThreadsPerWarp() {
|
||||
auto threadsPerWarp = triton::gpu_rocm::getThreadsPerWarp(getEncoding());
|
||||
threadsPerWarp[getAxis()] = 1;
|
||||
return product<unsigned>(threadsPerWarp);
|
||||
}
|
||||
|
||||
// Return the flat numbers of threads computing independent scan results.
|
||||
unsigned ScanLoweringHelper::getNonAxisNumThreadsPerCTA() {
|
||||
unsigned numParallelThreadsPerWarp = getNonAxisNumThreadsPerWarp();
|
||||
auto warpsPerCTA = triton::gpu_rocm::getWarpsPerCTA(getEncoding());
|
||||
warpsPerCTA[getAxis()] = 1;
|
||||
unsigned numParallelWarpsPerCTA = product<unsigned>(warpsPerCTA);
|
||||
return numParallelThreadsPerWarp * numParallelWarpsPerCTA;
|
||||
}
|
||||
unsigned ScanLoweringHelper::getAxisNumWarps() {
|
||||
auto warpsPerCTA = triton::gpu_rocm::getWarpsPerCTA(srcEncoding);
|
||||
return warpsPerCTA[getAxis()];
|
||||
}
|
||||
|
||||
unsigned ScanLoweringHelper::getAxisNumBlocks() {
|
||||
auto type = scanOp.getOperand(0).getType().cast<RankedTensorType>();
|
||||
auto sizePerThreads = triton::gpu_rocm::getSizePerThread(srcEncoding);
|
||||
auto threadsPerWarp = triton::gpu_rocm::getThreadsPerWarp(srcEncoding);
|
||||
auto warpsPerCTA = triton::gpu_rocm::getWarpsPerCTA(srcEncoding);
|
||||
unsigned axis = getAxis();
|
||||
return ceil<unsigned>(
|
||||
type.getShape()[axis],
|
||||
(sizePerThreads[axis] * threadsPerWarp[axis] * warpsPerCTA[axis]));
|
||||
}
|
||||
|
||||
unsigned ScanLoweringHelper::getNonAxisNumBlocks() {
|
||||
auto type = scanOp.getOperand(0).getType().cast<RankedTensorType>();
|
||||
auto sizePerThreads = triton::gpu_rocm::getSizePerThread(srcEncoding);
|
||||
auto threadsPerWarp = triton::gpu_rocm::getThreadsPerWarp(srcEncoding);
|
||||
auto warpsPerCTA = triton::gpu_rocm::getWarpsPerCTA(srcEncoding);
|
||||
unsigned axis = getAxis();
|
||||
unsigned numBlocks = 1;
|
||||
for (unsigned i = 0; i < sizePerThreads.size(); i++) {
|
||||
if (i == axis)
|
||||
continue;
|
||||
numBlocks *= ceil<unsigned>(
|
||||
type.getShape()[i],
|
||||
(sizePerThreads[i] * threadsPerWarp[i] * warpsPerCTA[i]));
|
||||
}
|
||||
return numBlocks;
|
||||
}
|
||||
|
||||
bool ScanLoweringHelper::isSupported() {
|
||||
// TODO: Support the following cases:
|
||||
// 1. Scan on non-blocking encodings
|
||||
// 2. Scan with multiple operands
|
||||
if (!isa<triton::gpu_rocm::BlockedEncodingAttr>(srcEncoding))
|
||||
return false;
|
||||
if (scanOp.getNumOperands() != 1)
|
||||
return false;
|
||||
return true;
|
||||
}
|
||||
|
||||
unsigned ScanLoweringHelper::getScratchSizeInBytes() {
|
||||
auto type = scanOp.getOperand(0).getType().cast<RankedTensorType>();
|
||||
unsigned elementSizeInBytes = type.getElementTypeBitWidth() / 8;
|
||||
auto mod = scanOp->getParentOfType<ModuleOp>();
|
||||
unsigned numWarps = triton::gpu_rocm::TritonGPUROCMDialect::getNumWarps(mod);
|
||||
unsigned numNonAxisElementsPerWapr =
|
||||
getNonAxisNumThreadsPerWarp() * getNonAxisNumElementsPerThread();
|
||||
unsigned numElements = numWarps * numNonAxisElementsPerWapr *
|
||||
getAxisNumBlocks() * getNonAxisNumBlocks();
|
||||
return elementSizeInBytes * numElements;
|
||||
}
|
||||
|
||||
triton::gpu_rocm::BlockedEncodingAttr ScanLoweringHelper::getEncoding() {
|
||||
return srcEncoding.cast<triton::gpu_rocm::BlockedEncodingAttr>();
|
||||
}
|
||||
|
||||
unsigned ScanLoweringHelper::getAxisElementStride() {
|
||||
auto order = triton::gpu_rocm::getOrder(srcEncoding);
|
||||
unsigned stride = 1;
|
||||
for (unsigned dim : order) {
|
||||
if (dim == getAxis())
|
||||
return stride;
|
||||
stride *= getContigPerThread(getEncoding())[dim];
|
||||
}
|
||||
llvm_unreachable("Axis not found in order");
|
||||
}
|
||||
|
||||
unsigned ScanLoweringHelper::getAxisThreadStride() {
|
||||
auto order = triton::gpu_rocm::getOrder(srcEncoding);
|
||||
unsigned stride = 1;
|
||||
for (unsigned dim : order) {
|
||||
if (dim == getAxis())
|
||||
return stride;
|
||||
stride *= getEncoding().getThreadsPerWarp()[dim];
|
||||
}
|
||||
llvm_unreachable("Axis not found in order");
|
||||
}
|
||||
|
||||
unsigned ScanLoweringHelper::getAxisBlockStride() {
|
||||
auto order = triton::gpu_rocm::getOrder(srcEncoding);
|
||||
unsigned stride = 1;
|
||||
auto type = scanOp.getOperand(0).getType().cast<RankedTensorType>();
|
||||
auto sizePerThreads = triton::gpu_rocm::getSizePerThread(srcEncoding);
|
||||
auto threadsPerWarp = triton::gpu_rocm::getThreadsPerWarp(srcEncoding);
|
||||
auto warpsPerCTA = triton::gpu_rocm::getWarpsPerCTA(srcEncoding);
|
||||
for (unsigned dim : order) {
|
||||
if (dim == getAxis())
|
||||
return stride;
|
||||
stride *= ceil<unsigned int>(type.getShape()[dim], sizePerThreads[dim] *
|
||||
threadsPerWarp[dim] *
|
||||
warpsPerCTA[dim]);
|
||||
}
|
||||
llvm_unreachable("Axis not found in order");
|
||||
}
|
||||
|
||||
bool maybeSharedAllocationOp(Operation *op) {
|
||||
// TODO(Keren): This function can be replaced by adding
|
||||
// MemoryEffectOpInterface. We can then use the MemoryEffectOpInterface to
|
||||
// query the memory effects of the op.
|
||||
auto *dialect = op->getDialect();
|
||||
return dialect &&
|
||||
(dialect->getTypeID() ==
|
||||
mlir::TypeID::get<triton::gpu_rocm::TritonGPUROCMDialect>() ||
|
||||
dialect->getTypeID() ==
|
||||
mlir::TypeID::get<triton::nvidia_gpu::TritonNvidiaGPUDialect>() ||
|
||||
dialect->getTypeID() == mlir::TypeID::get<triton::TritonDialect>() ||
|
||||
dialect->getTypeID() == mlir::TypeID::get<arith::ArithDialect>() ||
|
||||
dialect->getTypeID() == mlir::TypeID::get<tensor::TensorDialect>());
|
||||
}
|
||||
|
||||
bool maybeAliasOp(Operation *op) {
|
||||
return isa<triton::gpu_rocm::ExtractSliceOp>(op) || isa<triton::TransOp>(op) ||
|
||||
isa<triton::gpu_rocm::InsertSliceAsyncOp>(op) ||
|
||||
isa<triton::nvidia_gpu::InsertSliceAsyncV2Op>(op) ||
|
||||
isa<triton::nvidia_gpu::StoreAsyncOp>(op) ||
|
||||
isa<tensor::InsertSliceOp>(op);
|
||||
}
|
||||
|
||||
bool supportMMA(triton::DotOp op, int version) {
|
||||
// Refer to mma section for the data type supported by Volta and Hopper
|
||||
// Tensor Core in
|
||||
// https://docs.nvidia.com/cuda/parallel-thread-execution/index.html#warp-level-matrix-fragment-mma-884-f16
|
||||
auto aElemTy = op.getA().getType().cast<RankedTensorType>().getElementType();
|
||||
auto bElemTy = op.getB().getType().cast<RankedTensorType>().getElementType();
|
||||
if (version == 3) {
|
||||
if (!::triton::tools::getBoolEnv("ENABLE_MMA_V3"))
|
||||
return false;
|
||||
auto retType = op.getResult().getType().cast<RankedTensorType>();
|
||||
auto retShapePerCTA = triton::gpu_rocm::getShapePerCTA(retType);
|
||||
auto mod = op->getParentOfType<mlir::ModuleOp>();
|
||||
int numWarps = triton::gpu_rocm::TritonGPUROCMDialect::getNumWarps(mod);
|
||||
if (!(numWarps % 4 == 0 && retShapePerCTA[0] % 64 == 0 &&
|
||||
retShapePerCTA[1] % 8 == 0 &&
|
||||
(aElemTy.isFloat8E5M2() || aElemTy.isFloat8E4M3FNUZ() ||
|
||||
aElemTy.isInteger(8) || aElemTy.isF16() || aElemTy.isBF16() ||
|
||||
aElemTy.isF32()))) {
|
||||
return false;
|
||||
}
|
||||
}
|
||||
if (aElemTy.isF32() && bElemTy.isF32()) {
|
||||
return (op.getAllowTF32() && version == 2) || version == 3;
|
||||
}
|
||||
return supportMMA(op.getA(), version) && supportMMA(op.getB(), version);
|
||||
}
|
||||
|
||||
#if 1
|
||||
static bool supportMFMAGranularity(int m, int n, int k, int64_t nonKDim) {
|
||||
// these limitations are dtype dependent, in future we may relax them
|
||||
const int granularityMN = nonKDim;
|
||||
const int granularityK = nonKDim == 32 ? 8 : 16;
|
||||
if (m % granularityMN != 0 || n % granularityMN != 0)
|
||||
return false;
|
||||
if (k % granularityK != 0)
|
||||
return false;
|
||||
return true;
|
||||
}
|
||||
|
||||
bool supportMFMA(triton::DotOp op, int64_t nonKDim) {
|
||||
auto aTy = op.getA().getType().cast<RankedTensorType>();
|
||||
auto bTy = op.getB().getType().cast<RankedTensorType>();
|
||||
|
||||
auto aElemTy = aTy.getElementType();
|
||||
auto bElemTy = bTy.getElementType();
|
||||
|
||||
if (aElemTy != bElemTy)
|
||||
return false;
|
||||
|
||||
auto aShape = aTy.getShape();
|
||||
auto bShape = bTy.getShape();
|
||||
|
||||
assert(aShape[1] == bShape[0]);
|
||||
if (!supportMFMAGranularity(aShape[0], bShape[1], aShape[1], nonKDim))
|
||||
return false;
|
||||
|
||||
return aElemTy.isF16() || aElemTy.isBF16() || aElemTy.isF32() ||
|
||||
aElemTy.isInteger(8);
|
||||
}
|
||||
#endif
|
||||
|
||||
bool supportMMA(Value value, int version) {
|
||||
// Tell whether a DotOp support MMA by the operand type(either $a or $b).
|
||||
// We cannot get both the operand types(in TypeConverter), here we assume the
|
||||
// types of both the operands are identical here.
|
||||
assert((version == 1 || version == 2 || version == 3) &&
|
||||
"Unexpected MMA layout version found");
|
||||
|
||||
auto elemTy = value.getType().cast<RankedTensorType>().getElementType();
|
||||
// FP8 is not natively supported on all mma versions but it can always be
|
||||
// promoted to fp16 therefore we can always support it.
|
||||
bool isFP8 = elemTy.isFloat8E5M2() || elemTy.isFloat8E4M3FN() ||
|
||||
elemTy.isFloat8E5M2FNUZ() || elemTy.isFloat8E4M3FNUZ();
|
||||
return isFP8 || elemTy.isF16() || elemTy.isBF16() ||
|
||||
(elemTy.isF32() && version >= 2) ||
|
||||
(elemTy.isInteger(8) && version >= 2);
|
||||
}
|
||||
|
||||
bool isMmaToDotShortcut(RankedTensorType &srcTy, RankedTensorType &dstTy) {
|
||||
// dot_op<opIdx=0, parent=#mma> = #mma
|
||||
// when #mma = MmaEncoding<version=2, warpsPerCTA=[..., 1]>
|
||||
auto srcLayout = srcTy.getEncoding();
|
||||
auto dstLayout = dstTy.getEncoding();
|
||||
auto mmaLayout = srcLayout.cast<triton::gpu_rocm::MmaEncodingAttr>();
|
||||
auto dotOperandLayout = dstLayout.cast<triton::gpu_rocm::DotOperandEncodingAttr>();
|
||||
return mmaLayout.getVersionMajor() == 2 &&
|
||||
mmaLayout.getWarpsPerCTA()[1] == 1 &&
|
||||
dotOperandLayout.getOpIdx() == 0 &&
|
||||
dotOperandLayout.getParent() == mmaLayout &&
|
||||
!srcTy.getElementType().isF32();
|
||||
}
|
||||
|
||||
#if 1
|
||||
bool isMfmaToDotShortcut(RankedTensorType &srcTy, RankedTensorType &dstTy) {
|
||||
auto srcLayout = srcTy.getEncoding();
|
||||
auto dstLayout = dstTy.getEncoding();
|
||||
auto mfmaLayout = srcLayout.cast<triton::gpu_rocm::MfmaEncodingAttr>();
|
||||
auto dotOperandLayout = dstLayout.cast<triton::gpu_rocm::DotOperandEncodingAttr>();
|
||||
// TODO: Remove the restriction on the warpsPerCTA once chain dot testing is
|
||||
// improved. In addition, we can enable this shortcut for regular MFMA
|
||||
// layout when opIdx == 1.
|
||||
return mfmaLayout.getWarpsPerCTA()[1] == 1 &&
|
||||
dotOperandLayout.getOpIdx() == 0 &&
|
||||
dotOperandLayout.getKWidth() == 4 &&
|
||||
dotOperandLayout.getParent() == mfmaLayout &&
|
||||
mfmaLayout.getNonKDim() == 32 && mfmaLayout.getIsTransposed() &&
|
||||
(srcTy.getElementType().isF16() || srcTy.getElementType().isBF16());
|
||||
}
|
||||
#endif
|
||||
|
||||
bool isMmaToMmaShortcut(RankedTensorType &srcTy, RankedTensorType &dstTy) {
|
||||
auto src = srcTy.getEncoding().cast<triton::gpu_rocm::MmaEncodingAttr>();
|
||||
auto dst = dstTy.getEncoding().cast<triton::gpu_rocm::MmaEncodingAttr>();
|
||||
auto srcElemsPerThread = triton::gpu_rocm::getTotalElemsPerThread(srcTy);
|
||||
auto dstElemsPerThread = triton::gpu_rocm::getTotalElemsPerThread(dstTy);
|
||||
// when #mma = MmaEncoding<version=3, warpsPerCTA=[..., 1]>
|
||||
return src.getVersionMajor() == 3 && src.getWarpsPerCTA()[1] == 1 &&
|
||||
dst.getVersionMajor() == 3 && dst.getWarpsPerCTA()[1] == 1 &&
|
||||
srcElemsPerThread == dstElemsPerThread;
|
||||
}
|
||||
|
||||
bool isSingleValue(Value value) {
|
||||
// Don't consider load as expensive if it is loading a scalar.
|
||||
if (auto tensorTy = value.getType().dyn_cast<RankedTensorType>())
|
||||
return tensorTy.getNumElements() == 1;
|
||||
// TODO: Handle other cases.
|
||||
// For example, when ptr is a tensor of single value.
|
||||
// It means that ptr is a resultant of broadcast or generated through
|
||||
// a chain of broadcast and other operations.
|
||||
// Rematerialize it without considering contiguous memory access pattern is
|
||||
// fine.
|
||||
return true;
|
||||
}
|
||||
|
||||
namespace {
|
||||
|
||||
/// A data structure similar to SetVector but maintains
|
||||
/// a deque instead of a vector to allow for efficient
|
||||
/// push_back and pop_front operations.
|
||||
/// Using SetVector doesn't suffice our needs because
|
||||
/// it only pushes and pops from the back.
|
||||
/// For example, if we have a queue like this:
|
||||
/// 0->4 1->2->3
|
||||
/// ^--------
|
||||
/// where 3 depends on 4, once we pop 3, we found
|
||||
/// 4 is not ready, so we check 2 and push 3 back
|
||||
/// to the queue.
|
||||
struct DFSSubgraphState {
|
||||
DFSSubgraphState() : set(), deque() {}
|
||||
DenseSet<Operation *> set;
|
||||
std::deque<Operation *> deque;
|
||||
|
||||
bool push_back(Operation *op) {
|
||||
if (set.insert(op).second) {
|
||||
deque.push_back(op);
|
||||
return true;
|
||||
}
|
||||
return false;
|
||||
}
|
||||
|
||||
Operation *pop_front() {
|
||||
Operation *op = deque.front();
|
||||
deque.pop_front();
|
||||
set.erase(op);
|
||||
return op;
|
||||
}
|
||||
|
||||
bool empty() { return deque.empty(); }
|
||||
};
|
||||
|
||||
/// DFS post-order implementation that maintains a global count to work across
|
||||
/// multiple invocations, to help implement topological sort on multi-root DAGs.
|
||||
/// We traverse all operations but only record the ones that appear in
|
||||
/// `toSort` for the final result.
|
||||
struct DFSState {
|
||||
DFSState(const SetVector<Operation *> &set) : toSort(set), seen() {}
|
||||
const SetVector<Operation *> &toSort;
|
||||
SmallVector<Operation *, 16> topologicalCounts;
|
||||
DenseSet<Operation *> seen;
|
||||
|
||||
/// We mark each op as ready if all its operands and parents ops are seen. If
|
||||
/// an op is ready, we add it to the queue. Otherwise, we keep adding its
|
||||
/// operands to the ancestors set.
|
||||
/// We always want an op to be scheduled after all its parents to handle
|
||||
/// correctly cases with scf operations.
|
||||
void addToReadyQueue(Operation *op, DFSSubgraphState &subGraph,
|
||||
SmallVector<Operation *, 4> &readyQueue) {
|
||||
bool ready = true;
|
||||
for (Value operand : op->getOperands()) {
|
||||
auto def = operand.getDefiningOp();
|
||||
if (def && !seen.count(def)) {
|
||||
subGraph.push_back(def);
|
||||
ready = false;
|
||||
}
|
||||
}
|
||||
Operation *parent = op->getParentOp();
|
||||
while (parent) {
|
||||
if (!seen.count(parent)) {
|
||||
subGraph.push_back(parent);
|
||||
ready = false;
|
||||
}
|
||||
parent = parent->getParentOp();
|
||||
}
|
||||
if (ready)
|
||||
readyQueue.push_back(op);
|
||||
}
|
||||
};
|
||||
|
||||
void dfsPostorder(Operation *root, DFSState *state) {
|
||||
DFSSubgraphState subGraph;
|
||||
subGraph.push_back(root);
|
||||
SmallVector<Operation *> ops;
|
||||
while (!subGraph.empty()) {
|
||||
// Nodes in the ready queue are ready to be processed.
|
||||
// Meaning that either their operands are all seen or they have null
|
||||
// operands.
|
||||
SmallVector<Operation *, 4> readyQueue;
|
||||
auto *current = subGraph.pop_front();
|
||||
state->addToReadyQueue(current, subGraph, readyQueue);
|
||||
while (!readyQueue.empty()) {
|
||||
Operation *current = readyQueue.pop_back_val();
|
||||
if (!state->seen.insert(current).second)
|
||||
continue;
|
||||
ops.push_back(current);
|
||||
for (Value result : current->getResults()) {
|
||||
for (Operation *op : result.getUsers())
|
||||
state->addToReadyQueue(op, subGraph, readyQueue);
|
||||
}
|
||||
for (Region ®ion : current->getRegions()) {
|
||||
for (Operation &op : region.getOps())
|
||||
state->addToReadyQueue(&op, subGraph, readyQueue);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
for (Operation *op : llvm::reverse(ops)) {
|
||||
if (state->toSort.count(op) > 0)
|
||||
state->topologicalCounts.push_back(op);
|
||||
}
|
||||
}
|
||||
|
||||
} // namespace
|
||||
|
||||
SetVector<Operation *>
|
||||
multiRootTopologicalSort(const SetVector<Operation *> &toSort) {
|
||||
if (toSort.empty()) {
|
||||
return toSort;
|
||||
}
|
||||
|
||||
// Run from each root with global count and `seen` set.
|
||||
DFSState state(toSort);
|
||||
for (auto *s : toSort) {
|
||||
assert(toSort.count(s) == 1 && "NYI: multi-sets not supported");
|
||||
dfsPostorder(s, &state);
|
||||
}
|
||||
|
||||
// Reorder and return.
|
||||
SetVector<Operation *> res;
|
||||
for (auto it = state.topologicalCounts.rbegin(),
|
||||
eit = state.topologicalCounts.rend();
|
||||
it != eit; ++it) {
|
||||
res.insert(*it);
|
||||
}
|
||||
return res;
|
||||
}
|
||||
|
||||
SetVector<Operation *> multiRootGetSlice(Operation *op,
|
||||
TransitiveFilter backwardFilter,
|
||||
TransitiveFilter forwardFilter) {
|
||||
SetVector<Operation *> slice;
|
||||
slice.insert(op);
|
||||
|
||||
unsigned currentIndex = 0;
|
||||
SetVector<Operation *> backwardSlice;
|
||||
SetVector<Operation *> forwardSlice;
|
||||
while (currentIndex != slice.size()) {
|
||||
auto *currentOp = (slice)[currentIndex];
|
||||
// Compute and insert the backwardSlice starting from currentOp.
|
||||
backwardSlice.clear();
|
||||
getBackwardSlice(currentOp, &backwardSlice, backwardFilter);
|
||||
slice.insert(backwardSlice.begin(), backwardSlice.end());
|
||||
|
||||
// Compute and insert the forwardSlice starting from currentOp.
|
||||
forwardSlice.clear();
|
||||
getForwardSlice(currentOp, &forwardSlice, forwardFilter);
|
||||
slice.insert(forwardSlice.begin(), forwardSlice.end());
|
||||
++currentIndex;
|
||||
}
|
||||
return multiRootTopologicalSort(slice);
|
||||
}
|
||||
|
||||
namespace {
|
||||
// Copied from TestDeadCodeAnalysis.cpp, because some dead code analysis
|
||||
// interacts with constant propagation, but SparseConstantPropagation
|
||||
// doesn't seem to be sufficient.
|
||||
class ConstantAnalysis : public DataFlowAnalysis {
|
||||
public:
|
||||
using DataFlowAnalysis::DataFlowAnalysis;
|
||||
|
||||
LogicalResult initialize(Operation *top) override {
|
||||
WalkResult result = top->walk([&](Operation *op) {
|
||||
if (failed(visit(op)))
|
||||
return WalkResult::interrupt();
|
||||
return WalkResult::advance();
|
||||
});
|
||||
return success(!result.wasInterrupted());
|
||||
}
|
||||
|
||||
LogicalResult visit(ProgramPoint point) override {
|
||||
Operation *op = point.get<Operation *>();
|
||||
Attribute value;
|
||||
if (matchPattern(op, m_Constant(&value))) {
|
||||
auto *constant = getOrCreate<dataflow::Lattice<dataflow::ConstantValue>>(
|
||||
op->getResult(0));
|
||||
propagateIfChanged(constant, constant->join(dataflow::ConstantValue(
|
||||
value, op->getDialect())));
|
||||
return success();
|
||||
}
|
||||
// Dead code analysis requires every operands has initialized ConstantValue
|
||||
// state before it is visited.
|
||||
// https://github.com/llvm/llvm-project/blob/2ec1aba2b69faa1de5f71832a48e25aa3b5d5314/mlir/lib/Analysis/DataFlow/DeadCodeAnalysis.cpp#L322
|
||||
// That's why we need to set all operands to unknown constants.
|
||||
setAllToUnknownConstants(op->getResults());
|
||||
for (Region ®ion : op->getRegions()) {
|
||||
for (Block &block : region.getBlocks())
|
||||
setAllToUnknownConstants(block.getArguments());
|
||||
}
|
||||
return success();
|
||||
}
|
||||
|
||||
private:
|
||||
/// Set all given values as not constants.
|
||||
void setAllToUnknownConstants(ValueRange values) {
|
||||
dataflow::ConstantValue unknownConstant(nullptr, nullptr);
|
||||
for (Value value : values) {
|
||||
auto *constant =
|
||||
getOrCreate<dataflow::Lattice<dataflow::ConstantValue>>(value);
|
||||
propagateIfChanged(constant, constant->join(unknownConstant));
|
||||
}
|
||||
}
|
||||
};
|
||||
} // namespace
|
||||
|
||||
std::unique_ptr<DataFlowSolver> createDataFlowSolver() {
|
||||
auto solver = std::make_unique<DataFlowSolver>();
|
||||
solver->load<dataflow::DeadCodeAnalysis>();
|
||||
solver->load<ConstantAnalysis>();
|
||||
return solver;
|
||||
}
|
||||
|
||||
static triton::MakeTensorPtrOp getMakeTensorPtrOpImpl(Operation *op, Value v) {
|
||||
|
||||
if (auto makeTensorPtrOp = dyn_cast<triton::MakeTensorPtrOp>(op)) {
|
||||
return makeTensorPtrOp;
|
||||
}
|
||||
|
||||
if (auto advanceOp = dyn_cast<triton::AdvanceOp>(op)) {
|
||||
return getMakeTensorPtrOp(advanceOp.getPtr());
|
||||
}
|
||||
|
||||
if (auto branch = dyn_cast<RegionBranchOpInterface>(op)) {
|
||||
auto idx = v.cast<OpResult>().getResultNumber();
|
||||
llvm::SmallVector<scf::YieldOp> yieldOps;
|
||||
op->walk([&](Operation *op) {
|
||||
if (auto yieldOp = dyn_cast<scf::YieldOp>(op))
|
||||
yieldOps.push_back(yieldOp);
|
||||
});
|
||||
|
||||
// benzh@ if multi yields, all yields operand should come from same arg.
|
||||
Value newValue = yieldOps[0].getOperands()[idx];
|
||||
return getMakeTensorPtrOp(newValue);
|
||||
}
|
||||
|
||||
llvm_unreachable("Unable to getMakeTensorPtr()");
|
||||
}
|
||||
|
||||
triton::MakeTensorPtrOp getMakeTensorPtrOp(Value v) {
|
||||
using BranchOps = llvm::SetVector<std::pair<Operation *, int>>;
|
||||
llvm::DenseMap<Block *, BranchOps> blockToCFOps;
|
||||
auto moduleOp =
|
||||
v.getParentBlock()->getParentOp()->getParentOfType<ModuleOp>();
|
||||
|
||||
moduleOp.walk([&](Operation *op) {
|
||||
if (auto br = dyn_cast<cf::BranchOp>(op)) {
|
||||
Block *block = br.getDest();
|
||||
blockToCFOps[block].insert({op, -1});
|
||||
}
|
||||
if (auto condBr = dyn_cast<cf::CondBranchOp>(op)) {
|
||||
Block *blockT = condBr.getTrueDest();
|
||||
Block *blockF = condBr.getFalseDest();
|
||||
blockToCFOps[blockT].insert({condBr, 1});
|
||||
blockToCFOps[blockF].insert({condBr, 0});
|
||||
}
|
||||
});
|
||||
|
||||
if (Operation *definingOp = v.getDefiningOp()) {
|
||||
return getMakeTensorPtrOpImpl(definingOp, v);
|
||||
} else if (BlockArgument arg = v.cast<BlockArgument>()) {
|
||||
unsigned argNum = arg.getArgNumber();
|
||||
Operation *argOwner = arg.getOwner()->getParentOp();
|
||||
|
||||
if (auto forOp = dyn_cast<scf::ForOp>(argOwner)) {
|
||||
return getMakeTensorPtrOp(
|
||||
forOp.getOperand(argNum + forOp.getNumControlOperands() - 1));
|
||||
} else if (auto funcOp = dyn_cast<mlir::triton::FuncOp>(argOwner)) {
|
||||
Block *block = arg.getOwner();
|
||||
Operation *op;
|
||||
int tOrF;
|
||||
std::tie(op, tOrF) = blockToCFOps[block][0];
|
||||
if (auto br = dyn_cast<cf::BranchOp>(op)) {
|
||||
return getMakeTensorPtrOp(br.getDestOperands()[argNum]);
|
||||
}
|
||||
if (auto condBr = dyn_cast<cf::CondBranchOp>(op)) {
|
||||
if (tOrF) {
|
||||
return getMakeTensorPtrOp(condBr.getTrueDestOperands()[argNum]);
|
||||
} else {
|
||||
return getMakeTensorPtrOp(condBr.getFalseDestOperands()[argNum]);
|
||||
}
|
||||
}
|
||||
} else {
|
||||
return getMakeTensorPtrOp(argOwner->getOperand(argNum));
|
||||
}
|
||||
}
|
||||
|
||||
llvm_unreachable("Unable to getMakeTensorPtr()");
|
||||
}
|
||||
|
||||
} // namespace mlir
|
||||
5
python/triton/third_party/hip/lib/CMakeLists.txt
vendored
Normal file
5
python/triton/third_party/hip/lib/CMakeLists.txt
vendored
Normal file
@@ -0,0 +1,5 @@
|
||||
# add_subdirectory(codegen)
|
||||
add_subdirectory(AnalysisROCM)
|
||||
add_subdirectory(Conversion)
|
||||
add_subdirectory(Dialect)
|
||||
add_subdirectory(Target)
|
||||
2
python/triton/third_party/hip/lib/Conversion/CMakeLists.txt
vendored
Normal file
2
python/triton/third_party/hip/lib/Conversion/CMakeLists.txt
vendored
Normal file
@@ -0,0 +1,2 @@
|
||||
add_subdirectory(TritonToTritonGPUROCM)
|
||||
add_subdirectory(TritonGPUROCMToLLVM)
|
||||
217
python/triton/third_party/hip/lib/Conversion/TritonGPUROCMToLLVM/BarrierOpToLLVM.cpp
vendored
Normal file
217
python/triton/third_party/hip/lib/Conversion/TritonGPUROCMToLLVM/BarrierOpToLLVM.cpp
vendored
Normal file
@@ -0,0 +1,217 @@
|
||||
/*
|
||||
* Copyright (c) 2023 NVIDIA Corporation & Affiliates. All rights reserved.
|
||||
*
|
||||
* Permission is hereby granted, free of charge, to any person obtaining
|
||||
* a copy of this software and associated documentation files
|
||||
* (the "Software"), to deal in the Software without restriction,
|
||||
* including without limitation the rights to use, copy, modify, merge,
|
||||
* publish, distribute, sublicense, and/or sell copies of the Software,
|
||||
* and to permit persons to whom the Software is furnished to do so,
|
||||
* subject to the following conditions:
|
||||
*
|
||||
* The above copyright notice and this permission notice shall be
|
||||
* included in all copies or substantial portions of the Software.
|
||||
*
|
||||
* THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND,
|
||||
* EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF
|
||||
* MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT.
|
||||
* IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY
|
||||
* CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION OF CONTRACT,
|
||||
* TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION WITH THE
|
||||
* SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE.
|
||||
*/
|
||||
|
||||
#include "BarrierOpToLLVM.h"
|
||||
|
||||
using namespace mlir;
|
||||
using namespace mlir::triton;
|
||||
|
||||
// --------------------------------------------------------------------------
|
||||
// -- MBarrier related Ops lowering, to be moved to a seperate file ---------
|
||||
// --------------------------------------------------------------------------
|
||||
struct AllocMBarrierOpConversion : public ConvertTritonGPUOpToLLVMPattern<
|
||||
triton::nvidia_gpu::AllocMBarrierOp> {
|
||||
using ConvertTritonGPUOpToLLVMPattern<
|
||||
triton::nvidia_gpu::AllocMBarrierOp>::ConvertTritonGPUOpToLLVMPattern;
|
||||
|
||||
LogicalResult
|
||||
matchAndRewrite(triton::nvidia_gpu::AllocMBarrierOp op, OpAdaptor adaptor,
|
||||
ConversionPatternRewriter &rewriter) const override {
|
||||
Location loc = op->getLoc();
|
||||
Value smemBase = getSharedMemoryBase(loc, rewriter, op.getResult());
|
||||
auto resultTy = op.getType();
|
||||
auto resultTensorTy = resultTy.dyn_cast<RankedTensorType>();
|
||||
Type elemPtrTy;
|
||||
if (resultTensorTy) {
|
||||
auto llvmElemTy =
|
||||
getTypeConverter()->convertType(resultTensorTy.getElementType());
|
||||
elemPtrTy = ptr_ty(llvmElemTy, 3);
|
||||
} else {
|
||||
elemPtrTy = getTypeConverter()->convertType(resultTy);
|
||||
}
|
||||
smemBase = bitcast(smemBase, elemPtrTy);
|
||||
auto threadId = getThreadId(rewriter, loc);
|
||||
auto pred = icmp_eq(threadId, i32_val(0));
|
||||
int numMBarriers = 1;
|
||||
if (resultTensorTy) {
|
||||
assert(resultTensorTy.getRank() == 1 &&
|
||||
"unexpected rank for AllocMBarrierOp");
|
||||
numMBarriers = resultTensorTy.getShape()[0];
|
||||
}
|
||||
for (int i = 0; i < numMBarriers; ++i) {
|
||||
Value smem = smemBase;
|
||||
if (i > 0) {
|
||||
smem = gep(elemPtrTy, smem, i32_val(i));
|
||||
}
|
||||
rewriter.create<triton::nvgpu::MBarrierInitOp>(loc, smem, pred,
|
||||
op.getCount());
|
||||
}
|
||||
if (resultTensorTy) {
|
||||
auto smemObj = SharedMemoryObject(smemBase, resultTensorTy.getShape(),
|
||||
{0}, loc, rewriter);
|
||||
auto retVal = getStructFromSharedMemoryObject(loc, smemObj, rewriter);
|
||||
rewriter.replaceOp(op, retVal);
|
||||
} else {
|
||||
rewriter.replaceOp(op, smemBase);
|
||||
}
|
||||
return success();
|
||||
}
|
||||
};
|
||||
|
||||
struct MBarrierArriveOpConversion : public ConvertTritonGPUOpToLLVMPattern<
|
||||
triton::nvidia_gpu::MBarrierArriveOp> {
|
||||
using ConvertTritonGPUOpToLLVMPattern<
|
||||
triton::nvidia_gpu::MBarrierArriveOp>::ConvertTritonGPUOpToLLVMPattern;
|
||||
|
||||
LogicalResult
|
||||
matchAndRewrite(triton::nvidia_gpu::MBarrierArriveOp op, OpAdaptor adaptor,
|
||||
ConversionPatternRewriter &rewriter) const override {
|
||||
Location loc = op->getLoc();
|
||||
auto mbarrier = adaptor.getMbarrier();
|
||||
bool trackAsyncOp = op.getTrackAsyncOp();
|
||||
triton::nvgpu::MBarriveType type = triton::nvgpu::MBarriveType::normal;
|
||||
uint32_t txCount = op.getTxCount();
|
||||
auto remoteCtaId = adaptor.getRemoteCtaId();
|
||||
if (trackAsyncOp) {
|
||||
type = triton::nvgpu::MBarriveType::cp_async;
|
||||
} else if (remoteCtaId) {
|
||||
assert(txCount == 0 &&
|
||||
"remote arrive of transaction mbarrier is not implemented yet");
|
||||
type = triton::nvgpu::MBarriveType::remote;
|
||||
} else if (txCount > 0) {
|
||||
type = triton::nvgpu::MBarriveType::expect_tx;
|
||||
}
|
||||
Value pred = adaptor.getPred();
|
||||
if (pred == nullptr) {
|
||||
pred = int_val(/*width*/ 1, 1);
|
||||
}
|
||||
rewriter.replaceOpWithNewOp<triton::nvgpu::MBarrierArriveOp>(
|
||||
op, mbarrier, pred, remoteCtaId, type, txCount);
|
||||
return success();
|
||||
}
|
||||
};
|
||||
|
||||
struct MBarrierWaitOpConversion : public ConvertTritonGPUOpToLLVMPattern<
|
||||
triton::nvidia_gpu::MBarrierWaitOp> {
|
||||
using ConvertTritonGPUOpToLLVMPattern<
|
||||
triton::nvidia_gpu::MBarrierWaitOp>::ConvertTritonGPUOpToLLVMPattern;
|
||||
|
||||
LogicalResult
|
||||
matchAndRewrite(triton::nvidia_gpu::MBarrierWaitOp op, OpAdaptor adaptor,
|
||||
ConversionPatternRewriter &rewriter) const override {
|
||||
Location loc = op->getLoc();
|
||||
rewriter.replaceOpWithNewOp<triton::nvgpu::MBarrierWaitOp>(
|
||||
op, adaptor.getMbarrier(), adaptor.getPhase());
|
||||
return success();
|
||||
}
|
||||
};
|
||||
|
||||
struct ExtractMBarrierOpConversion
|
||||
: public ConvertTritonGPUOpToLLVMPattern<
|
||||
triton::nvidia_gpu::ExtractMBarrierOp> {
|
||||
using ConvertTritonGPUOpToLLVMPattern<
|
||||
triton::nvidia_gpu::ExtractMBarrierOp>::ConvertTritonGPUOpToLLVMPattern;
|
||||
|
||||
LogicalResult
|
||||
matchAndRewrite(triton::nvidia_gpu::ExtractMBarrierOp op, OpAdaptor adaptor,
|
||||
ConversionPatternRewriter &rewriter) const override {
|
||||
Location loc = op->getLoc();
|
||||
auto elemTy =
|
||||
op.getTensor().getType().cast<RankedTensorType>().getElementType();
|
||||
auto tensorStruct = adaptor.getTensor();
|
||||
auto index = adaptor.getIndex();
|
||||
auto ptrTy =
|
||||
LLVM::LLVMPointerType::get(getTypeConverter()->convertType(elemTy), 3);
|
||||
auto basePtr =
|
||||
extract_val(ptrTy, tensorStruct, rewriter.getDenseI64ArrayAttr(0));
|
||||
Value result = gep(ptrTy, basePtr, index);
|
||||
rewriter.replaceOp(op, result);
|
||||
return success();
|
||||
}
|
||||
};
|
||||
|
||||
struct NamedBarrierArriveOpConversion
|
||||
: public ConvertTritonGPUOpToLLVMPattern<
|
||||
triton::nvidia_gpu::NamedBarrierArriveOp> {
|
||||
using ConvertTritonGPUOpToLLVMPattern<
|
||||
triton::nvidia_gpu::NamedBarrierArriveOp>::
|
||||
ConvertTritonGPUOpToLLVMPattern;
|
||||
|
||||
LogicalResult
|
||||
matchAndRewrite(triton::nvidia_gpu::NamedBarrierArriveOp op,
|
||||
OpAdaptor adaptor,
|
||||
ConversionPatternRewriter &rewriter) const override {
|
||||
Location loc = op->getLoc();
|
||||
rewriter.replaceOpWithNewOp<triton::nvgpu::NamedBarrierArriveOp>(
|
||||
op, adaptor.getBar(), adaptor.getNumThreads());
|
||||
return success();
|
||||
}
|
||||
};
|
||||
|
||||
struct NamedBarrierWaitOpConversion
|
||||
: public ConvertTritonGPUOpToLLVMPattern<
|
||||
triton::nvidia_gpu::NamedBarrierWaitOp> {
|
||||
using ConvertTritonGPUOpToLLVMPattern<
|
||||
triton::nvidia_gpu::NamedBarrierWaitOp>::ConvertTritonGPUOpToLLVMPattern;
|
||||
|
||||
LogicalResult
|
||||
matchAndRewrite(triton::nvidia_gpu::NamedBarrierWaitOp op, OpAdaptor adaptor,
|
||||
ConversionPatternRewriter &rewriter) const override {
|
||||
Location loc = op->getLoc();
|
||||
rewriter.replaceOpWithNewOp<triton::nvgpu::NamedBarrierWaitOp>(
|
||||
op, adaptor.getBar(), adaptor.getNumThreads());
|
||||
return success();
|
||||
}
|
||||
};
|
||||
|
||||
struct FenceAsyncSharedOpConversion
|
||||
: public ConvertTritonGPUOpToLLVMPattern<
|
||||
triton::nvidia_gpu::FenceAsyncSharedOp> {
|
||||
using ConvertTritonGPUOpToLLVMPattern<
|
||||
triton::nvidia_gpu::FenceAsyncSharedOp>::ConvertTritonGPUOpToLLVMPattern;
|
||||
|
||||
LogicalResult
|
||||
matchAndRewrite(triton::nvidia_gpu::FenceAsyncSharedOp op, OpAdaptor adaptor,
|
||||
ConversionPatternRewriter &rewriter) const override {
|
||||
Location loc = op->getLoc();
|
||||
rewriter.replaceOpWithNewOp<triton::nvgpu::FenceAsyncSharedOp>(
|
||||
op, adaptor.getBCluster());
|
||||
return success();
|
||||
}
|
||||
};
|
||||
|
||||
void populateBarrierOpToLLVMPatterns(
|
||||
TritonGPUToLLVMTypeConverter &typeConverter, RewritePatternSet &patterns,
|
||||
int numWarps, ModuleAxisInfoAnalysis &axisInfoAnalysis,
|
||||
ModuleAllocation &allocation, PatternBenefit benefit) {
|
||||
patterns.add<AllocMBarrierOpConversion>(typeConverter, allocation, benefit);
|
||||
patterns.add<MBarrierArriveOpConversion>(typeConverter, allocation, benefit);
|
||||
patterns.add<MBarrierWaitOpConversion>(typeConverter, allocation, benefit);
|
||||
patterns.add<ExtractMBarrierOpConversion>(typeConverter, allocation, benefit);
|
||||
patterns.add<NamedBarrierArriveOpConversion>(typeConverter, allocation,
|
||||
benefit);
|
||||
patterns.add<NamedBarrierWaitOpConversion>(typeConverter, allocation,
|
||||
benefit);
|
||||
patterns.add<FenceAsyncSharedOpConversion>(typeConverter, allocation,
|
||||
benefit);
|
||||
}
|
||||
37
python/triton/third_party/hip/lib/Conversion/TritonGPUROCMToLLVM/BarrierOpToLLVM.h
vendored
Normal file
37
python/triton/third_party/hip/lib/Conversion/TritonGPUROCMToLLVM/BarrierOpToLLVM.h
vendored
Normal file
@@ -0,0 +1,37 @@
|
||||
/*
|
||||
* Copyright (c) 2023 NVIDIA Corporation & Affiliates. All rights reserved.
|
||||
*
|
||||
* Permission is hereby granted, free of charge, to any person obtaining
|
||||
* a copy of this software and associated documentation files
|
||||
* (the "Software"), to deal in the Software without restriction,
|
||||
* including without limitation the rights to use, copy, modify, merge,
|
||||
* publish, distribute, sublicense, and/or sell copies of the Software,
|
||||
* and to permit persons to whom the Software is furnished to do so,
|
||||
* subject to the following conditions:
|
||||
*
|
||||
* The above copyright notice and this permission notice shall be
|
||||
* included in all copies or substantial portions of the Software.
|
||||
*
|
||||
* THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND,
|
||||
* EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF
|
||||
* MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT.
|
||||
* IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY
|
||||
* CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION OF CONTRACT,
|
||||
* TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION WITH THE
|
||||
* SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE.
|
||||
*/
|
||||
|
||||
#ifndef TRITON_CONVERSION_TRITONGPUROCM_TO_LLVM_BARRIER_OP_H
|
||||
#define TRITON_CONVERSION_TRITONGPUROCM_TO_LLVM_BARRIER_OP_H
|
||||
|
||||
#include "TritonGPUToLLVMBase.h"
|
||||
|
||||
using namespace mlir;
|
||||
using namespace mlir::triton;
|
||||
|
||||
void populateBarrierOpToLLVMPatterns(
|
||||
TritonGPUToLLVMTypeConverter &typeConverter, RewritePatternSet &patterns,
|
||||
int numWarps, ModuleAxisInfoAnalysis &axisInfoAnalysis,
|
||||
ModuleAllocation &allocation, PatternBenefit benefit);
|
||||
|
||||
#endif
|
||||
69
python/triton/third_party/hip/lib/Conversion/TritonGPUROCMToLLVM/CMakeLists.txt
vendored
Normal file
69
python/triton/third_party/hip/lib/Conversion/TritonGPUROCMToLLVM/CMakeLists.txt
vendored
Normal file
@@ -0,0 +1,69 @@
|
||||
add_library(rocm_libraries SHARED IMPORTED )
|
||||
set_target_properties(rocm_libraries PROPERTIES IMPORTED_LOCATION ${ROCM_LIBRARIES})
|
||||
|
||||
add_mlir_conversion_library(TritonGPUROCMToLLVM
|
||||
ConvertLayoutOpToLLVM/SharedToDotOperandFMA.cpp
|
||||
ConvertLayoutOpToLLVM/SharedToDotOperandMMAv1.cpp
|
||||
ConvertLayoutOpToLLVM/SharedToDotOperandMMAv2.cpp
|
||||
ConvertLayoutOpToLLVM.cpp
|
||||
DotOpToLLVM/FMA.cpp
|
||||
DotOpToLLVM/MMAv1.cpp
|
||||
DotOpToLLVM/MMAv2.cpp
|
||||
DotOpToLLVM/WGMMA.cpp
|
||||
DotOpToLLVM.cpp
|
||||
ElementwiseOpToLLVM.cpp
|
||||
LoadStoreOpToLLVM.cpp
|
||||
BarrierOpToLLVM.cpp
|
||||
TritonGPUToLLVM.cpp
|
||||
GCNAsmFormat.cpp
|
||||
PTXAsmFormat.cpp
|
||||
ConvertLayoutOpToLLVM/SharedToDotOperandFMA.cpp
|
||||
ConvertLayoutOpToLLVM/SharedToDotOperandMMAv1.cpp
|
||||
ConvertLayoutOpToLLVM/SharedToDotOperandMMAv2.cpp
|
||||
ConvertLayoutOpToLLVM/SharedToDotOperandMFMA.cpp
|
||||
ConvertLayoutOpToLLVM.cpp
|
||||
DotOpToLLVM/FMA.cpp
|
||||
DotOpToLLVM/MMAv1.cpp
|
||||
DotOpToLLVM/MMAv2.cpp
|
||||
DotOpToLLVM/MFMA.cpp
|
||||
DotOpToLLVM.cpp
|
||||
ElementwiseOpToLLVM.cpp
|
||||
LoadStoreOpToLLVM.cpp
|
||||
TritonGPUToLLVM.cpp
|
||||
TritonGPUToLLVMPass.cpp
|
||||
GCNAsmFormat.cpp
|
||||
PTXAsmFormat.cpp
|
||||
ReduceOpToLLVM.cpp
|
||||
ScanOpToLLVM.cpp
|
||||
TypeConverter.cpp
|
||||
Utility.cpp
|
||||
ViewOpToLLVM.cpp
|
||||
TensorPtrOpsToLLVM.cpp
|
||||
ClusterOpsToLLVM.cpp
|
||||
RegReallocOpToLLVM.cpp
|
||||
|
||||
ADDITIONAL_HEADER_DIRS
|
||||
${PROJECT_SOURCE_DIR}/include/triton/Conversion/TritonGPUROCMToLLVM
|
||||
${PROJECT_BINARY_DIR}/include/triton/Conversion/TritonGPUROCMToLLVM
|
||||
|
||||
DEPENDS
|
||||
TritonGPUROCMConversionPassIncGen
|
||||
|
||||
LINK_COMPONENTS
|
||||
Core
|
||||
|
||||
LINK_LIBS PUBLIC
|
||||
MLIRIR
|
||||
MLIRPass
|
||||
MLIRGPUOps
|
||||
MLIRGPUToNVVMTransforms
|
||||
MLIRGPUToROCDLTransforms
|
||||
MLIRGPUTransforms
|
||||
TritonAnalysisROCM
|
||||
TritonIR
|
||||
TritonGPUROCMIR
|
||||
TritonGPUROCMTransforms
|
||||
TritonNvidiaGPUTransforms
|
||||
NVGPUIR
|
||||
rocm_libraries
|
||||
)
|
||||
62
python/triton/third_party/hip/lib/Conversion/TritonGPUROCMToLLVM/ClusterOpsToLLVM.cpp
vendored
Normal file
62
python/triton/third_party/hip/lib/Conversion/TritonGPUROCMToLLVM/ClusterOpsToLLVM.cpp
vendored
Normal file
@@ -0,0 +1,62 @@
|
||||
/*
|
||||
* Copyright (c) 2023 NVIDIA Corporation & Affiliates. All rights reserved.
|
||||
*
|
||||
* Permission is hereby granted, free of charge, to any person obtaining
|
||||
* a copy of this software and associated documentation files
|
||||
* (the "Software"), to deal in the Software without restriction,
|
||||
* including without limitation the rights to use, copy, modify, merge,
|
||||
* publish, distribute, sublicense, and/or sell copies of the Software,
|
||||
* and to permit persons to whom the Software is furnished to do so,
|
||||
* subject to the following conditions:
|
||||
*
|
||||
* The above copyright notice and this permission notice shall be
|
||||
* included in all copies or substantial portions of the Software.
|
||||
*
|
||||
* THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND,
|
||||
* EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF
|
||||
* MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT.
|
||||
* IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY
|
||||
* CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION OF CONTRACT,
|
||||
* TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION WITH THE
|
||||
* SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE.
|
||||
*/
|
||||
|
||||
#include "ClusterOpsToLLVM.h"
|
||||
using namespace mlir;
|
||||
using namespace mlir::triton;
|
||||
|
||||
struct ClusterArriveOpConversion : public ConvertTritonGPUOpToLLVMPattern<
|
||||
triton::nvidia_gpu::ClusterArriveOp> {
|
||||
using ConvertTritonGPUOpToLLVMPattern<
|
||||
triton::nvidia_gpu::ClusterArriveOp>::ConvertTritonGPUOpToLLVMPattern;
|
||||
|
||||
LogicalResult
|
||||
matchAndRewrite(triton::nvidia_gpu::ClusterArriveOp op, OpAdaptor adaptor,
|
||||
ConversionPatternRewriter &rewriter) const override {
|
||||
rewriter.replaceOpWithNewOp<triton::nvgpu::ClusterArriveOp>(
|
||||
op, op.getRelaxed());
|
||||
return success();
|
||||
}
|
||||
};
|
||||
|
||||
struct ClusterWaitOpConversion : public ConvertTritonGPUOpToLLVMPattern<
|
||||
triton::nvidia_gpu::ClusterWaitOp> {
|
||||
using ConvertTritonGPUOpToLLVMPattern<
|
||||
triton::nvidia_gpu::ClusterWaitOp>::ConvertTritonGPUOpToLLVMPattern;
|
||||
|
||||
LogicalResult
|
||||
matchAndRewrite(triton::nvidia_gpu::ClusterWaitOp op, OpAdaptor adaptor,
|
||||
ConversionPatternRewriter &rewriter) const override {
|
||||
rewriter.replaceOpWithNewOp<triton::nvgpu::ClusterWaitOp>(op);
|
||||
return success();
|
||||
}
|
||||
};
|
||||
|
||||
void populateClusterOpsToLLVMPatterns(
|
||||
TritonGPUToLLVMTypeConverter &typeConverter, RewritePatternSet &patterns,
|
||||
int numWarps, ModuleAxisInfoAnalysis &axisInfoAnalysis,
|
||||
ModuleAllocation &allocation, PatternBenefit benefit) {
|
||||
patterns.add<ClusterArriveOpConversion>(typeConverter, benefit);
|
||||
patterns.add<ClusterWaitOpConversion>(typeConverter, benefit);
|
||||
return;
|
||||
}
|
||||
37
python/triton/third_party/hip/lib/Conversion/TritonGPUROCMToLLVM/ClusterOpsToLLVM.h
vendored
Normal file
37
python/triton/third_party/hip/lib/Conversion/TritonGPUROCMToLLVM/ClusterOpsToLLVM.h
vendored
Normal file
@@ -0,0 +1,37 @@
|
||||
/*
|
||||
* Copyright (c) 2023 NVIDIA Corporation & Affiliates. All rights reserved.
|
||||
*
|
||||
* Permission is hereby granted, free of charge, to any person obtaining
|
||||
* a copy of this software and associated documentation files
|
||||
* (the "Software"), to deal in the Software without restriction,
|
||||
* including without limitation the rights to use, copy, modify, merge,
|
||||
* publish, distribute, sublicense, and/or sell copies of the Software,
|
||||
* and to permit persons to whom the Software is furnished to do so,
|
||||
* subject to the following conditions:
|
||||
*
|
||||
* The above copyright notice and this permission notice shall be
|
||||
* included in all copies or substantial portions of the Software.
|
||||
*
|
||||
* THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND,
|
||||
* EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF
|
||||
* MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT.
|
||||
* IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY
|
||||
* CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION OF CONTRACT,
|
||||
* TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION WITH THE
|
||||
* SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE.
|
||||
*/
|
||||
|
||||
#ifndef TRITON_CONVERSION_TRITONGPUROCM_TO_LLVM_CLUSTER_OPS_H
|
||||
#define TRITON_CONVERSION_TRITONGPUROCM_TO_LLVM_CLUSTER_OPS_H
|
||||
|
||||
#include "TritonGPUToLLVMBase.h"
|
||||
|
||||
using namespace mlir;
|
||||
using namespace mlir::triton;
|
||||
|
||||
void populateClusterOpsToLLVMPatterns(
|
||||
TritonGPUToLLVMTypeConverter &typeConverter, RewritePatternSet &patterns,
|
||||
int numWarps, ModuleAxisInfoAnalysis &axisInfoAnalysis,
|
||||
ModuleAllocation &allocation, PatternBenefit benefit);
|
||||
|
||||
#endif
|
||||
1107
python/triton/third_party/hip/lib/Conversion/TritonGPUROCMToLLVM/ConvertLayoutOpToLLVM.cpp
vendored
Normal file
1107
python/triton/third_party/hip/lib/Conversion/TritonGPUROCMToLLVM/ConvertLayoutOpToLLVM.cpp
vendored
Normal file
File diff suppressed because it is too large
Load Diff
18
python/triton/third_party/hip/lib/Conversion/TritonGPUROCMToLLVM/ConvertLayoutOpToLLVM.h
vendored
Normal file
18
python/triton/third_party/hip/lib/Conversion/TritonGPUROCMToLLVM/ConvertLayoutOpToLLVM.h
vendored
Normal file
@@ -0,0 +1,18 @@
|
||||
#ifndef TRITON_CONVERSION_TRITONGPUROCM_TO_LLVM_CONVERT_LAYOUT_OP_H
|
||||
#define TRITON_CONVERSION_TRITONGPUROCM_TO_LLVM_CONVERT_LAYOUT_OP_H
|
||||
|
||||
#include "TritonGPUToLLVMBase.h"
|
||||
|
||||
using namespace mlir;
|
||||
using namespace mlir::triton;
|
||||
|
||||
using ::mlir::triton::gpu_rocm::DotOperandEncodingAttr;
|
||||
|
||||
void populateConvertLayoutOpToLLVMPatterns(
|
||||
TritonGPUToLLVMTypeConverter &typeConverter, RewritePatternSet &patterns,
|
||||
int numWarps, ModuleAxisInfoAnalysis &axisInfoAnalysis,
|
||||
ModuleAllocation &allocation,
|
||||
ConvertTritonGPUOpToLLVMPatternBase::IndexCacheInfo &indexCacheInfo,
|
||||
PatternBenefit benefit);
|
||||
|
||||
#endif
|
||||
@@ -0,0 +1,232 @@
|
||||
#include "../ConvertLayoutOpToLLVM.h"
|
||||
#include "../Utility.h"
|
||||
|
||||
using ValueTable = std::map<std::pair<int, int>, Value>;
|
||||
using ::mlir::LLVM::delinearize;
|
||||
using ::mlir::LLVM::getSharedMemoryObjectFromStruct;
|
||||
using ::mlir::LLVM::getStridesFromShapeAndOrder;
|
||||
using ::mlir::LLVM::linearize;
|
||||
using ::mlir::triton::gpu_rocm::DotOperandEncodingAttr;
|
||||
using ::mlir::triton::gpu_rocm::getContigPerThread;
|
||||
using ::mlir::triton::gpu_rocm::getOrder;
|
||||
using ::mlir::triton::gpu_rocm::getShapePerCTA;
|
||||
using ::mlir::triton::gpu_rocm::getSizePerThread;
|
||||
using ::mlir::triton::gpu_rocm::getTotalElemsPerThread;
|
||||
using ::mlir::triton::gpu_rocm::isaDistributedLayout;
|
||||
using ::mlir::triton::gpu_rocm::SharedEncodingAttr;
|
||||
|
||||
SmallVector<Value>
|
||||
getThreadIds(Value threadId, ArrayRef<unsigned int> shapePerCTATile,
|
||||
ArrayRef<unsigned int> sizePerThread, ArrayRef<unsigned int> order,
|
||||
ConversionPatternRewriter &rewriter, Location loc) {
|
||||
int dim = order.size();
|
||||
SmallVector<Value> threadIds(dim);
|
||||
for (unsigned k = 0; k < dim - 1; k++) {
|
||||
Value dimK = i32_val(shapePerCTATile[order[k]] / sizePerThread[order[k]]);
|
||||
Value rem = urem(threadId, dimK);
|
||||
threadId = udiv(threadId, dimK);
|
||||
threadIds[order[k]] = rem;
|
||||
}
|
||||
Value dimK = i32_val(shapePerCTATile[order[dim - 1]]);
|
||||
threadIds[order[dim - 1]] = urem(threadId, dimK);
|
||||
return threadIds;
|
||||
}
|
||||
|
||||
// Get shapePerCTATile for M or N axis.
|
||||
int getShapePerCTATileForMN(BlockedEncodingAttr layout, bool isM) {
|
||||
auto order = layout.getOrder();
|
||||
auto shapePerCTATile = getShapePerCTATile(layout);
|
||||
|
||||
int mShapePerCTATile =
|
||||
order[0] == 1 ? shapePerCTATile[order[1]] : shapePerCTATile[order[0]];
|
||||
int nShapePerCTATile =
|
||||
order[0] == 0 ? shapePerCTATile[order[1]] : shapePerCTATile[order[0]];
|
||||
return isM ? mShapePerCTATile : nShapePerCTATile;
|
||||
}
|
||||
|
||||
// Get sizePerThread for M or N axis.
|
||||
int getSizePerThreadForMN(BlockedEncodingAttr layout, bool isM) {
|
||||
auto order = layout.getOrder();
|
||||
auto sizePerThread = getSizePerThread(layout);
|
||||
|
||||
int mSizePerThread =
|
||||
order[0] == 1 ? sizePerThread[order[1]] : sizePerThread[order[0]];
|
||||
int nSizePerThread =
|
||||
order[0] == 0 ? sizePerThread[order[1]] : sizePerThread[order[0]];
|
||||
return isM ? mSizePerThread : nSizePerThread;
|
||||
}
|
||||
|
||||
Value getStructFromValueTable(ArrayRef<Value> vals,
|
||||
ConversionPatternRewriter &rewriter, Location loc,
|
||||
TritonGPUToLLVMTypeConverter *typeConverter,
|
||||
Type elemTy) {
|
||||
SmallVector<Type> elemTypes(vals.size(), elemTy);
|
||||
SmallVector<Value> elems;
|
||||
elems.reserve(vals.size());
|
||||
for (auto &val : vals) {
|
||||
elems.push_back(val);
|
||||
}
|
||||
MLIRContext *ctx = elemTy.getContext();
|
||||
Type structTy = struct_ty(elemTypes);
|
||||
return typeConverter->packLLElements(loc, elems, rewriter, structTy);
|
||||
}
|
||||
|
||||
ValueTable getValueTableFromStruct(Value val, int K, int n0, int shapePerCTA,
|
||||
int sizePerThread,
|
||||
ConversionPatternRewriter &rewriter,
|
||||
Location loc,
|
||||
TritonGPUToLLVMTypeConverter *typeConverter,
|
||||
Type type) {
|
||||
ValueTable res;
|
||||
auto elems = typeConverter->unpackLLElements(loc, val, rewriter, type);
|
||||
int index = 0;
|
||||
for (unsigned k = 0; k < K; ++k) {
|
||||
for (unsigned m = 0; m < n0; m += shapePerCTA)
|
||||
for (unsigned mm = 0; mm < sizePerThread; ++mm) {
|
||||
res[{m + mm, k}] = elems[index++];
|
||||
}
|
||||
}
|
||||
return res;
|
||||
}
|
||||
|
||||
Value loadAFMA(Value A, Value llA, BlockedEncodingAttr dLayout, Value thread,
|
||||
Location loc, TritonGPUToLLVMTypeConverter *typeConverter,
|
||||
ConversionPatternRewriter &rewriter) {
|
||||
auto aTensorTy = A.getType().cast<RankedTensorType>();
|
||||
auto aLayout = aTensorTy.getEncoding().cast<SharedEncodingAttr>();
|
||||
auto aShapePerCTA = getShapePerCTA(aTensorTy);
|
||||
|
||||
auto aOrder = aLayout.getOrder();
|
||||
auto order = dLayout.getOrder();
|
||||
|
||||
bool isARow = aOrder[0] == 1;
|
||||
|
||||
auto aSmem = getSharedMemoryObjectFromStruct(loc, llA, rewriter);
|
||||
Value strideAM = aSmem.strides[0];
|
||||
Value strideAK = aSmem.strides[1];
|
||||
Value strideA0 = isARow ? strideAK : strideAM;
|
||||
Value strideA1 = isARow ? strideAM : strideAK;
|
||||
int aNumPtr = 8;
|
||||
int K = aShapePerCTA[1];
|
||||
int M = aShapePerCTA[0];
|
||||
|
||||
auto shapePerCTATile = getShapePerCTATile(dLayout);
|
||||
auto sizePerThread = getSizePerThread(dLayout);
|
||||
|
||||
Value _0 = i32_val(0);
|
||||
|
||||
Value mContig = i32_val(sizePerThread[order[1]]);
|
||||
|
||||
// threadId in blocked layout
|
||||
auto threadIds = getThreadIds(thread, shapePerCTATile, sizePerThread, order,
|
||||
rewriter, loc);
|
||||
Value threadIdM = threadIds[0];
|
||||
|
||||
Value offA0 = isARow ? _0 : mul(threadIdM, mContig);
|
||||
Value offA1 = isARow ? mul(threadIdM, mContig) : _0;
|
||||
SmallVector<Value> aOff(aNumPtr);
|
||||
for (int i = 0; i < aNumPtr; ++i) {
|
||||
aOff[i] = add(mul(offA0, strideA0), mul(offA1, strideA1));
|
||||
}
|
||||
auto elemTy = typeConverter->convertType(
|
||||
A.getType().cast<RankedTensorType>().getElementType());
|
||||
|
||||
Type ptrTy = ptr_ty(elemTy, 3);
|
||||
SmallVector<Value> aPtrs(aNumPtr);
|
||||
for (int i = 0; i < aNumPtr; ++i)
|
||||
aPtrs[i] = gep(ptrTy, aSmem.base, aOff[i]);
|
||||
|
||||
SmallVector<Value> vas;
|
||||
|
||||
int mShapePerCTATile = getShapePerCTATileForMN(dLayout, true /*isM*/);
|
||||
int mSizePerThread = getSizePerThreadForMN(dLayout, true /*isM*/);
|
||||
|
||||
for (unsigned k = 0; k < K; ++k)
|
||||
for (unsigned m = 0; m < M; m += mShapePerCTATile)
|
||||
for (unsigned mm = 0; mm < mSizePerThread; ++mm) {
|
||||
Value offset =
|
||||
add(mul(i32_val(m + mm), strideAM), mul(i32_val(k), strideAK));
|
||||
Value pa = gep(ptrTy, aPtrs[0], offset);
|
||||
Value va = load(pa);
|
||||
vas.emplace_back(va);
|
||||
}
|
||||
|
||||
return getStructFromValueTable(vas, rewriter, loc, typeConverter, elemTy);
|
||||
}
|
||||
|
||||
Value loadBFMA(Value B, Value llB, BlockedEncodingAttr dLayout, Value thread,
|
||||
Location loc, TritonGPUToLLVMTypeConverter *typeConverter,
|
||||
ConversionPatternRewriter &rewriter) {
|
||||
auto bTensorTy = B.getType().cast<RankedTensorType>();
|
||||
auto bLayout = bTensorTy.getEncoding().cast<SharedEncodingAttr>();
|
||||
auto bShapePerCTA = getShapePerCTA(bTensorTy);
|
||||
|
||||
auto bOrder = bLayout.getOrder();
|
||||
auto order = dLayout.getOrder();
|
||||
|
||||
bool isBRow = bOrder[0] == 1;
|
||||
|
||||
auto bSmem = getSharedMemoryObjectFromStruct(loc, llB, rewriter);
|
||||
Value strideBN = bSmem.strides[1];
|
||||
Value strideBK = bSmem.strides[0];
|
||||
Value strideB0 = isBRow ? strideBN : strideBK;
|
||||
Value strideB1 = isBRow ? strideBK : strideBN;
|
||||
int bNumPtr = 8;
|
||||
int K = bShapePerCTA[0];
|
||||
int N = bShapePerCTA[1];
|
||||
|
||||
auto shapePerCTATile = getShapePerCTATile(dLayout);
|
||||
auto sizePerThread = getSizePerThread(dLayout);
|
||||
|
||||
Value _0 = i32_val(0);
|
||||
|
||||
Value nContig = i32_val(sizePerThread[order[0]]);
|
||||
|
||||
// threadId in blocked layout
|
||||
auto threadIds = getThreadIds(thread, shapePerCTATile, sizePerThread, order,
|
||||
rewriter, loc);
|
||||
Value threadIdN = threadIds[1];
|
||||
|
||||
Value offB0 = isBRow ? mul(threadIdN, nContig) : _0;
|
||||
Value offB1 = isBRow ? _0 : mul(threadIdN, nContig);
|
||||
SmallVector<Value> bOff(bNumPtr);
|
||||
for (int i = 0; i < bNumPtr; ++i) {
|
||||
bOff[i] = add(mul(offB0, strideB0), mul(offB1, strideB1));
|
||||
}
|
||||
auto elemTy = typeConverter->convertType(
|
||||
B.getType().cast<RankedTensorType>().getElementType());
|
||||
|
||||
Type ptrTy = ptr_ty(elemTy, 3);
|
||||
SmallVector<Value> bPtrs(bNumPtr);
|
||||
for (int i = 0; i < bNumPtr; ++i)
|
||||
bPtrs[i] = gep(ptrTy, bSmem.base, bOff[i]);
|
||||
|
||||
SmallVector<Value> vbs;
|
||||
|
||||
int nShapePerCTATile = getShapePerCTATileForMN(dLayout, false /*isM*/);
|
||||
int nSizePerThread = getSizePerThreadForMN(dLayout, false /*isM*/);
|
||||
|
||||
for (unsigned k = 0; k < K; ++k)
|
||||
for (unsigned n = 0; n < N; n += nShapePerCTATile)
|
||||
for (unsigned nn = 0; nn < nSizePerThread; ++nn) {
|
||||
Value offset =
|
||||
add(mul(i32_val(n + nn), strideBN), mul(i32_val(k), strideBK));
|
||||
Value pb = gep(ptrTy, bPtrs[0], offset);
|
||||
Value vb = load(pb);
|
||||
vbs.emplace_back(vb);
|
||||
}
|
||||
|
||||
return getStructFromValueTable(vbs, rewriter, loc, typeConverter, elemTy);
|
||||
}
|
||||
|
||||
namespace SharedToDotOperandFMA {
|
||||
Value convertLayout(int opIdx, Value val, Value llVal,
|
||||
BlockedEncodingAttr dLayout, Value thread, Location loc,
|
||||
TritonGPUToLLVMTypeConverter *typeConverter,
|
||||
ConversionPatternRewriter &rewriter) {
|
||||
if (opIdx == 0)
|
||||
return loadAFMA(val, llVal, dLayout, thread, loc, typeConverter, rewriter);
|
||||
else
|
||||
return loadBFMA(val, llVal, dLayout, thread, loc, typeConverter, rewriter);
|
||||
}
|
||||
} // namespace SharedToDotOperandFMA
|
||||
@@ -0,0 +1,725 @@
|
||||
#if 1
|
||||
|
||||
#include "../ConvertLayoutOpToLLVM.h"
|
||||
#include "../Utility.h"
|
||||
|
||||
using ::mlir::triton::gpu_rocm::DotOperandEncodingAttr;
|
||||
using ::mlir::triton::gpu_rocm::getOrder;
|
||||
using ::mlir::triton::gpu_rocm::getShapePerCTA;
|
||||
using ::mlir::triton::gpu_rocm::SharedEncodingAttr;
|
||||
|
||||
namespace {
|
||||
|
||||
Type getShemPtrTy(Type elemTy) {
|
||||
if (elemTy.isBF16()) {
|
||||
auto ctx = elemTy.getContext();
|
||||
return ptr_ty(type::i16Ty(ctx), 3);
|
||||
}
|
||||
return ptr_ty(elemTy, 3);
|
||||
}
|
||||
|
||||
// Get a waveId for M axis.
|
||||
Value getWaveM(ConversionPatternRewriter &rewriter, Location loc, Value wave,
|
||||
const ArrayRef<unsigned int> &wpt, int elemPerInstr, int M) {
|
||||
return urem(urem(wave, i32_val(wpt[0])), i32_val(M / elemPerInstr));
|
||||
}
|
||||
// Get a waveId for N axis.
|
||||
Value getWaveN(ConversionPatternRewriter &rewriter, Location loc, Value wave,
|
||||
const ArrayRef<unsigned int> &wpt, int elemPerInstr, int N) {
|
||||
Value waveMN = udiv(wave, i32_val(wpt[0]));
|
||||
return urem(urem(waveMN, i32_val(wpt[1])), i32_val(N / elemPerInstr));
|
||||
}
|
||||
|
||||
} // namespace
|
||||
|
||||
namespace SharedToDotOperandMFMA {
|
||||
|
||||
/**
|
||||
* @brief swizzling tensor element indexes according pattern encoded in
|
||||
* SharedEncodingAttr
|
||||
*
|
||||
* @param rewriter
|
||||
* @param loc
|
||||
* @param row row of target tensor element related to the start of smemObj
|
||||
* @param col col of target tensor element related to the start of smemObj
|
||||
* @param smemObj shared memory object, contains info about tensor in LDS
|
||||
* @param attr layout attribute, contains swizzling info
|
||||
* @return swizzled row, col indexes in tensor notation
|
||||
*/
|
||||
std::pair<mlir::Value, mlir::Value>
|
||||
swizzleIndexes(ConversionPatternRewriter &rewriter, Location loc, Value row,
|
||||
Value col, SharedMemoryObject smemObj, SharedEncodingAttr attr) {
|
||||
(void)smemObj; // unused in current pattern
|
||||
bool transposed = (attr.getOrder()[0] != 1);
|
||||
if (transposed) {
|
||||
// tensor is column-wise, so swapping col and row in computations
|
||||
std::swap(row, col);
|
||||
}
|
||||
auto vec = i32_val(attr.getVec());
|
||||
auto perPhase = i32_val(attr.getPerPhase());
|
||||
auto maxPhase = i32_val(attr.getMaxPhase());
|
||||
|
||||
// Original algorithm taken from getSwizzledSharedPtrs function
|
||||
// (TritonGPUToLLVMBase.h): Basic algorithm for row-major tensor is following:
|
||||
//
|
||||
// phase = (row // perPhase) % maxPhase
|
||||
// colOffSwizzled = ((col // vec) ^ phase) * vec
|
||||
// colOffOrdered = col % vec
|
||||
// colOff = colOffSwizzled + colOffOrdered
|
||||
auto phase = urem(udiv(row, perPhase), maxPhase);
|
||||
auto colOffSwizzled = mul(xor_(udiv(col, vec), phase), vec);
|
||||
auto colOffOrdered = urem(col, vec);
|
||||
auto colOff = add(colOffSwizzled, colOffOrdered);
|
||||
|
||||
if (transposed)
|
||||
return {colOff, row};
|
||||
else
|
||||
return {row, colOff};
|
||||
}
|
||||
|
||||
/**
|
||||
* @brief This function maps particular load of mfma dot operand to element
|
||||
* indexes(row, col)
|
||||
*
|
||||
* Whole tensor is broken into "blocks" of waves along "non-K" axis.
|
||||
* One block could be processed by multiple waves.
|
||||
* One wave works on a piece of tensor size elemsPerInstr[0] x K.
|
||||
* Each of these pieces is broken into "tiles" of size elemsPerInstr[0] x
|
||||
* elemsPerInstr[1].
|
||||
*
|
||||
* Total offset of element is a sum of following values:
|
||||
* 1. Offset of wave block in tensor
|
||||
* 2. Offset of wave inside one wave block
|
||||
* 3. Offset of tile in one wave
|
||||
* 4. Offset of one lane data in a tile
|
||||
* 5. Offset of particular element of tensor processed by one lane
|
||||
*
|
||||
* This function computes these offsets for axies independently
|
||||
*
|
||||
* @param rewriter
|
||||
* @param loc
|
||||
* @param elemsPerInstr operand tile shape consumed by one MFMA instruction
|
||||
* @param waveId id component of 2d wave grid along nono-K axis
|
||||
* @param laneId lane id in warp [0..63]
|
||||
* @param warpsPerGroup number of warps in one block
|
||||
* @param numOfElems number of elements accessed by thread per repetition
|
||||
* @param reps number of instructions repretition to fully cover dot operand
|
||||
* @param smemStrides strides in LDS tensor
|
||||
* @param loadVecSize number of elements loaded by one operation
|
||||
* @param iNonKDim non-K dimension of dot operand
|
||||
* @return vector (i-th element corresponds to i-th load instruction) of
|
||||
* 2-element vectors(tensor row and col).
|
||||
*/
|
||||
llvm::SmallVector<llvm::SmallVector<Value>>
|
||||
computeTensorElemMapping(ConversionPatternRewriter &rewriter, Location loc,
|
||||
const ArrayRef<int64_t> &elemsPerInstr, Value waveId,
|
||||
Value laneId, int warpsPerGroup, int numOfElems,
|
||||
ArrayRef<int64_t> reps, ArrayRef<Value> smemOffsets,
|
||||
int loadVecSize, unsigned iNonKDim) {
|
||||
auto numM = reps[0];
|
||||
auto numK = reps[1];
|
||||
const int loadsPerThread = numOfElems / loadVecSize;
|
||||
llvm::SmallVector<llvm::SmallVector<Value>> mapping(numM * numK *
|
||||
loadsPerThread);
|
||||
|
||||
Value _0 = i32_val(0);
|
||||
Value _32 = i32_val(32);
|
||||
Value nonKDim = i32_val(iNonKDim);
|
||||
|
||||
for (int block = 0; block < numM; ++block) {
|
||||
Value blockVOffset = i32_val(block * elemsPerInstr[0] * warpsPerGroup);
|
||||
Value blockHOffset = _0;
|
||||
Value waveVOffset = mul(waveId, i32_val(elemsPerInstr[0]));
|
||||
Value waveHOffset = _0;
|
||||
for (int tile = 0; tile < numK; ++tile) {
|
||||
Value tileVOffset = _0;
|
||||
Value tileHOffset = i32_val(tile * elemsPerInstr[1]);
|
||||
|
||||
Value laneVOffset = urem(laneId, nonKDim);
|
||||
Value laneHOffset;
|
||||
if (iNonKDim == 32)
|
||||
laneHOffset = select(icmp_uge(laneId, _32), i32_val(numOfElems), _0);
|
||||
else
|
||||
laneHOffset = mul(udiv(laneId, nonKDim), i32_val(numOfElems));
|
||||
|
||||
for (int loadId = 0; loadId < loadsPerThread; ++loadId) {
|
||||
Value elemVOffset = _0;
|
||||
Value elemHOffset = i32_val(loadId * loadVecSize);
|
||||
|
||||
Value sliceVOffset = add(
|
||||
add(add(add(blockVOffset, waveVOffset), tileVOffset), laneVOffset),
|
||||
elemVOffset);
|
||||
Value sliceHOffset = add(
|
||||
add(add(add(blockHOffset, waveHOffset), tileHOffset), laneHOffset),
|
||||
elemHOffset);
|
||||
|
||||
Value row = add(sliceVOffset, smemOffsets[0]);
|
||||
Value col = add(sliceHOffset, smemOffsets[1]);
|
||||
|
||||
mapping[numK * loadsPerThread * block + loadsPerThread * tile +
|
||||
loadId] = {row, col};
|
||||
}
|
||||
}
|
||||
}
|
||||
return mapping;
|
||||
}
|
||||
|
||||
bool isSwizzled(SharedEncodingAttr layout) { return layout.getMaxPhase() != 1; }
|
||||
|
||||
Value computeOffset(ConversionPatternRewriter &rewriter, Location loc,
|
||||
Value row, Value col, SharedMemoryObject smemObj,
|
||||
SharedEncodingAttr srcLayout) {
|
||||
auto [swizzledRow, swizzledCol] =
|
||||
swizzleIndexes(rewriter, loc, row, col, smemObj, srcLayout);
|
||||
auto &strides = smemObj.strides;
|
||||
Value rowOffset = mul(swizzledRow, strides[0]);
|
||||
Value colOffset = mul(swizzledCol, strides[1]);
|
||||
return add(rowOffset, colOffset);
|
||||
}
|
||||
|
||||
llvm::SmallVector<Value>
|
||||
computeOffsetsAType(ConversionPatternRewriter &rewriter, Location loc,
|
||||
const ArrayRef<int64_t> &elemsPerInstr, Value waveId,
|
||||
Value laneId, int warpsPerGroup, int numOfElems,
|
||||
ArrayRef<int64_t> reps, SharedMemoryObject smemObj,
|
||||
SharedEncodingAttr srcLayout, unsigned nonKDim) {
|
||||
SmallVector<Value> strides{smemObj.strides[0], smemObj.strides[1]};
|
||||
SmallVector<Value> offsets{smemObj.offsets[0], smemObj.offsets[1]};
|
||||
|
||||
int vectorSize = 1;
|
||||
if (srcLayout.getOrder()[0] == 1) {
|
||||
if (isSwizzled(srcLayout))
|
||||
vectorSize = std::min(static_cast<int>(srcLayout.getVec()), numOfElems);
|
||||
else
|
||||
vectorSize = numOfElems;
|
||||
}
|
||||
|
||||
auto mapping = computeTensorElemMapping(rewriter, loc, elemsPerInstr, waveId,
|
||||
laneId, warpsPerGroup, numOfElems,
|
||||
reps, offsets, vectorSize, nonKDim);
|
||||
llvm::SmallVector<Value> aOffsets(mapping.size());
|
||||
for (int i = 0; i < mapping.size(); ++i) {
|
||||
Value row = mapping[i][0];
|
||||
Value col = mapping[i][1];
|
||||
aOffsets[i] = computeOffset(rewriter, loc, row, col, smemObj, srcLayout);
|
||||
}
|
||||
return aOffsets;
|
||||
}
|
||||
|
||||
llvm::SmallVector<Value>
|
||||
computeOffsetsBType(ConversionPatternRewriter &rewriter, Location loc,
|
||||
const ArrayRef<int64_t> &elemsPerInstr, Value waveId,
|
||||
Value laneId, int warpsPerGroup, int numOfElems,
|
||||
ArrayRef<int64_t> reps, SharedMemoryObject smemObj,
|
||||
SharedEncodingAttr srcLayout, unsigned nonKDim) {
|
||||
// transpose reps and offsets, because operand B has layout equal to
|
||||
// transposed operand A layout
|
||||
SmallVector<int64_t> tElemsPerInstr{elemsPerInstr[1], elemsPerInstr[0]};
|
||||
SmallVector<int64_t> tReps{reps[1], reps[0]};
|
||||
SmallVector<Value> toffsets{smemObj.offsets[1], smemObj.offsets[0]};
|
||||
|
||||
int vectorSize = 1;
|
||||
if (srcLayout.getOrder()[0] == 0) {
|
||||
if (isSwizzled(srcLayout))
|
||||
vectorSize = std::min(static_cast<int>(srcLayout.getVec()), numOfElems);
|
||||
else
|
||||
vectorSize = numOfElems;
|
||||
}
|
||||
|
||||
auto mapping = computeTensorElemMapping(rewriter, loc, tElemsPerInstr, waveId,
|
||||
laneId, warpsPerGroup, numOfElems,
|
||||
tReps, toffsets, vectorSize, nonKDim);
|
||||
llvm::SmallVector<Value> bOffsets(mapping.size());
|
||||
for (int i = 0; i < mapping.size(); ++i) {
|
||||
// swap row and col, because operand B layout is a transposed operand A
|
||||
// layout
|
||||
Value row = mapping[i][1];
|
||||
Value col = mapping[i][0];
|
||||
bOffsets[i] = computeOffset(rewriter, loc, row, col, smemObj, srcLayout);
|
||||
}
|
||||
return bOffsets;
|
||||
}
|
||||
|
||||
Value computeBasePtr(ConversionPatternRewriter &rewriter, Location loc,
|
||||
const SharedMemoryObject &smemObj) {
|
||||
Value base = smemObj.base;
|
||||
Type type = base.getType();
|
||||
for (int i = 0; i < smemObj.strides.size(); ++i) {
|
||||
Value offset = sub(i32_val(0), mul(smemObj.offsets[i], smemObj.strides[i]));
|
||||
base = gep(type, base, offset);
|
||||
}
|
||||
return base;
|
||||
}
|
||||
|
||||
/**
|
||||
* @brief try find if value is an integer constant
|
||||
*
|
||||
* Trace def-use chain and return integer in case we can proof it is constant.
|
||||
* Current implementation can trace chains of insertValue->extractValue
|
||||
* operations.
|
||||
*
|
||||
* @param val Value for that we want to get constant
|
||||
* @return std::optional on found integer value or empty std::optional
|
||||
*/
|
||||
std::optional<int> findConstValue(Value val) {
|
||||
while (val && !val.getDefiningOp<LLVM::ConstantOp>()) {
|
||||
LLVM::ExtractValueOp extractValOp =
|
||||
val.getDefiningOp<LLVM::ExtractValueOp>();
|
||||
if (!extractValOp)
|
||||
return std::optional<int>();
|
||||
auto extractPosArr = extractValOp.getPosition();
|
||||
if (extractPosArr.size() > 1)
|
||||
return std::optional<int>();
|
||||
int extractPos = extractPosArr[0];
|
||||
|
||||
int insertPos = -1;
|
||||
LLVM::InsertValueOp insertValOp;
|
||||
Value container = extractValOp.getOperand();
|
||||
do {
|
||||
insertValOp = container.getDefiningOp<LLVM::InsertValueOp>();
|
||||
if (!insertValOp)
|
||||
return std::optional<int>();
|
||||
auto insertPosArr = insertValOp.getPosition();
|
||||
if (insertPosArr.size() > 1)
|
||||
return std::optional<int>();
|
||||
insertPos = insertPosArr[0];
|
||||
container = insertValOp.getContainer();
|
||||
} while (insertPos != extractPos);
|
||||
val = insertValOp.getValue();
|
||||
}
|
||||
if (!val)
|
||||
return std::optional<int>();
|
||||
auto cOp = val.getDefiningOp<LLVM::ConstantOp>();
|
||||
assert(cOp);
|
||||
auto valAttr = cOp.getValueAttr();
|
||||
auto intAttr = dyn_cast<mlir::IntegerAttr>(valAttr);
|
||||
assert(intAttr);
|
||||
return intAttr.getInt();
|
||||
}
|
||||
|
||||
bool fastPathAvailable(const SharedMemoryObject &smemObj,
|
||||
const SharedEncodingAttr &srcEncoding,
|
||||
const MfmaEncodingAttr &dstEncoding) {
|
||||
if (dstEncoding.getNonKDim() != 32)
|
||||
return false;
|
||||
if (srcEncoding.getMaxPhase() > 1)
|
||||
return false;
|
||||
auto stride0 = findConstValue(smemObj.strides[0]);
|
||||
auto stride1 = findConstValue(smemObj.strides[1]);
|
||||
auto offset0 = findConstValue(smemObj.offsets[0]);
|
||||
auto offset1 = findConstValue(smemObj.offsets[1]);
|
||||
bool allValuesDefined = stride0.has_value() && stride1.has_value() &&
|
||||
offset0.has_value() && offset1.has_value();
|
||||
if (!allValuesDefined)
|
||||
return false;
|
||||
if (offset0.value() != 0 || offset1.value() != 0)
|
||||
return false;
|
||||
return true;
|
||||
}
|
||||
|
||||
// Computes offsets for operand A or transposed operand B
|
||||
// @param rewriter
|
||||
// @param loc
|
||||
// @param elemsPerInstr operand tile shape consumed by one MFMA instruction
|
||||
// @param waveM wave id for the "non K" axis
|
||||
// @param laneId lane id in warp [0..63]
|
||||
// @param warpsPerGroup number of warps in one block
|
||||
// @param numOfElems number of elements accessed by thread per repetition
|
||||
// @param reps number of instructions repretition to fully cover dot operand
|
||||
// @param cSwizzleOffset
|
||||
llvm::SmallVector<Value>
|
||||
fastPathComputeOffsetsTy1(ConversionPatternRewriter &rewriter, Location loc,
|
||||
const ArrayRef<int64_t> &elemsPerInstr, Value waveId,
|
||||
Value laneId, int warpsPerGroup, int numOfElems,
|
||||
ArrayRef<int64_t> reps, Value cSwizzleOffset) {
|
||||
const int loadVecSize = numOfElems;
|
||||
const int loadsPerThread = 1; // 1 is just in case if we decide to use different loadVecSize
|
||||
auto numM = reps[0];
|
||||
auto numK = reps[1];
|
||||
SmallVector<Value> offsets(numM * numK * loadsPerThread);
|
||||
int lineSize = elemsPerInstr[1] * numK;
|
||||
int blockSize = elemsPerInstr[0] * warpsPerGroup * lineSize;
|
||||
Value _0 = i32_val(0);
|
||||
Value _32 = i32_val(32);
|
||||
Value waveHalf = udiv(laneId, _32);
|
||||
|
||||
Value waveOffset = mul(waveId, i32_val(elemsPerInstr[0] * lineSize));
|
||||
Value colOffset = select(icmp_uge(laneId, _32), i32_val(numOfElems), _0);
|
||||
|
||||
for (int block = 0; block < numM; ++block) {
|
||||
Value blockOffset = i32_val(block * blockSize);
|
||||
for (int tile = 0; tile < numK; ++tile) {
|
||||
Value tileOffset = i32_val(tile * elemsPerInstr[1]);
|
||||
for (int loadId = 0; loadId < loadsPerThread; ++loadId) {
|
||||
Value rowOffset =
|
||||
add(mul(urem(laneId, _32), i32_val(lineSize)), i32_val(loadId * loadVecSize));
|
||||
Value elemOffset = add(rowOffset, colOffset);
|
||||
Value offset =
|
||||
add(add(add(waveOffset, blockOffset), tileOffset), elemOffset);
|
||||
offsets[numK * loadsPerThread * block + loadsPerThread * tile + loadId] = offset;
|
||||
}
|
||||
}
|
||||
}
|
||||
return offsets;
|
||||
}
|
||||
|
||||
// Computes offsets for operand B or transposed operand A
|
||||
// @param rewriter
|
||||
// @param loc
|
||||
// @param elemsPerInstr operand tile shape consumed by one MFMA instruction
|
||||
// @param waveId wave id for the "non K" axis
|
||||
// @param laneId lane id in warp [0..63]
|
||||
// @param warpsPerGroup number of warps per horizontal axis
|
||||
// @param numOfElems number of elements accessed by threads per repetition
|
||||
// @param reps number of instructions repretition to fully cover dot operand
|
||||
// @param cSwizzleOffset
|
||||
llvm::SmallVector<Value>
|
||||
fastPathComputeOffsetsTy2(ConversionPatternRewriter &rewriter, Location loc,
|
||||
const ArrayRef<int64_t> &elemsPerInstr, Value waveId,
|
||||
Value laneId, int warpsPerGroup, int numOfElems,
|
||||
ArrayRef<int64_t> reps, Value cSwizzleOffset) {
|
||||
auto numK = reps[0];
|
||||
auto numN = reps[1];
|
||||
SmallVector<Value> offsets(numK * numN * numOfElems);
|
||||
|
||||
int lineSize = warpsPerGroup * elemsPerInstr[1] * numN;
|
||||
Value _0 = i32_val(0);
|
||||
Value _32 = i32_val(32);
|
||||
Value waveOffset = mul(waveId, i32_val(elemsPerInstr[1]));
|
||||
Value colOffset = urem(laneId, _32);
|
||||
|
||||
for (int block = 0; block < numN; ++block) {
|
||||
Value blockOffset = i32_val(block * elemsPerInstr[1] * warpsPerGroup);
|
||||
for (int tile = 0; tile < numK; ++tile) {
|
||||
Value tileOffset = i32_val(tile * elemsPerInstr[0] * lineSize);
|
||||
for (int elem = 0; elem < numOfElems; ++elem) {
|
||||
Value halfOffset =
|
||||
select(icmp_uge(laneId, _32), i32_val(numOfElems * lineSize), _0);
|
||||
Value rowOffset = add(i32_val(elem * lineSize), halfOffset);
|
||||
Value elemOffset = add(rowOffset, colOffset);
|
||||
Value offset =
|
||||
add(add(add(waveOffset, blockOffset), tileOffset), elemOffset);
|
||||
offsets[numK * numOfElems * block + numOfElems * tile + elem] = offset;
|
||||
}
|
||||
}
|
||||
}
|
||||
return offsets;
|
||||
}
|
||||
|
||||
bool isTransposed(::llvm::ArrayRef<unsigned> order) {
|
||||
assert(order.size() == 2 && (order[0] & ~1ul) == 0 &&
|
||||
order[0] + order[1] == 1);
|
||||
return order[0] == 0;
|
||||
}
|
||||
|
||||
Value loadA(ConversionPatternRewriter &rewriter, Location loc, Value thread,
|
||||
DotOperandEncodingAttr encoding,
|
||||
TritonGPUToLLVMTypeConverter *typeConverter, Value tensor,
|
||||
const SharedMemoryObject &smemObj) {
|
||||
auto mfmaLayout = encoding.getParent().cast<MfmaEncodingAttr>();
|
||||
auto nonKDim = mfmaLayout.getNonKDim();
|
||||
assert(nonKDim == 32 || nonKDim == 16);
|
||||
auto warpsPerCTA = mfmaLayout.getWarpsPerCTA();
|
||||
|
||||
auto aTensorTy = tensor.getType().cast<RankedTensorType>();
|
||||
SmallVector<int64_t> shape(aTensorTy.getShape().begin(),
|
||||
aTensorTy.getShape().end());
|
||||
auto sharedLayout = aTensorTy.getEncoding().cast<SharedEncodingAttr>();
|
||||
auto order = sharedLayout.getOrder();
|
||||
|
||||
auto aElemTy = aTensorTy.getElementType();
|
||||
auto aElemsPerInstr = encoding.getMFMAElemsPerInstr();
|
||||
auto mfmaInstrM = aElemsPerInstr[0];
|
||||
auto mfmaInstrK = aElemsPerInstr[1];
|
||||
|
||||
auto numReps = encoding.getMFMARep(shape, aElemTy);
|
||||
auto numRepM = numReps[0];
|
||||
auto numRepK = numReps[1];
|
||||
|
||||
unsigned iWaveSize = triton::gpu_rocm::getWarpSize(mfmaLayout);
|
||||
assert(iWaveSize == 64);
|
||||
Value waveSize = i32_val(iWaveSize);
|
||||
Value wave = udiv(thread, waveSize);
|
||||
Value lane = urem(thread, waveSize);
|
||||
|
||||
Value waveM =
|
||||
getWaveM(rewriter, loc, wave, warpsPerCTA, mfmaInstrM, shape[0]);
|
||||
int numOfElems = mfmaInstrM * mfmaInstrK / iWaveSize;
|
||||
assert(numOfElems >= 1);
|
||||
unsigned int maxNumWarps = shape[0] / mfmaInstrM;
|
||||
int warpsPerGroupM = std::min(warpsPerCTA[0], maxNumWarps);
|
||||
|
||||
SmallVector<Value> ha;
|
||||
|
||||
if (fastPathAvailable(smemObj, sharedLayout, mfmaLayout)) {
|
||||
Value cSwizzleOffset = smemObj.getCSwizzleOffset(order[0]);
|
||||
SmallVector<Value> offsets;
|
||||
if (isTransposed(order)) {
|
||||
SmallVector<int64_t> elemsPerInstr{mfmaInstrK, mfmaInstrM};
|
||||
SmallVector<int64_t> reps{numReps[1], numReps[0]};
|
||||
offsets = fastPathComputeOffsetsTy2(rewriter, loc, elemsPerInstr, waveM,
|
||||
lane, warpsPerGroupM, numOfElems,
|
||||
reps, cSwizzleOffset);
|
||||
} else {
|
||||
offsets = fastPathComputeOffsetsTy1(rewriter, loc, aElemsPerInstr, waveM,
|
||||
lane, warpsPerGroupM, numOfElems,
|
||||
numReps, cSwizzleOffset);
|
||||
}
|
||||
Value smemBase = smemObj.getBaseBeforeSlice(order[0], loc, rewriter);
|
||||
|
||||
Type smemPtrTy = getShemPtrTy(aElemTy);
|
||||
Type resElemTy = aElemTy.isBF16() ? i16_ty : aElemTy;
|
||||
|
||||
int loadsPerThread = offsets.size() / (numRepM * numRepK);
|
||||
const int elemsPerLoad = numOfElems / loadsPerThread;
|
||||
assert(numOfElems % loadsPerThread == 0);
|
||||
|
||||
for (int m = 0; m < numRepM; ++m) {
|
||||
for (int k = 0; k < numRepK; ++k) {
|
||||
auto vecTy = vec_ty(resElemTy, numOfElems);
|
||||
Value valVec = undef(vecTy);
|
||||
for (unsigned loadId = 0; loadId < loadsPerThread; ++loadId) {
|
||||
auto loadVecTy = vec_ty(aElemTy, elemsPerLoad);
|
||||
Value loadOffset =
|
||||
offsets[m * loadsPerThread * numRepK + k * loadsPerThread + loadId];
|
||||
Value loadAddress = bitcast(gep(smemPtrTy, smemBase, loadOffset),
|
||||
getShemPtrTy(loadVecTy));
|
||||
Value vectorValue = load(loadAddress);
|
||||
if (numOfElems > 1) {
|
||||
for (int elemId = 0; elemId < elemsPerLoad; ++elemId) {
|
||||
Value elemVal =
|
||||
extract_element(aElemTy, vectorValue, i32_val(elemId));
|
||||
elemVal = bitcast(elemVal, resElemTy);
|
||||
valVec = insert_element(vecTy, valVec, elemVal,
|
||||
i32_val(loadId * elemsPerLoad + elemId));
|
||||
}
|
||||
} else {
|
||||
valVec = extract_element(aElemTy, vectorValue, i32_val(0));
|
||||
valVec = bitcast(valVec, resElemTy);
|
||||
}
|
||||
}
|
||||
if (aElemTy == i8_ty)
|
||||
valVec = bitcast(valVec, i32_ty);
|
||||
ha.push_back(valVec);
|
||||
}
|
||||
}
|
||||
} else { // normal path
|
||||
SmallVector<Value> offsets = computeOffsetsAType(
|
||||
rewriter, loc, aElemsPerInstr, waveM, lane, warpsPerGroupM, numOfElems,
|
||||
numReps, smemObj, sharedLayout, nonKDim);
|
||||
|
||||
Value smemBase = computeBasePtr(rewriter, loc, smemObj);
|
||||
Type resElemTy = aElemTy.isBF16() ? i16_ty : aElemTy;
|
||||
|
||||
Type smemPtrTy = getShemPtrTy(aElemTy);
|
||||
|
||||
int loadsPerThread = offsets.size() / (numReps[0] * numReps[1]);
|
||||
int elemsPerLoad = numOfElems / loadsPerThread;
|
||||
|
||||
for (int m = 0; m < numRepM; ++m) {
|
||||
for (int k = 0; k < numRepK; ++k) {
|
||||
auto vecTy = vec_ty(resElemTy, numOfElems);
|
||||
Value valVec = undef(vecTy);
|
||||
for (unsigned loadId = 0; loadId < loadsPerThread; ++loadId) {
|
||||
auto loadVecTy = vec_ty(aElemTy, elemsPerLoad);
|
||||
Value loadOffset = offsets[m * loadsPerThread * numRepK +
|
||||
k * loadsPerThread + loadId];
|
||||
Value loadAddress = bitcast(gep(smemPtrTy, smemBase, loadOffset),
|
||||
getShemPtrTy(loadVecTy));
|
||||
Value vectorValue = load(loadAddress);
|
||||
if (numOfElems > 1) {
|
||||
for (int elemId = 0; elemId < elemsPerLoad; ++elemId) {
|
||||
Value elemVal =
|
||||
extract_element(aElemTy, vectorValue, i32_val(elemId));
|
||||
elemVal = bitcast(elemVal, resElemTy);
|
||||
valVec = insert_element(vecTy, valVec, elemVal,
|
||||
i32_val(loadId * elemsPerLoad + elemId));
|
||||
}
|
||||
} else {
|
||||
valVec = extract_element(aElemTy, vectorValue, i32_val(0));
|
||||
valVec = bitcast(valVec, resElemTy);
|
||||
}
|
||||
}
|
||||
if (aElemTy == i8_ty)
|
||||
valVec = bitcast(valVec, i32_ty);
|
||||
ha.push_back(valVec);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
MLIRContext *ctx = mfmaLayout.getContext();
|
||||
Type structTy = LLVM::LLVMStructType::getLiteral(
|
||||
ctx, SmallVector<Type>(ha.size(), ha[0].getType()));
|
||||
auto result = typeConverter->packLLElements(loc, ha, rewriter, structTy);
|
||||
return result;
|
||||
}
|
||||
|
||||
Value loadB(ConversionPatternRewriter &rewriter, Location loc, Value thread,
|
||||
DotOperandEncodingAttr encoding,
|
||||
TritonGPUToLLVMTypeConverter *typeConverter, Value tensor,
|
||||
const SharedMemoryObject &smemObj) {
|
||||
auto mfmaLayout = encoding.getParent().cast<MfmaEncodingAttr>();
|
||||
auto nonKDim = mfmaLayout.getNonKDim();
|
||||
assert(nonKDim == 32 || nonKDim == 16);
|
||||
auto warpsPerCTA = mfmaLayout.getWarpsPerCTA();
|
||||
|
||||
auto bTensorTy = tensor.getType().cast<RankedTensorType>();
|
||||
ArrayRef<int64_t> shape = bTensorTy.getShape();
|
||||
auto sharedLayout = bTensorTy.getEncoding().cast<SharedEncodingAttr>();
|
||||
auto order = sharedLayout.getOrder();
|
||||
|
||||
auto bElemTy = bTensorTy.getElementType();
|
||||
auto bElemsPerInstr = encoding.getMFMAElemsPerInstr();
|
||||
auto mfmaInstrK = bElemsPerInstr[0];
|
||||
auto mfmaInstrN = bElemsPerInstr[1];
|
||||
|
||||
auto numReps = encoding.getMFMARep(shape, bElemTy);
|
||||
auto numRepK = numReps[0];
|
||||
auto numRepN = numReps[1];
|
||||
|
||||
unsigned iWaveSize = triton::gpu_rocm::getWarpSize(mfmaLayout);
|
||||
assert(iWaveSize == 64);
|
||||
Value waveSize = i32_val(iWaveSize);
|
||||
Value wave = udiv(thread, waveSize);
|
||||
Value lane = urem(thread, waveSize);
|
||||
|
||||
Value waveN =
|
||||
getWaveN(rewriter, loc, wave, warpsPerCTA, mfmaInstrN, shape[1]);
|
||||
int numOfElems = mfmaInstrK * mfmaInstrN / iWaveSize;
|
||||
assert(numOfElems >= 1);
|
||||
|
||||
unsigned int maxNumWarps = shape[1] / mfmaInstrN;
|
||||
int warpsPerGroupN = std::min(warpsPerCTA[1], maxNumWarps);
|
||||
|
||||
SmallVector<Value> hb;
|
||||
|
||||
if (fastPathAvailable(smemObj, sharedLayout, mfmaLayout)) {
|
||||
Value cSwizzleOffset = smemObj.getCSwizzleOffset(order[0]);
|
||||
|
||||
llvm::SmallVector<Value> offsets;
|
||||
unsigned int maxNumWarps = shape[1] / mfmaInstrN;
|
||||
int warpsPerGroupN = std::min(warpsPerCTA[1], maxNumWarps);
|
||||
if (isTransposed(order)) {
|
||||
SmallVector<int64_t> elemsPerInstr{mfmaInstrN, mfmaInstrK};
|
||||
SmallVector<int64_t> reps{numReps[1], numReps[0]};
|
||||
offsets = fastPathComputeOffsetsTy1(rewriter, loc, elemsPerInstr, waveN,
|
||||
lane, warpsPerGroupN, numOfElems,
|
||||
reps, cSwizzleOffset);
|
||||
} else {
|
||||
offsets = fastPathComputeOffsetsTy2(rewriter, loc, bElemsPerInstr, waveN,
|
||||
lane, warpsPerGroupN, numOfElems,
|
||||
numReps, cSwizzleOffset);
|
||||
}
|
||||
|
||||
Value smemBase = smemObj.getBaseBeforeSlice(order[0], loc, rewriter);
|
||||
|
||||
Type resElemTy = bElemTy.isBF16() ? i16_ty : bElemTy;
|
||||
|
||||
Type smemPtrTy = getShemPtrTy(bElemTy);
|
||||
|
||||
const int loadsPerThread = offsets.size() / (numRepN * numRepK);
|
||||
const int elemsPerLoad = numOfElems / loadsPerThread;
|
||||
assert(numOfElems % loadsPerThread == 0);
|
||||
|
||||
for (int n = 0; n < numRepN; ++n) {
|
||||
for (int k = 0; k < numRepK; ++k) {
|
||||
auto vecTy = vec_ty(resElemTy, numOfElems);
|
||||
Value valVec = undef(vecTy);
|
||||
for (unsigned loadId = 0; loadId < loadsPerThread; ++loadId) {
|
||||
auto loadVecTy = vec_ty(bElemTy, elemsPerLoad);
|
||||
Value loadOffset =
|
||||
offsets[n * loadsPerThread * numRepK + k * loadsPerThread + loadId];
|
||||
Value loadAddress = bitcast(gep(smemPtrTy, smemBase, loadOffset),
|
||||
getShemPtrTy(loadVecTy));
|
||||
Value vectorValue = load(loadAddress);
|
||||
if (numOfElems > 1) {
|
||||
for (int elemId = 0; elemId < elemsPerLoad; ++elemId) {
|
||||
Value elemVal =
|
||||
extract_element(bElemTy, vectorValue, i32_val(elemId));
|
||||
elemVal = bitcast(elemVal, resElemTy);
|
||||
valVec = insert_element(vecTy, valVec, elemVal,
|
||||
i32_val(loadId * elemsPerLoad + elemId));
|
||||
}
|
||||
} else {
|
||||
valVec = extract_element(bElemTy, vectorValue, i32_val(0));
|
||||
valVec = bitcast(valVec, resElemTy);
|
||||
}
|
||||
}
|
||||
if (bElemTy == i8_ty)
|
||||
valVec = bitcast(valVec, i32_ty);
|
||||
hb.push_back(valVec);
|
||||
}
|
||||
}
|
||||
} else { // normal path
|
||||
llvm::SmallVector<Value> offsets = computeOffsetsBType(
|
||||
rewriter, loc, bElemsPerInstr, waveN, lane, warpsPerGroupN, numOfElems,
|
||||
numReps, smemObj, sharedLayout, nonKDim);
|
||||
|
||||
Value smemBase = computeBasePtr(rewriter, loc, smemObj);
|
||||
|
||||
Type resElemTy = bElemTy.isBF16() ? i16_ty : bElemTy;
|
||||
|
||||
Type smemPtrTy = getShemPtrTy(bElemTy);
|
||||
|
||||
int loadsPerThread = offsets.size() / (numReps[0] * numReps[1]);
|
||||
int elemsPerLoad = numOfElems / loadsPerThread;
|
||||
for (int n = 0; n < numRepN; ++n) {
|
||||
for (int k = 0; k < numRepK; ++k) {
|
||||
auto vecTy = vec_ty(resElemTy, numOfElems);
|
||||
Value valVec = undef(vecTy);
|
||||
for (unsigned loadId = 0; loadId < loadsPerThread; ++loadId) {
|
||||
auto loadVecTy = vec_ty(bElemTy, elemsPerLoad);
|
||||
Value loadOffset = offsets[n * loadsPerThread * numRepK +
|
||||
k * loadsPerThread + loadId];
|
||||
Value loadAddress = bitcast(gep(smemPtrTy, smemBase, loadOffset),
|
||||
getShemPtrTy(loadVecTy));
|
||||
Value vectorValue = load(loadAddress);
|
||||
if (numOfElems > 1) {
|
||||
for (int elemId = 0; elemId < elemsPerLoad; ++elemId) {
|
||||
Value elemVal =
|
||||
extract_element(bElemTy, vectorValue, i32_val(elemId));
|
||||
elemVal = bitcast(elemVal, resElemTy);
|
||||
valVec = insert_element(vecTy, valVec, elemVal,
|
||||
i32_val(loadId * elemsPerLoad + elemId));
|
||||
}
|
||||
} else {
|
||||
valVec = extract_element(bElemTy, vectorValue, i32_val(0));
|
||||
valVec = bitcast(valVec, resElemTy);
|
||||
}
|
||||
}
|
||||
if (bElemTy == i8_ty)
|
||||
valVec = bitcast(valVec, i32_ty);
|
||||
hb.push_back(valVec);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
MLIRContext *ctx = mfmaLayout.getContext();
|
||||
Type structTy = LLVM::LLVMStructType::getLiteral(
|
||||
ctx, SmallVector<Type>(hb.size(), hb[0].getType()));
|
||||
auto result = typeConverter->packLLElements(loc, hb, rewriter, structTy);
|
||||
return result;
|
||||
}
|
||||
|
||||
Value convertLayout(int opIdx, ConversionPatternRewriter &rewriter,
|
||||
Location loc, Value tensor, DotOperandEncodingAttr encoding,
|
||||
const SharedMemoryObject &smemObj,
|
||||
TritonGPUToLLVMTypeConverter *typeConverter, Value thread) {
|
||||
switch (opIdx) {
|
||||
case 0:
|
||||
// operand $a
|
||||
return loadA(rewriter, loc, thread, encoding, typeConverter, tensor,
|
||||
smemObj);
|
||||
case 1:
|
||||
// operand $b
|
||||
return loadB(rewriter, loc, thread, encoding, typeConverter, tensor,
|
||||
smemObj);
|
||||
default:
|
||||
assert(false && "unexpected operand idx");
|
||||
return Value();
|
||||
}
|
||||
}
|
||||
|
||||
} // namespace SharedToDotOperandMFMA
|
||||
|
||||
#endif // ifdef USE_ROCM
|
||||
@@ -0,0 +1,459 @@
|
||||
#include "../ConvertLayoutOpToLLVM.h"
|
||||
#include "../Utility.h"
|
||||
|
||||
using CoordTy = SmallVector<Value>;
|
||||
using ValueTable = std::map<std::pair<int, int>, std::pair<Value, Value>>;
|
||||
using ::mlir::LLVM::getSharedMemoryObjectFromStruct;
|
||||
using ::mlir::LLVM::getStridesFromShapeAndOrder;
|
||||
using ::mlir::triton::gpu_rocm::DotOperandEncodingAttr;
|
||||
using ::mlir::triton::gpu_rocm::getContigPerThread;
|
||||
using ::mlir::triton::gpu_rocm::getOrder;
|
||||
using ::mlir::triton::gpu_rocm::getShapePerCTA;
|
||||
using ::mlir::triton::gpu_rocm::getSizePerThread;
|
||||
using ::mlir::triton::gpu_rocm::getTotalElemsPerThread;
|
||||
using ::mlir::triton::gpu_rocm::isaDistributedLayout;
|
||||
using ::mlir::triton::gpu_rocm::SharedEncodingAttr;
|
||||
|
||||
// Compute the offset of the matrix to load.
|
||||
// Returns offsetAM, offsetAK, offsetBN, offsetBK.
|
||||
// NOTE, the information M(from $a) and N(from $b) couldn't be retrieved at
|
||||
// the same time in the usage in convert_layout[shared->dot_op], we leave
|
||||
// the noexist info to be 0 and only use the desired argument from the
|
||||
// composed result. In this way we want to retain the original code
|
||||
// structure in convert_mma884 method for easier debugging.
|
||||
static std::tuple<Value, Value, Value, Value>
|
||||
computeOffsets(Value threadId, bool isARow, bool isBRow, ArrayRef<int> fpw,
|
||||
ArrayRef<int> spw, ArrayRef<int> rep,
|
||||
ConversionPatternRewriter &rewriter, Location loc,
|
||||
Type resultTy) {
|
||||
auto *ctx = rewriter.getContext();
|
||||
auto wpt = resultTy.cast<RankedTensorType>()
|
||||
.getEncoding()
|
||||
.cast<DotOperandEncodingAttr>()
|
||||
.getParent()
|
||||
.cast<MmaEncodingAttr>()
|
||||
.getWarpsPerCTA();
|
||||
|
||||
Value _1 = i32_val(1);
|
||||
Value _3 = i32_val(3);
|
||||
Value _4 = i32_val(4);
|
||||
Value _16 = i32_val(16);
|
||||
Value _32 = i32_val(32);
|
||||
|
||||
Value lane = urem(threadId, _32);
|
||||
Value warp = udiv(threadId, _32);
|
||||
|
||||
// warp offset
|
||||
Value warp0 = urem(warp, i32_val(wpt[0]));
|
||||
Value warp12 = udiv(warp, i32_val(wpt[0]));
|
||||
Value warp1 = urem(warp12, i32_val(wpt[1]));
|
||||
Value warpMOff = mul(warp0, i32_val(spw[0]));
|
||||
Value warpNOff = mul(warp1, i32_val(spw[1]));
|
||||
// Quad offset
|
||||
Value quadMOff = mul(udiv(and_(lane, _16), _4), i32_val(fpw[0]));
|
||||
Value quadNOff = mul(udiv(and_(lane, _16), _4), i32_val(fpw[1]));
|
||||
// Pair offset
|
||||
Value pairMOff = udiv(urem(lane, _16), _4);
|
||||
pairMOff = urem(pairMOff, i32_val(fpw[0]));
|
||||
pairMOff = mul(pairMOff, _4);
|
||||
Value pairNOff = udiv(urem(lane, _16), _4);
|
||||
pairNOff = udiv(pairNOff, i32_val(fpw[0]));
|
||||
pairNOff = urem(pairNOff, i32_val(fpw[1]));
|
||||
pairNOff = mul(pairNOff, _4);
|
||||
// scale
|
||||
pairMOff = mul(pairMOff, i32_val(rep[0] / 2));
|
||||
quadMOff = mul(quadMOff, i32_val(rep[0] / 2));
|
||||
pairNOff = mul(pairNOff, i32_val(rep[1] / 2));
|
||||
quadNOff = mul(quadNOff, i32_val(rep[1] / 2));
|
||||
// Quad pair offset
|
||||
Value laneMOff = add(pairMOff, quadMOff);
|
||||
Value laneNOff = add(pairNOff, quadNOff);
|
||||
// A offset
|
||||
Value offsetAM = add(warpMOff, laneMOff);
|
||||
Value offsetAK = and_(lane, _3);
|
||||
// B offset
|
||||
Value offsetBN = add(warpNOff, laneNOff);
|
||||
Value offsetBK = and_(lane, _3);
|
||||
// i indices
|
||||
Value offsetCM = add(and_(lane, _1), offsetAM);
|
||||
if (isARow) {
|
||||
offsetAM = add(offsetAM, urem(threadId, _4));
|
||||
offsetAK = i32_val(0);
|
||||
}
|
||||
if (!isBRow) {
|
||||
offsetBN = add(offsetBN, urem(threadId, _4));
|
||||
offsetBK = i32_val(0);
|
||||
}
|
||||
|
||||
return std::make_tuple(offsetAM, offsetAK, offsetBN, offsetBK);
|
||||
}
|
||||
|
||||
static Value loadA(Value tensor, const SharedMemoryObject &smemObj,
|
||||
Value thread, Location loc,
|
||||
TritonGPUToLLVMTypeConverter *typeConverter,
|
||||
ConversionPatternRewriter &rewriter, Type resultTy) {
|
||||
static constexpr std::array<int, 3> fpw{{2, 2, 1}};
|
||||
auto wpt = resultTy.cast<RankedTensorType>()
|
||||
.getEncoding()
|
||||
.cast<DotOperandEncodingAttr>()
|
||||
.getParent()
|
||||
.cast<MmaEncodingAttr>()
|
||||
.getWarpsPerCTA();
|
||||
|
||||
auto *ctx = rewriter.getContext();
|
||||
auto tensorTy = tensor.getType().cast<RankedTensorType>();
|
||||
auto sharedLayout = tensorTy.getEncoding().cast<SharedEncodingAttr>();
|
||||
auto shape = tensorTy.getShape();
|
||||
auto order = sharedLayout.getOrder();
|
||||
|
||||
Value cSwizzleOffset = smemObj.getCSwizzleOffset(order[0]);
|
||||
Value smemBase = smemObj.getBaseBeforeSlice(order[0], loc, rewriter);
|
||||
|
||||
bool isARow = order[0] != 0;
|
||||
auto resultEncoding = resultTy.cast<RankedTensorType>()
|
||||
.getEncoding()
|
||||
.cast<DotOperandEncodingAttr>();
|
||||
auto [offsetAM, offsetAK, _3, _4] = computeOffsets(
|
||||
thread, isARow, false, fpw, resultEncoding.getMMAv1ShapePerWarp(),
|
||||
resultEncoding.getMMAv1Rep(), rewriter, loc, resultTy);
|
||||
|
||||
int vecA = sharedLayout.getVec();
|
||||
|
||||
auto strides = smemObj.strides;
|
||||
Value strideAM = isARow ? strides[0] : i32_val(1);
|
||||
Value strideAK = isARow ? i32_val(1) : strides[1];
|
||||
Value strideA0 = isARow ? strideAK : strideAM;
|
||||
Value strideA1 = isARow ? strideAM : strideAK;
|
||||
|
||||
int strideRepM = wpt[0] * fpw[0] * 8;
|
||||
int strideRepK = 1;
|
||||
|
||||
// swizzling
|
||||
int perPhaseA = sharedLayout.getPerPhase();
|
||||
int maxPhaseA = sharedLayout.getMaxPhase();
|
||||
int stepA0 = isARow ? strideRepK : strideRepM;
|
||||
int numPtrA = std::max(2 * perPhaseA * maxPhaseA / stepA0, 1);
|
||||
int NK = shape[1];
|
||||
|
||||
// pre-compute pointer lanes
|
||||
Value offA0 = isARow ? offsetAK : offsetAM;
|
||||
Value offA1 = isARow ? offsetAM : offsetAK;
|
||||
Value phaseA = urem(udiv(offA1, i32_val(perPhaseA)), i32_val(maxPhaseA));
|
||||
offA0 = add(offA0, cSwizzleOffset);
|
||||
SmallVector<Value> offA(numPtrA);
|
||||
for (int i = 0; i < numPtrA; i++) {
|
||||
Value offA0I = add(offA0, i32_val(i * (isARow ? 4 : strideRepM)));
|
||||
offA0I = udiv(offA0I, i32_val(vecA));
|
||||
offA0I = xor_(offA0I, phaseA);
|
||||
offA0I = mul(offA0I, i32_val(vecA));
|
||||
offA[i] = add(mul(offA0I, strideA0), mul(offA1, strideA1));
|
||||
}
|
||||
|
||||
Type elemX2Ty = vec_ty(f16_ty, 2);
|
||||
Type elemPtrTy = ptr_ty(f16_ty, 3);
|
||||
if (tensorTy.getElementType().isBF16()) {
|
||||
elemX2Ty = vec_ty(i16_ty, 2);
|
||||
elemPtrTy = ptr_ty(i16_ty, 3);
|
||||
}
|
||||
|
||||
// prepare arguments
|
||||
SmallVector<Value> ptrA(numPtrA);
|
||||
|
||||
std::map<std::pair<int, int>, std::pair<Value, Value>> has;
|
||||
for (int i = 0; i < numPtrA; i++)
|
||||
ptrA[i] = gep(ptr_ty(f16_ty, 3), smemBase, offA[i]);
|
||||
|
||||
auto ld = [&](decltype(has) &vals, int m, int k, Value val0, Value val1) {
|
||||
vals[{m, k}] = {val0, val1};
|
||||
};
|
||||
auto loadA = [&](int m, int k) {
|
||||
int offidx = (isARow ? k / 4 : m) % numPtrA;
|
||||
Value thePtrA = gep(elemPtrTy, smemBase, offA[offidx]);
|
||||
|
||||
int stepAM = isARow ? m : m / numPtrA * numPtrA;
|
||||
int stepAK = isARow ? k / (numPtrA * vecA) * (numPtrA * vecA) : k;
|
||||
Value offset = add(mul(i32_val(stepAM * strideRepM), strideAM),
|
||||
mul(i32_val(stepAK), strideAK));
|
||||
Value pa = gep(elemPtrTy, thePtrA, offset);
|
||||
Type aPtrTy = ptr_ty(vec_ty(i32_ty, std::max<int>(vecA / 2, 1)), 3);
|
||||
Value ha = load(bitcast(pa, aPtrTy));
|
||||
// record lds that needs to be moved
|
||||
Value ha00 = bitcast(extract_element(ha, i32_val(0)), elemX2Ty);
|
||||
Value ha01 = bitcast(extract_element(ha, i32_val(1)), elemX2Ty);
|
||||
ld(has, m, k, ha00, ha01);
|
||||
|
||||
if (vecA > 4) {
|
||||
Value ha10 = bitcast(extract_element(ha, i32_val(2)), elemX2Ty);
|
||||
Value ha11 = bitcast(extract_element(ha, i32_val(3)), elemX2Ty);
|
||||
if (isARow)
|
||||
ld(has, m, k + 4, ha10, ha11);
|
||||
else
|
||||
ld(has, m + 1, k, ha10, ha11);
|
||||
}
|
||||
};
|
||||
|
||||
bool isARow_ = resultEncoding.getMMAv1IsRow();
|
||||
bool isAVec4 = resultEncoding.getMMAv1IsVec4();
|
||||
unsigned numM = resultEncoding.getMMAv1NumOuter(shape);
|
||||
for (unsigned k = 0; k < NK; k += 4)
|
||||
for (unsigned m = 0; m < numM / 2; ++m)
|
||||
if (!has.count({m, k}))
|
||||
loadA(m, k);
|
||||
|
||||
SmallVector<Value> elems;
|
||||
elems.reserve(has.size() * 2);
|
||||
for (auto item : has) { // has is a map, the key should be ordered.
|
||||
elems.push_back(bitcast(item.second.first, i32_ty));
|
||||
elems.push_back(bitcast(item.second.second, i32_ty));
|
||||
}
|
||||
|
||||
Value res = typeConverter->packLLElements(loc, elems, rewriter, resultTy);
|
||||
return res;
|
||||
}
|
||||
|
||||
static Value loadB(Value tensor, const SharedMemoryObject &smemObj,
|
||||
Value thread, Location loc,
|
||||
TritonGPUToLLVMTypeConverter *typeConverter,
|
||||
ConversionPatternRewriter &rewriter, Type resultTy) {
|
||||
static constexpr std::array<int, 3> fpw{{2, 2, 1}};
|
||||
auto wpt = resultTy.cast<RankedTensorType>()
|
||||
.getEncoding()
|
||||
.cast<DotOperandEncodingAttr>()
|
||||
.getParent()
|
||||
.cast<MmaEncodingAttr>()
|
||||
.getWarpsPerCTA();
|
||||
// smem
|
||||
auto strides = smemObj.strides;
|
||||
|
||||
auto *ctx = rewriter.getContext();
|
||||
auto tensorTy = tensor.getType().cast<RankedTensorType>();
|
||||
auto sharedLayout = tensorTy.getEncoding().cast<SharedEncodingAttr>();
|
||||
|
||||
auto shape = tensorTy.getShape();
|
||||
auto order = sharedLayout.getOrder();
|
||||
|
||||
Value smem = smemObj.getBaseBeforeSlice(order[0], loc, rewriter);
|
||||
bool isBRow = order[0] != 0; // is row-major in shared memory layout
|
||||
// isBRow_ indicates whether B is row-major in DotOperand layout
|
||||
auto resultEncoding = resultTy.cast<RankedTensorType>()
|
||||
.getEncoding()
|
||||
.cast<DotOperandEncodingAttr>();
|
||||
|
||||
int vecB = sharedLayout.getVec();
|
||||
Value strideBN = isBRow ? i32_val(1) : strides[1];
|
||||
Value strideBK = isBRow ? strides[0] : i32_val(1);
|
||||
Value strideB0 = isBRow ? strideBN : strideBK;
|
||||
Value strideB1 = isBRow ? strideBK : strideBN;
|
||||
int strideRepN = wpt[1] * fpw[1] * 8;
|
||||
int strideRepK = 1;
|
||||
|
||||
auto [_3, _4, offsetBN, offsetBK] = computeOffsets(
|
||||
thread, false, isBRow, fpw, resultEncoding.getMMAv1ShapePerWarp(),
|
||||
resultEncoding.getMMAv1Rep(), rewriter, loc, resultTy);
|
||||
|
||||
// swizzling
|
||||
int perPhaseB = sharedLayout.getPerPhase();
|
||||
int maxPhaseB = sharedLayout.getMaxPhase();
|
||||
int stepB0 = isBRow ? strideRepN : strideRepK;
|
||||
int numPtrB = std::max(2 * perPhaseB * maxPhaseB / stepB0, 1);
|
||||
int NK = shape[0];
|
||||
|
||||
Value offB0 = isBRow ? offsetBN : offsetBK;
|
||||
Value offB1 = isBRow ? offsetBK : offsetBN;
|
||||
Value phaseB = urem(udiv(offB1, i32_val(perPhaseB)), i32_val(maxPhaseB));
|
||||
Value cSwizzleOffset = smemObj.getCSwizzleOffset(order[0]);
|
||||
|
||||
offB0 = add(offB0, cSwizzleOffset);
|
||||
SmallVector<Value> offB(numPtrB);
|
||||
for (int i = 0; i < numPtrB; ++i) {
|
||||
Value offB0I = add(offB0, i32_val(i * (isBRow ? strideRepN : 4)));
|
||||
offB0I = udiv(offB0I, i32_val(vecB));
|
||||
offB0I = xor_(offB0I, phaseB);
|
||||
offB0I = mul(offB0I, i32_val(vecB));
|
||||
offB[i] = add(mul(offB0I, strideB0), mul(offB1, strideB1));
|
||||
}
|
||||
|
||||
Type elemPtrTy = ptr_ty(f16_ty, 3);
|
||||
Type elemX2Ty = vec_ty(f16_ty, 2);
|
||||
if (tensorTy.getElementType().isBF16()) {
|
||||
elemPtrTy = ptr_ty(i16_ty, 3);
|
||||
elemX2Ty = vec_ty(i16_ty, 2);
|
||||
}
|
||||
|
||||
SmallVector<Value> ptrB(numPtrB);
|
||||
ValueTable hbs;
|
||||
for (int i = 0; i < numPtrB; ++i)
|
||||
ptrB[i] = gep(ptr_ty(f16_ty, 3), smem, offB[i]);
|
||||
|
||||
auto ld = [&](decltype(hbs) &vals, int m, int k, Value val0, Value val1) {
|
||||
vals[{m, k}] = {val0, val1};
|
||||
};
|
||||
|
||||
auto loadB = [&](int n, int K) {
|
||||
int offidx = (isBRow ? n : K / 4) % numPtrB;
|
||||
Value thePtrB = ptrB[offidx];
|
||||
|
||||
int stepBN = isBRow ? n / numPtrB * numPtrB : n;
|
||||
int stepBK = isBRow ? K : K / (numPtrB * vecB) * (numPtrB * vecB);
|
||||
Value offset = add(mul(i32_val(stepBN * strideRepN), strideBN),
|
||||
mul(i32_val(stepBK), strideBK));
|
||||
Value pb = gep(elemPtrTy, thePtrB, offset);
|
||||
|
||||
Value hb =
|
||||
load(bitcast(pb, ptr_ty(vec_ty(i32_ty, std::max(vecB / 2, 1)), 3)));
|
||||
// record lds that needs to be moved
|
||||
Value hb00 = bitcast(extract_element(hb, i32_val(0)), elemX2Ty);
|
||||
Value hb01 = bitcast(extract_element(hb, i32_val(1)), elemX2Ty);
|
||||
ld(hbs, n, K, hb00, hb01);
|
||||
if (vecB > 4) {
|
||||
Value hb10 = bitcast(extract_element(hb, i32_val(2)), elemX2Ty);
|
||||
Value hb11 = bitcast(extract_element(hb, i32_val(3)), elemX2Ty);
|
||||
if (isBRow)
|
||||
ld(hbs, n + 1, K, hb10, hb11);
|
||||
else
|
||||
ld(hbs, n, K + 4, hb10, hb11);
|
||||
}
|
||||
};
|
||||
|
||||
bool isBRow_ = resultEncoding.getMMAv1IsRow();
|
||||
assert(isBRow == isBRow_ && "B need smem isRow");
|
||||
bool isBVec4 = resultEncoding.getMMAv1IsVec4();
|
||||
unsigned numN = resultEncoding.getMMAv1NumOuter(shape);
|
||||
for (unsigned k = 0; k < NK; k += 4)
|
||||
for (unsigned n = 0; n < numN / 2; ++n) {
|
||||
if (!hbs.count({n, k}))
|
||||
loadB(n, k);
|
||||
}
|
||||
|
||||
SmallVector<Value> elems;
|
||||
for (auto &item : hbs) { // has is a map, the key should be ordered.
|
||||
elems.push_back(bitcast(item.second.first, i32_ty));
|
||||
elems.push_back(bitcast(item.second.second, i32_ty));
|
||||
}
|
||||
|
||||
Value res = typeConverter->packLLElements(loc, elems, rewriter, resultTy);
|
||||
return res;
|
||||
}
|
||||
|
||||
namespace SharedToDotOperandMMAv1 {
|
||||
using CoordTy = SmallVector<Value>;
|
||||
using ValueTable = std::map<std::pair<int, int>, std::pair<Value, Value>>;
|
||||
|
||||
SmallVector<CoordTy> getMNCoords(Value thread, Location loc,
|
||||
ConversionPatternRewriter &rewriter,
|
||||
ArrayRef<unsigned int> wpt,
|
||||
const MmaEncodingAttr &mmaLayout,
|
||||
ArrayRef<int64_t> shape, bool isARow,
|
||||
bool isBRow, bool isAVec4, bool isBVec4) {
|
||||
static constexpr std::array<int, 3> fpw{{2, 2, 1}};
|
||||
|
||||
auto *ctx = thread.getContext();
|
||||
Value _1 = i32_val(1);
|
||||
Value _2 = i32_val(2);
|
||||
Value _4 = i32_val(4);
|
||||
Value _16 = i32_val(16);
|
||||
Value _32 = i32_val(32);
|
||||
Value _fpw0 = i32_val(fpw[0]);
|
||||
Value _fpw1 = i32_val(fpw[1]);
|
||||
|
||||
// A info
|
||||
auto aEncoding = DotOperandEncodingAttr::get(ctx, 0, mmaLayout, 0);
|
||||
auto aRep = aEncoding.getMMAv1Rep();
|
||||
auto aSpw = aEncoding.getMMAv1ShapePerWarp();
|
||||
// B info
|
||||
auto bEncoding = DotOperandEncodingAttr::get(ctx, 1, mmaLayout, 0);
|
||||
auto bSpw = bEncoding.getMMAv1ShapePerWarp();
|
||||
auto bRep = bEncoding.getMMAv1Rep();
|
||||
|
||||
SmallVector<int, 2> rep({aRep[0], bRep[1]});
|
||||
SmallVector<int, 2> spw({aSpw[0], bSpw[1]});
|
||||
SmallVector<unsigned, 2> shapePerCTA({spw[0] * wpt[0], spw[1] * wpt[1]});
|
||||
|
||||
Value lane = urem(thread, _32);
|
||||
Value warp = udiv(thread, _32);
|
||||
|
||||
Value warp0 = urem(warp, i32_val(wpt[0]));
|
||||
Value warp12 = udiv(warp, i32_val(wpt[0]));
|
||||
Value warp1 = urem(warp12, i32_val(wpt[1]));
|
||||
|
||||
// warp offset
|
||||
Value offWarpM = mul(warp0, i32_val(spw[0]));
|
||||
Value offWarpN = mul(warp1, i32_val(spw[1]));
|
||||
// quad offset
|
||||
Value offQuadM = mul(udiv(and_(lane, _16), _4), _fpw0);
|
||||
Value offQuadN = mul(udiv(and_(lane, _16), _4), _fpw1);
|
||||
// pair offset
|
||||
Value offPairM = udiv(urem(lane, _16), _4);
|
||||
offPairM = urem(offPairM, _fpw0);
|
||||
offPairM = mul(offPairM, _4);
|
||||
Value offPairN = udiv(urem(lane, _16), _4);
|
||||
offPairN = udiv(offPairN, _fpw0);
|
||||
offPairN = urem(offPairN, _fpw1);
|
||||
offPairN = mul(offPairN, _4);
|
||||
|
||||
// sclare
|
||||
offPairM = mul(offPairM, i32_val(rep[0] / 2));
|
||||
offQuadM = mul(offQuadM, i32_val(rep[0] / 2));
|
||||
offPairN = mul(offPairN, i32_val(rep[1] / 2));
|
||||
offQuadN = mul(offQuadN, i32_val(rep[1] / 2));
|
||||
|
||||
// quad pair offset
|
||||
Value offLaneM = add(offPairM, offQuadM);
|
||||
Value offLaneN = add(offPairN, offQuadN);
|
||||
// a, b offset
|
||||
Value offsetAM = add(offWarpM, offLaneM);
|
||||
Value offsetBN = add(offWarpN, offLaneN);
|
||||
// m indices
|
||||
Value offsetCM = add(and_(lane, _1), offsetAM);
|
||||
SmallVector<Value> idxM;
|
||||
for (unsigned m = 0; m < shape[0]; m += shapePerCTA[0])
|
||||
for (unsigned mm = 0; mm < rep[0]; ++mm)
|
||||
idxM.push_back(add(offsetCM, i32_val(m + mm * 2)));
|
||||
|
||||
// n indices
|
||||
Value offsetCN = add((and_(lane, _2)), (add(offWarpN, offPairN)));
|
||||
SmallVector<Value> idxN;
|
||||
for (int n = 0; n < shape[1]; n += shapePerCTA[1]) {
|
||||
for (int nn = 0; nn < rep[1]; ++nn) {
|
||||
idxN.push_back(add(
|
||||
offsetCN, i32_val(n + nn / 2 * 4 + (nn % 2) * 2 * fpw[1] * rep[1])));
|
||||
idxN.push_back(
|
||||
add(offsetCN,
|
||||
i32_val(n + nn / 2 * 4 + (nn % 2) * 2 * fpw[1] * rep[1] + 1)));
|
||||
}
|
||||
}
|
||||
|
||||
SmallVector<SmallVector<Value>> axes({idxM, idxN});
|
||||
|
||||
// product the axis M and axis N to get coords, ported from
|
||||
// generator::init_idx method from triton2.0
|
||||
|
||||
// TODO[Superjomn]: check the order.
|
||||
SmallVector<CoordTy> coords;
|
||||
for (Value x1 : axes[1]) { // N
|
||||
for (Value x0 : axes[0]) { // M
|
||||
SmallVector<Value, 2> idx(2);
|
||||
idx[0] = x0; // M
|
||||
idx[1] = x1; // N
|
||||
coords.push_back(std::move(idx));
|
||||
}
|
||||
}
|
||||
|
||||
return coords; // {M,N} in row-major
|
||||
}
|
||||
|
||||
Value convertLayout(int opIdx, Value tensor, const SharedMemoryObject &smemObj,
|
||||
Value thread, Location loc,
|
||||
TritonGPUToLLVMTypeConverter *typeConverter,
|
||||
ConversionPatternRewriter &rewriter, Type resultTy) {
|
||||
if (opIdx == 0)
|
||||
return loadA(tensor, smemObj, thread, loc, typeConverter, rewriter,
|
||||
resultTy);
|
||||
else {
|
||||
assert(opIdx == 1);
|
||||
return loadB(tensor, smemObj, thread, loc, typeConverter, rewriter,
|
||||
resultTy);
|
||||
}
|
||||
}
|
||||
|
||||
} // namespace SharedToDotOperandMMAv1
|
||||
@@ -0,0 +1,637 @@
|
||||
#include "../ConvertLayoutOpToLLVM.h"
|
||||
#include "../Utility.h"
|
||||
|
||||
using namespace mlir;
|
||||
|
||||
using ValueTable = std::map<std::pair<unsigned, unsigned>, Value>;
|
||||
using ::mlir::LLVM::delinearize;
|
||||
using ::mlir::LLVM::getSharedMemoryObjectFromStruct;
|
||||
using ::mlir::LLVM::getStridesFromShapeAndOrder;
|
||||
using ::mlir::triton::gpu_rocm::DotOperandEncodingAttr;
|
||||
using ::mlir::triton::gpu_rocm::getContigPerThread;
|
||||
using ::mlir::triton::gpu_rocm::getOrder;
|
||||
using ::mlir::triton::gpu_rocm::getShapePerCTA;
|
||||
using ::mlir::triton::gpu_rocm::getSizePerThread;
|
||||
using ::mlir::triton::gpu_rocm::getTotalElemsPerThread;
|
||||
using ::mlir::triton::gpu_rocm::isaDistributedLayout;
|
||||
using ::mlir::triton::gpu_rocm::SharedEncodingAttr;
|
||||
|
||||
// Data loader for mma.16816 instruction.
|
||||
class MMA16816SmemLoader {
|
||||
public:
|
||||
MMA16816SmemLoader(int nPerWarp, int warpsPerTile, ArrayRef<uint32_t> order,
|
||||
ArrayRef<uint32_t> warpsPerCTA, uint32_t kOrder,
|
||||
int kWidth, ArrayRef<Value> smemStrides,
|
||||
ArrayRef<int64_t> tileShape, ArrayRef<int> instrShape,
|
||||
ArrayRef<int> matShape, int perPhase, int maxPhase,
|
||||
int elemBytes, ConversionPatternRewriter &rewriter,
|
||||
TritonGPUToLLVMTypeConverter *typeConverter,
|
||||
const Location &loc);
|
||||
|
||||
// lane = thread % 32
|
||||
// warpOff = (thread/32) % warpsPerTile(0)
|
||||
llvm::SmallVector<Value> computeOffsets(Value warpOff, Value lane,
|
||||
Value cSwizzleOffset) {
|
||||
if (canUseLdmatrix)
|
||||
return computeLdmatrixMatOffs(warpOff, lane, cSwizzleOffset);
|
||||
else
|
||||
return computeLdsMatOffs(warpOff, lane, cSwizzleOffset);
|
||||
return {};
|
||||
}
|
||||
|
||||
int getNumPtrs() const { return numPtrs; }
|
||||
|
||||
// Compute the offset to the matrix this thread(indexed by warpOff and lane)
|
||||
// mapped to.
|
||||
SmallVector<Value> computeLdmatrixMatOffs(Value warpId, Value lane,
|
||||
Value cSwizzleOffset);
|
||||
// compute 8-bit matrix offset.
|
||||
SmallVector<Value> computeLdsMatOffs(Value warpOff, Value lane,
|
||||
Value cSwizzleOffset);
|
||||
|
||||
// Load 4 matrices and returns 4 vec<2> elements.
|
||||
std::tuple<Value, Value, Value, Value> loadX4(int mat0, int mat1,
|
||||
ArrayRef<Value> ptrs,
|
||||
Type matTy,
|
||||
Type shemPtrTy) const;
|
||||
|
||||
private:
|
||||
SmallVector<uint32_t> order;
|
||||
SmallVector<uint32_t> warpsPerCTA;
|
||||
int kOrder;
|
||||
int kWidth;
|
||||
int vecWidth;
|
||||
SmallVector<int64_t> tileShape;
|
||||
SmallVector<int> instrShape;
|
||||
SmallVector<int> matShape;
|
||||
int perPhase;
|
||||
int maxPhase;
|
||||
int elemBytes;
|
||||
ConversionPatternRewriter &rewriter;
|
||||
const Location &loc;
|
||||
MLIRContext *ctx{};
|
||||
|
||||
// ldmatrix loads a matrix of size stridedMatShape x contiguousMatShape
|
||||
int contiguousMatShape;
|
||||
int stridedMatShape;
|
||||
|
||||
// Offset in shared memory to increment on the strided axis
|
||||
// This would be different than the tile shape in the case of a sliced tensor
|
||||
Value stridedSmemOffset;
|
||||
|
||||
bool needTrans;
|
||||
bool canUseLdmatrix;
|
||||
|
||||
int numPtrs;
|
||||
|
||||
// Load operations offset in number of Matrices on contiguous and strided axes
|
||||
int contiguousLoadMatOffset;
|
||||
int stridedLoadMatOffset;
|
||||
|
||||
// Offset in number of matrices to increment on non-k dim within a warp's 2x2
|
||||
// matrices
|
||||
int inWarpMatOffset;
|
||||
// Offset in number of matrices to increment on non-k dim across warps
|
||||
int warpMatOffset;
|
||||
|
||||
int nPerWarp;
|
||||
};
|
||||
|
||||
SmallVector<Value>
|
||||
MMA16816SmemLoader::computeLdmatrixMatOffs(Value warpId, Value lane,
|
||||
Value cSwizzleOffset) {
|
||||
// 4x4 matrices
|
||||
Value rowInMat = urem(lane, i32_val(8)); // row in the 8x8 matrix
|
||||
Value matIndex =
|
||||
udiv(lane, i32_val(8)); // linear index of the matrix in the 2x2 matrices
|
||||
|
||||
// Decompose matIndex => s_0, s_1, that is the coordinate in 2x2 matrices in a
|
||||
// warp
|
||||
Value s0 = urem(matIndex, i32_val(2));
|
||||
Value s1 = udiv(matIndex, i32_val(2));
|
||||
|
||||
// We use different orders for a and b for better performance.
|
||||
Value kMatArr = kOrder == 1 ? s1 : s0; // index of matrix on the k dim
|
||||
Value nkMatArr = kOrder == 1 ? s0 : s1; // index of matrix on the non-k dim
|
||||
|
||||
// Matrix coordinates inside a CTA,
|
||||
// the matrix layout is [2warpsPerTile[0], 2] for A and [2, 2warpsPerTile[1]]
|
||||
// for B. e.g., Setting warpsPerTile=4, the data layout for A(kOrder=1) is
|
||||
// |0 0| -> 0,1,2,3 are the warpids
|
||||
// |0 0|
|
||||
// |1 1|
|
||||
// |1 1|
|
||||
// |2 2|
|
||||
// |2 2|
|
||||
// |3 3|
|
||||
// |3 3|
|
||||
//
|
||||
// for B(kOrder=0) is
|
||||
// |0 1 2 3 0 1 2 3| -> 0,1,2,3 are the warpids
|
||||
// |0 1 2 3 0 1 2 3|
|
||||
// Note, for each warp, it handles a 2x2 matrices, that is the coordinate
|
||||
// address (s0,s1) annotates.
|
||||
|
||||
Value matOff[2];
|
||||
// When B's shape(k, n) is (16, 8) and ldmatrix.x4 is used, the shared memory
|
||||
// access will be out of bound. In the future we should change this case to
|
||||
// ldmatrix.x2
|
||||
if (kOrder == 0 && nPerWarp == 8) {
|
||||
matOff[kOrder ^ 1] = mul(warpId, i32_val(warpMatOffset));
|
||||
} else {
|
||||
matOff[kOrder ^ 1] = add(
|
||||
mul(warpId, i32_val(warpMatOffset)), // warp offset (kOrder=1)
|
||||
mul(nkMatArr,
|
||||
i32_val(
|
||||
inWarpMatOffset))); // matrix offset inside a warp (kOrder=1)
|
||||
}
|
||||
matOff[kOrder] = kMatArr;
|
||||
|
||||
// Physical offset (before swizzling)
|
||||
Value contiguousMatIndex = matOff[order[0]];
|
||||
Value stridedMatIndex = matOff[order[1]];
|
||||
// Add the offset of the slice
|
||||
Value contiguousSliceMatOffset =
|
||||
udiv(cSwizzleOffset, i32_val(contiguousMatShape));
|
||||
|
||||
SmallVector<Value> offs(numPtrs);
|
||||
Value phase = urem(udiv(rowInMat, i32_val(perPhase)), i32_val(maxPhase));
|
||||
// To prevent out-of-bound access of B when warpsPerTile * 16 > tile_size.
|
||||
// In such a case, we need to wrap around the offset of B.
|
||||
// |0 1 2 3 0 1 2 3| -> | 0(0) 1(1) 2(2) 3(3) |
|
||||
// |0 1 2 3 0 1 2 3| | 0(0) 1(1) 2(2) 3(3) |
|
||||
// ~~~~~~~ out-of-bound access
|
||||
|
||||
Value rowOffset =
|
||||
urem(add(rowInMat, mul(stridedMatIndex, i32_val(stridedMatShape))),
|
||||
i32_val(tileShape[order[1]]));
|
||||
auto contiguousTileNumMats = tileShape[order[0]] / matShape[order[0]];
|
||||
|
||||
for (int i = 0; i < numPtrs; ++i) {
|
||||
Value contiguousIndex =
|
||||
add(contiguousMatIndex, i32_val(i * contiguousLoadMatOffset));
|
||||
if (warpsPerCTA[order[0]] > contiguousTileNumMats ||
|
||||
contiguousTileNumMats % warpsPerCTA[order[0]] != 0)
|
||||
contiguousIndex = urem(contiguousIndex, i32_val(contiguousTileNumMats));
|
||||
contiguousIndex = add(contiguousIndex, contiguousSliceMatOffset);
|
||||
Value contiguousIndexSwizzled = xor_(contiguousIndex, phase);
|
||||
offs[i] = add(mul(contiguousIndexSwizzled, i32_val(contiguousMatShape)),
|
||||
mul(rowOffset, stridedSmemOffset));
|
||||
}
|
||||
|
||||
return offs;
|
||||
}
|
||||
|
||||
// clang-format off
|
||||
// Each `ldmatrix.x4` loads data as follows when `needTrans == False`:
|
||||
//
|
||||
// quad width
|
||||
// <----------------------------------------->
|
||||
// vecWidth
|
||||
// <------->
|
||||
// *#t0 ... *#t0 t1 ... t1 t2 ... t2 t3 ... t3 || *t0 ... *t0 t1 ... t1 t2 ... t2 t3 ... t3 /|\
|
||||
// t4 ... t4 t5 ... t5 t6 ... t6 t7 ... t7 || t4 ... t4 t5 ... t5 t6 ... t6 t7 ... t7 |
|
||||
// t8 ... t8 t9 ... t9 t10 .. t10 t11 .. t11 || t8 ... t8 t9 ... t9 t10 .. t10 t11 .. t11 | quad height
|
||||
// ... |
|
||||
// t28 .. t28 t29 .. t29 t30 .. t30 t31 .. t31 || t28 .. t28 t29 .. t29 t30 .. t30 t31 .. t31 \|/
|
||||
// --------------------------------------------- || --------------------------------------------
|
||||
// *#t0 ... *#t0 t1 ... t1 t2 ... t2 t3 ... t3 || t0 ... t0 t1 ... t1 t2 ... t2 t3 ... t3
|
||||
// t4 ... t4 t5 ... t5 t6 ... t6 t7 ... t7 || t4 ... t4 t5 ... t5 t6 ... t6 t7 ... t7
|
||||
// t8 ... t8 t9 ... t9 t10 .. t10 t11 .. t11 || t8 ... t8 t9 ... t9 t10 .. t10 t11 .. t11
|
||||
// ...
|
||||
// t28 .. t28 t29 .. t29 t30 .. t30 t31 .. t31 || t28 .. t28 t29 .. t29 t30 .. t30 t31 .. t31
|
||||
//
|
||||
// we assume that the phase is < 8 so we don't need to maintain a separate pointer for the two
|
||||
// lower quadrants. This pattern repeats every warpsPerTile[0] (resp. warpsPerTile[1]) blocks
|
||||
// along the row (resp. col) dimension.
|
||||
// clang-format on
|
||||
|
||||
SmallVector<Value> MMA16816SmemLoader::computeLdsMatOffs(Value warpOff,
|
||||
Value lane,
|
||||
Value cSwizzleOffset) {
|
||||
int cTileShape = tileShape[order[0]];
|
||||
int sTileShape = tileShape[order[1]];
|
||||
if (!needTrans) {
|
||||
std::swap(cTileShape, sTileShape);
|
||||
}
|
||||
|
||||
SmallVector<Value> offs(numPtrs);
|
||||
|
||||
int threadsPerQuad[2] = {8, 4};
|
||||
int laneWidth = 4;
|
||||
int laneHeight = 8;
|
||||
int quadWidth = laneWidth * kWidth;
|
||||
int quadHeight = laneHeight;
|
||||
int numQuadI = 2;
|
||||
|
||||
// outer index base
|
||||
Value iBase = udiv(lane, i32_val(laneWidth));
|
||||
|
||||
for (int rep = 0; rep < numPtrs / (2 * kWidth); ++rep)
|
||||
for (int quadId = 0; quadId < 2; ++quadId)
|
||||
for (int elemId = 0; elemId < kWidth; ++elemId) {
|
||||
// inner index base
|
||||
Value jBase = mul(urem(lane, i32_val(laneWidth)), i32_val(kWidth));
|
||||
jBase = add(jBase, i32_val(elemId));
|
||||
// inner index offset
|
||||
Value jOff = i32_val(0);
|
||||
if (!needTrans) {
|
||||
jOff = add(jOff, i32_val(quadId));
|
||||
jOff = add(jOff, i32_val(rep * contiguousLoadMatOffset));
|
||||
}
|
||||
// outer index offset
|
||||
Value iOff = mul(warpOff, i32_val(warpMatOffset));
|
||||
if (needTrans) {
|
||||
int pStride = kOrder == 1 ? 1 : 2;
|
||||
iOff = add(iOff, i32_val(quadId * inWarpMatOffset));
|
||||
iOff = add(iOff, i32_val(rep * contiguousLoadMatOffset * pStride));
|
||||
}
|
||||
// swizzle
|
||||
if (!needTrans) {
|
||||
Value phase = urem(udiv(iBase, i32_val(perPhase)), i32_val(maxPhase));
|
||||
jOff = add(jOff, udiv(cSwizzleOffset, i32_val(quadWidth)));
|
||||
jOff = xor_(jOff, phase);
|
||||
} else {
|
||||
Value phase = urem(udiv(jBase, i32_val(perPhase)), i32_val(maxPhase));
|
||||
iOff = add(iOff, udiv(cSwizzleOffset, i32_val(quadHeight)));
|
||||
iOff = xor_(iOff, phase);
|
||||
}
|
||||
// To prevent out-of-bound access when tile is too small.
|
||||
Value i = add(iBase, mul(iOff, i32_val(quadHeight)));
|
||||
Value j = add(jBase, mul(jOff, i32_val(quadWidth)));
|
||||
// Compute id of this ptr
|
||||
int idx = rep * 2 * kWidth;
|
||||
if (needTrans) {
|
||||
idx += quadId * vecWidth;
|
||||
idx += elemId % vecWidth;
|
||||
idx += elemId / vecWidth * kWidth;
|
||||
} else {
|
||||
idx += quadId * kWidth;
|
||||
idx += elemId;
|
||||
}
|
||||
|
||||
if (needTrans) {
|
||||
offs[idx] = add(i, mul(j, stridedSmemOffset));
|
||||
} else {
|
||||
offs[idx] = add(mul(i, stridedSmemOffset), j);
|
||||
}
|
||||
}
|
||||
|
||||
return offs;
|
||||
}
|
||||
|
||||
std::tuple<Value, Value, Value, Value>
|
||||
MMA16816SmemLoader::loadX4(int mat0, int mat1, ArrayRef<Value> ptrs, Type matTy,
|
||||
Type shemPtrTy) const {
|
||||
assert(mat0 % 2 == 0 && mat1 % 2 == 0 && "smem matrix load must be aligned");
|
||||
int matIdx[2] = {mat0, mat1};
|
||||
|
||||
int ptrIdx{-1};
|
||||
|
||||
if (canUseLdmatrix)
|
||||
ptrIdx = matIdx[order[0]] / (instrShape[order[0]] / matShape[order[0]]);
|
||||
else
|
||||
ptrIdx = matIdx[order[0]] * (needTrans ? kWidth : vecWidth);
|
||||
|
||||
// The main difference with the original triton code is we removed the
|
||||
// prefetch-related logic here for the upstream optimizer phase should
|
||||
// take care with it, and that is transparent in dot conversion.
|
||||
auto getPtr = [&](int idx) { return ptrs[idx]; };
|
||||
Value ptr = getPtr(ptrIdx);
|
||||
|
||||
// The struct should have exactly the same element types.
|
||||
auto resTy = matTy.cast<LLVM::LLVMStructType>();
|
||||
Type elemTy = matTy.cast<LLVM::LLVMStructType>().getBody()[0];
|
||||
|
||||
// For some reasons, LLVM's NVPTX backend inserts unnecessary (?) integer
|
||||
// instructions to pack & unpack sub-word integers. A workaround is to
|
||||
// store the results of ldmatrix in i32
|
||||
if (auto vecElemTy = elemTy.dyn_cast<VectorType>()) {
|
||||
Type elemElemTy = vecElemTy.getElementType();
|
||||
if (auto intTy = elemElemTy.dyn_cast<IntegerType>()) {
|
||||
if (intTy.getWidth() <= 16) {
|
||||
elemTy = rewriter.getI32Type();
|
||||
resTy =
|
||||
LLVM::LLVMStructType::getLiteral(ctx, SmallVector<Type>(4, elemTy));
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
if (canUseLdmatrix) {
|
||||
Value stridedOffset =
|
||||
mul(i32_val(matIdx[order[1]] * stridedLoadMatOffset * stridedMatShape),
|
||||
stridedSmemOffset);
|
||||
Value readPtr = gep(shemPtrTy, ptr, stridedOffset);
|
||||
|
||||
PTXBuilder builder;
|
||||
// ldmatrix.m8n8.x4 returns 4x2xfp16(that is 4xb32) elements for a
|
||||
// thread.
|
||||
auto resArgs = builder.newListOperand(4, "=r");
|
||||
auto addrArg = builder.newAddrOperand(readPtr, "r");
|
||||
|
||||
auto ldmatrix = builder.create("ldmatrix.sync.aligned.m8n8.x4")
|
||||
->o("trans", needTrans /*predicate*/)
|
||||
.o("shared.b16");
|
||||
ldmatrix(resArgs, addrArg);
|
||||
|
||||
// The result type is 4xi32, each i32 is composed of 2xf16
|
||||
// elements (adjacent two columns in a row) or a single f32 element.
|
||||
Value resV4 = builder.launch(rewriter, loc, resTy);
|
||||
return {extract_val(elemTy, resV4, 0), extract_val(elemTy, resV4, 1),
|
||||
extract_val(elemTy, resV4, 2), extract_val(elemTy, resV4, 3)};
|
||||
} else {
|
||||
// base pointers
|
||||
std::array<std::array<Value, 4>, 2> ptrs;
|
||||
for (int i = 0; i < vecWidth; i++)
|
||||
ptrs[0][i] = getPtr(ptrIdx + i);
|
||||
for (int i = 0; i < vecWidth; i++)
|
||||
ptrs[1][i] = getPtr(ptrIdx + i + vecWidth);
|
||||
// static offsets along outer dimension
|
||||
int _i0 = matIdx[order[1]] * (stridedLoadMatOffset * stridedMatShape);
|
||||
int _i1 = _i0;
|
||||
if (needTrans)
|
||||
_i1 += (kWidth != vecWidth) ? vecWidth
|
||||
: stridedLoadMatOffset * stridedMatShape;
|
||||
else
|
||||
_i1 += (kOrder == 1 ? 1 : stridedLoadMatOffset) * stridedMatShape;
|
||||
Value i0 = mul(i32_val(_i0), stridedSmemOffset);
|
||||
Value i1 = mul(i32_val(_i1), stridedSmemOffset);
|
||||
std::array<Value, 2> ii = {i0, i1};
|
||||
// load 4 32-bit values from shared memory
|
||||
// (equivalent to ldmatrix.x4)
|
||||
SmallVector<SmallVector<Value>> vptrs(4, SmallVector<Value>(vecWidth));
|
||||
|
||||
for (int i = 0; i < 4; ++i)
|
||||
for (int j = 0; j < vecWidth; ++j) {
|
||||
vptrs[i][j] = gep(shemPtrTy, ptrs[i / 2][j], ii[i % 2]);
|
||||
}
|
||||
// row + trans and col + no-trans are equivalent
|
||||
bool isActualTrans =
|
||||
(needTrans && kOrder == 1) || (!needTrans && kOrder == 0);
|
||||
// pack loaded vectors into 4 32-bit values
|
||||
int inc = needTrans ? 1 : kWidth;
|
||||
VectorType packedTy = vec_ty(int_ty(8 * elemBytes), inc);
|
||||
int canonBits = std::min(32, 8 * elemBytes * inc);
|
||||
int canonWidth = (8 * elemBytes * inc) / canonBits;
|
||||
Type canonInt = int_ty(canonBits);
|
||||
std::array<Value, 4> retElems;
|
||||
retElems.fill(undef(vec_ty(canonInt, 32 / canonBits)));
|
||||
for (int r = 0; r < 2; ++r) {
|
||||
for (int em = 0; em < 2 * vecWidth; em += inc) {
|
||||
int e = em % vecWidth;
|
||||
int m = em / vecWidth;
|
||||
int idx = m * 2 + r;
|
||||
Value ptr = bitcast(vptrs[idx][e], ptr_ty(packedTy, 3));
|
||||
Value val = load(ptr);
|
||||
Value canonval = bitcast(val, vec_ty(canonInt, canonWidth));
|
||||
for (int w = 0; w < canonWidth; ++w) {
|
||||
int ridx = idx + w * kWidth / vecWidth;
|
||||
retElems[ridx] =
|
||||
insert_element(retElems[ridx],
|
||||
extract_element(canonval, i32_val(w)), i32_val(e));
|
||||
}
|
||||
}
|
||||
}
|
||||
if (isActualTrans)
|
||||
std::swap(retElems[1], retElems[2]);
|
||||
return {bitcast(retElems[0], i32_ty), bitcast(retElems[1], i32_ty),
|
||||
bitcast(retElems[2], i32_ty), bitcast(retElems[3], i32_ty)};
|
||||
}
|
||||
}
|
||||
|
||||
MMA16816SmemLoader::MMA16816SmemLoader(
|
||||
int nPerWarp, int warpsPerTile, ArrayRef<uint32_t> order,
|
||||
ArrayRef<uint32_t> warpsPerCTA, uint32_t kOrder, int kWidth,
|
||||
ArrayRef<Value> smemStrides, ArrayRef<int64_t> tileShape,
|
||||
ArrayRef<int> instrShape, ArrayRef<int> matShape, int perPhase,
|
||||
int maxPhase, int elemBytes, ConversionPatternRewriter &rewriter,
|
||||
TritonGPUToLLVMTypeConverter *typeConverter, const Location &loc)
|
||||
: nPerWarp(nPerWarp), order(order.begin(), order.end()),
|
||||
warpsPerCTA(warpsPerCTA.begin(), warpsPerCTA.end()), kOrder(kOrder),
|
||||
kWidth(kWidth), tileShape(tileShape.begin(), tileShape.end()),
|
||||
instrShape(instrShape.begin(), instrShape.end()),
|
||||
matShape(matShape.begin(), matShape.end()), perPhase(perPhase),
|
||||
maxPhase(maxPhase), elemBytes(elemBytes), rewriter(rewriter), loc(loc),
|
||||
ctx(rewriter.getContext()) {
|
||||
contiguousMatShape = matShape[order[0]];
|
||||
stridedMatShape = matShape[order[1]];
|
||||
stridedSmemOffset = smemStrides[order[1]];
|
||||
vecWidth = 4 / elemBytes;
|
||||
|
||||
// rule: k must be the fast-changing axis.
|
||||
needTrans = kOrder != order[0];
|
||||
canUseLdmatrix = elemBytes == 2 || (!needTrans);
|
||||
canUseLdmatrix = canUseLdmatrix && (kWidth == vecWidth);
|
||||
// canUseLdmatrix = false;
|
||||
|
||||
if (canUseLdmatrix) {
|
||||
// Each CTA, the warps is arranged as [1xwarpsPerTile] if not transposed,
|
||||
// otherwise [warpsPerTilex1], and each warp will perform a mma.
|
||||
numPtrs = tileShape[order[0]] / (needTrans ? warpsPerTile : 1) /
|
||||
instrShape[order[0]];
|
||||
} else {
|
||||
numPtrs = tileShape[order[0]] / (needTrans ? warpsPerTile : 1) /
|
||||
matShape[order[0]];
|
||||
numPtrs *= kWidth;
|
||||
}
|
||||
numPtrs = std::max<int>(numPtrs, 2);
|
||||
|
||||
// Special rule for i8/u8, 4 ptrs for each matrix
|
||||
// if (!canUseLdmatrix && elemBytes == 1)
|
||||
|
||||
int loadOffsetInMat[2];
|
||||
loadOffsetInMat[kOrder] =
|
||||
2; // instrShape[kOrder] / matShape[kOrder], always 2
|
||||
loadOffsetInMat[kOrder ^ 1] =
|
||||
warpsPerTile * (instrShape[kOrder ^ 1] / matShape[kOrder ^ 1]);
|
||||
|
||||
contiguousLoadMatOffset = loadOffsetInMat[order[0]];
|
||||
|
||||
stridedLoadMatOffset =
|
||||
loadOffsetInMat[order[1]] / (instrShape[order[1]] / matShape[order[1]]);
|
||||
|
||||
// The stride (in number of matrices) within warp
|
||||
inWarpMatOffset = kOrder == 1 ? 1 : warpsPerTile;
|
||||
// The stride (in number of matrices) of each warp
|
||||
warpMatOffset = instrShape[kOrder ^ 1] / matShape[kOrder ^ 1];
|
||||
}
|
||||
|
||||
Type getSharedMemPtrTy(Type argType) {
|
||||
MLIRContext *ctx = argType.getContext();
|
||||
if (argType.isF16())
|
||||
return ptr_ty(type::f16Ty(ctx), 3);
|
||||
else if (argType.isBF16())
|
||||
return ptr_ty(type::i16Ty(ctx), 3);
|
||||
else if (argType.isF32())
|
||||
return ptr_ty(type::f32Ty(ctx), 3);
|
||||
else if (argType.getIntOrFloatBitWidth() == 8)
|
||||
return ptr_ty(type::i8Ty(ctx), 3);
|
||||
else
|
||||
llvm::report_fatal_error("mma16816 data type not supported");
|
||||
}
|
||||
|
||||
Value composeValuesToDotOperandLayoutStruct(
|
||||
const ValueTable &vals, int n0, int n1,
|
||||
TritonGPUToLLVMTypeConverter *typeConverter, Location loc,
|
||||
ConversionPatternRewriter &rewriter) {
|
||||
std::vector<Value> elems;
|
||||
for (int m = 0; m < n0; ++m)
|
||||
for (int k = 0; k < n1; ++k) {
|
||||
elems.push_back(vals.at({2 * m, 2 * k}));
|
||||
elems.push_back(vals.at({2 * m, 2 * k + 1}));
|
||||
elems.push_back(vals.at({2 * m + 1, 2 * k}));
|
||||
elems.push_back(vals.at({2 * m + 1, 2 * k + 1}));
|
||||
}
|
||||
|
||||
assert(!elems.empty());
|
||||
|
||||
Type elemTy = elems[0].getType();
|
||||
MLIRContext *ctx = elemTy.getContext();
|
||||
Type structTy = LLVM::LLVMStructType::getLiteral(
|
||||
ctx, SmallVector<Type>(elems.size(), elemTy));
|
||||
auto result = typeConverter->packLLElements(loc, elems, rewriter, structTy);
|
||||
return result;
|
||||
}
|
||||
|
||||
std::function<void(int, int)> getLoadMatrixFn(
|
||||
Value tensor, const SharedMemoryObject &smemObj, MmaEncodingAttr mmaLayout,
|
||||
int warpsPerTile, uint32_t kOrder, int kWidth, SmallVector<int> instrShape,
|
||||
SmallVector<int> matShape, Value warpId, Value lane, ValueTable &vals,
|
||||
bool isA, TritonGPUToLLVMTypeConverter *typeConverter,
|
||||
ConversionPatternRewriter &rewriter, Location loc) {
|
||||
auto tensorTy = tensor.getType().cast<RankedTensorType>();
|
||||
auto shapePerCTA = getShapePerCTA(tensorTy);
|
||||
Type eltTy = tensorTy.getElementType();
|
||||
// We assumes that the input operand of Dot should be from shared layout.
|
||||
// TODO(Superjomn) Consider other layouts if needed later.
|
||||
auto sharedLayout = tensorTy.getEncoding().cast<SharedEncodingAttr>();
|
||||
const int perPhase = sharedLayout.getPerPhase();
|
||||
const int maxPhase = sharedLayout.getMaxPhase();
|
||||
const int vecPhase = sharedLayout.getVec();
|
||||
const int elemBytes = tensorTy.getElementTypeBitWidth() / 8;
|
||||
auto order = sharedLayout.getOrder();
|
||||
|
||||
if (kWidth != (4 / elemBytes))
|
||||
assert(vecPhase == 1 || vecPhase == 4 * kWidth);
|
||||
|
||||
int nPerWarp =
|
||||
std::max<int>(shapePerCTA[1] / mmaLayout.getWarpsPerCTA()[1], 8);
|
||||
|
||||
// (a, b) is the coordinate.
|
||||
auto load = [=, &rewriter, &vals](int a, int b) {
|
||||
MMA16816SmemLoader loader(nPerWarp, warpsPerTile, sharedLayout.getOrder(),
|
||||
mmaLayout.getWarpsPerCTA(), kOrder, kWidth,
|
||||
smemObj.strides, shapePerCTA /*tileShape*/,
|
||||
instrShape, matShape, perPhase, maxPhase,
|
||||
elemBytes, rewriter, typeConverter, loc);
|
||||
// Offset of a slice within the original tensor in shared memory
|
||||
Value cSwizzleOffset = smemObj.getCSwizzleOffset(order[0]);
|
||||
SmallVector<Value> offs =
|
||||
loader.computeOffsets(warpId, lane, cSwizzleOffset);
|
||||
// initialize pointers
|
||||
const int numPtrs = loader.getNumPtrs();
|
||||
SmallVector<Value> ptrs(numPtrs);
|
||||
Value smemBase = smemObj.getBaseBeforeSlice(order[0], loc, rewriter);
|
||||
Type smemPtrTy = getSharedMemPtrTy(eltTy);
|
||||
for (int i = 0; i < numPtrs; ++i)
|
||||
ptrs[i] = bitcast(gep(smemPtrTy, smemBase, offs[i]), smemPtrTy);
|
||||
// actually load from shared memory
|
||||
auto matTy = LLVM::LLVMStructType::getLiteral(eltTy.getContext(),
|
||||
SmallVector<Type>(4, i32_ty));
|
||||
auto [ha0, ha1, ha2, ha3] = loader.loadX4(
|
||||
(kOrder == 1) ? a : b /*mat0*/, (kOrder == 1) ? b : a /*mat1*/, ptrs,
|
||||
matTy, getSharedMemPtrTy(eltTy));
|
||||
if (!isA)
|
||||
std::swap(ha1, ha2);
|
||||
// the following is incorrect
|
||||
// but causes dramatically better performance in ptxas
|
||||
// although it only changes the order of operands in
|
||||
// `mma.sync`
|
||||
// if(isA)
|
||||
// std::swap(ha1, ha2);
|
||||
// update user-provided values in-place
|
||||
vals[{a, b}] = ha0;
|
||||
vals[{a + 1, b}] = ha1;
|
||||
vals[{a, b + 1}] = ha2;
|
||||
vals[{a + 1, b + 1}] = ha3;
|
||||
};
|
||||
|
||||
return load;
|
||||
}
|
||||
|
||||
Value loadArg(ConversionPatternRewriter &rewriter, Location loc, Value tensor,
|
||||
DotOperandEncodingAttr encoding,
|
||||
const SharedMemoryObject &smemObj,
|
||||
TritonGPUToLLVMTypeConverter *typeConverter, Value thread,
|
||||
bool isA) {
|
||||
auto tensorTy = tensor.getType().cast<RankedTensorType>();
|
||||
auto shapePerCTA = getShapePerCTA(tensorTy);
|
||||
int bitwidth = tensorTy.getElementTypeBitWidth();
|
||||
auto mmaLayout = encoding.getParent().cast<MmaEncodingAttr>();
|
||||
|
||||
ValueTable vals;
|
||||
int mmaInstrM = 16, mmaInstrN = 8, mmaInstrK = 4 * 64 / bitwidth;
|
||||
int matShapeM = 8, matShapeN = 8, matShapeK = 2 * 64 / bitwidth;
|
||||
|
||||
auto numRep = encoding.getMMAv2Rep(shapePerCTA, bitwidth);
|
||||
int kWidth = encoding.getKWidth();
|
||||
|
||||
auto warpsPerCTA = mmaLayout.getWarpsPerCTA();
|
||||
auto order = triton::gpu_rocm::getOrder(mmaLayout);
|
||||
Value warp = udiv(thread, i32_val(32));
|
||||
Value lane = urem(thread, i32_val(32));
|
||||
|
||||
SmallVector<Value> multiDimWarpId =
|
||||
delinearize(rewriter, loc, warp, warpsPerCTA, order);
|
||||
Value warpM = urem(multiDimWarpId[0], i32_val(shapePerCTA[0] / 16));
|
||||
Value warpN = urem(multiDimWarpId[1], i32_val(shapePerCTA[1] / 8));
|
||||
|
||||
int warpsPerTile;
|
||||
if (isA)
|
||||
warpsPerTile = std::min<int>(warpsPerCTA[0], shapePerCTA[0] / 16);
|
||||
else
|
||||
warpsPerTile = std::min<int>(warpsPerCTA[1], shapePerCTA[1] / 16);
|
||||
|
||||
std::function<void(int, int)> loadFn;
|
||||
if (isA)
|
||||
loadFn = getLoadMatrixFn(
|
||||
tensor, smemObj, mmaLayout, warpsPerTile /*warpsPerTile*/, 1 /*kOrder*/,
|
||||
kWidth, {mmaInstrM, mmaInstrK} /*instrShape*/,
|
||||
{matShapeM, matShapeK} /*matShape*/, warpM /*warpId*/, lane /*laneId*/,
|
||||
vals /*vals*/, isA /*isA*/, typeConverter /* typeConverter */,
|
||||
rewriter /*rewriter*/, loc /*loc*/);
|
||||
else
|
||||
loadFn = getLoadMatrixFn(
|
||||
tensor, smemObj, mmaLayout, warpsPerTile /*warpsPerTile*/, 0 /*kOrder*/,
|
||||
kWidth, {mmaInstrK, mmaInstrN} /*instrShape*/,
|
||||
{matShapeK, matShapeN} /*matShape*/, warpN /*warpId*/, lane /*laneId*/,
|
||||
vals /*vals*/, isA /*isA*/, typeConverter /* typeConverter */,
|
||||
rewriter /*rewriter*/, loc /*loc*/);
|
||||
|
||||
// Perform loading.
|
||||
int numRepOuter = isA ? numRep[0] : std::max<int>(numRep[1] / 2, 1);
|
||||
int numRepK = isA ? numRep[1] : numRep[0];
|
||||
for (int m = 0; m < numRepOuter; ++m)
|
||||
for (int k = 0; k < numRepK; ++k)
|
||||
loadFn(2 * m, 2 * k);
|
||||
|
||||
// Format the values to LLVM::Struct to passing to mma codegen.
|
||||
return composeValuesToDotOperandLayoutStruct(vals, numRepOuter, numRepK,
|
||||
typeConverter, loc, rewriter);
|
||||
}
|
||||
|
||||
namespace SharedToDotOperandMMAv2 {
|
||||
Value convertLayout(int opIdx, ConversionPatternRewriter &rewriter,
|
||||
Location loc, Value tensor, DotOperandEncodingAttr encoding,
|
||||
const SharedMemoryObject &smemObj,
|
||||
TritonGPUToLLVMTypeConverter *typeConverter, Value thread) {
|
||||
if (opIdx == 0)
|
||||
return loadArg(rewriter, loc, tensor, encoding, smemObj, typeConverter,
|
||||
thread, true);
|
||||
else {
|
||||
assert(opIdx == 1);
|
||||
return loadArg(rewriter, loc, tensor, encoding, smemObj, typeConverter,
|
||||
thread, false);
|
||||
}
|
||||
}
|
||||
} // namespace SharedToDotOperandMMAv2
|
||||
165
python/triton/third_party/hip/lib/Conversion/TritonGPUROCMToLLVM/DotOpToLLVM.cpp
vendored
Normal file
165
python/triton/third_party/hip/lib/Conversion/TritonGPUROCMToLLVM/DotOpToLLVM.cpp
vendored
Normal file
@@ -0,0 +1,165 @@
|
||||
#include "DotOpToLLVM.h"
|
||||
#include "Utility.h"
|
||||
|
||||
using namespace mlir;
|
||||
using namespace mlir::triton;
|
||||
|
||||
using ::mlir::LLVM::getSharedMemoryObjectFromStruct;
|
||||
using ::mlir::triton::gpu_rocm::DotOperandEncodingAttr;
|
||||
using ::mlir::triton::gpu_rocm::getShapePerCTA;
|
||||
using ::mlir::triton::gpu_rocm::MmaEncodingAttr;
|
||||
|
||||
LogicalResult convertFMADot(triton::DotOp op, triton::DotOp::Adaptor adaptor,
|
||||
TritonGPUToLLVMTypeConverter *typeConverter,
|
||||
ConversionPatternRewriter &rewriter);
|
||||
|
||||
LogicalResult convertMMA884(triton::DotOp op, triton::DotOp::Adaptor adaptor,
|
||||
TritonGPUToLLVMTypeConverter *typeConverter,
|
||||
ConversionPatternRewriter &rewriter);
|
||||
|
||||
LogicalResult convertMMA1688(triton::DotOp op, triton::DotOp::Adaptor adaptor,
|
||||
TritonGPUToLLVMTypeConverter *typeConverter,
|
||||
ConversionPatternRewriter &rewriter);
|
||||
|
||||
LogicalResult convertMMA16816(triton::DotOp op, triton::DotOp::Adaptor adaptor,
|
||||
TritonGPUToLLVMTypeConverter *typeConverter,
|
||||
ConversionPatternRewriter &rewriter);
|
||||
|
||||
#if 1
|
||||
LogicalResult convertMFMA(triton::DotOp op, triton::DotOp::Adaptor adaptor,
|
||||
TritonGPUToLLVMTypeConverter *typeConverter,
|
||||
ConversionPatternRewriter &rewriter);
|
||||
#endif
|
||||
LogicalResult convertWGMMA(triton::DotOp op, triton::DotOp::Adaptor adaptor,
|
||||
TritonGPUToLLVMTypeConverter *typeConverter,
|
||||
ConversionPatternRewriter &rewriter, Value thread);
|
||||
|
||||
LogicalResult convertAsyncWGMMA(triton::nvidia_gpu::DotAsyncOp op,
|
||||
triton::nvidia_gpu::DotAsyncOp::Adaptor adaptor,
|
||||
TritonGPUToLLVMTypeConverter *typeConverter,
|
||||
ConversionPatternRewriter &rewriter,
|
||||
Value thread);
|
||||
|
||||
struct DotOpConversion : public ConvertTritonGPUOpToLLVMPattern<triton::DotOp> {
|
||||
using ConvertTritonGPUOpToLLVMPattern<
|
||||
triton::DotOp>::ConvertTritonGPUOpToLLVMPattern;
|
||||
|
||||
LogicalResult
|
||||
matchAndRewrite(triton::DotOp op, OpAdaptor adaptor,
|
||||
ConversionPatternRewriter &rewriter) const override {
|
||||
Location loc = op->getLoc();
|
||||
// D = A * B + C
|
||||
Value A = op.getA();
|
||||
Value D = op.getResult();
|
||||
|
||||
// Here we assume the DotOp's operands always comes from shared memory.
|
||||
auto AShapePerCTA = getShapePerCTA(A.getType());
|
||||
size_t reduceAxis = 1;
|
||||
unsigned K = AShapePerCTA[reduceAxis];
|
||||
bool isOuter = K == 1;
|
||||
|
||||
MmaEncodingAttr mmaLayout = D.getType()
|
||||
.cast<RankedTensorType>()
|
||||
.getEncoding()
|
||||
.dyn_cast<MmaEncodingAttr>();
|
||||
if (!isOuter && mmaLayout && supportMMA(op, mmaLayout.getVersionMajor())) {
|
||||
if (mmaLayout.isVolta())
|
||||
return convertMMA884(op, adaptor, getTypeConverter(), rewriter);
|
||||
if (mmaLayout.isTuring())
|
||||
return convertMMA1688(op, adaptor, getTypeConverter(), rewriter);
|
||||
if (mmaLayout.isAmpere())
|
||||
return convertMMA16816(op, adaptor, getTypeConverter(), rewriter);
|
||||
if (mmaLayout.isHopper())
|
||||
return convertWGMMA(op, adaptor, getTypeConverter(), rewriter,
|
||||
getThreadId(rewriter, loc));
|
||||
|
||||
llvm::report_fatal_error(
|
||||
"Unsupported MMA kind found when converting DotOp to LLVM.");
|
||||
}
|
||||
|
||||
#if 1
|
||||
MfmaEncodingAttr mfmaLayout = D.getType()
|
||||
.cast<RankedTensorType>()
|
||||
.getEncoding()
|
||||
.dyn_cast<MfmaEncodingAttr>();
|
||||
if (!isOuter && mfmaLayout && supportMFMA(op, mfmaLayout.getNonKDim())) {
|
||||
return convertMFMA(op, adaptor, getTypeConverter(), rewriter);
|
||||
}
|
||||
#endif
|
||||
|
||||
if (D.getType()
|
||||
.cast<RankedTensorType>()
|
||||
.getEncoding()
|
||||
.isa<BlockedEncodingAttr>())
|
||||
return convertFMADot(op, adaptor, getTypeConverter(), rewriter);
|
||||
|
||||
llvm::report_fatal_error(
|
||||
"Unsupported DotOp found when converting TritonGPU to LLVM.");
|
||||
}
|
||||
};
|
||||
|
||||
struct DotAsyncOpConversion
|
||||
: public ConvertTritonGPUOpToLLVMPattern<triton::nvidia_gpu::DotAsyncOp> {
|
||||
using ConvertTritonGPUOpToLLVMPattern<
|
||||
triton::nvidia_gpu::DotAsyncOp>::ConvertTritonGPUOpToLLVMPattern;
|
||||
|
||||
LogicalResult
|
||||
matchAndRewrite(triton::nvidia_gpu::DotAsyncOp op, OpAdaptor adaptor,
|
||||
ConversionPatternRewriter &rewriter) const override {
|
||||
auto loc = op.getLoc();
|
||||
// D = A * B + C
|
||||
Value A = op.getA();
|
||||
Value D = op.getResult();
|
||||
|
||||
// Here we assume the DotOp's operands always comes from shared memory.
|
||||
auto AShapePerCTA = getShapePerCTA(A.getType());
|
||||
size_t reduceAxis = 1;
|
||||
unsigned K = AShapePerCTA[reduceAxis];
|
||||
bool isOuter = K == 1;
|
||||
|
||||
MmaEncodingAttr mmaLayout = D.getType()
|
||||
.cast<RankedTensorType>()
|
||||
.getEncoding()
|
||||
.dyn_cast<MmaEncodingAttr>();
|
||||
if (!isOuter && mmaLayout &&
|
||||
supportMMA(op.getOperand(0), mmaLayout.getVersionMajor())) {
|
||||
if (mmaLayout.isHopper()) {
|
||||
return convertAsyncWGMMA(op, adaptor, getTypeConverter(), rewriter,
|
||||
getThreadId(rewriter, loc));
|
||||
}
|
||||
|
||||
llvm::report_fatal_error(
|
||||
"Unsupported MMA kind found when converting DotAsyncOp to LLVM.");
|
||||
}
|
||||
|
||||
llvm::report_fatal_error(
|
||||
"Unsupported DotAsyncOp found when converting TritonGPU to LLVM.");
|
||||
}
|
||||
};
|
||||
|
||||
struct DotWaitOpConversion
|
||||
: public ConvertTritonGPUOpToLLVMPattern<triton::nvidia_gpu::DotWaitOp> {
|
||||
using ConvertTritonGPUOpToLLVMPattern<
|
||||
triton::nvidia_gpu::DotWaitOp>::ConvertTritonGPUOpToLLVMPattern;
|
||||
|
||||
LogicalResult
|
||||
matchAndRewrite(triton::nvidia_gpu::DotWaitOp op, OpAdaptor adaptor,
|
||||
ConversionPatternRewriter &rewriter) const override {
|
||||
auto pendings = op.getPendings();
|
||||
rewriter.create<triton::nvgpu::WGMMAWaitGroupOp>(op.getLoc(), pendings);
|
||||
|
||||
// Safe to remove the op since it doesn't have any return value.
|
||||
rewriter.eraseOp(op);
|
||||
return success();
|
||||
}
|
||||
};
|
||||
|
||||
void populateDotOpToLLVMPatterns(TritonGPUToLLVMTypeConverter &typeConverter,
|
||||
RewritePatternSet &patterns, int numWarps,
|
||||
ModuleAxisInfoAnalysis &axisInfoAnalysis,
|
||||
ModuleAllocation &allocation,
|
||||
PatternBenefit benefit) {
|
||||
patterns.add<DotOpConversion>(typeConverter, allocation, benefit);
|
||||
patterns.add<DotAsyncOpConversion>(typeConverter, allocation, benefit);
|
||||
patterns.add<DotWaitOpConversion>(typeConverter, allocation, benefit);
|
||||
}
|
||||
15
python/triton/third_party/hip/lib/Conversion/TritonGPUROCMToLLVM/DotOpToLLVM.h
vendored
Normal file
15
python/triton/third_party/hip/lib/Conversion/TritonGPUROCMToLLVM/DotOpToLLVM.h
vendored
Normal file
@@ -0,0 +1,15 @@
|
||||
#ifndef TRITON_CONVERSION_TRITONGPUROCM_TO_LLVM_DOT_OP_H
|
||||
#define TRITON_CONVERSION_TRITONGPUROCM_TO_LLVM_DOT_OP_H
|
||||
|
||||
#include "TritonGPUToLLVMBase.h"
|
||||
|
||||
using namespace mlir;
|
||||
using namespace mlir::triton;
|
||||
|
||||
void populateDotOpToLLVMPatterns(TritonGPUToLLVMTypeConverter &typeConverter,
|
||||
RewritePatternSet &patterns, int numWarps,
|
||||
ModuleAxisInfoAnalysis &axisInfoAnalysis,
|
||||
ModuleAllocation &allocation,
|
||||
PatternBenefit benefit);
|
||||
|
||||
#endif
|
||||
102
python/triton/third_party/hip/lib/Conversion/TritonGPUROCMToLLVM/DotOpToLLVM/FMA.cpp
vendored
Normal file
102
python/triton/third_party/hip/lib/Conversion/TritonGPUROCMToLLVM/DotOpToLLVM/FMA.cpp
vendored
Normal file
@@ -0,0 +1,102 @@
|
||||
#include "../DotOpToLLVM.h"
|
||||
#include "../Utility.h"
|
||||
|
||||
using namespace mlir;
|
||||
using namespace mlir::triton;
|
||||
|
||||
using ::mlir::triton::gpu_rocm::DotOperandEncodingAttr;
|
||||
using ::mlir::triton::gpu_rocm::getShapePerCTA;
|
||||
using ::mlir::triton::gpu_rocm::MmaEncodingAttr;
|
||||
|
||||
using ValueTableFMA = std::map<std::pair<int, int>, Value>;
|
||||
|
||||
static ValueTableFMA getValueTableFromStructFMA(
|
||||
Value val, int K, int n0, int shapePerCTATile, int sizePerThread,
|
||||
ConversionPatternRewriter &rewriter, Location loc,
|
||||
TritonGPUToLLVMTypeConverter *typeConverter, Type type) {
|
||||
ValueTableFMA res;
|
||||
auto elems = typeConverter->unpackLLElements(loc, val, rewriter, type);
|
||||
int index = 0;
|
||||
for (unsigned k = 0; k < K; ++k) {
|
||||
for (unsigned m = 0; m < n0; m += shapePerCTATile)
|
||||
for (unsigned mm = 0; mm < sizePerThread; ++mm) {
|
||||
res[{m + mm, k}] = elems[index++];
|
||||
}
|
||||
}
|
||||
return res;
|
||||
}
|
||||
|
||||
LogicalResult convertFMADot(triton::DotOp op, triton::DotOp::Adaptor adaptor,
|
||||
TritonGPUToLLVMTypeConverter *typeConverter,
|
||||
ConversionPatternRewriter &rewriter) {
|
||||
auto *ctx = rewriter.getContext();
|
||||
auto loc = op.getLoc();
|
||||
|
||||
auto A = op.getA();
|
||||
auto B = op.getB();
|
||||
auto C = op.getC();
|
||||
auto D = op.getResult();
|
||||
|
||||
auto aTensorTy = A.getType().cast<RankedTensorType>();
|
||||
auto bTensorTy = B.getType().cast<RankedTensorType>();
|
||||
auto dTensorTy = D.getType().cast<RankedTensorType>();
|
||||
|
||||
auto aShapePerCTA = getShapePerCTA(aTensorTy);
|
||||
auto bShapePerCTA = getShapePerCTA(bTensorTy);
|
||||
|
||||
BlockedEncodingAttr dLayout =
|
||||
dTensorTy.getEncoding().cast<BlockedEncodingAttr>();
|
||||
auto order = dLayout.getOrder();
|
||||
auto cc =
|
||||
typeConverter->unpackLLElements(loc, adaptor.getC(), rewriter, dTensorTy);
|
||||
|
||||
Value llA = adaptor.getA();
|
||||
Value llB = adaptor.getB();
|
||||
|
||||
auto sizePerThread = getSizePerThread(dLayout);
|
||||
auto shapePerCTATile = getShapePerCTATile(dLayout);
|
||||
|
||||
int K = aShapePerCTA[1];
|
||||
int M = aShapePerCTA[0];
|
||||
int N = bShapePerCTA[1];
|
||||
|
||||
int mShapePerCTATile =
|
||||
order[0] == 1 ? shapePerCTATile[order[1]] : shapePerCTATile[order[0]];
|
||||
int mSizePerThread =
|
||||
order[0] == 1 ? sizePerThread[order[1]] : sizePerThread[order[0]];
|
||||
int nShapePerCTATile =
|
||||
order[0] == 0 ? shapePerCTATile[order[1]] : shapePerCTATile[order[0]];
|
||||
int nSizePerThread =
|
||||
order[0] == 0 ? sizePerThread[order[1]] : sizePerThread[order[0]];
|
||||
|
||||
auto has =
|
||||
getValueTableFromStructFMA(llA, K, M, mShapePerCTATile, mSizePerThread,
|
||||
rewriter, loc, typeConverter, aTensorTy);
|
||||
auto hbs =
|
||||
getValueTableFromStructFMA(llB, K, N, nShapePerCTATile, nSizePerThread,
|
||||
rewriter, loc, typeConverter, bTensorTy);
|
||||
|
||||
SmallVector<Value> ret = cc;
|
||||
bool isCRow = order[0] == 1;
|
||||
|
||||
for (unsigned k = 0; k < K; k++) {
|
||||
for (unsigned m = 0; m < M; m += mShapePerCTATile)
|
||||
for (unsigned n = 0; n < N; n += nShapePerCTATile)
|
||||
for (unsigned mm = 0; mm < mSizePerThread; ++mm)
|
||||
for (unsigned nn = 0; nn < nSizePerThread; ++nn) {
|
||||
int mIdx = m / mShapePerCTATile * mSizePerThread + mm;
|
||||
int nIdx = n / nShapePerCTATile * nSizePerThread + nn;
|
||||
|
||||
int z = isCRow
|
||||
? mIdx * N / nShapePerCTATile * mSizePerThread + nIdx
|
||||
: nIdx * M / mShapePerCTATile * nSizePerThread + mIdx;
|
||||
ret[z] = rewriter.create<LLVM::FMulAddOp>(loc, has[{m + mm, k}],
|
||||
hbs[{n + nn, k}], ret[z]);
|
||||
}
|
||||
}
|
||||
|
||||
auto res = typeConverter->packLLElements(loc, ret, rewriter, dTensorTy);
|
||||
rewriter.replaceOp(op, res);
|
||||
|
||||
return success();
|
||||
}
|
||||
283
python/triton/third_party/hip/lib/Conversion/TritonGPUROCMToLLVM/DotOpToLLVM/MFMA.cpp
vendored
Normal file
283
python/triton/third_party/hip/lib/Conversion/TritonGPUROCMToLLVM/DotOpToLLVM/MFMA.cpp
vendored
Normal file
@@ -0,0 +1,283 @@
|
||||
#if 1
|
||||
|
||||
#include "../DotOpToLLVM.h"
|
||||
#include "../Utility.h"
|
||||
|
||||
#include "mlir/Dialect/LLVMIR/ROCDLDialect.h"
|
||||
|
||||
using namespace mlir;
|
||||
using namespace mlir::triton;
|
||||
|
||||
namespace {
|
||||
|
||||
using ::mlir::triton::gpu_rocm::DotOperandEncodingAttr;
|
||||
using ::mlir::triton::gpu_rocm::MfmaEncodingAttr;
|
||||
using ::mlir::triton::gpu_rocm::SharedEncodingAttr;
|
||||
|
||||
enum class MatrixCoreType : uint8_t {
|
||||
// D = AB + C
|
||||
FP32_FP16_FP16_FP32 = 0, // default
|
||||
FP32_BF16_BF16_FP32,
|
||||
FP32_BF16_BF16_FP32_1K,
|
||||
FP32_FP32_FP32_FP32,
|
||||
FP64_FP64_FP64_FP64,
|
||||
INT32_INT8_INT8_INT32,
|
||||
NOT_APPLICABLE,
|
||||
};
|
||||
|
||||
struct MFMAInstrDescr {
|
||||
MatrixCoreType coreType;
|
||||
unsigned size;
|
||||
};
|
||||
|
||||
using ValueTable = std::map<std::pair<unsigned, unsigned>, Value>;
|
||||
|
||||
struct DotOpMFMAConversionHelper {
|
||||
MfmaEncodingAttr mfmaLayout;
|
||||
|
||||
ConversionPatternRewriter &rewriter;
|
||||
TritonGPUToLLVMTypeConverter *typeConverter;
|
||||
Location loc;
|
||||
MLIRContext *ctx{};
|
||||
|
||||
explicit DotOpMFMAConversionHelper(
|
||||
MfmaEncodingAttr mfmaLayout, ConversionPatternRewriter &rewriter,
|
||||
TritonGPUToLLVMTypeConverter *typeConverter, Location loc)
|
||||
: mfmaLayout(mfmaLayout), rewriter(rewriter),
|
||||
typeConverter(typeConverter), loc(loc), ctx(mfmaLayout.getContext()) {}
|
||||
|
||||
Value getThreadId() const {
|
||||
auto llvmIndexTy = typeConverter->getIndexType();
|
||||
auto tid = rewriter.create<::mlir::gpu::ThreadIdOp>(
|
||||
loc, rewriter.getIndexType(), ::mlir::gpu::Dimension::x);
|
||||
return rewriter.create<arith::TruncIOp>(loc, i32_ty, tid);
|
||||
}
|
||||
|
||||
Value generateMFMAOp(MFMAInstrDescr mfmaDescr, Value valA, Value valB,
|
||||
Value valC) const {
|
||||
auto resType = valC.getType();
|
||||
Value zeroFlag = i32_val(0);
|
||||
switch (mfmaDescr.coreType) {
|
||||
case MatrixCoreType::FP32_FP16_FP16_FP32:
|
||||
if (mfmaDescr.size == 16) {
|
||||
return rewriter.create<ROCDL::mfma_f32_16x16x16f16>(
|
||||
loc, TypeRange{resType},
|
||||
ValueRange{valA, valB, valC, zeroFlag, zeroFlag, zeroFlag});
|
||||
} else {
|
||||
return rewriter.create<ROCDL::mfma_f32_32x32x8f16>(
|
||||
loc, TypeRange{resType},
|
||||
ValueRange{valA, valB, valC, zeroFlag, zeroFlag, zeroFlag});
|
||||
}
|
||||
case MatrixCoreType::FP32_BF16_BF16_FP32:
|
||||
if (mfmaDescr.size == 16) {
|
||||
return rewriter.create<ROCDL::mfma_f32_16x16x8bf16>(
|
||||
loc, TypeRange{resType},
|
||||
ValueRange{valA, valB, valC, zeroFlag, zeroFlag, zeroFlag});
|
||||
} else {
|
||||
return rewriter.create<ROCDL::mfma_f32_32x32x4bf16>(
|
||||
loc, TypeRange{resType},
|
||||
ValueRange{valA, valB, valC, zeroFlag, zeroFlag, zeroFlag});
|
||||
}
|
||||
case MatrixCoreType::FP32_BF16_BF16_FP32_1K:
|
||||
if (mfmaDescr.size == 16) {
|
||||
return rewriter.create<ROCDL::mfma_f32_16x16x16bf16_1k>(
|
||||
loc, TypeRange{resType},
|
||||
ValueRange{valA, valB, valC, zeroFlag, zeroFlag, zeroFlag});
|
||||
} else {
|
||||
assert(mfmaDescr.size == 32);
|
||||
return rewriter.create<ROCDL::mfma_f32_32x32x8bf16_1k>(
|
||||
loc, TypeRange{resType},
|
||||
ValueRange{valA, valB, valC, zeroFlag, zeroFlag, zeroFlag});
|
||||
}
|
||||
case MatrixCoreType::FP32_FP32_FP32_FP32:
|
||||
if (mfmaDescr.size == 16) {
|
||||
return rewriter.create<ROCDL::mfma_f32_16x16x4f32>(
|
||||
loc, TypeRange{resType},
|
||||
ValueRange{valA, valB, valC, zeroFlag, zeroFlag, zeroFlag});
|
||||
} else {
|
||||
assert(mfmaDescr.size == 32);
|
||||
return rewriter.create<ROCDL::mfma_f32_32x32x2f32>(
|
||||
loc, TypeRange{resType},
|
||||
ValueRange{valA, valB, valC, zeroFlag, zeroFlag, zeroFlag});
|
||||
}
|
||||
case MatrixCoreType::INT32_INT8_INT8_INT32:
|
||||
if (mfmaDescr.size == 16) {
|
||||
return rewriter.create<ROCDL::mfma_i32_16x16x16i8>(
|
||||
loc, TypeRange{resType},
|
||||
ValueRange{valA, valB, valC, zeroFlag, zeroFlag, zeroFlag});
|
||||
} else {
|
||||
return rewriter.create<ROCDL::mfma_i32_32x32x8i8>(
|
||||
loc, TypeRange{resType},
|
||||
ValueRange{valA, valB, valC, zeroFlag, zeroFlag, zeroFlag});
|
||||
}
|
||||
case MatrixCoreType::FP64_FP64_FP64_FP64:
|
||||
assert(mfmaDescr.size == 16);
|
||||
return rewriter.create<ROCDL::mfma_f64_16x16x4f64>(
|
||||
loc, TypeRange{resType},
|
||||
ValueRange{valA, valB, valC, zeroFlag, zeroFlag, zeroFlag});
|
||||
default:
|
||||
llvm::report_fatal_error("MFMA data type not supported");
|
||||
}
|
||||
}
|
||||
|
||||
static MatrixCoreType getMatrixCoreTypeFromDot(DotOp op) {
|
||||
auto aOperandTy = op.getA().getType();
|
||||
auto tensorTy = aOperandTy.cast<RankedTensorType>();
|
||||
auto elemTy = tensorTy.getElementType();
|
||||
auto dotOpEncoding = tensorTy.getEncoding().cast<DotOperandEncodingAttr>();
|
||||
auto mfmaEncoding = dotOpEncoding.getParent().cast<MfmaEncodingAttr>();
|
||||
if (elemTy.isF16())
|
||||
return MatrixCoreType::FP32_FP16_FP16_FP32;
|
||||
if (elemTy.isF32())
|
||||
return MatrixCoreType::FP32_FP32_FP32_FP32;
|
||||
if (elemTy.isBF16()) {
|
||||
auto nonKDim = mfmaEncoding.getNonKDim();
|
||||
auto kWidth = dotOpEncoding.getKWidth();
|
||||
if ((nonKDim == 32 && kWidth == 4) || (nonKDim == 16 && kWidth == 4)) {
|
||||
return MatrixCoreType::FP32_BF16_BF16_FP32_1K;
|
||||
} else {
|
||||
assert((nonKDim == 32 && kWidth == 2) ||
|
||||
(nonKDim == 16 && kWidth == 2));
|
||||
return MatrixCoreType::FP32_BF16_BF16_FP32;
|
||||
}
|
||||
}
|
||||
if (elemTy.isInteger(8))
|
||||
return MatrixCoreType::INT32_INT8_INT8_INT32;
|
||||
if (elemTy.isF64())
|
||||
return MatrixCoreType::FP64_FP64_FP64_FP64;
|
||||
return MatrixCoreType::NOT_APPLICABLE;
|
||||
}
|
||||
|
||||
static MFMAInstrDescr getMatrixInstrDescr(DotOp op) {
|
||||
MFMAInstrDescr descr;
|
||||
auto tensorTy = op.getD().getType().cast<RankedTensorType>();
|
||||
auto encoding = tensorTy.getEncoding().cast<MfmaEncodingAttr>();
|
||||
descr.coreType = getMatrixCoreTypeFromDot(op);
|
||||
descr.size = encoding.getNonKDim();
|
||||
return descr;
|
||||
}
|
||||
|
||||
// Conduct the Dot conversion.
|
||||
LogicalResult convertDot(DotOp op, DotOpAdaptor adaptor) const {
|
||||
auto warpsPerCTA = mfmaLayout.getWarpsPerCTA();
|
||||
auto nonKDim = mfmaLayout.getNonKDim();
|
||||
assert(nonKDim == 32 || nonKDim == 16);
|
||||
auto mfmaInstrDescr = getMatrixInstrDescr(op);
|
||||
|
||||
Value a = op.getA();
|
||||
Value b = op.getB();
|
||||
Value d = op.getD();
|
||||
auto aTensorTy = a.getType().cast<RankedTensorType>();
|
||||
auto bTensorTy = b.getType().cast<RankedTensorType>();
|
||||
auto dTensorTy = d.getType().cast<RankedTensorType>();
|
||||
auto elemTy = aTensorTy.getElementType();
|
||||
|
||||
auto aEncoding = aTensorTy.getEncoding().cast<DotOperandEncodingAttr>();
|
||||
auto bEncoding = bTensorTy.getEncoding().cast<DotOperandEncodingAttr>();
|
||||
|
||||
auto repA = aEncoding.getMFMARep(aTensorTy.getShape(), elemTy);
|
||||
auto repB = bEncoding.getMFMARep(bTensorTy.getShape(), elemTy);
|
||||
|
||||
assert(repA[1] == repB[0]);
|
||||
|
||||
Value loadedA = adaptor.getA();
|
||||
Value loadedB = adaptor.getB();
|
||||
Value loadedC = adaptor.getC();
|
||||
|
||||
auto numRepM = repA[0];
|
||||
auto numRepN = repB[1];
|
||||
auto numRepK = repA[1];
|
||||
|
||||
ValueTable ha = getValuesFromDotOperandLayoutStruct(
|
||||
loadedA, numRepM, numRepK, aTensorTy.getElementType());
|
||||
ValueTable hb = getValuesFromDotOperandLayoutStruct(
|
||||
loadedB, numRepN, numRepK, aTensorTy.getElementType());
|
||||
auto dstElemTy = dTensorTy.getElementType();
|
||||
auto fc =
|
||||
typeConverter->unpackLLElements(loc, loadedC, rewriter, dstElemTy);
|
||||
|
||||
unsigned warpSize = triton::gpu_rocm::getWarpSize(mfmaLayout);
|
||||
// compute number of output elements that each thread holds for one MFMA
|
||||
// instruction
|
||||
auto elemsPerVec = nonKDim * nonKDim / warpSize;
|
||||
|
||||
auto vecTy = vec_ty(dstElemTy, elemsPerVec);
|
||||
for (int m = 0; m < numRepM; ++m) {
|
||||
for (int n = 0; n < numRepN; ++n) {
|
||||
Value acc = undef(vecTy);
|
||||
for (unsigned v = 0; v < elemsPerVec; ++v) {
|
||||
acc = insert_element(
|
||||
vecTy, acc, fc[m * numRepN * elemsPerVec + n * elemsPerVec + v],
|
||||
i32_val(v));
|
||||
}
|
||||
|
||||
for (size_t k = 0; k < numRepK; k++) {
|
||||
acc =
|
||||
mfmaLayout.getIsTransposed()
|
||||
? generateMFMAOp(mfmaInstrDescr, hb[{n, k}], ha[{m, k}], acc)
|
||||
: generateMFMAOp(mfmaInstrDescr, ha[{m, k}], hb[{n, k}], acc);
|
||||
}
|
||||
for (unsigned v = 0; v < elemsPerVec; ++v) {
|
||||
fc[m * numRepN * elemsPerVec + n * elemsPerVec + v] =
|
||||
extract_element(dstElemTy, acc, i32_val(v));
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// replace with new packed result
|
||||
Type structTy = LLVM::LLVMStructType::getLiteral(
|
||||
ctx, SmallVector<Type>(fc.size(), dstElemTy));
|
||||
Value res = typeConverter->packLLElements(loc, fc, rewriter, structTy);
|
||||
rewriter.replaceOp(op, res);
|
||||
|
||||
return success();
|
||||
}
|
||||
|
||||
ValueTable getValuesFromDotOperandLayoutStruct(Value value, int n0, int n1,
|
||||
Type type) const {
|
||||
auto elems = typeConverter->unpackLLElements(loc, value, rewriter, type);
|
||||
ValueTable vals;
|
||||
for (int i = 0; i < n0; i++) {
|
||||
for (int j = 0; j < n1; j++) {
|
||||
vals[{i, j}] = elems[n1 * i + j];
|
||||
}
|
||||
}
|
||||
return vals;
|
||||
}
|
||||
};
|
||||
|
||||
} // namespace
|
||||
|
||||
LogicalResult convertMFMA(triton::DotOp op, triton::DotOp::Adaptor adaptor,
|
||||
TritonGPUToLLVMTypeConverter *typeConverter,
|
||||
ConversionPatternRewriter &rewriter) {
|
||||
auto rankedTType = [](Value tensor) {
|
||||
return tensor.getType().cast<RankedTensorType>();
|
||||
};
|
||||
|
||||
assert(rankedTType(op.getA()).getEncoding().isa<DotOperandEncodingAttr>() &&
|
||||
rankedTType(op.getB()).getEncoding().isa<DotOperandEncodingAttr>() &&
|
||||
"Both $a and %b should be DotOperand layout.");
|
||||
|
||||
auto cTensorTy = rankedTType(op.getC());
|
||||
auto dTensorTy = rankedTType(op.getD());
|
||||
assert(cTensorTy.getEncoding().isa<MfmaEncodingAttr>() &&
|
||||
"Currently, we only support $c with a mfma layout.");
|
||||
|
||||
assert(cTensorTy.getShape()[0] == dTensorTy.getShape()[0] &&
|
||||
cTensorTy.getShape()[1] == dTensorTy.getShape()[1] &&
|
||||
"DotOp's $c operand should pass the same number of values as $d");
|
||||
|
||||
auto loc = op.getLoc();
|
||||
auto mfmaLayout = op.getResult()
|
||||
.getType()
|
||||
.cast<RankedTensorType>()
|
||||
.getEncoding()
|
||||
.cast<MfmaEncodingAttr>();
|
||||
|
||||
DotOpMFMAConversionHelper helper(mfmaLayout, rewriter, typeConverter, loc);
|
||||
|
||||
return helper.convertDot(op, adaptor);
|
||||
}
|
||||
|
||||
#endif // ifdef USE_ROCM
|
||||
161
python/triton/third_party/hip/lib/Conversion/TritonGPUROCMToLLVM/DotOpToLLVM/MMAv1.cpp
vendored
Normal file
161
python/triton/third_party/hip/lib/Conversion/TritonGPUROCMToLLVM/DotOpToLLVM/MMAv1.cpp
vendored
Normal file
@@ -0,0 +1,161 @@
|
||||
#include "../DotOpToLLVM.h"
|
||||
#include "../Utility.h"
|
||||
|
||||
using namespace mlir;
|
||||
using namespace mlir::triton;
|
||||
|
||||
using ::mlir::triton::gpu_rocm::DotOperandEncodingAttr;
|
||||
using ::mlir::triton::gpu_rocm::MmaEncodingAttr;
|
||||
|
||||
using ValueTable = std::map<std::pair<int, int>, std::pair<Value, Value>>;
|
||||
|
||||
static Type getMmaRetType(TensorType operand) {
|
||||
auto *ctx = operand.getContext();
|
||||
Type fp32Ty = type::f32Ty(ctx);
|
||||
// f16*f16+f32->f32
|
||||
return struct_ty(SmallVector<Type>{8, fp32Ty});
|
||||
}
|
||||
|
||||
static ValueTable
|
||||
extractLoadedOperand(Value llStruct, int NK,
|
||||
ConversionPatternRewriter &rewriter,
|
||||
TritonGPUToLLVMTypeConverter *typeConverter, Type type) {
|
||||
ValueTable rcds;
|
||||
SmallVector<Value> elems = typeConverter->unpackLLElements(
|
||||
llStruct.getLoc(), llStruct, rewriter, type);
|
||||
|
||||
int offset = 0;
|
||||
for (int i = 0; offset < elems.size(); ++i) {
|
||||
for (int k = 0; k < NK; k += 4) {
|
||||
rcds[{i, k}] = std::make_pair(elems[offset], elems[offset + 1]);
|
||||
offset += 2;
|
||||
}
|
||||
}
|
||||
|
||||
return rcds;
|
||||
}
|
||||
|
||||
LogicalResult convertMMA884(triton::DotOp op, triton::DotOp::Adaptor adaptor,
|
||||
TritonGPUToLLVMTypeConverter *typeConverter,
|
||||
ConversionPatternRewriter &rewriter) {
|
||||
auto *ctx = op.getContext();
|
||||
auto loc = op.getLoc();
|
||||
|
||||
Value A = op.getA();
|
||||
Value B = op.getB();
|
||||
Value D = op.getResult();
|
||||
auto mmaLayout = D.getType()
|
||||
.cast<RankedTensorType>()
|
||||
.getEncoding()
|
||||
.cast<MmaEncodingAttr>();
|
||||
auto ALayout = A.getType()
|
||||
.cast<RankedTensorType>()
|
||||
.getEncoding()
|
||||
.cast<DotOperandEncodingAttr>();
|
||||
auto BLayout = B.getType()
|
||||
.cast<RankedTensorType>()
|
||||
.getEncoding()
|
||||
.cast<DotOperandEncodingAttr>();
|
||||
|
||||
auto ATensorTy = A.getType().cast<RankedTensorType>();
|
||||
auto BTensorTy = B.getType().cast<RankedTensorType>();
|
||||
auto DTensorTy = D.getType().cast<RankedTensorType>();
|
||||
auto AShape = ATensorTy.getShape();
|
||||
auto BShape = BTensorTy.getShape();
|
||||
|
||||
bool isARow = ALayout.getMMAv1IsRow();
|
||||
bool isBRow = BLayout.getMMAv1IsRow();
|
||||
auto [isARow_, isBRow_, isAVec4_, isBVec4_, _] =
|
||||
mmaLayout.decodeVoltaLayoutStates();
|
||||
assert(isARow == isARow_);
|
||||
assert(isBRow == isBRow_);
|
||||
|
||||
unsigned numM = ALayout.getMMAv1NumOuter(AShape);
|
||||
unsigned numN = BLayout.getMMAv1NumOuter(BShape);
|
||||
unsigned NK = AShape[1];
|
||||
|
||||
auto has = extractLoadedOperand(adaptor.getA(), NK, rewriter, typeConverter,
|
||||
ATensorTy);
|
||||
auto hbs = extractLoadedOperand(adaptor.getB(), NK, rewriter, typeConverter,
|
||||
BTensorTy);
|
||||
|
||||
// Initialize accumulators with external values, the acc holds the
|
||||
// accumulator value that is shared between the MMA instructions inside a
|
||||
// DotOp, we can call the order of the values the accumulator-internal
|
||||
// order.
|
||||
SmallVector<Value> acc =
|
||||
typeConverter->unpackLLElements(loc, adaptor.getC(), rewriter, DTensorTy);
|
||||
size_t resSize = acc.size();
|
||||
|
||||
// The resVals holds the final result of the DotOp.
|
||||
// NOTE The current order of resVals is different from acc, we call it the
|
||||
// accumulator-external order. and
|
||||
SmallVector<Value> resVals(resSize);
|
||||
|
||||
auto getIdx = [&](int m, int n) {
|
||||
std::vector<size_t> idx{{
|
||||
(m * 2 + 0) + (n * 4 + 0) * numM, // row0
|
||||
(m * 2 + 0) + (n * 4 + 1) * numM,
|
||||
(m * 2 + 1) + (n * 4 + 0) * numM, // row1
|
||||
(m * 2 + 1) + (n * 4 + 1) * numM,
|
||||
(m * 2 + 0) + (n * 4 + 2) * numM, // row2
|
||||
(m * 2 + 0) + (n * 4 + 3) * numM,
|
||||
(m * 2 + 1) + (n * 4 + 2) * numM, // row3
|
||||
(m * 2 + 1) + (n * 4 + 3) * numM,
|
||||
}};
|
||||
return idx;
|
||||
};
|
||||
|
||||
auto callMMA = [&](unsigned m, unsigned n, unsigned k) {
|
||||
auto ha = has.at({m, k});
|
||||
auto hb = hbs.at({n, k});
|
||||
|
||||
PTXBuilder builder;
|
||||
auto idx = getIdx(m, n);
|
||||
|
||||
// note: using "=f" for float leads to cleaner PTX
|
||||
bool isIntMMA = DTensorTy.getElementType().isInteger(32);
|
||||
auto *resOprs = builder.newListOperand(8, isIntMMA ? "=r" : "=f");
|
||||
auto *AOprs = builder.newListOperand({
|
||||
{ha.first, "r"},
|
||||
{ha.second, "r"},
|
||||
});
|
||||
|
||||
auto *BOprs = builder.newListOperand({
|
||||
{hb.first, "r"},
|
||||
{hb.second, "r"},
|
||||
});
|
||||
auto *COprs = builder.newListOperand();
|
||||
for (int i = 0; i < 8; ++i)
|
||||
COprs->listAppend(builder.newOperand(acc[idx[i]], std::to_string(i)));
|
||||
|
||||
auto mma = builder.create("mma.sync.aligned.m8n8k4")
|
||||
->o(isARow ? "row" : "col")
|
||||
.o(isBRow ? "row" : "col")
|
||||
.o("f32.f16.f16.f32");
|
||||
|
||||
mma(resOprs, AOprs, BOprs, COprs);
|
||||
|
||||
Value res = builder.launch(rewriter, loc, getMmaRetType(ATensorTy));
|
||||
|
||||
for (auto i = 0; i < 8; i++) {
|
||||
Value elem = extract_val(f32_ty, res, i);
|
||||
acc[idx[i]] = elem;
|
||||
}
|
||||
};
|
||||
|
||||
for (unsigned k = 0; k < NK; k += 4)
|
||||
for (unsigned m = 0; m < numM / 2; ++m)
|
||||
for (unsigned n = 0; n < numN / 2; ++n) {
|
||||
callMMA(m, n, k);
|
||||
}
|
||||
|
||||
// res holds the same layout of acc
|
||||
for (size_t i = 0; i < acc.size(); ++i) {
|
||||
resVals[i] = acc[i];
|
||||
}
|
||||
|
||||
Value res = typeConverter->packLLElements(loc, resVals, rewriter, DTensorTy);
|
||||
rewriter.replaceOp(op, res);
|
||||
return success();
|
||||
}
|
||||
333
python/triton/third_party/hip/lib/Conversion/TritonGPUROCMToLLVM/DotOpToLLVM/MMAv2.cpp
vendored
Normal file
333
python/triton/third_party/hip/lib/Conversion/TritonGPUROCMToLLVM/DotOpToLLVM/MMAv2.cpp
vendored
Normal file
@@ -0,0 +1,333 @@
|
||||
#include "../DotOpToLLVM.h"
|
||||
#include "../Utility.h"
|
||||
|
||||
using namespace mlir;
|
||||
using namespace mlir::triton;
|
||||
|
||||
using ::mlir::triton::gpu_rocm::DotOperandEncodingAttr;
|
||||
using ::mlir::triton::gpu_rocm::MmaEncodingAttr;
|
||||
|
||||
using ValueTableV2 = std::map<std::pair<unsigned, unsigned>, Value>;
|
||||
|
||||
Value loadC(Value tensor, Value llTensor,
|
||||
TritonGPUToLLVMTypeConverter *typeConverter, Location loc,
|
||||
ConversionPatternRewriter &rewriter) {
|
||||
MLIRContext *ctx = tensor.getContext();
|
||||
auto tensorTy = tensor.getType().cast<RankedTensorType>();
|
||||
size_t fcSize = triton::gpu_rocm::getTotalElemsPerThread(tensor.getType());
|
||||
|
||||
assert(tensorTy.getEncoding().isa<MmaEncodingAttr>() &&
|
||||
"Currently, we only support $c with a mma layout.");
|
||||
// Load a normal C tensor with mma layout, that should be a
|
||||
// LLVM::struct with fcSize elements.
|
||||
auto structTy = llTensor.getType().cast<LLVM::LLVMStructType>();
|
||||
assert(structTy.getBody().size() == fcSize &&
|
||||
"DotOp's $c operand should pass the same number of values as $d in "
|
||||
"mma layout.");
|
||||
|
||||
auto numMmaRets = tensorTy.getElementType().getIntOrFloatBitWidth() / 8;
|
||||
assert(numMmaRets == 4 || numMmaRets == 2);
|
||||
if (numMmaRets == 4) {
|
||||
return llTensor;
|
||||
} else if (numMmaRets == 2) {
|
||||
auto cPack = SmallVector<Value>();
|
||||
auto cElemTy = tensorTy.getElementType();
|
||||
int numCPackedElem = 4 / numMmaRets;
|
||||
Type cPackTy = vec_ty(cElemTy, numCPackedElem);
|
||||
for (int i = 0; i < fcSize; i += numCPackedElem) {
|
||||
Value pack = rewriter.create<LLVM::UndefOp>(loc, cPackTy);
|
||||
for (int j = 0; j < numCPackedElem; ++j) {
|
||||
pack = insert_element(
|
||||
cPackTy, pack, extract_val(cElemTy, llTensor, i + j), i32_val(j));
|
||||
}
|
||||
cPack.push_back(pack);
|
||||
}
|
||||
|
||||
Type structTy = LLVM::LLVMStructType::getLiteral(
|
||||
ctx, SmallVector<Type>(cPack.size(), cPackTy));
|
||||
Value result =
|
||||
typeConverter->packLLElements(loc, cPack, rewriter, structTy);
|
||||
return result;
|
||||
}
|
||||
|
||||
return llTensor;
|
||||
}
|
||||
|
||||
ValueTableV2 getValuesFromDotOperandLayoutStruct(
|
||||
TritonGPUToLLVMTypeConverter *typeConverter, Location loc,
|
||||
ConversionPatternRewriter &rewriter, Value value, int n0, int n1,
|
||||
RankedTensorType type) {
|
||||
|
||||
auto elems = typeConverter->unpackLLElements(loc, value, rewriter, type);
|
||||
int offset{};
|
||||
ValueTableV2 vals;
|
||||
for (int i = 0; i < n0; ++i) {
|
||||
for (int j = 0; j < n1; j++) {
|
||||
vals[{2 * i, 2 * j}] = elems[offset++];
|
||||
vals[{2 * i, 2 * j + 1}] = elems[offset++];
|
||||
vals[{2 * i + 1, 2 * j}] = elems[offset++];
|
||||
vals[{2 * i + 1, 2 * j + 1}] = elems[offset++];
|
||||
}
|
||||
}
|
||||
return vals;
|
||||
}
|
||||
|
||||
enum class TensorCoreType : uint8_t {
|
||||
// floating-point tensor core instr
|
||||
FP32_FP16_FP16_FP32 = 0, // default
|
||||
FP32_BF16_BF16_FP32,
|
||||
FP32_TF32_TF32_FP32,
|
||||
FP16_FP16_FP16_FP16,
|
||||
// integer tensor core instr
|
||||
INT32_INT1_INT1_INT32, // Not implemented
|
||||
INT32_INT4_INT4_INT32, // Not implemented
|
||||
INT32_INT8_INT8_INT32, // Not implemented
|
||||
//
|
||||
NOT_APPLICABLE,
|
||||
};
|
||||
|
||||
Type getMmaRetType(TensorCoreType mmaType, MLIRContext *ctx) {
|
||||
Type fp32Ty = type::f32Ty(ctx);
|
||||
Type fp16Ty = type::f16Ty(ctx);
|
||||
Type i32Ty = type::i32Ty(ctx);
|
||||
Type fp32x4Ty =
|
||||
LLVM::LLVMStructType::getLiteral(ctx, SmallVector<Type>(4, fp32Ty));
|
||||
Type i32x4Ty =
|
||||
LLVM::LLVMStructType::getLiteral(ctx, SmallVector<Type>(4, i32Ty));
|
||||
Type fp16x2Pack2Ty = LLVM::LLVMStructType::getLiteral(
|
||||
ctx, SmallVector<Type>(2, vec_ty(fp16Ty, 2)));
|
||||
switch (mmaType) {
|
||||
case TensorCoreType::FP32_FP16_FP16_FP32:
|
||||
return fp32x4Ty;
|
||||
case TensorCoreType::FP32_BF16_BF16_FP32:
|
||||
return fp32x4Ty;
|
||||
case TensorCoreType::FP32_TF32_TF32_FP32:
|
||||
return fp32x4Ty;
|
||||
case TensorCoreType::FP16_FP16_FP16_FP16:
|
||||
return fp16x2Pack2Ty;
|
||||
case TensorCoreType::INT32_INT8_INT8_INT32:
|
||||
return i32x4Ty;
|
||||
default:
|
||||
llvm::report_fatal_error("Unsupported mma type found");
|
||||
}
|
||||
|
||||
return Type{};
|
||||
}
|
||||
|
||||
TensorCoreType getMmaType(triton::DotOp op) {
|
||||
Value A = op.getA();
|
||||
Value B = op.getB();
|
||||
auto aTy = A.getType().cast<RankedTensorType>();
|
||||
auto bTy = B.getType().cast<RankedTensorType>();
|
||||
// d = a*b + c
|
||||
auto dTy = op.getD().getType().cast<RankedTensorType>();
|
||||
|
||||
if (dTy.getElementType().isF32()) {
|
||||
if (aTy.getElementType().isF16() && bTy.getElementType().isF16())
|
||||
return TensorCoreType::FP32_FP16_FP16_FP32;
|
||||
if (aTy.getElementType().isBF16() && bTy.getElementType().isBF16())
|
||||
return TensorCoreType::FP32_BF16_BF16_FP32;
|
||||
if (aTy.getElementType().isF32() && bTy.getElementType().isF32() &&
|
||||
op.getAllowTF32())
|
||||
return TensorCoreType::FP32_TF32_TF32_FP32;
|
||||
} else if (dTy.getElementType().isInteger(32)) {
|
||||
if (aTy.getElementType().isInteger(8) && bTy.getElementType().isInteger(8))
|
||||
return TensorCoreType::INT32_INT8_INT8_INT32;
|
||||
} else if (dTy.getElementType().isF16()) {
|
||||
if (aTy.getElementType().isF16() && bTy.getElementType().isF16())
|
||||
return TensorCoreType::FP16_FP16_FP16_FP16;
|
||||
}
|
||||
|
||||
return TensorCoreType::NOT_APPLICABLE;
|
||||
}
|
||||
|
||||
inline static const std::map<TensorCoreType, std::string> mmaInstrPtxTuring = {
|
||||
{TensorCoreType::FP32_FP16_FP16_FP32,
|
||||
"mma.sync.aligned.m16n8k8.row.col.f32.f16.f16.f32"},
|
||||
|
||||
{TensorCoreType::FP16_FP16_FP16_FP16,
|
||||
"mma.sync.aligned.m16n8k8.row.col.f16.f16.f16.f16"},
|
||||
};
|
||||
|
||||
inline static const std::map<TensorCoreType, std::string> mmaInstrPtxAmpere = {
|
||||
{TensorCoreType::FP32_FP16_FP16_FP32,
|
||||
"mma.sync.aligned.m16n8k16.row.col.f32.f16.f16.f32"},
|
||||
{TensorCoreType::FP32_BF16_BF16_FP32,
|
||||
"mma.sync.aligned.m16n8k16.row.col.f32.bf16.bf16.f32"},
|
||||
{TensorCoreType::FP32_TF32_TF32_FP32,
|
||||
"mma.sync.aligned.m16n8k8.row.col.f32.tf32.tf32.f32"},
|
||||
|
||||
{TensorCoreType::INT32_INT1_INT1_INT32,
|
||||
"mma.sync.aligned.m16n8k256.row.col.s32.b1.b1.s32.xor.popc"},
|
||||
{TensorCoreType::INT32_INT4_INT4_INT32,
|
||||
"mma.sync.aligned.m16n8k64.row.col.satfinite.s32.s4.s4.s32"},
|
||||
{TensorCoreType::INT32_INT8_INT8_INT32,
|
||||
"mma.sync.aligned.m16n8k32.row.col.satfinite.s32.s8.s8.s32"},
|
||||
|
||||
{TensorCoreType::FP16_FP16_FP16_FP16,
|
||||
"mma.sync.aligned.m16n8k16.row.col.f16.f16.f16.f16"},
|
||||
};
|
||||
|
||||
LogicalResult convertDot(TritonGPUToLLVMTypeConverter *typeConverter,
|
||||
ConversionPatternRewriter &rewriter, Location loc,
|
||||
Value a, Value b, Value c, Value d, Value loadedA,
|
||||
Value loadedB, Value loadedC, DotOp op,
|
||||
DotOpAdaptor adaptor, bool isTuring) {
|
||||
MLIRContext *ctx = c.getContext();
|
||||
auto aTensorTy = a.getType().cast<RankedTensorType>();
|
||||
auto bTensorTy = b.getType().cast<RankedTensorType>();
|
||||
auto dTensorTy = d.getType().cast<RankedTensorType>();
|
||||
|
||||
auto aShapePerCTA = triton::gpu_rocm::getShapePerCTA(aTensorTy);
|
||||
auto bShapePerCTA = triton::gpu_rocm::getShapePerCTA(bTensorTy);
|
||||
auto dShapePerCTA = triton::gpu_rocm::getShapePerCTA(dTensorTy);
|
||||
|
||||
int bitwidth = aTensorTy.getElementType().getIntOrFloatBitWidth();
|
||||
auto repA =
|
||||
aTensorTy.getEncoding().cast<DotOperandEncodingAttr>().getMMAv2Rep(
|
||||
aShapePerCTA, bitwidth);
|
||||
auto repB =
|
||||
bTensorTy.getEncoding().cast<DotOperandEncodingAttr>().getMMAv2Rep(
|
||||
bShapePerCTA, bitwidth);
|
||||
|
||||
assert(repA[1] == repB[0]);
|
||||
int repM = repA[0], repN = repB[1], repK = repA[1];
|
||||
|
||||
// shape / shape_per_cta
|
||||
auto ha = getValuesFromDotOperandLayoutStruct(typeConverter, loc, rewriter,
|
||||
loadedA, repM, repK, aTensorTy);
|
||||
auto hb = getValuesFromDotOperandLayoutStruct(typeConverter, loc, rewriter,
|
||||
loadedB, std::max(repN / 2, 1),
|
||||
repK, bTensorTy);
|
||||
auto fc = typeConverter->unpackLLElements(loc, loadedC, rewriter, dTensorTy);
|
||||
auto numMmaRets = dTensorTy.getElementType().getIntOrFloatBitWidth() / 8;
|
||||
int numCPackedElem = 4 / numMmaRets;
|
||||
|
||||
auto mmaType = getMmaType(op);
|
||||
|
||||
const auto &mmaInstructions =
|
||||
isTuring ? mmaInstrPtxTuring : mmaInstrPtxAmpere;
|
||||
|
||||
auto callMma = [&](unsigned m, unsigned n, unsigned k) {
|
||||
unsigned colsPerThread = repN * 2;
|
||||
PTXBuilder builder;
|
||||
auto &mma = *builder.create(mmaInstructions.at(mmaType));
|
||||
// using =r for float32 works but leads to less readable ptx.
|
||||
bool isIntMMA = dTensorTy.getElementType().isInteger(32);
|
||||
bool isAccF16 = dTensorTy.getElementType().isF16();
|
||||
auto retArgs =
|
||||
builder.newListOperand(numMmaRets, isIntMMA || isAccF16 ? "=r" : "=f");
|
||||
auto cArgs = builder.newListOperand();
|
||||
for (int i = 0; i < numMmaRets; ++i) {
|
||||
cArgs->listAppend(builder.newOperand(
|
||||
fc[(m * colsPerThread + 4 * n) / numCPackedElem + i],
|
||||
std::to_string(i)));
|
||||
// reuse the output registers
|
||||
}
|
||||
|
||||
if (isTuring) {
|
||||
auto aArgs1 = builder.newListOperand({
|
||||
{ha[{m, k}], "r"},
|
||||
{ha[{m + 1, k}], "r"},
|
||||
});
|
||||
auto bArgs1 = builder.newListOperand({
|
||||
{hb[{n, k}], "r"},
|
||||
});
|
||||
auto aArgs2 = builder.newListOperand({
|
||||
{ha[{m, k + 1}], "r"},
|
||||
{ha[{m + 1, k + 1}], "r"},
|
||||
});
|
||||
auto bArgs2 = builder.newListOperand({{hb[{n, k + 1}], "r"}});
|
||||
mma(retArgs, aArgs1, bArgs1, cArgs);
|
||||
mma(retArgs, aArgs2, bArgs2, cArgs);
|
||||
} else {
|
||||
auto aArgs = builder.newListOperand({
|
||||
{ha[{m, k}], "r"},
|
||||
{ha[{m + 1, k}], "r"},
|
||||
{ha[{m, k + 1}], "r"},
|
||||
{ha[{m + 1, k + 1}], "r"},
|
||||
});
|
||||
auto bArgs =
|
||||
builder.newListOperand({{hb[{n, k}], "r"}, {hb[{n, k + 1}], "r"}});
|
||||
mma(retArgs, aArgs, bArgs, cArgs);
|
||||
}
|
||||
Value mmaOut =
|
||||
builder.launch(rewriter, loc, getMmaRetType(mmaType, op.getContext()));
|
||||
|
||||
Type elemTy = mmaOut.getType().cast<LLVM::LLVMStructType>().getBody()[0];
|
||||
for (int i = 0; i < numMmaRets; ++i) {
|
||||
fc[(m * colsPerThread + 4 * n) / numCPackedElem + i] =
|
||||
extract_val(elemTy, mmaOut, i);
|
||||
}
|
||||
};
|
||||
|
||||
for (int k = 0; k < repK; ++k)
|
||||
for (int m = 0; m < repM; ++m)
|
||||
for (int n = 0; n < repN; ++n)
|
||||
callMma(2 * m, n, 2 * k);
|
||||
|
||||
Type resElemTy = dTensorTy.getElementType();
|
||||
|
||||
// replace with new packed result
|
||||
Type structTy = LLVM::LLVMStructType::getLiteral(
|
||||
ctx, SmallVector<Type>(fc.size() * numCPackedElem, resElemTy));
|
||||
SmallVector<Value> results(fc.size() * numCPackedElem);
|
||||
for (int i = 0; i < fc.size(); ++i) {
|
||||
for (int j = 0; j < numCPackedElem; ++j) {
|
||||
results[i * numCPackedElem + j] =
|
||||
numCPackedElem > 1
|
||||
? bitcast(extract_element(fc[i], i32_val(j)), resElemTy)
|
||||
: bitcast(fc[i], resElemTy);
|
||||
}
|
||||
}
|
||||
Value res = typeConverter->packLLElements(loc, results, rewriter, structTy);
|
||||
|
||||
rewriter.replaceOp(op, res);
|
||||
|
||||
return success();
|
||||
}
|
||||
|
||||
LogicalResult convertMMA(triton::DotOp op, triton::DotOp::Adaptor adaptor,
|
||||
TritonGPUToLLVMTypeConverter *typeConverter,
|
||||
ConversionPatternRewriter &rewriter, bool isTuring) {
|
||||
auto loc = op.getLoc();
|
||||
auto mmaLayout = op.getResult()
|
||||
.getType()
|
||||
.cast<RankedTensorType>()
|
||||
.getEncoding()
|
||||
.cast<MmaEncodingAttr>();
|
||||
|
||||
Value A = op.getA();
|
||||
Value B = op.getB();
|
||||
Value C = op.getC();
|
||||
|
||||
auto ATensorTy = A.getType().cast<RankedTensorType>();
|
||||
auto BTensorTy = B.getType().cast<RankedTensorType>();
|
||||
|
||||
assert(ATensorTy.getEncoding().isa<DotOperandEncodingAttr>() &&
|
||||
BTensorTy.getEncoding().isa<DotOperandEncodingAttr>() &&
|
||||
"Both $a and %b should be DotOperand layout.");
|
||||
|
||||
Value loadedA, loadedB, loadedC;
|
||||
loadedA = adaptor.getA();
|
||||
loadedB = adaptor.getB();
|
||||
loadedC =
|
||||
loadC(op.getC(), adaptor.getC(), typeConverter, op.getLoc(), rewriter);
|
||||
|
||||
return convertDot(typeConverter, rewriter, op.getLoc(), A, B, C, op.getD(),
|
||||
loadedA, loadedB, loadedC, op, adaptor, isTuring);
|
||||
}
|
||||
|
||||
// Convert to mma.m16n8k8
|
||||
LogicalResult convertMMA1688(triton::DotOp op, triton::DotOp::Adaptor adaptor,
|
||||
TritonGPUToLLVMTypeConverter *typeConverter,
|
||||
ConversionPatternRewriter &rewriter) {
|
||||
return convertMMA(op, adaptor, typeConverter, rewriter, true /*isTuring*/);
|
||||
}
|
||||
|
||||
// Convert to mma.m16n8k16
|
||||
LogicalResult convertMMA16816(triton::DotOp op, triton::DotOp::Adaptor adaptor,
|
||||
TritonGPUToLLVMTypeConverter *typeConverter,
|
||||
ConversionPatternRewriter &rewriter) {
|
||||
return convertMMA(op, adaptor, typeConverter, rewriter, false /*isTuring*/);
|
||||
}
|
||||
431
python/triton/third_party/hip/lib/Conversion/TritonGPUROCMToLLVM/DotOpToLLVM/WGMMA.cpp
vendored
Normal file
431
python/triton/third_party/hip/lib/Conversion/TritonGPUROCMToLLVM/DotOpToLLVM/WGMMA.cpp
vendored
Normal file
@@ -0,0 +1,431 @@
|
||||
/*
|
||||
* Copyright (c) 2023 NVIDIA Corporation & Affiliates. All rights reserved.
|
||||
*
|
||||
* Permission is hereby granted, free of charge, to any person obtaining
|
||||
* a copy of this software and associated documentation files
|
||||
* (the "Software"), to deal in the Software without restriction,
|
||||
* including without limitation the rights to use, copy, modify, merge,
|
||||
* publish, distribute, sublicense, and/or sell copies of the Software,
|
||||
* and to permit persons to whom the Software is furnished to do so,
|
||||
* subject to the following conditions:
|
||||
*
|
||||
* The above copyright notice and this permission notice shall be
|
||||
* included in all copies or substantial portions of the Software.
|
||||
*
|
||||
* THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND,
|
||||
* EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF
|
||||
* MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT.
|
||||
* IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY
|
||||
* CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION OF CONTRACT,
|
||||
* TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION WITH THE
|
||||
* SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE.
|
||||
*/
|
||||
|
||||
#include "DotOpToLLVM.h"
|
||||
#include "Utility.h"
|
||||
|
||||
using namespace mlir;
|
||||
using namespace mlir::triton;
|
||||
|
||||
using ::mlir::LLVM::getSharedMemoryObjectFromStruct;
|
||||
using ::mlir::triton::gpu_rocm::getShapePerCTA;
|
||||
using ::mlir::triton::gpu_rocm::getShapePerCTATile;
|
||||
using ::mlir::triton::gpu_rocm::MmaEncodingAttr;
|
||||
using ::mlir::triton::gpu_rocm::SharedEncodingAttr;
|
||||
|
||||
triton::nvgpu::WGMMAEltType getMmaRetType(Value d) {
|
||||
auto dTy = d.getType().cast<RankedTensorType>().getElementType();
|
||||
if (dTy.isF32()) {
|
||||
return triton::nvgpu::WGMMAEltType::f32;
|
||||
} else if (dTy.isF16()) {
|
||||
return triton::nvgpu::WGMMAEltType::f16;
|
||||
} else if (dTy.isInteger(32)) {
|
||||
return triton::nvgpu::WGMMAEltType::s32;
|
||||
} else {
|
||||
llvm::report_fatal_error("Unsupported mma result type found");
|
||||
}
|
||||
}
|
||||
|
||||
triton::nvgpu::WGMMAEltType getMmaOperandType(Value a, bool allowTF32) {
|
||||
auto aTy = a.getType().cast<RankedTensorType>().getElementType();
|
||||
if (aTy.isF16()) {
|
||||
return triton::nvgpu::WGMMAEltType::f16;
|
||||
} else if (aTy.isBF16()) {
|
||||
return triton::nvgpu::WGMMAEltType::bf16;
|
||||
} else if (aTy.isF32() && allowTF32) {
|
||||
return triton::nvgpu::WGMMAEltType::tf32;
|
||||
} else if (aTy.isInteger(8)) {
|
||||
return triton::nvgpu::WGMMAEltType::s8;
|
||||
} else if (aTy.isFloat8E5M2()) {
|
||||
return triton::nvgpu::WGMMAEltType::e5m2;
|
||||
} else if (aTy.isFloat8E4M3FNUZ()) {
|
||||
return triton::nvgpu::WGMMAEltType::e4m3;
|
||||
} else {
|
||||
llvm::report_fatal_error("Unsupported mma operand type found");
|
||||
}
|
||||
}
|
||||
|
||||
mlir::triton::nvgpu::WGMMADescMode
|
||||
getModeFromLayout(const SharedEncodingAttr &layout, uint32_t widthInByte) {
|
||||
int perPhase = layout.getPerPhase();
|
||||
int maxPhase = layout.getMaxPhase();
|
||||
uint32_t swizzlingByteWidth = 0;
|
||||
|
||||
mlir::triton::nvgpu::WGMMADescMode mode;
|
||||
if (perPhase == 4 && maxPhase == 2) {
|
||||
mode = mlir::triton::nvgpu::WGMMADescMode::swizzle32;
|
||||
swizzlingByteWidth = 32;
|
||||
} else if (perPhase == 2 && maxPhase == 4) {
|
||||
mode = mlir::triton::nvgpu::WGMMADescMode::swizzle64;
|
||||
swizzlingByteWidth = 64;
|
||||
} else if (perPhase == 1 && maxPhase == 8) {
|
||||
mode = mlir::triton::nvgpu::WGMMADescMode::swizzle128;
|
||||
swizzlingByteWidth = 128;
|
||||
} else {
|
||||
llvm::report_fatal_error("Unsupported shared layout.");
|
||||
}
|
||||
|
||||
// TODO[biaow]: remove it once we support swizzling size larger than matrix
|
||||
// width, which requires padding the matrix width to the swizzling size when
|
||||
// allocating shared memory.
|
||||
assert(swizzlingByteWidth <= widthInByte &&
|
||||
"swizzling size larger than matrix width is not supported.");
|
||||
return mode;
|
||||
}
|
||||
|
||||
class DotOpMmaV3SmemLoader {
|
||||
public:
|
||||
DotOpMmaV3SmemLoader(Value tensor, const SharedMemoryObject &smemObj,
|
||||
SmallVector<int64_t> shape, Value warpId,
|
||||
unsigned int dimWpt, bool trans,
|
||||
SmallVector<unsigned int> instrShape,
|
||||
ConversionPatternRewriter &rewriter, Location loc)
|
||||
: base(smemObj.base), shape(shape), warpId(warpId), dimWpt(dimWpt),
|
||||
trans(trans), instrShape(instrShape), rewriter(rewriter), loc(loc) {
|
||||
auto tensorTy = tensor.getType().cast<RankedTensorType>();
|
||||
auto sharedLayout = tensorTy.getEncoding().cast<SharedEncodingAttr>();
|
||||
ord = sharedLayout.getOrder();
|
||||
const int perPhase = sharedLayout.getPerPhase();
|
||||
const int maxPhase = sharedLayout.getMaxPhase();
|
||||
elemBytes = tensorTy.getElementTypeBitWidth() / 8;
|
||||
elemsPerSwizzlingRow = 128 / perPhase / elemBytes;
|
||||
elemsPerSwizzlingRowVal = i32_val(elemsPerSwizzlingRow);
|
||||
|
||||
uint32_t widthInByte = shape[ord[0]] * elemBytes;
|
||||
mode = getModeFromLayout(sharedLayout, widthInByte);
|
||||
|
||||
baseDesc = rewriter.create<triton::nvgpu::WGMMADescCreateOp>(
|
||||
loc, base, i32_val(shape[ord[1]]), mode);
|
||||
}
|
||||
|
||||
Value smemLoad(int a, int b) {
|
||||
Value k = i32_val(b * instrShape[1]);
|
||||
Value m = add(i32_val(a * dimWpt * instrShape[0]),
|
||||
mul(warpId, i32_val(instrShape[0])));
|
||||
if (trans) {
|
||||
std::swap(k, m);
|
||||
}
|
||||
Value leading_offset = mul(udiv(k, elemsPerSwizzlingRowVal),
|
||||
i32_val(shape[ord[1]] * elemsPerSwizzlingRow));
|
||||
Value stride_offset = mul(m, elemsPerSwizzlingRowVal);
|
||||
Value offset = add(add(leading_offset, stride_offset),
|
||||
urem(k, elemsPerSwizzlingRowVal));
|
||||
Value off1 = mul(i32_val(elemBytes), offset);
|
||||
Value off_ = zext(i64_ty, udiv(off1, i32_val(16)));
|
||||
|
||||
return add(baseDesc, off_);
|
||||
}
|
||||
|
||||
private:
|
||||
Value base;
|
||||
SmallVector<int64_t> shape;
|
||||
Value warpId;
|
||||
int dimWpt;
|
||||
bool trans;
|
||||
Value elemsPerSwizzlingRowVal;
|
||||
mlir::triton::nvgpu::WGMMADescMode mode;
|
||||
SmallVector<unsigned int> instrShape;
|
||||
ArrayRef<unsigned> ord;
|
||||
ConversionPatternRewriter &rewriter;
|
||||
Location loc;
|
||||
int elemsPerSwizzlingRow;
|
||||
int elemBytes;
|
||||
Value baseDesc;
|
||||
};
|
||||
|
||||
DotOpMmaV3SmemLoader loadA(TritonGPUToLLVMTypeConverter *typeConverter,
|
||||
ConversionPatternRewriter &rewriter, Location loc,
|
||||
const MmaEncodingAttr &mmaEncoding, Value tensor,
|
||||
const SharedMemoryObject &smemObj, Value thread) {
|
||||
auto aTensorTy = tensor.getType().cast<RankedTensorType>();
|
||||
auto aSharedLayout = aTensorTy.getEncoding().dyn_cast<SharedEncodingAttr>();
|
||||
assert(aSharedLayout && "only support load dot operand from shared.");
|
||||
auto instrShape = mmaEncoding.getInstrShape();
|
||||
auto wpt = mmaEncoding.getWarpsPerCTA();
|
||||
auto aOrd = aSharedLayout.getOrder();
|
||||
bool transA = aOrd[0] == 0;
|
||||
auto shapePerCTA = getShapePerCTA(aTensorTy);
|
||||
|
||||
int numRepM = ceil<unsigned>(shapePerCTA[0], instrShape[0] * wpt[0]);
|
||||
int numRepK = ceil<unsigned>(shapePerCTA[1], instrShape[2]);
|
||||
|
||||
Value warp = udiv(thread, i32_val(32));
|
||||
Value warpM = urem(warp, i32_val(wpt[0]));
|
||||
Value warpId = urem(warpM, i32_val(shapePerCTA[0] / instrShape[0]));
|
||||
|
||||
return {tensor,
|
||||
smemObj,
|
||||
shapePerCTA,
|
||||
warpId,
|
||||
wpt[0],
|
||||
transA,
|
||||
{instrShape[0], instrShape[2]},
|
||||
rewriter,
|
||||
loc};
|
||||
}
|
||||
|
||||
DotOpMmaV3SmemLoader loadB(TritonGPUToLLVMTypeConverter *typeConverter,
|
||||
ConversionPatternRewriter &rewriter, Location loc,
|
||||
MmaEncodingAttr &mmaEncoding, Value tensor,
|
||||
const SharedMemoryObject &smemObj, Value thread) {
|
||||
auto bTensorTy = tensor.getType().cast<RankedTensorType>();
|
||||
auto bSharedLayout = bTensorTy.getEncoding().cast<SharedEncodingAttr>();
|
||||
assert(bSharedLayout && "only support load B from shared.");
|
||||
auto instrShape = mmaEncoding.getInstrShape();
|
||||
auto wpt = mmaEncoding.getWarpsPerCTA();
|
||||
auto bOrd = bSharedLayout.getOrder();
|
||||
bool transB = bOrd[0] == 1;
|
||||
auto shapePerCTA = triton::gpu_rocm::getShapePerCTA(bTensorTy);
|
||||
|
||||
int numRepK = ceil<unsigned>(shapePerCTA[0], instrShape[2]);
|
||||
int numRepN = ceil<unsigned>(shapePerCTA[1], instrShape[1] * wpt[1]);
|
||||
|
||||
Value warp = udiv(thread, i32_val(32));
|
||||
Value warpMN = udiv(warp, i32_val(wpt[0]));
|
||||
Value warpN = urem(warpMN, i32_val(wpt[1]));
|
||||
Value warpId = urem(warpN, i32_val(shapePerCTA[1] / instrShape[1]));
|
||||
|
||||
return {tensor,
|
||||
smemObj,
|
||||
shapePerCTA,
|
||||
warpId,
|
||||
wpt[1],
|
||||
transB,
|
||||
{instrShape[1], instrShape[2]},
|
||||
rewriter,
|
||||
loc};
|
||||
}
|
||||
|
||||
// Return a vector of Value of the accumulator start at startIndex and pack the
|
||||
// values into 32bits in case the accumulator is fp16.
|
||||
llvm::SmallVector<Value> loadC(ConversionPatternRewriter &rewriter,
|
||||
Location loc, const SmallVector<Value> &elements,
|
||||
int startIndex, int numElements) {
|
||||
if (!elements[0].getType().isF16()) {
|
||||
llvm::SmallVector<Value> mmaOut(numElements);
|
||||
for (int i = 0; i < numElements; ++i)
|
||||
mmaOut[i] = elements[startIndex + i];
|
||||
return mmaOut;
|
||||
}
|
||||
// For FP16 we need to pack accumulator into 32-bit integers.
|
||||
llvm::SmallVector<Value> mmaOut(numElements / 2);
|
||||
for (int i = 0; i < numElements / 2; ++i) {
|
||||
Value a0 = elements[startIndex + 2 * i];
|
||||
Value a1 = elements[startIndex + 2 * i + 1];
|
||||
Type cPackTy = vec_ty(rewriter.getF16Type(), 2);
|
||||
Value pack = rewriter.create<LLVM::UndefOp>(loc, cPackTy);
|
||||
pack = insert_element(cPackTy, pack, a0, i32_val(0));
|
||||
pack = insert_element(cPackTy, pack, a1, i32_val(1));
|
||||
pack = bitcast(pack, rewriter.getIntegerType(32));
|
||||
mmaOut[i] = pack;
|
||||
}
|
||||
return mmaOut;
|
||||
}
|
||||
|
||||
// If the accumulator is fp16 unpack it from 32-bit integers.
|
||||
SmallVector<Value> unpackAccumulator(ConversionPatternRewriter &rewriter,
|
||||
Location loc,
|
||||
const SmallVector<Value> &packed,
|
||||
RankedTensorType tensorTy) {
|
||||
if (!tensorTy.getElementType().isF16())
|
||||
return packed;
|
||||
// For fp16 the accumualtor is pack into 32-bit integers so we need to unpack
|
||||
// it.
|
||||
SmallVector<Value> results;
|
||||
for (Value elem : packed) {
|
||||
elem = bitcast(elem, vec_ty(rewriter.getF16Type(), 2));
|
||||
results.push_back(extract_element(rewriter.getF16Type(), elem, i32_val(0)));
|
||||
results.push_back(extract_element(rewriter.getF16Type(), elem, i32_val(1)));
|
||||
}
|
||||
return results;
|
||||
}
|
||||
|
||||
LogicalResult convertDot(TritonGPUToLLVMTypeConverter *typeConverter,
|
||||
ConversionPatternRewriter &rewriter, Location loc,
|
||||
Operation *op, Value a, Value b, Value c, Value d,
|
||||
Value loadedA, Value loadedB, Value loadedC,
|
||||
bool allowTF32, const SharedMemoryObject &smemObjA,
|
||||
const SharedMemoryObject &smemObjB, bool sync,
|
||||
Value thread) {
|
||||
auto aTensorTy = a.getType().cast<RankedTensorType>();
|
||||
auto bTensorTy = b.getType().cast<RankedTensorType>();
|
||||
auto dTensorTy = d.getType().cast<RankedTensorType>();
|
||||
auto aSharedLayout = aTensorTy.getEncoding().cast<SharedEncodingAttr>();
|
||||
auto bSharedLayout = bTensorTy.getEncoding().cast<SharedEncodingAttr>();
|
||||
auto mmaEncoding = dTensorTy.getEncoding().cast<MmaEncodingAttr>();
|
||||
auto aOrd = aSharedLayout.getOrder();
|
||||
auto bOrd = bSharedLayout.getOrder();
|
||||
bool transA = aOrd[0] == 0;
|
||||
bool transB = bOrd[0] == 1;
|
||||
auto dShapePerCTA = getShapePerCTA(dTensorTy);
|
||||
auto instrShape = mmaEncoding.getInstrShape();
|
||||
auto accSize = 2 * (instrShape[1] / 4);
|
||||
int M = 4 * instrShape[0];
|
||||
int N = instrShape[1];
|
||||
int K = instrShape[2];
|
||||
|
||||
auto shapePerCTATile = getShapePerCTATile(mmaEncoding);
|
||||
int numRepM = ceil<unsigned>(dShapePerCTA[0], shapePerCTATile[0]);
|
||||
int numRepN = ceil<unsigned>(dShapePerCTA[1], shapePerCTATile[1]);
|
||||
int numRepK = ceil<unsigned>(aTensorTy.getShape()[1], instrShape[2]);
|
||||
|
||||
DotOpMmaV3SmemLoader aLoader =
|
||||
loadA(typeConverter, rewriter, loc, mmaEncoding, a, smemObjA, thread);
|
||||
DotOpMmaV3SmemLoader bLoader =
|
||||
loadB(typeConverter, rewriter, loc, mmaEncoding, b, smemObjB, thread);
|
||||
|
||||
auto fc = typeConverter->unpackLLElements(loc, loadedC, rewriter, dTensorTy);
|
||||
|
||||
triton::nvgpu::WGMMAEltType eltTypeC = getMmaRetType(d);
|
||||
triton::nvgpu::WGMMAEltType eltTypeA = getMmaOperandType(a, allowTF32);
|
||||
triton::nvgpu::WGMMAEltType eltTypeB = getMmaOperandType(b, allowTF32);
|
||||
|
||||
triton::nvgpu::WGMMALayout layoutA = transA ? triton::nvgpu::WGMMALayout::col
|
||||
: triton::nvgpu::WGMMALayout::row;
|
||||
triton::nvgpu::WGMMALayout layoutB = transB ? triton::nvgpu::WGMMALayout::row
|
||||
: triton::nvgpu::WGMMALayout::col;
|
||||
|
||||
auto func = op->getParentOfType<LLVM::LLVMFuncOp>();
|
||||
int numTMADescs =
|
||||
func->getAttr(kAttrNumTMALoadDescsName).cast<IntegerAttr>().getInt();
|
||||
if (numTMADescs == 0)
|
||||
rewriter.create<triton::nvgpu::FenceAsyncSharedOp>(loc, 0);
|
||||
rewriter.create<triton::nvgpu::WGMMAFenceOp>(loc);
|
||||
|
||||
SmallVector<Value> mmaResults;
|
||||
for (int m = 0; m < numRepM; ++m) {
|
||||
for (int n = 0; n < numRepN; ++n) {
|
||||
llvm::SmallVector<Value> mmaOut =
|
||||
loadC(rewriter, loc, fc, (m * numRepN + n) * accSize, accSize);
|
||||
llvm::SmallVector<Type> elemTypes;
|
||||
for (Value accEl : mmaOut)
|
||||
elemTypes.push_back(accEl.getType());
|
||||
auto accTy =
|
||||
LLVM::LLVMStructType::getLiteral(rewriter.getContext(), elemTypes);
|
||||
Value d = typeConverter->packLLElements(loc, mmaOut, rewriter, accTy);
|
||||
for (int k = 0; k < numRepK; ++k) {
|
||||
auto a = aLoader.smemLoad(m, k);
|
||||
auto b = bLoader.smemLoad(n, k);
|
||||
ValueRange operands{a, b, d};
|
||||
d = rewriter.create<triton::nvgpu::WGMMAOp>(loc, accTy, a, b, d, M, N,
|
||||
K, eltTypeC, eltTypeA,
|
||||
eltTypeB, layoutA, layoutB);
|
||||
}
|
||||
auto acc = typeConverter->unpackLLElements(loc, d, rewriter, accTy);
|
||||
for (int i = 0; i < acc.size(); ++i) {
|
||||
mmaResults.push_back(acc[i]);
|
||||
}
|
||||
}
|
||||
}
|
||||
rewriter.create<triton::nvgpu::WGMMACommitGroupOp>(loc);
|
||||
|
||||
if (sync)
|
||||
rewriter.create<triton::nvgpu::WGMMAWaitGroupOp>(loc, 0);
|
||||
|
||||
SmallVector<Value> results =
|
||||
unpackAccumulator(rewriter, loc, mmaResults, dTensorTy);
|
||||
|
||||
// replace with new packed result
|
||||
Type structTy = LLVM::LLVMStructType::getLiteral(
|
||||
mmaEncoding.getContext(),
|
||||
SmallVector<Type>(results.size(), dTensorTy.getElementType()));
|
||||
auto res = typeConverter->packLLElements(loc, results, rewriter, structTy);
|
||||
rewriter.replaceOp(op, res);
|
||||
return success();
|
||||
}
|
||||
|
||||
// Loading $c to registers, returns a Value.
|
||||
Value loadC(Value tensor, Value llTensor) {
|
||||
auto tensorTy = tensor.getType().cast<RankedTensorType>();
|
||||
auto mmaEncoding = tensorTy.getEncoding().dyn_cast<MmaEncodingAttr>();
|
||||
assert(mmaEncoding && "Currently, we only support $c with a mma layout.");
|
||||
auto instrShape = mmaEncoding.getInstrShape();
|
||||
auto wpt = mmaEncoding.getWarpsPerCTA();
|
||||
auto shapePerCTA = getShapePerCTA(tensorTy);
|
||||
auto shapePerCTATile = getShapePerCTATile(mmaEncoding);
|
||||
|
||||
int numRepM = ceil<unsigned>(shapePerCTA[0], shapePerCTATile[0]);
|
||||
int numRepN = ceil<unsigned>(shapePerCTA[1], shapePerCTATile[1]);
|
||||
|
||||
size_t fcSize = 2 * (instrShape[1] / 4) * numRepM * numRepN;
|
||||
|
||||
auto structTy = llTensor.getType().cast<LLVM::LLVMStructType>();
|
||||
assert(structTy.getBody().size() == fcSize &&
|
||||
"DotOp's $c operand should pass the same number of values as $d in "
|
||||
"mma layout.");
|
||||
return llTensor;
|
||||
}
|
||||
|
||||
LogicalResult convertWGMMA(triton::DotOp op, triton::DotOp::Adaptor adaptor,
|
||||
TritonGPUToLLVMTypeConverter *typeConverter,
|
||||
ConversionPatternRewriter &rewriter, Value thread) {
|
||||
auto loc = op.getLoc();
|
||||
Value A = op.getA();
|
||||
Value B = op.getB();
|
||||
Value C = op.getC();
|
||||
auto ATensorTy = A.getType().cast<RankedTensorType>();
|
||||
auto BTensorTy = B.getType().cast<RankedTensorType>();
|
||||
|
||||
assert(ATensorTy.getEncoding().isa<SharedEncodingAttr>() &&
|
||||
BTensorTy.getEncoding().isa<SharedEncodingAttr>() &&
|
||||
"Both $a and %b should be Shared layout.");
|
||||
|
||||
Value llA, llB, llC;
|
||||
llA = adaptor.getA();
|
||||
llB = adaptor.getB();
|
||||
llC = loadC(C, adaptor.getC());
|
||||
|
||||
auto smemObjA = getSharedMemoryObjectFromStruct(loc, llA, rewriter);
|
||||
auto smemObjB = getSharedMemoryObjectFromStruct(loc, llB, rewriter);
|
||||
return convertDot(typeConverter, rewriter, loc, op.getOperation(), A, B, C,
|
||||
op.getD(), llA, llB, llC, op.getAllowTF32(), smemObjA,
|
||||
smemObjB, true, thread);
|
||||
}
|
||||
|
||||
LogicalResult convertAsyncWGMMA(triton::nvidia_gpu::DotAsyncOp op,
|
||||
triton::nvidia_gpu::DotAsyncOp::Adaptor adaptor,
|
||||
TritonGPUToLLVMTypeConverter *typeConverter,
|
||||
ConversionPatternRewriter &rewriter,
|
||||
Value thread) {
|
||||
auto loc = op.getLoc();
|
||||
Value A = op.getA();
|
||||
Value B = op.getB();
|
||||
Value C = op.getC();
|
||||
auto ATensorTy = A.getType().cast<RankedTensorType>();
|
||||
auto BTensorTy = B.getType().cast<RankedTensorType>();
|
||||
|
||||
assert(ATensorTy.getEncoding().isa<SharedEncodingAttr>() &&
|
||||
BTensorTy.getEncoding().isa<SharedEncodingAttr>() &&
|
||||
"Both $a and %b should be Shared layout.");
|
||||
|
||||
Value llA, llB, llC;
|
||||
llA = adaptor.getA();
|
||||
llB = adaptor.getB();
|
||||
llC = loadC(C, adaptor.getC());
|
||||
|
||||
auto smemObjA = getSharedMemoryObjectFromStruct(loc, llA, rewriter);
|
||||
auto smemObjB = getSharedMemoryObjectFromStruct(loc, llB, rewriter);
|
||||
return convertDot(typeConverter, rewriter, loc, op.getOperation(), A, B, C,
|
||||
op.getD(), llA, llB, llC, op.getAllowTF32(), smemObjA,
|
||||
smemObjB, false, thread);
|
||||
}
|
||||
2170
python/triton/third_party/hip/lib/Conversion/TritonGPUROCMToLLVM/ElementwiseOpToLLVM.cpp
vendored
Normal file
2170
python/triton/third_party/hip/lib/Conversion/TritonGPUROCMToLLVM/ElementwiseOpToLLVM.cpp
vendored
Normal file
File diff suppressed because it is too large
Load Diff
22
python/triton/third_party/hip/lib/Conversion/TritonGPUROCMToLLVM/ElementwiseOpToLLVM.h
vendored
Normal file
22
python/triton/third_party/hip/lib/Conversion/TritonGPUROCMToLLVM/ElementwiseOpToLLVM.h
vendored
Normal file
@@ -0,0 +1,22 @@
|
||||
#ifndef TRITON_CONVERSION_TRITONGPUROCM_TO_ELEMENTWISE_OP_H
|
||||
#define TRITON_CONVERSION_TRITONGPUROCM_TO_ELEMENTWISE_OP_H
|
||||
|
||||
#include "TritonGPUToLLVMBase.h"
|
||||
|
||||
using namespace mlir;
|
||||
using namespace mlir::triton;
|
||||
|
||||
void populateElementwiseOpToLLVMPatterns(
|
||||
TritonGPUToLLVMTypeConverter &typeConverter, RewritePatternSet &patterns,
|
||||
int numWarps, ModuleAxisInfoAnalysis &axisInfoAnalysis,
|
||||
ModuleAllocation &allocation,
|
||||
ConvertTritonGPUOpToLLVMPatternBase::IndexCacheInfo &indexCacheInfo,
|
||||
int computeCapability, PatternBenefit benefit);
|
||||
|
||||
bool isLegalElementwiseOp(Operation *op);
|
||||
|
||||
void populateElementwiseOpToPTXPatterns(
|
||||
TritonGPUToLLVMTypeConverter &typeConverter, RewritePatternSet &patterns,
|
||||
PatternBenefit benefit);
|
||||
|
||||
#endif
|
||||
191
python/triton/third_party/hip/lib/Conversion/TritonGPUROCMToLLVM/GCNAsmFormat.cpp
vendored
Normal file
191
python/triton/third_party/hip/lib/Conversion/TritonGPUROCMToLLVM/GCNAsmFormat.cpp
vendored
Normal file
@@ -0,0 +1,191 @@
|
||||
#include "triton/Conversion/TritonGPUROCMToLLVM/GCNAsmFormat.h"
|
||||
#include "mlir/Dialect/LLVMIR/LLVMDialect.h"
|
||||
#include "mlir/Transforms/DialectConversion.h"
|
||||
#include "triton/Conversion/TritonGPUROCMToLLVM/AsmFormat.h"
|
||||
#include "llvm/Support/raw_ostream.h"
|
||||
#include <sstream> // unify to llvm::raw_string_ostream ?
|
||||
|
||||
namespace mlir {
|
||||
namespace triton {
|
||||
|
||||
GCNInstr::Operand *
|
||||
GCNBuilder::newOperand(mlir::Value value, StringRef constraint,
|
||||
std::function<std::string(int)> formatter) {
|
||||
argArchive.emplace_back(std::make_unique<Operand>(value, constraint));
|
||||
auto *opr = argArchive.back().get();
|
||||
opr->repr = formatter;
|
||||
opr->idx = oprCounter++;
|
||||
return opr;
|
||||
}
|
||||
|
||||
GCNBuilder::Operand *GCNBuilder::newOperand(StringRef constraint) {
|
||||
// Constraint should be something like "=r"
|
||||
assert(!constraint.empty() && constraint[0] == '=');
|
||||
auto *opr = newOperand();
|
||||
opr->idx = oprCounter++;
|
||||
opr->constraint = constraint;
|
||||
return opr;
|
||||
}
|
||||
|
||||
GCNBuilder::Modifier *GCNBuilder::newModifier(StringRef modifier,
|
||||
StringRef arg) {
|
||||
assert(!modifier.empty());
|
||||
auto *mod = newModifier();
|
||||
mod->modifier = modifier;
|
||||
mod->arg = arg;
|
||||
return mod;
|
||||
}
|
||||
|
||||
GCNBuilder::Operand *GCNBuilder::newConstantOperand(const std::string &v) {
|
||||
argArchive.emplace_back(std::make_unique<Operand>());
|
||||
argArchive.back()->repr = [v](int idx) { return v; };
|
||||
return argArchive.back().get();
|
||||
}
|
||||
|
||||
GCNBuilder::Operand *GCNBuilder::newConstantOperand(int v) {
|
||||
std::stringstream ss;
|
||||
ss << "0x" << std::hex << v;
|
||||
return newConstantOperand(ss.str());
|
||||
}
|
||||
|
||||
std::string GCNBuilder::getConstraints() const {
|
||||
auto args = getAllArgs();
|
||||
llvm::SmallVector<std::string, 4> argReprs;
|
||||
for (auto arg : args)
|
||||
argReprs.push_back(arg->constraint);
|
||||
return strJoin(argReprs, ",");
|
||||
}
|
||||
|
||||
llvm::SmallVector<Value, 4> GCNBuilder::getAllMLIRArgs() const {
|
||||
llvm::SmallVector<Value, 4> res;
|
||||
for (auto &arg : argArchive) {
|
||||
if (!arg->isList() && arg->value)
|
||||
res.push_back(arg->value);
|
||||
}
|
||||
return res;
|
||||
}
|
||||
|
||||
SmallVector<GCNBuilder::Operand *, 4> GCNBuilder::getAllArgs() const {
|
||||
llvm::SmallVector<Operand *, 4> res;
|
||||
for (auto &x : argArchive)
|
||||
if (!x->isList())
|
||||
res.push_back(x.get());
|
||||
return res;
|
||||
}
|
||||
|
||||
mlir::Value GCNBuilder::launch(ConversionPatternRewriter &rewriter,
|
||||
Location loc, Type resTy, bool hasSideEffect,
|
||||
bool isAlignStack,
|
||||
ArrayRef<Attribute> attrs) const {
|
||||
auto *ctx = rewriter.getContext();
|
||||
auto inlineAsm = rewriter.create<LLVM::InlineAsmOp>(
|
||||
loc, resTy, getAllMLIRArgs(), // operands
|
||||
dump(), // asm_string
|
||||
getConstraints(), // constraints
|
||||
hasSideEffect, // has_side_effects
|
||||
isAlignStack, // is_align_stack
|
||||
LLVM::AsmDialectAttr::get(ctx,
|
||||
LLVM::AsmDialect::AD_ATT), // asm_dialect
|
||||
ArrayAttr::get(ctx, attrs) // operand_attrs
|
||||
);
|
||||
|
||||
return inlineAsm.getRes();
|
||||
}
|
||||
|
||||
std::string GCNInstr::Operand::dump() const {
|
||||
if (repr)
|
||||
return repr(idx);
|
||||
if (!isList())
|
||||
return "$" + std::to_string(idx);
|
||||
|
||||
llvm::SmallVector<std::string> oprs;
|
||||
for (auto *opr : list)
|
||||
oprs.push_back(opr->dump());
|
||||
return strJoin(oprs, ", ");
|
||||
}
|
||||
|
||||
std::string GCNInstr::Modifier::dump() const {
|
||||
if (!isList())
|
||||
return to_str();
|
||||
|
||||
llvm::SmallVector<std::string> mods;
|
||||
for (auto *mod : list)
|
||||
mods.push_back(mod->dump());
|
||||
return strJoin(mods, " ");
|
||||
}
|
||||
|
||||
GCNInstr::Operand *GCNBuilder::newAddrOperand(mlir::Value addr,
|
||||
StringRef constraint) {
|
||||
auto *opr = newOperand(addr, constraint);
|
||||
opr->repr = [](int idx) -> std::string {
|
||||
std::stringstream ss;
|
||||
ss << "$" << idx;
|
||||
return ss.str();
|
||||
};
|
||||
|
||||
return opr;
|
||||
}
|
||||
|
||||
std::string GCNBuilder::dump() const {
|
||||
llvm::SmallVector<std::string> lines;
|
||||
for (auto &exec : executions) {
|
||||
lines.push_back(exec->dump());
|
||||
}
|
||||
|
||||
return strJoin(lines, "\n\t");
|
||||
}
|
||||
|
||||
GCNInstrExecution &GCNInstrCommon::call(ArrayRef<Operand *> oprs,
|
||||
ArrayRef<Modifier *> mods) {
|
||||
builder->executions.emplace_back(
|
||||
std::make_unique<GCNInstrExecution>(this, oprs, mods));
|
||||
return *builder->executions.back();
|
||||
}
|
||||
|
||||
GCNInstrExecution &GCNInstrCommon::operator()(ArrayRef<Operand *> oprs,
|
||||
ArrayRef<Modifier *> mods) {
|
||||
return call(oprs, mods);
|
||||
}
|
||||
|
||||
std::string GCNInstrExecution::dump() const {
|
||||
std::string osStr;
|
||||
llvm::raw_string_ostream os(osStr);
|
||||
|
||||
std::string instrRepr = strJoin(instr->instrParts, "_");
|
||||
|
||||
llvm::SmallVector<std::string, 4> argReprs;
|
||||
for (auto *arg : argsInOrder) {
|
||||
argReprs.push_back(arg->dump());
|
||||
}
|
||||
|
||||
std::string argsRepr = strJoin(argReprs, ", ");
|
||||
|
||||
llvm::SmallVector<std::string, 4> modReprs;
|
||||
for (auto *mod : mods) {
|
||||
modReprs.push_back(mod->dump());
|
||||
}
|
||||
|
||||
std::string modsRepr = strJoin(modReprs, " ");
|
||||
if (!modsRepr.empty()) {
|
||||
os << instrRepr << " " << argsRepr << " " << modsRepr;
|
||||
} else {
|
||||
os << instrRepr << " " << argsRepr;
|
||||
}
|
||||
os.flush();
|
||||
return osStr;
|
||||
}
|
||||
|
||||
SmallVector<GCNInstrExecution::Operand *>
|
||||
GCNInstrExecution::getArgList() const {
|
||||
SmallVector<Operand *> args;
|
||||
for (auto *arg : argsInOrder) {
|
||||
if (arg->isList())
|
||||
args.insert(args.end(), arg->list.begin(), arg->list.end());
|
||||
else
|
||||
args.push_back(arg);
|
||||
}
|
||||
return args;
|
||||
}
|
||||
|
||||
} // namespace triton
|
||||
} // namespace mlir
|
||||
2061
python/triton/third_party/hip/lib/Conversion/TritonGPUROCMToLLVM/LoadStoreOpToLLVM.cpp
vendored
Normal file
2061
python/triton/third_party/hip/lib/Conversion/TritonGPUROCMToLLVM/LoadStoreOpToLLVM.cpp
vendored
Normal file
File diff suppressed because it is too large
Load Diff
17
python/triton/third_party/hip/lib/Conversion/TritonGPUROCMToLLVM/LoadStoreOpToLLVM.h
vendored
Normal file
17
python/triton/third_party/hip/lib/Conversion/TritonGPUROCMToLLVM/LoadStoreOpToLLVM.h
vendored
Normal file
@@ -0,0 +1,17 @@
|
||||
#ifndef TRITON_CONVERSION_TRITONGPUROCM_TO_LLVM_LOAD_STORE_OP_H
|
||||
#define TRITON_CONVERSION_TRITONGPUROCM_TO_LLVM_LOAD_STORE_OP_H
|
||||
|
||||
#include "TritonGPUToLLVMBase.h"
|
||||
|
||||
using namespace mlir;
|
||||
using namespace mlir::triton;
|
||||
|
||||
void populateLoadStoreOpToLLVMPatterns(
|
||||
TritonGPUToLLVMTypeConverter &typeConverter, RewritePatternSet &patterns,
|
||||
int numWarps, ModuleAxisInfoAnalysis &axisInfoAnalysis,
|
||||
ModuleAllocation &allocation,
|
||||
ConvertTritonGPUOpToLLVMPatternBase::IndexCacheInfo &indexCacheInfo,
|
||||
mlir::triton::gpu::TMAMetadataTy *tmaMetadata,
|
||||
const TensorPtrMapT *tensorPtrMap, PatternBenefit benefit);
|
||||
|
||||
#endif
|
||||
234
python/triton/third_party/hip/lib/Conversion/TritonGPUROCMToLLVM/PTXAsmFormat.cpp
vendored
Normal file
234
python/triton/third_party/hip/lib/Conversion/TritonGPUROCMToLLVM/PTXAsmFormat.cpp
vendored
Normal file
@@ -0,0 +1,234 @@
|
||||
#include "triton/Conversion/TritonGPUROCMToLLVM/PTXAsmFormat.h"
|
||||
#include "triton/Conversion/TritonGPUROCMToLLVM/AsmFormat.h"
|
||||
#include "mlir/Dialect/LLVMIR/LLVMDialect.h"
|
||||
#include "mlir/Transforms/DialectConversion.h"
|
||||
#include "triton/Conversion/TritonGPUROCMToLLVM/AsmFormat.h"
|
||||
#include "llvm/Support/raw_ostream.h"
|
||||
// TODO(Superjomn): unify to llvm::raw_string_ostream
|
||||
#include <sstream>
|
||||
|
||||
namespace mlir {
|
||||
namespace triton {
|
||||
|
||||
PTXInstr::Operand *
|
||||
PTXBuilder::newOperand(mlir::Value value, StringRef constraint,
|
||||
std::function<std::string(int)> formatter) {
|
||||
argArchive.emplace_back(std::make_unique<Operand>(value, constraint));
|
||||
auto *opr = argArchive.back().get();
|
||||
opr->repr = formatter;
|
||||
opr->idx = oprCounter++;
|
||||
return opr;
|
||||
}
|
||||
|
||||
void PTXBuilder::initOperand(Operand *opr) {
|
||||
auto numBits = 0;
|
||||
// Derive numBits from the constraint.
|
||||
if (opr->constraint[1] == 'c' || opr->constraint[1] == 'h')
|
||||
numBits = 16;
|
||||
else if (opr->constraint[1] == 'r')
|
||||
numBits = 32;
|
||||
else if (opr->constraint[1] == 'l')
|
||||
numBits = 64;
|
||||
else
|
||||
llvm_unreachable(("Unknown constraint: " + opr->constraint).c_str());
|
||||
// If numBits is less than 16, we use 16 as default because PTX does not
|
||||
// support 8-bit mov.
|
||||
numBits = numBits < 16 ? 16 : numBits;
|
||||
auto *zero = newConstantOperand(0);
|
||||
auto &init = create<>("mov")->o("u" + std::to_string(numBits));
|
||||
init(opr, zero);
|
||||
}
|
||||
|
||||
PTXBuilder::Operand *PTXBuilder::newOperand(StringRef constraint, bool init) {
|
||||
// Constraint should be something like "=r"
|
||||
assert(constraint.size() == 2 && constraint[0] == '=');
|
||||
auto *opr = newOperand();
|
||||
opr->idx = oprCounter++;
|
||||
opr->constraint = constraint;
|
||||
if (init) {
|
||||
initOperand(opr);
|
||||
}
|
||||
return opr;
|
||||
}
|
||||
|
||||
PTXBuilder::Operand *PTXBuilder::newOperand(unsigned operandIndex) {
|
||||
assert(operandIndex < oprCounter && "operand index out of range");
|
||||
auto *opr = newOperand();
|
||||
opr->idx = oprCounter++;
|
||||
opr->constraint = std::to_string(operandIndex);
|
||||
return opr;
|
||||
}
|
||||
|
||||
PTXBuilder::Operand *PTXBuilder::newConstantOperand(const std::string &v) {
|
||||
argArchive.emplace_back(std::make_unique<Operand>());
|
||||
argArchive.back()->repr = [v](int idx) { return v; };
|
||||
return argArchive.back().get();
|
||||
}
|
||||
|
||||
PTXBuilder::Operand *PTXBuilder::newConstantOperand(int64_t v) {
|
||||
std::stringstream ss;
|
||||
ss << "0x" << std::hex << v;
|
||||
return newConstantOperand(ss.str());
|
||||
}
|
||||
|
||||
std::string PTXBuilder::getConstraints() const {
|
||||
auto args = getAllArgs();
|
||||
llvm::SmallVector<std::string, 4> argReprs;
|
||||
for (auto arg : args)
|
||||
argReprs.push_back(arg->constraint);
|
||||
return strJoin(argReprs, ",");
|
||||
}
|
||||
|
||||
llvm::SmallVector<Value, 4> PTXBuilder::getAllMLIRArgs() const {
|
||||
llvm::SmallVector<Value, 4> res;
|
||||
for (auto &arg : argArchive) {
|
||||
if (!arg->isList() && arg->value)
|
||||
res.push_back(arg->value);
|
||||
}
|
||||
return res;
|
||||
}
|
||||
|
||||
SmallVector<PTXBuilder::Operand *, 4> PTXBuilder::getAllArgs() const {
|
||||
llvm::SmallVector<Operand *, 4> res;
|
||||
for (auto &x : argArchive)
|
||||
if (!x->isList())
|
||||
res.push_back(x.get());
|
||||
return res;
|
||||
}
|
||||
|
||||
mlir::Value PTXBuilder::launch(OpBuilder &rewriter, Location loc, Type resTy,
|
||||
bool hasSideEffect, bool isAlignStack,
|
||||
ArrayRef<Attribute> attrs) const {
|
||||
auto *ctx = rewriter.getContext();
|
||||
auto inlineAsm = rewriter.create<LLVM::InlineAsmOp>(
|
||||
loc, resTy, getAllMLIRArgs(), // operands
|
||||
dump(), // asm_string
|
||||
getConstraints(), // constraints
|
||||
hasSideEffect, // has_side_effects
|
||||
isAlignStack, // is_align_stack
|
||||
LLVM::AsmDialectAttr::get(ctx,
|
||||
LLVM::AsmDialect::AD_ATT), // asm_dialect
|
||||
ArrayAttr::get(ctx, attrs) // operand_attrs
|
||||
);
|
||||
|
||||
return inlineAsm.getRes();
|
||||
}
|
||||
|
||||
std::string PTXInstr::Operand::dump() const {
|
||||
if (repr)
|
||||
return repr(idx);
|
||||
if (!isList())
|
||||
return "$" + std::to_string(idx);
|
||||
|
||||
llvm::SmallVector<std::string> oprs;
|
||||
for (auto *opr : list)
|
||||
oprs.push_back(opr->dump());
|
||||
return "{ " + strJoin(oprs, ", ") + " }";
|
||||
}
|
||||
|
||||
PTXInstr::Operand *PTXBuilder::newAddrOperand(mlir::Value addr,
|
||||
StringRef constraint, int off) {
|
||||
auto *opr = newOperand(addr, constraint);
|
||||
opr->repr = [off](int idx) -> std::string {
|
||||
std::stringstream ss;
|
||||
ss << "[ $" << idx << " + " << off << " ]";
|
||||
return ss.str();
|
||||
};
|
||||
|
||||
return opr;
|
||||
}
|
||||
|
||||
std::string PTXBuilder::dump() const {
|
||||
llvm::SmallVector<std::string> lines;
|
||||
for (auto &exec : executions) {
|
||||
lines.push_back(exec->dump());
|
||||
}
|
||||
|
||||
return strJoin(lines, "\n\t");
|
||||
}
|
||||
|
||||
PTXInstrExecution &PTXInstrCommon::call(ArrayRef<Operand *> oprs,
|
||||
bool onlyAttachMLIRArgs) {
|
||||
if (onlyAttachMLIRArgs) {
|
||||
// Nearly impossible to make the $0,$1 in two PTX code snippets to point to
|
||||
// the same MLIR values in onlyAttachMLIRArgs mode.
|
||||
assert(builder->executions.empty() &&
|
||||
"builder can only hold a single execution when onlyAttachMIIRArgs "
|
||||
"is true.");
|
||||
builder->reorderArgArchive(oprs);
|
||||
}
|
||||
|
||||
builder->executions.emplace_back(
|
||||
std::make_unique<PTXInstrExecution>(this, oprs, onlyAttachMLIRArgs));
|
||||
|
||||
return *builder->executions.back();
|
||||
}
|
||||
|
||||
PTXInstrExecution &PTXInstrCommon::operator()(ArrayRef<Operand *> oprs,
|
||||
bool onlyAttachMLIRArgs) {
|
||||
return call(oprs, onlyAttachMLIRArgs);
|
||||
}
|
||||
|
||||
std::string PTXInstrExecution::dump() const {
|
||||
std::string osStr;
|
||||
llvm::raw_string_ostream os(osStr);
|
||||
|
||||
std::string instrRepr = strJoin(instr->instrParts, ".");
|
||||
if (onlyAttachMLIRArgs)
|
||||
return instrRepr;
|
||||
|
||||
if (pred) {
|
||||
if (!pred->repr)
|
||||
os << "@" << pred->dump() << " ";
|
||||
else
|
||||
os << pred->repr(pred->idx) << " ";
|
||||
}
|
||||
|
||||
llvm::SmallVector<std::string, 4> argReprs;
|
||||
for (auto *arg : argsInOrder) {
|
||||
argReprs.push_back(arg->dump());
|
||||
}
|
||||
|
||||
std::string argsRepr = strJoin(argReprs, ", ");
|
||||
|
||||
os << instrRepr << " " << argsRepr << ";";
|
||||
os.flush();
|
||||
return osStr;
|
||||
}
|
||||
|
||||
SmallVector<PTXInstrExecution::Operand *>
|
||||
PTXInstrExecution::getArgList() const {
|
||||
SmallVector<Operand *> args;
|
||||
for (auto *arg : argsInOrder) {
|
||||
if (arg->isList())
|
||||
args.insert(args.end(), arg->list.begin(), arg->list.end());
|
||||
else
|
||||
args.push_back(arg);
|
||||
}
|
||||
return args;
|
||||
}
|
||||
|
||||
PTXInstr &PTXInstr::global() {
|
||||
o("global");
|
||||
return *this;
|
||||
}
|
||||
|
||||
PTXInstr &PTXInstr::shared() {
|
||||
o("shared");
|
||||
return *this;
|
||||
}
|
||||
|
||||
PTXInstr &PTXInstr::v(int vecWidth, bool predicate) {
|
||||
if (vecWidth > 1) {
|
||||
o("v" + std::to_string(vecWidth), predicate);
|
||||
}
|
||||
return *this;
|
||||
}
|
||||
|
||||
PTXInstr &PTXInstr::b(int width) {
|
||||
o("b" + std::to_string(width));
|
||||
return *this;
|
||||
}
|
||||
|
||||
} // namespace triton
|
||||
} // namespace mlir
|
||||
748
python/triton/third_party/hip/lib/Conversion/TritonGPUROCMToLLVM/ReduceOpToLLVM.cpp
vendored
Normal file
748
python/triton/third_party/hip/lib/Conversion/TritonGPUROCMToLLVM/ReduceOpToLLVM.cpp
vendored
Normal file
@@ -0,0 +1,748 @@
|
||||
#include "ReduceOpToLLVM.h"
|
||||
#include "Utility.h"
|
||||
#include "mlir/Dialect/LLVMIR/NVVMDialect.h"
|
||||
#include "triton/Dialect/TritonGPUROCM/Transforms/Utility.h"
|
||||
#include "triton/Dialect/TritonNvidiaGPU/Transforms/Utility.h"
|
||||
|
||||
using namespace mlir;
|
||||
using namespace mlir::triton;
|
||||
|
||||
using ::mlir::LLVM::delinearize;
|
||||
using ::mlir::LLVM::linearize;
|
||||
using ::mlir::LLVM::shflSync;
|
||||
using ::mlir::LLVM::storeShared;
|
||||
using ::mlir::triton::gpu_rocm::getOrder;
|
||||
using ::mlir::triton::gpu_rocm::getTotalElemsPerThread;
|
||||
|
||||
struct ReduceOpConversion
|
||||
: public ConvertTritonGPUOpToLLVMPattern<triton::ReduceOp> {
|
||||
public:
|
||||
ReduceOpConversion(
|
||||
TritonGPUToLLVMTypeConverter &typeConverter, ModuleAllocation &allocation,
|
||||
ConvertTritonGPUOpToLLVMPatternBase::IndexCacheInfo &indexCacheInfo,
|
||||
int computeCapability, PatternBenefit benefit)
|
||||
: ConvertTritonGPUOpToLLVMPattern<triton::ReduceOp>(
|
||||
typeConverter, allocation, indexCacheInfo, benefit),
|
||||
computeCapability(computeCapability) {}
|
||||
|
||||
LogicalResult
|
||||
matchAndRewrite(triton::ReduceOp op, OpAdaptor adaptor,
|
||||
ConversionPatternRewriter &rewriter) const override {
|
||||
if (ReduceOpHelper(op).isFastReduction())
|
||||
return matchAndRewriteFast(op, adaptor, rewriter);
|
||||
return matchAndRewriteBasic(op, adaptor, rewriter);
|
||||
}
|
||||
|
||||
private:
|
||||
int computeCapability;
|
||||
|
||||
void accumulate(ConversionPatternRewriter &rewriter, Region &combineOp,
|
||||
SmallVector<Value> &acc, ValueRange cur, bool isFirst) const {
|
||||
if (isFirst) {
|
||||
acc = SmallVector<Value>(cur.begin(), cur.end());
|
||||
return;
|
||||
}
|
||||
|
||||
// Create a new copy of the reduce block, and inline it
|
||||
Block *currentBlock = rewriter.getBlock();
|
||||
Region &parent = *currentBlock->getParent();
|
||||
rewriter.cloneRegionBefore(combineOp, &parent.front());
|
||||
auto &newReduce = parent.front();
|
||||
auto returnOp = dyn_cast<triton::ReduceReturnOp>(newReduce.getTerminator());
|
||||
|
||||
llvm::SmallVector<Value> combineArgs(2 * acc.size());
|
||||
for (unsigned i = 0; i < acc.size(); ++i) {
|
||||
combineArgs[i] = acc[i];
|
||||
combineArgs[acc.size() + i] = cur[i];
|
||||
}
|
||||
|
||||
rewriter.inlineBlockBefore(&newReduce, &*rewriter.getInsertionPoint(),
|
||||
combineArgs);
|
||||
|
||||
auto results = returnOp.getResult();
|
||||
for (unsigned i = 0; i < acc.size(); ++i) {
|
||||
acc[i] = results[i];
|
||||
}
|
||||
|
||||
// Delete the terminator, which is no longer used
|
||||
rewriter.eraseOp(returnOp);
|
||||
}
|
||||
|
||||
SmallVector<SmallVector<Value>>
|
||||
unpackInputs(Location loc, triton::ReduceOp op, OpAdaptor adaptor,
|
||||
ConversionPatternRewriter &rewriter) const {
|
||||
auto types = op.getInputTypes();
|
||||
auto operands = adaptor.getOperands();
|
||||
unsigned srcElems = getTotalElemsPerThread(types[0]);
|
||||
SmallVector<SmallVector<Value>> srcValues(srcElems);
|
||||
for (unsigned i = 0; i < op.getNumOperands(); ++i) {
|
||||
auto values = getTypeConverter()->unpackLLElements(loc, operands[i],
|
||||
rewriter, types[i]);
|
||||
|
||||
assert(values.size() == srcValues.size());
|
||||
for (unsigned j = 0; j < srcValues.size(); ++j) {
|
||||
srcValues[j].push_back(values[j]);
|
||||
}
|
||||
}
|
||||
return srcValues;
|
||||
}
|
||||
|
||||
// Calculates the write index in the shared memory where we would be writing
|
||||
// the within-thread accumulations before we start doing across-threads
|
||||
// accumulations. `index` is the index of the within-thread accumulations in
|
||||
// the full tensor, whereas `writeIdx` is the mapped-to index in the shared
|
||||
// memory
|
||||
void getWriteIndexBasic(ConversionPatternRewriter &rewriter, Location loc,
|
||||
Attribute layout, SmallVector<Value> &index,
|
||||
SmallVector<Value> &writeIdx,
|
||||
std::map<int, Value> &ints, unsigned originalAxis,
|
||||
unsigned axis) const {
|
||||
if (auto sliceLayout = layout.dyn_cast<SliceEncodingAttr>()) {
|
||||
// Recover the axis in the parent layout
|
||||
auto parentAxis = axis < sliceLayout.getDim() ? axis : axis + 1;
|
||||
auto parentLayout = sliceLayout.getParent();
|
||||
getWriteIndexBasic(rewriter, loc, parentLayout, index, writeIdx, ints,
|
||||
originalAxis, parentAxis);
|
||||
return;
|
||||
}
|
||||
|
||||
writeIdx = index;
|
||||
auto sizePerThread = triton::gpu_rocm::getSizePerThread(layout);
|
||||
Value axisSizePerThread = ints[sizePerThread[axis]];
|
||||
Value _8 = ints[8];
|
||||
Value _16 = ints[16];
|
||||
#if 1
|
||||
Value _2 = ints[2];
|
||||
Value _4 = ints[4];
|
||||
Value _32 = ints[32];
|
||||
#endif
|
||||
|
||||
if (layout.isa<BlockedEncodingAttr>()) {
|
||||
// A single thread owns axisSizePerThread contiguous values
|
||||
// on the reduction axis. After within thread reduction,
|
||||
// we would have a single accumulation every `axisSizePerThread`
|
||||
// contiguous values in the original tensor, so we would need
|
||||
// to map every `axisSizePerThread` to 1 value in smem as:
|
||||
// writeIdx[originalAxis] = index[originalAxis] / axisSizePerThread
|
||||
writeIdx[originalAxis] = udiv(index[originalAxis], axisSizePerThread);
|
||||
} else if (auto mmaLayout = layout.dyn_cast<MmaEncodingAttr>()) {
|
||||
if (!mmaLayout.isAmpere() && !mmaLayout.isHopper()) {
|
||||
llvm::report_fatal_error("Unsupported layout");
|
||||
}
|
||||
if (originalAxis == 0) {
|
||||
// Because warpTileSize = [16, 8] and threadsPerWarp = [8, 4], each 8
|
||||
// rows in smem would correspond to a warp. The mapping
|
||||
// is: (warp_index) x 8 + (row index within warp)
|
||||
writeIdx[originalAxis] = add(mul(udiv(index[originalAxis], _16), _8),
|
||||
urem(index[originalAxis], _8));
|
||||
} else {
|
||||
// Same as BlockedEncodingAttr case
|
||||
writeIdx[originalAxis] = udiv(index[originalAxis], axisSizePerThread);
|
||||
}
|
||||
} else if (auto mfmaLayout = layout.dyn_cast<MfmaEncodingAttr>()) {
|
||||
// TODO: Support MFMA transposed layout.
|
||||
if (axis == 0) {
|
||||
// Because warpTileSize = [32, 32] and threadsPerWarp = [2, 32], each 2
|
||||
// rows in smem would correspond to a warp. The mapping
|
||||
// is: (warp_index) x 2 + (row index within warp)
|
||||
writeIdx[axis] = add(mul(udiv(index[axis], _32), _2),
|
||||
udiv(urem(index[axis], _32), _4));
|
||||
} else {
|
||||
// Same as BlockedEncodingAttr case
|
||||
writeIdx[axis] = udiv(index[axis], axisSizePerThread);
|
||||
}
|
||||
} else {
|
||||
llvm::report_fatal_error("Unsupported layout");
|
||||
}
|
||||
}
|
||||
|
||||
// Use shared memory for reduction within warps and across warps
|
||||
LogicalResult
|
||||
matchAndRewriteBasic(triton::ReduceOp op, OpAdaptor adaptor,
|
||||
ConversionPatternRewriter &rewriter) const {
|
||||
ReduceOpHelper helper(op);
|
||||
Location loc = op.getLoc();
|
||||
unsigned axis = op.getAxis();
|
||||
|
||||
auto srcTys = op.getInputTypes();
|
||||
auto srcLayout = helper.getSrcLayout();
|
||||
if (!helper.isSupportedLayout()) {
|
||||
assert(false && "Unexpected srcLayout in ReduceOpConversion");
|
||||
}
|
||||
// The order of the axes for the the threads within the warp
|
||||
auto srcOrd = triton::gpu_rocm::getOrder(srcLayout);
|
||||
auto sizePerThread = triton::gpu_rocm::getSizePerThread(srcLayout);
|
||||
auto srcShape = helper.getSrcShape();
|
||||
|
||||
SmallVector<Type> elemPtrTys(srcTys.size());
|
||||
for (unsigned i = 0; i < op.getNumOperands(); ++i) {
|
||||
auto ty = srcTys[i].getElementType();
|
||||
auto llvmElemTy = getTypeConverter()->convertType(ty);
|
||||
elemPtrTys[i] = LLVM::LLVMPointerType::get(llvmElemTy, 3);
|
||||
}
|
||||
auto llvmIndexTy = getTypeConverter()->getIndexType();
|
||||
|
||||
auto smemShape = helper.getScratchConfigBasic();
|
||||
unsigned elems = product<unsigned>(smemShape);
|
||||
|
||||
SmallVector<Value> smemBases(op.getNumOperands());
|
||||
smemBases[0] = bitcast(
|
||||
getSharedMemoryBase(loc, rewriter, op.getOperation()), elemPtrTys[0]);
|
||||
for (unsigned i = 1; i < op.getNumOperands(); ++i) {
|
||||
smemBases[i] =
|
||||
bitcast(gep(elemPtrTys[i - 1], smemBases[i - 1], i32_val(elems)),
|
||||
elemPtrTys[i]);
|
||||
}
|
||||
|
||||
auto srcValues = unpackInputs(loc, op, adaptor, rewriter);
|
||||
std::map<SmallVector<unsigned>, SmallVector<Value>> accs;
|
||||
std::map<SmallVector<unsigned>, SmallVector<Value>> indices;
|
||||
reduceWithinThreads(helper, srcValues, accs, indices, rewriter);
|
||||
|
||||
// cached int32 constants
|
||||
std::map<int, Value> ints;
|
||||
ints[0] = i32_val(0);
|
||||
for (int N = smemShape[axis] / 2; N > 0; N >>= 1)
|
||||
ints[N] = i32_val(N);
|
||||
ints[sizePerThread[axis]] = i32_val(sizePerThread[axis]);
|
||||
ints[8] = i32_val(8);
|
||||
ints[16] = i32_val(16);
|
||||
#if 1
|
||||
ints[2] = i32_val(2);
|
||||
ints[4] = i32_val(4);
|
||||
ints[32] = i32_val(32);
|
||||
#endif
|
||||
// reduce across threads
|
||||
for (auto it : accs) {
|
||||
const SmallVector<unsigned> &key = it.first;
|
||||
auto &acc = it.second;
|
||||
// get the writeIdx at which to write in smem
|
||||
SmallVector<Value> writeIdx;
|
||||
getWriteIndexBasic(rewriter, loc, srcLayout, indices[key], writeIdx, ints,
|
||||
axis, axis);
|
||||
|
||||
// calculate the offset in smem for that writeIdx
|
||||
Value writeOffset = linearize(rewriter, loc, writeIdx, smemShape, srcOrd);
|
||||
SmallVector<Value> writePtrs(op.getNumOperands());
|
||||
for (unsigned i = 0; i < op.getNumOperands(); ++i) {
|
||||
// Store the within-thread accumulated value into shared memory
|
||||
writePtrs[i] = gep(elemPtrTys[i], smemBases[i], writeOffset);
|
||||
store(acc[i], writePtrs[i]);
|
||||
}
|
||||
|
||||
SmallVector<Value> readIdx(writeIdx.size(), ints[0]);
|
||||
// Perform parallel reduction with sequential addressing
|
||||
// E.g. We reduce `smemShape[axis]` elements into `smemShape[axis]/2`
|
||||
// elements using `smemShape[axis]/2` threads where each thread
|
||||
// would accumalte values that are `smemShape[axis]/2` apart
|
||||
// to avoid bank conflicts. Then we repeat with `smemShape[axis]/4`
|
||||
// threads, .. etc.
|
||||
for (int N = smemShape[axis] / 2; N > 0; N >>= 1) {
|
||||
// The readIdx will be N elements away on the reduction axis
|
||||
readIdx[axis] = ints[N];
|
||||
// If the writeIdx is greater or equal to N, do nothing
|
||||
Value readMask = icmp_slt(writeIdx[axis], ints[N]);
|
||||
// Calculate the readOffset, if readMask is False, readOffset=0
|
||||
// meaning we reduce the value at writeIdx with itself
|
||||
Value readOffset = select(
|
||||
readMask, linearize(rewriter, loc, readIdx, smemShape, srcOrd),
|
||||
ints[0]);
|
||||
SmallVector<Value> readPtrs(op.getNumOperands());
|
||||
for (unsigned i = 0; i < op.getNumOperands(); ++i) {
|
||||
// The readPtr is readOffset away from writePtr
|
||||
readPtrs[i] = gep(elemPtrTys[i], writePtrs[i], readOffset);
|
||||
}
|
||||
|
||||
sync(rewriter, loc, op);
|
||||
|
||||
// Combine accumulator value from another thread
|
||||
SmallVector<Value> cur(op.getNumOperands());
|
||||
for (unsigned i = 0; i < op.getNumOperands(); ++i) {
|
||||
cur[i] = load(readPtrs[i]);
|
||||
}
|
||||
accumulate(rewriter, op.getCombineOp(), acc, cur, false);
|
||||
|
||||
sync(rewriter, loc, op);
|
||||
|
||||
// Publish our new accumulator value to shared memory
|
||||
for (unsigned i = 0; i < op.getNumOperands(); ++i) {
|
||||
store(acc[i], writePtrs[i]);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
sync(rewriter, loc, op);
|
||||
|
||||
// set output values
|
||||
SmallVector<Value> results(op.getNumOperands());
|
||||
for (unsigned i = 0; i < op.getNumOperands(); ++i) {
|
||||
if (auto resultTy =
|
||||
op.getResult()[i].getType().dyn_cast<RankedTensorType>()) {
|
||||
// nd-tensor where n >= 1
|
||||
|
||||
auto resultLayout = resultTy.getEncoding();
|
||||
|
||||
unsigned resultElems = getTotalElemsPerThread(resultTy);
|
||||
auto resultIndices = emitIndices(loc, rewriter, resultLayout, resultTy);
|
||||
assert(resultIndices.size() == resultElems);
|
||||
|
||||
SmallVector<Value> resultVals(resultElems);
|
||||
for (unsigned j = 0; j < resultElems; ++j) {
|
||||
SmallVector<Value> readIdx = resultIndices[j];
|
||||
readIdx.insert(readIdx.begin() + axis, ints[0]);
|
||||
Value readOffset =
|
||||
linearize(rewriter, loc, readIdx, smemShape, srcOrd);
|
||||
Value readPtr = gep(elemPtrTys[i], smemBases[i], readOffset);
|
||||
resultVals[j] = load(readPtr);
|
||||
}
|
||||
results[i] = getTypeConverter()->packLLElements(loc, resultVals,
|
||||
rewriter, resultTy);
|
||||
} else {
|
||||
// 0d-tensor -> scalar
|
||||
results[i] = load(smemBases[i]);
|
||||
}
|
||||
}
|
||||
|
||||
auto parentBlock = op.getOperation()->getBlock();
|
||||
rewriter.replaceOp(op, results);
|
||||
return success();
|
||||
}
|
||||
|
||||
void sync(ConversionPatternRewriter &rewriter, Location loc,
|
||||
triton::ReduceOp op) const {
|
||||
// TODO[shuhaoj]: change hard code style of numThreads. Hide async_agent
|
||||
// attr.
|
||||
if (getWSAgentId(op)) {
|
||||
barSync(rewriter, op, getAgentIds(op).front(), 128);
|
||||
} else {
|
||||
barrier();
|
||||
}
|
||||
}
|
||||
|
||||
// Check if the reduction can use a redux op and return the kind.
|
||||
std::optional<NVVM::ReduxKind> matchReduxKind(triton::ReduceOp op) const {
|
||||
#if 1
|
||||
return std::nullopt;
|
||||
#endif
|
||||
if (computeCapability < 80)
|
||||
return std::nullopt;
|
||||
if (op.getNumOperands() != 1 || op.getNumResults() != 1)
|
||||
return std::nullopt;
|
||||
Block *block = &(*op.getCombineOp().begin());
|
||||
Operation *yield = block->getTerminator();
|
||||
Operation *reduceOp = yield->getOperand(0).getDefiningOp();
|
||||
if (!reduceOp || reduceOp->getNumOperands() != 2 ||
|
||||
reduceOp->getNumResults() != 1)
|
||||
return std::nullopt;
|
||||
auto intType = reduceOp->getResultTypes()[0].dyn_cast<IntegerType>();
|
||||
if (!intType || intType.getWidth() > 32)
|
||||
return std::nullopt;
|
||||
if (reduceOp->getOperand(0) != block->getArgument(0) ||
|
||||
reduceOp->getOperand(1) != block->getArgument(1))
|
||||
return std::nullopt;
|
||||
if (isa<arith::AddIOp>(reduceOp))
|
||||
return NVVM::ReduxKind::ADD;
|
||||
if (isa<arith::AndIOp>(reduceOp))
|
||||
return NVVM::ReduxKind::AND;
|
||||
if (isa<arith::OrIOp>(reduceOp))
|
||||
return NVVM::ReduxKind::OR;
|
||||
if (isa<arith::XOrIOp>(reduceOp))
|
||||
return NVVM::ReduxKind::XOR;
|
||||
if (isa<arith::MinSIOp>(reduceOp))
|
||||
return NVVM::ReduxKind::MIN;
|
||||
if (isa<arith::MinUIOp>(reduceOp))
|
||||
return NVVM::ReduxKind::UMIN;
|
||||
if (isa<arith::MaxSIOp>(reduceOp))
|
||||
return NVVM::ReduxKind::MAX;
|
||||
if (isa<arith::MaxUIOp>(reduceOp))
|
||||
return NVVM::ReduxKind::UMAX;
|
||||
return std::nullopt;
|
||||
}
|
||||
|
||||
// Reduce along op axis for elements that are in the same thread. The
|
||||
// accumulated value is stored in accs.
|
||||
void reduceWithinThreads(
|
||||
ReduceOpHelper &helper, SmallVector<SmallVector<Value>> &srcValues,
|
||||
std::map<SmallVector<unsigned>, SmallVector<Value>> &accs,
|
||||
std::map<SmallVector<unsigned>, SmallVector<Value>> &indices,
|
||||
ConversionPatternRewriter &rewriter) const {
|
||||
triton::ReduceOp op = helper.getOperation();
|
||||
RankedTensorType operandType = op.getInputTypes()[0];
|
||||
// Assumes offsets don't actually depend on type
|
||||
SmallVector<SmallVector<unsigned>> offset =
|
||||
emitOffsetForLayout(helper.getSrcLayout(), operandType);
|
||||
unsigned srcElems = getTotalElemsPerThread(operandType);
|
||||
auto *combineOp = &op.getCombineOp();
|
||||
auto srcIndices =
|
||||
emitIndices(op.getLoc(), rewriter, helper.getSrcLayout(), operandType);
|
||||
// reduce within threads
|
||||
for (unsigned i = 0; i < srcElems; ++i) {
|
||||
SmallVector<unsigned> key = offset[i];
|
||||
key[op.getAxis()] = 0;
|
||||
bool isFirst = accs.find(key) == accs.end();
|
||||
accumulate(rewriter, *combineOp, accs[key], srcValues[i], isFirst);
|
||||
if (isFirst)
|
||||
indices[key] = srcIndices[i];
|
||||
}
|
||||
}
|
||||
|
||||
// Apply warp reduction across the given number of contiguous lanes using op
|
||||
// region and the accumulator values as source.
|
||||
void warpReduce(ConversionPatternRewriter &rewriter, Location loc,
|
||||
SmallVector<Value> &acc, triton::ReduceOp op,
|
||||
unsigned numLaneToReduce) const {
|
||||
if (auto kind = matchReduxKind(op)) {
|
||||
// Based on benchmarking on A100 redux op gives a speed up only when doing
|
||||
// a single reduction (not partioned) and when the mask is static.
|
||||
// Therefore we currently only enable it to reduce across all the lanes.
|
||||
if (numLaneToReduce == 32) {
|
||||
assert(acc.size() == 1);
|
||||
Value mask = i32_val(0xFFFFFFFF);
|
||||
// Even though we currently don't use redux for partitioned reduction
|
||||
// the code below supports it in case we want to tweak the heuristic.
|
||||
if (numLaneToReduce < 32) {
|
||||
// For partitioned reduction we need to caluclate the mask so that
|
||||
// each group of numLaneToReduce threads has the correct mask.
|
||||
unsigned bitmask = (1 << numLaneToReduce) - 1;
|
||||
Value threadId = getThreadId(rewriter, loc);
|
||||
Value laneId = urem(threadId, i32_val(32));
|
||||
mask = shl(i32_val(bitmask),
|
||||
and_(laneId, i32_val(~(numLaneToReduce - 1))));
|
||||
}
|
||||
for (unsigned i = 0; i < acc.size(); ++i) {
|
||||
unsigned bitwidth = acc[i].getType().cast<IntegerType>().getWidth();
|
||||
if (bitwidth < 32) {
|
||||
if (*kind == NVVM::ReduxKind::MIN || *kind == NVVM::ReduxKind::MAX)
|
||||
acc[i] = sext(i32_ty, acc[i]);
|
||||
else
|
||||
acc[i] = zext(i32_ty, acc[i]);
|
||||
}
|
||||
acc[i] = rewriter.create<NVVM::ReduxOp>(loc, acc[i].getType(), acc[0],
|
||||
*kind, mask);
|
||||
if (bitwidth < 32)
|
||||
acc[i] = trunc(int_ty(bitwidth), acc[i]);
|
||||
}
|
||||
return;
|
||||
}
|
||||
}
|
||||
|
||||
for (unsigned N = numLaneToReduce / 2; N > 0; N >>= 1) {
|
||||
SmallVector<Value> shfl(acc.size());
|
||||
unsigned shuffleIdx = N;
|
||||
#if 1
|
||||
auto srcTys = op.getInputTypes();
|
||||
auto inputTy = srcTys[0].cast<RankedTensorType>();
|
||||
auto inMfma =
|
||||
inputTy.getEncoding().dyn_cast<triton::gpu_rocm::MfmaEncodingAttr>();
|
||||
if (inMfma && inMfma.getIsTransposed()) {
|
||||
assert(numLaneToReduce == 2 || numLaneToReduce == 4);
|
||||
// for mfma 32x32 adjecant threads in y dimension in transposed MFMA layout are 32
|
||||
// apart: [[0 0 0 0 32 32 32 32 ...] [1 1 1 1 33 33 33 33 ...] ...].
|
||||
// for mfma 16x16 adjecant threads in y dimension in transposed MFMA layout are 16
|
||||
// apart: [[0 0 0 0 16 16 16 16 32 32 32 32 ...] [1 1 1 1 33 33 33 33 ...] ...].
|
||||
const int warpSize = 64;
|
||||
shuffleIdx = warpSize / N / 2;
|
||||
}
|
||||
#endif
|
||||
for (unsigned i = 0; i < acc.size(); ++i) {
|
||||
shfl[i] = shflSync(loc, rewriter, acc[i], shuffleIdx);
|
||||
}
|
||||
accumulate(rewriter, op.getCombineOp(), acc, shfl, false);
|
||||
}
|
||||
}
|
||||
|
||||
// Reduce across threads within each warp.
|
||||
void
|
||||
reduceWithinWarps(ReduceOpHelper &helper,
|
||||
std::map<SmallVector<unsigned>, SmallVector<Value>> &accs,
|
||||
ConversionPatternRewriter &rewriter) const {
|
||||
triton::ReduceOp op = helper.getOperation();
|
||||
unsigned sizeIntraWarps = helper.getIntraWarpSizeWithUniqueData();
|
||||
for (auto it : accs) {
|
||||
const SmallVector<unsigned> &key = it.first;
|
||||
SmallVector<Value> &acc = accs[key];
|
||||
warpReduce(rewriter, op.getLoc(), acc, op, sizeIntraWarps);
|
||||
}
|
||||
}
|
||||
|
||||
// Pack the accumualtor values and replace the reduce op with the result.
|
||||
void packResults(ReduceOpHelper &helper,
|
||||
std::map<SmallVector<unsigned>, SmallVector<Value>> &accs,
|
||||
ConversionPatternRewriter &rewriter) const {
|
||||
triton::ReduceOp op = helper.getOperation();
|
||||
Location loc = op.getLoc();
|
||||
unsigned axis = op.getAxis();
|
||||
SmallVector<Value> results(op.getNumOperands());
|
||||
for (unsigned i = 0; i < op.getNumOperands(); ++i) {
|
||||
if (auto resultTy =
|
||||
op.getResult()[i].getType().dyn_cast<RankedTensorType>()) {
|
||||
auto resultLayout = resultTy.getEncoding().cast<SliceEncodingAttr>();
|
||||
unsigned resultElems = getTotalElemsPerThread(resultTy);
|
||||
SmallVector<SmallVector<unsigned>> resultOffset =
|
||||
emitOffsetForLayout(resultLayout, resultTy);
|
||||
SmallVector<Value> resultVals;
|
||||
for (int j = 0; j < resultElems; j++) {
|
||||
auto key = resultOffset[j];
|
||||
key.insert(key.begin() + axis, 0);
|
||||
resultVals.push_back(accs[key][i]);
|
||||
}
|
||||
results[i] = getTypeConverter()->packLLElements(loc, resultVals,
|
||||
rewriter, resultTy);
|
||||
} else
|
||||
results[i] = accs.begin()->second[i];
|
||||
}
|
||||
rewriter.replaceOp(op, results);
|
||||
}
|
||||
|
||||
// Return the type of the shared memory pointer for operand i.
|
||||
Type getElementPtrType(triton::ReduceOp op, int i) const {
|
||||
auto ty = op.getInputTypes()[i].getElementType();
|
||||
auto llvmElemTy = getTypeConverter()->convertType(ty);
|
||||
return LLVM::LLVMPointerType::get(llvmElemTy, 3);
|
||||
}
|
||||
|
||||
void storeWarpReduceToSharedMemory(
|
||||
ReduceOpHelper &helper,
|
||||
std::map<SmallVector<unsigned>, SmallVector<Value>> &accs,
|
||||
std::map<SmallVector<unsigned>, SmallVector<Value>> &indices,
|
||||
SmallVector<Value> &smemBases,
|
||||
ConversionPatternRewriter &rewriter) const {
|
||||
triton::ReduceOp op = helper.getOperation();
|
||||
Location loc = op.getLoc();
|
||||
Value threadId = getThreadId(rewriter, loc);
|
||||
auto srcLayout = helper.getSrcLayout();
|
||||
unsigned wavefront_size = triton::gpu_rocm::getWarpSize(srcLayout);
|
||||
Value warpSize = i32_val(wavefront_size);
|
||||
Value warpId = udiv(threadId, warpSize);
|
||||
Value laneId = urem(threadId, warpSize);
|
||||
auto srcShape = helper.getSrcShape();
|
||||
unsigned axis = op.getAxis();
|
||||
auto smemShapes = helper.getScratchConfigsFast();
|
||||
|
||||
auto threadsPerWarp =
|
||||
triton::gpu_rocm::getThreadsPerWarpWithUniqueData(srcLayout, srcShape);
|
||||
auto warpsPerCTA =
|
||||
triton::gpu_rocm::getWarpsPerCTAWithUniqueData(srcLayout, srcShape);
|
||||
auto order = getOrder(srcLayout);
|
||||
SmallVector<Value> multiDimLaneId =
|
||||
delinearize(rewriter, loc, laneId, threadsPerWarp, order);
|
||||
SmallVector<Value> multiDimWarpId =
|
||||
delinearize(rewriter, loc, warpId, warpsPerCTA, order);
|
||||
|
||||
#if 1
|
||||
auto srcTys = op.getInputTypes();
|
||||
auto inputTy = srcTys[0].cast<RankedTensorType>();
|
||||
auto inMfma =
|
||||
inputTy.getEncoding().dyn_cast<triton::gpu_rocm::MfmaEncodingAttr>();
|
||||
// Original logic is buggy for warpsPerCTA=[2, 2], but works fine for
|
||||
// warpsPerCTA=[4, 1] (that is used in flash attention, thus tested).
|
||||
// TODO: Check whether this is the case for MMA layout as well, if yes, this
|
||||
// should be fixed in the upstream repo.
|
||||
if (inMfma) {
|
||||
multiDimLaneId = delinearize(rewriter, loc, laneId, threadsPerWarp);
|
||||
multiDimWarpId = delinearize(rewriter, loc, warpId, warpsPerCTA);
|
||||
}
|
||||
#endif
|
||||
|
||||
Value laneIdAxis = multiDimLaneId[axis];
|
||||
Value warpIdAxis = multiDimWarpId[axis];
|
||||
|
||||
Value zero = i32_val(0);
|
||||
Value laneZero = icmp_eq(laneIdAxis, zero);
|
||||
|
||||
std::map<SmallVector<unsigned>, SmallVector<Value>> finalAccs;
|
||||
for (auto it : accs) {
|
||||
const SmallVector<unsigned> &key = it.first;
|
||||
SmallVector<Value> acc = it.second;
|
||||
|
||||
SmallVector<Value> writeIdx = indices[key];
|
||||
writeIdx[axis] = warpIdAxis;
|
||||
Value writeOffset =
|
||||
linearize(rewriter, loc, writeIdx, smemShapes[0], order);
|
||||
for (unsigned i = 0; i < op.getNumOperands(); ++i) {
|
||||
auto elemPtrTy = getElementPtrType(op, i);
|
||||
Value writePtr = gep(elemPtrTy, smemBases[i], writeOffset);
|
||||
storeShared(rewriter, loc, writePtr, acc[i], laneZero);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Load the reduction of each warp and accumulate them to a final value and
|
||||
// store back to shared memory.
|
||||
void accumulatePartialReductions(ReduceOpHelper &helper,
|
||||
SmallVector<Value> &smemBases,
|
||||
ConversionPatternRewriter &rewriter) const {
|
||||
triton::ReduceOp op = helper.getOperation();
|
||||
auto srcLayout = helper.getSrcLayout();
|
||||
auto smemShapes = helper.getScratchConfigsFast();
|
||||
unsigned elems = product<unsigned>(smemShapes[0]);
|
||||
unsigned sizeInterWarps = helper.getInterWarpSizeWithUniqueData();
|
||||
Location loc = op.getLoc();
|
||||
|
||||
Value threadId = getThreadId(rewriter, loc);
|
||||
Value warpSize = i32_val(32);
|
||||
Value laneId = urem(threadId, warpSize);
|
||||
Value zero = i32_val(0);
|
||||
|
||||
auto mod = op.getOperation()->getParentOfType<ModuleOp>();
|
||||
unsigned numThreads =
|
||||
product<unsigned>(triton::gpu_rocm::getWarpsPerCTA(srcLayout)) *
|
||||
triton::gpu_rocm::TritonGPUROCMDialect::getThreadsPerWarp(mod);
|
||||
unsigned elemsPerThread = std::max<unsigned>(elems / numThreads, 1);
|
||||
Value readOffset = threadId;
|
||||
for (unsigned round = 0; round < elemsPerThread; ++round) {
|
||||
// FIXME(Qingyi): need predicate icmp_slt(threadId,
|
||||
// i32_val(sizeInerWarps))
|
||||
SmallVector<Value> acc(op.getNumOperands());
|
||||
for (unsigned i = 0; i < op.getNumOperands(); ++i) {
|
||||
auto elemPtrTy = getElementPtrType(op, i);
|
||||
Value readPtr = gep(elemPtrTy, smemBases[i], readOffset);
|
||||
acc[i] = load(readPtr);
|
||||
}
|
||||
warpReduce(rewriter, loc, acc, op, sizeInterWarps);
|
||||
|
||||
// only the first thread in each sizeInterWarps is writing
|
||||
Value writeOffset = readOffset;
|
||||
SmallVector<Value> writePtrs(op.getNumOperands());
|
||||
for (unsigned i = 0; i < op.getNumOperands(); ++i) {
|
||||
auto elemPtrTy = getElementPtrType(op, i);
|
||||
writePtrs[i] = gep(elemPtrTy, smemBases[i], writeOffset);
|
||||
}
|
||||
Value threadIsNeeded = icmp_slt(threadId, i32_val(elems));
|
||||
Value laneIdModSizeInterWarps = urem(laneId, i32_val(sizeInterWarps));
|
||||
Value laneIdModSizeInterWarpsIsZero =
|
||||
icmp_eq(laneIdModSizeInterWarps, zero);
|
||||
Value pred = and_(threadIsNeeded, laneIdModSizeInterWarpsIsZero);
|
||||
|
||||
auto srcLayout = helper.getSrcLayout();
|
||||
unsigned wavefront_size = triton::gpu_rocm::getWarpSize(srcLayout);
|
||||
|
||||
for (unsigned i = 0; i < op.getNumOperands(); ++i) {
|
||||
#if 1
|
||||
// This barrier is known to be critical for Navi 2x/3x
|
||||
if (i > 0 && wavefront_size == 32) {
|
||||
GCNBuilder BuilderMemfenceLDS;
|
||||
BuilderMemfenceLDS.create<>("s_waitcnt lgkmcnt(0)")->operator()();
|
||||
BuilderMemfenceLDS.launch(rewriter, loc, void_ty(rewriter.getContext()));
|
||||
}
|
||||
#endif
|
||||
storeShared(rewriter, loc, writePtrs[i], acc[i], pred);
|
||||
}
|
||||
|
||||
if (round != elemsPerThread - 1) {
|
||||
readOffset = add(readOffset, i32_val(numThreads));
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Load the final reduction from shared memory and replace the reduce result
|
||||
// with it.
|
||||
void loadReductionAndPackResult(ReduceOpHelper &helper,
|
||||
SmallVector<Value> &smemBases,
|
||||
ConversionPatternRewriter &rewriter) const {
|
||||
triton::ReduceOp op = helper.getOperation();
|
||||
Location loc = op.getLoc();
|
||||
auto smemShapes = helper.getScratchConfigsFast();
|
||||
auto order = getOrder(helper.getSrcLayout());
|
||||
SmallVector<Value> results(op.getNumOperands());
|
||||
for (unsigned i = 0; i < op.getNumOperands(); ++i) {
|
||||
if (auto resultTy =
|
||||
op.getResult()[i].getType().dyn_cast<RankedTensorType>()) {
|
||||
// nd-tensor where n >= 1
|
||||
auto resultLayout = resultTy.getEncoding().cast<SliceEncodingAttr>();
|
||||
unsigned resultElems = getTotalElemsPerThread(resultTy);
|
||||
auto resultIndices = emitIndices(loc, rewriter, resultLayout, resultTy);
|
||||
assert(resultIndices.size() == resultElems);
|
||||
|
||||
SmallVector<Value> resultVals(resultElems);
|
||||
for (size_t j = 0; j < resultElems; ++j) {
|
||||
SmallVector<Value> readIdx = resultIndices[j];
|
||||
readIdx.insert(readIdx.begin() + op.getAxis(), i32_val(0));
|
||||
Value readOffset =
|
||||
linearize(rewriter, loc, readIdx, smemShapes[0], order);
|
||||
Value readPtr =
|
||||
gep(getElementPtrType(op, i), smemBases[i], readOffset);
|
||||
resultVals[j] = load(readPtr);
|
||||
}
|
||||
|
||||
results[i] = getTypeConverter()->packLLElements(loc, resultVals,
|
||||
rewriter, resultTy);
|
||||
} else {
|
||||
// 0d-tensor -> scalar
|
||||
results[i] = load(smemBases[i]);
|
||||
}
|
||||
}
|
||||
rewriter.replaceOp(op, results);
|
||||
}
|
||||
|
||||
// Use warp shuffle for reduction within warps and shared memory for data
|
||||
// exchange across warps
|
||||
LogicalResult matchAndRewriteFast(triton::ReduceOp op, OpAdaptor adaptor,
|
||||
ConversionPatternRewriter &rewriter) const {
|
||||
ReduceOpHelper helper(op);
|
||||
assert(helper.isSupportedLayout() &&
|
||||
"Unexpected srcLayout in ReduceOpConversion");
|
||||
Location loc = op->getLoc();
|
||||
|
||||
auto srcValues = unpackInputs(loc, op, adaptor, rewriter);
|
||||
std::map<SmallVector<unsigned>, SmallVector<Value>> accs;
|
||||
std::map<SmallVector<unsigned>, SmallVector<Value>> indices;
|
||||
// First reduce all the values along axis within each thread.
|
||||
reduceWithinThreads(helper, srcValues, accs, indices, rewriter);
|
||||
|
||||
// Then reduce across threads within a warp.
|
||||
reduceWithinWarps(helper, accs, rewriter);
|
||||
|
||||
if (helper.isWarpSynchronous()) {
|
||||
// If all the values to be reduced are within the same warp there is
|
||||
// nothing left to do.
|
||||
packResults(helper, accs, rewriter);
|
||||
return success();
|
||||
}
|
||||
|
||||
// Compute a shared memory base per operand.
|
||||
auto smemShapes = helper.getScratchConfigsFast();
|
||||
unsigned elems = product<unsigned>(smemShapes[0]);
|
||||
unsigned maxElems = std::max(elems, product<unsigned>(smemShapes[1]));
|
||||
SmallVector<Value> smemBases(op.getNumOperands());
|
||||
smemBases[0] =
|
||||
bitcast(getSharedMemoryBase(loc, rewriter, op.getOperation()),
|
||||
getElementPtrType(op, 0));
|
||||
for (unsigned i = 1; i < op.getNumOperands(); ++i) {
|
||||
smemBases[i] = bitcast(gep(getElementPtrType(op, i - 1), smemBases[i - 1],
|
||||
i32_val(maxElems)),
|
||||
getElementPtrType(op, i));
|
||||
}
|
||||
storeWarpReduceToSharedMemory(helper, accs, indices, smemBases, rewriter);
|
||||
|
||||
sync(rewriter, loc, op);
|
||||
|
||||
// The second round of shuffle reduction
|
||||
// now the problem size: sizeInterWarps, s1, s2, .. , sn
|
||||
// where sizeInterWarps is 2^m
|
||||
//
|
||||
// Each thread needs to process:
|
||||
// elemsPerThread = sizeInterWarps * s1 * s2 .. Sn / numThreads
|
||||
accumulatePartialReductions(helper, smemBases, rewriter);
|
||||
|
||||
// We could avoid this barrier in some of the layouts, however this is not
|
||||
// the general case.
|
||||
// TODO: optimize the barrier in case the layouts are accepted.
|
||||
sync(rewriter, loc, op);
|
||||
|
||||
// set output values
|
||||
loadReductionAndPackResult(helper, smemBases, rewriter);
|
||||
|
||||
return success();
|
||||
}
|
||||
};
|
||||
|
||||
void populateReduceOpToLLVMPatterns(
|
||||
TritonGPUToLLVMTypeConverter &typeConverter, RewritePatternSet &patterns,
|
||||
int numWarps, ModuleAxisInfoAnalysis &axisInfoAnalysis,
|
||||
ModuleAllocation &allocation,
|
||||
ConvertTritonGPUOpToLLVMPatternBase::IndexCacheInfo &indexCacheInfo,
|
||||
int computeCapability, PatternBenefit benefit) {
|
||||
patterns.add<ReduceOpConversion>(typeConverter, allocation, indexCacheInfo,
|
||||
computeCapability, benefit);
|
||||
}
|
||||
16
python/triton/third_party/hip/lib/Conversion/TritonGPUROCMToLLVM/ReduceOpToLLVM.h
vendored
Normal file
16
python/triton/third_party/hip/lib/Conversion/TritonGPUROCMToLLVM/ReduceOpToLLVM.h
vendored
Normal file
@@ -0,0 +1,16 @@
|
||||
#ifndef TRITON_CONVERSION_TRITONGPUROCM_TO_LLVM_REDUCE_OP_H
|
||||
#define TRITON_CONVERSION_TRITONGPUROCM_TO_LLVM_REDUCE_OP_H
|
||||
|
||||
#include "TritonGPUToLLVMBase.h"
|
||||
|
||||
using namespace mlir;
|
||||
using namespace mlir::triton;
|
||||
|
||||
void populateReduceOpToLLVMPatterns(
|
||||
TritonGPUToLLVMTypeConverter &typeConverter, RewritePatternSet &patterns,
|
||||
int numWarps, ModuleAxisInfoAnalysis &axisInfoAnalysis,
|
||||
ModuleAllocation &allocation,
|
||||
ConvertTritonGPUOpToLLVMPatternBase::IndexCacheInfo &indexCacheInfo,
|
||||
int computeCapability, PatternBenefit benefit);
|
||||
|
||||
#endif
|
||||
43
python/triton/third_party/hip/lib/Conversion/TritonGPUROCMToLLVM/RegReallocOpToLLVM.cpp
vendored
Normal file
43
python/triton/third_party/hip/lib/Conversion/TritonGPUROCMToLLVM/RegReallocOpToLLVM.cpp
vendored
Normal file
@@ -0,0 +1,43 @@
|
||||
#include "RegReallocOpToLLVM.h"
|
||||
|
||||
using namespace mlir;
|
||||
using namespace mlir::triton;
|
||||
|
||||
struct RegAllocOpConversion
|
||||
: public ConvertTritonGPUOpToLLVMPattern<triton::nvidia_gpu::RegAllocOp> {
|
||||
using ConvertTritonGPUOpToLLVMPattern<
|
||||
triton::nvidia_gpu::RegAllocOp>::ConvertTritonGPUOpToLLVMPattern;
|
||||
|
||||
LogicalResult
|
||||
matchAndRewrite(triton::nvidia_gpu::RegAllocOp op, OpAdaptor adaptor,
|
||||
ConversionPatternRewriter &rewriter) const override {
|
||||
Location loc = op->getLoc();
|
||||
rewriter.replaceOpWithNewOp<triton::nvgpu::RegAllocOp>(
|
||||
op, adaptor.getRegCount());
|
||||
return success();
|
||||
}
|
||||
};
|
||||
|
||||
struct RegDeallocOpConversion
|
||||
: public ConvertTritonGPUOpToLLVMPattern<triton::nvidia_gpu::RegDeallocOp> {
|
||||
using ConvertTritonGPUOpToLLVMPattern<
|
||||
triton::nvidia_gpu::RegDeallocOp>::ConvertTritonGPUOpToLLVMPattern;
|
||||
|
||||
LogicalResult
|
||||
matchAndRewrite(triton::nvidia_gpu::RegDeallocOp op, OpAdaptor adaptor,
|
||||
ConversionPatternRewriter &rewriter) const override {
|
||||
Location loc = op->getLoc();
|
||||
rewriter.replaceOpWithNewOp<triton::nvgpu::RegDeallocOp>(
|
||||
op, adaptor.getRegCount());
|
||||
return success();
|
||||
}
|
||||
};
|
||||
|
||||
void populateRegReallocOpToLLVMPatterns(
|
||||
TritonGPUToLLVMTypeConverter &typeConverter, RewritePatternSet &patterns,
|
||||
int numWarps, ModuleAxisInfoAnalysis &axisInfoAnalysis,
|
||||
const ModuleAllocation &allocation, PatternBenefit benefit) {
|
||||
patterns.add<RegAllocOpConversion>(typeConverter, benefit);
|
||||
patterns.add<RegDeallocOpConversion>(typeConverter, benefit);
|
||||
return;
|
||||
}
|
||||
14
python/triton/third_party/hip/lib/Conversion/TritonGPUROCMToLLVM/RegReallocOpToLLVM.h
vendored
Normal file
14
python/triton/third_party/hip/lib/Conversion/TritonGPUROCMToLLVM/RegReallocOpToLLVM.h
vendored
Normal file
@@ -0,0 +1,14 @@
|
||||
#ifndef TRITON_CONVERSION_TRITONGPUROCM_TO_LLVM_REGREALLOC_OP_H
|
||||
#define TRITON_CONVERSION_TRITONGPUROCM_TO_LLVM_REGREALLOC_OP_H
|
||||
|
||||
#include "TritonGPUToLLVMBase.h"
|
||||
|
||||
using namespace mlir;
|
||||
using namespace mlir::triton;
|
||||
|
||||
void populateRegReallocOpToLLVMPatterns(
|
||||
TritonGPUToLLVMTypeConverter &typeConverter, RewritePatternSet &patterns,
|
||||
int numWarps, ModuleAxisInfoAnalysis &axisInfoAnalysis,
|
||||
const ModuleAllocation &allocation, PatternBenefit benefit);
|
||||
|
||||
#endif
|
||||
332
python/triton/third_party/hip/lib/Conversion/TritonGPUROCMToLLVM/ScanOpToLLVM.cpp
vendored
Normal file
332
python/triton/third_party/hip/lib/Conversion/TritonGPUROCMToLLVM/ScanOpToLLVM.cpp
vendored
Normal file
@@ -0,0 +1,332 @@
|
||||
#include "ScanOpToLLVM.h"
|
||||
#include "TritonGPUToLLVMBase.h"
|
||||
#include "triton/AnalysisROCM/Utility.h"
|
||||
|
||||
using namespace mlir;
|
||||
using namespace mlir::triton;
|
||||
|
||||
using ::mlir::LLVM::delinearize;
|
||||
using ::mlir::LLVM::linearize;
|
||||
using ::mlir::LLVM::shflUpSync;
|
||||
using ::mlir::LLVM::storeShared;
|
||||
|
||||
// Apply the region of the scan op to the acc and cur values and update acc
|
||||
// inplace with the result.
|
||||
static void accumulate(ConversionPatternRewriter &rewriter, Region &combineOp,
|
||||
Value &acc, Value cur) {
|
||||
if (!acc) {
|
||||
acc = cur;
|
||||
return;
|
||||
}
|
||||
// Create a new copy of the reduce block, and inline it
|
||||
Block *currentBlock = rewriter.getBlock();
|
||||
Region &parent = *currentBlock->getParent();
|
||||
rewriter.cloneRegionBefore(combineOp, &parent.front());
|
||||
auto &newScan = parent.front();
|
||||
auto returnOp = dyn_cast<triton::ScanReturnOp>(newScan.getTerminator());
|
||||
llvm::SmallVector<Value> combineArgs = {acc, cur};
|
||||
rewriter.inlineBlockBefore(&newScan, &*rewriter.getInsertionPoint(),
|
||||
combineArgs);
|
||||
auto results = returnOp.getResult();
|
||||
acc = results[0];
|
||||
// Delete the terminator, which is no longer used
|
||||
rewriter.eraseOp(returnOp);
|
||||
}
|
||||
|
||||
// Scan a contiguous elements within a thread and update `srcValues` in place.
|
||||
static void scanThreadContiguousElements(SmallVector<Value> &srcValues,
|
||||
ConversionPatternRewriter &rewriter,
|
||||
ScanLoweringHelper &helper) {
|
||||
// Depending on layout contiguous elements along axis dim may not be
|
||||
// contiguous in srcValues. Keep track of what elements belong to the same
|
||||
// chunk of contiguous elements.
|
||||
unsigned scanElementsPerThreads = helper.getAxisNumElementsPerThread();
|
||||
unsigned parallelElementsPerThread = helper.getAxisNumElementsPerThread();
|
||||
unsigned numChunks = srcValues.size() / scanElementsPerThreads;
|
||||
unsigned stride = helper.getAxisElementStride();
|
||||
SmallVector<Value> accs(numChunks);
|
||||
for (unsigned srcIndex = 0; srcIndex < srcValues.size(); srcIndex++) {
|
||||
unsigned accIndex = (srcIndex % stride) +
|
||||
((srcIndex / stride) / scanElementsPerThreads) * stride;
|
||||
|
||||
accumulate(rewriter, helper.getCombineOp(), accs[accIndex],
|
||||
srcValues[srcIndex]);
|
||||
srcValues[srcIndex] = accs[accIndex];
|
||||
}
|
||||
}
|
||||
|
||||
// Apply a scan across threads of the warp for the last element of each
|
||||
// contiguous group of elements.
|
||||
static void warpScan(SmallVector<Value> &srcValues,
|
||||
ConversionPatternRewriter &rewriter,
|
||||
ScanLoweringHelper &helper, Value laneIdAxis,
|
||||
Value laneId) {
|
||||
Location loc = helper.getLoc();
|
||||
unsigned scanElementsPerThreads = helper.getAxisNumElementsPerThread();
|
||||
unsigned elementStride = helper.getAxisElementStride();
|
||||
unsigned threadStride = helper.getAxisThreadStride();
|
||||
unsigned scanDim = helper.getAxisNumThreadsPerWarp();
|
||||
for (unsigned srcIndex = 0; srcIndex < srcValues.size(); srcIndex++) {
|
||||
unsigned elementIdx = (srcIndex / elementStride) % scanElementsPerThreads;
|
||||
// Only consider the last element of each contiguous chunk of elements.
|
||||
if (elementIdx != scanElementsPerThreads - 1)
|
||||
continue;
|
||||
// Reduce within warps.
|
||||
Value acc = srcValues[srcIndex];
|
||||
for (unsigned i = 1; i <= (scanDim) / 2; i = i << 1) {
|
||||
Value shfl = shflUpSync(loc, rewriter, acc, i * threadStride, laneId);
|
||||
Value tempAcc = acc;
|
||||
accumulate(rewriter, helper.getCombineOp(), tempAcc, shfl);
|
||||
Value mask = icmp_slt(laneIdAxis, i32_val(i));
|
||||
acc = select(mask, acc, tempAcc);
|
||||
}
|
||||
srcValues[srcIndex] = acc;
|
||||
}
|
||||
}
|
||||
|
||||
// For each set of contiguous elements within a thread we store the partial
|
||||
// reduction into shared memory. Each parallel scan and each warp will store its
|
||||
// own partial reductions. The shared memory is organized as follow:
|
||||
// -----------------------------------------------------------------
|
||||
// chunk 0: | acc[0] warp 0 | acc[1] warp 0 | acc[0] warp 1 | acc[1] warp 1 |
|
||||
// chunk 1: | acc[0] warp 0 | acc[1] warp 0 | acc[0] warp 1 | acc[1] warp 1 |
|
||||
static void storeWarpAccumulator(SmallVector<Value> &srcValues,
|
||||
ConversionPatternRewriter &rewriter,
|
||||
ScanLoweringHelper &helper, Value laneId,
|
||||
Value warpId, Value baseSharedMemPtr,
|
||||
Value parallelLaneId) {
|
||||
Location loc = helper.getLoc();
|
||||
unsigned scanElementsPerThreads = helper.getAxisNumElementsPerThread();
|
||||
unsigned scanDim = helper.getAxisNumThreadsPerWarp();
|
||||
unsigned numParallelLane = helper.getNonAxisNumThreadsPerCTA();
|
||||
unsigned numWarps = helper.getAxisNumWarps();
|
||||
unsigned chunkId = 0;
|
||||
unsigned elementStride = helper.getAxisElementStride();
|
||||
for (unsigned srcIndex = 0; srcIndex < srcValues.size(); srcIndex++) {
|
||||
unsigned elementIdx = (srcIndex / elementStride) % scanElementsPerThreads;
|
||||
// Only consider the last element of each contiguous chunk of elements.
|
||||
if (elementIdx != scanElementsPerThreads - 1)
|
||||
continue;
|
||||
Value lastElement = srcValues[srcIndex];
|
||||
Value mask = icmp_eq(laneId, i32_val(scanDim - 1));
|
||||
Value index = add(parallelLaneId, mul(warpId, i32_val(numParallelLane)));
|
||||
index = add(index, i32_val(chunkId * numParallelLane * numWarps));
|
||||
Value writePtr = gep(baseSharedMemPtr.getType(), baseSharedMemPtr, index);
|
||||
storeShared(rewriter, loc, writePtr, lastElement, mask);
|
||||
chunkId++;
|
||||
}
|
||||
}
|
||||
|
||||
// Read the partial reductions from shared memory from each chunk of contiguous
|
||||
// elements for each warp and parallel scan. Then combine the partial reduction
|
||||
// with the right elements. Within a given contiguous element chunk we update
|
||||
// all the elements by accumulating the value from the last element of the
|
||||
// reduced value from the previous lane.
|
||||
static void AddPartialReduce(SmallVector<Value> &srcValues,
|
||||
ConversionPatternRewriter &rewriter,
|
||||
ScanLoweringHelper &helper, Value sharedMemoryPtr,
|
||||
Value warpId, Value laneIdAxis,
|
||||
Value parallelLaneId, Value laneId) {
|
||||
Location loc = helper.getLoc();
|
||||
unsigned numParallelLane = helper.getNonAxisNumThreadsPerCTA();
|
||||
unsigned numWarps = helper.getAxisNumWarps();
|
||||
unsigned scanElementsPerThreads = helper.getAxisNumElementsPerThread();
|
||||
unsigned parallelElementsPerThread = helper.getNonAxisNumElementsPerThread();
|
||||
unsigned elementStride = helper.getAxisElementStride();
|
||||
unsigned threadStride = helper.getAxisThreadStride();
|
||||
Value maskFirstWarp = icmp_eq(warpId, i32_val(0));
|
||||
Value maskFirstLane = icmp_eq(laneIdAxis, i32_val(0));
|
||||
Value maskFirstThread = and_(maskFirstWarp, maskFirstLane);
|
||||
struct Accumulator {
|
||||
Value acc;
|
||||
Value maskedAcc;
|
||||
};
|
||||
unsigned numScanBlocks = helper.getAxisNumBlocks();
|
||||
unsigned numParallelBlocks = helper.getNonAxisNumBlocks();
|
||||
assert(numScanBlocks * numParallelBlocks * parallelElementsPerThread *
|
||||
scanElementsPerThreads ==
|
||||
srcValues.size());
|
||||
SmallVector<Accumulator> accumulators(numParallelBlocks *
|
||||
parallelElementsPerThread);
|
||||
unsigned chunkId = 0;
|
||||
unsigned blockStride = helper.getAxisBlockStride();
|
||||
for (unsigned srcIndex = 0; srcIndex < srcValues.size(); srcIndex++) {
|
||||
unsigned elementIdx = (srcIndex / elementStride) % scanElementsPerThreads;
|
||||
// Only consider the last element of each contiguous chunk of elements.
|
||||
if (elementIdx != scanElementsPerThreads - 1)
|
||||
continue;
|
||||
// Accumulate the partial reduction from shared memory. Decide which
|
||||
// accumulator to combine based on whether the elements belong to the same
|
||||
// dimension along axis.
|
||||
unsigned blockId = chunkId / parallelElementsPerThread;
|
||||
unsigned parallelBlockId =
|
||||
blockId % blockStride +
|
||||
((blockId / blockStride) / numScanBlocks) * blockStride;
|
||||
unsigned accumulatorIndex = chunkId % parallelElementsPerThread +
|
||||
parallelBlockId * parallelElementsPerThread;
|
||||
Accumulator &accumulator = accumulators[accumulatorIndex];
|
||||
for (unsigned i = 0; i < numWarps; ++i) {
|
||||
Value index = add(parallelLaneId,
|
||||
i32_val(numParallelLane * (i + chunkId * numWarps)));
|
||||
Value ptr = gep(sharedMemoryPtr.getType(), sharedMemoryPtr, index);
|
||||
Value partialReduce = load(ptr);
|
||||
if (!accumulator.acc) {
|
||||
accumulator.acc = partialReduce;
|
||||
accumulator.maskedAcc = partialReduce;
|
||||
continue;
|
||||
}
|
||||
accumulate(rewriter, helper.getCombineOp(), accumulator.acc,
|
||||
partialReduce);
|
||||
Value mask = icmp_slt(warpId, i32_val(i + 1));
|
||||
accumulator.maskedAcc =
|
||||
select(mask, accumulator.maskedAcc, accumulator.acc);
|
||||
}
|
||||
Value temp = srcValues[srcIndex];
|
||||
accumulate(rewriter, helper.getCombineOp(), temp, accumulator.maskedAcc);
|
||||
unsigned axisBlockId = (blockId / blockStride) % numScanBlocks;
|
||||
if (axisBlockId == 0) {
|
||||
// For the first warp and first chunk we don't have anything to
|
||||
// accumulate.
|
||||
temp = select(maskFirstWarp, srcValues[srcIndex], temp);
|
||||
}
|
||||
srcValues[srcIndex] = temp;
|
||||
// Update the rest of the contiguous elements.
|
||||
Value lastElement =
|
||||
shflUpSync(loc, rewriter, srcValues[srcIndex], threadStride, laneId);
|
||||
lastElement = select(maskFirstLane, accumulator.maskedAcc, lastElement);
|
||||
for (unsigned i = 1; i < scanElementsPerThreads; ++i) {
|
||||
Value laneValue = srcValues[srcIndex - i * elementStride];
|
||||
accumulate(rewriter, helper.getCombineOp(), laneValue, lastElement);
|
||||
if (axisBlockId == 0) {
|
||||
// For the first warp and first chunk we don't have anything to
|
||||
// accumulate.
|
||||
laneValue = select(maskFirstThread,
|
||||
srcValues[srcIndex - i * elementStride], laneValue);
|
||||
}
|
||||
srcValues[srcIndex - i * elementStride] = laneValue;
|
||||
}
|
||||
// For the next chunk start back from the value containing the
|
||||
// accumulated value of all the warps.
|
||||
accumulator.maskedAcc = accumulator.acc;
|
||||
chunkId++;
|
||||
}
|
||||
}
|
||||
|
||||
namespace {
|
||||
struct ScanOpConversion
|
||||
: public ConvertTritonGPUOpToLLVMPattern<triton::ScanOp> {
|
||||
public:
|
||||
using ConvertTritonGPUOpToLLVMPattern<
|
||||
triton::ScanOp>::ConvertTritonGPUOpToLLVMPattern;
|
||||
|
||||
LogicalResult
|
||||
matchAndRewrite(triton::ScanOp op, OpAdaptor adaptor,
|
||||
ConversionPatternRewriter &rewriter) const override {
|
||||
if (succeeded(emitFastScan(op, adaptor, rewriter)))
|
||||
return success();
|
||||
return failure();
|
||||
}
|
||||
|
||||
private:
|
||||
std::tuple<Value, Value, Value>
|
||||
getDelinearizedIds(ConversionPatternRewriter &rewriter,
|
||||
ScanLoweringHelper &helper, Value laneId,
|
||||
Value warpId) const;
|
||||
LogicalResult emitFastScan(triton::ScanOp op, triton::ScanOpAdaptor adaptor,
|
||||
ConversionPatternRewriter &rewriter) const;
|
||||
};
|
||||
|
||||
// Break up the threadId into lane and warp id along the scan dimension and
|
||||
// compute a flat id for the parallel dimensions.
|
||||
std::tuple<Value, Value, Value>
|
||||
ScanOpConversion::getDelinearizedIds(ConversionPatternRewriter &rewriter,
|
||||
ScanLoweringHelper &helper, Value laneId,
|
||||
Value warpId) const {
|
||||
auto loc = helper.getLoc();
|
||||
unsigned axis = helper.getAxis();
|
||||
auto srcEncoding = helper.getEncoding();
|
||||
|
||||
auto threadsPerWarp = triton::gpu_rocm::getThreadsPerWarp(srcEncoding);
|
||||
auto warpsPerCTA = triton::gpu_rocm::getWarpsPerCTA(srcEncoding);
|
||||
auto order = triton::gpu_rocm::getOrder(srcEncoding);
|
||||
SmallVector<Value> multiDimLaneId =
|
||||
delinearize(rewriter, loc, laneId, threadsPerWarp, order);
|
||||
SmallVector<Value> multiDimWarpId =
|
||||
delinearize(rewriter, loc, warpId, warpsPerCTA, order);
|
||||
|
||||
Value laneIdAxis = multiDimLaneId[axis];
|
||||
Value warpIdAxis = multiDimWarpId[axis];
|
||||
|
||||
multiDimLaneId[axis] = i32_val(0);
|
||||
threadsPerWarp[axis] = 1;
|
||||
Value laneIdParallel =
|
||||
linearize(rewriter, loc, multiDimLaneId, threadsPerWarp, order);
|
||||
multiDimWarpId[axis] = i32_val(0);
|
||||
warpsPerCTA[axis] = 1;
|
||||
Value warpIdParallel =
|
||||
linearize(rewriter, loc, multiDimWarpId, warpsPerCTA, order);
|
||||
Value flatIdParallel =
|
||||
add(laneIdParallel,
|
||||
mul(warpIdParallel, i32_val(helper.getNonAxisNumThreadsPerWarp())));
|
||||
return std::make_tuple(laneIdAxis, warpIdAxis, flatIdParallel);
|
||||
}
|
||||
|
||||
// Lowering using warp shuffle operations to do warp level scan.
|
||||
LogicalResult
|
||||
ScanOpConversion::emitFastScan(triton::ScanOp op, triton::ScanOpAdaptor adaptor,
|
||||
ConversionPatternRewriter &rewriter) const {
|
||||
ScanLoweringHelper helper(op);
|
||||
auto loc = helper.getLoc();
|
||||
if (!helper.isSupported())
|
||||
return failure();
|
||||
|
||||
// Obtain global laneId and pass it around
|
||||
Value threadId = getThreadId(rewriter, loc);
|
||||
auto mod = op->getParentOfType<ModuleOp>();
|
||||
unsigned iWarpSize = triton::gpu_rocm::TritonGPUROCMDialect::getThreadsPerWarp(mod);
|
||||
Value warpSize = i32_val(iWarpSize);
|
||||
Value warpId = udiv(threadId, warpSize);
|
||||
Value laneId = urem(threadId, warpSize);
|
||||
|
||||
auto [laneIdAxis, warpIdAxis, flatIdParallel] =
|
||||
getDelinearizedIds(rewriter, helper, laneId, warpId);
|
||||
auto input = adaptor.getOperands()[0];
|
||||
auto type = op.getOperand(0).getType().cast<RankedTensorType>();
|
||||
SmallVector<Value> srcValues =
|
||||
getTypeConverter()->unpackLLElements(loc, input, rewriter, type);
|
||||
|
||||
// Scan contigous elements in a thread and update `srcValues`.
|
||||
scanThreadContiguousElements(srcValues, rewriter, helper);
|
||||
// Apply warp level scan to the last element of each chunk of contiguous
|
||||
// elements.
|
||||
warpScan(srcValues, rewriter, helper, laneIdAxis, laneId);
|
||||
|
||||
// Store the partial reducing for each warp into shared memory.
|
||||
Type elemPtrTys = LLVM::LLVMPointerType::get(srcValues[0].getType(), 3);
|
||||
Value baseSharedMemPtr = bitcast(
|
||||
getSharedMemoryBase(loc, rewriter, op.getOperation()), elemPtrTys);
|
||||
storeWarpAccumulator(srcValues, rewriter, helper, laneIdAxis, warpIdAxis,
|
||||
baseSharedMemPtr, flatIdParallel);
|
||||
barrier();
|
||||
// Read back the partial reduction of each warp and accumulate them based on
|
||||
// warpId. Then update each chunk of contiguous elements by adding the
|
||||
// accumulated value from the previous lane.
|
||||
AddPartialReduce(srcValues, rewriter, helper, baseSharedMemPtr, warpIdAxis,
|
||||
laneIdAxis, flatIdParallel, laneId);
|
||||
|
||||
Value results = getTypeConverter()->packLLElements(loc, srcValues, rewriter,
|
||||
input.getType());
|
||||
rewriter.replaceOp(op, results);
|
||||
return success();
|
||||
}
|
||||
} // namespace
|
||||
|
||||
void populateScanOpToLLVMPatterns(
|
||||
TritonGPUToLLVMTypeConverter &typeConverter, RewritePatternSet &patterns,
|
||||
int numWarps, ModuleAxisInfoAnalysis &axisInfoAnalysis,
|
||||
ModuleAllocation &allocation,
|
||||
ConvertTritonGPUOpToLLVMPatternBase::IndexCacheInfo &indexCacheInfo,
|
||||
PatternBenefit benefit) {
|
||||
patterns.add<ScanOpConversion>(typeConverter, allocation, indexCacheInfo,
|
||||
benefit);
|
||||
}
|
||||
16
python/triton/third_party/hip/lib/Conversion/TritonGPUROCMToLLVM/ScanOpToLLVM.h
vendored
Normal file
16
python/triton/third_party/hip/lib/Conversion/TritonGPUROCMToLLVM/ScanOpToLLVM.h
vendored
Normal file
@@ -0,0 +1,16 @@
|
||||
#ifndef TRITON_CONVERSION_TRITONGPUROCM_TO_LLVM_SCAN_OP_H
|
||||
#define TRITON_CONVERSION_TRITONGPUROCM_TO_LLVM_SCAN_OP_H
|
||||
|
||||
#include "TritonGPUToLLVMBase.h"
|
||||
|
||||
using namespace mlir;
|
||||
using namespace mlir::triton;
|
||||
|
||||
void populateScanOpToLLVMPatterns(
|
||||
TritonGPUToLLVMTypeConverter &typeConverter, RewritePatternSet &patterns,
|
||||
int numWarps, ModuleAxisInfoAnalysis &axisInfoAnalysis,
|
||||
ModuleAllocation &allocation,
|
||||
ConvertTritonGPUOpToLLVMPatternBase::IndexCacheInfo &indexCacheInfo,
|
||||
PatternBenefit benefit);
|
||||
|
||||
#endif
|
||||
104
python/triton/third_party/hip/lib/Conversion/TritonGPUROCMToLLVM/TensorPtrOpsToLLVM.cpp
vendored
Normal file
104
python/triton/third_party/hip/lib/Conversion/TritonGPUROCMToLLVM/TensorPtrOpsToLLVM.cpp
vendored
Normal file
@@ -0,0 +1,104 @@
|
||||
/*
|
||||
* Copyright (c) 2023 NVIDIA Corporation & Affiliates. All rights reserved.
|
||||
*
|
||||
* Permission is hereby granted, free of charge, to any person obtaining
|
||||
* a copy of this software and associated documentation files
|
||||
* (the "Software"), to deal in the Software without restriction,
|
||||
* including without limitation the rights to use, copy, modify, merge,
|
||||
* publish, distribute, sublicense, and/or sell copies of the Software,
|
||||
* and to permit persons to whom the Software is furnished to do so,
|
||||
* subject to the following conditions:
|
||||
*
|
||||
* The above copyright notice and this permission notice shall be
|
||||
* included in all copies or substantial portions of the Software.
|
||||
*
|
||||
* THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND,
|
||||
* EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF
|
||||
* MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT.
|
||||
* IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY
|
||||
* CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION OF CONTRACT,
|
||||
* TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION WITH THE
|
||||
* SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE.
|
||||
*/
|
||||
|
||||
#include "TensorPtrOpsToLLVM.h"
|
||||
using namespace mlir;
|
||||
using namespace mlir::triton;
|
||||
|
||||
struct MakeTensorPtrOpConversion
|
||||
: public ConvertTritonGPUOpToLLVMPattern<triton::MakeTensorPtrOp> {
|
||||
using ConvertTritonGPUOpToLLVMPattern<
|
||||
triton::MakeTensorPtrOp>::ConvertTritonGPUOpToLLVMPattern;
|
||||
|
||||
LogicalResult
|
||||
matchAndRewrite(triton::MakeTensorPtrOp op, OpAdaptor adaptor,
|
||||
ConversionPatternRewriter &rewriter) const override {
|
||||
|
||||
// struct { offset0, offset1, shape0, shape1, stride0,
|
||||
// stride1, base_ptr};
|
||||
auto offsets = adaptor.getOffsets();
|
||||
auto shapes = adaptor.getShape();
|
||||
auto strides = adaptor.getStrides();
|
||||
auto base = adaptor.getBase();
|
||||
auto result = op.getResult();
|
||||
|
||||
SmallVector<Value> elems;
|
||||
for (auto offset : offsets)
|
||||
elems.push_back(offset);
|
||||
for (auto shape : shapes)
|
||||
elems.push_back(shape);
|
||||
for (auto stride : strides)
|
||||
elems.push_back(stride);
|
||||
|
||||
elems.push_back(base);
|
||||
|
||||
auto newValue = getTypeConverter()->packLLElements(
|
||||
op.getLoc(), elems, rewriter, result.getType());
|
||||
rewriter.replaceOp(op, newValue);
|
||||
return success();
|
||||
}
|
||||
};
|
||||
|
||||
struct AdvanceOpConversion
|
||||
: public ConvertTritonGPUOpToLLVMPattern<triton::AdvanceOp> {
|
||||
using ConvertTritonGPUOpToLLVMPattern<
|
||||
triton::AdvanceOp>::ConvertTritonGPUOpToLLVMPattern;
|
||||
|
||||
LogicalResult
|
||||
matchAndRewrite(triton::AdvanceOp op, OpAdaptor adaptor,
|
||||
ConversionPatternRewriter &rewriter) const override {
|
||||
// struct { offset0, offset1, shape0, shape1, stride0,
|
||||
// stride1, base_ptr};
|
||||
auto loc = op.getLoc();
|
||||
auto ptrType = op.getPtr().getType();
|
||||
auto tensorPtr = adaptor.getPtr();
|
||||
|
||||
auto offsets = adaptor.getOffsets();
|
||||
auto elems =
|
||||
getTypeConverter()->unpackLLElements(loc, tensorPtr, rewriter, ptrType);
|
||||
|
||||
SmallVector<Value, 2> newOffsets;
|
||||
|
||||
for (auto [offset, oldOffset] : llvm::zip_first(offsets, elems)) {
|
||||
newOffsets.push_back((add(offset, oldOffset)));
|
||||
}
|
||||
|
||||
for (size_t i = 0; i < newOffsets.size(); ++i) {
|
||||
elems[i] = newOffsets[i];
|
||||
}
|
||||
|
||||
auto newValue = getTypeConverter()->packLLElements(op.getLoc(), elems,
|
||||
rewriter, ptrType);
|
||||
rewriter.replaceOp(op, newValue);
|
||||
return success();
|
||||
}
|
||||
};
|
||||
|
||||
void populateTensorPtrOpsToLLVMPatterns(
|
||||
TritonGPUToLLVMTypeConverter &typeConverter, RewritePatternSet &patterns,
|
||||
int numWarps, ModuleAxisInfoAnalysis &axisInfoAnalysis,
|
||||
ModuleAllocation &allocation, PatternBenefit benefit) {
|
||||
patterns.add<MakeTensorPtrOpConversion>(typeConverter, benefit);
|
||||
patterns.add<AdvanceOpConversion>(typeConverter, benefit);
|
||||
return;
|
||||
}
|
||||
37
python/triton/third_party/hip/lib/Conversion/TritonGPUROCMToLLVM/TensorPtrOpsToLLVM.h
vendored
Normal file
37
python/triton/third_party/hip/lib/Conversion/TritonGPUROCMToLLVM/TensorPtrOpsToLLVM.h
vendored
Normal file
@@ -0,0 +1,37 @@
|
||||
/*
|
||||
* Copyright (c) 2023 NVIDIA Corporation & Affiliates. All rights reserved.
|
||||
*
|
||||
* Permission is hereby granted, free of charge, to any person obtaining
|
||||
* a copy of this software and associated documentation files
|
||||
* (the "Software"), to deal in the Software without restriction,
|
||||
* including without limitation the rights to use, copy, modify, merge,
|
||||
* publish, distribute, sublicense, and/or sell copies of the Software,
|
||||
* and to permit persons to whom the Software is furnished to do so,
|
||||
* subject to the following conditions:
|
||||
*
|
||||
* The above copyright notice and this permission notice shall be
|
||||
* included in all copies or substantial portions of the Software.
|
||||
*
|
||||
* THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND,
|
||||
* EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF
|
||||
* MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT.
|
||||
* IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY
|
||||
* CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION OF CONTRACT,
|
||||
* TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION WITH THE
|
||||
* SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE.
|
||||
*/
|
||||
|
||||
#ifndef TRITON_CONVERSION_TRITONGPUROCM_TO_LLVM_TENSOR_PTR_OPS_H
|
||||
#define TRITON_CONVERSION_TRITONGPUROCM_TO_LLVM_TENSOR_PTR_OPS_H
|
||||
|
||||
#include "TritonGPUToLLVMBase.h"
|
||||
|
||||
using namespace mlir;
|
||||
using namespace mlir::triton;
|
||||
|
||||
void populateTensorPtrOpsToLLVMPatterns(
|
||||
TritonGPUToLLVMTypeConverter &typeConverter, RewritePatternSet &patterns,
|
||||
int numWarps, ModuleAxisInfoAnalysis &axisInfoAnalysis,
|
||||
ModuleAllocation &allocation, PatternBenefit benefit);
|
||||
|
||||
#endif
|
||||
825
python/triton/third_party/hip/lib/Conversion/TritonGPUROCMToLLVM/TritonGPUToLLVM.cpp
vendored
Normal file
825
python/triton/third_party/hip/lib/Conversion/TritonGPUROCMToLLVM/TritonGPUToLLVM.cpp
vendored
Normal file
@@ -0,0 +1,825 @@
|
||||
#include "TritonGPUToLLVM.h"
|
||||
#include "Utility.h"
|
||||
#include "mlir/Dialect/ControlFlow/IR/ControlFlowOps.h"
|
||||
|
||||
using namespace mlir;
|
||||
using namespace mlir::triton;
|
||||
|
||||
using ::mlir::LLVM::getSharedMemoryObjectFromStruct;
|
||||
using ::mlir::LLVM::getSRegValue;
|
||||
using ::mlir::triton::gpu_rocm::getTotalElemsPerThread;
|
||||
using ::mlir::triton::gpu_rocm::SharedEncodingAttr;
|
||||
|
||||
struct ReturnOpConversion : public ConvertOpToLLVMPattern<triton::ReturnOp> {
|
||||
using ConvertOpToLLVMPattern<triton::ReturnOp>::ConvertOpToLLVMPattern;
|
||||
|
||||
LogicalResult
|
||||
matchAndRewrite(triton::ReturnOp op, OpAdaptor adaptor,
|
||||
ConversionPatternRewriter &rewriter) const override {
|
||||
unsigned numArguments = op.getNumOperands();
|
||||
|
||||
// Currently, Triton kernel function always return nothing.
|
||||
// TODO(Superjomn) add support for non-inline device function
|
||||
if (numArguments > 0) {
|
||||
return rewriter.notifyMatchFailure(
|
||||
op, "Only kernel function with nothing returned is supported.");
|
||||
}
|
||||
|
||||
rewriter.replaceOpWithNewOp<LLVM::ReturnOp>(op, TypeRange(), ValueRange(),
|
||||
op->getAttrs());
|
||||
return success();
|
||||
}
|
||||
};
|
||||
|
||||
struct BroadcastOpConversion
|
||||
: public ConvertTritonGPUOpToLLVMPattern<triton::BroadcastOp> {
|
||||
using ConvertTritonGPUOpToLLVMPattern<
|
||||
triton::BroadcastOp>::ConvertTritonGPUOpToLLVMPattern;
|
||||
|
||||
LogicalResult
|
||||
matchAndRewrite(triton::BroadcastOp op, OpAdaptor adaptor,
|
||||
ConversionPatternRewriter &rewriter) const override {
|
||||
// Following the order of indices in the legacy code, a broadcast of:
|
||||
// [s(0), s(1) ... s(k-1), 1, s(k+1), s(k+2) ... s(n-1)]
|
||||
// =>
|
||||
// [s(0), s(1) ... s(k-1), s(k), s(k+1), s(k+2) ... s(n-1)]
|
||||
//
|
||||
// logically maps to a broadcast within a thread's scope:
|
||||
// [cta(0)..cta(k-1), 1,cta(k+1)..cta(n-1),spt(0)..spt(k-1),
|
||||
// 1,spt(k+1)..spt(n-1)]
|
||||
// =>
|
||||
// [cta(0)..cta(k-1),cta(k),cta(k+1)..cta(n-1),spt(0)..spt(k-1),spt(k),spt(k+1)..spt(n-1)]
|
||||
//
|
||||
// regardless of the order of the layout
|
||||
//
|
||||
Location loc = op->getLoc();
|
||||
Value src = adaptor.getSrc();
|
||||
Value result = op.getResult();
|
||||
auto srcTy = op.getSrc().getType().cast<RankedTensorType>();
|
||||
auto resultTy = result.getType().cast<RankedTensorType>();
|
||||
auto srcLayout = srcTy.getEncoding();
|
||||
auto resultLayout = resultTy.getEncoding();
|
||||
auto srcShape = srcTy.getShape();
|
||||
auto resultShape = resultTy.getShape();
|
||||
unsigned rank = srcTy.getRank();
|
||||
|
||||
assert(rank == resultTy.getRank());
|
||||
auto order = triton::gpu_rocm::getOrder(srcLayout);
|
||||
auto srcOffsets = emitOffsetForLayout(srcLayout, srcTy);
|
||||
auto resultOffsets = emitOffsetForLayout(resultLayout, resultTy);
|
||||
SmallVector<Value> srcVals =
|
||||
getTypeConverter()->unpackLLElements(loc, src, rewriter, srcTy);
|
||||
|
||||
DenseMap<SmallVector<unsigned>, Value, SmallVectorKeyInfo> srcValues;
|
||||
for (size_t i = 0; i < srcOffsets.size(); i++) {
|
||||
srcValues[srcOffsets[i]] = srcVals[i];
|
||||
}
|
||||
|
||||
SmallVector<Value> resultVals;
|
||||
for (size_t i = 0; i < resultOffsets.size(); i++) {
|
||||
auto offset = resultOffsets[i];
|
||||
for (size_t j = 0; j < srcShape.size(); j++)
|
||||
if (srcShape[j] == 1)
|
||||
offset[j] = 0;
|
||||
resultVals.push_back(srcValues.lookup(offset));
|
||||
}
|
||||
|
||||
Value resultStruct =
|
||||
getTypeConverter()->packLLElements(loc, resultVals, rewriter, resultTy);
|
||||
rewriter.replaceOp(op, {resultStruct});
|
||||
return success();
|
||||
}
|
||||
};
|
||||
|
||||
struct PrintOpConversion
|
||||
: public ConvertTritonGPUOpToLLVMPattern<triton::PrintOp> {
|
||||
using ConvertTritonGPUOpToLLVMPattern<
|
||||
triton::PrintOp>::ConvertTritonGPUOpToLLVMPattern;
|
||||
|
||||
LogicalResult
|
||||
matchAndRewrite(triton::PrintOp op, OpAdaptor adaptor,
|
||||
ConversionPatternRewriter &rewriter) const override {
|
||||
auto loc = op->getLoc();
|
||||
SmallVector<Value, 16> operands;
|
||||
for (size_t i = 0; i < op.getNumOperands(); i++) {
|
||||
auto sub_operands = getTypeConverter()->unpackLLElements(
|
||||
loc, adaptor.getOperands()[i], rewriter, op.getOperand(i).getType());
|
||||
for (auto elem : sub_operands) {
|
||||
operands.push_back(elem);
|
||||
}
|
||||
}
|
||||
std::string formatStr;
|
||||
llvm::raw_string_ostream os(formatStr);
|
||||
os << op.getPrefix();
|
||||
if (!operands.empty()) {
|
||||
os << getFormatSubstr(operands[0]);
|
||||
}
|
||||
|
||||
for (size_t i = 1; i < operands.size(); ++i) {
|
||||
os << ", " << getFormatSubstr(operands[i]);
|
||||
}
|
||||
#if 1
|
||||
llPrintfHIP(loc, op->getParentOfType<mlir::ModuleOp>(), formatStr, operands,
|
||||
rewriter);
|
||||
#else
|
||||
llPrintf(formatStr, operands, rewriter);
|
||||
#endif
|
||||
rewriter.eraseOp(op);
|
||||
return success();
|
||||
}
|
||||
|
||||
std::string getFormatSubstr(Value value) const {
|
||||
Type type = value.getType();
|
||||
if (type.isa<LLVM::LLVMPointerType>()) {
|
||||
return "%p";
|
||||
} else if (type.isBF16() || type.isF16() || type.isF32() || type.isF64()) {
|
||||
return "%f";
|
||||
} else if (type.isSignedInteger()) {
|
||||
if (type.getIntOrFloatBitWidth() == 64)
|
||||
return "%lli";
|
||||
else
|
||||
return "%i";
|
||||
} else if (type.isUnsignedInteger() || type.isSignlessInteger()) {
|
||||
if (type.getIntOrFloatBitWidth() == 64)
|
||||
return "%llu";
|
||||
else
|
||||
return "%u";
|
||||
}
|
||||
assert(false && "not supported type");
|
||||
return "";
|
||||
}
|
||||
|
||||
// declare vprintf(i8*, i8*) as external function
|
||||
static LLVM::LLVMFuncOp
|
||||
getVprintfDeclaration(ConversionPatternRewriter &rewriter) {
|
||||
auto moduleOp =
|
||||
rewriter.getBlock()->getParent()->getParentOfType<ModuleOp>();
|
||||
StringRef funcName("vprintf");
|
||||
Operation *funcOp = moduleOp.lookupSymbol(funcName);
|
||||
if (funcOp)
|
||||
return cast<LLVM::LLVMFuncOp>(*funcOp);
|
||||
|
||||
auto *context = rewriter.getContext();
|
||||
|
||||
SmallVector<Type> argsType{ptr_ty(IntegerType::get(context, 8)),
|
||||
ptr_ty(IntegerType::get(context, 8))};
|
||||
auto funcType = LLVM::LLVMFunctionType::get(i32_ty, argsType);
|
||||
|
||||
ConversionPatternRewriter::InsertionGuard guard(rewriter);
|
||||
rewriter.setInsertionPointToStart(moduleOp.getBody());
|
||||
|
||||
return rewriter.create<LLVM::LLVMFuncOp>(UnknownLoc::get(context), funcName,
|
||||
funcType);
|
||||
}
|
||||
|
||||
// extend integer to int32, extend float to float64
|
||||
// this comes from vprintf alignment requirements.
|
||||
static std::pair<Type, Value>
|
||||
promoteValue(ConversionPatternRewriter &rewriter, Value value) {
|
||||
auto *context = rewriter.getContext();
|
||||
auto type = value.getType();
|
||||
Value newOp = value;
|
||||
Type newType = type;
|
||||
auto loc = UnknownLoc::get(context);
|
||||
|
||||
bool bUnsigned = type.isUnsignedInteger();
|
||||
if (type.isIntOrIndex() && type.getIntOrFloatBitWidth() < 32) {
|
||||
if (bUnsigned) {
|
||||
newType = ui32_ty;
|
||||
newOp = zext(newType, value);
|
||||
} else {
|
||||
newType = i32_ty;
|
||||
newOp = sext(newType, value);
|
||||
}
|
||||
} else if (type.isBF16() || type.isF16() || type.isF32()) {
|
||||
newType = f64_ty;
|
||||
newOp = fpext(newType, value);
|
||||
}
|
||||
|
||||
return {newType, newOp};
|
||||
}
|
||||
|
||||
static void llPrintf(StringRef msg, ValueRange args,
|
||||
ConversionPatternRewriter &rewriter) {
|
||||
assert(!msg.empty() && "printf with empty string not support");
|
||||
Type int8Ptr = ptr_ty(i8_ty);
|
||||
|
||||
auto *ctx = rewriter.getContext();
|
||||
auto moduleOp =
|
||||
rewriter.getBlock()->getParent()->getParentOfType<ModuleOp>();
|
||||
auto funcOp = getVprintfDeclaration(rewriter);
|
||||
auto loc = UnknownLoc::get(ctx);
|
||||
|
||||
Value one = i32_val(1);
|
||||
Value zero = i32_val(0);
|
||||
|
||||
llvm::SmallString<64> msgNewline(msg);
|
||||
msgNewline.push_back('\n');
|
||||
msgNewline.push_back('\0');
|
||||
Value prefixString =
|
||||
LLVM::addStringToModule(loc, rewriter, "printfFormat_", msgNewline);
|
||||
Value bufferPtr = null(int8Ptr);
|
||||
|
||||
SmallVector<Value, 16> newArgs;
|
||||
if (args.size() >= 1) {
|
||||
SmallVector<Type> argTypes;
|
||||
for (auto arg : args) {
|
||||
Type newType;
|
||||
Value newArg;
|
||||
std::tie(newType, newArg) = promoteValue(rewriter, arg);
|
||||
argTypes.push_back(newType);
|
||||
newArgs.push_back(newArg);
|
||||
}
|
||||
|
||||
Type structTy = LLVM::LLVMStructType::getLiteral(ctx, argTypes);
|
||||
auto allocated =
|
||||
rewriter.create<LLVM::AllocaOp>(loc, ptr_ty(structTy), one,
|
||||
/*alignment=*/0);
|
||||
|
||||
for (const auto &entry : llvm::enumerate(newArgs)) {
|
||||
auto index = i32_val(entry.index());
|
||||
auto fieldPtr = gep(ptr_ty(argTypes[entry.index()]), allocated,
|
||||
ArrayRef<Value>{zero, index});
|
||||
store(entry.value(), fieldPtr);
|
||||
}
|
||||
bufferPtr = bitcast(allocated, int8Ptr);
|
||||
}
|
||||
|
||||
SmallVector<Value> operands{prefixString, bufferPtr};
|
||||
call(funcOp, operands);
|
||||
}
|
||||
};
|
||||
|
||||
struct AssertOpConversion
|
||||
: public ConvertTritonGPUOpToLLVMPattern<triton::AssertOp> {
|
||||
using ConvertTritonGPUOpToLLVMPattern<
|
||||
triton::AssertOp>::ConvertTritonGPUOpToLLVMPattern;
|
||||
|
||||
LogicalResult
|
||||
matchAndRewrite(triton::AssertOp op, OpAdaptor adaptor,
|
||||
ConversionPatternRewriter &rewriter) const override {
|
||||
auto loc = op.getLoc();
|
||||
auto ctx = rewriter.getContext();
|
||||
auto elems = getTypeConverter()->unpackLLElements(
|
||||
loc, adaptor.getCondition(), rewriter, op.getCondition().getType());
|
||||
auto elemTy = elems[0].getType();
|
||||
Value condition = int_val(elemTy.getIntOrFloatBitWidth(), 0);
|
||||
for (auto elem : elems) {
|
||||
if (elemTy.isSignedInteger() || elemTy.isSignlessInteger()) {
|
||||
condition =
|
||||
or_(condition,
|
||||
icmp_eq(elem, rewriter.create<LLVM::ConstantOp>(
|
||||
loc, elemTy, rewriter.getZeroAttr(elemTy))));
|
||||
} else {
|
||||
assert(false && "Unsupported type for assert");
|
||||
return failure();
|
||||
}
|
||||
}
|
||||
#if 1
|
||||
llAssertHIP(op, condition, adaptor.getMessage(), adaptor.getFile(),
|
||||
adaptor.getFunc(), adaptor.getLine(), rewriter);
|
||||
#else
|
||||
llAssert(op, condition, adaptor.getMessage(), adaptor.getFile(),
|
||||
adaptor.getFunc(), adaptor.getLine(), rewriter);
|
||||
#endif
|
||||
rewriter.eraseOp(op);
|
||||
return success();
|
||||
}
|
||||
|
||||
// op: the op at which the assert is inserted. Unlike printf, we need to
|
||||
// know about the op to split the block.
|
||||
#if 1
|
||||
void llAssertHIP(Operation *op, Value condition, StringRef message,
|
||||
StringRef file, StringRef func, int line,
|
||||
ConversionPatternRewriter &rewriter) const {
|
||||
ConversionPatternRewriter::InsertionGuard guard(rewriter);
|
||||
auto ctx = rewriter.getContext();
|
||||
auto loc = op->getLoc();
|
||||
|
||||
// #prevBlock
|
||||
// if (condition) {
|
||||
// #ifBlock
|
||||
// print(message);
|
||||
// halt;
|
||||
// }
|
||||
// #endBlock
|
||||
Block *prevBlock = op->getBlock();
|
||||
Block *ifBlock = rewriter.splitBlock(prevBlock, op->getIterator());
|
||||
rewriter.setInsertionPointToStart(ifBlock);
|
||||
|
||||
SmallString<256> tmpBuf;
|
||||
message =
|
||||
llvm::Twine("Assertion failed: " + message + ", File: " + file +
|
||||
", Function: " + func + ", Line: " + llvm::Twine(line))
|
||||
.toStringRef(tmpBuf);
|
||||
|
||||
// Print assert message.
|
||||
llPrintfHIP(loc, op->getParentOfType<mlir::ModuleOp>(), message,
|
||||
ValueRange(), rewriter, /*stderr*/ true);
|
||||
|
||||
// Perform the trap.
|
||||
GCNBuilder BuilderTrap;
|
||||
// TODO: LLVM::Trap LLVM::DebugTrap instructions don't work here.
|
||||
BuilderTrap.create<>("s_endpgm")->operator()();
|
||||
BuilderTrap.launch(rewriter, loc, void_ty(ctx));
|
||||
|
||||
// Split a block after the call.
|
||||
Block *endBlock = rewriter.splitBlock(ifBlock, op->getIterator());
|
||||
rewriter.setInsertionPointToEnd(ifBlock);
|
||||
rewriter.create<cf::BranchOp>(loc, endBlock);
|
||||
rewriter.setInsertionPointToEnd(prevBlock);
|
||||
rewriter.create<cf::CondBranchOp>(loc, condition, ifBlock, endBlock);
|
||||
}
|
||||
|
||||
#else // USE_ROCM
|
||||
|
||||
static void llAssert(Operation *op, Value condition, StringRef message,
|
||||
StringRef file, StringRef func, int line,
|
||||
ConversionPatternRewriter &rewriter) {
|
||||
ConversionPatternRewriter::InsertionGuard guard(rewriter);
|
||||
auto ctx = rewriter.getContext();
|
||||
auto loc = op->getLoc();
|
||||
|
||||
// #block1
|
||||
// if (condition) {
|
||||
// #block2
|
||||
// __assertfail(message);
|
||||
// }
|
||||
// #block3
|
||||
Block *prevBlock = op->getBlock();
|
||||
Block *ifBlock = rewriter.splitBlock(prevBlock, op->getIterator());
|
||||
rewriter.setInsertionPointToStart(ifBlock);
|
||||
|
||||
auto funcOp = getAssertfailDeclaration(rewriter);
|
||||
auto moduleOp =
|
||||
rewriter.getBlock()->getParent()->getParentOfType<ModuleOp>();
|
||||
Value messageString =
|
||||
LLVM::addStringToModule(loc, rewriter, "assertMessage_", message);
|
||||
Value fileString =
|
||||
LLVM::addStringToModule(loc, rewriter, "assertFile_", file);
|
||||
Value funcString =
|
||||
LLVM::addStringToModule(loc, rewriter, "assertFunc_", func);
|
||||
Value lineNumber = i32_val(line);
|
||||
Value charSize = int_val(sizeof(size_t) * 8, sizeof(char));
|
||||
|
||||
SmallVector<Value> operands = {messageString, fileString, lineNumber,
|
||||
funcString, charSize};
|
||||
auto ret = call(funcOp, operands);
|
||||
|
||||
// Split a block after the call.
|
||||
Block *thenBlock = rewriter.splitBlock(ifBlock, op->getIterator());
|
||||
rewriter.setInsertionPointToEnd(ifBlock);
|
||||
rewriter.create<cf::BranchOp>(loc, thenBlock);
|
||||
rewriter.setInsertionPointToEnd(prevBlock);
|
||||
rewriter.create<cf::CondBranchOp>(loc, condition, ifBlock, thenBlock);
|
||||
}
|
||||
|
||||
static LLVM::LLVMFuncOp
|
||||
getAssertfailDeclaration(ConversionPatternRewriter &rewriter) {
|
||||
auto moduleOp =
|
||||
rewriter.getBlock()->getParent()->getParentOfType<ModuleOp>();
|
||||
StringRef funcName("__assertfail");
|
||||
Operation *funcOp = moduleOp.lookupSymbol(funcName);
|
||||
if (funcOp)
|
||||
return cast<LLVM::LLVMFuncOp>(*funcOp);
|
||||
|
||||
// void __assert_fail(const char * assertion, const char * file, unsigned
|
||||
// int line, const char * function);
|
||||
auto *ctx = rewriter.getContext();
|
||||
SmallVector<Type> argsType{ptr_ty(i8_ty), ptr_ty(i8_ty), i32_ty,
|
||||
ptr_ty(i8_ty),
|
||||
rewriter.getIntegerType(sizeof(size_t) * 8)};
|
||||
auto funcType = LLVM::LLVMFunctionType::get(void_ty(ctx), argsType);
|
||||
|
||||
ConversionPatternRewriter::InsertionGuard guard(rewriter);
|
||||
rewriter.setInsertionPointToStart(moduleOp.getBody());
|
||||
|
||||
return rewriter.create<LLVM::LLVMFuncOp>(UnknownLoc::get(ctx), funcName,
|
||||
funcType);
|
||||
}
|
||||
#endif // USE_ROCM
|
||||
};
|
||||
|
||||
struct MakeRangeOpConversion
|
||||
: public ConvertTritonGPUOpToLLVMPattern<triton::MakeRangeOp> {
|
||||
|
||||
MakeRangeOpConversion(
|
||||
TritonGPUToLLVMTypeConverter &converter,
|
||||
ConvertTritonGPUOpToLLVMPatternBase::IndexCacheInfo &indexCacheInfo,
|
||||
PatternBenefit benefit)
|
||||
: ConvertTritonGPUOpToLLVMPattern<triton::MakeRangeOp>(
|
||||
converter, indexCacheInfo, benefit) {}
|
||||
|
||||
LogicalResult
|
||||
matchAndRewrite(triton::MakeRangeOp op, OpAdaptor adaptor,
|
||||
ConversionPatternRewriter &rewriter) const override {
|
||||
Location loc = op->getLoc();
|
||||
auto rankedTy = op.getResult().getType().cast<RankedTensorType>();
|
||||
auto shape = rankedTy.getShape();
|
||||
auto layout = rankedTy.getEncoding();
|
||||
|
||||
auto elemTy = rankedTy.getElementType();
|
||||
assert(elemTy.isInteger(32));
|
||||
Value start = createIndexAttrConstant(rewriter, loc, elemTy, op.getStart());
|
||||
auto idxs = emitIndices(loc, rewriter, layout, rankedTy);
|
||||
unsigned elems = idxs.size();
|
||||
SmallVector<Value> retVals(elems);
|
||||
// TODO: slice layout has more elements than expected.
|
||||
// Unexpected behavior for make range, but generally OK when followed by
|
||||
// expand dims + broadcast. very weird behavior otherwise potentially.
|
||||
for (const auto &multiDim : llvm::enumerate(idxs)) {
|
||||
assert(multiDim.value().size() == 1);
|
||||
retVals[multiDim.index()] = add(multiDim.value()[0], start);
|
||||
}
|
||||
Value result =
|
||||
getTypeConverter()->packLLElements(loc, retVals, rewriter, rankedTy);
|
||||
rewriter.replaceOp(op, result);
|
||||
return success();
|
||||
}
|
||||
};
|
||||
|
||||
struct GetProgramIdOpConversion
|
||||
: public ConvertTritonGPUOpToLLVMPattern<triton::GetProgramIdOp> {
|
||||
using ConvertTritonGPUOpToLLVMPattern<
|
||||
triton::GetProgramIdOp>::ConvertTritonGPUOpToLLVMPattern;
|
||||
|
||||
LogicalResult
|
||||
matchAndRewrite(triton::GetProgramIdOp op, OpAdaptor adaptor,
|
||||
ConversionPatternRewriter &rewriter) const override {
|
||||
|
||||
#if 1
|
||||
Location loc = op->getLoc();
|
||||
assert(op.getAxisAsInt() < 3);
|
||||
|
||||
Value blockId =
|
||||
rewriter.create<::mlir::gpu::BlockIdOp>(loc, dims[op.getAxisAsInt()]);
|
||||
rewriter.replaceOpWithNewOp<arith::IndexCastOp>(op, i32_ty, blockId);
|
||||
return success();
|
||||
#else
|
||||
// It is not easy to get the compute capability here, so we use numCTAs to
|
||||
// decide the semantic of GetProgramIdOp. If numCTAs = 1, then
|
||||
// GetProgramIdOp is converted to "%ctaid", otherwise it is converted to
|
||||
// "%clusterid".
|
||||
auto moduleOp = op->getParentOfType<ModuleOp>();
|
||||
assert(moduleOp && "Parent ModuleOp not found for GetProgramIdOp");
|
||||
int numCTAs = triton::gpu_rocm::TritonGPUROCMDialect::getNumCTAs(moduleOp);
|
||||
|
||||
Location loc = op->getLoc();
|
||||
assert(op.getAxisAsInt() < 3);
|
||||
std::string sreg = numCTAs == 1 ? "%ctaid." : "%clusterid.";
|
||||
sreg.append(1, 'x' + op.getAxisAsInt()); // 0 -> 'x', 1 -> 'y', 2 -> 'z'
|
||||
|
||||
Value programId = getSRegValue(rewriter, loc, sreg);
|
||||
rewriter.replaceOp(op, programId);
|
||||
return success();
|
||||
#endif
|
||||
}
|
||||
static constexpr mlir::gpu::Dimension dims[] = {mlir::gpu::Dimension::x,
|
||||
mlir::gpu::Dimension::y,
|
||||
mlir::gpu::Dimension::z};
|
||||
};
|
||||
|
||||
struct GetNumProgramsOpConversion
|
||||
: public ConvertTritonGPUOpToLLVMPattern<triton::GetNumProgramsOp> {
|
||||
using ConvertTritonGPUOpToLLVMPattern<
|
||||
triton::GetNumProgramsOp>::ConvertTritonGPUOpToLLVMPattern;
|
||||
|
||||
LogicalResult
|
||||
matchAndRewrite(triton::GetNumProgramsOp op, OpAdaptor adaptor,
|
||||
ConversionPatternRewriter &rewriter) const override {
|
||||
|
||||
#if 1
|
||||
|
||||
Location loc = op->getLoc();
|
||||
assert(op.getAxis() < 3);
|
||||
// Seem like GridDimOp returns the number of threads (not the number of
|
||||
// workgroups) in a kernel (a bug in llvm https://reviews.llvm.org/D156009),
|
||||
// so as a workaround here, we divide by the number of threads
|
||||
// per workgroup to get the number of workgroups in a kernel.
|
||||
// TODO: when we do upstream to include llvm fix, we can remove this workaround
|
||||
// The unit test added in this PR can guarantee that.
|
||||
Value threadsPerGrid =
|
||||
rewriter.create<::mlir::gpu::GridDimOp>(loc, dims[op.getAxis()]);
|
||||
Value threadsPerBlock =
|
||||
rewriter.create<::mlir::gpu::BlockDimOp>(loc, dims[op.getAxis()]);
|
||||
Value threadNumPerGrid = rewriter.create<arith::TruncIOp>(loc, i32_ty, threadsPerGrid);
|
||||
Value threadNumPerBlock = rewriter.create<arith::TruncIOp>(loc, i32_ty, threadsPerBlock);
|
||||
rewriter.replaceOpWithNewOp<LLVM::UDivOp>(op, threadNumPerGrid, threadNumPerBlock);
|
||||
return success();
|
||||
#else
|
||||
// It is not easy to get the compute capability here, so we use numCTAs to
|
||||
// decide the semantic of GetNumProgramsOp. If numCTAs = 1, then
|
||||
// GetNumProgramsOp is converted to "%nctaid", otherwise it is converted to
|
||||
// "%nclusterid".
|
||||
auto moduleOp = op->getParentOfType<ModuleOp>();
|
||||
assert(moduleOp && "Parent ModuleOp not found for GetProgramIdOp");
|
||||
int numCTAs = triton::gpu_rocm::TritonGPUROCMDialect::getNumCTAs(moduleOp);
|
||||
|
||||
Location loc = op->getLoc();
|
||||
assert(op.getAxis() < 3);
|
||||
std::string sreg = numCTAs == 1 ? "%nctaid." : "%nclusterid.";
|
||||
sreg.append(1, 'x' + op.getAxis()); // 0 -> 'x', 1 -> 'y', 2 -> 'z'
|
||||
|
||||
Value numPrograms = getSRegValue(rewriter, loc, sreg);
|
||||
rewriter.replaceOp(op, numPrograms);
|
||||
return success();
|
||||
|
||||
#endif
|
||||
}
|
||||
|
||||
static constexpr mlir::gpu::Dimension dims[] = {mlir::gpu::Dimension::x,
|
||||
mlir::gpu::Dimension::y,
|
||||
mlir::gpu::Dimension::z};
|
||||
};
|
||||
|
||||
// TODO[goostavz]: GetThreadIdOp/GetClusterCTAIdOp is a temporary solution
|
||||
// before async dialect is done. These concepts should appear in ttgpu
|
||||
// level, and they are planned to be deprecated along with ttgpu.mbarrier_xxx
|
||||
// ops.
|
||||
struct GetThreadIdOpConversion : public ConvertTritonGPUOpToLLVMPattern<
|
||||
triton::nvidia_gpu::GetThreadIdOp> {
|
||||
using ConvertTritonGPUOpToLLVMPattern<
|
||||
triton::nvidia_gpu::GetThreadIdOp>::ConvertTritonGPUOpToLLVMPattern;
|
||||
|
||||
LogicalResult
|
||||
matchAndRewrite(triton::nvidia_gpu::GetThreadIdOp op, OpAdaptor adaptor,
|
||||
ConversionPatternRewriter &rewriter) const override {
|
||||
rewriter.replaceOp(op, getThreadId(rewriter, op->getLoc()));
|
||||
return success();
|
||||
}
|
||||
};
|
||||
|
||||
struct GetClusterCTAIdOpConversion
|
||||
: public ConvertTritonGPUOpToLLVMPattern<
|
||||
triton::nvidia_gpu::GetClusterCTAIdOp> {
|
||||
using ConvertTritonGPUOpToLLVMPattern<
|
||||
triton::nvidia_gpu::GetClusterCTAIdOp>::ConvertTritonGPUOpToLLVMPattern;
|
||||
|
||||
LogicalResult
|
||||
matchAndRewrite(triton::nvidia_gpu::GetClusterCTAIdOp op, OpAdaptor adaptor,
|
||||
ConversionPatternRewriter &rewriter) const override {
|
||||
rewriter.replaceOp(op, getClusterCTAId(rewriter, op->getLoc()));
|
||||
return success();
|
||||
}
|
||||
};
|
||||
|
||||
struct AddPtrOpConversion
|
||||
: public ConvertTritonGPUOpToLLVMPattern<triton::AddPtrOp> {
|
||||
using ConvertTritonGPUOpToLLVMPattern<
|
||||
triton::AddPtrOp>::ConvertTritonGPUOpToLLVMPattern;
|
||||
|
||||
LogicalResult
|
||||
matchAndRewrite(triton::AddPtrOp op, OpAdaptor adaptor,
|
||||
ConversionPatternRewriter &rewriter) const override {
|
||||
Location loc = op->getLoc();
|
||||
auto resultTy = op.getType();
|
||||
auto offsetTy = op.getOffset().getType();
|
||||
auto ptrTy = op.getPtr().getType();
|
||||
auto resultTensorTy = resultTy.dyn_cast<RankedTensorType>();
|
||||
if (resultTensorTy) {
|
||||
unsigned elems = getTotalElemsPerThread(resultTy);
|
||||
Type elemTy =
|
||||
getTypeConverter()->convertType(resultTensorTy.getElementType());
|
||||
auto ptrs = getTypeConverter()->unpackLLElements(loc, adaptor.getPtr(),
|
||||
rewriter, ptrTy);
|
||||
auto offsets = getTypeConverter()->unpackLLElements(
|
||||
loc, adaptor.getOffset(), rewriter, offsetTy);
|
||||
SmallVector<Value> resultVals(elems);
|
||||
for (unsigned i = 0; i < elems; ++i) {
|
||||
resultVals[i] = gep(elemTy, ptrs[i], offsets[i]);
|
||||
}
|
||||
Value view = getTypeConverter()->packLLElements(loc, resultVals, rewriter,
|
||||
resultTy);
|
||||
rewriter.replaceOp(op, view);
|
||||
} else {
|
||||
assert(resultTy.isa<triton::PointerType>());
|
||||
Type llResultTy = getTypeConverter()->convertType(resultTy);
|
||||
Value result = gep(llResultTy, adaptor.getPtr(), adaptor.getOffset());
|
||||
rewriter.replaceOp(op, result);
|
||||
}
|
||||
return success();
|
||||
}
|
||||
};
|
||||
|
||||
struct AllocTensorOpConversion
|
||||
: public ConvertTritonGPUOpToLLVMPattern<triton::gpu_rocm::AllocTensorOp> {
|
||||
using ConvertTritonGPUOpToLLVMPattern<
|
||||
triton::gpu_rocm::AllocTensorOp>::ConvertTritonGPUOpToLLVMPattern;
|
||||
|
||||
LogicalResult
|
||||
matchAndRewrite(triton::gpu_rocm::AllocTensorOp op, OpAdaptor adaptor,
|
||||
ConversionPatternRewriter &rewriter) const override {
|
||||
Location loc = op->getLoc();
|
||||
Value smemBase = getSharedMemoryBase(loc, rewriter, op.getResult());
|
||||
auto resultTy = op.getType().dyn_cast<RankedTensorType>();
|
||||
auto llvmElemTy =
|
||||
getTypeConverter()->convertType(resultTy.getElementType());
|
||||
auto elemPtrTy = ptr_ty(llvmElemTy, 3);
|
||||
smemBase = bitcast(smemBase, elemPtrTy);
|
||||
auto sharedLayout = resultTy.getEncoding().cast<SharedEncodingAttr>();
|
||||
auto order = sharedLayout.getOrder();
|
||||
// Workaround for 3D tensors
|
||||
// TODO: we need to modify the pipeline pass to give a proper shared
|
||||
// encoding to 3D tensors
|
||||
SmallVector<unsigned> newOrder;
|
||||
if (resultTy.getShape().size() == 3)
|
||||
newOrder = {1 + order[0], 1 + order[1], 0};
|
||||
else
|
||||
newOrder = SmallVector<unsigned>(order.begin(), order.end());
|
||||
|
||||
auto shapePerCTA = getShapePerCTA(sharedLayout, resultTy.getShape());
|
||||
auto smemObj =
|
||||
SharedMemoryObject(smemBase, shapePerCTA, newOrder, loc, rewriter);
|
||||
auto retVal = getStructFromSharedMemoryObject(loc, smemObj, rewriter);
|
||||
rewriter.replaceOp(op, retVal);
|
||||
return success();
|
||||
}
|
||||
};
|
||||
|
||||
struct ExtractSliceOpConversion
|
||||
: public ConvertTritonGPUOpToLLVMPattern<triton::gpu_rocm::ExtractSliceOp> {
|
||||
using ConvertTritonGPUOpToLLVMPattern<
|
||||
triton::gpu_rocm::ExtractSliceOp>::ConvertTritonGPUOpToLLVMPattern;
|
||||
|
||||
LogicalResult
|
||||
matchAndRewrite(triton::gpu_rocm::ExtractSliceOp op, OpAdaptor adaptor,
|
||||
ConversionPatternRewriter &rewriter) const override {
|
||||
// %dst = extract_slice %src[%offsets]
|
||||
Location loc = op->getLoc();
|
||||
auto srcTy = op.getSource().getType().dyn_cast<RankedTensorType>();
|
||||
auto srcLayout = srcTy.getEncoding().dyn_cast<SharedEncodingAttr>();
|
||||
assert(srcLayout && "Unexpected resultLayout in ExtractSliceOpConversion");
|
||||
assert(op.hasUnitStride() &&
|
||||
"Only unit stride supported by ExtractSliceOpConversion");
|
||||
|
||||
// newBase = base + offset
|
||||
// Triton supports either static and dynamic offsets
|
||||
auto smemObj =
|
||||
getSharedMemoryObjectFromStruct(loc, adaptor.getSource(), rewriter);
|
||||
SmallVector<Value, 4> opOffsetVals;
|
||||
SmallVector<Value, 4> offsetVals;
|
||||
auto mixedOffsets = op.getMixedOffsets();
|
||||
for (auto i = 0; i < mixedOffsets.size(); ++i) {
|
||||
if (op.isDynamicOffset(i))
|
||||
opOffsetVals.emplace_back(adaptor.getOffsets()[i]);
|
||||
else
|
||||
opOffsetVals.emplace_back(i32_val(op.getStaticOffset(i)));
|
||||
offsetVals.emplace_back(add(smemObj.offsets[i], opOffsetVals[i]));
|
||||
}
|
||||
// Compute the offset based on the original strides of the shared memory
|
||||
// object
|
||||
auto offset = dot(rewriter, loc, opOffsetVals, smemObj.strides);
|
||||
// newShape = rank_reduce(shape)
|
||||
// Triton only supports static tensor sizes
|
||||
SmallVector<Value, 4> strideVals;
|
||||
for (auto i = 0; i < op.getStaticSizes().size(); ++i) {
|
||||
if (op.getStaticSize(i) == 1) {
|
||||
offsetVals.erase(offsetVals.begin() + i);
|
||||
} else {
|
||||
strideVals.emplace_back(smemObj.strides[i]);
|
||||
}
|
||||
}
|
||||
|
||||
auto llvmElemTy = getTypeConverter()->convertType(srcTy.getElementType());
|
||||
auto elemPtrTy = ptr_ty(llvmElemTy, 3);
|
||||
smemObj = SharedMemoryObject(gep(elemPtrTy, smemObj.base, offset),
|
||||
strideVals, offsetVals);
|
||||
auto retVal = getStructFromSharedMemoryObject(loc, smemObj, rewriter);
|
||||
rewriter.replaceOp(op, retVal);
|
||||
return success();
|
||||
}
|
||||
};
|
||||
|
||||
struct AsyncWaitOpConversion
|
||||
: public ConvertTritonGPUOpToLLVMPattern<triton::gpu_rocm::AsyncWaitOp> {
|
||||
using ConvertTritonGPUOpToLLVMPattern<
|
||||
triton::gpu_rocm::AsyncWaitOp>::ConvertTritonGPUOpToLLVMPattern;
|
||||
|
||||
LogicalResult
|
||||
matchAndRewrite(triton::gpu_rocm::AsyncWaitOp op, OpAdaptor adaptor,
|
||||
ConversionPatternRewriter &rewriter) const override {
|
||||
PTXBuilder ptxBuilder;
|
||||
auto &asyncWaitOp = *ptxBuilder.create<>("cp.async.wait_group");
|
||||
auto num = op->getAttrOfType<IntegerAttr>("num").getInt();
|
||||
asyncWaitOp(ptxBuilder.newConstantOperand(num));
|
||||
|
||||
auto ctx = op.getContext();
|
||||
auto loc = op.getLoc();
|
||||
auto voidTy = void_ty(ctx);
|
||||
ptxBuilder.launch(rewriter, loc, voidTy);
|
||||
|
||||
// Safe to remove the op since it doesn't have any return value.
|
||||
rewriter.eraseOp(op);
|
||||
return success();
|
||||
}
|
||||
};
|
||||
|
||||
struct AsyncCommitGroupOpConversion
|
||||
: public ConvertTritonGPUOpToLLVMPattern<triton::gpu_rocm::AsyncCommitGroupOp> {
|
||||
using ConvertTritonGPUOpToLLVMPattern<
|
||||
triton::gpu_rocm::AsyncCommitGroupOp>::ConvertTritonGPUOpToLLVMPattern;
|
||||
|
||||
LogicalResult
|
||||
matchAndRewrite(triton::gpu_rocm::AsyncCommitGroupOp op, OpAdaptor adaptor,
|
||||
ConversionPatternRewriter &rewriter) const override {
|
||||
|
||||
PTXBuilder ptxBuilder;
|
||||
ptxBuilder.create<>("cp.async.commit_group")->operator()();
|
||||
ptxBuilder.launch(rewriter, op.getLoc(), void_ty(op.getContext()));
|
||||
// Safe to remove the op since it doesn't have any return value.
|
||||
rewriter.eraseOp(op);
|
||||
return success();
|
||||
}
|
||||
};
|
||||
|
||||
struct AsyncBulkWaitOpConversion
|
||||
: public ConvertTritonGPUOpToLLVMPattern<triton::gpu_rocm::AsyncBulkWaitOp> {
|
||||
using ConvertTritonGPUOpToLLVMPattern<
|
||||
triton::gpu_rocm::AsyncBulkWaitOp>::ConvertTritonGPUOpToLLVMPattern;
|
||||
|
||||
LogicalResult
|
||||
matchAndRewrite(triton::gpu_rocm::AsyncBulkWaitOp op, OpAdaptor adaptor,
|
||||
ConversionPatternRewriter &rewriter) const override {
|
||||
PTXBuilder ptxBuilder;
|
||||
auto &asyncBulkWaitOp = *ptxBuilder.create<>("cp.async.bulk.wait_group");
|
||||
auto num = op->getAttrOfType<IntegerAttr>("num").getInt();
|
||||
asyncBulkWaitOp(ptxBuilder.newConstantOperand(num));
|
||||
|
||||
auto ctx = op.getContext();
|
||||
auto loc = op.getLoc();
|
||||
auto voidTy = void_ty(ctx);
|
||||
ptxBuilder.launch(rewriter, loc, voidTy);
|
||||
|
||||
// Safe to remove the op since it doesn't have any return value.
|
||||
rewriter.eraseOp(op);
|
||||
return success();
|
||||
}
|
||||
};
|
||||
|
||||
struct AsyncBulkCommitGroupOpConversion
|
||||
: public ConvertTritonGPUOpToLLVMPattern<
|
||||
triton::gpu_rocm::AsyncBulkCommitGroupOp> {
|
||||
using ConvertTritonGPUOpToLLVMPattern<
|
||||
triton::gpu_rocm::AsyncBulkCommitGroupOp>::ConvertTritonGPUOpToLLVMPattern;
|
||||
|
||||
LogicalResult
|
||||
matchAndRewrite(triton::gpu_rocm::AsyncBulkCommitGroupOp op, OpAdaptor adaptor,
|
||||
ConversionPatternRewriter &rewriter) const override {
|
||||
|
||||
PTXBuilder ptxBuilder;
|
||||
ptxBuilder.create<>("cp.async.bulk.commit_group")->operator()();
|
||||
ptxBuilder.launch(rewriter, op.getLoc(), void_ty(op.getContext()));
|
||||
// Safe to remove the op since it doesn't have any return value.
|
||||
rewriter.eraseOp(op);
|
||||
return success();
|
||||
}
|
||||
};
|
||||
|
||||
namespace mlir {
|
||||
namespace LLVM {
|
||||
|
||||
void vprintf(StringRef msg, ValueRange args,
|
||||
ConversionPatternRewriter &rewriter) {
|
||||
PrintOpConversion::llPrintf(msg, args, rewriter);
|
||||
}
|
||||
|
||||
void vprintf_array(Value thread, ArrayRef<Value> arr, std::string info,
|
||||
std::string elem_repr, ConversionPatternRewriter &builder) {
|
||||
std::string fmt = info + " t-%d ";
|
||||
std::vector<Value> new_arr({thread});
|
||||
for (int i = 0; i < arr.size(); ++i) {
|
||||
fmt += elem_repr + ((i == arr.size() - 1) ? "" : ", ");
|
||||
new_arr.push_back(arr[i]);
|
||||
}
|
||||
|
||||
vprintf(fmt, new_arr, builder);
|
||||
}
|
||||
|
||||
} // namespace LLVM
|
||||
} // namespace mlir
|
||||
|
||||
void populateTritonGPUToLLVMPatterns(
|
||||
TritonGPUToLLVMTypeConverter &typeConverter, RewritePatternSet &patterns,
|
||||
int numWarps, ModuleAxisInfoAnalysis &axisInfoAnalysis,
|
||||
ModuleAllocation &moduleAllocation,
|
||||
ConvertTritonGPUOpToLLVMPatternBase::IndexCacheInfo &indexCacheInfo,
|
||||
PatternBenefit benefit) {
|
||||
patterns.add<AddPtrOpConversion>(typeConverter, benefit);
|
||||
patterns.add<AllocTensorOpConversion>(typeConverter, moduleAllocation,
|
||||
benefit);
|
||||
patterns.add<AsyncCommitGroupOpConversion>(typeConverter, benefit);
|
||||
patterns.add<AsyncWaitOpConversion>(typeConverter, benefit);
|
||||
patterns.add<AsyncBulkCommitGroupOpConversion>(typeConverter, benefit);
|
||||
patterns.add<AsyncBulkWaitOpConversion>(typeConverter, benefit);
|
||||
patterns.add<BroadcastOpConversion>(typeConverter, benefit);
|
||||
patterns.add<ExtractSliceOpConversion>(typeConverter, moduleAllocation,
|
||||
benefit);
|
||||
patterns.add<GetProgramIdOpConversion>(typeConverter, benefit);
|
||||
patterns.add<GetNumProgramsOpConversion>(typeConverter, benefit);
|
||||
patterns.add<GetThreadIdOpConversion>(typeConverter, benefit);
|
||||
patterns.add<GetClusterCTAIdOpConversion>(typeConverter, benefit);
|
||||
patterns.add<MakeRangeOpConversion>(typeConverter, indexCacheInfo, benefit);
|
||||
patterns.add<ReturnOpConversion>(typeConverter, benefit);
|
||||
patterns.add<PrintOpConversion>(typeConverter, benefit);
|
||||
patterns.add<AssertOpConversion>(typeConverter, benefit);
|
||||
}
|
||||
16
python/triton/third_party/hip/lib/Conversion/TritonGPUROCMToLLVM/TritonGPUToLLVM.h
vendored
Normal file
16
python/triton/third_party/hip/lib/Conversion/TritonGPUROCMToLLVM/TritonGPUToLLVM.h
vendored
Normal file
@@ -0,0 +1,16 @@
|
||||
#ifndef TRITON_CONVERSION_TRITONGPUROCM_TO_LLVM_H
|
||||
#define TRITON_CONVERSION_TRITONGPUROCM_TO_LLVM_H
|
||||
|
||||
#include "TritonGPUToLLVMBase.h"
|
||||
|
||||
using namespace mlir;
|
||||
using namespace mlir::triton;
|
||||
|
||||
void populateTritonGPUToLLVMPatterns(
|
||||
TritonGPUToLLVMTypeConverter &typeConverter, RewritePatternSet &patterns,
|
||||
int numWarps, ModuleAxisInfoAnalysis &axisInfoAnalysis,
|
||||
ModuleAllocation &allocation,
|
||||
ConvertTritonGPUOpToLLVMPatternBase::IndexCacheInfo &indexCacheInfo,
|
||||
PatternBenefit benefit);
|
||||
|
||||
#endif
|
||||
1444
python/triton/third_party/hip/lib/Conversion/TritonGPUROCMToLLVM/TritonGPUToLLVMBase.h
vendored
Normal file
1444
python/triton/third_party/hip/lib/Conversion/TritonGPUROCMToLLVM/TritonGPUToLLVMBase.h
vendored
Normal file
File diff suppressed because it is too large
Load Diff
936
python/triton/third_party/hip/lib/Conversion/TritonGPUROCMToLLVM/TritonGPUToLLVMPass.cpp
vendored
Normal file
936
python/triton/third_party/hip/lib/Conversion/TritonGPUROCMToLLVM/TritonGPUToLLVMPass.cpp
vendored
Normal file
@@ -0,0 +1,936 @@
|
||||
#include "triton/Conversion/TritonGPUROCMToLLVM/TritonGPUToLLVMPass.h"
|
||||
|
||||
#include "mlir/Analysis/DataFlowFramework.h"
|
||||
#include "mlir/Conversion/ArithToLLVM/ArithToLLVM.h"
|
||||
#include "mlir/Conversion/ControlFlowToLLVM/ControlFlowToLLVM.h"
|
||||
#include "mlir/Conversion/GPUToNVVM/GPUToNVVMPass.h"
|
||||
#include "mlir/Conversion/GPUToROCDL/GPUToROCDLPass.h"
|
||||
#include "mlir/Conversion/LLVMCommon/VectorPattern.h"
|
||||
#include "mlir/Conversion/MathToLLVM/MathToLLVM.h"
|
||||
#include "mlir/Conversion/SCFToControlFlow/SCFToControlFlow.h"
|
||||
#include "mlir/Dialect/Index/IR/IndexDialect.h"
|
||||
#include "mlir/Dialect/Index/IR/IndexOps.h"
|
||||
#include "mlir/Dialect/LLVMIR/LLVMDialect.h"
|
||||
#include "mlir/Dialect/LLVMIR/NVVMDialect.h"
|
||||
#include "mlir/Dialect/LLVMIR/ROCDLDialect.h"
|
||||
#include "mlir/Pass/Pass.h"
|
||||
#include "mlir/Transforms/GreedyPatternRewriteDriver.h"
|
||||
#include "triton/AnalysisROCM/Allocation.h"
|
||||
#include "triton/AnalysisROCM/AxisInfo.h"
|
||||
#include "triton/AnalysisROCM/Membar.h"
|
||||
#include "triton/Dialect/Triton/IR/Dialect.h"
|
||||
#include "triton/Dialect/TritonGPUROCM/IR/Dialect.h"
|
||||
#if 0
|
||||
#else
|
||||
#include "triton/Dialect/TritonNvidiaGPU/IR/Dialect.h"
|
||||
#endif
|
||||
#include "triton/Tools/Sys/GetPlatform.hpp"
|
||||
|
||||
#include "BarrierOpToLLVM.h"
|
||||
#include "ClusterOpsToLLVM.h"
|
||||
#include "ConvertLayoutOpToLLVM.h"
|
||||
#include "DotOpToLLVM.h"
|
||||
#include "ElementwiseOpToLLVM.h"
|
||||
#include "LoadStoreOpToLLVM.h"
|
||||
#include "ReduceOpToLLVM.h"
|
||||
#include "RegReallocOpToLLVM.h"
|
||||
#include "ScanOpToLLVM.h"
|
||||
#include "TensorPtrOpsToLLVM.h"
|
||||
#include "TritonGPUToLLVM.h"
|
||||
#include "TritonGPUToLLVMBase.h"
|
||||
#include "TypeConverter.h"
|
||||
#include "ViewOpToLLVM.h"
|
||||
|
||||
#include "mlir/Dialect/ControlFlow/IR/ControlFlowOps.h"
|
||||
|
||||
namespace mlir {
|
||||
namespace triton {
|
||||
#define GEN_PASS_DEF_CONVERTTRITONGPUROCMTOLLVM
|
||||
#include "triton/Conversion/TritonGPUROCMToLLVM/Passes.h.inc"
|
||||
} // namespace triton
|
||||
} // namespace mlir
|
||||
|
||||
using namespace mlir;
|
||||
using namespace mlir::triton;
|
||||
namespace ttng = mlir::triton::nvidia_gpu;
|
||||
|
||||
namespace {
|
||||
|
||||
// pass ws related named attrs.
|
||||
static void addWSNamedAttrs(Operation *op,
|
||||
ArrayRef<mlir::NamedAttribute> attrs) {
|
||||
for (const NamedAttribute attr : attrs)
|
||||
if (attr.getName() == "async_agent" || attr.getName() == "agent.mutex_role")
|
||||
op->setAttr(attr.getName(), attr.getValue());
|
||||
}
|
||||
|
||||
class TritonLLVMFunctionConversionTarget : public ConversionTarget {
|
||||
public:
|
||||
explicit TritonLLVMFunctionConversionTarget(MLIRContext &ctx, Target target)
|
||||
: ConversionTarget(ctx) {
|
||||
addLegalDialect<index::IndexDialect>();
|
||||
addLegalDialect<LLVM::LLVMDialect>();
|
||||
switch (target) {
|
||||
case Target::NVVM:
|
||||
addLegalDialect<NVVM::NVVMDialect>();
|
||||
break;
|
||||
case Target::ROCDL:
|
||||
addLegalDialect<ROCDL::ROCDLDialect>();
|
||||
addLegalDialect<mlir::scf::SCFDialect>();
|
||||
break;
|
||||
}
|
||||
addLegalOp<mlir::UnrealizedConversionCastOp>();
|
||||
}
|
||||
};
|
||||
|
||||
class FoldSplatMaskInInsertAsync : public mlir::RewritePattern {
|
||||
|
||||
public:
|
||||
FoldSplatMaskInInsertAsync(mlir::MLIRContext *context)
|
||||
: mlir::RewritePattern(
|
||||
triton::nvidia_gpu::InsertSliceAsyncV2Op::getOperationName(), 1,
|
||||
context) {}
|
||||
|
||||
LogicalResult
|
||||
matchAndRewrite(mlir::Operation *op,
|
||||
mlir::PatternRewriter &rewriter) const override {
|
||||
auto insertOp = cast<triton::nvidia_gpu::InsertSliceAsyncV2Op>(op);
|
||||
if (!insertOp.getMask())
|
||||
return failure();
|
||||
auto splatOp = insertOp.getMask().getDefiningOp<triton::SplatOp>();
|
||||
if (!splatOp)
|
||||
return failure();
|
||||
rewriter.updateRootInPlace(insertOp, [&]() {
|
||||
insertOp.getMaskMutable().assign(splatOp->getOperand(0));
|
||||
});
|
||||
return mlir::success();
|
||||
}
|
||||
};
|
||||
|
||||
struct ReturnOpConversion : public ConvertOpToLLVMPattern<triton::ReturnOp> {
|
||||
using ConvertOpToLLVMPattern<triton::ReturnOp>::ConvertOpToLLVMPattern;
|
||||
|
||||
LogicalResult
|
||||
matchAndRewrite(triton::ReturnOp op, OpAdaptor adaptor,
|
||||
ConversionPatternRewriter &rewriter) const override {
|
||||
auto funcOp = op->getParentOfType<LLVM::LLVMFuncOp>();
|
||||
if (funcOp->hasAttr("nvvm.kernel")) {
|
||||
// A GPU kernel
|
||||
if (op.getNumOperands() > 0) {
|
||||
return rewriter.notifyMatchFailure(
|
||||
op, "Kernel functions do not support return with operands");
|
||||
}
|
||||
rewriter.replaceOpWithNewOp<LLVM::ReturnOp>(op, TypeRange(), ValueRange(),
|
||||
op->getAttrs());
|
||||
} else {
|
||||
// A device function
|
||||
LLVM::ReturnOp newOp;
|
||||
if (adaptor.getOperands().size() < 2) {
|
||||
// Single or no return value.
|
||||
newOp =
|
||||
rewriter.create<LLVM::ReturnOp>(op.getLoc(), adaptor.getOperands());
|
||||
} else {
|
||||
// Pack the results into a struct.
|
||||
auto packedResultsTy = this->getTypeConverter()->packFunctionResults(
|
||||
funcOp.getResultTypes());
|
||||
Value packedResults =
|
||||
rewriter.create<LLVM::UndefOp>(op.getLoc(), packedResultsTy);
|
||||
auto loc = op.getLoc();
|
||||
for (auto it : llvm::enumerate(adaptor.getOperands())) {
|
||||
packedResults = insert_val(packedResultsTy, packedResults, it.value(),
|
||||
it.index());
|
||||
}
|
||||
newOp = rewriter.create<LLVM::ReturnOp>(op.getLoc(), packedResults);
|
||||
}
|
||||
newOp->setAttrs(op->getAttrs());
|
||||
rewriter.replaceOp(op, newOp->getResults());
|
||||
}
|
||||
return success();
|
||||
}
|
||||
};
|
||||
|
||||
/// FuncOp legalization pattern that converts MemRef arguments to pointers to
|
||||
/// MemRef descriptors (LLVM struct data types) containing all the MemRef type
|
||||
/// information.
|
||||
struct FuncOpConversion : public FuncOpConversionBase {
|
||||
FuncOpConversion(LLVMTypeConverter &converter, int numWarps,
|
||||
ModuleAllocation &allocation, PatternBenefit benefit)
|
||||
: FuncOpConversionBase(converter, benefit), numWarps(numWarps),
|
||||
allocation(allocation) {}
|
||||
|
||||
triton::FuncOp amendFuncOp(triton::FuncOp funcOp,
|
||||
ConversionPatternRewriter &rewriter) const {
|
||||
// Push back a variable that indicates the current stack pointer of shared
|
||||
// memory to the function arguments.
|
||||
auto loc = funcOp.getLoc();
|
||||
auto ctx = funcOp->getContext();
|
||||
auto ptrTy = LLVM::LLVMPointerType::get(
|
||||
this->getTypeConverter()->convertType(rewriter.getI8Type()), 3);
|
||||
// 1. Modify the function type to add the new argument.
|
||||
auto funcTy = funcOp.getFunctionType();
|
||||
auto amendedInputTy = llvm::to_vector<4>(funcTy.getInputs());
|
||||
amendedInputTy.push_back(ptrTy);
|
||||
auto amendedFuncTy = FunctionType::get(funcTy.getContext(), amendedInputTy,
|
||||
funcTy.getResults());
|
||||
// 2. Modify the argument attributes to add the new argument.
|
||||
SmallVector<NamedAttribute> amendedAttrs;
|
||||
filterFuncAttributes(funcOp, /*filterArgAttrs=*/true, amendedAttrs);
|
||||
auto amendedArgAttrs = llvm::to_vector<4>(funcOp.getAllArgAttrs());
|
||||
amendedArgAttrs.emplace_back(DictionaryAttr::get(ctx));
|
||||
amendedAttrs.push_back(rewriter.getNamedAttr(
|
||||
funcOp.getArgAttrsAttrName(), rewriter.getArrayAttr(amendedArgAttrs)));
|
||||
// 3. Add a new argument to the region
|
||||
auto amendedFuncOp = rewriter.create<triton::FuncOp>(
|
||||
funcOp.getLoc(), funcOp.getName(), amendedFuncTy, amendedAttrs);
|
||||
auto ®ion = funcOp.getBody();
|
||||
region.addArgument(ptrTy, loc);
|
||||
rewriter.inlineRegionBefore(region, amendedFuncOp.getBody(),
|
||||
amendedFuncOp.end());
|
||||
return amendedFuncOp;
|
||||
}
|
||||
|
||||
LogicalResult
|
||||
matchAndRewrite(triton::FuncOp funcOp, OpAdaptor adaptor,
|
||||
ConversionPatternRewriter &rewriter) const override {
|
||||
// Prevent LLVM's inliner to inline this function
|
||||
auto amendedFuncOp = funcOp;
|
||||
if (!allocation.isRoot(funcOp))
|
||||
amendedFuncOp = amendFuncOp(funcOp, rewriter);
|
||||
|
||||
// Collect TMA informations.
|
||||
unsigned numTMALoad = 0;
|
||||
funcOp.walk(
|
||||
[&numTMALoad](triton::nvidia_gpu::InsertSliceAsyncV2Op insertSliceOp) {
|
||||
numTMALoad++;
|
||||
});
|
||||
unsigned numTMAStore = 0;
|
||||
funcOp.walk([&numTMAStore](triton::nvidia_gpu::StoreAsyncOp storeAsyncOp) {
|
||||
numTMAStore++;
|
||||
});
|
||||
unsigned numTMA = numTMALoad + numTMAStore;
|
||||
|
||||
auto newFuncOp = convertFuncOpToLLVMFuncOp(amendedFuncOp, rewriter);
|
||||
if (!newFuncOp) {
|
||||
return failure();
|
||||
}
|
||||
|
||||
auto ctx = funcOp->getContext();
|
||||
|
||||
if (allocation.isRoot(funcOp)) {
|
||||
// Set an attribute to indicate this function is a kernel entry.
|
||||
newFuncOp->setAttr("nvvm.kernel",
|
||||
rewriter.getIntegerAttr(type::u1Ty(ctx), 1));
|
||||
} else {
|
||||
// The noinline attribute will be used by the LLVM codegen to prevent
|
||||
// inlining.
|
||||
// https://github.com/llvm/llvm-project/blob/main/mlir/lib/Dialect/LLVMIR/IR/LLVMInlining.cpp#L267
|
||||
newFuncOp.setPassthroughAttr(
|
||||
ArrayAttr::get(ctx, rewriter.getStringAttr("noinline")));
|
||||
rewriter.eraseOp(amendedFuncOp);
|
||||
}
|
||||
#if 0
|
||||
// Set an attribute for maxntidx, it could be used in latter LLVM codegen
|
||||
// for `nvvm.annotation` metadata.
|
||||
newFuncOp->setAttr("nvvm.maxntid", rewriter.getI32ArrayAttr(32 * numWarps));
|
||||
#endif
|
||||
// The call graph is updated by mapping the old function to the new one.
|
||||
allocation.mapFuncOp(funcOp, newFuncOp);
|
||||
|
||||
// Append arguments to receive TMADesc in global memory in the runtime
|
||||
auto i8PtrTy = LLVM::LLVMPointerType::get(
|
||||
this->getTypeConverter()->convertType(rewriter.getI8Type()), 1);
|
||||
auto numArgs = newFuncOp.getBody().front().getNumArguments();
|
||||
auto funcTy = newFuncOp.getFunctionType().cast<LLVM::LLVMFunctionType>();
|
||||
SmallVector<Type> newInputsTy(funcTy.getParams().begin(),
|
||||
funcTy.getParams().end());
|
||||
for (unsigned i = 0; i < numTMA; ++i) {
|
||||
newFuncOp.getBody().front().addArgument(i8PtrTy, funcOp.getLoc());
|
||||
newInputsTy.push_back(i8PtrTy);
|
||||
}
|
||||
newFuncOp.setType(
|
||||
LLVM::LLVMFunctionType::get(funcTy.getReturnType(), newInputsTy));
|
||||
// required by AxisInfoAnalysis
|
||||
for (unsigned i = 0; i < numTMA; ++i) {
|
||||
newFuncOp.setArgAttr(numArgs + i, "tt.divisibility",
|
||||
rewriter.getIntegerAttr(i32_ty, 1));
|
||||
}
|
||||
|
||||
newFuncOp->setAttr(kAttrNumTMALoadDescsName,
|
||||
rewriter.getIntegerAttr(i32_ty, numTMALoad));
|
||||
newFuncOp->setAttr(kAttrNumTMAStoreDescsName,
|
||||
rewriter.getIntegerAttr(i32_ty, numTMAStore));
|
||||
|
||||
rewriter.eraseOp(funcOp);
|
||||
return success();
|
||||
}
|
||||
|
||||
private:
|
||||
int numWarps{0};
|
||||
ModuleAllocation &allocation;
|
||||
};
|
||||
|
||||
// CallOpInterfaceLowering is adapted from
|
||||
// https://github.com/llvm/llvm-project/blob/fae656b2dd80246c3c6f01e9c77c49560368752c/mlir/lib/Conversion/FuncToLLVM/FuncToLLVM.cpp#L485
|
||||
struct CallOpConversion : public ConvertOpToLLVMPattern<triton::CallOp> {
|
||||
CallOpConversion(LLVMTypeConverter &converter, int numWarps,
|
||||
ModuleAllocation &allocation, PatternBenefit benefit)
|
||||
: ConvertOpToLLVMPattern<triton::CallOp>(converter, benefit),
|
||||
numWarps(numWarps), allocation(allocation) {}
|
||||
|
||||
LogicalResult
|
||||
matchAndRewrite(triton::CallOp callOp,
|
||||
typename triton::CallOp::Adaptor adaptor,
|
||||
ConversionPatternRewriter &rewriter) const override {
|
||||
auto promotedOperands = promoteOperands(callOp, adaptor, rewriter);
|
||||
auto newCallOp =
|
||||
convertCallOpToLLVMCallOp(callOp, promotedOperands, rewriter);
|
||||
if (!newCallOp)
|
||||
return failure();
|
||||
allocation.mapCallOp(callOp, newCallOp);
|
||||
auto results = getCallOpResults(callOp, newCallOp, rewriter);
|
||||
rewriter.replaceOp(callOp, results);
|
||||
return success();
|
||||
}
|
||||
|
||||
private:
|
||||
SmallVector<Value, 4>
|
||||
promoteOperands(triton::CallOp callOp,
|
||||
typename triton::CallOp::Adaptor adaptor,
|
||||
ConversionPatternRewriter &rewriter) const {
|
||||
// Get the last argument of the caller, which is the current stack pointer
|
||||
// of shared memory and append it to the operands of the callOp.
|
||||
auto loc = callOp.getLoc();
|
||||
auto caller = callOp->getParentOfType<FunctionOpInterface>();
|
||||
auto ptrTy = LLVM::LLVMPointerType::get(
|
||||
this->getTypeConverter()->convertType(rewriter.getI8Type()),
|
||||
NVVM::kSharedMemorySpace);
|
||||
auto promotedOperands = this->getTypeConverter()->promoteOperands(
|
||||
callOp.getLoc(), /*opOperands=*/callOp->getOperands(),
|
||||
adaptor.getOperands(), rewriter);
|
||||
auto base = allocation.getFunctionSharedMemoryBase(caller);
|
||||
auto *funcAllocation = allocation.getFuncData(caller);
|
||||
auto bufferId = funcAllocation->getBufferId(callOp);
|
||||
// function doesn't have a shared mem buffer
|
||||
if (bufferId == (size_t)-1) {
|
||||
promotedOperands.push_back(base);
|
||||
return promotedOperands;
|
||||
}
|
||||
// function has a shared mem buffer
|
||||
auto offset = funcAllocation->getOffset(bufferId);
|
||||
auto offsetValue = gep(ptrTy, base, i32_val(offset));
|
||||
promotedOperands.push_back(offsetValue);
|
||||
return promotedOperands;
|
||||
}
|
||||
|
||||
LLVM::CallOp
|
||||
convertCallOpToLLVMCallOp(triton::CallOp callOp,
|
||||
ArrayRef<Value> promotedOperands,
|
||||
ConversionPatternRewriter &rewriter) const {
|
||||
// Pack the result types into a struct.
|
||||
Type packedResult = nullptr;
|
||||
unsigned numResults = callOp.getNumResults();
|
||||
auto resultTypes = llvm::to_vector<4>(callOp.getResultTypes());
|
||||
|
||||
if (numResults != 0) {
|
||||
if (!(packedResult =
|
||||
this->getTypeConverter()->packFunctionResults(resultTypes)))
|
||||
return nullptr;
|
||||
}
|
||||
auto newCallOp = rewriter.create<LLVM::CallOp>(
|
||||
callOp.getLoc(), packedResult ? TypeRange(packedResult) : TypeRange(),
|
||||
promotedOperands, callOp->getAttrs());
|
||||
return newCallOp;
|
||||
}
|
||||
|
||||
SmallVector<Value>
|
||||
getCallOpResults(triton::CallOp callOp, LLVM::CallOp newCallOp,
|
||||
ConversionPatternRewriter &rewriter) const {
|
||||
auto numResults = callOp.getNumResults();
|
||||
SmallVector<Value> results;
|
||||
if (numResults < 2) {
|
||||
// If < 2 results, packing did not do anything and we can just return.
|
||||
results.append(newCallOp.result_begin(), newCallOp.result_end());
|
||||
} else {
|
||||
// Otherwise, it had been converted to an operation producing a structure.
|
||||
// Extract individual results from the structure and return them as list.
|
||||
results.reserve(numResults);
|
||||
for (unsigned i = 0; i < numResults; ++i) {
|
||||
results.push_back(rewriter.create<LLVM::ExtractValueOp>(
|
||||
callOp.getLoc(), newCallOp->getResult(0), i));
|
||||
}
|
||||
}
|
||||
return results;
|
||||
}
|
||||
|
||||
int numWarps{0};
|
||||
ModuleAllocation &allocation;
|
||||
};
|
||||
|
||||
class TritonLLVMConversionTarget : public ConversionTarget {
|
||||
public:
|
||||
explicit TritonLLVMConversionTarget(MLIRContext &ctx, Target target)
|
||||
: ConversionTarget(ctx) {
|
||||
addLegalDialect<LLVM::LLVMDialect>();
|
||||
switch (target) {
|
||||
case Target::NVVM:
|
||||
addLegalDialect<NVVM::NVVMDialect>();
|
||||
break;
|
||||
case Target::ROCDL:
|
||||
addLegalDialect<ROCDL::ROCDLDialect>();
|
||||
addLegalDialect<mlir::scf::SCFDialect>();
|
||||
break;
|
||||
}
|
||||
addLegalDialect<mlir::triton::nvgpu::NVGPUDialect>();
|
||||
addIllegalDialect<triton::TritonDialect>();
|
||||
addIllegalDialect<triton::gpu_rocm::TritonGPUROCMDialect>();
|
||||
addIllegalDialect<triton::nvidia_gpu::TritonNvidiaGPUDialect>();
|
||||
addIllegalDialect<mlir::gpu::GPUDialect>();
|
||||
addLegalOp<mlir::UnrealizedConversionCastOp>();
|
||||
}
|
||||
};
|
||||
|
||||
struct ConvertTritonGPUROCMToLLVM
|
||||
: public triton::impl::ConvertTritonGPUROCMToLLVMBase<ConvertTritonGPUROCMToLLVM> {
|
||||
using ConvertTritonGPUROCMToLLVMBase<
|
||||
ConvertTritonGPUROCMToLLVM>::ConvertTritonGPUROCMToLLVMBase;
|
||||
|
||||
void runOnOperation() override {
|
||||
MLIRContext *context = &getContext();
|
||||
ModuleOp mod = getOperation();
|
||||
mlir::LowerToLLVMOptions option(context);
|
||||
option.overrideIndexBitwidth(32);
|
||||
TritonGPUToLLVMTypeConverter typeConverter(context, option);
|
||||
TritonLLVMConversionTarget convTarget(*context, target);
|
||||
int numWarps = triton::gpu_rocm::TritonGPUROCMDialect::getNumWarps(mod);
|
||||
int numCTAs = triton::gpu_rocm::TritonGPUROCMDialect::getNumCTAs(mod);
|
||||
int threadsPerWarp = triton::gpu_rocm::TritonGPUROCMDialect::getThreadsPerWarp(mod);
|
||||
|
||||
// Preprocess
|
||||
decomposeFp8e4b15Convert(mod);
|
||||
decomposeMmaToDotOperand(mod, numWarps, threadsPerWarp, numCTAs);
|
||||
#if 1
|
||||
decomposeMfmaToDotOperand(mod, numWarps, threadsPerWarp, numCTAs);
|
||||
#endif
|
||||
decomposeBlockedToDotOperand(mod);
|
||||
decomposeInsertSliceAsyncOp(mod);
|
||||
decomposeMixedModeDotOp(mod);
|
||||
|
||||
// Allocate shared memory and set barrier
|
||||
ModuleAllocation allocation(mod);
|
||||
ModuleMembarAnalysis membarPass(&allocation);
|
||||
membarPass.run();
|
||||
|
||||
/* Get tensorPtrMap before conversion */
|
||||
TensorPtrMapT tensorPtrMap;
|
||||
mod.walk([&tensorPtrMap](
|
||||
mlir::triton::nvidia_gpu::InsertSliceAsyncV2Op insertOp) {
|
||||
auto src = insertOp.getSrc();
|
||||
auto ptrTy = src.getType().dyn_cast<triton::PointerType>();
|
||||
if (ptrTy && ptrTy.getPointeeType().isa<RankedTensorType>()) {
|
||||
auto makeTensorPtrOp = getMakeTensorPtrOp(insertOp.getSrc());
|
||||
tensorPtrMap[insertOp.getOperation()] = makeTensorPtrOp;
|
||||
}
|
||||
});
|
||||
|
||||
mod.walk([&tensorPtrMap](mlir::triton::nvidia_gpu::StoreAsyncOp storeOp) {
|
||||
auto dst = storeOp.getDst();
|
||||
auto ptrTy = dst.getType().dyn_cast<triton::PointerType>();
|
||||
if (ptrTy && ptrTy.getPointeeType().isa<RankedTensorType>()) {
|
||||
auto makeTensorPtrOp = getMakeTensorPtrOp(storeOp.getDst());
|
||||
tensorPtrMap[storeOp.getOperation()] = makeTensorPtrOp;
|
||||
}
|
||||
});
|
||||
|
||||
// Hack: cleanup
|
||||
{
|
||||
RewritePatternSet patterns(context);
|
||||
patterns.add<FoldSplatMaskInInsertAsync>(context);
|
||||
SmallVector<Operation *> insertSlices;
|
||||
mod.walk([&insertSlices](triton::nvidia_gpu::InsertSliceAsyncV2Op op) {
|
||||
insertSlices.push_back(op);
|
||||
});
|
||||
if (applyOpPatternsAndFold(insertSlices, std::move(patterns)).failed())
|
||||
signalPassFailure();
|
||||
}
|
||||
|
||||
// Lower functions
|
||||
{
|
||||
mlir::LowerToLLVMOptions option(context);
|
||||
TritonGPUToLLVMTypeConverter typeConverter(context, option);
|
||||
TritonLLVMFunctionConversionTarget funcTarget(*context, target);
|
||||
RewritePatternSet funcPatterns(context);
|
||||
funcPatterns.add<FuncOpConversion>(typeConverter, numWarps, allocation,
|
||||
/*benefit=*/1);
|
||||
mlir::cf::populateControlFlowToLLVMConversionPatterns(typeConverter,
|
||||
funcPatterns);
|
||||
if (failed(
|
||||
applyPartialConversion(mod, funcTarget, std::move(funcPatterns))))
|
||||
return signalPassFailure();
|
||||
}
|
||||
|
||||
// initSharedMemory is run before the conversion of call and ret ops,
|
||||
// because the call op has to know the shared memory base address of each
|
||||
// function
|
||||
initSharedMemory(allocation, typeConverter);
|
||||
|
||||
// Convert call and ret ops
|
||||
{
|
||||
mlir::LowerToLLVMOptions option(context);
|
||||
TritonGPUToLLVMTypeConverter typeConverter(context, option);
|
||||
TritonLLVMFunctionConversionTarget funcTarget(*context, target);
|
||||
RewritePatternSet funcPatterns(context);
|
||||
funcPatterns.add<CallOpConversion>(typeConverter, numWarps, allocation,
|
||||
/*benefit=*/1);
|
||||
funcPatterns.add<ReturnOpConversion>(typeConverter, /*benefit=*/1);
|
||||
if (failed(
|
||||
applyPartialConversion(mod, funcTarget, std::move(funcPatterns))))
|
||||
return signalPassFailure();
|
||||
}
|
||||
|
||||
ModuleAxisInfoAnalysis axisInfoAnalysis(mod);
|
||||
|
||||
// Emit logics to get threadId/blockIds/linearized clusterCTAId etc. and
|
||||
// cache the values. The reason to do it here is that cluster_ctaid is
|
||||
// currently implemented via inline asm, and thus cannot be CSEed.
|
||||
// clusterCTAId will be emitted only when numCTAs is larger than 1, and
|
||||
// other values will be DCEed if not used hereafter.
|
||||
bool isWarpSpecialization =
|
||||
ttng::TritonNvidiaGPUDialect::getWSSupportedAttr(mod);
|
||||
OpBuilder::InsertPoint indexInsertPoint;
|
||||
ConvertTritonGPUOpToLLVMPatternBase::IndexCacheInfo indexCacheInfo{
|
||||
&baseIndexCache, &indexCache, &indexInsertPoint};
|
||||
// TODO: enable index cache if there are multiple functions
|
||||
if (axisInfoAnalysis.getNumFunctions() > 1) {
|
||||
indexCacheInfo = {nullptr, nullptr, nullptr};
|
||||
}
|
||||
|
||||
// tmaMetadata is absent in a triton-opt unit test, in this case, create a
|
||||
// local one and dump it after this pass is done.
|
||||
mlir::triton::gpu::TMAMetadataTy tmaMetaDataDebug;
|
||||
if (tmaMetadata == nullptr)
|
||||
tmaMetadata = &tmaMetaDataDebug;
|
||||
|
||||
RewritePatternSet patterns(context);
|
||||
|
||||
auto populatePatterns1 = [&](auto populateFunc) {
|
||||
populateFunc(typeConverter, patterns, numWarps, axisInfoAnalysis,
|
||||
allocation, indexCacheInfo,
|
||||
/*benefit*/ 10);
|
||||
};
|
||||
|
||||
auto populatePatterns2 = [&](auto populateFunc) {
|
||||
populateFunc(typeConverter, patterns, numWarps, axisInfoAnalysis,
|
||||
allocation, /*benefit*/ 10);
|
||||
};
|
||||
|
||||
auto populatePatterns3 = [&](auto populateFunc) {
|
||||
populateFunc(typeConverter, patterns, numWarps, axisInfoAnalysis,
|
||||
allocation, indexCacheInfo, tmaMetadata, &tensorPtrMap,
|
||||
/*benefit*/ 10);
|
||||
};
|
||||
|
||||
auto populatePatterns4 = [&](auto populateFunc) {
|
||||
populateFunc(typeConverter, patterns, numWarps, axisInfoAnalysis,
|
||||
allocation, indexCacheInfo, computeCapability,
|
||||
/*benefit*/ 10);
|
||||
};
|
||||
|
||||
populatePatterns1(populateTritonGPUToLLVMPatterns);
|
||||
populatePatterns1(populateConvertLayoutOpToLLVMPatterns);
|
||||
populatePatterns2(populateDotOpToLLVMPatterns);
|
||||
populatePatterns4(populateElementwiseOpToLLVMPatterns);
|
||||
populatePatterns3(populateLoadStoreOpToLLVMPatterns);
|
||||
populatePatterns4(populateReduceOpToLLVMPatterns);
|
||||
populatePatterns1(populateScanOpToLLVMPatterns);
|
||||
populatePatterns2(populateViewOpToLLVMPatterns);
|
||||
populatePatterns2(populateBarrierOpToLLVMPatterns);
|
||||
populatePatterns2(populateTensorPtrOpsToLLVMPatterns);
|
||||
populatePatterns2(populateClusterOpsToLLVMPatterns);
|
||||
populatePatterns2(populateRegReallocOpToLLVMPatterns);
|
||||
|
||||
// TODO(thomas): this should probably be done in a separate step to not
|
||||
// interfere with our own lowering of arith ops. Add arith/math's patterns
|
||||
// to help convert scalar expression to LLVM.
|
||||
mlir::arith::populateArithToLLVMConversionPatterns(typeConverter, patterns);
|
||||
mlir::populateMathToLLVMConversionPatterns(typeConverter, patterns);
|
||||
|
||||
// Native lowering patterns
|
||||
switch (target) {
|
||||
case Target::NVVM:
|
||||
mlir::populateGpuToNVVMConversionPatterns(typeConverter, patterns);
|
||||
break;
|
||||
case Target::ROCDL:
|
||||
mlir::populateGpuToROCDLConversionPatterns(typeConverter, patterns,
|
||||
mlir::gpu::amd::HIP);
|
||||
break;
|
||||
}
|
||||
|
||||
mlir::cf::populateControlFlowToLLVMConversionPatterns(typeConverter,
|
||||
patterns);
|
||||
if (failed(applyPartialConversion(mod, convTarget, std::move(patterns))))
|
||||
return signalPassFailure();
|
||||
|
||||
// Fold CTAId when there is only 1 CTA.
|
||||
if (numCTAs == 1) {
|
||||
mod.walk([](triton::nvgpu::ClusterCTAIdOp id) {
|
||||
OpBuilder b(id);
|
||||
Value zero = LLVM::createConstantI32(id->getLoc(), b, 0);
|
||||
id.replaceAllUsesWith(zero);
|
||||
});
|
||||
}
|
||||
}
|
||||
|
||||
private:
|
||||
DenseMap<IndexCacheKeyT, SmallVector<Value>, CacheKeyDenseMapInfo>
|
||||
baseIndexCache;
|
||||
DenseMap<IndexCacheKeyT, SmallVector<SmallVector<Value>>,
|
||||
CacheKeyDenseMapInfo>
|
||||
indexCache;
|
||||
|
||||
void initSharedMemory(ModuleAllocation &allocation,
|
||||
TritonGPUToLLVMTypeConverter &typeConverter) {
|
||||
ModuleOp mod = getOperation();
|
||||
OpBuilder b(mod.getBodyRegion());
|
||||
auto ctx = mod.getContext();
|
||||
auto loc = mod.getLoc();
|
||||
auto elemTy = typeConverter.convertType(b.getIntegerType(8));
|
||||
// Set array size 0 and external linkage indicates that we use dynamic
|
||||
// shared allocation to allow a larger shared memory size for each kernel.
|
||||
auto arrayTy = LLVM::LLVMArrayType::get(elemTy, 0);
|
||||
auto global = b.create<LLVM::GlobalOp>(
|
||||
loc, arrayTy, /*isConstant=*/false, LLVM::Linkage::External,
|
||||
"global_smem", /*value=*/Attribute(), /*alignment=*/0,
|
||||
// Add ROCm support.
|
||||
static_cast<unsigned>(NVVM::NVVMMemorySpace::kSharedMemorySpace));
|
||||
mod.walk([&](FunctionOpInterface funcOp) {
|
||||
Value funcSmem;
|
||||
b.setInsertionPointToStart(&funcOp.getFunctionBody().front());
|
||||
if (allocation.isRoot(funcOp)) {
|
||||
funcSmem = b.create<LLVM::AddressOfOp>(loc, global);
|
||||
} else {
|
||||
funcSmem = funcOp.getArgument(funcOp.getNumArguments() - 1);
|
||||
}
|
||||
auto ptrTy =
|
||||
LLVM::LLVMPointerType::get(typeConverter.convertType(b.getI8Type()),
|
||||
NVVM::NVVMMemorySpace::kSharedMemorySpace);
|
||||
funcSmem = b.create<LLVM::BitcastOp>(loc, ptrTy, funcSmem);
|
||||
allocation.setFunctionSharedMemoryValue(funcOp, funcSmem);
|
||||
});
|
||||
mod->setAttr("triton_gpu_rocm.shared",
|
||||
mlir::IntegerAttr::get(mlir::IntegerType::get(ctx, 32),
|
||||
allocation.getSharedMemorySize()));
|
||||
}
|
||||
|
||||
void decomposeFp8e4b15Convert(ModuleOp mod) const {
|
||||
mod.walk([&](triton::gpu_rocm::ConvertLayoutOp cvtOp) -> void {
|
||||
OpBuilder builder(cvtOp);
|
||||
if (!getElementTypeOrSelf(cvtOp)
|
||||
.isa<mlir::Float8E4M3B11FNUZType, mlir::Float8E4M3FNType>())
|
||||
return;
|
||||
auto shape = cvtOp.getType().cast<RankedTensorType>().getShape();
|
||||
auto argEncoding =
|
||||
cvtOp.getOperand().getType().cast<RankedTensorType>().getEncoding();
|
||||
auto cvtEncoding = cvtOp.getType().cast<RankedTensorType>().getEncoding();
|
||||
if (argEncoding.isa<triton::gpu_rocm::DotOperandEncodingAttr>() ||
|
||||
cvtEncoding.isa<triton::gpu_rocm::DotOperandEncodingAttr>())
|
||||
return;
|
||||
auto F16Ty = builder.getF16Type();
|
||||
|
||||
auto newArgType = RankedTensorType::get(shape, F16Ty, argEncoding);
|
||||
auto newCvtType = RankedTensorType::get(shape, F16Ty, cvtEncoding);
|
||||
auto newArg = builder.create<mlir::triton::FpToFpOp>(
|
||||
cvtOp.getLoc(), newArgType, cvtOp.getOperand());
|
||||
addWSNamedAttrs(newArg, cvtOp->getAttrs());
|
||||
auto newCvt = builder.create<mlir::triton::gpu_rocm::ConvertLayoutOp>(
|
||||
cvtOp.getLoc(), newCvtType, newArg);
|
||||
addWSNamedAttrs(newCvt, cvtOp->getAttrs());
|
||||
auto newRet = builder.create<mlir::triton::FpToFpOp>(
|
||||
cvtOp.getLoc(), cvtOp.getType(), newCvt.getResult());
|
||||
addWSNamedAttrs(newRet, cvtOp->getAttrs());
|
||||
cvtOp.replaceAllUsesWith(newRet.getResult());
|
||||
cvtOp.erase();
|
||||
});
|
||||
}
|
||||
|
||||
void decomposeMmaToDotOperand(ModuleOp mod, int numWarps, int threadsPerWarp,
|
||||
int numCTAs) const {
|
||||
// Replace `mma -> dot_op` with `mma -> blocked -> dot_op`
|
||||
// unless certain conditions are met
|
||||
mod.walk([&](triton::gpu_rocm::ConvertLayoutOp cvtOp) -> void {
|
||||
OpBuilder builder(cvtOp);
|
||||
auto srcType = cvtOp.getOperand().getType().cast<RankedTensorType>();
|
||||
auto dstType = cvtOp.getType().cast<RankedTensorType>();
|
||||
auto srcMma =
|
||||
srcType.getEncoding().dyn_cast<triton::gpu_rocm::MmaEncodingAttr>();
|
||||
auto dstDotOp =
|
||||
dstType.getEncoding().dyn_cast<triton::gpu_rocm::DotOperandEncodingAttr>();
|
||||
if (srcMma && dstDotOp && !isMmaToDotShortcut(srcType, dstType)) {
|
||||
auto tmpType = RankedTensorType::get(
|
||||
dstType.getShape(), dstType.getElementType(),
|
||||
triton::gpu_rocm::BlockedEncodingAttr::get(
|
||||
mod.getContext(), srcType.getShape(), getSizePerThread(srcMma),
|
||||
getOrder(srcMma), numWarps, threadsPerWarp, numCTAs));
|
||||
auto tmp = builder.create<triton::gpu_rocm::ConvertLayoutOp>(
|
||||
cvtOp.getLoc(), tmpType, cvtOp.getOperand());
|
||||
addWSNamedAttrs(tmp, cvtOp->getAttrs());
|
||||
auto newConvert = builder.create<triton::gpu_rocm::ConvertLayoutOp>(
|
||||
cvtOp.getLoc(), dstType, tmp);
|
||||
addWSNamedAttrs(newConvert, cvtOp->getAttrs());
|
||||
cvtOp.replaceAllUsesWith(newConvert.getResult());
|
||||
cvtOp.erase();
|
||||
}
|
||||
});
|
||||
}
|
||||
|
||||
#if 1
|
||||
void decomposeMfmaToDotOperand(ModuleOp mod, int numWarps, int threadsPerWarp,
|
||||
int numCTAs) const {
|
||||
// Replace `mfma -> dot_op` with `mfma -> blocked -> dot_op`
|
||||
// unless certain conditions are met
|
||||
mod.walk([&](triton::gpu_rocm::ConvertLayoutOp cvtOp) -> void {
|
||||
OpBuilder builder(cvtOp);
|
||||
auto srcType = cvtOp.getOperand().getType().cast<RankedTensorType>();
|
||||
auto dstType = cvtOp.getType().cast<RankedTensorType>();
|
||||
auto srcMfma =
|
||||
srcType.getEncoding().dyn_cast<triton::gpu_rocm::MfmaEncodingAttr>();
|
||||
auto dstDotOp =
|
||||
dstType.getEncoding().dyn_cast<triton::gpu_rocm::DotOperandEncodingAttr>();
|
||||
if (srcMfma && dstDotOp && !isMfmaToDotShortcut(srcType, dstType)) {
|
||||
auto tmpType = RankedTensorType::get(
|
||||
dstType.getShape(), dstType.getElementType(),
|
||||
triton::gpu_rocm::BlockedEncodingAttr::get(
|
||||
mod.getContext(), srcType.getShape(), getSizePerThread(srcMfma),
|
||||
getOrder(srcMfma), numWarps, threadsPerWarp, numCTAs));
|
||||
auto tmp = builder.create<triton::gpu_rocm::ConvertLayoutOp>(
|
||||
cvtOp.getLoc(), tmpType, cvtOp.getOperand());
|
||||
auto newConvert = builder.create<triton::gpu_rocm::ConvertLayoutOp>(
|
||||
cvtOp.getLoc(), dstType, tmp);
|
||||
cvtOp.replaceAllUsesWith(newConvert.getResult());
|
||||
cvtOp.erase();
|
||||
}
|
||||
});
|
||||
}
|
||||
#endif
|
||||
|
||||
void decomposeBlockedToDotOperand(ModuleOp mod) const {
|
||||
// Replace `blocked -> dot_op` with `blocked -> shared -> dot_op`
|
||||
// because the codegen doesn't handle `blocked -> dot_op` directly
|
||||
mod.walk([&](triton::gpu_rocm::ConvertLayoutOp cvtOp) -> void {
|
||||
OpBuilder builder(cvtOp);
|
||||
auto srcType = cvtOp.getOperand().getType().cast<RankedTensorType>();
|
||||
auto dstType = cvtOp.getType().cast<RankedTensorType>();
|
||||
auto srcBlocked =
|
||||
srcType.getEncoding().dyn_cast<triton::gpu_rocm::BlockedEncodingAttr>();
|
||||
auto dstDotOp =
|
||||
dstType.getEncoding().dyn_cast<triton::gpu_rocm::DotOperandEncodingAttr>();
|
||||
if (srcBlocked && dstDotOp) {
|
||||
auto tmpType = RankedTensorType::get(
|
||||
dstType.getShape(), dstType.getElementType(),
|
||||
triton::gpu_rocm::SharedEncodingAttr::get(
|
||||
mod.getContext(), dstDotOp, srcType.getShape(),
|
||||
srcBlocked.getOrder(), srcBlocked.getCTALayout(),
|
||||
srcType.getElementType()));
|
||||
auto tmp = builder.create<triton::gpu_rocm::ConvertLayoutOp>(
|
||||
cvtOp.getLoc(), tmpType, cvtOp.getOperand());
|
||||
addWSNamedAttrs(tmp, cvtOp->getAttrs());
|
||||
auto newConvert = builder.create<triton::gpu_rocm::ConvertLayoutOp>(
|
||||
cvtOp.getLoc(), dstType, tmp);
|
||||
addWSNamedAttrs(newConvert, cvtOp->getAttrs());
|
||||
cvtOp.replaceAllUsesWith(newConvert.getResult());
|
||||
cvtOp.erase();
|
||||
}
|
||||
});
|
||||
}
|
||||
|
||||
void decomposeInsertSliceAsyncOp(ModuleOp mod) const {
|
||||
ModuleAxisInfoAnalysis axisInfoAnalysis(mod);
|
||||
// TODO(Keren): This is a hacky knob that may cause performance regression
|
||||
// when decomposition has been performed. We should remove this knob once we
|
||||
// have thorough analysis on async wait. Currently, we decompose
|
||||
// `insert_slice_async` into `load` and `insert_slice` without knowing which
|
||||
// `async_wait` is responsible for the `insert_slice_async`. To guarantee
|
||||
// correctness, we blindly set the `async_wait` to wait for all async ops.
|
||||
//
|
||||
// There are two options to improve this:
|
||||
// 1. We can perform a dataflow analysis to find the `async_wait` that is
|
||||
// responsible for the `insert_slice_async` in the backend.
|
||||
// 2. We can modify the pipeline to perform the decomposition before the
|
||||
// `async_wait` is inserted. However, it is also risky because we don't know
|
||||
// the correct vectorized shape yet in the pipeline pass. Making the
|
||||
// pipeline pass aware of the vectorization could introduce additional
|
||||
// dependencies on the AxisInfoAnalysis and the Coalesce analysis.
|
||||
bool decomposed = false;
|
||||
// insert_slice_async %src, %dst, %idx, %mask, %other
|
||||
// =>
|
||||
// %tmp = load %src, %mask, %other
|
||||
// %res = insert_slice %tmp into %dst[%idx]
|
||||
mod.walk([&](triton::gpu_rocm::InsertSliceAsyncOp insertSliceAsyncOp) -> void {
|
||||
OpBuilder builder(insertSliceAsyncOp);
|
||||
|
||||
// Get the vectorized load size
|
||||
auto src = insertSliceAsyncOp.getSrc();
|
||||
auto dst = insertSliceAsyncOp.getDst();
|
||||
auto mask = insertSliceAsyncOp.getMask();
|
||||
auto srcTy = src.getType().cast<RankedTensorType>();
|
||||
auto dstTy = dst.getType().cast<RankedTensorType>();
|
||||
auto srcBlocked =
|
||||
srcTy.getEncoding().dyn_cast<triton::gpu_rocm::BlockedEncodingAttr>();
|
||||
auto resSharedLayout =
|
||||
dstTy.getEncoding().dyn_cast<triton::gpu_rocm::SharedEncodingAttr>();
|
||||
auto resElemTy = dstTy.getElementType();
|
||||
unsigned inVec = axisInfoAnalysis.getPtrContiguity(src);
|
||||
if (mask)
|
||||
inVec =
|
||||
std::min<unsigned>(axisInfoAnalysis.getMaskAlignment(mask), inVec);
|
||||
unsigned outVec = resSharedLayout.getVec();
|
||||
unsigned minVec = inVec;
|
||||
if (outVec > 1)
|
||||
minVec = std::min(outVec, inVec);
|
||||
auto maxBitWidth =
|
||||
std::max<unsigned>(128, resElemTy.getIntOrFloatBitWidth());
|
||||
auto vecBitWidth = resElemTy.getIntOrFloatBitWidth() * minVec;
|
||||
auto bitWidth = std::min<unsigned>(maxBitWidth, vecBitWidth);
|
||||
auto byteWidth = bitWidth / 8;
|
||||
|
||||
// If the load byte width is not eligible or the current compute
|
||||
// capability does not support async copy, then we do decompose
|
||||
#if 0
|
||||
if (triton::gpu_rocm::InsertSliceAsyncOp::getEligibleLoadByteWidth(
|
||||
computeCapability)
|
||||
.contains(byteWidth)) {
|
||||
return;
|
||||
}
|
||||
#endif
|
||||
|
||||
// load
|
||||
auto tmpTy =
|
||||
RankedTensorType::get(srcTy.getShape(), resElemTy, srcBlocked);
|
||||
auto loadOp = builder.create<triton::LoadOp>(
|
||||
insertSliceAsyncOp.getLoc(), tmpTy, insertSliceAsyncOp.getSrc(),
|
||||
insertSliceAsyncOp.getMask(), insertSliceAsyncOp.getOther(),
|
||||
// TODO(Chenggang): confirm `boundaryCheck` and `padding`
|
||||
/*boundaryCheck=*/nullptr, /*padding=*/nullptr,
|
||||
insertSliceAsyncOp.getCache(), insertSliceAsyncOp.getEvict(),
|
||||
insertSliceAsyncOp.getIsVolatile());
|
||||
addWSNamedAttrs(loadOp, insertSliceAsyncOp->getAttrs());
|
||||
|
||||
// insert_slice
|
||||
auto axis = insertSliceAsyncOp.getAxis();
|
||||
auto intAttr = [&](int64_t v) { return builder.getI64IntegerAttr(v); };
|
||||
auto offsets = SmallVector<OpFoldResult>(dstTy.getRank(), intAttr(0));
|
||||
auto sizes = SmallVector<OpFoldResult>(dstTy.getRank(), intAttr(1));
|
||||
auto strides = SmallVector<OpFoldResult>(dstTy.getRank(), intAttr(1));
|
||||
offsets[axis] = insertSliceAsyncOp.getIndex();
|
||||
for (size_t i = 0; i < dstTy.getRank(); i++) {
|
||||
if (i != axis)
|
||||
sizes[i] = intAttr(dstTy.getShape()[i]);
|
||||
}
|
||||
auto insertSliceOp = builder.create<tensor::InsertSliceOp>(
|
||||
insertSliceAsyncOp.getLoc(), loadOp, insertSliceAsyncOp.getDst(),
|
||||
offsets, sizes, strides);
|
||||
addWSNamedAttrs(insertSliceOp, insertSliceAsyncOp->getAttrs());
|
||||
|
||||
// Replace
|
||||
insertSliceAsyncOp.replaceAllUsesWith(insertSliceOp.getResult());
|
||||
insertSliceAsyncOp.erase();
|
||||
decomposed = true;
|
||||
});
|
||||
|
||||
mod.walk([&](triton::gpu_rocm::AsyncCommitGroupOp asyncCommitGroupOp) -> void {
|
||||
if (!triton::gpu_rocm::AsyncCommitGroupOp::isSupported(computeCapability))
|
||||
asyncCommitGroupOp.erase();
|
||||
});
|
||||
|
||||
mod.walk([&](triton::gpu_rocm::AsyncWaitOp asyncWaitOp) -> void {
|
||||
#if 1
|
||||
// AsyncWait is not supported for ROCM and should be removed
|
||||
asyncWaitOp.erase();
|
||||
#else
|
||||
if (!triton::gpu_rocm::AsyncWaitOp::isSupported(computeCapability)) {
|
||||
// async wait is supported in Ampere and later
|
||||
asyncWaitOp.erase();
|
||||
} else if (decomposed) {
|
||||
// Wait for all previous async ops
|
||||
OpBuilder builder(asyncWaitOp);
|
||||
auto newWaitOp =
|
||||
builder.create<triton::gpu_rocm::AsyncWaitOp>(asyncWaitOp.getLoc(), 0);
|
||||
addWSNamedAttrs(newWaitOp, asyncWaitOp->getAttrs());
|
||||
asyncWaitOp.erase();
|
||||
}
|
||||
#endif
|
||||
});
|
||||
}
|
||||
|
||||
static Value promoteOperand(OpBuilder &builder, Location loc, Value operand,
|
||||
Type promotedType) {
|
||||
Type tensorPromotedType =
|
||||
operand.getType().cast<RankedTensorType>().cloneWith(std::nullopt,
|
||||
promotedType);
|
||||
return builder.create<triton::FpToFpOp>(loc, tensorPromotedType, operand);
|
||||
}
|
||||
|
||||
// promote operands of dot op if the existing combination is not natively
|
||||
// supported.
|
||||
void decomposeMixedModeDotOp(ModuleOp mod) const {
|
||||
mod.walk([](triton::DotOp dotOp) -> void {
|
||||
Value D = dotOp.getResult();
|
||||
OpBuilder builder(dotOp);
|
||||
Type AElType =
|
||||
dotOp.getA().getType().cast<RankedTensorType>().getElementType();
|
||||
Type promoteType;
|
||||
MmaEncodingAttr mmaLayout = D.getType()
|
||||
.cast<RankedTensorType>()
|
||||
.getEncoding()
|
||||
.dyn_cast<MmaEncodingAttr>();
|
||||
if (mmaLayout) {
|
||||
bool isNativeHopperFP8 =
|
||||
AElType.isFloat8E5M2() || AElType.isFloat8E4M3FNUZ();
|
||||
bool isFP8 = isNativeHopperFP8 || AElType.isFloat8E5M2FNUZ() ||
|
||||
AElType.isFloat8E4M3FN();
|
||||
if (!isFP8 || (isNativeHopperFP8 && mmaLayout.isHopper()))
|
||||
return;
|
||||
promoteType = builder.getF16Type();
|
||||
#if 1
|
||||
} else if (MfmaEncodingAttr mfmaLayout =
|
||||
D.getType()
|
||||
.cast<RankedTensorType>()
|
||||
.getEncoding()
|
||||
.dyn_cast<MfmaEncodingAttr>()) {
|
||||
if (AElType.isBF16() || AElType.isF16() || AElType.isF32() ||
|
||||
AElType.isInteger(8))
|
||||
return;
|
||||
promoteType = builder.getF16Type();
|
||||
#endif
|
||||
} else {
|
||||
// FMA case.
|
||||
Type AElType =
|
||||
dotOp.getA().getType().cast<RankedTensorType>().getElementType();
|
||||
Type DElType = D.getType().cast<RankedTensorType>().getElementType();
|
||||
if (AElType == DElType)
|
||||
return;
|
||||
promoteType = DElType;
|
||||
}
|
||||
Location loc = dotOp.getLoc();
|
||||
Value promotedA = promoteOperand(builder, loc, dotOp.getA(), promoteType);
|
||||
Value promotedB = promoteOperand(builder, loc, dotOp.getB(), promoteType);
|
||||
dotOp.setOperand(0, promotedA);
|
||||
dotOp.setOperand(1, promotedB);
|
||||
});
|
||||
}
|
||||
};
|
||||
|
||||
} // anonymous namespace
|
||||
|
||||
namespace mlir {
|
||||
namespace triton {
|
||||
|
||||
std::unique_ptr<OperationPass<ModuleOp>> createConvertTritonGPUROCMToLLVMPass() {
|
||||
return std::make_unique<ConvertTritonGPUROCMToLLVM>();
|
||||
}
|
||||
std::unique_ptr<OperationPass<ModuleOp>>
|
||||
createConvertTritonGPUROCMToLLVMPass(const ConvertTritonGPUROCMToLLVMOptions &options) {
|
||||
return std::make_unique<ConvertTritonGPUROCMToLLVM>(options);
|
||||
}
|
||||
|
||||
} // namespace triton
|
||||
} // namespace mlir
|
||||
208
python/triton/third_party/hip/lib/Conversion/TritonGPUROCMToLLVM/TypeConverter.cpp
vendored
Normal file
208
python/triton/third_party/hip/lib/Conversion/TritonGPUROCMToLLVM/TypeConverter.cpp
vendored
Normal file
@@ -0,0 +1,208 @@
|
||||
#include "TypeConverter.h"
|
||||
#include "Utility.h"
|
||||
#include "mlir/Dialect/LLVMIR/LLVMDialect.h"
|
||||
#include "triton/Conversion/MLIRTypes.h"
|
||||
|
||||
using namespace mlir;
|
||||
using namespace mlir::triton;
|
||||
|
||||
using ::mlir::triton::gpu_rocm::BlockedEncodingAttr;
|
||||
using ::mlir::triton::gpu_rocm::DotOperandEncodingAttr;
|
||||
using ::mlir::triton::gpu_rocm::getTotalElemsPerThread;
|
||||
using ::mlir::triton::gpu_rocm::MfmaEncodingAttr;
|
||||
using ::mlir::triton::gpu_rocm::MmaEncodingAttr;
|
||||
using ::mlir::triton::gpu_rocm::SharedEncodingAttr;
|
||||
using ::mlir::triton::gpu_rocm::SliceEncodingAttr;
|
||||
|
||||
TritonGPUToLLVMTypeConverter::TritonGPUToLLVMTypeConverter(
|
||||
MLIRContext *ctx, LowerToLLVMOptions &option,
|
||||
const DataLayoutAnalysis *analysis)
|
||||
: LLVMTypeConverter(ctx, option, analysis) {
|
||||
addConversion([&](triton::PointerType type) -> std::optional<Type> {
|
||||
return convertTritonPointerType(type);
|
||||
});
|
||||
addConversion([&](RankedTensorType type) -> std::optional<Type> {
|
||||
return convertTritonTensorType(type);
|
||||
});
|
||||
// Internally store float8 as int8
|
||||
addConversion([&](mlir::Float8E4M3B11FNUZType type) -> std::optional<Type> {
|
||||
return IntegerType::get(type.getContext(), 8);
|
||||
});
|
||||
addConversion([&](mlir::Float8E4M3FNType type) -> std::optional<Type> {
|
||||
return IntegerType::get(type.getContext(), 8);
|
||||
});
|
||||
addConversion([&](mlir::Float8E4M3FNUZType type) -> std::optional<Type> {
|
||||
return IntegerType::get(type.getContext(), 8);
|
||||
});
|
||||
addConversion([&](mlir::Float8E5M2Type type) -> std::optional<Type> {
|
||||
return IntegerType::get(type.getContext(), 8);
|
||||
});
|
||||
// Internally store bfloat16 as int16
|
||||
addConversion([&](BFloat16Type type) -> std::optional<Type> {
|
||||
return IntegerType::get(type.getContext(), 16);
|
||||
});
|
||||
}
|
||||
|
||||
Type TritonGPUToLLVMTypeConverter::convertTritonPointerType(
|
||||
triton::PointerType type) {
|
||||
auto ctx = type.getContext();
|
||||
auto pointeeType = type.getPointeeType();
|
||||
if (pointeeType.isa<RankedTensorType>()) {
|
||||
auto rankedTensorType = pointeeType.cast<RankedTensorType>();
|
||||
// struct { offset0, offset1, shape0, shape1, stride0,
|
||||
// stride1, base_ptr};
|
||||
auto eleType = rankedTensorType.getElementType();
|
||||
auto shape = rankedTensorType.getShape();
|
||||
SmallVector<Type, 4> types;
|
||||
// offsets
|
||||
for (size_t i = 0; i < shape.size(); ++i)
|
||||
types.push_back(IntegerType::get(ctx, 32));
|
||||
// shapes, strides
|
||||
for (size_t i = 0; i < 2 * shape.size(); ++i)
|
||||
types.push_back(IntegerType::get(ctx, 64));
|
||||
|
||||
types.push_back(
|
||||
LLVM::LLVMPointerType::get(eleType, type.getAddressSpace()));
|
||||
|
||||
return LLVM::LLVMStructType::getLiteral(ctx, types);
|
||||
}
|
||||
return LLVM::LLVMPointerType::get(convertType(type.getPointeeType()),
|
||||
type.getAddressSpace());
|
||||
}
|
||||
|
||||
Value TritonGPUToLLVMTypeConverter::packLLElements(
|
||||
Location loc, ValueRange resultVals, ConversionPatternRewriter &rewriter,
|
||||
Type type) {
|
||||
auto structType = this->convertType(type).dyn_cast<LLVM::LLVMStructType>();
|
||||
if (!structType) {
|
||||
assert(resultVals.size() == 1);
|
||||
return *resultVals.begin();
|
||||
}
|
||||
|
||||
auto elementTypes = structType.getBody();
|
||||
if (elementTypes.size() != resultVals.size()) {
|
||||
emitError(loc) << " size mismatch when packing elements for LLVM struct"
|
||||
<< " expected " << elementTypes.size() << " but got "
|
||||
<< resultVals.size();
|
||||
}
|
||||
Value llvmStruct = rewriter.create<LLVM::UndefOp>(loc, structType);
|
||||
for (const auto &v : llvm::enumerate(resultVals)) {
|
||||
if (!v.value()) {
|
||||
emitError(loc)
|
||||
<< "cannot insert null values into struct, but tried to insert"
|
||||
<< v.value();
|
||||
}
|
||||
if (v.value().getType() != elementTypes[v.index()]) {
|
||||
emitError(loc) << "invalid element type in packLLEElements. Expected "
|
||||
<< elementTypes[v.index()] << " but got "
|
||||
<< v.value().getType();
|
||||
}
|
||||
llvmStruct = insert_val(structType, llvmStruct, v.value(), v.index());
|
||||
}
|
||||
return llvmStruct;
|
||||
}
|
||||
|
||||
SmallVector<Value> TritonGPUToLLVMTypeConverter::packMfmaOperand(
|
||||
const SmallVector<Value> &inValues, Type srcTy,
|
||||
ConversionPatternRewriter &rewriter, Location loc) {
|
||||
auto tensorTy = srcTy.dyn_cast<RankedTensorType>();
|
||||
if (!tensorTy)
|
||||
return inValues;
|
||||
auto encoding = tensorTy.getEncoding().dyn_cast<DotOperandEncodingAttr>();
|
||||
if (!(encoding && encoding.getParent().isa<MfmaEncodingAttr>())) {
|
||||
return inValues;
|
||||
}
|
||||
|
||||
auto structType = this->convertType(srcTy).dyn_cast<LLVM::LLVMStructType>();
|
||||
auto elementTypes = structType.getBody();
|
||||
assert(elementTypes.size() > 0);
|
||||
mlir::VectorType vecTy = elementTypes[0].dyn_cast<mlir::VectorType>();
|
||||
if (!vecTy) return inValues;
|
||||
|
||||
unsigned size = vecTy.getNumElements();
|
||||
|
||||
SmallVector<Value> result;
|
||||
for (int i = 0; i < inValues.size(); i += size) {
|
||||
Value valVec = undef(vecTy);
|
||||
for (unsigned j = 0; j < size; ++j) {
|
||||
valVec = insert_element(vecTy, valVec, inValues[i + j], i32_val(j));
|
||||
}
|
||||
result.push_back(valVec);
|
||||
}
|
||||
|
||||
return result;
|
||||
}
|
||||
|
||||
SmallVector<Value> TritonGPUToLLVMTypeConverter::unpackLLElements(
|
||||
Location loc, Value llvmStruct, ConversionPatternRewriter &rewriter,
|
||||
Type type) {
|
||||
assert(bool(llvmStruct) && "can not unpack null values");
|
||||
if (llvmStruct.getType().isIntOrIndexOrFloat() ||
|
||||
llvmStruct.getType().isa<triton::PointerType>() ||
|
||||
llvmStruct.getType().isa<LLVM::LLVMPointerType>())
|
||||
return {llvmStruct};
|
||||
ArrayRef<Type> types =
|
||||
llvmStruct.getType().cast<LLVM::LLVMStructType>().getBody();
|
||||
SmallVector<Value> results(types.size());
|
||||
for (unsigned i = 0; i < types.size(); ++i) {
|
||||
Type type = types[i];
|
||||
results[i] = extract_val(type, llvmStruct, i);
|
||||
}
|
||||
return results;
|
||||
}
|
||||
|
||||
Type TritonGPUToLLVMTypeConverter::getElementTypeForStruct(
|
||||
RankedTensorType type) {
|
||||
auto ctx = type.getContext();
|
||||
Attribute layout = type.getEncoding();
|
||||
Type elemTy = convertType(type.getElementType());
|
||||
auto dotOpLayout = layout.dyn_cast<DotOperandEncodingAttr>();
|
||||
if (!dotOpLayout)
|
||||
return elemTy;
|
||||
|
||||
#if 1
|
||||
if (auto mfmaParent = dotOpLayout.getParent().dyn_cast<MfmaEncodingAttr>()) {
|
||||
if (elemTy.isF32())
|
||||
return elemTy;
|
||||
if (elemTy.isInteger(16)) // aka BF16
|
||||
return vec_ty(elemTy, dotOpLayout.getKWidth());
|
||||
if (elemTy.isF16())
|
||||
return vec_ty(elemTy, 4);
|
||||
if (elemTy.isInteger(8))
|
||||
return IntegerType::get(ctx, 32);
|
||||
}
|
||||
#endif
|
||||
|
||||
auto mmaParent = dotOpLayout.getParent().dyn_cast<MmaEncodingAttr>();
|
||||
if (!mmaParent)
|
||||
return elemTy;
|
||||
int bitwidth = elemTy.getIntOrFloatBitWidth();
|
||||
assert(bitwidth <= 32);
|
||||
return IntegerType::get(ctx, 32);
|
||||
}
|
||||
|
||||
Type TritonGPUToLLVMTypeConverter::convertTritonTensorType(
|
||||
RankedTensorType type) {
|
||||
auto ctx = type.getContext();
|
||||
Attribute layout = type.getEncoding();
|
||||
SmallVector<int64_t> shape(type.getShape().begin(), type.getShape().end());
|
||||
Type eltType = getElementTypeForStruct(type);
|
||||
|
||||
if (auto shared_layout = layout.dyn_cast<SharedEncodingAttr>()) {
|
||||
SmallVector<Type, 4> types;
|
||||
// base ptr
|
||||
auto ptrType = LLVM::LLVMPointerType::get(eltType, 3);
|
||||
types.push_back(ptrType);
|
||||
// shape dims
|
||||
auto rank = type.getRank();
|
||||
// offsets + strides
|
||||
for (auto i = 0; i < rank * 2; i++) {
|
||||
types.push_back(IntegerType::get(ctx, 32));
|
||||
}
|
||||
return LLVM::LLVMStructType::getLiteral(ctx, types);
|
||||
}
|
||||
|
||||
unsigned numElementsPerThread = getTotalElemsPerThread(type);
|
||||
SmallVector<Type, 4> types(numElementsPerThread, eltType);
|
||||
return LLVM::LLVMStructType::getLiteral(ctx, types);
|
||||
}
|
||||
35
python/triton/third_party/hip/lib/Conversion/TritonGPUROCMToLLVM/TypeConverter.h
vendored
Normal file
35
python/triton/third_party/hip/lib/Conversion/TritonGPUROCMToLLVM/TypeConverter.h
vendored
Normal file
@@ -0,0 +1,35 @@
|
||||
#ifndef TRITON_CONVERSION_TRITONGPUROCM_TO_LLVM_TYPECONVERTER_H
|
||||
#define TRITON_CONVERSION_TRITONGPUROCM_TO_LLVM_TYPECONVERTER_H
|
||||
|
||||
#include "mlir/Conversion/LLVMCommon/TypeConverter.h"
|
||||
#include "mlir/Dialect/LLVMIR/LLVMDialect.h"
|
||||
#include "triton/Conversion/MLIRTypes.h"
|
||||
|
||||
using namespace mlir;
|
||||
using namespace mlir::triton;
|
||||
|
||||
class TritonGPUToLLVMTypeConverter : public LLVMTypeConverter {
|
||||
public:
|
||||
using TypeConverter::convertType;
|
||||
|
||||
TritonGPUToLLVMTypeConverter(MLIRContext *ctx, LowerToLLVMOptions &option,
|
||||
const DataLayoutAnalysis *analysis = nullptr);
|
||||
|
||||
Type getElementTypeForStruct(RankedTensorType type);
|
||||
Type convertTritonPointerType(triton::PointerType type);
|
||||
|
||||
Value packLLElements(Location loc, ValueRange resultVals,
|
||||
ConversionPatternRewriter &rewriter, Type type);
|
||||
|
||||
SmallVector<Value> unpackLLElements(Location loc, Value llvmStruct,
|
||||
ConversionPatternRewriter &rewriter,
|
||||
Type type);
|
||||
|
||||
Type convertTritonTensorType(RankedTensorType type);
|
||||
|
||||
SmallVector<Value> packMfmaOperand(
|
||||
const SmallVector<Value> &inValues, Type srcTy,
|
||||
ConversionPatternRewriter &rewriter, Location loc);
|
||||
};
|
||||
|
||||
#endif
|
||||
397
python/triton/third_party/hip/lib/Conversion/TritonGPUROCMToLLVM/Utility.cpp
vendored
Normal file
397
python/triton/third_party/hip/lib/Conversion/TritonGPUROCMToLLVM/Utility.cpp
vendored
Normal file
@@ -0,0 +1,397 @@
|
||||
#include "Utility.h"
|
||||
#include "TypeConverter.h"
|
||||
|
||||
namespace mlir {
|
||||
|
||||
namespace LLVM {
|
||||
using namespace mlir::triton;
|
||||
|
||||
Value createConstantI32(Location loc, OpBuilder &rewriter, int32_t v) {
|
||||
auto i32ty = rewriter.getIntegerType(32);
|
||||
return rewriter.create<LLVM::ConstantOp>(loc, i32ty,
|
||||
IntegerAttr::get(i32ty, v));
|
||||
}
|
||||
|
||||
Value createConstantF32(Location loc, OpBuilder &rewriter, float v) {
|
||||
auto type = type::f32Ty(rewriter.getContext());
|
||||
return rewriter.create<LLVM::ConstantOp>(loc, type,
|
||||
rewriter.getF32FloatAttr(v));
|
||||
}
|
||||
|
||||
Value createConstantF64(Location loc, OpBuilder &rewriter, float v) {
|
||||
auto type = type::f64Ty(rewriter.getContext());
|
||||
return rewriter.create<LLVM::ConstantOp>(loc, type,
|
||||
rewriter.getF64FloatAttr(v));
|
||||
}
|
||||
|
||||
// Create an index type constant.
|
||||
Value createIndexConstant(OpBuilder &builder, Location loc,
|
||||
TypeConverter *converter, int64_t value) {
|
||||
Type ty = converter->convertType(builder.getIndexType());
|
||||
return builder.create<LLVM::ConstantOp>(loc, ty,
|
||||
builder.getIntegerAttr(ty, value));
|
||||
}
|
||||
|
||||
// Create an integer constant of \param width bits.
|
||||
Value createLLVMIntegerConstant(OpBuilder &builder, Location loc, short width,
|
||||
int64_t value) {
|
||||
Type ty = builder.getIntegerType(width);
|
||||
return builder.create<LLVM::ConstantOp>(loc, ty,
|
||||
builder.getIntegerAttr(ty, value));
|
||||
}
|
||||
|
||||
// A wrapper of LoadDSmemOp when vec = 1
|
||||
// (1) Get bitwidth from elemTy
|
||||
// (2) Create LoadDSmemOp
|
||||
// (3) Bitcast result from dataTy (u16/u32/u64) back to elemTy
|
||||
Value createLoadDSmem(Location loc, PatternRewriter &rewriter, Value addr,
|
||||
Value ctaId) {
|
||||
assert(addr.getType().isa<LLVMPointerType>() &&
|
||||
"addr must be a pointer type");
|
||||
auto ptrTy = addr.getType().cast<LLVMPointerType>();
|
||||
assert(ptrTy.getAddressSpace() == 3 && "Invalid addr space for load_dsmem");
|
||||
auto elemTy = ptrTy.getElementType();
|
||||
unsigned bitwidth = elemTy.getIntOrFloatBitWidth();
|
||||
Value ret =
|
||||
rewriter.create<triton::nvgpu::LoadDSmemOp>(loc, addr, ctaId, bitwidth);
|
||||
return bitcast(ret, elemTy);
|
||||
}
|
||||
|
||||
// A wrapper of LoadDSmemOp when vec > 1
|
||||
// (1) Get bitwidth from elemTy
|
||||
// (2) Create LoadDSmemOp and extract results from retStruct
|
||||
// (3) Bitcast results from dataTy (u16/u32/u64) back to elemTy
|
||||
SmallVector<Value> createLoadDSmem(Location loc, PatternRewriter &rewriter,
|
||||
Value addr, Value ctaId, unsigned vec) {
|
||||
assert(addr.getType().isa<LLVMPointerType>() &&
|
||||
"addr must be a pointer type");
|
||||
auto ptrTy = addr.getType().cast<LLVMPointerType>();
|
||||
assert(ptrTy.getAddressSpace() == 3 && "Invalid addr space for load_dsmem");
|
||||
auto elemTy = ptrTy.getElementType();
|
||||
unsigned bitwidth = elemTy.getIntOrFloatBitWidth();
|
||||
Value retStruct = rewriter.create<triton::nvgpu::LoadDSmemOp>(
|
||||
loc, addr, ctaId, bitwidth, vec);
|
||||
SmallVector<Value> retVals;
|
||||
for (unsigned i = 0; i < vec; ++i) {
|
||||
auto dataTy = rewriter.getIntegerType(bitwidth);
|
||||
Value data = extract_val(dataTy, retStruct, i);
|
||||
retVals.push_back(bitcast(data, elemTy));
|
||||
}
|
||||
return retVals;
|
||||
}
|
||||
|
||||
// A wrapper of StoreDSmemOp when vec = 1
|
||||
// (1) Get bitwidth from elemTy
|
||||
// (2) Bitcast value from elemTy to dataTy (u16/u32/u64)
|
||||
// (3) Create StoreDSmemOp
|
||||
void createStoreDSmem(Location loc, PatternRewriter &rewriter, Value addr,
|
||||
Value ctaId, Value value, Value pred) {
|
||||
assert(addr.getType().isa<LLVMPointerType>() &&
|
||||
"addr must be a pointer type");
|
||||
auto ptrTy = addr.getType().cast<LLVMPointerType>();
|
||||
assert(ptrTy.getAddressSpace() == 3 && "Invalid addr space for load_dsmem");
|
||||
auto elemTy = ptrTy.getElementType();
|
||||
unsigned bitwidth = elemTy.getIntOrFloatBitWidth();
|
||||
auto dataTy = rewriter.getIntegerType(bitwidth);
|
||||
Value data = bitcast(value, dataTy);
|
||||
rewriter.create<triton::nvgpu::StoreDSmemOp>(loc, addr, ctaId, data, pred);
|
||||
}
|
||||
|
||||
// A wrapper of StoreDSmemOp when vec = 1 and pred = 1
|
||||
void createStoreDSmem(Location loc, PatternRewriter &rewriter, Value addr,
|
||||
Value ctaId, Value value) {
|
||||
Value pred = int_val(/*width=*/1, 1);
|
||||
createStoreDSmem(loc, rewriter, addr, ctaId, value, pred);
|
||||
}
|
||||
|
||||
// A wrapper of StoreDSmemOp when vec > 1
|
||||
// (1) Get bitwidth from elemTy
|
||||
// (2) Bitcast values from elemTy to dataTy (u16/u32/u64)
|
||||
// (3) Create StoreDSmemOp
|
||||
void createStoreDSmem(Location loc, PatternRewriter &rewriter, Value addr,
|
||||
Value ctaId, ArrayRef<Value> values, Value pred) {
|
||||
assert(addr.getType().isa<LLVMPointerType>() &&
|
||||
"addr must be a pointer type");
|
||||
auto ptrTy = addr.getType().cast<LLVMPointerType>();
|
||||
assert(ptrTy.getAddressSpace() == 3 && "Invalid addr space for load_dsmem");
|
||||
auto elemTy = ptrTy.getElementType();
|
||||
unsigned bitwidth = elemTy.getIntOrFloatBitWidth();
|
||||
auto dataTy = rewriter.getIntegerType(bitwidth);
|
||||
SmallVector<Value> data;
|
||||
for (unsigned i = 0; i < values.size(); ++i)
|
||||
data.push_back(bitcast(values[i], dataTy));
|
||||
rewriter.create<triton::nvgpu::StoreDSmemOp>(loc, addr, ctaId, data, pred);
|
||||
}
|
||||
|
||||
// A wrapper of StoreDSmemOp when vec > 1 and pred = 1
|
||||
void createStoreDSmem(Location loc, PatternRewriter &rewriter, Value addr,
|
||||
Value ctaId, ArrayRef<Value> values) {
|
||||
Value pred = int_val(/*width=*/1, 1);
|
||||
createStoreDSmem(loc, rewriter, addr, ctaId, values, pred);
|
||||
}
|
||||
|
||||
SharedMemoryObject
|
||||
getSharedMemoryObjectFromStruct(Location loc, Value llvmStruct,
|
||||
ConversionPatternRewriter &rewriter) {
|
||||
ArrayRef<Type> types =
|
||||
llvmStruct.getType().cast<LLVM::LLVMStructType>().getBody();
|
||||
SmallVector<Value> elems(types.size());
|
||||
for (unsigned i = 0; i < types.size(); ++i) {
|
||||
Type type = types[i];
|
||||
elems[i] = extract_val(type, llvmStruct, i);
|
||||
}
|
||||
|
||||
auto rank = (elems.size() - 1) / 2;
|
||||
return {/*base=*/elems[0],
|
||||
/*strides=*/{elems.begin() + 1, elems.begin() + 1 + rank},
|
||||
/*offsets=*/{elems.begin() + 1 + rank, elems.end()}};
|
||||
}
|
||||
|
||||
SmallVector<Value>
|
||||
getStridesFromShapeAndOrder(ArrayRef<int64_t> shape, ArrayRef<unsigned> order,
|
||||
Location loc, ConversionPatternRewriter &rewriter) {
|
||||
auto rank = shape.size();
|
||||
SmallVector<Value> strides(rank);
|
||||
int64_t stride = 1;
|
||||
for (auto idx : order) {
|
||||
strides[idx] = i32_val(stride);
|
||||
stride *= shape[idx];
|
||||
}
|
||||
return strides;
|
||||
}
|
||||
|
||||
// Convert an \param index to a multi-dim coordinate given \param shape and
|
||||
// \param order.
|
||||
SmallVector<Value> delinearize(ConversionPatternRewriter &rewriter,
|
||||
Location loc, Value linear,
|
||||
ArrayRef<unsigned> shape,
|
||||
ArrayRef<unsigned> order) {
|
||||
unsigned rank = shape.size();
|
||||
assert(rank == order.size());
|
||||
auto reordered = reorder(shape, order);
|
||||
SmallVector<Value> reorderedMultiDim(rank);
|
||||
if (auto constantOp = linear.getDefiningOp<arith::ConstantOp>()) {
|
||||
unsigned intVal =
|
||||
constantOp.getValue().cast<IntegerAttr>().getValue().getSExtValue();
|
||||
reorderedMultiDim = delinearize(rewriter, loc, intVal, reordered);
|
||||
} else {
|
||||
reorderedMultiDim = delinearize(rewriter, loc, linear, reordered);
|
||||
}
|
||||
SmallVector<Value> multiDim(rank);
|
||||
for (unsigned i = 0; i < rank; ++i) {
|
||||
multiDim[order[i]] = reorderedMultiDim[i];
|
||||
}
|
||||
return multiDim;
|
||||
}
|
||||
|
||||
SmallVector<Value> delinearize(ConversionPatternRewriter &rewriter,
|
||||
Location loc, unsigned linear,
|
||||
ArrayRef<unsigned> shape) {
|
||||
unsigned rank = shape.size();
|
||||
assert(rank > 0);
|
||||
SmallVector<Value> multiDim(rank);
|
||||
unsigned remained = linear;
|
||||
for (auto &&en : llvm::enumerate(shape)) {
|
||||
unsigned dimSize = en.value();
|
||||
multiDim[en.index()] = i32_val(remained % dimSize);
|
||||
remained = remained / dimSize;
|
||||
}
|
||||
return multiDim;
|
||||
}
|
||||
|
||||
SmallVector<Value> delinearize(ConversionPatternRewriter &rewriter,
|
||||
Location loc, Value linear,
|
||||
ArrayRef<unsigned> shape) {
|
||||
unsigned rank = shape.size();
|
||||
assert(rank > 0);
|
||||
SmallVector<Value> multiDim(rank);
|
||||
Value remained = linear;
|
||||
for (auto &&en : llvm::enumerate(shape)) {
|
||||
Value dimSize = i32_val(en.value());
|
||||
multiDim[en.index()] = urem(remained, dimSize);
|
||||
remained = udiv(remained, dimSize);
|
||||
}
|
||||
return multiDim;
|
||||
}
|
||||
|
||||
Value linearize(ConversionPatternRewriter &rewriter, Location loc,
|
||||
ArrayRef<Value> multiDim, ArrayRef<unsigned> shape,
|
||||
ArrayRef<unsigned> order) {
|
||||
return linearize(rewriter, loc, reorder<Value>(multiDim, order),
|
||||
reorder<unsigned>(shape, order));
|
||||
}
|
||||
|
||||
Value linearize(ConversionPatternRewriter &rewriter, Location loc,
|
||||
ArrayRef<Value> multiDim, ArrayRef<unsigned> shape) {
|
||||
auto rank = multiDim.size();
|
||||
Value linear = i32_val(0);
|
||||
if (rank > 0) {
|
||||
linear = multiDim.back();
|
||||
for (auto [dim, dimShape] :
|
||||
llvm::reverse(llvm::zip(multiDim.drop_back(), shape.drop_back()))) {
|
||||
Value dimSize = i32_val(dimShape);
|
||||
linear = add(mul(linear, dimSize), dim);
|
||||
}
|
||||
}
|
||||
return linear;
|
||||
}
|
||||
|
||||
Value storeShared(ConversionPatternRewriter &rewriter, Location loc, Value ptr,
|
||||
Value val, Value pred) {
|
||||
#if 1
|
||||
store(val, ptr);
|
||||
return val;
|
||||
#else
|
||||
MLIRContext *ctx = rewriter.getContext();
|
||||
unsigned bits = val.getType().getIntOrFloatBitWidth();
|
||||
const char *c = bits == 64 ? "l" : (bits == 16 ? "h" : "r");
|
||||
|
||||
PTXBuilder builder;
|
||||
auto *ptrOpr = builder.newAddrOperand(ptr, "r");
|
||||
auto *valOpr = builder.newOperand(val, c);
|
||||
auto &st = builder.create<>("st")->shared().b(bits);
|
||||
st(ptrOpr, valOpr).predicate(pred, "b");
|
||||
return builder.launch(rewriter, loc, void_ty(ctx));
|
||||
#endif
|
||||
}
|
||||
|
||||
static Value commonShflSync(Location loc, ConversionPatternRewriter &rewriter,
|
||||
Value val, int i, const std::string &shuffleType,
|
||||
const std::string &clamp, Value laneId = Value()) {
|
||||
unsigned bits = val.getType().getIntOrFloatBitWidth();
|
||||
|
||||
if (bits == 64) {
|
||||
Type vecTy = vec_ty(f32_ty, 2);
|
||||
Value vec = bitcast(val, vecTy);
|
||||
Value val0 = extract_element(f32_ty, vec, i32_val(0));
|
||||
Value val1 = extract_element(f32_ty, vec, i32_val(1));
|
||||
val0 = commonShflSync(loc, rewriter, val0, i, shuffleType, clamp, laneId);
|
||||
val1 = commonShflSync(loc, rewriter, val1, i, shuffleType, clamp, laneId);
|
||||
vec = undef(vecTy);
|
||||
vec = insert_element(vecTy, vec, val0, i32_val(0));
|
||||
vec = insert_element(vecTy, vec, val1, i32_val(1));
|
||||
return bitcast(vec, val.getType());
|
||||
}
|
||||
|
||||
#if 1
|
||||
GCNBuilder builder;
|
||||
if (shuffleType == "bfly") {
|
||||
if (i > 16) {
|
||||
Value threadId =
|
||||
rewriter
|
||||
.create<UnrealizedConversionCastOp>(
|
||||
loc, TypeRange{i32_ty},
|
||||
ValueRange{rewriter.create<::mlir::gpu::ThreadIdOp>(
|
||||
loc, rewriter.getIndexType(), ::mlir::gpu::Dimension::x)})
|
||||
.getResult(0);
|
||||
Value stride = i32_val(32);
|
||||
Value byteOffset = i32_val(2);
|
||||
Value lineId = add(threadId, stride);
|
||||
Value permuteAddr = shl(lineId, byteOffset);
|
||||
auto shfl = builder.create("ds_permute_b32");
|
||||
auto dOpr = builder.newOperand("=v");
|
||||
auto addrOpr = builder.newOperand(permuteAddr, "v");
|
||||
auto aOpr = builder.newOperand(val, "v");
|
||||
(*shfl)(dOpr, addrOpr, aOpr);
|
||||
} else {
|
||||
// This map facilates the butterfly shuffle pattern for a stride less
|
||||
// than 16. The pattern stride is the key of the map.
|
||||
DenseMap<short, unsigned int> masks{
|
||||
{16, 0x401F}, {8, 0x201F}, {4, 0x101F}, {2, 0x081F}, {1, 0x041F}};
|
||||
auto shfl = builder.create("ds_swizzle_b32");
|
||||
auto dOpr = builder.newOperand("=v");
|
||||
auto aOpr = builder.newOperand(val, "v");
|
||||
auto maskOpr =
|
||||
builder.newConstantOperand("offset:" + std::to_string(masks[i]));
|
||||
(*shfl)(dOpr, aOpr, maskOpr);
|
||||
}
|
||||
} else { // shuffle_up
|
||||
assert(shuffleType == "up" && "Only shfl_bfly and shfl_up are supported");
|
||||
Value mask = icmp_slt(laneId, i32_val(i));
|
||||
Value delta = sub(laneId, i32_val(i));
|
||||
Value index = select(mask, laneId, delta);
|
||||
Value byteOffset = i32_val(2);
|
||||
Value permuteAddr = shl(index, byteOffset);
|
||||
auto shfl = builder.create("ds_bpermute_b32");
|
||||
auto dOpr = builder.newOperand("=v");
|
||||
auto addrOpr = builder.newOperand(permuteAddr, "v");
|
||||
auto aOpr = builder.newOperand(val, "v");
|
||||
(*shfl)(dOpr, addrOpr, aOpr);
|
||||
}
|
||||
auto swait = builder.create("s_waitcnt lgkmcnt(0)");
|
||||
(*swait)();
|
||||
return builder.launch(rewriter, loc, val.getType(), true);
|
||||
#else
|
||||
PTXBuilder builder;
|
||||
auto &shfl = builder.create("shfl.sync")->o(shuffleType).o("b32");
|
||||
auto *dOpr = builder.newOperand("=r");
|
||||
auto *aOpr = builder.newOperand(val, "r");
|
||||
auto *bOpr = builder.newConstantOperand(i);
|
||||
auto *cOpr = builder.newConstantOperand(clamp);
|
||||
auto *maskOpr = builder.newConstantOperand("0xffffffff");
|
||||
shfl(dOpr, aOpr, bOpr, cOpr, maskOpr);
|
||||
return builder.launch(rewriter, loc, val.getType(), false);
|
||||
#endif
|
||||
}
|
||||
|
||||
Value shflSync(Location loc, ConversionPatternRewriter &rewriter, Value val,
|
||||
int i) {
|
||||
return commonShflSync(loc, rewriter, val, i, "bfly", "0x1f");
|
||||
}
|
||||
|
||||
Value shflUpSync(Location loc, ConversionPatternRewriter &rewriter, Value val,
|
||||
int i, Value laneId) {
|
||||
return commonShflSync(loc, rewriter, val, i, "up", "0x0", laneId);
|
||||
}
|
||||
Value getSRegValue(OpBuilder &b, Location loc, const std::string &sRegStr) {
|
||||
PTXBuilder builder;
|
||||
auto &mov = builder.create("mov")->o("u32");
|
||||
auto *destOpr = builder.newOperand("=r");
|
||||
auto *sRegOpr = builder.newConstantOperand(sRegStr);
|
||||
mov(destOpr, sRegOpr);
|
||||
Value val = builder.launch(b, loc, b.getIntegerType(32), false);
|
||||
return val;
|
||||
}
|
||||
|
||||
Value addStringToModule(Location loc, ConversionPatternRewriter &rewriter,
|
||||
StringRef key, StringRef content) {
|
||||
auto moduleOp = rewriter.getBlock()->getParent()->getParentOfType<ModuleOp>();
|
||||
auto ctx = moduleOp.getContext();
|
||||
unsigned stringNumber = 0;
|
||||
SmallString<16> stringConstName;
|
||||
do {
|
||||
stringConstName.clear();
|
||||
(key + Twine(stringNumber++)).toStringRef(stringConstName);
|
||||
} while (moduleOp.lookupSymbol(stringConstName));
|
||||
|
||||
llvm::SmallString<64> contentStr(content);
|
||||
size_t contentSize = contentStr.size_in_bytes();
|
||||
auto globalType = LLVM::LLVMArrayType::get(i8_ty, contentSize);
|
||||
|
||||
LLVM::GlobalOp global;
|
||||
{
|
||||
ConversionPatternRewriter::InsertionGuard guard(rewriter);
|
||||
rewriter.setInsertionPointToStart(moduleOp.getBody());
|
||||
global = rewriter.create<LLVM::GlobalOp>(
|
||||
UnknownLoc::get(ctx), globalType,
|
||||
/*isConstant=*/true, LLVM::Linkage::Internal, stringConstName,
|
||||
rewriter.getStringAttr(contentStr));
|
||||
}
|
||||
|
||||
Value zero = i32_val(0);
|
||||
Value globalPtr =
|
||||
rewriter.create<LLVM::AddressOfOp>(UnknownLoc::get(ctx), global);
|
||||
Value stringStart =
|
||||
rewriter.create<LLVM::GEPOp>(UnknownLoc::get(ctx), ptr_ty(i8_ty),
|
||||
globalPtr, SmallVector<Value>({zero, zero}));
|
||||
return stringStart;
|
||||
}
|
||||
|
||||
} // namespace LLVM
|
||||
|
||||
bool isF8(Type eType) {
|
||||
return eType.isFloat8E5M2FNUZ() or eType.isFloat8E4M3FNUZ() or
|
||||
eType.isFloat8E5M2() or eType.isFloat8E5M2FNUZ();
|
||||
}
|
||||
|
||||
} // namespace mlir
|
||||
346
python/triton/third_party/hip/lib/Conversion/TritonGPUROCMToLLVM/Utility.h
vendored
Normal file
346
python/triton/third_party/hip/lib/Conversion/TritonGPUROCMToLLVM/Utility.h
vendored
Normal file
@@ -0,0 +1,346 @@
|
||||
#ifndef TRITON_CONVERSION_TRITONGPUROCM_TO_LLVM_UTILITY_H
|
||||
#define TRITON_CONVERSION_TRITONGPUROCM_TO_LLVM_UTILITY_H
|
||||
|
||||
#include "mlir/Conversion/LLVMCommon/Pattern.h"
|
||||
#include "mlir/Dialect/LLVMIR/LLVMDialect.h"
|
||||
#include "triton/AnalysisROCM/Utility.h"
|
||||
#include "triton/Conversion/MLIRTypes.h"
|
||||
#include "triton/Conversion/TritonGPUROCMToLLVM/PTXAsmFormat.h"
|
||||
#include "triton/Conversion/TritonGPUROCMToLLVM/GCNAsmFormat.h"
|
||||
|
||||
// Shortcuts for some commonly used LLVM ops to keep code simple and intuitive
|
||||
// Operators
|
||||
#define inttoptr(...) rewriter.create<LLVM::IntToPtrOp>(loc, __VA_ARGS__)
|
||||
#define ptrtoint(...) rewriter.create<LLVM::PtrToIntOp>(loc, __VA_ARGS__)
|
||||
#define zext(...) rewriter.create<LLVM::ZExtOp>(loc, __VA_ARGS__)
|
||||
#define trunc(...) rewriter.create<LLVM::TruncOp>(loc, __VA_ARGS__)
|
||||
#define sext(...) rewriter.create<LLVM::SExtOp>(loc, __VA_ARGS__)
|
||||
#define fpext(...) rewriter.create<LLVM::FPExtOp>(loc, __VA_ARGS__)
|
||||
#define trunc(...) rewriter.create<LLVM::TruncOp>(loc, __VA_ARGS__)
|
||||
#define udiv(...) rewriter.create<LLVM::UDivOp>(loc, __VA_ARGS__)
|
||||
#define urem(...) rewriter.create<LLVM::URemOp>(loc, __VA_ARGS__)
|
||||
#define add(...) rewriter.create<LLVM::AddOp>(loc, __VA_ARGS__)
|
||||
#define sub(...) rewriter.create<LLVM::SubOp>(loc, __VA_ARGS__)
|
||||
#define fadd(...) rewriter.create<LLVM::FAddOp>(loc, __VA_ARGS__)
|
||||
#define mul(...) rewriter.create<LLVM::MulOp>(loc, __VA_ARGS__)
|
||||
#define fmul(...) rewriter.create<LLVM::FMulOp>(loc, __VA_ARGS__)
|
||||
#define shl(...) rewriter.create<LLVM::ShlOp>(loc, __VA_ARGS__)
|
||||
#define lshr(...) rewriter.create<LLVM::LShrOp>(loc, __VA_ARGS__)
|
||||
#define ashr(...) rewriter.create<LLVM::AShrOp>(loc, __VA_ARGS__)
|
||||
#define smax(...) rewriter.create<LLVM::SMaxOp>(loc, __VA_ARGS__)
|
||||
#define umax(...) rewriter.create<LLVM::UMaxOp>(loc, __VA_ARGS__)
|
||||
#define fmax(...) rewriter.create<LLVM::MaxNumOp>(loc, __VA_ARGS__)
|
||||
#define smin(...) rewriter.create<LLVM::SMinOp>(loc, __VA_ARGS__)
|
||||
#define umin(...) rewriter.create<LLVM::UMinOp>(loc, __VA_ARGS__)
|
||||
#define fmin(...) rewriter.create<LLVM::MinNumOp>(loc, __VA_ARGS__)
|
||||
#define shl(...) rewriter.create<LLVM::ShlOp>(loc, __VA_ARGS__)
|
||||
#define lshr(...) rewriter.create<LLVM::LShrOp>(loc, __VA_ARGS__)
|
||||
#define and_(...) rewriter.create<LLVM::AndOp>(loc, __VA_ARGS__)
|
||||
#define or_(...) rewriter.create<LLVM::OrOp>(loc, __VA_ARGS__)
|
||||
#define xor_(...) rewriter.create<LLVM::XOrOp>(loc, __VA_ARGS__)
|
||||
#define or_(...) rewriter.create<LLVM::OrOp>(loc, __VA_ARGS__)
|
||||
#define bitcast(val__, type__) \
|
||||
rewriter.create<LLVM::BitcastOp>(loc, type__, val__)
|
||||
#define addrspacecast(val__, type__) \
|
||||
rewriter.create<LLVM::AddrSpaceCastOp>(loc, type__, val__)
|
||||
#define gep(...) rewriter.create<LLVM::GEPOp>(loc, __VA_ARGS__)
|
||||
#define ptr_ty(...) LLVM::LLVMPointerType::get(__VA_ARGS__)
|
||||
#define insert_val(...) rewriter.create<LLVM::InsertValueOp>(loc, __VA_ARGS__)
|
||||
#define extract_val(...) rewriter.create<LLVM::ExtractValueOp>(loc, __VA_ARGS__)
|
||||
#define insert_element(...) \
|
||||
rewriter.create<LLVM::InsertElementOp>(loc, __VA_ARGS__)
|
||||
#define extract_element(...) \
|
||||
rewriter.create<LLVM::ExtractElementOp>(loc, __VA_ARGS__)
|
||||
#define load(...) rewriter.create<LLVM::LoadOp>(loc, __VA_ARGS__)
|
||||
#define store(val, ptr) rewriter.create<LLVM::StoreOp>(loc, val, ptr)
|
||||
#define load_dsmem(...) LLVM::createLoadDSmem(loc, rewriter, __VA_ARGS__)
|
||||
#define store_dsmem(...) LLVM::createStoreDSmem(loc, rewriter, __VA_ARGS__)
|
||||
#define fcmp_ogt(lhs, rhs) \
|
||||
rewriter.create<LLVM::FCmpOp>(loc, rewriter.getI1Type(), \
|
||||
LLVM::FCmpPredicate::ogt, lhs, rhs)
|
||||
#define fcmp_olt(lhs, rhs) \
|
||||
rewriter.create<LLVM::FCmpOp>(loc, rewriter.getI1Type(), \
|
||||
LLVM::FCmpPredicate::olt, lhs, rhs)
|
||||
#define fcmp_eq(lhs, rhs) \
|
||||
rewriter.create<LLVM::FCmpOp>(loc, rewriter.getI1Type(), \
|
||||
LLVM::FCmpPredicate::oeq, lhs, rhs)
|
||||
#define icmp_eq(...) \
|
||||
rewriter.create<LLVM::ICmpOp>(loc, LLVM::ICmpPredicate::eq, __VA_ARGS__)
|
||||
#define icmp_ne(...) \
|
||||
rewriter.create<LLVM::ICmpOp>(loc, LLVM::ICmpPredicate::ne, __VA_ARGS__)
|
||||
#define icmp_slt(...) \
|
||||
rewriter.create<LLVM::ICmpOp>(loc, LLVM::ICmpPredicate::slt, __VA_ARGS__)
|
||||
#define icmp_sle(...) \
|
||||
rewriter.create<LLVM::ICmpOp>(loc, LLVM::ICmpPredicate::sle, __VA_ARGS__)
|
||||
#define icmp_sgt(...) \
|
||||
rewriter.create<LLVM::ICmpOp>(loc, LLVM::ICmpPredicate::sgt, __VA_ARGS__)
|
||||
#define icmp_sge(...) \
|
||||
rewriter.create<LLVM::ICmpOp>(loc, LLVM::ICmpPredicate::sge, __VA_ARGS__)
|
||||
#define icmp_ult(...) \
|
||||
rewriter.create<LLVM::ICmpOp>(loc, LLVM::ICmpPredicate::ult, __VA_ARGS__)
|
||||
#define icmp_ule(...) \
|
||||
rewriter.create<LLVM::ICmpOp>(loc, LLVM::ICmpPredicate::ule, __VA_ARGS__)
|
||||
#define icmp_ugt(...) \
|
||||
rewriter.create<LLVM::ICmpOp>(loc, LLVM::ICmpPredicate::ugt, __VA_ARGS__)
|
||||
#define icmp_uge(...) \
|
||||
rewriter.create<LLVM::ICmpOp>(loc, LLVM::ICmpPredicate::uge, __VA_ARGS__)
|
||||
#define select(...) rewriter.create<LLVM::SelectOp>(loc, __VA_ARGS__)
|
||||
#define address_of(...) rewriter.create<LLVM::AddressOfOp>(loc, __VA_ARGS__)
|
||||
#define barrier() rewriter.create<mlir::gpu::BarrierOp>(loc)
|
||||
#define barSync(rewriter, op, bar, numThreads) \
|
||||
do { \
|
||||
::mlir::triton::PTXBuilder ptxBuilder; \
|
||||
auto &barSyncOp = *ptxBuilder.create<>("bar.sync"); \
|
||||
barSyncOp(ptxBuilder.newConstantOperand(bar), \
|
||||
ptxBuilder.newConstantOperand(numThreads)); \
|
||||
auto voidTy = void_ty(op->getContext()); \
|
||||
ptxBuilder.launch(rewriter, op->getLoc(), voidTy); \
|
||||
} while (0)
|
||||
#define undef(...) rewriter.create<LLVM::UndefOp>(loc, __VA_ARGS__)
|
||||
#define null(...) rewriter.create<LLVM::NullOp>(loc, __VA_ARGS__)
|
||||
#define call(...) rewriter.create<LLVM::CallOp>(loc, __VA_ARGS__)
|
||||
|
||||
// Types
|
||||
#define int_ty(width) rewriter.getIntegerType(width)
|
||||
#define i64_ty rewriter.getIntegerType(64)
|
||||
#define i32_ty rewriter.getIntegerType(32)
|
||||
#define i16_ty rewriter.getIntegerType(16)
|
||||
#define i32_ty rewriter.getIntegerType(32)
|
||||
#define i64_ty rewriter.getIntegerType(64)
|
||||
#define ui32_ty rewriter.getIntegerType(32, false)
|
||||
#define f16_ty rewriter.getF16Type()
|
||||
#define bf16_ty rewriter.getBF16Type()
|
||||
#define i8_ty rewriter.getIntegerType(8)
|
||||
#define i1_ty rewriter.getI1Type()
|
||||
#define f32_ty rewriter.getF32Type()
|
||||
#define f64_ty rewriter.getF64Type()
|
||||
#define vec_ty(type, num) VectorType::get(num, type)
|
||||
#define void_ty(ctx) LLVM::LLVMVoidType::get(ctx)
|
||||
#define struct_ty(...) LLVM::LLVMStructType::getLiteral(ctx, __VA_ARGS__)
|
||||
#define array_ty(elemTy, count) LLVM::LLVMArrayType::get(elemTy, count)
|
||||
|
||||
// Constants
|
||||
#define f32_val(...) LLVM::createConstantF32(loc, rewriter, __VA_ARGS__)
|
||||
#define f64_val(...) LLVM::createConstantF64(loc, rewriter, __VA_ARGS__)
|
||||
#define i32_val(...) LLVM::createConstantI32(loc, rewriter, __VA_ARGS__)
|
||||
#define int_val(width, val) \
|
||||
LLVM::createLLVMIntegerConstant(rewriter, loc, width, val)
|
||||
#define tid_val() getThreadId(rewriter, loc)
|
||||
|
||||
// Attributes
|
||||
#define i32_arr_attr(...) rewriter.getI32ArrayAttr({__VA_ARGS__})
|
||||
#define i64_arr_attr(...) rewriter.getI64ArrayAttr({__VA_ARGS__})
|
||||
|
||||
namespace mlir {
|
||||
namespace triton {
|
||||
|
||||
// Delinearize supposing order is [0, 1, .. , n]
|
||||
template <typename T>
|
||||
llvm::SmallVector<T> getMultiDimIndexImpl(T linearIndex,
|
||||
llvm::ArrayRef<T> shape) {
|
||||
// shape: {a, b, c, d} -> accMul: {1, a, a*b, a*b*c}
|
||||
size_t rank = shape.size();
|
||||
T accMul = product(shape.drop_back());
|
||||
T linearRemain = linearIndex;
|
||||
llvm::SmallVector<T> multiDimIndex(rank);
|
||||
for (int i = rank - 1; i >= 0; --i) {
|
||||
multiDimIndex[i] = linearRemain / accMul;
|
||||
linearRemain = linearRemain % accMul;
|
||||
if (i != 0) {
|
||||
accMul = accMul / shape[i - 1];
|
||||
}
|
||||
}
|
||||
return multiDimIndex;
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
llvm::SmallVector<T> getMultiDimIndex(T linearIndex, llvm::ArrayRef<T> shape,
|
||||
llvm::ArrayRef<unsigned> order) {
|
||||
size_t rank = shape.size();
|
||||
assert(rank == order.size());
|
||||
auto reordered = reorder(shape, order);
|
||||
auto reorderedMultiDim = getMultiDimIndexImpl<T>(linearIndex, reordered);
|
||||
llvm::SmallVector<T> multiDim(rank);
|
||||
for (unsigned i = 0; i < rank; ++i) {
|
||||
multiDim[order[i]] = reorderedMultiDim[i];
|
||||
}
|
||||
return multiDim;
|
||||
}
|
||||
|
||||
// Linearize supposing order is [0, 1, .. , n]
|
||||
template <typename T>
|
||||
T getLinearIndexImpl(llvm::ArrayRef<T> multiDimIndex, llvm::ArrayRef<T> shape) {
|
||||
assert(multiDimIndex.size() == shape.size());
|
||||
// shape: {a, b, c, d} -> accMul: {1, a, a*b, a*b*c}
|
||||
size_t rank = shape.size();
|
||||
T accMul = product(shape.drop_back());
|
||||
T linearIndex = 0;
|
||||
for (int i = rank - 1; i >= 0; --i) {
|
||||
linearIndex += multiDimIndex[i] * accMul;
|
||||
if (i != 0) {
|
||||
accMul = accMul / shape[i - 1];
|
||||
}
|
||||
}
|
||||
return linearIndex;
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
T getLinearIndex(llvm::ArrayRef<T> multiDimIndex, llvm::ArrayRef<T> shape,
|
||||
llvm::ArrayRef<unsigned> order) {
|
||||
assert(shape.size() == order.size());
|
||||
return getLinearIndexImpl<T>(reorder(multiDimIndex, order),
|
||||
reorder(shape, order));
|
||||
}
|
||||
|
||||
} // namespace triton
|
||||
|
||||
namespace LLVM {
|
||||
using namespace mlir::triton;
|
||||
|
||||
Value createConstantI32(Location loc, OpBuilder &rewriter, int32_t v);
|
||||
|
||||
/// Create a 32-bit float constant.
|
||||
Value createConstantF32(Location loc, OpBuilder &rewriter, float v);
|
||||
|
||||
/// Create a 64-bit float constant.
|
||||
Value createConstantF64(Location loc, OpBuilder &rewriter, float v);
|
||||
|
||||
/// Create an index type constant.
|
||||
Value createIndexConstant(OpBuilder &builder, Location loc,
|
||||
TypeConverter *converter, int64_t value);
|
||||
|
||||
/// Create an integer constant of \param width bits.
|
||||
Value createLLVMIntegerConstant(OpBuilder &builder, Location loc, short width,
|
||||
int64_t value);
|
||||
|
||||
/// Usage of macro load_dsmem
|
||||
/// (1) load_dsmem(addr, ctaId)
|
||||
/// (2) load_dsmem(addr, ctaId, vec)
|
||||
Value createLoadDSmem(Location loc, PatternRewriter &rewriter, Value addr,
|
||||
Value ctaId);
|
||||
SmallVector<Value> createLoadDSmem(Location loc, PatternRewriter &rewriter,
|
||||
Value addr, Value ctaId, unsigned vec);
|
||||
|
||||
/// Usage of macro store_dsmem
|
||||
/// (1) store_dsmem(addr, ctaId, value, pred)
|
||||
/// (2) store_dsmem(addr, ctaId, value)
|
||||
/// (3) store_dsmem(addr, ctaId, values, pred)
|
||||
/// (4) store_dsmem(addr, ctaId, values)
|
||||
void createStoreDSmem(Location loc, PatternRewriter &rewriter, Value addr,
|
||||
Value ctaId, Value value, Value pred);
|
||||
void createStoreDSmem(Location loc, PatternRewriter &rewriter, Value addr,
|
||||
Value ctaId, Value value);
|
||||
void createStoreDSmem(Location loc, PatternRewriter &rewriter, Value addr,
|
||||
Value ctaId, ArrayRef<Value> values, Value pred);
|
||||
void createStoreDSmem(Location loc, PatternRewriter &rewriter, Value addr,
|
||||
Value ctaId, ArrayRef<Value> values);
|
||||
|
||||
/// Helper function to get strides from a given shape and its order
|
||||
SmallVector<Value>
|
||||
getStridesFromShapeAndOrder(ArrayRef<int64_t> shape, ArrayRef<unsigned> order,
|
||||
Location loc, ConversionPatternRewriter &rewriter);
|
||||
struct SharedMemoryObject {
|
||||
Value base; // i32 ptr. The start address of the shared memory object.
|
||||
// We need to store strides as Values but not integers because the
|
||||
// extract_slice instruction can take a slice at arbitrary offsets.
|
||||
// Take $a[16:32, 16:32] as an example, though we know the stride of $a[0] is
|
||||
// 32, we need to let the instruction that uses $a to be aware of that.
|
||||
// Otherwise, when we use $a, we only know that the shape of $a is 16x16. If
|
||||
// we store strides into an attribute array of integers, the information
|
||||
// cannot pass through block argument assignment because attributes are
|
||||
// associated with operations but not Values.
|
||||
// TODO(Keren): We may need to figure out a way to store strides as integers
|
||||
// if we want to support more optimizations.
|
||||
SmallVector<Value>
|
||||
strides; // i32 int. The strides of the shared memory object.
|
||||
SmallVector<Value> offsets; // i32 int. The offsets of the shared memory
|
||||
// objects from the originally allocated object.
|
||||
|
||||
SharedMemoryObject(Value base, ArrayRef<Value> strides,
|
||||
ArrayRef<Value> offsets)
|
||||
: base(base), strides(strides.begin(), strides.end()),
|
||||
offsets(offsets.begin(), offsets.end()) {}
|
||||
|
||||
SharedMemoryObject(Value base, ArrayRef<int64_t> shape,
|
||||
ArrayRef<unsigned> order, Location loc,
|
||||
ConversionPatternRewriter &rewriter)
|
||||
: base(base) {
|
||||
strides = getStridesFromShapeAndOrder(shape, order, loc, rewriter);
|
||||
offsets.append(order.size(), i32_val(0));
|
||||
}
|
||||
|
||||
SmallVector<Value> getElems() const {
|
||||
SmallVector<Value> elems;
|
||||
elems.push_back(base);
|
||||
elems.append(strides.begin(), strides.end());
|
||||
elems.append(offsets.begin(), offsets.end());
|
||||
return elems;
|
||||
}
|
||||
|
||||
SmallVector<Type> getTypes() const {
|
||||
SmallVector<Type> types;
|
||||
types.push_back(base.getType());
|
||||
types.append(strides.size(), IntegerType::get(base.getContext(), 32));
|
||||
types.append(offsets.size(), IntegerType::get(base.getContext(), 32));
|
||||
return types;
|
||||
}
|
||||
|
||||
Value getCSwizzleOffset(int order) const {
|
||||
assert(order >= 0 && order < strides.size());
|
||||
return offsets[order];
|
||||
}
|
||||
|
||||
Value getBaseBeforeSlice(int order, Location loc,
|
||||
ConversionPatternRewriter &rewriter) const {
|
||||
Value cSwizzleOffset = getCSwizzleOffset(order);
|
||||
Value offset = sub(i32_val(0), cSwizzleOffset);
|
||||
Type type = base.getType();
|
||||
return gep(type, base, offset);
|
||||
}
|
||||
};
|
||||
|
||||
SharedMemoryObject
|
||||
getSharedMemoryObjectFromStruct(Location loc, Value llvmStruct,
|
||||
ConversionPatternRewriter &rewriter);
|
||||
|
||||
// Convert an \param index to a multi-dim coordinate given \param shape and
|
||||
// \param order.
|
||||
SmallVector<Value> delinearize(ConversionPatternRewriter &rewriter,
|
||||
Location loc, Value linear,
|
||||
ArrayRef<unsigned> shape,
|
||||
ArrayRef<unsigned> order);
|
||||
|
||||
SmallVector<Value> delinearize(ConversionPatternRewriter &rewriter,
|
||||
Location loc, unsigned linear,
|
||||
ArrayRef<unsigned> shape);
|
||||
|
||||
SmallVector<Value> delinearize(ConversionPatternRewriter &rewriter,
|
||||
Location loc, Value linear,
|
||||
ArrayRef<unsigned> shape);
|
||||
|
||||
Value linearize(ConversionPatternRewriter &rewriter, Location loc,
|
||||
ArrayRef<Value> multiDim, ArrayRef<unsigned> shape,
|
||||
ArrayRef<unsigned> order);
|
||||
|
||||
Value linearize(ConversionPatternRewriter &rewriter, Location loc,
|
||||
ArrayRef<Value> multiDim, ArrayRef<unsigned> shape);
|
||||
|
||||
Value storeShared(ConversionPatternRewriter &rewriter, Location loc, Value ptr,
|
||||
Value val, Value pred);
|
||||
|
||||
Value shflSync(Location loc, ConversionPatternRewriter &rewriter, Value val,
|
||||
int i);
|
||||
Value shflUpSync(Location loc, ConversionPatternRewriter &rewriter, Value val,
|
||||
int i, Value laneId);
|
||||
|
||||
Value getSRegValue(OpBuilder &b, Location loc, const std::string &sRegStr);
|
||||
Value addStringToModule(Location loc, ConversionPatternRewriter &rewriter,
|
||||
StringRef key, StringRef content);
|
||||
|
||||
} // namespace LLVM
|
||||
|
||||
bool isF8(Type eType);
|
||||
|
||||
} // namespace mlir
|
||||
|
||||
#endif
|
||||
237
python/triton/third_party/hip/lib/Conversion/TritonGPUROCMToLLVM/ViewOpToLLVM.cpp
vendored
Normal file
237
python/triton/third_party/hip/lib/Conversion/TritonGPUROCMToLLVM/ViewOpToLLVM.cpp
vendored
Normal file
@@ -0,0 +1,237 @@
|
||||
#include "ViewOpToLLVM.h"
|
||||
|
||||
using namespace mlir;
|
||||
using namespace mlir::triton;
|
||||
|
||||
using ::mlir::LLVM::getSharedMemoryObjectFromStruct;
|
||||
using ::mlir::triton::gpu_rocm::getTotalElemsPerThread;
|
||||
|
||||
struct SplatOpConversion
|
||||
: public ConvertTritonGPUOpToLLVMPattern<triton::SplatOp> {
|
||||
using ConvertTritonGPUOpToLLVMPattern<
|
||||
triton::SplatOp>::ConvertTritonGPUOpToLLVMPattern;
|
||||
|
||||
// Convert SplatOp or arith::ConstantOp with SplatElementsAttr to a
|
||||
// LLVM::StructType value.
|
||||
//
|
||||
// @elemType: the element type in operand.
|
||||
// @resType: the return type of the Splat-like op.
|
||||
// @constVal: a LLVM::ConstantOp or other scalar value.
|
||||
static Value convertSplatLikeOp(Type elemType, Type resType, Value constVal,
|
||||
TritonGPUToLLVMTypeConverter *typeConverter,
|
||||
ConversionPatternRewriter &rewriter,
|
||||
Location loc) {
|
||||
auto tensorTy = resType.cast<RankedTensorType>();
|
||||
// Check the converted type for the tensor as depending on the encoding the
|
||||
// converter may pick different element types.
|
||||
auto srcType = typeConverter->convertType(tensorTy);
|
||||
if (auto structTy = dyn_cast<LLVM::LLVMStructType>(srcType))
|
||||
srcType = structTy.getBody()[0];
|
||||
// If the type sizes don't match we need to pack constants.
|
||||
if (srcType.isIntOrFloat() && constVal.getType().getIntOrFloatBitWidth() !=
|
||||
srcType.getIntOrFloatBitWidth()) {
|
||||
unsigned cstBitWidth = constVal.getType().getIntOrFloatBitWidth();
|
||||
unsigned srcBitWidth = srcType.getIntOrFloatBitWidth();
|
||||
assert(cstBitWidth <= srcBitWidth && srcBitWidth % cstBitWidth == 0);
|
||||
unsigned ratio = srcBitWidth / cstBitWidth;
|
||||
Type intTy = IntegerType::get(elemType.getContext(), cstBitWidth);
|
||||
VectorType vecType = VectorType::get(ratio, intTy);
|
||||
Value intCst = bitcast(constVal, intTy);
|
||||
Value vec = undef(vecType);
|
||||
for (unsigned i = 0; i < ratio; ++i)
|
||||
vec = insert_element(vecType, vec, intCst, int_val(32, i));
|
||||
constVal = vec;
|
||||
}
|
||||
auto llSrc = bitcast(constVal, srcType);
|
||||
size_t elemsPerThread = getTotalElemsPerThread(tensorTy);
|
||||
llvm::SmallVector<Value> elems(elemsPerThread, llSrc);
|
||||
return typeConverter->packLLElements(loc, elems, rewriter, resType);
|
||||
}
|
||||
|
||||
LogicalResult matchAndRewrite(triton::SplatOp op, OpAdaptor adaptor,
|
||||
ConversionPatternRewriter &rewriter) const {
|
||||
auto loc = op->getLoc();
|
||||
auto src = adaptor.getSrc();
|
||||
auto llStruct = convertSplatLikeOp(src.getType(), op.getType(), src,
|
||||
getTypeConverter(), rewriter, loc);
|
||||
rewriter.replaceOp(op, {llStruct});
|
||||
return success();
|
||||
}
|
||||
};
|
||||
|
||||
// This pattern helps to convert arith::ConstantOp(with SplatElementsAttr),
|
||||
// the logic is the same as triton::SplatOp, so the underlying implementation
|
||||
// is reused.
|
||||
struct ArithConstantSplatOpConversion
|
||||
: public ConvertTritonGPUOpToLLVMPattern<arith::ConstantOp> {
|
||||
using ConvertTritonGPUOpToLLVMPattern<
|
||||
arith::ConstantOp>::ConvertTritonGPUOpToLLVMPattern;
|
||||
|
||||
LogicalResult
|
||||
matchAndRewrite(arith::ConstantOp op, OpAdaptor adaptor,
|
||||
ConversionPatternRewriter &rewriter) const override {
|
||||
auto value = op.getValue();
|
||||
if (!value.dyn_cast<SplatElementsAttr>())
|
||||
return failure();
|
||||
|
||||
auto loc = op->getLoc();
|
||||
|
||||
LLVM::ConstantOp arithConstantOp;
|
||||
auto values = op.getValue().dyn_cast<SplatElementsAttr>();
|
||||
auto elemType = values.getElementType();
|
||||
|
||||
Attribute val;
|
||||
if (elemType.isBF16() || type::isFloat(elemType)) {
|
||||
val = values.getValues<FloatAttr>()[0];
|
||||
} else if (type::isInt(elemType)) {
|
||||
val = values.getValues<IntegerAttr>()[0];
|
||||
} else {
|
||||
llvm::errs() << "ArithConstantSplatOpConversion get unsupported type: "
|
||||
<< value.getType() << "\n";
|
||||
return failure();
|
||||
}
|
||||
|
||||
auto constOp = rewriter.create<LLVM::ConstantOp>(loc, elemType, val);
|
||||
auto llStruct = SplatOpConversion::convertSplatLikeOp(
|
||||
elemType, op.getType(), constOp, getTypeConverter(), rewriter, loc);
|
||||
rewriter.replaceOp(op, llStruct);
|
||||
|
||||
return success();
|
||||
}
|
||||
};
|
||||
|
||||
struct CatOpConversion : public ConvertTritonGPUOpToLLVMPattern<CatOp> {
|
||||
using OpAdaptor = typename CatOp::Adaptor;
|
||||
|
||||
explicit CatOpConversion(TritonGPUToLLVMTypeConverter &typeConverter,
|
||||
|
||||
PatternBenefit benefit = 1)
|
||||
: ConvertTritonGPUOpToLLVMPattern<CatOp>(typeConverter, benefit) {}
|
||||
|
||||
LogicalResult
|
||||
matchAndRewrite(CatOp op, OpAdaptor adaptor,
|
||||
ConversionPatternRewriter &rewriter) const override {
|
||||
Location loc = op->getLoc();
|
||||
auto resultTy = op.getType().template cast<RankedTensorType>();
|
||||
unsigned elems = getTotalElemsPerThread(resultTy);
|
||||
Type elemTy =
|
||||
this->getTypeConverter()->convertType(resultTy.getElementType());
|
||||
SmallVector<Type> types(elems, elemTy);
|
||||
// unpack input values
|
||||
auto lhsVals = getTypeConverter()->unpackLLElements(
|
||||
loc, adaptor.getLhs(), rewriter, op.getOperand(0).getType());
|
||||
auto rhsVals = getTypeConverter()->unpackLLElements(
|
||||
loc, adaptor.getRhs(), rewriter, op.getOperand(1).getType());
|
||||
// concatenate (and potentially reorder) values
|
||||
SmallVector<Value> retVals;
|
||||
for (Value v : lhsVals)
|
||||
retVals.push_back(v);
|
||||
for (Value v : rhsVals)
|
||||
retVals.push_back(v);
|
||||
// pack and replace
|
||||
Value ret =
|
||||
getTypeConverter()->packLLElements(loc, retVals, rewriter, resultTy);
|
||||
rewriter.replaceOp(op, ret);
|
||||
return success();
|
||||
}
|
||||
};
|
||||
|
||||
struct ViewOpConversion : public ConvertTritonGPUOpToLLVMPattern<ViewOp> {
|
||||
using OpAdaptor = typename ViewOp::Adaptor;
|
||||
explicit ViewOpConversion(TritonGPUToLLVMTypeConverter &typeConverter,
|
||||
|
||||
PatternBenefit benefit = 1)
|
||||
: ConvertTritonGPUOpToLLVMPattern<ViewOp>(typeConverter, benefit) {}
|
||||
|
||||
LogicalResult
|
||||
matchAndRewrite(ViewOp op, OpAdaptor adaptor,
|
||||
ConversionPatternRewriter &rewriter) const override {
|
||||
Location loc = op->getLoc();
|
||||
auto resultTy = op.getType().template cast<RankedTensorType>();
|
||||
auto vals = this->getTypeConverter()->unpackLLElements(
|
||||
loc, adaptor.getSrc(), rewriter, op.getOperand().getType());
|
||||
Value ret =
|
||||
this->getTypeConverter()->packLLElements(loc, vals, rewriter, resultTy);
|
||||
rewriter.replaceOp(op, ret);
|
||||
return success();
|
||||
}
|
||||
};
|
||||
|
||||
struct ExpandDimsOpConversion
|
||||
: public ConvertTritonGPUOpToLLVMPattern<ExpandDimsOp> {
|
||||
using OpAdaptor = typename ExpandDimsOp::Adaptor;
|
||||
explicit ExpandDimsOpConversion(TritonGPUToLLVMTypeConverter &typeConverter,
|
||||
|
||||
PatternBenefit benefit = 1)
|
||||
: ConvertTritonGPUOpToLLVMPattern<ExpandDimsOp>(typeConverter, benefit) {}
|
||||
|
||||
LogicalResult
|
||||
matchAndRewrite(ExpandDimsOp op, OpAdaptor adaptor,
|
||||
ConversionPatternRewriter &rewriter) const override {
|
||||
Location loc = op->getLoc();
|
||||
auto srcVals = this->getTypeConverter()->unpackLLElements(
|
||||
loc, adaptor.getSrc(), rewriter, op.getOperand().getType());
|
||||
|
||||
auto srcTy = op.getSrc().getType().cast<RankedTensorType>();
|
||||
auto resultTy = op.getType().template cast<RankedTensorType>();
|
||||
|
||||
assert(srcTy.getEncoding().isa<SliceEncodingAttr>() &&
|
||||
"ExpandDimsOp only support SliceEncodingAttr");
|
||||
auto srcLayout = srcTy.getEncoding().dyn_cast<SliceEncodingAttr>();
|
||||
auto resultLayout = resultTy.getEncoding();
|
||||
|
||||
auto srcOffsets = emitOffsetForLayout(srcLayout, srcTy);
|
||||
auto resultOffsets = emitOffsetForLayout(resultLayout, resultTy);
|
||||
DenseMap<SmallVector<unsigned>, Value, SmallVectorKeyInfo> srcValues;
|
||||
for (size_t i = 0; i < srcOffsets.size(); i++) {
|
||||
srcValues[srcOffsets[i]] = srcVals[i];
|
||||
}
|
||||
|
||||
SmallVector<Value> resultVals;
|
||||
for (size_t i = 0; i < resultOffsets.size(); i++) {
|
||||
auto offset = resultOffsets[i];
|
||||
offset.erase(offset.begin() + srcLayout.getDim());
|
||||
resultVals.push_back(srcValues.lookup(offset));
|
||||
}
|
||||
Value ret = this->getTypeConverter()->packLLElements(loc, resultVals,
|
||||
rewriter, resultTy);
|
||||
rewriter.replaceOp(op, ret);
|
||||
return success();
|
||||
}
|
||||
};
|
||||
|
||||
struct TransOpConversion
|
||||
: public ConvertTritonGPUOpToLLVMPattern<triton::TransOp> {
|
||||
using ConvertTritonGPUOpToLLVMPattern<
|
||||
triton::TransOp>::ConvertTritonGPUOpToLLVMPattern;
|
||||
|
||||
LogicalResult
|
||||
matchAndRewrite(triton::TransOp op, OpAdaptor adaptor,
|
||||
ConversionPatternRewriter &rewriter) const override {
|
||||
Location loc = op->getLoc();
|
||||
auto srcSmemObj =
|
||||
getSharedMemoryObjectFromStruct(loc, adaptor.getSrc(), rewriter);
|
||||
SmallVector<Value> dstStrides = {srcSmemObj.strides[1],
|
||||
srcSmemObj.strides[0]};
|
||||
SmallVector<Value> dstOffsets = {srcSmemObj.offsets[1],
|
||||
srcSmemObj.offsets[0]};
|
||||
auto dstSmemObj =
|
||||
SharedMemoryObject(srcSmemObj.base, dstStrides, dstOffsets);
|
||||
auto retVal = getStructFromSharedMemoryObject(loc, dstSmemObj, rewriter);
|
||||
rewriter.replaceOp(op, retVal);
|
||||
return success();
|
||||
}
|
||||
};
|
||||
|
||||
void populateViewOpToLLVMPatterns(TritonGPUToLLVMTypeConverter &typeConverter,
|
||||
RewritePatternSet &patterns, int numWarps,
|
||||
ModuleAxisInfoAnalysis &axisInfoAnalysis,
|
||||
ModuleAllocation &allocation,
|
||||
PatternBenefit benefit) {
|
||||
patterns.add<ViewOpConversion>(typeConverter, benefit);
|
||||
patterns.add<ExpandDimsOpConversion>(typeConverter, benefit);
|
||||
patterns.add<SplatOpConversion>(typeConverter, benefit);
|
||||
patterns.add<ArithConstantSplatOpConversion>(typeConverter, benefit);
|
||||
patterns.add<CatOpConversion>(typeConverter, benefit);
|
||||
patterns.add<TransOpConversion>(typeConverter, benefit);
|
||||
}
|
||||
15
python/triton/third_party/hip/lib/Conversion/TritonGPUROCMToLLVM/ViewOpToLLVM.h
vendored
Normal file
15
python/triton/third_party/hip/lib/Conversion/TritonGPUROCMToLLVM/ViewOpToLLVM.h
vendored
Normal file
@@ -0,0 +1,15 @@
|
||||
#ifndef TRITON_CONVERSION_TRITONGPUROCM_TO_LLVM_VIEW_OP_H
|
||||
#define TRITON_CONVERSION_TRITONGPUROCM_TO_LLVM_VIEW_OP_H
|
||||
|
||||
#include "TritonGPUToLLVMBase.h"
|
||||
|
||||
using namespace mlir;
|
||||
using namespace mlir::triton;
|
||||
|
||||
void populateViewOpToLLVMPatterns(TritonGPUToLLVMTypeConverter &typeConverter,
|
||||
RewritePatternSet &patterns, int numWarps,
|
||||
ModuleAxisInfoAnalysis &axisInfoAnalysis,
|
||||
ModuleAllocation &allocation,
|
||||
PatternBenefit benefit);
|
||||
|
||||
#endif
|
||||
21
python/triton/third_party/hip/lib/Conversion/TritonToTritonGPUROCM/CMakeLists.txt
vendored
Normal file
21
python/triton/third_party/hip/lib/Conversion/TritonToTritonGPUROCM/CMakeLists.txt
vendored
Normal file
@@ -0,0 +1,21 @@
|
||||
add_mlir_conversion_library(TritonToTritonGPUROCM
|
||||
TritonToTritonGPUPass.cpp
|
||||
|
||||
ADDITIONAL_HEADER_DIRS
|
||||
${PROJECT_SOURCE_DIR}/include/triton/Conversion/TritonToTritonGPUROCM
|
||||
${PROJECT_BINARY_DIR}/include/triton/Conversion/TritonToTritonGPUROCM
|
||||
|
||||
DEPENDS
|
||||
TritonToTritonGPUROCMConversionPassIncGen
|
||||
|
||||
LINK_COMPONENTS
|
||||
Core
|
||||
|
||||
LINK_LIBS PUBLIC
|
||||
MLIRIR
|
||||
MLIRPass
|
||||
MLIRTransforms
|
||||
TritonIR
|
||||
TritonGPUROCMIR
|
||||
TritonGPUROCMTransforms
|
||||
)
|
||||
977
python/triton/third_party/hip/lib/Conversion/TritonToTritonGPUROCM/TritonToTritonGPUPass.cpp
vendored
Normal file
977
python/triton/third_party/hip/lib/Conversion/TritonToTritonGPUROCM/TritonToTritonGPUPass.cpp
vendored
Normal file
@@ -0,0 +1,977 @@
|
||||
#include "triton/Conversion/TritonToTritonGPUROCM/TritonToTritonGPUPass.h"
|
||||
|
||||
#include "mlir/Dialect/Arith/IR/Arith.h"
|
||||
#include "mlir/Dialect/ControlFlow/IR/ControlFlowOps.h"
|
||||
#include "mlir/Dialect/GPU/IR/GPUDialect.h"
|
||||
#include "mlir/Dialect/Index/IR/IndexDialect.h"
|
||||
#include "mlir/Pass/Pass.h"
|
||||
#include "mlir/Transforms/DialectConversion.h"
|
||||
#include "triton/AnalysisROCM/Utility.h"
|
||||
#include "triton/Dialect/Triton/IR/Dialect.h"
|
||||
#include "triton/Dialect/TritonGPUROCM/IR/Dialect.h"
|
||||
#include "triton/Dialect/TritonGPUROCM/Transforms/TritonGPUConversion.h"
|
||||
#include "triton/Target/PTX/TmaMetadata.h"
|
||||
#include "llvm/ADT/APSInt.h"
|
||||
#include <numeric>
|
||||
|
||||
using namespace mlir;
|
||||
using namespace mlir::triton;
|
||||
|
||||
#define GEN_PASS_CLASSES
|
||||
#include "triton/Conversion/TritonToTritonGPUROCM/Passes.h.inc"
|
||||
|
||||
namespace {
|
||||
|
||||
// pass named attrs (e.g., tt.contiguity) from Triton to Triton
|
||||
static void addNamedAttrs(Operation *op, DictionaryAttr dictAttrs) {
|
||||
for (const NamedAttribute attr : dictAttrs.getValue())
|
||||
if (!op->hasAttr(attr.getName()))
|
||||
op->setAttr(attr.getName(), attr.getValue());
|
||||
}
|
||||
|
||||
template <class Op> class GenericOpPattern : public OpConversionPattern<Op> {
|
||||
public:
|
||||
using OpConversionPattern<Op>::OpConversionPattern;
|
||||
|
||||
LogicalResult
|
||||
matchAndRewrite(Op op, typename Op::Adaptor adaptor,
|
||||
ConversionPatternRewriter &rewriter) const override {
|
||||
Type retType = this->getTypeConverter()->convertType(op.getType());
|
||||
addNamedAttrs(
|
||||
rewriter.replaceOpWithNewOp<Op>(op, retType, adaptor.getOperands()),
|
||||
adaptor.getAttributes());
|
||||
return success();
|
||||
}
|
||||
};
|
||||
|
||||
template <class SrcOp, class DstOp>
|
||||
class ArithCmpPattern : public OpConversionPattern<SrcOp> {
|
||||
public:
|
||||
using OpConversionPattern<SrcOp>::OpConversionPattern;
|
||||
|
||||
LogicalResult
|
||||
matchAndRewrite(SrcOp op, typename SrcOp::Adaptor adaptor,
|
||||
ConversionPatternRewriter &rewriter) const override {
|
||||
Type retType = this->getTypeConverter()->convertType(op.getType());
|
||||
addNamedAttrs(
|
||||
rewriter.replaceOpWithNewOp<DstOp>(op, retType, adaptor.getPredicate(),
|
||||
adaptor.getLhs(), adaptor.getRhs()),
|
||||
adaptor.getAttributes());
|
||||
return success();
|
||||
}
|
||||
};
|
||||
|
||||
class ArithConstantPattern : public OpConversionPattern<arith::ConstantOp> {
|
||||
public:
|
||||
using OpConversionPattern<arith::ConstantOp>::OpConversionPattern;
|
||||
|
||||
LogicalResult
|
||||
matchAndRewrite(arith::ConstantOp op, OpAdaptor adaptor,
|
||||
ConversionPatternRewriter &rewriter) const override {
|
||||
Type retType = getTypeConverter()->convertType(op.getType());
|
||||
auto retShapedType = retType.cast<ShapedType>();
|
||||
auto value = adaptor.getValue().dyn_cast<DenseElementsAttr>();
|
||||
if (dyn_cast<RankedTensorType>(retShapedType)) {
|
||||
assert(value);
|
||||
if (value.getElementType().isInteger(1) && value.isSplat())
|
||||
// Workaround until https://reviews.llvm.org/D133743 is included.
|
||||
value =
|
||||
DenseElementsAttr::get(retShapedType, value.getSplatValue<bool>());
|
||||
else
|
||||
// This is a hack. We just want to add encoding
|
||||
value = value.reshape(retShapedType);
|
||||
}
|
||||
addNamedAttrs(rewriter.replaceOpWithNewOp<arith::ConstantOp>(
|
||||
op, retShapedType, value),
|
||||
adaptor.getAttributes());
|
||||
return success();
|
||||
}
|
||||
};
|
||||
|
||||
class ConvertArithOp : public ConversionPattern {
|
||||
public:
|
||||
ConvertArithOp(TritonGPUTypeConverter &typeConverter, MLIRContext *context)
|
||||
: ConversionPattern(typeConverter, MatchAnyOpTypeTag(), /*benefit=*/1,
|
||||
context) {}
|
||||
|
||||
LogicalResult
|
||||
matchAndRewrite(Operation *op, ArrayRef<Value> operands,
|
||||
ConversionPatternRewriter &rewriter) const override {
|
||||
Dialect *dialect = op->getDialect();
|
||||
if (dialect->getTypeID() != mlir::TypeID::get<arith::ArithDialect>())
|
||||
return failure();
|
||||
return success();
|
||||
}
|
||||
};
|
||||
|
||||
void populateArithPatternsAndLegality(TritonGPUTypeConverter &typeConverter,
|
||||
RewritePatternSet &patterns,
|
||||
TritonGPUConversionTarget &target) {
|
||||
// --------------
|
||||
// Add legality and rewrite pattern rules for operations
|
||||
// from the Arith dialect. The basic premise is that
|
||||
// Arith operations require both inputs to have the same
|
||||
// non-null encoding
|
||||
// --------------
|
||||
MLIRContext *context = patterns.getContext();
|
||||
// TODO: there's probably a better way to avoid adding all ops one-by-one
|
||||
patterns.add<
|
||||
ArithConstantPattern, GenericOpPattern<arith::AddIOp>,
|
||||
GenericOpPattern<arith::SubIOp>, GenericOpPattern<arith::MulIOp>,
|
||||
GenericOpPattern<arith::DivUIOp>, GenericOpPattern<arith::DivSIOp>,
|
||||
GenericOpPattern<arith::CeilDivUIOp>,
|
||||
GenericOpPattern<arith::CeilDivSIOp>,
|
||||
GenericOpPattern<arith::FloorDivSIOp>, GenericOpPattern<arith::RemUIOp>,
|
||||
GenericOpPattern<arith::RemSIOp>, GenericOpPattern<arith::AndIOp>,
|
||||
GenericOpPattern<arith::OrIOp>, GenericOpPattern<arith::XOrIOp>,
|
||||
GenericOpPattern<arith::ShLIOp>, GenericOpPattern<arith::ShRUIOp>,
|
||||
GenericOpPattern<arith::ShRSIOp>, // NegFOp
|
||||
// Floating point
|
||||
GenericOpPattern<arith::AddFOp>, GenericOpPattern<arith::SubFOp>,
|
||||
// MaxMin
|
||||
GenericOpPattern<arith::MaxFOp>, GenericOpPattern<arith::MaxSIOp>,
|
||||
GenericOpPattern<arith::MaxUIOp>, GenericOpPattern<arith::MinFOp>,
|
||||
GenericOpPattern<arith::MinSIOp>, GenericOpPattern<arith::MinUIOp>,
|
||||
// Floating point
|
||||
GenericOpPattern<arith::MulFOp>, GenericOpPattern<arith::DivFOp>,
|
||||
GenericOpPattern<arith::RemFOp>,
|
||||
// Cmp
|
||||
ArithCmpPattern<arith::CmpIOp, triton::gpu_rocm::CmpIOp>,
|
||||
ArithCmpPattern<arith::CmpFOp, triton::gpu_rocm::CmpFOp>,
|
||||
// Cast Ops
|
||||
GenericOpPattern<arith::TruncIOp>, GenericOpPattern<arith::TruncFOp>,
|
||||
GenericOpPattern<arith::ExtUIOp>, GenericOpPattern<arith::ExtSIOp>,
|
||||
GenericOpPattern<arith::ExtFOp>, GenericOpPattern<arith::SIToFPOp>,
|
||||
GenericOpPattern<arith::FPToSIOp>, GenericOpPattern<arith::FPToUIOp>,
|
||||
GenericOpPattern<arith::UIToFPOp>>(typeConverter, context);
|
||||
}
|
||||
|
||||
// this shouldn't exist if mlir's SelectOp checked encodings properly
|
||||
class StdSelectPattern : public OpConversionPattern<arith::SelectOp> {
|
||||
public:
|
||||
using OpConversionPattern<arith::SelectOp>::OpConversionPattern;
|
||||
|
||||
LogicalResult
|
||||
matchAndRewrite(arith::SelectOp op, OpAdaptor adaptor,
|
||||
ConversionPatternRewriter &rewriter) const override {
|
||||
Type retType = this->getTypeConverter()->convertType(op.getType());
|
||||
|
||||
Value cond = adaptor.getCondition();
|
||||
if (llvm::isa<RankedTensorType>(retType) &&
|
||||
!llvm::isa<TensorType>(cond.getType())) {
|
||||
// triton_gpu_rocm.select doesn't support scalar condition values, so add a
|
||||
// splat
|
||||
auto retTypeTensor = llvm::cast<RankedTensorType>(retType);
|
||||
auto retShape = retTypeTensor.getShape();
|
||||
auto retEncoding = retTypeTensor.getEncoding();
|
||||
Type condTy =
|
||||
RankedTensorType::get(retShape, cond.getType(), retEncoding);
|
||||
cond = rewriter.create<triton::SplatOp>(op.getLoc(), condTy, cond);
|
||||
}
|
||||
|
||||
addNamedAttrs(
|
||||
rewriter.replaceOpWithNewOp<triton::gpu_rocm::SelectOp>(
|
||||
op, retType, cond, adaptor.getTrueValue(), adaptor.getFalseValue()),
|
||||
adaptor.getAttributes());
|
||||
return success();
|
||||
}
|
||||
};
|
||||
|
||||
void populateStdPatternsAndLegality(TritonGPUTypeConverter &typeConverter,
|
||||
RewritePatternSet &patterns,
|
||||
TritonGPUConversionTarget &target) {
|
||||
MLIRContext *context = patterns.getContext();
|
||||
// Rewrite rule
|
||||
patterns.add<StdSelectPattern>(typeConverter, context);
|
||||
}
|
||||
|
||||
void populateMathPatternsAndLegality(TritonGPUTypeConverter &typeConverter,
|
||||
RewritePatternSet &patterns,
|
||||
TritonGPUConversionTarget &target) {
|
||||
MLIRContext *context = patterns.getContext();
|
||||
// Rewrite rule
|
||||
patterns.add<GenericOpPattern<math::ExpOp>, GenericOpPattern<math::CosOp>,
|
||||
GenericOpPattern<math::SinOp>, GenericOpPattern<math::LogOp>,
|
||||
GenericOpPattern<math::AbsFOp>, GenericOpPattern<math::AbsIOp>,
|
||||
GenericOpPattern<math::SqrtOp>>(typeConverter, context);
|
||||
}
|
||||
|
||||
//
|
||||
// Triton patterns
|
||||
//
|
||||
// TODO: Do we need to put them in anonymous namespace?
|
||||
struct TritonMakeRangePattern
|
||||
: public OpConversionPattern<triton::MakeRangeOp> {
|
||||
using OpConversionPattern<triton::MakeRangeOp>::OpConversionPattern;
|
||||
|
||||
LogicalResult
|
||||
matchAndRewrite(triton::MakeRangeOp op, OpAdaptor adaptor,
|
||||
ConversionPatternRewriter &rewriter) const override {
|
||||
Type retType = getTypeConverter()->convertType(op.getType());
|
||||
addNamedAttrs(rewriter.replaceOpWithNewOp<triton::MakeRangeOp>(
|
||||
op, retType, adaptor.getStart(), adaptor.getEnd()),
|
||||
adaptor.getAttributes());
|
||||
return success();
|
||||
}
|
||||
};
|
||||
|
||||
struct TritonExpandDimsPattern
|
||||
: public OpConversionPattern<triton::ExpandDimsOp> {
|
||||
using OpConversionPattern<triton::ExpandDimsOp>::OpConversionPattern;
|
||||
|
||||
LogicalResult
|
||||
matchAndRewrite(triton::ExpandDimsOp op, OpAdaptor adaptor,
|
||||
ConversionPatternRewriter &rewriter) const override {
|
||||
// Type retType = op.getType());
|
||||
RankedTensorType argType =
|
||||
adaptor.getSrc().getType().cast<RankedTensorType>();
|
||||
Attribute _argEncoding = argType.getEncoding();
|
||||
if (!_argEncoding)
|
||||
return failure();
|
||||
auto argEncoding = _argEncoding.cast<triton::gpu_rocm::BlockedEncodingAttr>();
|
||||
// return shape
|
||||
auto retShape = argType.getShape().vec();
|
||||
retShape.insert(retShape.begin() + op.getAxis(), 1);
|
||||
// return encoding
|
||||
auto retSizePerThread = argEncoding.getSizePerThread().vec();
|
||||
retSizePerThread.insert(retSizePerThread.begin() + op.getAxis(), 1);
|
||||
auto retThreadsPerWarp = argEncoding.getThreadsPerWarp().vec();
|
||||
retThreadsPerWarp.insert(retThreadsPerWarp.begin() + op.getAxis(), 1);
|
||||
auto retWarpsPerCTA = argEncoding.getWarpsPerCTA().vec();
|
||||
retWarpsPerCTA.insert(retWarpsPerCTA.begin() + op.getAxis(), 1);
|
||||
SmallVector<unsigned, 4> retOrder(retShape.size());
|
||||
std::iota(retOrder.begin(), retOrder.end(), 0);
|
||||
|
||||
auto argCTALayout = argEncoding.getCTALayout();
|
||||
auto retCTAsPerCGA = insertOne(argCTALayout.getCTAsPerCGA(), op.getAxis());
|
||||
auto retCTASplitNum =
|
||||
insertOne(argCTALayout.getCTASplitNum(), op.getAxis());
|
||||
auto retCTAOrder = insertOrder(argCTALayout.getCTAOrder(), op.getAxis());
|
||||
auto retCTALayout = triton::gpu_rocm::CTALayoutAttr::get(
|
||||
getContext(), retCTAsPerCGA, retCTASplitNum, retCTAOrder);
|
||||
|
||||
triton::gpu_rocm::BlockedEncodingAttr retEncoding =
|
||||
triton::gpu_rocm::BlockedEncodingAttr::get(getContext(), retSizePerThread,
|
||||
retThreadsPerWarp, retWarpsPerCTA,
|
||||
retOrder, retCTALayout);
|
||||
// convert operand to slice of return type
|
||||
Attribute newArgEncoding = triton::gpu_rocm::SliceEncodingAttr::get(
|
||||
getContext(), op.getAxis(), retEncoding);
|
||||
RankedTensorType newArgType = RankedTensorType::get(
|
||||
argType.getShape(), argType.getElementType(), newArgEncoding);
|
||||
// construct new op
|
||||
auto newSrc = rewriter.create<triton::gpu_rocm::ConvertLayoutOp>(
|
||||
op.getLoc(), newArgType, adaptor.getSrc());
|
||||
addNamedAttrs(rewriter.replaceOpWithNewOp<triton::ExpandDimsOp>(
|
||||
op, newSrc, adaptor.getAxis()),
|
||||
adaptor.getAttributes());
|
||||
return success();
|
||||
}
|
||||
|
||||
private:
|
||||
template <typename T>
|
||||
SmallVector<T> insertOne(ArrayRef<T> vec, unsigned axis) const {
|
||||
SmallVector<T> res(vec.begin(), vec.end());
|
||||
res.insert(res.begin() + axis, 1);
|
||||
return res;
|
||||
}
|
||||
|
||||
// Example: order = [ 0, 2, 1, 3], dim = 2
|
||||
// resOrder = [2, 0, 3, 1, 4]
|
||||
SmallVector<unsigned> insertOrder(ArrayRef<unsigned> order,
|
||||
unsigned axis) const {
|
||||
SmallVector<unsigned> resOrder(order.begin(), order.end());
|
||||
for (unsigned i = 0; i < resOrder.size(); ++i)
|
||||
if (resOrder[i] >= axis)
|
||||
++resOrder[i];
|
||||
resOrder.insert(resOrder.begin(), axis);
|
||||
return resOrder;
|
||||
}
|
||||
};
|
||||
|
||||
struct TritonDotPattern : public OpConversionPattern<triton::DotOp> {
|
||||
using OpConversionPattern<triton::DotOp>::OpConversionPattern;
|
||||
|
||||
LogicalResult
|
||||
matchAndRewrite(triton::DotOp op, OpAdaptor adaptor,
|
||||
ConversionPatternRewriter &rewriter) const override {
|
||||
RankedTensorType origType = op.getType().cast<RankedTensorType>();
|
||||
auto origShape = origType.getShape();
|
||||
auto typeConverter = getTypeConverter<TritonGPUTypeConverter>();
|
||||
int numWarps = typeConverter->getNumWarps();
|
||||
int threadsPerWarp = typeConverter->getThreadsPerWarp();
|
||||
int numCTAs = typeConverter->getNumCTAs();
|
||||
|
||||
SmallVector<unsigned> retSizePerThread = {1, 1};
|
||||
if (origShape[0] * origShape[1] / (numWarps * threadsPerWarp) >= 4)
|
||||
retSizePerThread = {2, 2};
|
||||
if (origShape[0] * origShape[1] / (numWarps * threadsPerWarp) >= 16)
|
||||
retSizePerThread = {4, 4};
|
||||
SmallVector<unsigned> retOrder = {1, 0};
|
||||
Attribute dEncoding = triton::gpu_rocm::BlockedEncodingAttr::get(
|
||||
getContext(), origShape, retSizePerThread, retOrder, numWarps,
|
||||
threadsPerWarp, numCTAs);
|
||||
RankedTensorType retType =
|
||||
RankedTensorType::get(origShape, origType.getElementType(), dEncoding);
|
||||
// a & b must be of smem layout
|
||||
auto aType = adaptor.getA().getType().cast<RankedTensorType>();
|
||||
auto bType = adaptor.getB().getType().cast<RankedTensorType>();
|
||||
Type aEltType = aType.getElementType();
|
||||
Type bEltType = bType.getElementType();
|
||||
Attribute aEncoding = aType.getEncoding();
|
||||
Attribute bEncoding = bType.getEncoding();
|
||||
if (!aEncoding || !bEncoding)
|
||||
return failure();
|
||||
Value a = adaptor.getA();
|
||||
Value b = adaptor.getB();
|
||||
Value c = adaptor.getC();
|
||||
if (!aEncoding.isa<triton::gpu_rocm::DotOperandEncodingAttr>()) {
|
||||
Attribute encoding = triton::gpu_rocm::DotOperandEncodingAttr::get(
|
||||
getContext(), 0, dEncoding, aEltType);
|
||||
auto dstType =
|
||||
RankedTensorType::get(aType.getShape(), aEltType, encoding);
|
||||
a = rewriter.create<triton::gpu_rocm::ConvertLayoutOp>(a.getLoc(), dstType, a);
|
||||
}
|
||||
if (!bEncoding.isa<triton::gpu_rocm::DotOperandEncodingAttr>()) {
|
||||
Attribute encoding = triton::gpu_rocm::DotOperandEncodingAttr::get(
|
||||
getContext(), 1, dEncoding, bEltType);
|
||||
auto dstType =
|
||||
RankedTensorType::get(bType.getShape(), bEltType, encoding);
|
||||
b = rewriter.create<triton::gpu_rocm::ConvertLayoutOp>(b.getLoc(), dstType, b);
|
||||
}
|
||||
c = rewriter.create<triton::gpu_rocm::ConvertLayoutOp>(c.getLoc(), retType, c);
|
||||
|
||||
addNamedAttrs(rewriter.replaceOpWithNewOp<triton::DotOp>(
|
||||
op, retType, a, b, c, adaptor.getAllowTF32()),
|
||||
adaptor.getAttributes());
|
||||
return success();
|
||||
}
|
||||
};
|
||||
|
||||
struct TritonCatPattern : public OpConversionPattern<triton::CatOp> {
|
||||
|
||||
using OpConversionPattern<triton::CatOp>::OpConversionPattern;
|
||||
|
||||
LogicalResult
|
||||
matchAndRewrite(triton::CatOp op, OpAdaptor adaptor,
|
||||
ConversionPatternRewriter &rewriter) const override {
|
||||
// The cat op satisfy two conditions:
|
||||
// 1. output.numel = lhs.numel + rhs.numel
|
||||
// 2. output.total_elems_per_thread =
|
||||
// next_power_of_2(lhs.total_elems_per_thread + rhs.total_elems_per_thread)
|
||||
// For now, this behaves like generic, but this
|
||||
// will evolve when we add support for `can_reorder=False`.
|
||||
auto retType = this->getTypeConverter()
|
||||
->convertType(op.getType())
|
||||
.cast<RankedTensorType>();
|
||||
auto retEncoding =
|
||||
retType.getEncoding().cast<triton::gpu_rocm::BlockedEncodingAttr>();
|
||||
auto lhsType = adaptor.getLhs().getType().cast<RankedTensorType>();
|
||||
auto rhsType = adaptor.getRhs().getType().cast<RankedTensorType>();
|
||||
auto lhsTotalElemsPerThread = triton::gpu_rocm::getTotalElemsPerThread(lhsType);
|
||||
auto rhsTotalElemsPerThread = triton::gpu_rocm::getTotalElemsPerThread(rhsType);
|
||||
auto retTotalElemsPerThread = triton::gpu_rocm::getTotalElemsPerThread(retType);
|
||||
auto retShape = retType.getShape();
|
||||
auto retOrder = retEncoding.getOrder();
|
||||
auto retSizePerThread = retEncoding.getSizePerThread();
|
||||
auto retThreadsPerWarp = retEncoding.getThreadsPerWarp();
|
||||
auto retWarpsPerCTA = retEncoding.getWarpsPerCTA();
|
||||
// Get new retSizePerThread if ret elems per thread is not enough.
|
||||
// We have to round it up to the next power of 2 due to triton's tensor size
|
||||
// constraint.
|
||||
auto newRetTotalElemsPerThread =
|
||||
nextPowOf2(lhsTotalElemsPerThread + rhsTotalElemsPerThread);
|
||||
auto newRetSizePerThread = retSizePerThread.vec();
|
||||
newRetSizePerThread[retOrder[0]] *=
|
||||
newRetTotalElemsPerThread / retTotalElemsPerThread;
|
||||
triton::gpu_rocm::BlockedEncodingAttr newRetEncoding =
|
||||
triton::gpu_rocm::BlockedEncodingAttr::get(
|
||||
getContext(), newRetSizePerThread, retThreadsPerWarp,
|
||||
retWarpsPerCTA, retOrder, retEncoding.getCTALayout());
|
||||
auto newRetType = RankedTensorType::get(retShape, retType.getElementType(),
|
||||
newRetEncoding);
|
||||
addNamedAttrs(rewriter.replaceOpWithNewOp<triton::CatOp>(
|
||||
op, newRetType, adaptor.getOperands()),
|
||||
adaptor.getAttributes());
|
||||
return success();
|
||||
}
|
||||
};
|
||||
|
||||
struct TritonTransPattern : public OpConversionPattern<triton::TransOp> {
|
||||
|
||||
using OpConversionPattern<triton::TransOp>::OpConversionPattern;
|
||||
|
||||
LogicalResult
|
||||
matchAndRewrite(triton::TransOp op, OpAdaptor adaptor,
|
||||
ConversionPatternRewriter &rewriter) const override {
|
||||
Value src = adaptor.getSrc();
|
||||
auto srcType = src.getType().cast<RankedTensorType>();
|
||||
Attribute srcEncoding = srcType.getEncoding();
|
||||
if (!srcEncoding)
|
||||
return failure();
|
||||
if (!srcEncoding.isa<triton::gpu_rocm::SharedEncodingAttr>()) {
|
||||
// TODO: end-to-end correctness is broken if
|
||||
// the input is blocked and the output is shared
|
||||
// with different order. Maybe a backend issue in BlockedToShared?
|
||||
SmallVector<unsigned> order = {1, 0};
|
||||
if (auto srcBlockedEncoding =
|
||||
srcEncoding.dyn_cast<triton::gpu_rocm::BlockedEncodingAttr>())
|
||||
llvm::copy(srcBlockedEncoding.getOrder(), order.begin());
|
||||
// TODO(Qingyi): need to check whether the CTALayout of srcEncoding should
|
||||
// be used here. For tests where numCTAs = 1, this is not a problem since
|
||||
// all CTALayouts are the same.
|
||||
auto CTALayout = triton::gpu_rocm::getCTALayout(srcEncoding);
|
||||
srcEncoding = triton::gpu_rocm::SharedEncodingAttr::get(getContext(), 1, 1, 1,
|
||||
order, CTALayout);
|
||||
srcType = RankedTensorType::get(srcType.getShape(),
|
||||
srcType.getElementType(), srcEncoding);
|
||||
src = rewriter.create<triton::gpu_rocm::ConvertLayoutOp>(src.getLoc(), srcType,
|
||||
src);
|
||||
}
|
||||
addNamedAttrs(rewriter.replaceOpWithNewOp<triton::TransOp>(op, src),
|
||||
adaptor.getAttributes());
|
||||
return success();
|
||||
}
|
||||
};
|
||||
|
||||
struct TritonLoadPattern : public OpConversionPattern<triton::LoadOp> {
|
||||
using OpConversionPattern<triton::LoadOp>::OpConversionPattern;
|
||||
|
||||
LogicalResult
|
||||
matchAndRewrite(triton::LoadOp op, OpAdaptor adaptor,
|
||||
ConversionPatternRewriter &rewriter) const override {
|
||||
addNamedAttrs(rewriter.replaceOpWithNewOp<triton::LoadOp>(
|
||||
op, typeConverter->convertType(op.getType()),
|
||||
adaptor.getPtr(), adaptor.getMask(), adaptor.getOther(),
|
||||
adaptor.getBoundaryCheckAttr(), adaptor.getPaddingAttr(),
|
||||
adaptor.getCache(), adaptor.getEvict(),
|
||||
adaptor.getIsVolatile()),
|
||||
adaptor.getAttributes());
|
||||
return success();
|
||||
}
|
||||
};
|
||||
|
||||
struct TritonStorePattern : public OpConversionPattern<triton::StoreOp> {
|
||||
using OpConversionPattern<triton::StoreOp>::OpConversionPattern;
|
||||
|
||||
LogicalResult
|
||||
matchAndRewrite(triton::StoreOp op, OpAdaptor adaptor,
|
||||
ConversionPatternRewriter &rewriter) const override {
|
||||
addNamedAttrs(rewriter.replaceOpWithNewOp<triton::StoreOp>(
|
||||
op, adaptor.getPtr(), adaptor.getValue(),
|
||||
adaptor.getMask(), adaptor.getCache(),
|
||||
adaptor.getEvict()),
|
||||
adaptor.getAttributes());
|
||||
return success();
|
||||
}
|
||||
};
|
||||
|
||||
struct TritonAtomicCASPattern
|
||||
: public OpConversionPattern<triton::AtomicCASOp> {
|
||||
using OpConversionPattern<triton::AtomicCASOp>::OpConversionPattern;
|
||||
|
||||
LogicalResult
|
||||
matchAndRewrite(triton::AtomicCASOp op, OpAdaptor adaptor,
|
||||
ConversionPatternRewriter &rewriter) const override {
|
||||
addNamedAttrs(rewriter.replaceOpWithNewOp<triton::AtomicCASOp>(
|
||||
op, typeConverter->convertType(op.getType()),
|
||||
adaptor.getPtr(), adaptor.getCmp(), adaptor.getVal(),
|
||||
op.getSem()),
|
||||
adaptor.getAttributes());
|
||||
return success();
|
||||
}
|
||||
};
|
||||
|
||||
struct TritonAtomicRMWPattern
|
||||
: public OpConversionPattern<triton::AtomicRMWOp> {
|
||||
using OpConversionPattern<triton::AtomicRMWOp>::OpConversionPattern;
|
||||
|
||||
LogicalResult
|
||||
matchAndRewrite(triton::AtomicRMWOp op, OpAdaptor adaptor,
|
||||
ConversionPatternRewriter &rewriter) const override {
|
||||
addNamedAttrs(rewriter.replaceOpWithNewOp<triton::AtomicRMWOp>(
|
||||
op, typeConverter->convertType(op.getType()),
|
||||
adaptor.getAtomicRmwOp(), adaptor.getPtr(),
|
||||
adaptor.getVal(), adaptor.getMask(), op.getSem()),
|
||||
adaptor.getAttributes());
|
||||
return success();
|
||||
}
|
||||
};
|
||||
|
||||
template <class Op>
|
||||
struct TritonGenericPattern : public OpConversionPattern<Op> {
|
||||
using OpConversionPattern<Op>::OpConversionPattern;
|
||||
|
||||
LogicalResult
|
||||
matchAndRewrite(Op op, typename Op::Adaptor adaptor,
|
||||
ConversionPatternRewriter &rewriter) const override {
|
||||
Type retType = this->getTypeConverter()->convertType(op.getType());
|
||||
addNamedAttrs(
|
||||
rewriter.replaceOpWithNewOp<Op>(op, retType, adaptor.getOperands()),
|
||||
adaptor.getAttributes());
|
||||
return success();
|
||||
}
|
||||
};
|
||||
|
||||
struct TritonBroadcastPattern
|
||||
: public OpConversionPattern<triton::BroadcastOp> {
|
||||
using OpConversionPattern<triton::BroadcastOp>::OpConversionPattern;
|
||||
|
||||
// This creates a tensor with the new shape but the argument's layout
|
||||
LogicalResult
|
||||
matchAndRewrite(BroadcastOp op, OpAdaptor adaptor,
|
||||
ConversionPatternRewriter &rewriter) const override {
|
||||
auto srcType = adaptor.getSrc().getType().cast<RankedTensorType>();
|
||||
auto srcEncoding = srcType.getEncoding();
|
||||
if (!srcEncoding)
|
||||
return failure();
|
||||
auto opType = op.getType().cast<RankedTensorType>();
|
||||
Type retType = RankedTensorType::get(opType.getShape(),
|
||||
opType.getElementType(), srcEncoding);
|
||||
// Type retType = this->getTypeConverter()->convertType(op.getType());
|
||||
addNamedAttrs(rewriter.replaceOpWithNewOp<triton::BroadcastOp>(
|
||||
op, retType, adaptor.getOperands()),
|
||||
adaptor.getAttributes());
|
||||
return success();
|
||||
}
|
||||
};
|
||||
|
||||
struct TritonReducePattern : public OpConversionPattern<triton::ReduceOp> {
|
||||
using OpConversionPattern<triton::ReduceOp>::OpConversionPattern;
|
||||
|
||||
LogicalResult
|
||||
matchAndRewrite(triton::ReduceOp op, OpAdaptor adaptor,
|
||||
ConversionPatternRewriter &rewriter) const override {
|
||||
auto newReduce = rewriter.create<triton::ReduceOp>(
|
||||
op.getLoc(), adaptor.getOperands(), adaptor.getAxis());
|
||||
addNamedAttrs(newReduce, adaptor.getAttributes());
|
||||
|
||||
auto &newCombineOp = newReduce.getCombineOp();
|
||||
rewriter.cloneRegionBefore(op.getCombineOp(), newCombineOp,
|
||||
newCombineOp.end());
|
||||
rewriter.replaceOp(op, newReduce.getResult());
|
||||
return success();
|
||||
}
|
||||
};
|
||||
|
||||
struct TritonReduceReturnPattern
|
||||
: public OpConversionPattern<triton::ReduceReturnOp> {
|
||||
using OpConversionPattern<triton::ReduceReturnOp>::OpConversionPattern;
|
||||
|
||||
LogicalResult
|
||||
matchAndRewrite(triton::ReduceReturnOp op, OpAdaptor adaptor,
|
||||
ConversionPatternRewriter &rewriter) const override {
|
||||
addNamedAttrs(rewriter.replaceOpWithNewOp<triton::ReduceReturnOp>(
|
||||
op, adaptor.getResult()),
|
||||
adaptor.getAttributes());
|
||||
return success();
|
||||
}
|
||||
};
|
||||
|
||||
struct TritonScanPattern : public OpConversionPattern<triton::ScanOp> {
|
||||
using OpConversionPattern<triton::ScanOp>::OpConversionPattern;
|
||||
|
||||
LogicalResult
|
||||
matchAndRewrite(triton::ScanOp op, OpAdaptor adaptor,
|
||||
ConversionPatternRewriter &rewriter) const override {
|
||||
auto newScan = rewriter.create<triton::ScanOp>(
|
||||
op.getLoc(), adaptor.getOperands(), adaptor.getAxis());
|
||||
addNamedAttrs(newScan, adaptor.getAttributes());
|
||||
|
||||
auto &newCombineOp = newScan.getCombineOp();
|
||||
rewriter.cloneRegionBefore(op.getCombineOp(), newCombineOp,
|
||||
newCombineOp.end());
|
||||
rewriter.replaceOp(op, newScan.getResult());
|
||||
return success();
|
||||
}
|
||||
};
|
||||
|
||||
struct TritonScanReturnPattern
|
||||
: public OpConversionPattern<triton::ScanReturnOp> {
|
||||
using OpConversionPattern<triton::ScanReturnOp>::OpConversionPattern;
|
||||
|
||||
LogicalResult
|
||||
matchAndRewrite(triton::ScanReturnOp op, OpAdaptor adaptor,
|
||||
ConversionPatternRewriter &rewriter) const override {
|
||||
addNamedAttrs(rewriter.replaceOpWithNewOp<triton::ScanReturnOp>(
|
||||
op, adaptor.getResult()),
|
||||
adaptor.getAttributes());
|
||||
return success();
|
||||
}
|
||||
};
|
||||
|
||||
struct TritonPrintPattern : public OpConversionPattern<triton::PrintOp> {
|
||||
using OpConversionPattern<triton::PrintOp>::OpConversionPattern;
|
||||
|
||||
LogicalResult
|
||||
matchAndRewrite(triton::PrintOp op, typename triton::PrintOp::Adaptor adaptor,
|
||||
ConversionPatternRewriter &rewriter) const override {
|
||||
addNamedAttrs(rewriter.replaceOpWithNewOp<triton::PrintOp>(
|
||||
op, op.getPrefixAttr(), adaptor.getOperands()),
|
||||
adaptor.getAttributes());
|
||||
return success();
|
||||
}
|
||||
};
|
||||
|
||||
struct TritonAssertPattern : public OpConversionPattern<triton::AssertOp> {
|
||||
using OpConversionPattern<triton::AssertOp>::OpConversionPattern;
|
||||
|
||||
LogicalResult
|
||||
matchAndRewrite(triton::AssertOp op,
|
||||
typename triton::AssertOp::Adaptor adaptor,
|
||||
ConversionPatternRewriter &rewriter) const override {
|
||||
addNamedAttrs(rewriter.replaceOpWithNewOp<triton::AssertOp>(
|
||||
op, adaptor.getCondition(), op.getMessageAttr(),
|
||||
op.getFileAttr(), op.getFuncAttr(), op.getLineAttr()),
|
||||
adaptor.getAttributes());
|
||||
return success();
|
||||
}
|
||||
};
|
||||
|
||||
class TritonFuncOpPattern : public OpConversionPattern<triton::FuncOp> {
|
||||
public:
|
||||
using OpConversionPattern<triton::FuncOp>::OpConversionPattern;
|
||||
|
||||
LogicalResult
|
||||
matchAndRewrite(triton::FuncOp op, OpAdaptor adaptor,
|
||||
ConversionPatternRewriter &rewriter) const override {
|
||||
auto converter = getTypeConverter();
|
||||
auto newOp = rewriter.replaceOpWithNewOp<triton::FuncOp>(
|
||||
op, op.getName(), op.getFunctionType());
|
||||
addNamedAttrs(newOp, adaptor.getAttributes());
|
||||
rewriter.inlineRegionBefore(op.getBody(), newOp.getBody(),
|
||||
newOp.getBody().end());
|
||||
if (failed(rewriter.convertRegionTypes(&newOp.getBody(), *converter)))
|
||||
return failure();
|
||||
|
||||
return success();
|
||||
}
|
||||
};
|
||||
|
||||
class TritonCallOpPattern : public OpConversionPattern<triton::CallOp> {
|
||||
public:
|
||||
using OpConversionPattern<triton::CallOp>::OpConversionPattern;
|
||||
|
||||
LogicalResult
|
||||
matchAndRewrite(triton::CallOp op, OpAdaptor adaptor,
|
||||
ConversionPatternRewriter &rewriter) const override {
|
||||
auto newOp = rewriter.replaceOpWithNewOp<triton::CallOp>(
|
||||
op, op.getCallee(), op.getResultTypes(), adaptor.getOperands());
|
||||
addNamedAttrs(newOp, adaptor.getAttributes());
|
||||
return success();
|
||||
}
|
||||
};
|
||||
|
||||
class TritonReturnOpPattern : public OpConversionPattern<ReturnOp> {
|
||||
public:
|
||||
using OpConversionPattern<ReturnOp>::OpConversionPattern;
|
||||
|
||||
LogicalResult
|
||||
matchAndRewrite(ReturnOp op, ReturnOp::Adaptor adaptor,
|
||||
ConversionPatternRewriter &rewriter) const override {
|
||||
rewriter.replaceOpWithNewOp<ReturnOp>(op, adaptor.getOperands());
|
||||
return success();
|
||||
}
|
||||
};
|
||||
|
||||
void populateTritonPatterns(TritonGPUTypeConverter &typeConverter,
|
||||
RewritePatternSet &patterns, unsigned numCTAs) {
|
||||
MLIRContext *context = patterns.getContext();
|
||||
patterns.insert< // TODO: view should have custom pattern that views the
|
||||
// layout
|
||||
TritonGenericPattern<triton::AdvanceOp>,
|
||||
TritonGenericPattern<triton::MakeTensorPtrOp>,
|
||||
TritonGenericPattern<triton::ViewOp>,
|
||||
TritonGenericPattern<triton::BitcastOp>,
|
||||
TritonGenericPattern<triton::FpToFpOp>,
|
||||
TritonGenericPattern<triton::IntToPtrOp>,
|
||||
TritonGenericPattern<triton::PtrToIntOp>,
|
||||
TritonGenericPattern<triton::SplatOp>, TritonBroadcastPattern,
|
||||
TritonGenericPattern<triton::AddPtrOp>, TritonCatPattern,
|
||||
TritonGenericPattern<triton::ElementwiseInlineAsmOp>, TritonReducePattern,
|
||||
TritonReduceReturnPattern, TritonScanPattern, TritonScanReturnPattern,
|
||||
TritonTransPattern, TritonExpandDimsPattern, TritonMakeRangePattern,
|
||||
TritonDotPattern, TritonLoadPattern, TritonStorePattern,
|
||||
TritonGenericPattern<triton::ExternElementwiseOp>, TritonPrintPattern,
|
||||
TritonAssertPattern, TritonAtomicRMWPattern, TritonFuncOpPattern,
|
||||
TritonReturnOpPattern, TritonCallOpPattern>(typeConverter, context);
|
||||
}
|
||||
|
||||
//
|
||||
// SCF patterns
|
||||
//
|
||||
// This is borrowed from ConvertForOpTypes in
|
||||
// SCF/Transforms/StructuralTypeConversions.cpp
|
||||
struct SCFForPattern : public OpConversionPattern<scf::ForOp> {
|
||||
using OpConversionPattern<scf::ForOp>::OpConversionPattern;
|
||||
// Ref: ConvertForOpTypes
|
||||
LogicalResult
|
||||
matchAndRewrite(scf::ForOp op, OpAdaptor adaptor,
|
||||
ConversionPatternRewriter &rewriter) const override {
|
||||
auto newOp =
|
||||
cast<scf::ForOp>(rewriter.cloneWithoutRegions(*op.getOperation()));
|
||||
rewriter.inlineRegionBefore(op.getLoopBody(), newOp.getLoopBody(),
|
||||
newOp.getLoopBody().end());
|
||||
|
||||
// Now, update all the types.
|
||||
|
||||
// Convert the types of block arguments within the given region. This
|
||||
// replaces each block with a new block containing the updated signature.
|
||||
// The entry block may have a special conversion if `entryConversion` is
|
||||
// provided. On success, the new entry block to the region is returned for
|
||||
// convenience. Otherwise, failure is returned.
|
||||
if (failed(rewriter.convertRegionTypes(&newOp.getLoopBody(),
|
||||
*getTypeConverter()))) {
|
||||
return rewriter.notifyMatchFailure(op, "could not convert body types");
|
||||
}
|
||||
// Change the clone to use the updated operands. We could have cloned with
|
||||
// a IRMapping, but this seems a bit more direct.
|
||||
newOp->setOperands(adaptor.getOperands());
|
||||
// Update the result types to the new converted types.
|
||||
SmallVector<Type> newResultTypes;
|
||||
for (Type type : op.getResultTypes()) {
|
||||
Type newType = typeConverter->convertType(type);
|
||||
if (!newType)
|
||||
return rewriter.notifyMatchFailure(op, "not a 1:1 type conversion");
|
||||
newResultTypes.push_back(newType);
|
||||
}
|
||||
for (auto t : llvm::zip(newOp.getResults(), newResultTypes))
|
||||
std::get<0>(t).setType(std::get<1>(t));
|
||||
|
||||
rewriter.replaceOp(op, newOp.getResults());
|
||||
|
||||
return success();
|
||||
}
|
||||
};
|
||||
|
||||
struct SCFYieldPattern : public OpConversionPattern<scf::YieldOp> {
|
||||
using OpConversionPattern<scf::YieldOp>::OpConversionPattern;
|
||||
|
||||
LogicalResult
|
||||
matchAndRewrite(scf::YieldOp op, OpAdaptor adaptor,
|
||||
ConversionPatternRewriter &rewriter) const override {
|
||||
// rewriter.setInsertionPointToEnd(rewriter.getInsertionBlock());
|
||||
// rewriter.create<scf::YieldOp>(op.getLoc(), adaptor.getOperands());
|
||||
// op.erase();
|
||||
addNamedAttrs(
|
||||
rewriter.replaceOpWithNewOp<scf::YieldOp>(op, adaptor.getOperands()),
|
||||
adaptor.getAttributes());
|
||||
return success();
|
||||
}
|
||||
};
|
||||
|
||||
// This is borrowed from ConvertFIfOpTypes in
|
||||
// SCF/Transforms/StructuralTypeConversions.cpp
|
||||
class SCFIfPattern : public OpConversionPattern<scf::IfOp> {
|
||||
public:
|
||||
using OpConversionPattern<scf::IfOp>::OpConversionPattern;
|
||||
LogicalResult
|
||||
matchAndRewrite(scf::IfOp op, OpAdaptor adaptor,
|
||||
ConversionPatternRewriter &rewriter) const override {
|
||||
// TODO: Generalize this to any type conversion, not just 1:1.
|
||||
//
|
||||
// We need to implement something more sophisticated here that tracks which
|
||||
// types convert to which other types and does the appropriate
|
||||
// materialization logic.
|
||||
// For example, it's possible that one result type converts to 0 types and
|
||||
// another to 2 types, so newResultTypes would at least be the right size to
|
||||
// not crash in the llvm::zip call below, but then we would set the the
|
||||
// wrong type on the SSA values! These edge cases are also why we cannot
|
||||
// safely use the TypeConverter::convertTypes helper here.
|
||||
SmallVector<Type> newResultTypes;
|
||||
for (auto type : op.getResultTypes()) {
|
||||
Type newType = typeConverter->convertType(type);
|
||||
if (!newType)
|
||||
return rewriter.notifyMatchFailure(op, "not a 1:1 type conversion");
|
||||
newResultTypes.push_back(newType);
|
||||
}
|
||||
|
||||
// See comments in the ForOp pattern for why we clone without regions and
|
||||
// then inline.
|
||||
scf::IfOp newOp =
|
||||
cast<scf::IfOp>(rewriter.cloneWithoutRegions(*op.getOperation()));
|
||||
rewriter.inlineRegionBefore(op.getThenRegion(), newOp.getThenRegion(),
|
||||
newOp.getThenRegion().end());
|
||||
rewriter.inlineRegionBefore(op.getElseRegion(), newOp.getElseRegion(),
|
||||
newOp.getElseRegion().end());
|
||||
|
||||
// Update the operands and types.
|
||||
newOp->setOperands(adaptor.getOperands());
|
||||
for (auto t : llvm::zip(newOp.getResults(), newResultTypes))
|
||||
std::get<0>(t).setType(std::get<1>(t));
|
||||
rewriter.replaceOp(op, newOp.getResults());
|
||||
return success();
|
||||
}
|
||||
};
|
||||
|
||||
// This is borrowed from ConvertFIfOpTypes in
|
||||
// SCF/Transforms/StructuralTypeConversions.cpp
|
||||
class SCFWhilePattern : public OpConversionPattern<scf::WhileOp> {
|
||||
public:
|
||||
using OpConversionPattern<scf::WhileOp>::OpConversionPattern;
|
||||
|
||||
LogicalResult
|
||||
matchAndRewrite(scf::WhileOp op, OpAdaptor adaptor,
|
||||
ConversionPatternRewriter &rewriter) const override {
|
||||
auto *converter = getTypeConverter();
|
||||
assert(converter);
|
||||
SmallVector<Type> newResultTypes;
|
||||
if (failed(converter->convertTypes(op.getResultTypes(), newResultTypes)))
|
||||
return failure();
|
||||
|
||||
auto newOp = rewriter.create<scf::WhileOp>(op.getLoc(), newResultTypes,
|
||||
adaptor.getOperands());
|
||||
for (auto i : {0u, 1u}) {
|
||||
auto &dstRegion = newOp.getRegion(i);
|
||||
rewriter.inlineRegionBefore(op.getRegion(i), dstRegion, dstRegion.end());
|
||||
if (failed(rewriter.convertRegionTypes(&dstRegion, *converter)))
|
||||
return rewriter.notifyMatchFailure(op, "could not convert body types");
|
||||
}
|
||||
rewriter.replaceOp(op, newOp.getResults());
|
||||
return success();
|
||||
}
|
||||
};
|
||||
|
||||
class SCFConditionPattern : public OpConversionPattern<scf::ConditionOp> {
|
||||
public:
|
||||
using OpConversionPattern<scf::ConditionOp>::OpConversionPattern;
|
||||
LogicalResult
|
||||
matchAndRewrite(scf::ConditionOp op, OpAdaptor adaptor,
|
||||
ConversionPatternRewriter &rewriter) const override {
|
||||
rewriter.updateRootInPlace(
|
||||
op, [&]() { op->setOperands(adaptor.getOperands()); });
|
||||
return success();
|
||||
}
|
||||
};
|
||||
|
||||
void populateSCFPatterns(TritonGPUTypeConverter &typeConverter,
|
||||
RewritePatternSet &patterns) {
|
||||
MLIRContext *context = patterns.getContext();
|
||||
patterns.add<SCFYieldPattern, SCFForPattern, SCFIfPattern, SCFWhilePattern,
|
||||
SCFConditionPattern>(typeConverter, context);
|
||||
}
|
||||
|
||||
// CF
|
||||
|
||||
class CFBranchPattern : public OpConversionPattern<cf::BranchOp> {
|
||||
public:
|
||||
using OpConversionPattern<cf::BranchOp>::OpConversionPattern;
|
||||
|
||||
LogicalResult
|
||||
matchAndRewrite(cf::BranchOp op, cf::BranchOp::Adaptor adaptor,
|
||||
ConversionPatternRewriter &rewriter) const override {
|
||||
auto converter = getTypeConverter();
|
||||
auto newOp = rewriter.replaceOpWithNewOp<cf::BranchOp>(
|
||||
op, op.getSuccessor(), adaptor.getOperands());
|
||||
if (failed(rewriter.convertRegionTypes(newOp.getSuccessor()->getParent(),
|
||||
*converter)))
|
||||
return failure();
|
||||
return success();
|
||||
}
|
||||
};
|
||||
|
||||
class CFCondBranchPattern : public OpConversionPattern<cf::CondBranchOp> {
|
||||
public:
|
||||
using OpConversionPattern<cf::CondBranchOp>::OpConversionPattern;
|
||||
|
||||
LogicalResult
|
||||
matchAndRewrite(cf::CondBranchOp op, cf::CondBranchOp::Adaptor adaptor,
|
||||
ConversionPatternRewriter &rewriter) const override {
|
||||
auto converter = getTypeConverter();
|
||||
auto newOp = rewriter.replaceOpWithNewOp<cf::CondBranchOp>(
|
||||
op, adaptor.getCondition(), op.getTrueDest(),
|
||||
adaptor.getTrueDestOperands(), op.getFalseDest(),
|
||||
adaptor.getFalseDestOperands());
|
||||
addNamedAttrs(newOp, adaptor.getAttributes());
|
||||
|
||||
if (failed(rewriter.convertRegionTypes(newOp.getTrueDest()->getParent(),
|
||||
*converter)))
|
||||
return failure();
|
||||
if (failed(rewriter.convertRegionTypes(newOp.getFalseDest()->getParent(),
|
||||
*converter)))
|
||||
return failure();
|
||||
return success();
|
||||
}
|
||||
};
|
||||
|
||||
void populateCFPatterns(TritonGPUTypeConverter &typeConverter,
|
||||
RewritePatternSet &patterns) {
|
||||
MLIRContext *context = patterns.getContext();
|
||||
patterns.add<CFCondBranchPattern, CFBranchPattern>(typeConverter, context);
|
||||
}
|
||||
//
|
||||
|
||||
class ConvertTritonToTritonGPUROCM
|
||||
: public ConvertTritonToTritonGPUROCMBase<ConvertTritonToTritonGPUROCM> {
|
||||
public:
|
||||
ConvertTritonToTritonGPUROCM() = default;
|
||||
// constructor with some parameters set explicitly.
|
||||
ConvertTritonToTritonGPUROCM(int numWarps, int threadsPerWarp, int numCTAs,
|
||||
int computeCapability) {
|
||||
this->numWarps = numWarps;
|
||||
this->threadsPerWarp = threadsPerWarp;
|
||||
this->numCTAs = numCTAs;
|
||||
this->computeCapability = computeCapability;
|
||||
}
|
||||
|
||||
void runOnOperation() override {
|
||||
MLIRContext *context = &getContext();
|
||||
ModuleOp mod = getOperation();
|
||||
// type converter
|
||||
TritonGPUTypeConverter typeConverter(context, numWarps, threadsPerWarp,
|
||||
numCTAs);
|
||||
TritonGPUConversionTarget target(*context, typeConverter);
|
||||
// rewrite patterns
|
||||
RewritePatternSet patterns(context);
|
||||
// add rules
|
||||
populateStdPatternsAndLegality(typeConverter, patterns, target);
|
||||
populateArithPatternsAndLegality(typeConverter, patterns, target);
|
||||
populateMathPatternsAndLegality(typeConverter, patterns, target);
|
||||
populateTritonPatterns(typeConverter, patterns, numCTAs);
|
||||
// TODO: can we use
|
||||
// mlir::scf::populateSCFStructurealTypeConversionsAndLegality(...) here?
|
||||
populateSCFPatterns(typeConverter, patterns);
|
||||
populateCFPatterns(typeConverter, patterns);
|
||||
|
||||
if (failed(applyPartialConversion(mod, target, std::move(patterns))))
|
||||
return signalPassFailure();
|
||||
|
||||
auto inti = llvm::APSInt(32, false);
|
||||
auto i32_ty = IntegerType::get(mod->getContext(), 32);
|
||||
|
||||
mod->setAttr(
|
||||
AttrNumWarpsName,
|
||||
IntegerAttr::get(i32_ty, llvm::APInt(32, numWarps.getValue())));
|
||||
mod->setAttr(
|
||||
AttrNumThreadsPerWarp,
|
||||
IntegerAttr::get(i32_ty, llvm::APInt(32, threadsPerWarp.getValue())));
|
||||
|
||||
mod->setAttr(AttrNumCTAsName,
|
||||
IntegerAttr::get(i32_ty, llvm::APInt(32, numCTAs.getValue())));
|
||||
|
||||
mod->setAttr(AttrComputeCapabilityName,
|
||||
IntegerAttr::get(
|
||||
i32_ty, llvm::APInt(32, computeCapability.getValue())));
|
||||
|
||||
// update layouts
|
||||
// broadcast src => multicast, dst => broadcasted
|
||||
// if (failed(target.refineLayouts(mod, numWarps)))
|
||||
// return signalPassFailure();
|
||||
}
|
||||
};
|
||||
|
||||
} // namespace
|
||||
|
||||
std::unique_ptr<OperationPass<ModuleOp>>
|
||||
mlir::triton::createConvertTritonToTritonGPUROCMPass(int numWarps,
|
||||
int threadsPerWarp,
|
||||
int numCTAs,
|
||||
int computeCapability) {
|
||||
return std::make_unique<::ConvertTritonToTritonGPUROCM>(
|
||||
numWarps, threadsPerWarp, numCTAs, computeCapability);
|
||||
}
|
||||
|
||||
std::unique_ptr<OperationPass<ModuleOp>>
|
||||
mlir::triton::createConvertTritonToTritonGPUROCMPass() {
|
||||
return std::make_unique<::ConvertTritonToTritonGPUROCM>();
|
||||
}
|
||||
1
python/triton/third_party/hip/lib/Dialect/CMakeLists.txt
vendored
Normal file
1
python/triton/third_party/hip/lib/Dialect/CMakeLists.txt
vendored
Normal file
@@ -0,0 +1 @@
|
||||
add_subdirectory(TritonGPUROCM)
|
||||
2
python/triton/third_party/hip/lib/Dialect/TritonGPUROCM/CMakeLists.txt
vendored
Normal file
2
python/triton/third_party/hip/lib/Dialect/TritonGPUROCM/CMakeLists.txt
vendored
Normal file
@@ -0,0 +1,2 @@
|
||||
add_subdirectory(IR)
|
||||
add_subdirectory(Transforms)
|
||||
13
python/triton/third_party/hip/lib/Dialect/TritonGPUROCM/IR/CMakeLists.txt
vendored
Normal file
13
python/triton/third_party/hip/lib/Dialect/TritonGPUROCM/IR/CMakeLists.txt
vendored
Normal file
@@ -0,0 +1,13 @@
|
||||
add_mlir_dialect_library(TritonGPUROCMIR
|
||||
Dialect.cpp
|
||||
Traits.cpp
|
||||
Types.cpp
|
||||
|
||||
DEPENDS
|
||||
TritonGPUROCMTableGen
|
||||
TritonGPUROCMAttrDefsIncGen
|
||||
|
||||
LINK_LIBS PUBLIC
|
||||
MLIRGPUOps
|
||||
TritonIR
|
||||
)
|
||||
1963
python/triton/third_party/hip/lib/Dialect/TritonGPUROCM/IR/Dialect.cpp
vendored
Normal file
1963
python/triton/third_party/hip/lib/Dialect/TritonGPUROCM/IR/Dialect.cpp
vendored
Normal file
File diff suppressed because it is too large
Load Diff
Some files were not shown because too many files have changed in this diff Show More
Reference in New Issue
Block a user