mirror of
https://github.com/ROCm/ROCm.git
synced 2026-04-05 03:01:17 -04:00
[Triton-MLIR][BACKEND] Refactor TritonGPUToLLVM into several files (#988)
Refactor the backend into multiple smaller files.
This commit is contained in:
@@ -10,8 +10,8 @@
|
||||
#include "mlir/Support/LogicalResult.h"
|
||||
#include "mlir/Target/LLVMIR/Dialect/LLVMIR/LLVMToLLVMIRTranslation.h"
|
||||
#include "mlir/Target/LLVMIR/Export.h"
|
||||
#include "triton/Conversion/TritonGPUToLLVM/TritonGPUToLLVM.h"
|
||||
#include "triton/Conversion/TritonToTritonGPU/TritonToTritonGPU.h"
|
||||
#include "triton/Conversion/TritonGPUToLLVM/TritonGPUToLLVMPass.h"
|
||||
#include "triton/Conversion/TritonToTritonGPU/TritonToTritonGPUPass.h"
|
||||
#include "triton/Dialect/Triton/IR/Dialect.h"
|
||||
#include "triton/Dialect/TritonGPU/IR/Dialect.h"
|
||||
#include "triton/Target/LLVMIR/LLVMIRTranslation.h"
|
||||
|
||||
@@ -10,20 +10,21 @@ namespace triton {
|
||||
namespace type {
|
||||
|
||||
// Integer types
|
||||
Type i32Ty(MLIRContext *ctx) { return IntegerType::get(ctx, 32); }
|
||||
Type i8Ty(MLIRContext *ctx) { return IntegerType::get(ctx, 8); }
|
||||
Type u32Ty(MLIRContext *ctx) {
|
||||
// TODO(Superjomn): may change `static` into better implementations
|
||||
static Type i32Ty(MLIRContext *ctx) { return IntegerType::get(ctx, 32); }
|
||||
static Type i8Ty(MLIRContext *ctx) { return IntegerType::get(ctx, 8); }
|
||||
static Type u32Ty(MLIRContext *ctx) {
|
||||
return IntegerType::get(ctx, 32, IntegerType::Unsigned);
|
||||
}
|
||||
Type u1Ty(MLIRContext *ctx) {
|
||||
static Type u1Ty(MLIRContext *ctx) {
|
||||
return IntegerType::get(ctx, 1, IntegerType::Unsigned);
|
||||
}
|
||||
|
||||
// Float types
|
||||
Type f16Ty(MLIRContext *ctx) { return FloatType::getF16(ctx); }
|
||||
Type f32Ty(MLIRContext *ctx) { return FloatType::getF32(ctx); }
|
||||
Type f64Ty(MLIRContext *ctx) { return FloatType::getF64(ctx); }
|
||||
Type bf16Ty(MLIRContext *ctx) { return FloatType::getBF16(ctx); }
|
||||
static Type f16Ty(MLIRContext *ctx) { return FloatType::getF16(ctx); }
|
||||
static Type f32Ty(MLIRContext *ctx) { return FloatType::getF32(ctx); }
|
||||
static Type f64Ty(MLIRContext *ctx) { return FloatType::getF64(ctx); }
|
||||
static Type bf16Ty(MLIRContext *ctx) { return FloatType::getBF16(ctx); }
|
||||
|
||||
static bool isFloat(Type type) {
|
||||
return type.isF32() || type.isF64() || type.isF16() || type.isF128();
|
||||
|
||||
@@ -2,8 +2,8 @@
|
||||
#define TRITON_CONVERSION_PASSES_H
|
||||
|
||||
#include "mlir/Dialect/LLVMIR/LLVMDialect.h"
|
||||
#include "triton/Conversion/TritonGPUToLLVM/TritonGPUToLLVM.h"
|
||||
#include "triton/Conversion/TritonToTritonGPU/TritonToTritonGPU.h"
|
||||
#include "triton/Conversion/TritonGPUToLLVM/TritonGPUToLLVMPass.h"
|
||||
#include "triton/Conversion/TritonToTritonGPU/TritonToTritonGPUPass.h"
|
||||
|
||||
namespace mlir {
|
||||
namespace triton {
|
||||
|
||||
@@ -1,10 +1,10 @@
|
||||
#ifndef TRITON_CONVERSION_TRITON_GPU_TO_LLVM_ASM_FORMAT_H_
|
||||
#define TRITON_CONVERSION_TRITON_GPU_TO_LLVM_ASM_FORMAT_H_
|
||||
#ifndef TRITON_CONVERSION_TRITONGPU_TO_LLVM_ASM_FORMAT_H
|
||||
#define TRITON_CONVERSION_TRITONGPU_TO_LLVM_ASM_FORMAT_H
|
||||
|
||||
#include "mlir/IR/Value.h"
|
||||
#include "triton/Dialect/Triton/IR/Dialect.h"
|
||||
#include "llvm/ADT/SmallVector.h"
|
||||
#include "llvm/ADT/StringRef.h"
|
||||
#include "mlir/IR/Value.h"
|
||||
#include "triton/Dialect/Triton/IR/Dialect.h"
|
||||
#include <memory>
|
||||
#include <string>
|
||||
|
||||
@@ -172,11 +172,11 @@ private:
|
||||
return argArchive.back().get();
|
||||
}
|
||||
|
||||
// Make the oprands in argArchive follow the provided \param order.
|
||||
// Make the operands in argArchive follow the provided \param order.
|
||||
void reorderArgArchive(ArrayRef<Operand *> order) {
|
||||
assert(order.size() == argArchive.size());
|
||||
// The order in argArchive is unnecessary when onlyAttachMLIRArgs=false, but
|
||||
// it do necessary when onlyAttachMLIRArgs is true for the $0,$1.. are
|
||||
// it does necessary when onlyAttachMLIRArgs is true for the $0, $1... are
|
||||
// determined by PTX code snippet passed from external.
|
||||
sort(argArchive.begin(), argArchive.end(),
|
||||
[&](std::unique_ptr<Operand> &a, std::unique_ptr<Operand> &b) {
|
||||
@@ -306,8 +306,7 @@ struct PTXInstrExecution {
|
||||
bool onlyAttachMLIRArgs{};
|
||||
};
|
||||
|
||||
//// =============================== Some instruction wrappers
|
||||
///===============================
|
||||
/// ====== Some instruction wrappers ======
|
||||
// We add the wrappers to make the usage more intuitive by avoiding mixing the
|
||||
// PTX code with some trivial C++ code.
|
||||
|
||||
@@ -324,4 +323,4 @@ struct PTXCpAsyncLoadInstr : PTXInstrBase<PTXCpAsyncLoadInstr> {
|
||||
} // namespace triton
|
||||
} // namespace mlir
|
||||
|
||||
#endif // TRITON_CONVERSION_TRITON_GPU_TO_LLVM_ASM_FORMAT_H_
|
||||
#endif
|
||||
@@ -1,48 +0,0 @@
|
||||
#ifndef TRITON_CONVERSION_TRITONGPUTOLLVM_TRITONGPUTOLLVMPASS_H_
|
||||
#define TRITON_CONVERSION_TRITONGPUTOLLVM_TRITONGPUTOLLVMPASS_H_
|
||||
|
||||
#include "mlir/Conversion/LLVMCommon/TypeConverter.h"
|
||||
#include "mlir/Transforms/DialectConversion.h"
|
||||
#include <memory>
|
||||
|
||||
namespace mlir {
|
||||
|
||||
class ModuleOp;
|
||||
template <typename T> class OperationPass;
|
||||
|
||||
class TritonLLVMConversionTarget : public ConversionTarget {
|
||||
public:
|
||||
explicit TritonLLVMConversionTarget(MLIRContext &ctx,
|
||||
mlir::LLVMTypeConverter &typeConverter);
|
||||
};
|
||||
|
||||
class TritonLLVMFunctionConversionTarget : public ConversionTarget {
|
||||
public:
|
||||
explicit TritonLLVMFunctionConversionTarget(
|
||||
MLIRContext &ctx, mlir::LLVMTypeConverter &typeConverter);
|
||||
};
|
||||
|
||||
namespace LLVM {
|
||||
void vprintf(StringRef msg, ValueRange args,
|
||||
ConversionPatternRewriter &rewriter);
|
||||
}
|
||||
|
||||
namespace triton {
|
||||
|
||||
// Names for identifying different NVVM annotations. It is used as attribute
|
||||
// names in MLIR modules. Refer to
|
||||
// https://docs.nvidia.com/cuda/nvvm-ir-spec/index.html#supported-properties for
|
||||
// the full list.
|
||||
struct NVVMMetadataField {
|
||||
static constexpr char MaxNTid[] = "nvvm.maxntid";
|
||||
static constexpr char Kernel[] = "nvvm.kernel";
|
||||
};
|
||||
|
||||
std::unique_ptr<OperationPass<ModuleOp>>
|
||||
createConvertTritonGPUToLLVMPass(int computeCapability = 80);
|
||||
|
||||
} // namespace triton
|
||||
|
||||
} // namespace mlir
|
||||
|
||||
#endif
|
||||
@@ -0,0 +1,22 @@
|
||||
#ifndef TRITON_CONVERSION_TRITONGPU_TO_LLVM_PASS_H
|
||||
#define TRITON_CONVERSION_TRITONGPU_TO_LLVM_PASS_H
|
||||
|
||||
#include "mlir/Conversion/LLVMCommon/TypeConverter.h"
|
||||
#include "mlir/Transforms/DialectConversion.h"
|
||||
#include <memory>
|
||||
|
||||
namespace mlir {
|
||||
|
||||
class ModuleOp;
|
||||
template <typename T> class OperationPass;
|
||||
|
||||
namespace triton {
|
||||
|
||||
std::unique_ptr<OperationPass<ModuleOp>>
|
||||
createConvertTritonGPUToLLVMPass(int computeCapability = 80);
|
||||
|
||||
} // namespace triton
|
||||
|
||||
} // namespace mlir
|
||||
|
||||
#endif
|
||||
@@ -1,5 +1,5 @@
|
||||
#ifndef TRITON_CONVERSION_TRITONTOTRITONGPU_TRITONTOTRITONGPUPASS_H_
|
||||
#define TRITON_CONVERSION_TRITONTOTRITONGPU_TRITONTOTRITONGPUPASS_H_
|
||||
#ifndef TRITON_CONVERSION_TRITONTOTRITONGPU_TRITONTOTRITONGPUPASS_H
|
||||
#define TRITON_CONVERSION_TRITONTOTRITONGPU_TRITONTOTRITONGPUPASS_H
|
||||
|
||||
#include <memory>
|
||||
|
||||
@@ -1,20 +0,0 @@
|
||||
#ifndef TRITON_CONVERSION_PASSDETAIL_H
|
||||
#define TRITON_CONVERSION_PASSDETAIL_H
|
||||
|
||||
#include "mlir/Dialect/GPU/GPUDialect.h"
|
||||
#include "mlir/Dialect/LLVMIR/LLVMDialect.h"
|
||||
#include "mlir/Dialect/LLVMIR/NVVMDialect.h"
|
||||
#include "mlir/Pass/Pass.h"
|
||||
#include "triton/Dialect/Triton/IR/Dialect.h"
|
||||
#include "triton/Dialect/TritonGPU/IR/Dialect.h"
|
||||
|
||||
namespace mlir {
|
||||
namespace triton {
|
||||
|
||||
#define GEN_PASS_CLASSES
|
||||
#include "triton/Conversion/Passes.h.inc"
|
||||
|
||||
} // namespace triton
|
||||
} // namespace mlir
|
||||
|
||||
#endif
|
||||
@@ -1,6 +1,6 @@
|
||||
add_mlir_conversion_library(TritonGPUToLLVM
|
||||
TritonGPUToLLVM.cpp
|
||||
PtxAsmFormat.cpp
|
||||
TritonGPUToLLVMPass.cpp
|
||||
PTXAsmFormat.cpp
|
||||
|
||||
ADDITIONAL_HEADER_DIRS
|
||||
${PROJECT_SOURCE_DIR}/include/triton/Conversion/TritonGPUToLLVM
|
||||
|
||||
@@ -1,7 +1,8 @@
|
||||
#ifndef TRITON_CONVERSION_TRITONGPU_TO_LLVM_DOT_HELPERS_H
|
||||
#define TRITON_CONVERSION_TRITONGPU_TO_LLVM_DOT_HELPERS_H
|
||||
|
||||
#include "./Utility.h"
|
||||
#include "llvm/Support/Format.h"
|
||||
#include "llvm/Support/FormatVariadic.h"
|
||||
#include "mlir/Analysis/SliceAnalysis.h"
|
||||
#include "mlir/Conversion/ArithmeticToLLVM/ArithmeticToLLVM.h"
|
||||
#include "mlir/Conversion/GPUToNVVM/GPUToNVVMPass.h"
|
||||
@@ -22,14 +23,11 @@
|
||||
#include "triton/Analysis/Membar.h"
|
||||
#include "triton/Analysis/Utility.h"
|
||||
#include "triton/Conversion/MLIRTypes.h"
|
||||
#include "triton/Conversion/TritonGPUToLLVM/PtxAsmFormat.h"
|
||||
#include "triton/Conversion/TritonGPUToLLVM/PTXAsmFormat.h"
|
||||
#include "triton/Dialect/Triton/IR/Dialect.h"
|
||||
#include "triton/Dialect/TritonGPU/IR/Dialect.h"
|
||||
#include "llvm/Support/Format.h"
|
||||
#include "llvm/Support/FormatVariadic.h"
|
||||
#include <memory>
|
||||
#include <numeric>
|
||||
#include <string>
|
||||
|
||||
#include "Utility.h"
|
||||
|
||||
namespace mlir {
|
||||
namespace LLVM {
|
||||
|
||||
@@ -1,8 +1,10 @@
|
||||
#include "triton/Conversion/TritonGPUToLLVM/PtxAsmFormat.h"
|
||||
#include "triton/Conversion/TritonGPUToLLVM/PTXAsmFormat.h"
|
||||
|
||||
#include "llvm/Support/raw_ostream.h"
|
||||
#include "mlir/Dialect/LLVMIR/LLVMDialect.h"
|
||||
#include "mlir/Transforms/DialectConversion.h"
|
||||
#include "llvm/Support/raw_ostream.h"
|
||||
#include <sstream> // unify to llvm::raw_string_ostream ?
|
||||
// TODO(Superjomn): unify to llvm::raw_string_ostream
|
||||
#include <sstream>
|
||||
|
||||
namespace mlir {
|
||||
namespace triton {
|
||||
File diff suppressed because it is too large
Load Diff
541
lib/Conversion/TritonGPUToLLVM/TritonGPUToLLVMBase.h
Normal file
541
lib/Conversion/TritonGPUToLLVM/TritonGPUToLLVMBase.h
Normal file
@@ -0,0 +1,541 @@
|
||||
#ifndef TRITON_CONVERSION_TRITONGPU_TO_LLVM_BASE_H
|
||||
#define TRITON_CONVERSION_TRITONGPU_TO_LLVM_BASE_H
|
||||
|
||||
#include "Utility.h"
|
||||
|
||||
using namespace mlir;
|
||||
using namespace mlir::triton;
|
||||
|
||||
using ::mlir::LLVM::SharedMemoryObject;
|
||||
using ::mlir::triton::gpu::BlockedEncodingAttr;
|
||||
using ::mlir::triton::gpu::MmaEncodingAttr;
|
||||
using ::mlir::triton::gpu::SliceEncodingAttr;
|
||||
|
||||
// FuncOpConversion/FuncOpConversionBase is borrowed from
|
||||
// https://github.com/llvm/llvm-project/blob/fae656b2dd80246c3c6f01e9c77c49560368752c/mlir/lib/Conversion/FuncToLLVM/FuncToLLVM.cpp#L276
|
||||
// since it is not exposed on header files in mlir v14
|
||||
// TODO(Superjomn): remove the code when MLIR v15.0 is included.
|
||||
// All the rights are reserved by the LLVM community.
|
||||
|
||||
struct FuncOpConversionBase : public ConvertOpToLLVMPattern<FuncOp> {
|
||||
private:
|
||||
/// Only retain those attributes that are not constructed by
|
||||
/// `LLVMFuncOp::build`. If `filterArgAttrs` is set, also filter out argument
|
||||
/// attributes.
|
||||
static void filterFuncAttributes(ArrayRef<NamedAttribute> attrs,
|
||||
bool filterArgAttrs,
|
||||
SmallVectorImpl<NamedAttribute> &result) {
|
||||
for (const auto &attr : attrs) {
|
||||
if (attr.getName() == SymbolTable::getSymbolAttrName() ||
|
||||
attr.getName() == FunctionOpInterface::getTypeAttrName() ||
|
||||
attr.getName() == "std.varargs" ||
|
||||
(filterArgAttrs &&
|
||||
attr.getName() == FunctionOpInterface::getArgDictAttrName()))
|
||||
continue;
|
||||
result.push_back(attr);
|
||||
}
|
||||
}
|
||||
|
||||
/// Helper function for wrapping all attributes into a single DictionaryAttr
|
||||
static auto wrapAsStructAttrs(OpBuilder &b, ArrayAttr attrs) {
|
||||
return DictionaryAttr::get(
|
||||
b.getContext(), b.getNamedAttr("llvm.struct_attrs", attrs));
|
||||
}
|
||||
|
||||
protected:
|
||||
using ConvertOpToLLVMPattern<FuncOp>::ConvertOpToLLVMPattern;
|
||||
|
||||
// Convert input FuncOp to LLVMFuncOp by using the LLVMTypeConverter provided
|
||||
// to this legalization pattern.
|
||||
LLVM::LLVMFuncOp
|
||||
convertFuncOpToLLVMFuncOp(FuncOp funcOp,
|
||||
ConversionPatternRewriter &rewriter) const {
|
||||
// Convert the original function arguments. They are converted using the
|
||||
// LLVMTypeConverter provided to this legalization pattern.
|
||||
auto varargsAttr = funcOp->getAttrOfType<BoolAttr>("func.varargs");
|
||||
TypeConverter::SignatureConversion result(funcOp.getNumArguments());
|
||||
auto llvmType = getTypeConverter()->convertFunctionSignature(
|
||||
funcOp.getType(), varargsAttr && varargsAttr.getValue(), result);
|
||||
if (!llvmType)
|
||||
return nullptr;
|
||||
|
||||
// Propagate argument/result attributes to all converted arguments/result
|
||||
// obtained after converting a given original argument/result.
|
||||
SmallVector<NamedAttribute, 4> attributes;
|
||||
filterFuncAttributes(funcOp->getAttrs(), /*filterArgAttrs=*/true,
|
||||
attributes);
|
||||
if (ArrayAttr resAttrDicts = funcOp.getAllResultAttrs()) {
|
||||
assert(!resAttrDicts.empty() && "expected array to be non-empty");
|
||||
auto newResAttrDicts =
|
||||
(funcOp.getNumResults() == 1)
|
||||
? resAttrDicts
|
||||
: rewriter.getArrayAttr(
|
||||
{wrapAsStructAttrs(rewriter, resAttrDicts)});
|
||||
attributes.push_back(rewriter.getNamedAttr(
|
||||
FunctionOpInterface::getResultDictAttrName(), newResAttrDicts));
|
||||
}
|
||||
if (ArrayAttr argAttrDicts = funcOp.getAllArgAttrs()) {
|
||||
SmallVector<Attribute, 4> newArgAttrs(
|
||||
llvmType.cast<LLVM::LLVMFunctionType>().getNumParams());
|
||||
for (unsigned i = 0, e = funcOp.getNumArguments(); i < e; ++i) {
|
||||
auto mapping = result.getInputMapping(i);
|
||||
assert(mapping && "unexpected deletion of function argument");
|
||||
for (size_t j = 0; j < mapping->size; ++j)
|
||||
newArgAttrs[mapping->inputNo + j] = argAttrDicts[i];
|
||||
}
|
||||
attributes.push_back(
|
||||
rewriter.getNamedAttr(FunctionOpInterface::getArgDictAttrName(),
|
||||
rewriter.getArrayAttr(newArgAttrs)));
|
||||
}
|
||||
for (const auto &pair : llvm::enumerate(attributes)) {
|
||||
if (pair.value().getName() == "llvm.linkage") {
|
||||
attributes.erase(attributes.begin() + pair.index());
|
||||
break;
|
||||
}
|
||||
}
|
||||
|
||||
// Create an LLVM function, use external linkage by default until MLIR
|
||||
// functions have linkage.
|
||||
LLVM::Linkage linkage = LLVM::Linkage::External;
|
||||
if (funcOp->hasAttr("llvm.linkage")) {
|
||||
auto attr =
|
||||
funcOp->getAttr("llvm.linkage").dyn_cast<mlir::LLVM::LinkageAttr>();
|
||||
if (!attr) {
|
||||
funcOp->emitError()
|
||||
<< "Contains llvm.linkage attribute not of type LLVM::LinkageAttr";
|
||||
return nullptr;
|
||||
}
|
||||
linkage = attr.getLinkage();
|
||||
}
|
||||
auto newFuncOp = rewriter.create<LLVM::LLVMFuncOp>(
|
||||
funcOp.getLoc(), funcOp.getName(), llvmType, linkage,
|
||||
/*dsoLocal*/ false, attributes);
|
||||
rewriter.inlineRegionBefore(funcOp.getBody(), newFuncOp.getBody(),
|
||||
newFuncOp.end());
|
||||
if (failed(rewriter.convertRegionTypes(&newFuncOp.getBody(), *typeConverter,
|
||||
&result)))
|
||||
return nullptr;
|
||||
|
||||
return newFuncOp;
|
||||
}
|
||||
};
|
||||
|
||||
struct ConvertTritonGPUOpToLLVMPatternBase {
|
||||
static Value
|
||||
getStructFromSharedMemoryObject(Location loc,
|
||||
const SharedMemoryObject &smemObj,
|
||||
ConversionPatternRewriter &rewriter) {
|
||||
auto elems = smemObj.getElems();
|
||||
auto types = smemObj.getTypes();
|
||||
auto structTy =
|
||||
LLVM::LLVMStructType::getLiteral(rewriter.getContext(), types);
|
||||
return getStructFromElements(loc, elems, rewriter, structTy);
|
||||
}
|
||||
};
|
||||
|
||||
template <typename SourceOp>
|
||||
class ConvertTritonGPUOpToLLVMPattern
|
||||
: public ConvertOpToLLVMPattern<SourceOp>,
|
||||
public ConvertTritonGPUOpToLLVMPatternBase {
|
||||
public:
|
||||
using OpAdaptor = typename SourceOp::Adaptor;
|
||||
|
||||
explicit ConvertTritonGPUOpToLLVMPattern(LLVMTypeConverter &typeConverter,
|
||||
PatternBenefit benefit = 1)
|
||||
: ConvertOpToLLVMPattern<SourceOp>(typeConverter, benefit) {}
|
||||
|
||||
explicit ConvertTritonGPUOpToLLVMPattern(LLVMTypeConverter &typeConverter,
|
||||
const Allocation *allocation,
|
||||
Value smem,
|
||||
PatternBenefit benefit = 1)
|
||||
: ConvertOpToLLVMPattern<SourceOp>(typeConverter, benefit),
|
||||
allocation(allocation), smem(smem) {}
|
||||
|
||||
Value getThreadId(ConversionPatternRewriter &rewriter, Location loc) const {
|
||||
auto llvmIndexTy = this->getTypeConverter()->getIndexType();
|
||||
auto cast = rewriter.create<UnrealizedConversionCastOp>(
|
||||
loc, TypeRange{llvmIndexTy},
|
||||
ValueRange{rewriter.create<::mlir::gpu::ThreadIdOp>(
|
||||
loc, rewriter.getIndexType(), ::mlir::gpu::Dimension::x)});
|
||||
Value threadId = cast.getResult(0);
|
||||
return threadId;
|
||||
}
|
||||
|
||||
// -----------------------------------------------------------------------
|
||||
// Utilities
|
||||
// -----------------------------------------------------------------------
|
||||
|
||||
// Convert an \param index to a multi-dim coordinate given \param shape and
|
||||
// \param order.
|
||||
SmallVector<Value> delinearize(ConversionPatternRewriter &rewriter,
|
||||
Location loc, Value linear,
|
||||
ArrayRef<unsigned> shape,
|
||||
ArrayRef<unsigned> order) const {
|
||||
unsigned rank = shape.size();
|
||||
assert(rank == order.size());
|
||||
auto reordered = reorder(shape, order);
|
||||
auto reorderedMultiDim = delinearize(rewriter, loc, linear, reordered);
|
||||
SmallVector<Value> multiDim(rank);
|
||||
for (unsigned i = 0; i < rank; ++i) {
|
||||
multiDim[order[i]] = reorderedMultiDim[i];
|
||||
}
|
||||
return multiDim;
|
||||
}
|
||||
|
||||
SmallVector<Value> delinearize(ConversionPatternRewriter &rewriter,
|
||||
Location loc, Value linear,
|
||||
ArrayRef<unsigned> shape) const {
|
||||
unsigned rank = shape.size();
|
||||
assert(rank > 0);
|
||||
SmallVector<Value> multiDim(rank);
|
||||
if (rank == 1) {
|
||||
multiDim[0] = linear;
|
||||
} else {
|
||||
Value remained = linear;
|
||||
for (auto &&en : llvm::enumerate(shape.drop_back())) {
|
||||
Value dimSize = idx_val(en.value());
|
||||
multiDim[en.index()] = urem(remained, dimSize);
|
||||
remained = udiv(remained, dimSize);
|
||||
}
|
||||
multiDim[rank - 1] = remained;
|
||||
}
|
||||
return multiDim;
|
||||
}
|
||||
|
||||
Value linearize(ConversionPatternRewriter &rewriter, Location loc,
|
||||
ArrayRef<Value> multiDim, ArrayRef<unsigned> shape,
|
||||
ArrayRef<unsigned> order) const {
|
||||
return linearize(rewriter, loc, reorder<Value>(multiDim, order),
|
||||
reorder<unsigned>(shape, order));
|
||||
}
|
||||
|
||||
Value linearize(ConversionPatternRewriter &rewriter, Location loc,
|
||||
ArrayRef<Value> multiDim, ArrayRef<unsigned> shape) const {
|
||||
auto rank = multiDim.size();
|
||||
Value linear = idx_val(0);
|
||||
if (rank > 0) {
|
||||
linear = multiDim.back();
|
||||
for (auto [dim, dimShape] :
|
||||
llvm::reverse(llvm::zip(multiDim.drop_back(), shape.drop_back()))) {
|
||||
Value dimSize = idx_val(dimShape);
|
||||
linear = add(mul(linear, dimSize), dim);
|
||||
}
|
||||
}
|
||||
return linear;
|
||||
}
|
||||
|
||||
Value dot(ConversionPatternRewriter &rewriter, Location loc,
|
||||
ArrayRef<Value> offsets, ArrayRef<Value> strides) const {
|
||||
assert(offsets.size() == strides.size());
|
||||
Value ret = idx_val(0);
|
||||
for (auto [offset, stride] : llvm::zip(offsets, strides)) {
|
||||
ret = add(ret, mul(offset, stride));
|
||||
}
|
||||
return ret;
|
||||
}
|
||||
|
||||
// -----------------------------------------------------------------------
|
||||
// Blocked layout indices
|
||||
// -----------------------------------------------------------------------
|
||||
|
||||
// Get an index-base for each dimension for a \param blocked_layout.
|
||||
SmallVector<Value>
|
||||
emitBaseIndexForBlockedLayout(Location loc,
|
||||
ConversionPatternRewriter &rewriter,
|
||||
const BlockedEncodingAttr &blocked_layout,
|
||||
ArrayRef<int64_t> shape) const {
|
||||
Value threadId = getThreadId(rewriter, loc);
|
||||
Value warpSize = idx_val(32);
|
||||
Value laneId = urem(threadId, warpSize);
|
||||
Value warpId = udiv(threadId, warpSize);
|
||||
auto sizePerThread = blocked_layout.getSizePerThread();
|
||||
auto threadsPerWarp = blocked_layout.getThreadsPerWarp();
|
||||
auto warpsPerCTA = blocked_layout.getWarpsPerCTA();
|
||||
auto order = blocked_layout.getOrder();
|
||||
unsigned rank = shape.size();
|
||||
|
||||
// delinearize threadId to get the base index
|
||||
SmallVector<Value> multiDimWarpId =
|
||||
delinearize(rewriter, loc, warpId, warpsPerCTA, order);
|
||||
SmallVector<Value> multiDimThreadId =
|
||||
delinearize(rewriter, loc, laneId, threadsPerWarp, order);
|
||||
|
||||
SmallVector<Value> multiDimBase(rank);
|
||||
for (unsigned k = 0; k < rank; ++k) {
|
||||
// Wrap around multiDimWarpId/multiDimThreadId incase
|
||||
// shape[k] > shapePerCTA[k]
|
||||
auto maxWarps =
|
||||
ceil<unsigned>(shape[k], sizePerThread[k] * threadsPerWarp[k]);
|
||||
auto maxThreads = ceil<unsigned>(shape[k], sizePerThread[k]);
|
||||
multiDimWarpId[k] = urem(multiDimWarpId[k], idx_val(maxWarps));
|
||||
multiDimThreadId[k] = urem(multiDimThreadId[k], idx_val(maxThreads));
|
||||
// multiDimBase[k] = (multiDimThreadId[k] +
|
||||
// multiDimWarpId[k] * threadsPerWarp[k]) *
|
||||
// sizePerThread[k];
|
||||
Value threadsPerWarpK = idx_val(threadsPerWarp[k]);
|
||||
Value sizePerThreadK = idx_val(sizePerThread[k]);
|
||||
multiDimBase[k] =
|
||||
mul(sizePerThreadK, add(multiDimThreadId[k],
|
||||
mul(multiDimWarpId[k], threadsPerWarpK)));
|
||||
}
|
||||
return multiDimBase;
|
||||
}
|
||||
|
||||
SmallVector<SmallVector<unsigned>>
|
||||
emitOffsetForBlockedLayout(const BlockedEncodingAttr &blockedLayout,
|
||||
ArrayRef<int64_t> shape) const {
|
||||
auto sizePerThread = blockedLayout.getSizePerThread();
|
||||
auto threadsPerWarp = blockedLayout.getThreadsPerWarp();
|
||||
auto warpsPerCTA = blockedLayout.getWarpsPerCTA();
|
||||
auto order = blockedLayout.getOrder();
|
||||
|
||||
unsigned rank = shape.size();
|
||||
SmallVector<unsigned> shapePerCTA = getShapePerCTA(blockedLayout);
|
||||
SmallVector<unsigned> tilesPerDim(rank);
|
||||
for (unsigned k = 0; k < rank; ++k)
|
||||
tilesPerDim[k] = ceil<unsigned>(shape[k], shapePerCTA[k]);
|
||||
|
||||
SmallVector<SmallVector<unsigned>> offset(rank);
|
||||
for (unsigned k = 0; k < rank; ++k) {
|
||||
// 1 block in minimum if shape[k] is less than shapePerCTA[k]
|
||||
for (unsigned blockOffset = 0; blockOffset < tilesPerDim[k];
|
||||
++blockOffset)
|
||||
for (unsigned warpOffset = 0; warpOffset < warpsPerCTA[k]; ++warpOffset)
|
||||
for (unsigned threadOffset = 0; threadOffset < threadsPerWarp[k];
|
||||
++threadOffset)
|
||||
for (unsigned elemOffset = 0; elemOffset < sizePerThread[k];
|
||||
++elemOffset)
|
||||
offset[k].push_back(blockOffset * sizePerThread[k] *
|
||||
threadsPerWarp[k] * warpsPerCTA[k] +
|
||||
warpOffset * sizePerThread[k] *
|
||||
threadsPerWarp[k] +
|
||||
threadOffset * sizePerThread[k] + elemOffset);
|
||||
}
|
||||
|
||||
unsigned elemsPerThread = blockedLayout.getElemsPerThread(shape);
|
||||
unsigned totalSizePerThread = product<unsigned>(sizePerThread);
|
||||
SmallVector<SmallVector<unsigned>> reorderedOffset(elemsPerThread);
|
||||
for (unsigned n = 0; n < elemsPerThread; ++n) {
|
||||
unsigned linearNanoTileId = n / totalSizePerThread;
|
||||
unsigned linearNanoTileElemId = n % totalSizePerThread;
|
||||
SmallVector<unsigned> multiDimNanoTileId =
|
||||
getMultiDimIndex<unsigned>(linearNanoTileId, tilesPerDim, order);
|
||||
SmallVector<unsigned> multiDimNanoTileElemId = getMultiDimIndex<unsigned>(
|
||||
linearNanoTileElemId, sizePerThread, order);
|
||||
for (unsigned k = 0; k < rank; ++k) {
|
||||
unsigned reorderedMultiDimId =
|
||||
multiDimNanoTileId[k] *
|
||||
(sizePerThread[k] * threadsPerWarp[k] * warpsPerCTA[k]) +
|
||||
multiDimNanoTileElemId[k];
|
||||
reorderedOffset[n].push_back(offset[k][reorderedMultiDimId]);
|
||||
}
|
||||
}
|
||||
return reorderedOffset;
|
||||
}
|
||||
|
||||
// -----------------------------------------------------------------------
|
||||
// Mma layout indices
|
||||
// -----------------------------------------------------------------------
|
||||
|
||||
SmallVector<Value>
|
||||
emitBaseIndexForMmaLayoutV1(Location loc, ConversionPatternRewriter &rewriter,
|
||||
const MmaEncodingAttr &mmaLayout,
|
||||
ArrayRef<int64_t> shape) const {
|
||||
llvm_unreachable("emitIndicesForMmaLayoutV1 not implemented");
|
||||
}
|
||||
|
||||
SmallVector<SmallVector<unsigned>>
|
||||
emitOffsetForMmaLayoutV1(const MmaEncodingAttr &mmaLayout,
|
||||
ArrayRef<int64_t> shape) const {
|
||||
SmallVector<SmallVector<unsigned>> ret;
|
||||
|
||||
for (unsigned i = 0; i < shape[0]; i += getShapePerCTA(mmaLayout)[0]) {
|
||||
for (unsigned j = 0; j < shape[1]; j += getShapePerCTA(mmaLayout)[1]) {
|
||||
ret.push_back({i, j});
|
||||
ret.push_back({i, j + 1});
|
||||
ret.push_back({i + 2, j});
|
||||
ret.push_back({i + 2, j + 1});
|
||||
ret.push_back({i, j + 8});
|
||||
ret.push_back({i, j + 9});
|
||||
ret.push_back({i + 2, j + 8});
|
||||
ret.push_back({i + 2, j + 9});
|
||||
}
|
||||
}
|
||||
return ret;
|
||||
}
|
||||
|
||||
SmallVector<Value>
|
||||
emitBaseIndexForMmaLayoutV2(Location loc, ConversionPatternRewriter &rewriter,
|
||||
const MmaEncodingAttr &mmaLayout,
|
||||
ArrayRef<int64_t> shape) const {
|
||||
auto _warpsPerCTA = mmaLayout.getWarpsPerCTA();
|
||||
assert(_warpsPerCTA.size() == 2);
|
||||
SmallVector<Value> warpsPerCTA = {idx_val(_warpsPerCTA[0]),
|
||||
idx_val(_warpsPerCTA[1])};
|
||||
Value threadId = getThreadId(rewriter, loc);
|
||||
Value warpSize = idx_val(32);
|
||||
Value laneId = urem(threadId, warpSize);
|
||||
Value warpId = udiv(threadId, warpSize);
|
||||
Value warpId0 = urem(warpId, warpsPerCTA[0]);
|
||||
Value warpId1 = urem(udiv(warpId, warpsPerCTA[0]), warpsPerCTA[1]);
|
||||
Value offWarp0 = mul(warpId0, idx_val(16));
|
||||
Value offWarp1 = mul(warpId1, idx_val(8));
|
||||
|
||||
SmallVector<Value> multiDimBase(2);
|
||||
multiDimBase[0] = add(udiv(laneId, idx_val(4)), offWarp0);
|
||||
multiDimBase[1] = add(mul(idx_val(2), urem(laneId, idx_val(4))), offWarp1);
|
||||
return multiDimBase;
|
||||
}
|
||||
|
||||
SmallVector<SmallVector<unsigned>>
|
||||
emitOffsetForMmaLayoutV2(const MmaEncodingAttr &mmaLayout,
|
||||
ArrayRef<int64_t> shape) const {
|
||||
SmallVector<SmallVector<unsigned>> ret;
|
||||
|
||||
for (unsigned i = 0; i < shape[0]; i += getShapePerCTA(mmaLayout)[0]) {
|
||||
for (unsigned j = 0; j < shape[1]; j += getShapePerCTA(mmaLayout)[1]) {
|
||||
ret.push_back({i, j});
|
||||
ret.push_back({i, j + 1});
|
||||
ret.push_back({i + 8, j});
|
||||
ret.push_back({i + 8, j + 1});
|
||||
}
|
||||
}
|
||||
return ret;
|
||||
}
|
||||
|
||||
// -----------------------------------------------------------------------
|
||||
// Get offsets / indices for any layout
|
||||
// -----------------------------------------------------------------------
|
||||
|
||||
SmallVector<Value> emitBaseIndexForLayout(Location loc,
|
||||
ConversionPatternRewriter &rewriter,
|
||||
const Attribute &layout,
|
||||
ArrayRef<int64_t> shape) const {
|
||||
if (auto blockedLayout = layout.dyn_cast<BlockedEncodingAttr>())
|
||||
return emitBaseIndexForBlockedLayout(loc, rewriter, blockedLayout, shape);
|
||||
if (auto mmaLayout = layout.dyn_cast<MmaEncodingAttr>()) {
|
||||
if (mmaLayout.isVolta())
|
||||
return emitBaseIndexForMmaLayoutV1(loc, rewriter, mmaLayout, shape);
|
||||
if (mmaLayout.isAmpere())
|
||||
return emitBaseIndexForMmaLayoutV2(loc, rewriter, mmaLayout, shape);
|
||||
}
|
||||
llvm_unreachable("unsupported emitBaseIndexForLayout");
|
||||
}
|
||||
|
||||
SmallVector<SmallVector<unsigned>>
|
||||
emitOffsetForLayout(const Attribute &layout, ArrayRef<int64_t> shape) const {
|
||||
if (auto blockedLayout = layout.dyn_cast<BlockedEncodingAttr>())
|
||||
return emitOffsetForBlockedLayout(blockedLayout, shape);
|
||||
if (auto mmaLayout = layout.dyn_cast<MmaEncodingAttr>()) {
|
||||
if (mmaLayout.isVolta())
|
||||
return emitOffsetForMmaLayoutV1(mmaLayout, shape);
|
||||
if (mmaLayout.isAmpere())
|
||||
return emitOffsetForMmaLayoutV2(mmaLayout, shape);
|
||||
}
|
||||
llvm_unreachable("unsupported emitOffsetForLayout");
|
||||
}
|
||||
|
||||
// Emit indices calculation within each ConversionPattern, and returns a
|
||||
// [elemsPerThread X rank] index matrix.
|
||||
|
||||
// TODO: [phil] redundant indices computation do not appear to hurt
|
||||
// performance much, but they could still significantly slow down
|
||||
// computations.
|
||||
SmallVector<SmallVector<Value>> emitIndicesForDistributedLayout(
|
||||
Location loc, ConversionPatternRewriter &rewriter,
|
||||
const Attribute &layout, ArrayRef<int64_t> shape) const {
|
||||
|
||||
// step 1, delinearize threadId to get the base index
|
||||
auto multiDimBase = emitBaseIndexForLayout(loc, rewriter, layout, shape);
|
||||
// step 2, get offset of each element
|
||||
auto offset = emitOffsetForLayout(layout, shape);
|
||||
// step 3, add offset to base, and reorder the sequence of indices to
|
||||
// guarantee that elems in the same sizePerThread are adjacent in order
|
||||
unsigned rank = shape.size();
|
||||
unsigned elemsPerThread = offset.size();
|
||||
SmallVector<SmallVector<Value>> multiDimIdx(elemsPerThread,
|
||||
SmallVector<Value>(rank));
|
||||
for (unsigned n = 0; n < elemsPerThread; ++n)
|
||||
for (unsigned k = 0; k < rank; ++k)
|
||||
multiDimIdx[n][k] = add(multiDimBase[k], idx_val(offset[n][k]));
|
||||
|
||||
return multiDimIdx;
|
||||
}
|
||||
|
||||
struct SmallVectorKeyInfo {
|
||||
static unsigned getHashValue(const SmallVector<unsigned> &key) {
|
||||
return llvm::hash_combine_range(key.begin(), key.end());
|
||||
}
|
||||
static bool isEqual(const SmallVector<unsigned> &lhs,
|
||||
const SmallVector<unsigned> &rhs) {
|
||||
return lhs == rhs;
|
||||
}
|
||||
static SmallVector<unsigned> getEmptyKey() {
|
||||
return SmallVector<unsigned>();
|
||||
}
|
||||
static SmallVector<unsigned> getTombstoneKey() {
|
||||
return {std::numeric_limits<unsigned>::max()};
|
||||
}
|
||||
};
|
||||
|
||||
SmallVector<SmallVector<Value>>
|
||||
emitIndicesForSliceLayout(Location loc, ConversionPatternRewriter &rewriter,
|
||||
const SliceEncodingAttr &sliceLayout,
|
||||
ArrayRef<int64_t> shape) const {
|
||||
auto parent = sliceLayout.getParent();
|
||||
unsigned dim = sliceLayout.getDim();
|
||||
size_t rank = shape.size();
|
||||
auto parentIndices =
|
||||
emitIndices(loc, rewriter, parent, sliceLayout.paddedShape(shape));
|
||||
unsigned numIndices = parentIndices.size();
|
||||
SmallVector<SmallVector<Value>> resultIndices;
|
||||
for (unsigned i = 0; i < numIndices; ++i) {
|
||||
SmallVector<Value> indices = parentIndices[i];
|
||||
indices.erase(indices.begin() + dim);
|
||||
resultIndices.push_back(indices);
|
||||
}
|
||||
return resultIndices;
|
||||
}
|
||||
|
||||
// -----------------------------------------------------------------------
|
||||
// Emit indices
|
||||
// -----------------------------------------------------------------------
|
||||
SmallVector<SmallVector<Value>> emitIndices(Location loc,
|
||||
ConversionPatternRewriter &b,
|
||||
const Attribute &layout,
|
||||
ArrayRef<int64_t> shape) const {
|
||||
if (auto blocked = layout.dyn_cast<BlockedEncodingAttr>()) {
|
||||
return emitIndicesForDistributedLayout(loc, b, blocked, shape);
|
||||
} else if (auto mma = layout.dyn_cast<MmaEncodingAttr>()) {
|
||||
return emitIndicesForDistributedLayout(loc, b, mma, shape);
|
||||
} else if (auto slice = layout.dyn_cast<SliceEncodingAttr>()) {
|
||||
return emitIndicesForSliceLayout(loc, b, slice, shape);
|
||||
} else {
|
||||
assert(0 && "emitIndices for layouts other than blocked & slice not "
|
||||
"implemented yet");
|
||||
return {};
|
||||
}
|
||||
}
|
||||
|
||||
// -----------------------------------------------------------------------
|
||||
// Shared memory utilities
|
||||
// -----------------------------------------------------------------------
|
||||
template <typename T>
|
||||
Value getSharedMemoryBase(Location loc, ConversionPatternRewriter &rewriter,
|
||||
T value) const {
|
||||
auto ptrTy = LLVM::LLVMPointerType::get(
|
||||
this->getTypeConverter()->convertType(rewriter.getI8Type()), 3);
|
||||
auto bufferId = allocation->getBufferId(value);
|
||||
assert(bufferId != Allocation::InvalidBufferId && "BufferId not found");
|
||||
size_t offset = allocation->getOffset(bufferId);
|
||||
Value offVal = idx_val(offset);
|
||||
Value base = gep(ptrTy, smem, offVal);
|
||||
return base;
|
||||
}
|
||||
|
||||
protected:
|
||||
const Allocation *allocation;
|
||||
Value smem;
|
||||
};
|
||||
|
||||
#endif
|
||||
339
lib/Conversion/TritonGPUToLLVM/TritonGPUToLLVMPass.cpp
Normal file
339
lib/Conversion/TritonGPUToLLVM/TritonGPUToLLVMPass.cpp
Normal file
@@ -0,0 +1,339 @@
|
||||
#include "triton/Conversion/TritonGPUToLLVM/TritonGPUToLLVMPass.h"
|
||||
|
||||
#include "mlir/Conversion/ArithmeticToLLVM/ArithmeticToLLVM.h"
|
||||
#include "mlir/Conversion/GPUToNVVM/GPUToNVVMPass.h"
|
||||
#include "mlir/Conversion/MathToLLVM/MathToLLVM.h"
|
||||
#include "mlir/Conversion/SCFToStandard/SCFToStandard.h"
|
||||
#include "mlir/Conversion/StandardToLLVM/ConvertStandardToLLVM.h"
|
||||
#include "mlir/Dialect/LLVMIR/LLVMDialect.h"
|
||||
#include "mlir/Dialect/LLVMIR/NVVMDialect.h"
|
||||
#include "mlir/Pass/Pass.h"
|
||||
#include "triton/Analysis/Allocation.h"
|
||||
#include "triton/Analysis/AxisInfo.h"
|
||||
#include "triton/Analysis/Membar.h"
|
||||
#include "triton/Dialect/Triton/IR/Dialect.h"
|
||||
#include "triton/Dialect/TritonGPU/IR/Dialect.h"
|
||||
|
||||
#include "TritonGPUToLLVM.h"
|
||||
#include "TypeConverter.h"
|
||||
|
||||
using namespace mlir;
|
||||
using namespace mlir::triton;
|
||||
|
||||
#define GEN_PASS_CLASSES
|
||||
#include "triton/Conversion/Passes.h.inc"
|
||||
|
||||
namespace mlir {
|
||||
|
||||
class TritonLLVMConversionTarget : public ConversionTarget {
|
||||
public:
|
||||
explicit TritonLLVMConversionTarget(MLIRContext &ctx)
|
||||
: ConversionTarget(ctx) {
|
||||
addLegalDialect<LLVM::LLVMDialect>();
|
||||
addLegalDialect<NVVM::NVVMDialect>();
|
||||
addIllegalDialect<triton::TritonDialect>();
|
||||
addIllegalDialect<triton::gpu::TritonGPUDialect>();
|
||||
addIllegalDialect<mlir::gpu::GPUDialect>();
|
||||
addIllegalDialect<mlir::StandardOpsDialect>();
|
||||
addLegalOp<mlir::UnrealizedConversionCastOp>();
|
||||
}
|
||||
};
|
||||
|
||||
class TritonLLVMFunctionConversionTarget : public ConversionTarget {
|
||||
public:
|
||||
explicit TritonLLVMFunctionConversionTarget(MLIRContext &ctx)
|
||||
: ConversionTarget(ctx) {
|
||||
addLegalDialect<LLVM::LLVMDialect>();
|
||||
addLegalDialect<NVVM::NVVMDialect>();
|
||||
addIllegalOp<mlir::FuncOp>();
|
||||
addLegalOp<mlir::UnrealizedConversionCastOp>();
|
||||
}
|
||||
};
|
||||
|
||||
} // namespace mlir
|
||||
|
||||
namespace {
|
||||
|
||||
class ConvertTritonGPUToLLVM
|
||||
: public ConvertTritonGPUToLLVMBase<ConvertTritonGPUToLLVM> {
|
||||
|
||||
public:
|
||||
explicit ConvertTritonGPUToLLVM(int computeCapability)
|
||||
: computeCapability(computeCapability) {}
|
||||
|
||||
void runOnOperation() override {
|
||||
MLIRContext *context = &getContext();
|
||||
ModuleOp mod = getOperation();
|
||||
|
||||
mlir::LowerToLLVMOptions option(context);
|
||||
option.overrideIndexBitwidth(32);
|
||||
TritonGPUToLLVMTypeConverter typeConverter(context, option);
|
||||
TritonLLVMFunctionConversionTarget funcTarget(*context);
|
||||
TritonLLVMConversionTarget target(*context);
|
||||
|
||||
int numWarps = triton::gpu::TritonGPUDialect::getNumWarps(mod);
|
||||
|
||||
// Step 1: Decompose unoptimized layout conversions to use shared memory
|
||||
// Step 2: Decompose insert_slice_async to use load + insert_slice for
|
||||
// pre-Ampere architectures or unsupported vectorized load sizes
|
||||
// Step 3: Allocate shared memories and insert barriers
|
||||
// Step 4: Convert SCF to CFG
|
||||
// Step 5: Convert FuncOp to LLVMFuncOp via partial conversion
|
||||
// Step 6: Convert the rest of ops via partial conversion
|
||||
//
|
||||
// The reason for putting step 1 before step 2 is that the membar
|
||||
// analysis currently only supports SCF but not CFG. The reason for a
|
||||
// separation between 1/4 is that, step 3 is out of the scope of Dialect
|
||||
// Conversion, thus we need to make sure the smem is not revised during the
|
||||
// conversion of step 4.
|
||||
|
||||
// Step 1
|
||||
decomposeMmaToDotOperand(mod, numWarps);
|
||||
decomposeBlockedToDotOperand(mod);
|
||||
|
||||
// Step 2
|
||||
decomposeInsertSliceAsyncOp(mod);
|
||||
|
||||
// Step 3
|
||||
Allocation allocation(mod);
|
||||
MembarAnalysis membarPass(&allocation);
|
||||
membarPass.run();
|
||||
|
||||
// Step 4
|
||||
RewritePatternSet scf_patterns(context);
|
||||
mlir::populateLoopToStdConversionPatterns(scf_patterns);
|
||||
mlir::ConversionTarget scf_target(*context);
|
||||
scf_target.addIllegalOp<scf::ForOp, scf::IfOp, scf::ParallelOp,
|
||||
scf::WhileOp, scf::ExecuteRegionOp>();
|
||||
scf_target.markUnknownOpDynamicallyLegal([](Operation *) { return true; });
|
||||
if (failed(
|
||||
applyPartialConversion(mod, scf_target, std::move(scf_patterns))))
|
||||
return signalPassFailure();
|
||||
|
||||
// Step 5
|
||||
RewritePatternSet func_patterns(context);
|
||||
func_patterns.add<FuncOpConversion>(typeConverter, numWarps, /*benefit=*/1);
|
||||
if (failed(
|
||||
applyPartialConversion(mod, funcTarget, std::move(func_patterns))))
|
||||
return signalPassFailure();
|
||||
|
||||
// Step 6 - get axis and shared memory info
|
||||
AxisInfoAnalysis axisInfoAnalysis(mod.getContext());
|
||||
axisInfoAnalysis.run(mod);
|
||||
initSharedMemory(allocation.getSharedMemorySize(), typeConverter);
|
||||
mod->setAttr("triton_gpu.shared",
|
||||
mlir::IntegerAttr::get(mlir::IntegerType::get(context, 32),
|
||||
allocation.getSharedMemorySize()));
|
||||
|
||||
// Step 6 - rewrite rest of ops
|
||||
// We set a higher benefit here to ensure triton's patterns runs before
|
||||
// arith patterns for some encoding not supported by the community
|
||||
// patterns.
|
||||
RewritePatternSet patterns(context);
|
||||
populateTritonToLLVMPatterns(typeConverter, patterns, numWarps,
|
||||
axisInfoAnalysis, &allocation, smem,
|
||||
/*benefit=*/10);
|
||||
// Add arith/math's patterns to help convert scalar expression to LLVM.
|
||||
mlir::arith::populateArithmeticToLLVMConversionPatterns(typeConverter,
|
||||
patterns);
|
||||
mlir::populateMathToLLVMConversionPatterns(typeConverter, patterns);
|
||||
mlir::populateStdToLLVMConversionPatterns(typeConverter, patterns);
|
||||
mlir::populateGpuToNVVMConversionPatterns(typeConverter, patterns);
|
||||
|
||||
if (failed(applyPartialConversion(mod, target, std::move(patterns))))
|
||||
return signalPassFailure();
|
||||
}
|
||||
|
||||
private:
|
||||
Value smem;
|
||||
|
||||
int computeCapability{};
|
||||
|
||||
void initSharedMemory(size_t size,
|
||||
TritonGPUToLLVMTypeConverter &typeConverter) {
|
||||
ModuleOp mod = getOperation();
|
||||
OpBuilder b(mod.getBodyRegion());
|
||||
auto loc = mod.getLoc();
|
||||
auto elemTy = typeConverter.convertType(b.getIntegerType(8));
|
||||
// Set array size 0 and external linkage indicates that we use dynamic
|
||||
// shared allocation to allow a larger shared memory size for each kernel.
|
||||
auto arrayTy = LLVM::LLVMArrayType::get(elemTy, 0);
|
||||
auto global = b.create<LLVM::GlobalOp>(
|
||||
loc, arrayTy, /*isConstant=*/false, LLVM::Linkage::External,
|
||||
"global_smem", /*value=*/Attribute(), /*alignment=*/0,
|
||||
mlir::gpu::GPUDialect::getWorkgroupAddressSpace());
|
||||
SmallVector<LLVM::LLVMFuncOp> funcs;
|
||||
mod.walk([&](LLVM::LLVMFuncOp func) { funcs.push_back(func); });
|
||||
assert(funcs.size() == 1 &&
|
||||
"Inliner pass is expected before TritonGPUToLLVM");
|
||||
b.setInsertionPointToStart(&funcs[0].getBody().front());
|
||||
smem = b.create<LLVM::AddressOfOp>(loc, global);
|
||||
auto ptrTy =
|
||||
LLVM::LLVMPointerType::get(typeConverter.convertType(b.getI8Type()), 3);
|
||||
smem = b.create<LLVM::BitcastOp>(loc, ptrTy, smem);
|
||||
}
|
||||
|
||||
void decomposeMmaToDotOperand(ModuleOp mod, int numWarps) const {
|
||||
// Replace `mma -> dot_op` with `mma -> blocked -> dot_op`
|
||||
// unless certain conditions are met
|
||||
mod.walk([&](triton::gpu::ConvertLayoutOp cvtOp) -> void {
|
||||
OpBuilder builder(cvtOp);
|
||||
auto srcType = cvtOp.getOperand().getType().cast<RankedTensorType>();
|
||||
auto dstType = cvtOp.getType().cast<RankedTensorType>();
|
||||
auto srcMma =
|
||||
srcType.getEncoding().dyn_cast<triton::gpu::MmaEncodingAttr>();
|
||||
auto dstDotOp =
|
||||
dstType.getEncoding().dyn_cast<triton::gpu::DotOperandEncodingAttr>();
|
||||
if (srcMma && dstDotOp &&
|
||||
!ConvertLayoutOpConversion::isMmaToDotShortcut(srcMma, dstDotOp)) {
|
||||
auto tmpType = RankedTensorType::get(
|
||||
dstType.getShape(), dstType.getElementType(),
|
||||
triton::gpu::BlockedEncodingAttr::get(
|
||||
mod.getContext(), srcType.getShape(), getSizePerThread(srcMma),
|
||||
getOrder(srcMma), numWarps));
|
||||
auto tmp = builder.create<triton::gpu::ConvertLayoutOp>(
|
||||
cvtOp.getLoc(), tmpType, cvtOp.getOperand());
|
||||
auto newConvert = builder.create<triton::gpu::ConvertLayoutOp>(
|
||||
cvtOp.getLoc(), dstType, tmp);
|
||||
cvtOp.replaceAllUsesWith(newConvert.getResult());
|
||||
cvtOp.erase();
|
||||
}
|
||||
});
|
||||
}
|
||||
|
||||
void decomposeBlockedToDotOperand(ModuleOp mod) const {
|
||||
// Replace `blocked -> dot_op` with `blocked -> shared -> dot_op`
|
||||
// because the codegen doesn't handle `blocked -> dot_op` directly
|
||||
mod.walk([&](triton::gpu::ConvertLayoutOp cvtOp) -> void {
|
||||
OpBuilder builder(cvtOp);
|
||||
auto srcType = cvtOp.getOperand().getType().cast<RankedTensorType>();
|
||||
auto dstType = cvtOp.getType().cast<RankedTensorType>();
|
||||
auto srcBlocked =
|
||||
srcType.getEncoding().dyn_cast<triton::gpu::BlockedEncodingAttr>();
|
||||
auto dstDotOp =
|
||||
dstType.getEncoding().dyn_cast<triton::gpu::DotOperandEncodingAttr>();
|
||||
if (srcBlocked && dstDotOp) {
|
||||
auto tmpType = RankedTensorType::get(
|
||||
dstType.getShape(), dstType.getElementType(),
|
||||
triton::gpu::SharedEncodingAttr::get(
|
||||
mod.getContext(), dstDotOp, srcType.getShape(),
|
||||
getOrder(srcBlocked), srcType.getElementType()));
|
||||
auto tmp = builder.create<triton::gpu::ConvertLayoutOp>(
|
||||
cvtOp.getLoc(), tmpType, cvtOp.getOperand());
|
||||
auto newConvert = builder.create<triton::gpu::ConvertLayoutOp>(
|
||||
cvtOp.getLoc(), dstType, tmp);
|
||||
cvtOp.replaceAllUsesWith(newConvert.getResult());
|
||||
cvtOp.erase();
|
||||
}
|
||||
});
|
||||
}
|
||||
|
||||
void decomposeInsertSliceAsyncOp(ModuleOp mod) const {
|
||||
AxisInfoAnalysis axisInfoAnalysis(mod.getContext());
|
||||
axisInfoAnalysis.run(mod);
|
||||
// TODO(Keren): This is a hacky knob that may cause performance regression
|
||||
// when decomposition has been performed. We should remove this knob once we
|
||||
// have thorough analysis on async wait. Currently, we decompose
|
||||
// `insert_slice_async` into `load` and `insert_slice` without knowing which
|
||||
// `async_wait` is responsible for the `insert_slice_async`. To guarantee
|
||||
// correctness, we blindly set the `async_wait` to wait for all async ops.
|
||||
//
|
||||
// There are two options to improve this:
|
||||
// 1. We can perform a dataflow analysis to find the `async_wait` that is
|
||||
// responsible for the `insert_slice_async` in the backend.
|
||||
// 2. We can modify the pipeline to perform the decomposition before the
|
||||
// `async_wait` is inserted. However, it is also risky because we don't know
|
||||
// the correct vectorized shape yet in the pipeline pass. Making the
|
||||
// pipeline pass aware of the vectorization could introduce additional
|
||||
// dependencies on the AxisInfoAnalysis and the Coalesce analysis.
|
||||
bool decomposed = false;
|
||||
// insert_slice_async %src, %dst, %idx, %mask, %other
|
||||
// =>
|
||||
// %tmp = load %src, %mask, %other
|
||||
// %res = insert_slice %tmp into %dst[%idx]
|
||||
mod.walk([&](triton::gpu::InsertSliceAsyncOp insertSliceAsyncOp) -> void {
|
||||
OpBuilder builder(insertSliceAsyncOp);
|
||||
|
||||
// Get the vectorized load size
|
||||
auto src = insertSliceAsyncOp.src();
|
||||
auto dst = insertSliceAsyncOp.dst();
|
||||
auto srcTy = src.getType().cast<RankedTensorType>();
|
||||
auto dstTy = dst.getType().cast<RankedTensorType>();
|
||||
auto srcBlocked =
|
||||
srcTy.getEncoding().dyn_cast<triton::gpu::BlockedEncodingAttr>();
|
||||
auto resSharedLayout =
|
||||
dstTy.getEncoding().dyn_cast<triton::gpu::SharedEncodingAttr>();
|
||||
auto resElemTy = dstTy.getElementType();
|
||||
unsigned inVec = axisInfoAnalysis.getPtrVectorSize(src);
|
||||
unsigned outVec = resSharedLayout.getVec();
|
||||
unsigned minVec = std::min(outVec, inVec);
|
||||
auto maxBitWidth =
|
||||
std::max<unsigned>(128, resElemTy.getIntOrFloatBitWidth());
|
||||
auto vecBitWidth = resElemTy.getIntOrFloatBitWidth() * minVec;
|
||||
auto bitWidth = std::min<unsigned>(maxBitWidth, vecBitWidth);
|
||||
auto byteWidth = bitWidth / 8;
|
||||
|
||||
// If the load byte width is not eligible or the current compute
|
||||
// capability does not support async copy, then we do decompose
|
||||
if (triton::gpu::InsertSliceAsyncOp::getEligibleLoadByteWidth(
|
||||
computeCapability)
|
||||
.contains(byteWidth))
|
||||
return;
|
||||
|
||||
// load
|
||||
auto tmpTy =
|
||||
RankedTensorType::get(srcTy.getShape(), resElemTy, srcBlocked);
|
||||
auto loadOp = builder.create<triton::LoadOp>(
|
||||
insertSliceAsyncOp.getLoc(), tmpTy, insertSliceAsyncOp.src(),
|
||||
insertSliceAsyncOp.mask(), insertSliceAsyncOp.other(),
|
||||
insertSliceAsyncOp.cache(), insertSliceAsyncOp.evict(),
|
||||
insertSliceAsyncOp.isVolatile());
|
||||
|
||||
// insert_slice
|
||||
auto axis = insertSliceAsyncOp.axis();
|
||||
auto intAttr = [&](int64_t v) { return builder.getI64IntegerAttr(v); };
|
||||
auto offsets = SmallVector<OpFoldResult>(dstTy.getRank(), intAttr(0));
|
||||
auto sizes = SmallVector<OpFoldResult>(dstTy.getRank(), intAttr(1));
|
||||
auto strides = SmallVector<OpFoldResult>(dstTy.getRank(), intAttr(1));
|
||||
offsets[axis] = insertSliceAsyncOp.index();
|
||||
for (size_t i = 0; i < dstTy.getRank(); i++) {
|
||||
if (i != axis)
|
||||
sizes[i] = intAttr(dstTy.getShape()[i]);
|
||||
}
|
||||
auto insertSliceOp = builder.create<tensor::InsertSliceOp>(
|
||||
insertSliceAsyncOp.getLoc(), loadOp, insertSliceAsyncOp.dst(),
|
||||
offsets, sizes, strides);
|
||||
|
||||
// Replace
|
||||
insertSliceAsyncOp.replaceAllUsesWith(insertSliceOp.getResult());
|
||||
insertSliceAsyncOp.erase();
|
||||
decomposed = true;
|
||||
});
|
||||
|
||||
mod.walk([&](triton::gpu::AsyncWaitOp asyncWaitOp) -> void {
|
||||
if (!triton::gpu::AsyncWaitOp::isSupported(computeCapability)) {
|
||||
// async wait is supported in Ampere and later
|
||||
asyncWaitOp.erase();
|
||||
} else if (decomposed) {
|
||||
// Wait for all previous async ops
|
||||
OpBuilder builder(asyncWaitOp);
|
||||
auto newAsyncWaitOp =
|
||||
builder.create<triton::gpu::AsyncWaitOp>(asyncWaitOp.getLoc(), 0);
|
||||
asyncWaitOp.erase();
|
||||
}
|
||||
});
|
||||
}
|
||||
};
|
||||
|
||||
} // anonymous namespace
|
||||
|
||||
namespace mlir {
|
||||
namespace triton {
|
||||
|
||||
std::unique_ptr<OperationPass<ModuleOp>>
|
||||
createConvertTritonGPUToLLVMPass(int computeCapability) {
|
||||
return std::make_unique<::ConvertTritonGPUToLLVM>(computeCapability);
|
||||
}
|
||||
|
||||
} // namespace triton
|
||||
} // namespace mlir
|
||||
146
lib/Conversion/TritonGPUToLLVM/TypeConverter.h
Normal file
146
lib/Conversion/TritonGPUToLLVM/TypeConverter.h
Normal file
@@ -0,0 +1,146 @@
|
||||
#ifndef TRITON_CONVERSION_TRITONGPU_TO_LLVM_TYPECONVERTER_H
|
||||
#define TRITON_CONVERSION_TRITONGPU_TO_LLVM_TYPECONVERTER_H
|
||||
|
||||
#include "mlir/Dialect/LLVMIR/LLVMDialect.h"
|
||||
#include "triton/Conversion/MLIRTypes.h"
|
||||
|
||||
#include "DotHelpers.h"
|
||||
#include "Utility.h"
|
||||
|
||||
using namespace mlir;
|
||||
using namespace mlir::triton;
|
||||
|
||||
using ::mlir::LLVM::DotOpFMAConversionHelper;
|
||||
using ::mlir::LLVM::DotOpMmaV1ConversionHelper;
|
||||
using ::mlir::LLVM::MMA16816ConversionHelper;
|
||||
using ::mlir::triton::gpu::BlockedEncodingAttr;
|
||||
using ::mlir::triton::gpu::DotOperandEncodingAttr;
|
||||
using ::mlir::triton::gpu::MmaEncodingAttr;
|
||||
using ::mlir::triton::gpu::SharedEncodingAttr;
|
||||
using ::mlir::triton::gpu::SliceEncodingAttr;
|
||||
using ::mlir::triton::gpu::getElemsPerThread;
|
||||
|
||||
class TritonGPUToLLVMTypeConverter : public LLVMTypeConverter {
|
||||
public:
|
||||
using TypeConverter::convertType;
|
||||
|
||||
TritonGPUToLLVMTypeConverter(MLIRContext *ctx, LowerToLLVMOptions &option,
|
||||
const DataLayoutAnalysis *analysis = nullptr)
|
||||
: LLVMTypeConverter(ctx, option, analysis) {
|
||||
addConversion([&](triton::PointerType type) -> llvm::Optional<Type> {
|
||||
return convertTritonPointerType(type);
|
||||
});
|
||||
addConversion([&](RankedTensorType type) -> llvm::Optional<Type> {
|
||||
return convertTritonTensorType(type);
|
||||
});
|
||||
// Internally store float8 as int8
|
||||
addConversion([&](triton::Float8Type type) -> llvm::Optional<Type> {
|
||||
return IntegerType::get(type.getContext(), 8);
|
||||
});
|
||||
}
|
||||
|
||||
Type convertTritonPointerType(triton::PointerType type) {
|
||||
// Recursively translate pointee type
|
||||
return LLVM::LLVMPointerType::get(convertType(type.getPointeeType()),
|
||||
type.getAddressSpace());
|
||||
}
|
||||
|
||||
llvm::Optional<Type> convertTritonTensorType(RankedTensorType type) {
|
||||
auto ctx = type.getContext();
|
||||
Attribute layout = type.getEncoding();
|
||||
SmallVector<int64_t> shape(type.getShape().begin(), type.getShape().end());
|
||||
|
||||
if (layout &&
|
||||
(layout.isa<BlockedEncodingAttr>() || layout.isa<SliceEncodingAttr>() ||
|
||||
layout.isa<MmaEncodingAttr>())) {
|
||||
unsigned numElementsPerThread = getElemsPerThread(type);
|
||||
SmallVector<Type, 4> types(numElementsPerThread,
|
||||
convertType(type.getElementType()));
|
||||
return LLVM::LLVMStructType::getLiteral(ctx, types);
|
||||
} else if (auto shared_layout =
|
||||
layout.dyn_cast_or_null<SharedEncodingAttr>()) {
|
||||
SmallVector<Type, 4> types;
|
||||
// base ptr
|
||||
auto ptrType =
|
||||
LLVM::LLVMPointerType::get(convertType(type.getElementType()), 3);
|
||||
types.push_back(ptrType);
|
||||
// shape dims
|
||||
auto rank = type.getRank();
|
||||
// offsets + strides
|
||||
for (auto i = 0; i < rank * 2; i++) {
|
||||
types.push_back(IntegerType::get(ctx, 32));
|
||||
}
|
||||
return LLVM::LLVMStructType::getLiteral(ctx, types);
|
||||
} else if (auto dotOpLayout =
|
||||
layout.dyn_cast_or_null<DotOperandEncodingAttr>()) {
|
||||
if (dotOpLayout.getParent()
|
||||
.isa<BlockedEncodingAttr>()) { // for parent is blocked layout
|
||||
int numElemsPerThread =
|
||||
DotOpFMAConversionHelper::getNumElemsPerThread(shape, dotOpLayout);
|
||||
|
||||
return LLVM::LLVMStructType::getLiteral(
|
||||
ctx, SmallVector<Type>(numElemsPerThread, type::f32Ty(ctx)));
|
||||
} else { // for parent is MMA layout
|
||||
auto mmaLayout = dotOpLayout.getParent().cast<MmaEncodingAttr>();
|
||||
auto wpt = mmaLayout.getWarpsPerCTA();
|
||||
Type elemTy = convertType(type.getElementType());
|
||||
if (mmaLayout.isAmpere()) {
|
||||
const llvm::DenseMap<int, Type> targetTyMap = {
|
||||
{32, elemTy},
|
||||
{16, vec_ty(elemTy, 2)},
|
||||
{8, vec_ty(elemTy, 4)},
|
||||
};
|
||||
Type targetTy;
|
||||
if (targetTyMap.count(elemTy.getIntOrFloatBitWidth())) {
|
||||
targetTy = targetTyMap.lookup(elemTy.getIntOrFloatBitWidth());
|
||||
} else {
|
||||
assert(false && "Unsupported element type");
|
||||
}
|
||||
if (dotOpLayout.getOpIdx() == 0) { // $a
|
||||
auto elems =
|
||||
MMA16816ConversionHelper::getANumElemsPerThread(type, wpt[0]);
|
||||
return LLVM::LLVMStructType::getLiteral(
|
||||
ctx, SmallVector<Type>(elems, targetTy));
|
||||
}
|
||||
if (dotOpLayout.getOpIdx() == 1) { // $b
|
||||
auto elems =
|
||||
MMA16816ConversionHelper::getBNumElemsPerThread(type, wpt[1]);
|
||||
return struct_ty(SmallVector<Type>(elems, targetTy));
|
||||
}
|
||||
}
|
||||
|
||||
if (mmaLayout.isVolta()) {
|
||||
DotOpMmaV1ConversionHelper helper(mmaLayout);
|
||||
|
||||
// TODO[Superjomn]: Both transA and transB are not available here.
|
||||
bool trans = false;
|
||||
// TODO[Superjomn]: The order of A and B are not available here.
|
||||
SmallVector<unsigned> order({1, 0});
|
||||
if (trans) {
|
||||
std::swap(shape[0], shape[1]);
|
||||
std::swap(order[0], order[1]);
|
||||
}
|
||||
|
||||
if (dotOpLayout.getOpIdx() == 0) { // $a
|
||||
int elems = helper.numElemsPerThreadA(shape, order);
|
||||
Type x2Ty = vec_ty(elemTy, 2);
|
||||
return struct_ty(SmallVector<Type>(elems, x2Ty));
|
||||
}
|
||||
if (dotOpLayout.getOpIdx() == 1) { // $b
|
||||
int elems = helper.numElemsPerThreadB(shape, order);
|
||||
Type x2Ty = vec_ty(elemTy, 2);
|
||||
return struct_ty(SmallVector<Type>(elems, x2Ty));
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
llvm::errs() << "Unexpected dot operand layout detected in "
|
||||
"TritonToLLVMTypeConverter";
|
||||
return llvm::None;
|
||||
}
|
||||
|
||||
return llvm::None;
|
||||
}
|
||||
};
|
||||
|
||||
#endif
|
||||
@@ -1,34 +1,11 @@
|
||||
#ifndef TRITON_CONVERSION_TRITONGPU_TO_LLVM_UTILITY_H
|
||||
#define TRITON_CONVERSION_TRITONGPU_TO_LLVM_UTILITY_H
|
||||
|
||||
#include "mlir/Analysis/SliceAnalysis.h"
|
||||
#include "mlir/Conversion/ArithmeticToLLVM/ArithmeticToLLVM.h"
|
||||
#include "mlir/Conversion/GPUToNVVM/GPUToNVVMPass.h"
|
||||
#include "mlir/Conversion/LLVMCommon/LoweringOptions.h"
|
||||
#include "mlir/Conversion/LLVMCommon/Pattern.h"
|
||||
#include "mlir/Conversion/MathToLLVM/MathToLLVM.h"
|
||||
#include "mlir/Conversion/SCFToStandard/SCFToStandard.h"
|
||||
#include "mlir/Conversion/StandardToLLVM/ConvertStandardToLLVM.h"
|
||||
#include "mlir/Dialect/Arithmetic/IR/Arithmetic.h"
|
||||
#include "mlir/Dialect/GPU/GPUDialect.h"
|
||||
#include "mlir/Dialect/LLVMIR/LLVMDialect.h"
|
||||
#include "mlir/Dialect/Tensor/IR/Tensor.h"
|
||||
#include "mlir/IR/Matchers.h"
|
||||
#include "mlir/IR/TypeUtilities.h"
|
||||
#include "mlir/Transforms/DialectConversion.h"
|
||||
#include "triton/Analysis/Allocation.h"
|
||||
#include "triton/Analysis/AxisInfo.h"
|
||||
#include "triton/Analysis/Membar.h"
|
||||
#include "triton/Analysis/Utility.h"
|
||||
#include "triton/Conversion/MLIRTypes.h"
|
||||
#include "triton/Conversion/TritonGPUToLLVM/PtxAsmFormat.h"
|
||||
#include "triton/Conversion/TritonToTritonGPU/TritonToTritonGPU.h"
|
||||
#include "triton/Dialect/Triton/IR/Dialect.h"
|
||||
#include "triton/Dialect/TritonGPU/IR/Dialect.h"
|
||||
#include "llvm/Support/Format.h"
|
||||
#include "llvm/Support/FormatVariadic.h"
|
||||
#include <memory>
|
||||
#include <numeric>
|
||||
#include "triton/Conversion/TritonGPUToLLVM/PTXAsmFormat.h"
|
||||
|
||||
// Shortcuts for some commonly used LLVM ops to keep code simple and intuitive
|
||||
// Operators
|
||||
@@ -115,13 +92,76 @@
|
||||
#define idx_val(...) \
|
||||
LLVM::createIndexConstant(rewriter, loc, this->getTypeConverter(), \
|
||||
__VA_ARGS__)
|
||||
|
||||
#define tid_val() getThreadId(rewriter, loc)
|
||||
|
||||
namespace mlir {
|
||||
namespace triton {
|
||||
|
||||
// Delinearize supposing order is [0, 1, .. , n]
|
||||
template <typename T>
|
||||
llvm::SmallVector<T> getMultiDimIndexImpl(T linearIndex,
|
||||
llvm::ArrayRef<T> shape) {
|
||||
// shape: {a, b, c, d} -> accMul: {1, a, a*b, a*b*c}
|
||||
size_t rank = shape.size();
|
||||
T accMul = product(shape.drop_back());
|
||||
T linearRemain = linearIndex;
|
||||
llvm::SmallVector<T> multiDimIndex(rank);
|
||||
for (int i = rank - 1; i >= 0; --i) {
|
||||
multiDimIndex[i] = linearRemain / accMul;
|
||||
linearRemain = linearRemain % accMul;
|
||||
if (i != 0) {
|
||||
accMul = accMul / shape[i - 1];
|
||||
}
|
||||
}
|
||||
return multiDimIndex;
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
llvm::SmallVector<T> getMultiDimIndex(T linearIndex, llvm::ArrayRef<T> shape,
|
||||
llvm::ArrayRef<unsigned> order) {
|
||||
size_t rank = shape.size();
|
||||
assert(rank == order.size());
|
||||
auto reordered = reorder(shape, order);
|
||||
auto reorderedMultiDim = getMultiDimIndexImpl<T>(linearIndex, reordered);
|
||||
llvm::SmallVector<T> multiDim(rank);
|
||||
for (unsigned i = 0; i < rank; ++i) {
|
||||
multiDim[order[i]] = reorderedMultiDim[i];
|
||||
}
|
||||
return multiDim;
|
||||
}
|
||||
|
||||
// Linearize supposing order is [0, 1, .. , n]
|
||||
template <typename T>
|
||||
static T getLinearIndexImpl(llvm::ArrayRef<T> multiDimIndex, llvm::ArrayRef<T> shape) {
|
||||
assert(multiDimIndex.size() == shape.size());
|
||||
// shape: {a, b, c, d} -> accMul: {1, a, a*b, a*b*c}
|
||||
size_t rank = shape.size();
|
||||
T accMul = product(shape.drop_back());
|
||||
T linearIndex = 0;
|
||||
for (int i = rank - 1; i >= 0; --i) {
|
||||
linearIndex += multiDimIndex[i] * accMul;
|
||||
if (i != 0) {
|
||||
accMul = accMul / shape[i - 1];
|
||||
}
|
||||
}
|
||||
return linearIndex;
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
static T getLinearIndex(llvm::ArrayRef<T> multiDimIndex,
|
||||
llvm::ArrayRef<T> shape,
|
||||
llvm::ArrayRef<unsigned> order) {
|
||||
assert(shape.size() == order.size());
|
||||
return getLinearIndexImpl<T>(reorder(multiDimIndex, order),
|
||||
reorder(shape, order));
|
||||
}
|
||||
|
||||
} // namespace triton
|
||||
|
||||
namespace LLVM {
|
||||
using namespace mlir::triton;
|
||||
|
||||
static
|
||||
Value getStructFromElements(Location loc, ValueRange resultVals,
|
||||
ConversionPatternRewriter &rewriter,
|
||||
Type structType) {
|
||||
@@ -138,6 +178,7 @@ Value getStructFromElements(Location loc, ValueRange resultVals,
|
||||
return llvmStruct;
|
||||
}
|
||||
|
||||
static
|
||||
SmallVector<Value> getElementsFromStruct(Location loc, Value llvmStruct,
|
||||
ConversionPatternRewriter &rewriter) {
|
||||
if (llvmStruct.getType().isIntOrIndexOrFloat() ||
|
||||
@@ -155,47 +196,50 @@ SmallVector<Value> getElementsFromStruct(Location loc, Value llvmStruct,
|
||||
}
|
||||
|
||||
// Create a 32-bit integer constant.
|
||||
Value createConstantI32(Location loc, PatternRewriter &rewriter, int32_t v) {
|
||||
static Value createConstantI32(Location loc,
|
||||
PatternRewriter &rewriter, int32_t v) {
|
||||
auto i32ty = rewriter.getIntegerType(32);
|
||||
return rewriter.create<LLVM::ConstantOp>(loc, i32ty,
|
||||
IntegerAttr::get(i32ty, v));
|
||||
}
|
||||
|
||||
Value createConstantF32(Location loc, PatternRewriter &rewriter, float v) {
|
||||
static Value createConstantF32(Location loc,
|
||||
PatternRewriter &rewriter, float v) {
|
||||
auto type = type::f32Ty(rewriter.getContext());
|
||||
return rewriter.create<LLVM::ConstantOp>(loc, type,
|
||||
rewriter.getF32FloatAttr(v));
|
||||
}
|
||||
|
||||
Value createConstantF64(Location loc, PatternRewriter &rewriter, float v) {
|
||||
static Value createConstantF64(Location loc,
|
||||
PatternRewriter &rewriter, float v) {
|
||||
auto type = type::f64Ty(rewriter.getContext());
|
||||
return rewriter.create<LLVM::ConstantOp>(loc, type,
|
||||
rewriter.getF64FloatAttr(v));
|
||||
}
|
||||
|
||||
// Create an index type constant.
|
||||
Value createIndexConstant(OpBuilder &builder, Location loc,
|
||||
TypeConverter *converter, int64_t value) {
|
||||
static Value createIndexConstant(OpBuilder &builder, Location loc,
|
||||
TypeConverter *converter, int64_t value) {
|
||||
Type ty = converter->convertType(builder.getIndexType());
|
||||
return builder.create<LLVM::ConstantOp>(loc, ty,
|
||||
builder.getIntegerAttr(ty, value));
|
||||
}
|
||||
|
||||
// Create an integer constant of \param width bits.
|
||||
Value createLLVMIntegerConstant(OpBuilder &builder, Location loc, short width,
|
||||
int64_t value) {
|
||||
static Value createLLVMIntegerConstant(OpBuilder &builder, Location loc,
|
||||
short width, int64_t value) {
|
||||
Type ty = builder.getIntegerType(width);
|
||||
return builder.create<LLVM::ConstantOp>(loc, ty,
|
||||
builder.getIntegerAttr(ty, value));
|
||||
}
|
||||
|
||||
/// Helper function to get strides from a given shape and its order
|
||||
SmallVector<Value>
|
||||
static SmallVector<Value>
|
||||
getStridesFromShapeAndOrder(ArrayRef<int64_t> shape, ArrayRef<unsigned> order,
|
||||
Location loc, ConversionPatternRewriter &rewriter) {
|
||||
auto rank = shape.size();
|
||||
SmallVector<Value> strides(rank);
|
||||
auto stride = 1;
|
||||
int64_t stride = 1;
|
||||
for (auto idx : order) {
|
||||
strides[idx] = i32_val(stride);
|
||||
stride *= shape[idx];
|
||||
@@ -206,7 +250,7 @@ getStridesFromShapeAndOrder(ArrayRef<int64_t> shape, ArrayRef<unsigned> order,
|
||||
struct SharedMemoryObject {
|
||||
Value base; // i32 ptr. The start address of the shared memory object.
|
||||
// We need to store strides as Values but not integers because the
|
||||
// extract_slice instruction can take a slice at artibary offsets.
|
||||
// extract_slice instruction can take a slice at arbitrary offsets.
|
||||
// Take $a[16:32, 16:32] as an example, though we know the stride of $a[0] is
|
||||
// 32, we need to let the instruction that uses $a to be aware of that.
|
||||
// Otherwise, when we use $a, we only know that the shape of $a is 16x16. If
|
||||
@@ -266,7 +310,7 @@ struct SharedMemoryObject {
|
||||
}
|
||||
};
|
||||
|
||||
SharedMemoryObject
|
||||
static SharedMemoryObject
|
||||
getSharedMemoryObjectFromStruct(Location loc, Value llvmStruct,
|
||||
ConversionPatternRewriter &rewriter) {
|
||||
auto elems = getElementsFromStruct(loc, llvmStruct, rewriter);
|
||||
@@ -276,8 +320,8 @@ getSharedMemoryObjectFromStruct(Location loc, Value llvmStruct,
|
||||
/*offsets=*/{elems.begin() + 1 + rank, elems.end()}};
|
||||
}
|
||||
|
||||
Value storeShared(ConversionPatternRewriter &rewriter, Location loc, Value ptr,
|
||||
Value val, Value pred) {
|
||||
static Value storeShared(ConversionPatternRewriter &rewriter, Location loc,
|
||||
Value ptr, Value val, Value pred) {
|
||||
MLIRContext *ctx = rewriter.getContext();
|
||||
unsigned bits = val.getType().getIntOrFloatBitWidth();
|
||||
const char *c = bits == 64 ? "l" : (bits == 16 ? "h" : "r");
|
||||
@@ -290,8 +334,8 @@ Value storeShared(ConversionPatternRewriter &rewriter, Location loc, Value ptr,
|
||||
return builder.launch(rewriter, loc, void_ty(ctx));
|
||||
}
|
||||
|
||||
Value shflSync(Location loc, ConversionPatternRewriter &rewriter, Value val,
|
||||
int i) {
|
||||
static Value shflSync(Location loc, ConversionPatternRewriter &rewriter,
|
||||
Value val, int i) {
|
||||
unsigned bits = val.getType().getIntOrFloatBitWidth();
|
||||
|
||||
if (bits == 64) {
|
||||
|
||||
@@ -1,5 +1,5 @@
|
||||
add_mlir_conversion_library(TritonToTritonGPU
|
||||
TritonToTritonGPU.cpp
|
||||
TritonToTritonGPUPass.cpp
|
||||
|
||||
ADDITIONAL_HEADER_DIRS
|
||||
${PROJECT_SOURCE_DIR}/include/triton/Conversion/TritonToTritonGPU
|
||||
|
||||
@@ -1,16 +1,24 @@
|
||||
#include "triton/Conversion/TritonToTritonGPU/TritonToTritonGPU.h"
|
||||
#include "../PassDetail.h"
|
||||
#include "triton/Conversion/TritonToTritonGPU/TritonToTritonGPUPass.h"
|
||||
|
||||
#include "mlir/Dialect/Arithmetic/IR/Arithmetic.h"
|
||||
#include "mlir/Dialect/GPU/GPUDialect.h"
|
||||
#include "mlir/Dialect/LLVMIR/LLVMDialect.h"
|
||||
#include "mlir/Dialect/LLVMIR/NVVMDialect.h"
|
||||
#include "mlir/Dialect/StandardOps/IR/Ops.h"
|
||||
#include "mlir/Pass/Pass.h"
|
||||
#include "mlir/Transforms/DialectConversion.h"
|
||||
#include "triton/Dialect/Triton/IR/Dialect.h"
|
||||
#include "triton/Dialect/TritonGPU/IR/Dialect.h"
|
||||
#include "triton/Dialect/TritonGPU/Transforms/TritonGPUConversion.h"
|
||||
#include "llvm/ADT/APSInt.h"
|
||||
#include <numeric>
|
||||
|
||||
using namespace mlir;
|
||||
using namespace mlir::triton;
|
||||
|
||||
#define GEN_PASS_CLASSES
|
||||
#include "triton/Conversion/Passes.h.inc"
|
||||
|
||||
namespace {
|
||||
|
||||
template <class Op> class GenericOpPattern : public OpConversionPattern<Op> {
|
||||
@@ -1,4 +1,5 @@
|
||||
#include "triton/Target/LLVMIR/LLVMIRTranslation.h"
|
||||
|
||||
#include "mlir/Conversion/Passes.h"
|
||||
#include "mlir/Dialect/LLVMIR/LLVMDialect.h"
|
||||
#include "mlir/ExecutionEngine/ExecutionEngine.h"
|
||||
@@ -11,8 +12,8 @@
|
||||
#include "mlir/Target/LLVMIR/Export.h"
|
||||
#include "mlir/Target/LLVMIR/LLVMTranslationInterface.h"
|
||||
#include "mlir/Transforms/Passes.h"
|
||||
#include "triton/Conversion/TritonGPUToLLVM/TritonGPUToLLVM.h"
|
||||
#include "triton/tools/sys/getenv.hpp"
|
||||
#include "triton/Conversion/TritonGPUToLLVM/TritonGPUToLLVMPass.h"
|
||||
#include "triton/Tools/Sys/GetEnv.hpp"
|
||||
#include "llvm/IR/Constants.h"
|
||||
#include "llvm/IRReader/IRReader.h"
|
||||
#include "llvm/Linker/Linker.h"
|
||||
@@ -65,14 +66,14 @@ void extractNVVMMetadata(mlir::ModuleOp module,
|
||||
bool hasMetadata{};
|
||||
|
||||
// maxntid
|
||||
if (op->hasAttr(NVVMMetadataField::MaxNTid)) {
|
||||
auto attr = op->getAttr(NVVMMetadataField::MaxNTid);
|
||||
if (op->hasAttr("nvvm.maxntid")) {
|
||||
auto attr = op->getAttr("nvvm.maxntid");
|
||||
meta.maxntidx = attr.dyn_cast<IntegerAttr>().getInt();
|
||||
hasMetadata = true;
|
||||
}
|
||||
|
||||
// kernel
|
||||
if (op->hasAttr(NVVMMetadataField::Kernel)) {
|
||||
if (op->hasAttr("nvvm.kernel")) {
|
||||
meta.is_kernel = true;
|
||||
hasMetadata = true;
|
||||
}
|
||||
@@ -208,7 +209,6 @@ void addExternalLibs(mlir::ModuleOp &module,
|
||||
|
||||
DictionaryAttr dict = DictionaryAttr::get(module->getContext(), attrs);
|
||||
module.getOperation()->setAttr("triton_gpu.externs", dict);
|
||||
return;
|
||||
}
|
||||
|
||||
bool linkExternLib(llvm::Module &module, llvm::StringRef path) {
|
||||
|
||||
@@ -13,15 +13,15 @@
|
||||
|
||||
#include "mlir/Dialect/LLVMIR/LLVMDialect.h"
|
||||
#include "triton/Analysis/Allocation.h"
|
||||
#include "triton/Conversion/TritonGPUToLLVM/TritonGPUToLLVM.h"
|
||||
#include "triton/Conversion/TritonToTritonGPU/TritonToTritonGPU.h"
|
||||
#include "triton/Conversion/TritonGPUToLLVM/TritonGPUToLLVMPass.h"
|
||||
#include "triton/Conversion/TritonToTritonGPU/TritonToTritonGPUPass.h"
|
||||
#include "triton/Dialect/Triton/IR/Dialect.h"
|
||||
#include "triton/Dialect/Triton/IR/Types.h"
|
||||
#include "triton/Dialect/Triton/Transforms/Passes.h"
|
||||
#include "triton/Dialect/TritonGPU/Transforms/Passes.h"
|
||||
#include "triton/Target/LLVMIR/LLVMIRTranslation.h"
|
||||
#include "triton/Target/PTX/PTXTranslation.h"
|
||||
#include "triton/tools/sys/getenv.hpp"
|
||||
#include "triton/Tools/Sys/GetEnv.hpp"
|
||||
|
||||
#include "llvm/IR/LegacyPassManager.h"
|
||||
#include "llvm/IR/Module.h"
|
||||
|
||||
@@ -1,5 +1,5 @@
|
||||
add_triton_ut(
|
||||
NAME TestPtxAsmFormat
|
||||
SRCS PtxAsmFormatTest.cpp
|
||||
SRCS PTXAsmFormatTest.cpp
|
||||
LIBS TritonGPUToLLVM
|
||||
)
|
||||
|
||||
@@ -1,16 +1,17 @@
|
||||
#include "triton/Conversion/TritonGPUToLLVM/PtxAsmFormat.h"
|
||||
#include "mlir/Dialect/Arithmetic/IR/Arithmetic.h"
|
||||
#include "mlir/IR/Builders.h"
|
||||
#include "triton/Conversion/TritonGPUToLLVM/PTXAsmFormat.h"
|
||||
#include "triton/Dialect/Triton/IR/Dialect.h"
|
||||
|
||||
#include <gtest/gtest.h>
|
||||
|
||||
namespace mlir {
|
||||
namespace triton {
|
||||
class PtxAsmFormatTest : public ::testing::Test {
|
||||
class PTXAsmFormatTest : public ::testing::Test {
|
||||
protected:
|
||||
static constexpr int numValues = 4;
|
||||
|
||||
PtxAsmFormatTest() {
|
||||
PTXAsmFormatTest() {
|
||||
ctx.loadDialect<arith::ArithmeticDialect>();
|
||||
|
||||
createValues();
|
||||
@@ -34,7 +35,7 @@ protected:
|
||||
Value v[numValues + 1];
|
||||
};
|
||||
|
||||
TEST_F(PtxAsmFormatTest, basic) {
|
||||
TEST_F(PTXAsmFormatTest, basic) {
|
||||
PTXBuilder builder;
|
||||
|
||||
// Create the operands needed by the instructions in the PTX code.
|
||||
@@ -55,7 +56,7 @@ TEST_F(PtxAsmFormatTest, basic) {
|
||||
ASSERT_EQ(constraints, "=r,b"); // $0 -> =r, $1 -> b
|
||||
}
|
||||
|
||||
TEST_F(PtxAsmFormatTest, complexInstruction) {
|
||||
TEST_F(PTXAsmFormatTest, complexInstruction) {
|
||||
using triton::CacheModifier;
|
||||
using triton::EvictionPolicy;
|
||||
|
||||
@@ -99,7 +100,7 @@ TEST_F(PtxAsmFormatTest, complexInstruction) {
|
||||
EXPECT_EQ(builder.getConstraints(), "l,b");
|
||||
}
|
||||
|
||||
TEST_F(PtxAsmFormatTest, MultiLinePTX) {
|
||||
TEST_F(PTXAsmFormatTest, MultiLinePTX) {
|
||||
PTXBuilder builder;
|
||||
|
||||
auto *constVal = builder.newConstantOperand(1);
|
||||
@@ -121,7 +122,7 @@ TEST_F(PtxAsmFormatTest, MultiLinePTX) {
|
||||
EXPECT_EQ(values[1], v[2]); // $1 -> v[2]
|
||||
}
|
||||
|
||||
TEST_F(PtxAsmFormatTest, onlyAttachMLIRArgs) {
|
||||
TEST_F(PTXAsmFormatTest, onlyAttachMLIRArgs) {
|
||||
PTXBuilder builder;
|
||||
const char *ptxCode =
|
||||
".param .b64 param0;\n" // prepare param0 (format string)
|
||||
Reference in New Issue
Block a user