Merge branch 'triton-mlir-IFU' into merge_IFU_to_triton_mlir

This commit is contained in:
Rohit Santhanam
2023-01-03 23:37:11 +00:00
183 changed files with 15609 additions and 8077 deletions

2
.github/CODEOWNERS vendored
View File

@@ -28,6 +28,8 @@ lib/Analysis/Utility.cpp @Jokeren
# ----------
# Pipeline pass
lib/Dialect/TritonGPU/Transforms/Pipeline.cpp @daadaada
# Prefetch pass
lib/Dialect/TritonGPU/Transforms/Prefetch.cpp @daadaada
# Coalesce pass
lib/Dialect/TritonGPU/Transforms/Coalesce.cpp @ptillet
# Layout simplification pass

View File

@@ -4,7 +4,7 @@ on:
workflow_dispatch:
pull_request:
branches:
- main
- master
- triton-mlir
jobs:
@@ -17,7 +17,7 @@ jobs:
id: set-matrix
run: |
if [ x"${{ github.repository }}" == x"openai/triton" ]; then
echo '::set-output name=matrix::[["self-hosted", "A10"], "macos-10.15"]'
echo '::set-output name=matrix::[["self-hosted", "A10"], ["self-hosted", "V100"], "macos-10.15"]'
else
echo '::set-output name=matrix::["ubuntu-latest", "macos-10.15"]'
fi
@@ -40,26 +40,26 @@ jobs:
rm -rf ~/.triton/cache/
- name: Check imports
if: startsWith(matrix.runner, 'ubuntu')
if: ${{ matrix.runner != 'macos-10.15' }}
run: |
pip install isort
isort -c ./python || ( echo '::error title=Imports not sorted::Please run \"isort ./python\"' ; exit 1 )
- name: Check python style
if: startsWith(matrix.runner, 'ubuntu')
if: ${{ matrix.runner != 'macos-10.15' }}
run: |
pip install autopep8
autopep8 -a -r -d --exit-code ./python || ( echo '::error title=Style issues::Please run \"autopep8 -a -r -i ./python\"' ; exit 1 )
- name: Check cpp style
if: startsWith(matrix.runner, 'ubuntu')
if: ${{ matrix.runner != 'macos-10.15' }}
run: |
pip install clang-format
find . -regex '.*\.\(cpp\|hpp\|h\|cc\)' -not -path "./python/build/*" -not -path "./include/triton/external/*" -print0 | xargs -0 -n1 clang-format -style=file --dry-run -Werror -i ||
(echo '::error title=Style issues:: Please run `find . -regex ".*\.\(cpp\|hpp\|h\|cc\)" -not -path "./python/build/*" -not -path "./include/triton/external/*" -print0 | xargs -0 -n1 clang-format -style=file -i`' ; exit 1)
- name: Flake8
if: startsWith(matrix.runner, 'ubuntu')
if: ${{ matrix.runner != 'macos-10.15' }}
run: |
pip install flake8
flake8 --config ./python/setup.cfg ./python || ( echo '::error::Flake8 failed; see logs for errors.' ; exit 1 )
@@ -81,9 +81,10 @@ jobs:
- name: Run python tests
if: ${{matrix.runner[0] == 'self-hosted'}}
run: |
cd python/tests
cd python/test/unit/
pytest
- name: Run CXX unittests
run: |
cd python/

18
.gitignore vendored
View File

@@ -1,12 +1,24 @@
# Triton builds
build/
__pycache__
.pytest_cache
# Triton Python module builds
python/build/
python/triton.egg-info/
python/triton/_C/libtriton.pyd
python/triton/_C/libtriton.so
# Python caches
__pycache__
.pytest_cache
# VS Code project files
.vscode
.vs
# JetBrains project files
.idea
cmake-build-*
# cache dumps
triton_cache*
log_*

View File

@@ -25,6 +25,10 @@ endif()
# used conditionally in this file and by lit tests
find_package(Python3 REQUIRED COMPONENTS Development Interpreter)
# Customized release build type with assertions: TritonRelBuildWithAsserts
set(CMAKE_C_FLAGS_TRITONRELBUILDWITHASSERTS "-O2 -g")
set(CMAKE_CXX_FLAGS_TRITONRELBUILDWITHASSERTS "-O2 -g")
# Default build type
if(NOT CMAKE_BUILD_TYPE)
message(STATUS "Default build type: Release")
@@ -50,11 +54,7 @@ endif()
set(CMAKE_CXX_FLAGS "${CMAKE_C_FLAGS} -D__STDC_FORMAT_MACROS -fPIC -std=gnu++17 -fvisibility=hidden -fvisibility-inlines-hidden")
if (TRITON_USE_ROCM)
set(MI_GPU_ARCH $ENV{MI_GPU_ARCH})
if (NOT MI_GPU_ARCH)
set(MI_GPU_ARCH "gfx90a")
endif()
set(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} -DUSE_ROCM -DMI_GPU_ARCH=${MI_GPU_ARCH} -Wno-unused-result -Wno-attributes")
set(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} -DUSE_ROCM -Wno-unused-result -Wno-attributes")
endif()
if(APPLE)
@@ -215,7 +215,7 @@ target_link_libraries(triton
TritonGPUTransforms
TritonLLVMIR
TritonPTX
TritonAMDGCN
TritonHSACO
${dialect_libs}
${conversion_libs}
# optimizations

View File

@@ -55,6 +55,15 @@ ctest
lit -v test
```
# Install from source
```
git clone https://github.com/openai/triton.git;
cd triton/python;
pip install cmake; # build time dependency
pip install -e .
```
# Changelog
Version 1.1 is out! New features include:

View File

@@ -10,8 +10,8 @@
#include "mlir/Support/LogicalResult.h"
#include "mlir/Target/LLVMIR/Dialect/LLVMIR/LLVMToLLVMIRTranslation.h"
#include "mlir/Target/LLVMIR/Export.h"
#include "triton/Conversion/TritonGPUToLLVM/TritonGPUToLLVM.h"
#include "triton/Conversion/TritonToTritonGPU/TritonToTritonGPU.h"
#include "triton/Conversion/TritonGPUToLLVM/TritonGPUToLLVMPass.h"
#include "triton/Conversion/TritonToTritonGPU/TritonToTritonGPUPass.h"
#include "triton/Dialect/Triton/IR/Dialect.h"
#include "triton/Dialect/TritonGPU/IR/Dialect.h"
#include "triton/Target/LLVMIR/LLVMIRTranslation.h"
@@ -107,7 +107,8 @@ LogicalResult tritonTranslateMain(int argc, char **argv,
}
llvm::LLVMContext llvmContext;
auto llvmir = translateTritonGPUToLLVMIR(&llvmContext, *module);
auto llvmir =
translateTritonGPUToLLVMIR(&llvmContext, *module, SMArch.getValue());
if (!llvmir) {
llvm::errs() << "Translate to LLVM IR failed";
}

View File

@@ -168,7 +168,7 @@ Scheduling languages are, without a doubt, one of the most popular approaches fo
Limitations
++++++++++++
This ease-of-development comes at a cost. First of all, existing systems that follow this paradigm tend to be noticeably slower than Triton on modern hardware when applicable (e.g., V100/A100 tensor cores w/ equal tile sizes). I do believe that this is not a fundamental issue of scheduling languages -- in the sense that it could probably be solved with more efforts -- but it could mean that these systems are harder to engineer. More importantly, existing scheduling languages generate loops whose bounds and increments cannot depend on surrounding loop indice without at least imposing severe constraints on possible schedules -- if not breaking the system entirely. This is problematic for sparse computations, whose iteration spaces may be irregular.
This ease-of-development comes at a cost. First of all, existing systems that follow this paradigm tend to be noticeably slower than Triton on modern hardware when applicable (e.g., V100/A100 tensor cores w/ equal tile sizes). I do believe that this is not a fundamental issue of scheduling languages -- in the sense that it could probably be solved with more efforts -- but it could mean that these systems are harder to engineer. More importantly, existing scheduling languages generate loops whose bounds and increments cannot depend on surrounding loop indices without at least imposing severe constraints on possible schedules -- if not breaking the system entirely. This is problematic for sparse computations, whose iteration spaces may be irregular.
.. table::
:widths: 50 50

View File

@@ -20,8 +20,6 @@ SmallVector<unsigned>
getScratchConfigForCvtLayout(triton::gpu::ConvertLayoutOp op, unsigned &inVec,
unsigned &outVec);
SmallVector<unsigned> getScratchConfigForReduce(triton::ReduceOp op);
} // namespace triton
/// Modified from llvm-15.0: llvm/ADT/AddressRanges.h

View File

@@ -131,6 +131,12 @@ public:
ChangeResult
visitOperation(Operation *op,
ArrayRef<LatticeElement<AxisInfo> *> operands) override;
unsigned getPtrVectorSize(Value ptr);
unsigned getPtrAlignment(Value ptr);
unsigned getMaskAlignment(Value mask);
};
} // namespace mlir

View File

@@ -29,7 +29,11 @@ public:
/// The following circumstances are not considered yet:
/// - Double buffers
/// - N buffers
MembarAnalysis(Allocation *allocation) : allocation(allocation) { run(); }
MembarAnalysis(Allocation *allocation) : allocation(allocation) {}
/// Runs the membar analysis to the given operation, inserts a barrier if
/// necessary.
void run();
private:
struct RegionInfo {
@@ -56,8 +60,12 @@ private:
bool isIntersected(const RegionInfo &other, Allocation *allocation) const {
return /*RAW*/ isIntersected(syncWriteBuffers, other.syncReadBuffers,
allocation) ||
/*WAR*/ isIntersected(syncReadBuffers, other.syncWriteBuffers,
allocation);
/*WAR*/
isIntersected(syncReadBuffers, other.syncWriteBuffers,
allocation) ||
/*WAW*/
isIntersected(syncWriteBuffers, other.syncWriteBuffers,
allocation);
}
/// Clears the buffers because a barrier is inserted.
@@ -78,10 +86,6 @@ private:
}
};
/// Runs the membar analysis to the given operation, inserts a barrier if
/// necessary.
void run();
/// Applies the barrier analysis based on the SCF dialect, in which each
/// region has a single basic block only.
/// Example:

View File

@@ -8,12 +8,57 @@
namespace mlir {
class ReduceOpHelper {
public:
explicit ReduceOpHelper(triton::ReduceOp op) : op(op) {
srcTy = op.operand().getType().cast<RankedTensorType>();
}
ArrayRef<int64_t> getSrcShape() { return srcTy.getShape(); }
Attribute getSrcLayout() { return srcTy.getEncoding(); }
bool isFastReduction();
unsigned getInterWarpSize();
unsigned getIntraWarpSize();
unsigned getThreadsReductionAxis();
SmallVector<unsigned> getScratchConfigBasic();
SmallVector<SmallVector<unsigned>> getScratchConfigsFast();
unsigned getScratchSizeInBytes();
private:
triton::ReduceOp op;
RankedTensorType srcTy{};
};
bool isSharedEncoding(Value value);
bool maybeSharedAllocationOp(Operation *op);
bool maybeAliasOp(Operation *op);
bool supportMMA(triton::DotOp op, int version);
bool supportMMA(Value value, int version);
Type getElementType(Value value);
std::string getValueOperandName(Value value, AsmState &state);
template <typename T_OUT, typename T_IN>
inline SmallVector<T_OUT> convertType(ArrayRef<T_IN> in) {
SmallVector<T_OUT> out;
for (const T_IN &i : in)
out.push_back(T_OUT(i));
return out;
}
template <typename Int> Int product(llvm::ArrayRef<Int> arr) {
return std::accumulate(arr.begin(), arr.end(), 1, std::multiplies{});
}

View File

@@ -10,20 +10,22 @@ namespace triton {
namespace type {
// Integer types
Type i32Ty(MLIRContext *ctx) { return IntegerType::get(ctx, 32); }
Type i8Ty(MLIRContext *ctx) { return IntegerType::get(ctx, 8); }
Type u32Ty(MLIRContext *ctx) {
// TODO(Superjomn): may change `static` into better implementations
static Type i32Ty(MLIRContext *ctx) { return IntegerType::get(ctx, 32); }
static Type i16Ty(MLIRContext *ctx) { return IntegerType::get(ctx, 16); }
static Type i8Ty(MLIRContext *ctx) { return IntegerType::get(ctx, 8); }
static Type u32Ty(MLIRContext *ctx) {
return IntegerType::get(ctx, 32, IntegerType::Unsigned);
}
Type u1Ty(MLIRContext *ctx) {
static Type u1Ty(MLIRContext *ctx) {
return IntegerType::get(ctx, 1, IntegerType::Unsigned);
}
// Float types
Type f16Ty(MLIRContext *ctx) { return FloatType::getF16(ctx); }
Type f32Ty(MLIRContext *ctx) { return FloatType::getF32(ctx); }
Type f64Ty(MLIRContext *ctx) { return FloatType::getF64(ctx); }
Type bf16Ty(MLIRContext *ctx) { return FloatType::getBF16(ctx); }
static Type f16Ty(MLIRContext *ctx) { return FloatType::getF16(ctx); }
static Type f32Ty(MLIRContext *ctx) { return FloatType::getF32(ctx); }
static Type f64Ty(MLIRContext *ctx) { return FloatType::getF64(ctx); }
static Type bf16Ty(MLIRContext *ctx) { return FloatType::getBF16(ctx); }
static bool isFloat(Type type) {
return type.isF32() || type.isF64() || type.isF16() || type.isF128();

View File

@@ -2,8 +2,8 @@
#define TRITON_CONVERSION_PASSES_H
#include "mlir/Dialect/LLVMIR/LLVMDialect.h"
#include "triton/Conversion/TritonGPUToLLVM/TritonGPUToLLVM.h"
#include "triton/Conversion/TritonToTritonGPU/TritonToTritonGPU.h"
#include "triton/Conversion/TritonGPUToLLVM/TritonGPUToLLVMPass.h"
#include "triton/Conversion/TritonToTritonGPU/TritonToTritonGPUPass.h"
namespace mlir {
namespace triton {

View File

@@ -44,6 +44,12 @@ def ConvertTritonGPUToLLVM : Pass<"convert-triton-gpu-to-llvm", "mlir::ModuleOp"
"mlir::NVVM::NVVMDialect",
"mlir::ROCDL::ROCDLDialect",
"mlir::StandardOpsDialect"];
let options = [
Option<"computeCapability", "compute-capability",
"int32_t", /*default*/"80",
"device compute capability">
];
}
#endif

View File

@@ -1,5 +1,5 @@
#ifndef TRITON_CONVERSION_TRITON_GPU_TO_LLVM_GCN_ASM_FORMAT_H_
#define TRITON_CONVERSION_TRITON_GPU_TO_LLVM_GCN_ASM_FORMAT_H_
#ifndef TRITON_CONVERSION_TRITON_GPU_TO_LLVM_GCN_FORMAT_H_
#define TRITON_CONVERSION_TRITON_GPU_TO_LLVM_GCN_FORMAT_H_
#include "mlir/IR/Value.h"
#include "triton/Dialect/Triton/IR/Dialect.h"
@@ -48,7 +48,7 @@ struct GCNBuilder {
std::string dump() const;
};
struct Modifier {
struct Modifier {
Value value;
std::string modifier;
std::string arg;
@@ -77,6 +77,7 @@ struct GCNBuilder {
}
return str;
}
std::string dump() const;
};
@@ -133,8 +134,6 @@ struct GCNBuilder {
Operand *newAddrOperand(mlir::Value addr, StringRef constraint);
Operand *newEmptyOperand(std::string arg);
Modifier *newModifier(StringRef modifier, StringRef arg);
llvm::SmallVector<Operand *, 4> getAllArgs() const;
@@ -192,12 +191,10 @@ struct GCNInstrCommon {
// clang-format on
// Set operands of this instruction.
GCNInstrExecution &operator()(llvm::ArrayRef<Operand *> oprs,
llvm::ArrayRef<Modifier *> mods);
GCNInstrExecution &operator()(llvm::ArrayRef<Operand *> oprs, llvm::ArrayRef<Modifier*> mods);
protected:
GCNInstrExecution &call(llvm::ArrayRef<Operand *> oprs,
ArrayRef<Modifier *> mods);
GCNInstrExecution &call(llvm::ArrayRef<Operand *> oprs, ArrayRef<Modifier *> mods);
GCNBuilder *builder{};
llvm::SmallVector<std::string, 4> instrParts;
@@ -221,8 +218,35 @@ template <class ConcreteT> struct GCNInstrBase : public GCNInstrCommon {
}
};
enum VectorWidth {
Byte = 8,
Short = 16,
Dword = 32,
Qword = 64
};
struct GCNInstr : public GCNInstrBase<GCNInstr> {
using GCNInstrBase<GCNInstr>::GCNInstrBase;
GCNInstr &float_op_type(int width) {
switch (width) {
case Byte:
assert(Byte != width);
break;
case Short:
o("f16");
break;
case Dword:
o("f32");
break;
case Qword:
o("f64");
break;
default:
break;
}
return *this;
}
};
struct GCNInstrExecution {
@@ -234,10 +258,8 @@ struct GCNInstrExecution {
GCNInstrExecution() = default;
explicit GCNInstrExecution(GCNInstrCommon *instr,
llvm::ArrayRef<Operand *> oprs,
llvm::ArrayRef<Modifier *> modifiers)
: instr(instr), argsInOrder(oprs.begin(), oprs.end()),
mods(modifiers.begin(), modifiers.end()) {}
llvm::ArrayRef<Operand *> oprs, llvm::ArrayRef<Modifier *> modifiers)
: instr(instr), argsInOrder(oprs.begin(), oprs.end()), mods(modifiers.begin(), modifiers.end()) {}
std::string dump() const;
@@ -246,12 +268,12 @@ struct GCNInstrExecution {
GCNInstrCommon *instr{};
};
struct GCNMemInstr : public GCNInstrBase<GCNMemInstr> {
using GCNInstrBase<GCNMemInstr>::GCNInstrBase;
// Add specific type suffix to instruction
enum VectorWidth { Byte = 8, Short = 16, Dword = 32, Qword = 64 };
GCNMemInstr &load_type(int width) {
switch (width) {
case Byte:

View File

@@ -22,8 +22,8 @@ struct PTXInstrExecution;
// PTXBuilder helps to manage a PTX asm program consists of one or multiple
// instructions.
//
// A helper for building a ASM program, the objective of PTXBuilder is to give a
// thin encapsulation and make the ASM code for MLIR LLVM Dialect more clear.
// A helper for building an ASM program, the objective of PTXBuilder is to give
// a thin encapsulation and make the ASM code for MLIR LLVM Dialect more clear.
// Currently, several factors are introduced to reduce the need for mixing
// string and C++ if-else code.
//
@@ -147,7 +147,7 @@ struct PTXBuilder {
Operand *newOperand(StringRef constraint);
// Create a constant integer operand.
Operand *newConstantOperand(int v);
Operand *newConstantOperand(int64_t v);
// Create a constant operand with explicit code specified.
Operand *newConstantOperand(const std::string &v);
@@ -172,6 +172,22 @@ private:
return argArchive.back().get();
}
// Make the operands in argArchive follow the provided \param order.
void reorderArgArchive(ArrayRef<Operand *> order) {
assert(order.size() == argArchive.size());
// The order in argArchive is unnecessary when onlyAttachMLIRArgs=false, but
// it does necessary when onlyAttachMLIRArgs is true for the $0, $1... are
// determined by PTX code snippet passed from external.
sort(argArchive.begin(), argArchive.end(),
[&](std::unique_ptr<Operand> &a, std::unique_ptr<Operand> &b) {
auto ida = std::find(order.begin(), order.end(), a.get());
auto idb = std::find(order.begin(), order.end(), b.get());
assert(ida != order.end());
assert(idb != order.end());
return ida < idb;
});
}
friend struct PTXInstr;
friend struct PTXInstrCommon;
@@ -201,10 +217,17 @@ struct PTXInstrCommon {
// clang-format on
// Set operands of this instruction.
PTXInstrExecution &operator()(llvm::ArrayRef<Operand *> oprs);
PTXInstrExecution &operator()(llvm::ArrayRef<Operand *> oprs,
bool onlyAttachMLIRArgs = false);
protected:
PTXInstrExecution &call(llvm::ArrayRef<Operand *> oprs);
// "Call" the instruction with operands.
// \param oprs The operands of this instruction.
// \param onlyAttachMLIRArgs Indicate that it simply attach the MLIR Arguments
// to the inline Asm without generating the operand ids(such as $0, $1) in PTX
// code.
PTXInstrExecution &call(llvm::ArrayRef<Operand *> oprs,
bool onlyAttachMLIRArgs = false);
PTXBuilder *builder{};
llvm::SmallVector<std::string, 4> instrParts;
@@ -234,70 +257,18 @@ template <class ConcreteT> struct PTXInstrBase : public PTXInstrCommon {
struct PTXInstr : public PTXInstrBase<PTXInstr> {
using PTXInstrBase<PTXInstr>::PTXInstrBase;
};
// A helper for PTX ld/st instruction.
// Usage:
// PtxIOInstr store("st");
// store.predicate(pValue).global().v(32).b(1); // @%0 st.global.v32.b1
// store.addAddr(addrValue, "l", off);
struct PTXIOInstr : public PTXInstrBase<PTXIOInstr> {
using PTXInstrBase<PTXIOInstr>::PTXInstrBase;
// Append a ".global" to the instruction.
PTXInstr &global();
// Add ".global" suffix to instruction
PTXIOInstr &global(bool predicate = true) {
o("global", predicate);
return *this;
}
// Append a ".shared" to the instruction.
PTXInstr &shared();
// Add ".shared" suffix to instruction
PTXIOInstr &shared(bool predicate = true) {
o("shared", predicate);
return *this;
}
// Append a ".v[0-9]+" to the instruction
PTXInstr &v(int vecWidth, bool predicate = true);
// Add ".v" suffix to instruction
PTXIOInstr &v(int vecWidth, bool predicate = true) {
if (vecWidth > 1) {
o("v" + std::to_string(vecWidth), predicate);
}
return *this;
}
// Add ".b" suffix to instruction
PTXIOInstr &b(int width) {
o("b" + std::to_string(width));
return *this;
}
};
struct PTXCpAsyncInstrBase : public PTXInstrBase<PTXCpAsyncInstrBase> {
explicit PTXCpAsyncInstrBase(PTXBuilder *builder)
: PTXInstrBase(builder, "cp.async") {}
};
struct PTXCpAsyncCommitGroupInstr : public PTXCpAsyncInstrBase {
explicit PTXCpAsyncCommitGroupInstr(PTXBuilder *builder)
: PTXCpAsyncInstrBase(builder) {
o("commit_group");
}
};
struct PTXCpAsyncWaitGroupInstr : public PTXCpAsyncInstrBase {
explicit PTXCpAsyncWaitGroupInstr(PTXBuilder *builder)
: PTXCpAsyncInstrBase(builder) {
o("wait_group");
}
};
struct PTXCpAsyncLoadInstr : public PTXCpAsyncInstrBase {
explicit PTXCpAsyncLoadInstr(PTXBuilder *builder,
triton::CacheModifier modifier)
: PTXCpAsyncInstrBase(builder) {
o(triton::stringifyCacheModifier(modifier).str());
o("shared");
o("global");
}
// Append a".b[0-9]+" to the instruction
PTXInstr &b(int width);
};
// Record the operands and context for "launching" a PtxInstr.
@@ -308,8 +279,10 @@ struct PTXInstrExecution {
PTXInstrExecution() = default;
explicit PTXInstrExecution(PTXInstrCommon *instr,
llvm::ArrayRef<Operand *> oprs)
: argsInOrder(oprs.begin(), oprs.end()), instr(instr) {}
llvm::ArrayRef<Operand *> oprs,
bool onlyAttachMLIRArgs)
: argsInOrder(oprs.begin(), oprs.end()), instr(instr),
onlyAttachMLIRArgs(onlyAttachMLIRArgs) {}
// Prefix a predicate to the instruction.
PTXInstrExecution &predicate(mlir::Value value, StringRef constraint = "b") {
@@ -330,9 +303,24 @@ struct PTXInstrExecution {
PTXInstrCommon *instr{};
Operand *pred{};
bool onlyAttachMLIRArgs{};
};
/// ====== Some instruction wrappers ======
// We add the wrappers to make the usage more intuitive by avoiding mixing the
// PTX code with some trivial C++ code.
struct PTXCpAsyncLoadInstr : PTXInstrBase<PTXCpAsyncLoadInstr> {
explicit PTXCpAsyncLoadInstr(PTXBuilder *builder,
triton::CacheModifier modifier)
: PTXInstrBase(builder, "cp.async") {
o(triton::stringifyCacheModifier(modifier).str());
o("shared");
o("global");
}
};
} // namespace triton
} // namespace mlir
#endif // TRITON_CONVERSION_TRITON_GPU_TO_LLVM_ASM_FORMAT_H_
#endif

View File

@@ -1,42 +0,0 @@
#ifndef TRITON_CONVERSION_TRITONGPUTOLLVM_TRITONGPUTOLLVMPASS_H_
#define TRITON_CONVERSION_TRITONGPUTOLLVM_TRITONGPUTOLLVMPASS_H_
#include "mlir/Conversion/LLVMCommon/TypeConverter.h"
#include "mlir/Transforms/DialectConversion.h"
#include <memory>
namespace mlir {
class ModuleOp;
template <typename T> class OperationPass;
class TritonLLVMConversionTarget : public ConversionTarget {
public:
explicit TritonLLVMConversionTarget(MLIRContext &ctx,
mlir::LLVMTypeConverter &typeConverter);
};
class TritonLLVMFunctionConversionTarget : public ConversionTarget {
public:
explicit TritonLLVMFunctionConversionTarget(
MLIRContext &ctx, mlir::LLVMTypeConverter &typeConverter);
};
namespace triton {
// Names for identifying different NVVM annotations. It is used as attribute
// names in MLIR modules. Refer to
// https://docs.nvidia.com/cuda/nvvm-ir-spec/index.html#supported-properties for
// the full list.
struct NVVMMetadataField {
static constexpr char MaxNTid[] = "nvvm.maxntid";
static constexpr char Kernel[] = "nvvm.kernel";
};
std::unique_ptr<OperationPass<ModuleOp>> createConvertTritonGPUToLLVMPass();
} // namespace triton
} // namespace mlir
#endif

View File

@@ -0,0 +1,22 @@
#ifndef TRITON_CONVERSION_TRITONGPU_TO_LLVM_PASS_H
#define TRITON_CONVERSION_TRITONGPU_TO_LLVM_PASS_H
#include "mlir/Conversion/LLVMCommon/TypeConverter.h"
#include "mlir/Transforms/DialectConversion.h"
#include <memory>
namespace mlir {
class ModuleOp;
template <typename T> class OperationPass;
namespace triton {
std::unique_ptr<OperationPass<ModuleOp>>
createConvertTritonGPUToLLVMPass(int computeCapability = 80);
} // namespace triton
} // namespace mlir
#endif

View File

@@ -1,5 +1,5 @@
#ifndef TRITON_CONVERSION_TRITONTOTRITONGPU_TRITONTOTRITONGPUPASS_H_
#define TRITON_CONVERSION_TRITONTOTRITONGPU_TRITONTOTRITONGPUPASS_H_
#ifndef TRITON_CONVERSION_TRITONTOTRITONGPU_TRITONTOTRITONGPUPASS_H
#define TRITON_CONVERSION_TRITONTOTRITONGPU_TRITONTOTRITONGPUPASS_H
#include <memory>

View File

@@ -4,6 +4,7 @@
#include "mlir/Dialect/Math/IR/Math.h"
#include "mlir/Dialect/SCF/SCF.h"
#include "mlir/Dialect/StandardOps/IR/Ops.h"
#include "mlir/Dialect/Tensor/IR/Tensor.h"
#include "mlir/IR/BuiltinOps.h"
#include "mlir/IR/Dialect.h"
#include "mlir/Interfaces/ControlFlowInterfaces.h"
@@ -30,7 +31,15 @@ public:
virtual LogicalResult
inferExpandDimsOpEncoding(Attribute operandEncoding, unsigned axis,
Attribute &resultEncoding) const = 0;
Attribute &resultEncoding,
Optional<Location> location) const = 0;
// Note: this function only verify operand encoding but doesn't infer result
// encoding
virtual LogicalResult
inferDotOpEncoding(Attribute operandEncoding, unsigned opIdx,
Attribute retEncoding,
Optional<Location> location) const = 0;
};
} // namespace triton

View File

@@ -59,7 +59,8 @@ def TT_AtomicRMWAttr : I32EnumAttr<
I32EnumAttrCase<"MAX", 6, "max">,
I32EnumAttrCase<"MIN", 7, "min">,
I32EnumAttrCase<"UMAX", 8, "umax">,
I32EnumAttrCase<"UMIN", 9, "umin">
I32EnumAttrCase<"UMIN", 9, "umin">,
I32EnumAttrCase<"XCHG", 10, "exch">
]> {
let cppNamespace = "::mlir::triton";
}

View File

@@ -3,4 +3,9 @@
include "mlir/IR/OpBase.td"
def TensorSizeTrait : NativeOpTrait<"TensorSizeTrait">;
def SameOperandsAndResultEncoding : NativeOpTrait<"SameOperandsAndResultEncoding">;
def SameOperandsEncoding : NativeOpTrait<"SameOperandsEncoding">;
#endif // TRITON_INTERFACES

View File

@@ -10,10 +10,7 @@ include "mlir/Interfaces/SideEffectInterfaces.td" // NoSideEffect
include "mlir/Interfaces/ControlFlowInterfaces.td" // BranchOpInterface
include "mlir/Interfaces/InferTypeOpInterface.td" // SameOperandsAndResultType
include "mlir/Interfaces/SideEffectInterfaces.td" // NoSideEffect
def TensorSizeTrait : NativeOpTrait<"TensorSizeTrait">;
def SameOperandsAndResultEncoding : NativeOpTrait<"SameOperandsAndResultEncoding">;
def SameOperandsEncoding : NativeOpTrait<"SameOperandsEncoding">;
include "mlir/Interfaces/CastInterfaces.td" // CastOpInterface
//
// Op Base
@@ -72,17 +69,16 @@ def TT_BitcastOp : TT_Op<"bitcast", [SameOperandsAndResultShape,
// TODO: Add verifier
}
def TT_FpToFp : TT_Op<"fp_to_fp", [SameOperandsAndResultShape,
SameOperandsAndResultEncoding,
NoSideEffect,
/*DeclareOpInterfaceMethods<CastOpInterface>*/]> {
def TT_FpToFpOp : TT_Op<"fp_to_fp", [SameOperandsAndResultShape,
SameOperandsAndResultEncoding,
NoSideEffect,
DeclareOpInterfaceMethods<CastOpInterface>]> {
let summary = "Floating point casting for custom types";
let description = [{
Floating point casting for custom types (F8, BF8).
Floating point casting for custom types (F8).
F8 <-> BF8, FP16, FP32
BF8 <-> F8, FP16, FP32
F8 <-> FP16, BF16, FP32, FP64
}];
let arguments = (ins TT_FloatLike:$from);
@@ -103,15 +99,12 @@ def TT_AddPtrOp : TT_Op<"addptr",
SameOperandsAndResultShape,
SameOperandsAndResultEncoding,
TypesMatchWith<"result type matches ptr type",
"result", "ptr", "$_self">,
TypesMatchWith<"result shape matches offset shape",
"result", "offset",
"getI32SameShape($_self)">]> {
let arguments = (ins TT_PtrLike:$ptr, TT_I32Like:$offset);
"result", "ptr", "$_self">]> {
let arguments = (ins TT_PtrLike:$ptr, TT_IntLike:$offset);
let results = (outs TT_PtrLike:$result);
let assemblyFormat = "$ptr `,` $offset attr-dict `:` type($result)";
let assemblyFormat = "$ptr `,` $offset attr-dict `:` type($result) `,` type($offset)";
}
@@ -187,6 +180,7 @@ def TT_StoreOp : TT_Op<"store",
//
def TT_AtomicRMWOp : TT_Op<"atomic_rmw", [SameOperandsAndResultShape,
SameOperandsAndResultEncoding,
MemoryEffects<[MemRead]>,
MemoryEffects<[MemWrite]>,
TypesMatchWith<"infer ptr type from value type",
"val", "ptr",
@@ -208,7 +202,9 @@ def TT_AtomicRMWOp : TT_Op<"atomic_rmw", [SameOperandsAndResultShape,
let results = (outs TT_Type:$result);
}
def TT_AtomicCASOp : TT_Op<"atomic_cas", [SameOperandsAndResultShape,
def TT_AtomicCASOp : TT_Op<"atomic_cas", [MemoryEffects<[MemRead]>,
MemoryEffects<[MemWrite]>,
SameOperandsAndResultShape,
SameOperandsAndResultEncoding]> {
let summary = "atomic cas";
@@ -292,6 +288,18 @@ def TT_CatOp : TT_Op<"cat", [NoSideEffect,
let assemblyFormat = "$lhs `,` $rhs attr-dict `:` functional-type(operands, results)";
}
def TT_TransOp : TT_Op<"trans", [NoSideEffect,
SameOperandsAndResultElementType]> {
let summary = "transpose a tensor";
let arguments = (ins TT_Tensor:$src);
let results = (outs TT_Tensor:$result);
let assemblyFormat = "$src attr-dict `:` functional-type(operands, results)";
}
//
// SPMD Ops
//
@@ -324,7 +332,7 @@ def TT_DotOp : TT_Op<"dot", [NoSideEffect,
$d = matrix_multiply($a, $b) + $c
}];
let arguments = (ins TT_FpIntTensor:$a, TT_FpIntTensor:$b, TT_FpIntTensor:$c, BoolAttr:$allowTF32, BoolAttr:$transA, BoolAttr:$transB);
let arguments = (ins TT_FpIntTensor:$a, TT_FpIntTensor:$b, TT_FpIntTensor:$c, BoolAttr:$allowTF32);
let results = (outs TT_FpIntTensor:$d);
@@ -348,6 +356,11 @@ def TT_ReduceOp : TT_Op<"reduce", [NoSideEffect,
let assemblyFormat = "$operand attr-dict `:` type($operand) `->` type($result)";
let extraClassDeclaration = [{
// This member function is marked static because we need to call it before the ReduceOp
// is constructed, see the implementation of create_reduce in triton.cc.
static bool withIndex(mlir::triton::RedOp redOp);
}];
}
//

View File

@@ -14,9 +14,8 @@ class TritonTypeDef<string name, string _mnemonic>
// Floating-point Type
def F8 : TritonTypeDef<"Float8", "f8">;
def BF8 : TritonTypeDef<"BFloat8", "bf8">;
def TT_Float : AnyTypeOf<[F16, BF16, F32, F64], "floating-point">;
def TT_Float : AnyTypeOf<[F8, F16, BF16, F32, F64], "floating-point">;
def TT_FloatTensor : TensorOf<[TT_Float]>;
def TT_FloatLike : AnyTypeOf<[TT_Float, TT_FloatTensor]>;

View File

@@ -25,7 +25,13 @@ namespace gpu {
unsigned getElemsPerThread(Type type);
SmallVector<unsigned> getSizePerThread(Attribute layout);
SmallVector<unsigned> getThreadsPerWarp(const Attribute &layout);
SmallVector<unsigned> getWarpsPerCTA(const Attribute &layout);
SmallVector<unsigned> getSizePerThread(const Attribute &layout);
SmallVector<unsigned> getContigPerThread(Attribute layout);
SmallVector<unsigned> getThreadsPerCTA(const Attribute &layout);

View File

@@ -71,6 +71,75 @@ A_{3, 2} A_{3, 3} A_{3, 0} A_{3, 1} ... [phase 1] /
ArrayRefParameter<"unsigned", "order of axes by the rate of changing">:$order
);
let builders = [
AttrBuilder<(ins "DotOperandEncodingAttr":$dotOpEnc,
"ArrayRef<int64_t>":$shape,
"ArrayRef<unsigned>":$order,
"Type":$eltTy), [{
auto mmaEnc = dotOpEnc.getParent().dyn_cast<MmaEncodingAttr>();
if(!mmaEnc)
return $_get(context, 1, 1, 1, order);
int opIdx = dotOpEnc.getOpIdx();
// number of rows per phase
int perPhase = 128 / (shape[order[0]] * (eltTy.getIntOrFloatBitWidth() / 8));
perPhase = std::max<int>(perPhase, 1);
// index of the inner dimension in `order`
unsigned inner = (opIdx == 0) ? 0 : 1;
// ---- begin Volta ----
if (mmaEnc.isVolta()) {
bool is_row = order[0] != 0;
bool is_vec4 = opIdx == 0 ? !is_row && (shape[order[0]] <= 16) :
is_row && (shape[order[0]] <= 16);
// TODO[Superjomn]: Support the case when is_vec4=false later
// Currently, we only support ld.v2, for the mma layout varies with different ld vector width.
is_vec4 = true;
int pack_size = opIdx == 0 ? ((is_row || is_vec4) ? 1 : 2) :
((is_row && !is_vec4) ? 2 : 1);
int rep = 2 * pack_size;
int maxPhase = (order[inner] == 1 ? 8 : 4) / perPhase;
int vec = 2 * rep;
return $_get(context, vec, perPhase, maxPhase, order);
}
// ---- begin Ampere ----
if (mmaEnc.isAmpere()) {
std::vector<size_t> matShape = {8, 8,
2 * 64 / eltTy.getIntOrFloatBitWidth()};
// for now, disable swizzle when using transposed int8 tensor cores
if (eltTy.isInteger(8) && order[0] == inner)
return $_get(context, 1, 1, 1, order);
// --- handle A operand ---
if (opIdx == 0) { // compute swizzling for A operand
int vec = (order[0] == 1) ? matShape[2] : matShape[0]; // k : m
int mmaStride = (order[0] == 1) ? matShape[0] : matShape[2];
int maxPhase = mmaStride / perPhase;
return $_get(context, vec, perPhase, maxPhase, order);
}
// --- handle B operand ---
if (opIdx == 1) {
int vec = (order[0] == 1) ? matShape[1] : matShape[2]; // n : k
int mmaStride = (order[0] == 1) ? matShape[2] : matShape[1];
int maxPhase = mmaStride / perPhase;
return $_get(context, vec, perPhase, maxPhase, order);
}
llvm_unreachable("invalid operand index");
}
// ---- not implemented ----
llvm_unreachable("unsupported swizzling for provided MMA version");
}]>
];
let extraClassDeclaration = extraBaseClassDeclaration;
}
@@ -166,9 +235,11 @@ for
unsigned remainingLanes = 32;
unsigned remainingThreads = numWarps*32;
unsigned remainingWarps = numWarps;
unsigned prevLanes = 1;
unsigned prevWarps = 1;
SmallVector<unsigned, 4> threadsPerWarp(rank);
SmallVector<unsigned, 4> warpsPerCTA(rank);
for (int _dim = 0; _dim < rank; ++_dim) {
for (int _dim = 0; _dim < rank - 1; ++_dim) {
int i = order[_dim];
unsigned threadsPerCTA = std::clamp<unsigned>(remainingThreads, 1, shape[i] / sizePerThread[i]);
threadsPerWarp[i] = std::clamp<unsigned>(threadsPerCTA, 1, remainingLanes);
@@ -176,7 +247,12 @@ for
remainingWarps /= warpsPerCTA[i];
remainingLanes /= threadsPerWarp[i];
remainingThreads /= threadsPerCTA;
prevLanes *= threadsPerWarp[i];
prevWarps *= warpsPerCTA[i];
}
// Expand the last dimension to fill the remaining lanes and warps
threadsPerWarp[order[rank-1]] = 32 / prevLanes;
warpsPerCTA[order[rank-1]] = numWarps / prevWarps;
return $_get(context, sizePerThread, threadsPerWarp, warpsPerCTA, order);
@@ -215,46 +291,50 @@ def MmaEncodingAttr : DistributedEncoding<"MmaEncoding"> {
let description = [{
An encoding for tensors that have been produced by tensor cores.
It is characterized by two parameters:
- A 'version' which specifies the generation the tensor cores
- A 'versionMajor' which specifies the generation the tensor cores
whose output is being partitioned: 1 for first-gen tensor cores (Volta),
and 2 for second-gen tensor cores (Turing/Ampere).
- A 'versionMinor' which indicates the specific layout of a tensor core
generation, e.g. for Volta, there might be multiple kinds of layouts annotated
by 0,1,2 and so on.
- A `blockTileSize` to indicate how data should be
partitioned between warps.
// -------------------------------- version = 1 --------------------------- //
For first-gen tensor cores, the implicit warpTileSize is [16, 16].
Information about this layout can be found in the official PTX documentation
Note: the layout is different from the recommended in PTX ISA
https://docs.nvidia.com/cuda/parallel-thread-execution/index.html
(mma.884 section, FP32 accumulator).
For example, the matrix L corresponding to blockTileSize=[32,16] is:
For example, when versionMinor=1, the matrix L corresponding to
blockTileSize=[32,16] is:
warp 0
--------------------------------/\-------------------------------
[ 0 0 2 2 0 0 2 2 4 4 6 6 4 4 6 6 ]
[ 1 1 3 3 1 1 3 3 5 5 7 7 5 5 7 7 ]
[ 0 0 2 2 0 0 2 2 4 4 6 6 4 4 6 6 ]
[ 1 1 3 3 1 1 3 3 5 5 7 7 5 5 7 7 ]
[ 16 16 18 18 16 16 18 18 20 20 22 22 20 20 22 22]
[ 17 17 19 19 17 17 19 19 21 21 23 23 21 21 23 23]
[ 16 16 18 18 16 16 18 18 20 20 22 22 20 20 22 22]
[ 17 17 19 19 17 17 19 19 21 21 23 23 21 21 23 23]
[ 8 8 10 10 8 8 10 10 12 12 14 14 12 12 14 14]
[ 9 9 11 11 9 9 11 11 13 13 15 15 13 13 15 15]
[ ..............................................................
[ ..............................................................
[ 24 24 26 26 24 24 26 26 28 28 30 30 28 28 30 30]
[ 25 25 27 27 25 25 27 27 29 29 31 31 29 29 31 31]
[ 0 0 2 2 8 8 10 10 0 0 2 2 8 8 10 10 ]
[ 1 1 3 3 9 9 11 11 1 1 3 3 9 9 11 11 ]
[ 0 0 2 2 8 8 10 10 0 0 2 2 8 8 10 10 ]
[ 1 1 3 3 9 9 11 11 1 1 3 3 9 9 11 11 ]
[ 4 4 6 6 12 12 14 14 4 4 6 6 12 12 14 14 ]
[ 5 5 7 7 13 13 15 15 5 5 7 7 13 13 15 15 ]
[ 4 4 6 6 12 12 14 14 4 4 6 6 12 12 14 14 ]
[ 5 5 7 7 13 13 15 15 5 5 7 7 13 13 15 15 ]
[ 16 16 18 18 20 20 22 22 16 16 18 18 20 20 22 22 ]
[ 17 17 19 19 21 21 23 23 17 17 19 19 21 21 23 23 ]
[ 16 16 18 18 20 20 22 22 16 16 18 18 20 20 22 22 ]
[ 17 17 19 19 21 21 23 23 17 17 19 19 21 21 23 23 ]
[ 24 24 26 26 28 28 30 30 24 24 26 26 28 28 30 30 ]
[ 25 25 27 27 29 29 31 31 25 25 27 27 29 29 31 31 ]
[ 24 24 26 26 28 28 30 30 24 24 26 26 28 28 30 30 ]
[ 25 25 27 27 29 29 31 31 25 25 27 27 29 29 31 31 ]
warp 1 = warp0 + 32
warp 1 = warp0 + 32
--------------------------------/\-------------------------------
[ 32 32 34 34 32 32 34 34 36 36 38 38 36 36 38 38]
[ 33 33 35 35 33 33 35 35 37 37 39 39 37 37 39 39]
[ ..............................................................
[ ..............................................................
[ 56 56 58 58 56 56 58 58 60 60 62 62 60 60 62 62]
[ 57 57 59 59 57 57 59 59 61 61 63 63 61 61 63 63]
[ 32 32 34 34 40 40 42 42 32 32 34 34 40 40 42 42 ]
[ 33 33 35 35 41 41 43 43 33 33 35 35 41 41 43 43 ]
[ ............................................................... ]
// -------------------------------- version = 2 --------------------------- //
@@ -290,11 +370,39 @@ For example, the matrix L corresponding to blockTileSize=[32,16] is:
let parameters = (
ins
"unsigned":$version,
"unsigned":$versionMajor,
"unsigned":$versionMinor,
ArrayRefParameter<"unsigned">:$warpsPerCTA
);
let extraClassDeclaration = extraBaseClassDeclaration;
let builders = [
// specific for MMAV1(Volta)
AttrBuilder<(ins "int":$versionMajor,
"ArrayRef<unsigned>":$warpsPerCTA,
"ArrayRef<int64_t>":$shapeA,
"ArrayRef<int64_t>":$shapeB,
"bool":$isARow,
"bool":$isBRow), [{
assert(versionMajor == 1 && "Only MMAv1 has multiple versionMinor.");
bool isAVec4 = !isARow && (shapeA[isARow] <= 16);
bool isBVec4 = isBRow && (shapeB[isBRow] <= 16);
// 4-bits to encode 4 booleans: [isARow, isBRow, isAVec4, isBVec4]
int versionMinor = (isARow * (1<<0)) |\
(isBRow * (1<<1)) |\
(isAVec4 * (1<<2)) |\
(isBVec4 * (1<<3));
return $_get(context, versionMajor, versionMinor, warpsPerCTA);
}]>
];
let extraClassDeclaration = extraBaseClassDeclaration # [{
bool isVolta() const;
bool isAmpere() const;
// Get [isARow, isBRow, isAVec4, isBVec4] from versionMinor
std::tuple<bool, bool, bool, bool> decodeVoltaLayoutStates() const;
}];
}
def SliceEncodingAttr : DistributedEncoding<"SliceEncoding"> {
@@ -326,11 +434,11 @@ def SliceEncodingAttr : DistributedEncoding<"SliceEncoding"> {
);
let extraClassDeclaration = extraBaseClassDeclaration # [{
SmallVector<int64_t> paddedShape(ArrayRef<int64_t> shape) const;
template<class T>
SmallVector<T> paddedShape(ArrayRef<T> shape) const;
}];
}
def DotOperandEncodingAttr : DistributedEncoding<"DotOperandEncoding"> {
let mnemonic = "dot_op";
@@ -339,15 +447,35 @@ In TritonGPU dialect, considering `d = tt.dot a, b, c`
tt.dot's operands a and b must be of DotOperandEncodingAttr layout.
a's opIdx is 0, b's opIdx is 1.
The parend field in DotOperandEncodingAttr is the layout of d.
For MMA v1, an additional attribute `isMMAv1Row` determines whether e.g. the a operand is used
in the context of an mma.884.row.col or an mma.884.col.col operation. See the PTX ISA documentation
section 9.7.13.4.1 for more details.
}];
let parameters = (
ins
"unsigned":$opIdx,
"Attribute":$parent
"Attribute":$parent,
"Attribute":$isMMAv1Row
);
let builders = [
AttrBuilder<(ins "unsigned":$opIdx,
"Attribute":$parent), [{
Attribute isMMAv1Row;
if(parent.isa<MmaEncodingAttr>() &&
parent.cast<MmaEncodingAttr>().isVolta()){
isMMAv1Row = BoolAttr::get(context, true);
}
return $_get(context, opIdx, parent, isMMAv1Row);
}]>
];
let extraClassDeclaration = extraBaseClassDeclaration;
}
#endif

View File

@@ -32,13 +32,21 @@ def TTG_AsyncWaitOp : TTG_Op<"async_wait"> {
let arguments = (ins I32Attr:$num);
let assemblyFormat = "attr-dict";
let extraClassDeclaration = [{
static bool isSupported(int computeCapability) {
return computeCapability >= 80;
}
}];
}
// Port Arith_CmpIOp & Arith_CmpFOp & Std_SelectOp to TritonGPU.
// This is needed because these ops don't
// handle encodings
// e.g., https://github.com/llvm/llvm-project/blob/main/mlir/include/mlir/Dialect/Arithmetic/IR/ArithmeticOps.td#L111
def TTG_CmpIOp : TTG_Op<"cmpi", [NoSideEffect]> {
// e.g., https://github.com/llvm/llvm-project/blob/main/mlir/include/mlir/Dialect/Arith/IR/ArithOps.td#L111
def TTG_CmpIOp : TTG_Op<"cmpi", [NoSideEffect, Elementwise,
SameOperandsAndResultShape,
SameOperandsAndResultEncoding]> {
let summary = "integer comparison operation";
let description = [{}];
@@ -50,7 +58,9 @@ def TTG_CmpIOp : TTG_Op<"cmpi", [NoSideEffect]> {
let results = (outs TT_BoolLike:$result);
}
def TTG_CmpFOp : TTG_Op<"cmpf", [NoSideEffect]> {
def TTG_CmpFOp : TTG_Op<"cmpf", [NoSideEffect, Elementwise,
SameOperandsAndResultShape,
SameOperandsAndResultEncoding]> {
let summary = "floating-point comparison operation";
let description = [{}];
@@ -63,7 +73,9 @@ def TTG_CmpFOp : TTG_Op<"cmpf", [NoSideEffect]> {
}
// TODO: migrate to arith::SelectOp on LLVM16
def TTG_SelectOp : TTG_Op<"select", [NoSideEffect]> {
def TTG_SelectOp : TTG_Op<"select", [NoSideEffect, Elementwise,
SameOperandsAndResultShape,
SameOperandsAndResultEncoding]> {
let summary = "select operation";
let description = [{}];
@@ -79,8 +91,7 @@ def TTG_SelectOp : TTG_Op<"select", [NoSideEffect]> {
def TTG_InsertSliceAsyncOp : TTG_Op<"insert_slice_async",
[AttrSizedOperandSegments,
ResultsAreSharedEncoding,
// MemoryEffects<[MemRead]>, doesn't work with CSE but seems like it should?
NoSideEffect,
MemoryEffects<[MemRead]>,
TypesMatchWith<"infer mask type from src type",
"src", "mask", "getI1SameShape($_self)",
"($_op.getOperands().size() <= 3) || std::equal_to<>()">,
@@ -152,13 +163,24 @@ def TTG_InsertSliceAsyncOp : TTG_Op<"insert_slice_async",
// attr-dict `:` type($src) `->` type($dst)
//}];
let extraClassDeclaration = [{
static DenseSet<unsigned> getEligibleLoadByteWidth(int computeCapability) {
DenseSet<unsigned> validLoadBytes;
if (computeCapability >= 80) {
validLoadBytes = {4, 8, 16};
}
return validLoadBytes;
}
}];
// The custom parser could be replaced with oilist in LLVM-16
let parser = [{ return parseInsertSliceAsyncOp(parser, result); }];
let printer = [{ return printInsertSliceAsyncOp(p, *this); }];
}
def TTG_AllocTensorOp : TTG_Op<"alloc_tensor", [NoSideEffect, ResultsAreSharedEncoding]> {
def TTG_AllocTensorOp : TTG_Op<"alloc_tensor", [MemoryEffects<[MemAlloc]>, // Allocate shared memory
ResultsAreSharedEncoding]> {
let summary = "allocate tensor";
let description = [{

View File

@@ -6,13 +6,14 @@
namespace mlir {
std::unique_ptr<Pass> createTritonGPUPipelinePass(int numStages = 2);
std::unique_ptr<Pass> createTritonGPUCanonicalizeLoopsPass();
// TODO(Keren): prefetch pass not working yet
std::unique_ptr<Pass> createTritonGPUPrefetchPass();
std::unique_ptr<Pass> createTritonGPUSwizzlePass();
std::unique_ptr<Pass> createTritonGPUCanonicalizeLoopsPass();
std::unique_ptr<Pass> createTritonGPUCoalescePass();
std::unique_ptr<Pass> createTritonGPUCombineOpsPass();
std::unique_ptr<Pass> createTritonGPUCombineOpsPass(int computeCapability = 80);
std::unique_ptr<Pass> createTritonGPUVerifier();

View File

@@ -7,7 +7,7 @@ def TritonGPUPipeline : Pass<"tritongpu-pipeline", "mlir::ModuleOp"> {
let summary = "pipeline";
let description = [{
TODO
Unroll loops to hide global memory -> shared memory latency.
}];
let constructor = "mlir::createTritonGPUPipelinePass()";
@@ -23,11 +23,25 @@ def TritonGPUPipeline : Pass<"tritongpu-pipeline", "mlir::ModuleOp"> {
];
}
def TritonGPUPrefetch : Pass<"tritongpu-prefetch", "mlir::ModuleOp"> {
let summary = "prefetch";
let description = [{
Prefetch operands (a and b) of tt.dot into shared memory to hide shared memory -> register latency.
}];
let constructor = "mlir::createTritonGPUPrefetchPass()";
let dependentDialects = ["mlir::triton::gpu::TritonGPUDialect",
"mlir::scf::SCFDialect",
"mlir::arith::ArithmeticDialect"];
}
def TritonGPUCoalesce: Pass<"tritongpu-coalesce", "mlir::ModuleOp"> {
let summary = "coalesce";
let description = [{
TODO
TODO
}];
let constructor = "mlir::createTritonGPUCoalescePass()";
@@ -49,18 +63,12 @@ def TritonGPUCombineOps : Pass<"tritongpu-combine", "mlir::ModuleOp"> {
let dependentDialects = ["mlir::triton::gpu::TritonGPUDialect",
"mlir::triton::TritonDialect"];
}
def TritonGPUSwizzle : Pass<"tritongpu-swizzle", "mlir::ModuleOp"> {
let summary = "swizzle";
let description = [{
Inserts conversions to swizzled layout so as to avoid shared memory bank conflicts.
}];
let constructor = "mlir::createTritonGPUSwizzlePass()";
let dependentDialects = ["mlir::triton::gpu::TritonGPUDialect"];
let options = [
Option<"computeCapability", "compute-capability",
"int32_t", /*default*/"80",
"device compute capability">
];
}
def TritonGPUCanonicalizeLoops: Pass<"tritongpu-canonicalize-loops", "mlir::ModuleOp"> {

View File

@@ -14,6 +14,7 @@ namespace mlir {
class TritonGPUTypeConverter : public TypeConverter {
public:
TritonGPUTypeConverter(MLIRContext *context, int numWarps);
int getNumWarps() const { return numWarps; }
private:
MLIRContext *context;

View File

@@ -0,0 +1,20 @@
#ifndef TRITON_TARGET_HSACOTRANSLATION_H
#define TRITON_TARGET_HSACOTRANSLATION_H
#include <memory>
#include <string>
#include <tuple>
namespace llvm {
class Module;
} // namespace llvm
namespace triton {
// Translate TritonGPU IR to HSACO code.
std::tuple<std::string, std::string> translateLLVMIRToHSACO(llvm::Module& module,
std::string cc);
} // namespace triton
#endif

View File

@@ -2,6 +2,7 @@
#define TRITON_TARGET_LLVMIRTRANSLATION_H
#include "llvm/ADT/StringRef.h"
#include <memory>
#include <string>
#include <vector>
namespace llvm {
@@ -24,7 +25,7 @@ void addExternalLibs(mlir::ModuleOp &module,
// Translate TritonGPU dialect to LLVMIR, return null if failed.
std::unique_ptr<llvm::Module>
translateTritonGPUToLLVMIR(llvm::LLVMContext *llvmContext,
mlir::ModuleOp module);
mlir::ModuleOp module, int computeCapability);
// Translate mlir LLVM dialect to LLVMIR, return null if failed.
std::unique_ptr<llvm::Module>

View File

@@ -1,7 +1,6 @@
#ifndef TRITON_TARGET_PTXTRANSLATION_H
#define TRITON_TARGET_PTXTRANSLATION_H
#include <memory>
#include <string>
namespace llvm {

View File

@@ -26,13 +26,14 @@ ChangeResult SharedMemoryAliasAnalysis::visitOperation(
// These ops may allocate a new shared memory buffer.
auto result = op->getResult(0);
// FIXME(Keren): extract and insert are always alias for now
if (auto extractSliceOp = dyn_cast<tensor::ExtractSliceOp>(op)) {
if (isa<tensor::ExtractSliceOp, triton::TransOp>(op)) {
// extract_slice %src
aliasInfo = AliasInfo(operands[0]->getValue());
pessimistic = false;
} else if (auto insertSliceOp =
dyn_cast<triton::gpu::InsertSliceAsyncOp>(op)) {
} else if (isa<tensor::InsertSliceOp>(op) ||
isa<triton::gpu::InsertSliceAsyncOp>(op)) {
// insert_slice_async %src, %dst, %index
// insert_slice %src into %dst[%offsets]
aliasInfo = AliasInfo(operands[1]->getValue());
pessimistic = false;
} else if (isSharedEncoding(result)) {
@@ -44,7 +45,7 @@ ChangeResult SharedMemoryAliasAnalysis::visitOperation(
if (pessimistic) {
return markAllPessimisticFixpoint(op->getResults());
}
// Join all latice elements
// Join all lattice elements
ChangeResult result = ChangeResult::NoChange;
for (Value value : op->getResults()) {
result |= getLatticeElement(value).join(aliasInfo);

View File

@@ -12,6 +12,8 @@
#include <numeric>
using ::mlir::triton::gpu::BlockedEncodingAttr;
using ::mlir::triton::gpu::DotOperandEncodingAttr;
using ::mlir::triton::gpu::getContigPerThread;
using ::mlir::triton::gpu::getOrder;
using ::mlir::triton::gpu::getShapePerCTA;
using ::mlir::triton::gpu::getSizePerThread;
@@ -26,6 +28,29 @@ namespace mlir {
//===----------------------------------------------------------------------===//
namespace triton {
// Bitwidth of pointers
constexpr int kPtrBitWidth = 64;
static std::pair<SmallVector<unsigned>, SmallVector<unsigned>>
getCvtOrder(const Attribute &srcLayout, const Attribute &dstLayout) {
auto srcBlockedLayout = srcLayout.dyn_cast<BlockedEncodingAttr>();
auto srcMmaLayout = srcLayout.dyn_cast<MmaEncodingAttr>();
auto srcDotLayout = srcLayout.dyn_cast<DotOperandEncodingAttr>();
auto dstBlockedLayout = dstLayout.dyn_cast<BlockedEncodingAttr>();
auto dstMmaLayout = dstLayout.dyn_cast<MmaEncodingAttr>();
auto dstDotLayout = dstLayout.dyn_cast<DotOperandEncodingAttr>();
assert(!(srcMmaLayout && dstMmaLayout) &&
"Unexpected mma -> mma layout conversion");
// mma or dot layout does not have an order, so the order depends on the
// layout of the other operand.
auto inOrd = (srcMmaLayout || srcDotLayout) ? getOrder(dstLayout)
: getOrder(srcLayout);
auto outOrd = (dstMmaLayout || dstDotLayout) ? getOrder(srcLayout)
: getOrder(dstLayout);
return {inOrd, outOrd};
}
SmallVector<unsigned>
getScratchConfigForCvtLayout(triton::gpu::ConvertLayoutOp op, unsigned &inVec,
unsigned &outVec) {
@@ -35,18 +60,9 @@ getScratchConfigForCvtLayout(triton::gpu::ConvertLayoutOp op, unsigned &inVec,
Attribute dstLayout = dstTy.getEncoding();
assert(srcLayout && dstLayout &&
"Unexpect layout in getScratchConfigForCvtLayout()");
unsigned rank = dstTy.getRank();
SmallVector<unsigned> paddedRepShape(rank);
auto srcBlockedLayout = srcLayout.dyn_cast<BlockedEncodingAttr>();
auto srcMmaLayout = srcLayout.dyn_cast<MmaEncodingAttr>();
auto dstBlockedLayout = dstLayout.dyn_cast<BlockedEncodingAttr>();
auto dstMmaLayout = dstLayout.dyn_cast<MmaEncodingAttr>();
assert(!(srcMmaLayout && dstMmaLayout) &&
"Unexpected mma -> mma layout conversion");
auto inOrd = srcMmaLayout ? getOrder(dstLayout) : getOrder(srcLayout);
auto outOrd = dstMmaLayout ? getOrder(srcLayout) : getOrder(dstLayout);
unsigned srcContigPerThread = getSizePerThread(srcLayout)[inOrd[0]];
unsigned dstContigPerThread = getSizePerThread(dstLayout)[outOrd[0]];
auto [inOrd, outOrd] = getCvtOrder(srcLayout, dstLayout);
unsigned srcContigPerThread = getContigPerThread(srcLayout)[inOrd[0]];
unsigned dstContigPerThread = getContigPerThread(dstLayout)[outOrd[0]];
// TODO: Fix the legacy issue that ourOrd[0] == 0 always means
// that we cannot do vectorization.
inVec = outOrd[0] == 0 ? 1 : inOrd[0] == 0 ? 1 : srcContigPerThread;
@@ -55,6 +71,8 @@ getScratchConfigForCvtLayout(triton::gpu::ConvertLayoutOp op, unsigned &inVec,
auto srcShapePerCTA = getShapePerCTA(srcLayout);
auto dstShapePerCTA = getShapePerCTA(dstLayout);
unsigned rank = dstTy.getRank();
SmallVector<unsigned> paddedRepShape(rank);
unsigned pad = std::max(inVec, outVec);
for (unsigned d = 0; d < rank; ++d) {
paddedRepShape[d] =
@@ -71,30 +89,24 @@ getScratchConfigForCvtLayout(triton::gpu::ConvertLayoutOp op, unsigned &inVec,
return paddedRepShape;
}
SmallVector<unsigned> getScratchConfigForReduce(triton::ReduceOp op) {
auto srcTy = op.operand().getType().cast<RankedTensorType>();
auto srcLayout = srcTy.getEncoding().cast<BlockedEncodingAttr>();
auto srcShape = srcTy.getShape();
auto axis = op.axis();
bool fastReduce = axis == srcLayout.getOrder()[0];
// TODO: extend beyond scalars
SmallVector<unsigned> getScratchConfigForAtomicRMW(triton::AtomicRMWOp op) {
SmallVector<unsigned> smemShape;
for (auto d : srcShape)
smemShape.push_back(d);
if (fastReduce) {
unsigned sizeInterWarps = srcLayout.getWarpsPerCTA()[axis];
smemShape[axis] = sizeInterWarps;
if (op.ptr().getType().isa<RankedTensorType>()) {
// do nothing or just assert because shared memory is not used in tensor up
// to now
} else {
unsigned threadsPerCTAAxis =
srcLayout.getThreadsPerWarp()[axis] * srcLayout.getWarpsPerCTA()[axis];
smemShape[axis] = threadsPerCTAAxis;
// need only bytes for scalar
// always vec = 1 and elemsPerThread = 1 for scalar?
smemShape.push_back(1);
}
return smemShape;
}
SmallVector<unsigned> getScratchConfigForAtomicCAS(triton::AtomicCASOp op) {
return SmallVector<unsigned>{1};
}
class AllocationAnalysis {
public:
AllocationAnalysis(Operation *operation, Allocation *allocation)
@@ -124,8 +136,7 @@ private:
// For example: %a = scf.if -> yield
// %a must be allocated elsewhere by other operations.
// FIXME(Keren): extract and insert are always alias for now
if (!maybeSharedAllocationOp(op) || isa<tensor::ExtractSliceOp>(op) ||
isa<triton::gpu::InsertSliceAsyncOp>(op)) {
if (!maybeSharedAllocationOp(op) || maybeAliasOp(op)) {
return;
}
@@ -143,23 +154,10 @@ private:
/// Initializes temporary shared memory for a given operation.
void getScratchValueSize(Operation *op) {
// TODO(Keren): Add atomic ops
// TODO(Keren): Add convert ops
if (auto reduceOp = dyn_cast<triton::ReduceOp>(op)) {
// TODO(Keren): Reduce with index is not supported yet.
auto value = op->getOperand(0);
if (auto tensorType = value.getType().dyn_cast<RankedTensorType>()) {
if (tensorType.getEncoding().isa<BlockedEncodingAttr>()) {
auto smemShape = getScratchConfigForReduce(reduceOp);
unsigned elems = std::accumulate(smemShape.begin(), smemShape.end(),
1, std::multiplies{});
auto bytes = elems * tensorType.getElementTypeBitWidth() / 8;
allocation->addBuffer<BufferT::BufferKind::Scratch>(op, bytes);
} else {
assert(0 && "ReduceOp with input layout other than blocked layout is "
"not implemented yet");
}
}
ReduceOpHelper helper(reduceOp);
unsigned bytes = helper.getScratchSizeInBytes();
allocation->addBuffer<BufferT::BufferKind::Scratch>(op, bytes);
} else if (auto cvtLayout = dyn_cast<triton::gpu::ConvertLayoutOp>(op)) {
auto srcTy = cvtLayout.src().getType().cast<RankedTensorType>();
auto dstTy = cvtLayout.result().getType().cast<RankedTensorType>();
@@ -167,7 +165,7 @@ private:
auto dstEncoding = dstTy.getEncoding();
if (srcEncoding.isa<SharedEncodingAttr>() ||
dstEncoding.isa<SharedEncodingAttr>()) {
// Only blocked -> blocked conversion requires for scratch allocation
// Conversions from/to shared memory do not need scratch memory.
return;
}
// ConvertLayoutOp with both input/output non-shared_layout
@@ -179,7 +177,39 @@ private:
auto smemShape = getScratchConfigForCvtLayout(cvtLayout, inVec, outVec);
unsigned elems = std::accumulate(smemShape.begin(), smemShape.end(), 1,
std::multiplies{});
auto bytes = elems * srcTy.getElementTypeBitWidth() / 8;
auto bytes =
srcTy.getElementType().isa<triton::PointerType>()
? elems * kPtrBitWidth / 8
: elems * std::max<int>(8, srcTy.getElementTypeBitWidth()) / 8;
allocation->addBuffer<BufferT::BufferKind::Scratch>(op, bytes);
} else if (auto atomicRMWOp = dyn_cast<triton::AtomicRMWOp>(op)) {
auto value = op->getOperand(0);
// only scalar requires scratch memory
// make it explicit for readability
if (value.getType().dyn_cast<RankedTensorType>()) {
// nothing to do
} else {
auto smemShape = getScratchConfigForAtomicRMW(atomicRMWOp);
unsigned elems = std::accumulate(smemShape.begin(), smemShape.end(), 1,
std::multiplies{});
auto elemTy =
value.getType().cast<triton::PointerType>().getPointeeType();
auto bytes =
elemTy.isa<triton::PointerType>()
? elems * kPtrBitWidth / 8
: elems * std::max<int>(8, elemTy.getIntOrFloatBitWidth()) / 8;
allocation->addBuffer<BufferT::BufferKind::Scratch>(op, bytes);
}
} else if (auto atomicCASOp = dyn_cast<triton::AtomicCASOp>(op)) {
auto value = op->getOperand(0);
auto smemShape = getScratchConfigForAtomicCAS(atomicCASOp);
unsigned elems = std::accumulate(smemShape.begin(), smemShape.end(), 1,
std::multiplies{});
auto elemTy =
value.getType().cast<triton::PointerType>().getPointeeType();
auto bytes = elemTy.isa<triton::PointerType>()
? elems * kPtrBitWidth / 8
: elems * elemTy.getIntOrFloatBitWidth() / 8;
allocation->addBuffer<BufferT::BufferKind::Scratch>(op, bytes);
}
}
@@ -228,7 +258,7 @@ private:
}
}
/// Extends the liveness range by unioning the liveness range of the aliased
/// Extends the liveness range by unionizing the liveness range of the aliased
/// values because each allocated buffer could be an alias of others, if block
/// arguments are involved.
void resolveAliasBufferLiveness(

View File

@@ -132,6 +132,7 @@ ChangeResult AxisInfoAnalysis::visitOperation(
AxisInfo::DimVectorT(ty.getShape().begin(), ty.getShape().end()));
}
}
// TODO: refactor & complete binary ops
// Addition
if (llvm::isa<arith::AddIOp, triton::AddPtrOp>(op)) {
auto newContiguity = [&](AxisInfo lhs, AxisInfo rhs, int d) {
@@ -159,6 +160,20 @@ ChangeResult AxisInfoAnalysis::visitOperation(
curr = visitBinaryOp(op, operands[0]->getValue(), operands[1]->getValue(),
newContiguity, newDivisibility, newConstancy);
}
// Remainder
if (llvm::isa<arith::RemSIOp, arith::RemUIOp>(op)) {
auto newContiguity = [](AxisInfo lhs, AxisInfo rhs, int d) {
return gcd(lhs.getContiguity(d), rhs.getDivisibility(d));
};
auto newDivisibility = [](AxisInfo lhs, AxisInfo rhs, int d) {
return gcd(lhs.getDivisibility(d), rhs.getDivisibility(d));
};
auto newConstancy = [](AxisInfo lhs, AxisInfo rhs, int d) {
return gcd(lhs.getConstancy(d), rhs.getConstancy(d));
};
curr = visitBinaryOp(op, operands[0]->getValue(), operands[1]->getValue(),
newContiguity, newDivisibility, newConstancy);
}
// TODO: All other binary ops
if (llvm::isa<arith::AndIOp, arith::OrIOp>(op)) {
auto newContiguity = [](AxisInfo lhs, AxisInfo rhs, int d) { return 1; };
@@ -261,4 +276,46 @@ ChangeResult AxisInfoAnalysis::visitOperation(
return result;
}
unsigned AxisInfoAnalysis::getPtrVectorSize(Value ptr) {
auto tensorTy = ptr.getType().dyn_cast<RankedTensorType>();
if (!tensorTy)
return 1;
auto layout = tensorTy.getEncoding();
auto shape = tensorTy.getShape();
// Here order should be ordered by contiguous first, so the first element
// should have the largest contiguous.
auto order = triton::gpu::getOrder(layout);
unsigned align = getPtrAlignment(ptr);
unsigned contigPerThread = triton::gpu::getSizePerThread(layout)[order[0]];
unsigned vec = std::min(align, contigPerThread);
vec = std::min<unsigned>(shape[order[0]], vec);
return vec;
}
unsigned AxisInfoAnalysis::getPtrAlignment(Value ptr) {
auto tensorTy = ptr.getType().dyn_cast<RankedTensorType>();
if (!tensorTy)
return 1;
auto axisInfo = lookupLatticeElement(ptr)->getValue();
auto layout = tensorTy.getEncoding();
auto order = triton::gpu::getOrder(layout);
unsigned maxMultiple = axisInfo.getDivisibility(order[0]);
unsigned maxContig = axisInfo.getContiguity(order[0]);
unsigned alignment = std::min(maxMultiple, maxContig);
return alignment;
}
unsigned AxisInfoAnalysis::getMaskAlignment(Value mask) {
auto tensorTy = mask.getType().dyn_cast<RankedTensorType>();
if (!tensorTy)
return 1;
auto maskOrder = triton::gpu::getOrder(tensorTy.getEncoding());
auto maskAxis = lookupLatticeElement(mask)->getValue();
auto alignment = std::max<unsigned>(maskAxis.getConstancy(maskOrder[0]), 1);
return alignment;
}
} // namespace mlir

View File

@@ -1,4 +1,5 @@
#include "triton/Analysis/Membar.h"
#include "triton/Analysis/Alias.h"
#include "triton/Dialect/TritonGPU/IR/Dialect.h"
#include "mlir/Dialect/GPU/GPUDialect.h"
@@ -23,21 +24,43 @@ void MembarAnalysis::dfsOperation(Operation *operation,
// scf.if only: two regions
// scf.for: one region
RegionInfo curRegionInfo;
for (auto &region : operation->getRegions()) {
// Copy the parent info as the current info.
RegionInfo regionInfo = *parentRegionInfo;
for (auto &block : region.getBlocks()) {
assert(region.getBlocks().size() == 1 &&
"Multiple blocks in a region is not supported");
for (auto &op : block.getOperations()) {
// Traverse the nested operation.
dfsOperation(&op, &regionInfo, builder);
auto traverseRegions = [&]() -> auto{
for (auto &region : operation->getRegions()) {
// Copy the parent info as the current info.
RegionInfo regionInfo = *parentRegionInfo;
for (auto &block : region.getBlocks()) {
assert(region.getBlocks().size() == 1 &&
"Multiple blocks in a region is not supported");
for (auto &op : block.getOperations()) {
// Traverse the nested operation.
dfsOperation(&op, &regionInfo, builder);
}
}
curRegionInfo.join(regionInfo);
}
curRegionInfo.join(regionInfo);
// Set the parent region info as the union of the nested region info.
*parentRegionInfo = curRegionInfo;
};
traverseRegions();
if (isa<scf::ForOp>(operation)) {
// scf.for can have two possible inputs: the init value and the
// previous iteration's result. Although we've applied alias analysis,
// there could be unsynced memory accesses on reused memories.
// For example, consider the following code:
// %1 = convert_layout %0: blocked -> shared
// ...
// gpu.barrier
// ...
// %5 = convert_layout %4 : shared -> dot
// %6 = tt.dot %2, %5
// scf.yield
//
// Though %5 could be released before scf.yield, it may shared the same
// memory with %1. So we actually have to insert a barrier before %1 to
// make sure the memory is synced.
traverseRegions();
}
// Set the parent region info as the union of the nested region info.
*parentRegionInfo = curRegionInfo;
}
}
@@ -48,8 +71,7 @@ void MembarAnalysis::transfer(Operation *op, RegionInfo *regionInfo,
// Do not insert barriers before control flow operations and
// alloc/extract/insert
// alloc is an allocation op without memory write.
// In contrast, arith.constant is an allocation op with memory write.
// FIXME(Keren): extract is always alias for now
// FIXME(Keren): extract_slice is always alias for now
return;
}
@@ -59,9 +81,11 @@ void MembarAnalysis::transfer(Operation *op, RegionInfo *regionInfo,
return;
}
if (isa<triton::gpu::AsyncWaitOp>(op)) {
// If the current op is an async wait, we insert a barrier op and sync
// previous reads and writes.
if (isa<triton::gpu::AsyncWaitOp>(op) &&
!isa<gpu::BarrierOp>(op->getNextNode())) {
// If the current op is an async wait and the next op is not a barrier we
// insert a barrier op and sync
regionInfo->sync();
OpBuilder::InsertionGuard g(*builder);
builder->setInsertionPointAfter(op);
builder->create<gpu::BarrierOp>(op->getLoc());
@@ -71,11 +95,17 @@ void MembarAnalysis::transfer(Operation *op, RegionInfo *regionInfo,
RegionInfo curRegionInfo;
for (Value value : op->getOperands()) {
// ConvertLayoutOp: shared memory -> registers
// Need to consider all alias buffers
for (auto bufferId : allocation->getBufferIds(value)) {
if (bufferId != Allocation::InvalidBufferId) {
curRegionInfo.syncReadBuffers.insert(bufferId);
if (isa<triton::gpu::InsertSliceAsyncOp>(op) ||
isa<tensor::InsertSliceOp>(op)) {
// FIXME(Keren): insert_slice and insert_slice_async are always alias
// for now
curRegionInfo.syncWriteBuffers.insert(bufferId);
} else {
// ConvertLayoutOp: shared memory -> registers
curRegionInfo.syncReadBuffers.insert(bufferId);
}
}
}
}
@@ -86,9 +116,10 @@ void MembarAnalysis::transfer(Operation *op, RegionInfo *regionInfo,
curRegionInfo.syncWriteBuffers.insert(bufferId);
}
}
// Scratch buffer is considered as a shared memory read
// Scratch buffer is considered as both shared memory write & read
auto bufferId = allocation->getBufferId(op);
if (bufferId != Allocation::InvalidBufferId) {
curRegionInfo.syncWriteBuffers.insert(bufferId);
curRegionInfo.syncReadBuffers.insert(bufferId);
}

View File

@@ -5,6 +5,82 @@
namespace mlir {
bool ReduceOpHelper::isFastReduction() {
auto srcLayout = srcTy.getEncoding();
auto axis = op.axis();
return axis == triton::gpu::getOrder(srcLayout)[0];
}
unsigned ReduceOpHelper::getInterWarpSize() {
auto srcLayout = srcTy.getEncoding();
auto srcShape = srcTy.getShape();
auto axis = op.axis();
auto srcReduceDimSize = static_cast<unsigned>(srcShape[axis]);
unsigned sizeIntraWarps = getIntraWarpSize();
return std::min(srcReduceDimSize / sizeIntraWarps,
triton::gpu::getWarpsPerCTA(srcLayout)[axis]);
}
unsigned ReduceOpHelper::getIntraWarpSize() {
auto srcLayout = srcTy.getEncoding();
auto srcShape = srcTy.getShape();
auto axis = op.axis();
auto srcReduceDimSize = static_cast<unsigned>(srcShape[axis]);
return std::min(srcReduceDimSize,
triton::gpu::getThreadsPerWarp(srcLayout)[axis]);
}
unsigned ReduceOpHelper::getThreadsReductionAxis() {
auto srcLayout = srcTy.getEncoding();
auto axis = op.axis();
return triton::gpu::getThreadsPerWarp(srcLayout)[axis] *
triton::gpu::getWarpsPerCTA(srcLayout)[axis];
}
SmallVector<unsigned> ReduceOpHelper::getScratchConfigBasic() {
auto axis = op.axis();
auto smemShape = convertType<unsigned>(getSrcShape());
smemShape[axis] = std::min(smemShape[axis], getThreadsReductionAxis());
return smemShape;
}
SmallVector<SmallVector<unsigned>> ReduceOpHelper::getScratchConfigsFast() {
auto axis = op.axis();
SmallVector<SmallVector<unsigned>> smemShapes(3);
/// shared memory block0
smemShapes[0] = convertType<unsigned>(getSrcShape());
smemShapes[0][axis] = getInterWarpSize();
/// FIXME(Qingyi): This size is actually larger than required.
/// shared memory block1:
auto mod = op.getOperation()->getParentOfType<ModuleOp>();
unsigned numWarps = triton::gpu::TritonGPUDialect::getNumWarps(mod);
smemShapes[1].push_back(numWarps * 32);
return smemShapes;
}
unsigned ReduceOpHelper::getScratchSizeInBytes() {
unsigned elems = 0;
if (isFastReduction()) {
auto smemShapes = getScratchConfigsFast();
for (const auto &smemShape : smemShapes)
elems = std::max(elems, product<unsigned>(smemShape));
} else {
auto smemShape = getScratchConfigBasic();
elems = product<unsigned>(smemShape);
}
auto tensorType = op.operand().getType().cast<RankedTensorType>();
unsigned bytes = elems * tensorType.getElementTypeBitWidth() / 8;
if (triton::ReduceOp::withIndex(op.redOp()))
bytes += elems * sizeof(int32_t);
return bytes;
}
bool isSharedEncoding(Value value) {
auto type = value.getType();
if (auto tensorType = type.dyn_cast<RankedTensorType>()) {
@@ -28,6 +104,43 @@ bool maybeSharedAllocationOp(Operation *op) {
dialect->getTypeID() == mlir::TypeID::get<tensor::TensorDialect>());
}
bool maybeAliasOp(Operation *op) {
return isa<tensor::ExtractSliceOp>(op) || isa<triton::TransOp>(op) ||
isa<triton::gpu::InsertSliceAsyncOp>(op) ||
isa<tensor::InsertSliceOp>(op);
}
bool supportMMA(triton::DotOp op, int version) {
// Refer to mma section for the data type supported by Volta and Hopper
// Tensor Core in
// https://docs.nvidia.com/cuda/parallel-thread-execution/index.html#warp-level-matrix-fragment-mma-884-f16
auto aElemTy = op.a().getType().cast<RankedTensorType>().getElementType();
auto bElemTy = op.b().getType().cast<RankedTensorType>().getElementType();
if (aElemTy.isF32() && bElemTy.isF32()) {
return op.allowTF32() && version >= 2;
}
return supportMMA(op.a(), version) && supportMMA(op.b(), version);
}
bool supportMMA(Value value, int version) {
// Tell whether a DotOp support HMMA by the operand type(either $a or $b).
// We cannot get both the operand types(in TypeConverter), here we assume the
// types of both the operands are identical here.
assert((version == 1 || version == 2) &&
"Unexpected MMA layout version found");
auto elemTy = value.getType().cast<RankedTensorType>().getElementType();
return elemTy.isF16() || elemTy.isBF16() ||
(elemTy.isF32() && version >= 2) ||
(elemTy.isInteger(8) && version >= 2);
}
Type getElementType(Value value) {
auto type = value.getType();
if (auto tensorType = type.dyn_cast<RankedTensorType>())
return tensorType.getElementType();
return type;
}
std::string getValueOperandName(Value value, AsmState &state) {
std::string opName;
llvm::raw_string_ostream ss(opName);

View File

@@ -1,21 +0,0 @@
#ifndef TRITON_CONVERSION_PASSDETAIL_H
#define TRITON_CONVERSION_PASSDETAIL_H
#include "mlir/Dialect/GPU/GPUDialect.h"
#include "mlir/Dialect/LLVMIR/LLVMDialect.h"
#include "mlir/Dialect/LLVMIR/NVVMDialect.h"
#include "mlir/Dialect/LLVMIR/ROCDLDialect.h"
#include "mlir/Pass/Pass.h"
#include "triton/Dialect/Triton/IR/Dialect.h"
#include "triton/Dialect/TritonGPU/IR/Dialect.h"
namespace mlir {
namespace triton {
#define GEN_PASS_CLASSES
#include "triton/Conversion/Passes.h.inc"
} // namespace triton
} // namespace mlir
#endif

View File

@@ -1,7 +1,14 @@
add_mlir_conversion_library(TritonGPUToLLVM
TritonGPUToLLVM.cpp
GcnAsmFormat.cpp
PtxAsmFormat.cpp
GCNAsmFormat.cpp
PTXAsmFormat.cpp
TritonGPUToLLVMPass.cpp
ConvertLayoutOpToLLVM.cpp
ElementwiseOpToLLVM.cpp
ViewOpToLLVM.cpp
LoadStoreOpToLLVM.cpp
DotOpToLLVM.cpp
ReduceOpToLLVM.cpp
ADDITIONAL_HEADER_DIRS
${PROJECT_SOURCE_DIR}/include/triton/Conversion/TritonGPUToLLVM

View File

@@ -0,0 +1,686 @@
#include "ConvertLayoutOpToLLVM.h"
#include "DotOpHelpers.h"
using ::mlir::LLVM::DotOpFMAConversionHelper;
using ::mlir::LLVM::DotOpMmaV1ConversionHelper;
using ::mlir::LLVM::getElementsFromStruct;
using ::mlir::LLVM::getSharedMemoryObjectFromStruct;
using ::mlir::LLVM::getStridesFromShapeAndOrder;
using ::mlir::LLVM::getStructFromElements;
using ::mlir::LLVM::MMA16816ConversionHelper;
using ::mlir::triton::gpu::DotOperandEncodingAttr;
using ::mlir::triton::gpu::getElemsPerThread;
using ::mlir::triton::gpu::getOrder;
using ::mlir::triton::gpu::getShapePerCTA;
using ::mlir::triton::gpu::getSizePerThread;
using ::mlir::triton::gpu::SharedEncodingAttr;
bool isMmaToDotShortcut(MmaEncodingAttr &mmaLayout,
DotOperandEncodingAttr &dotOperandLayout) {
// dot_op<opIdx=0, parent=#mma> = #mma
// when #mma = MmaEncoding<version=2, warpsPerCTA=[..., 1]>
return mmaLayout.getWarpsPerCTA()[1] == 1 &&
dotOperandLayout.getOpIdx() == 0 &&
dotOperandLayout.getParent() == mmaLayout;
}
void storeBlockedToShared(Value src, Value llSrc, ArrayRef<Value> srcStrides,
ArrayRef<Value> srcIndices, Value dst, Value smemBase,
Type elemTy, Location loc,
ConversionPatternRewriter &rewriter) {
auto srcTy = src.getType().cast<RankedTensorType>();
auto srcShape = srcTy.getShape();
assert(srcShape.size() == 2 && "Unexpected rank of insertSlice");
auto dstTy = dst.getType().cast<RankedTensorType>();
auto srcBlockedLayout = srcTy.getEncoding().cast<BlockedEncodingAttr>();
auto dstSharedLayout = dstTy.getEncoding().cast<SharedEncodingAttr>();
auto inOrd = srcBlockedLayout.getOrder();
auto outOrd = dstSharedLayout.getOrder();
if (inOrd != outOrd)
llvm_unreachable(
"blocked -> shared with different order not yet implemented");
unsigned inVec =
inOrd == outOrd ? srcBlockedLayout.getSizePerThread()[inOrd[0]] : 1;
unsigned outVec = dstSharedLayout.getVec();
unsigned minVec = std::min(outVec, inVec);
unsigned perPhase = dstSharedLayout.getPerPhase();
unsigned maxPhase = dstSharedLayout.getMaxPhase();
unsigned numElems = getElemsPerThread(srcTy);
auto inVals = getElementsFromStruct(loc, llSrc, rewriter);
auto srcAccumSizeInThreads =
product<unsigned>(srcBlockedLayout.getSizePerThread());
auto wordTy = vec_ty(elemTy, minVec);
auto elemPtrTy = ptr_ty(elemTy);
// TODO: [goostavz] We should make a cache for the calculation of
// emitBaseIndexForBlockedLayout in case backend compiler not being able to
// optimize that
SmallVector<unsigned> srcShapePerCTA = getShapePerCTA(srcBlockedLayout);
SmallVector<unsigned> reps{ceil<unsigned>(srcShape[0], srcShapePerCTA[0]),
ceil<unsigned>(srcShape[1], srcShapePerCTA[1])};
// Visit each input value in the order they are placed in inVals
//
// Please note that the order was not awaring of blockLayout.getOrder(),
// thus the adjacent elems may not belong to a same word. This could be
// improved if we update the elements order by emitIndicesForBlockedLayout()
SmallVector<unsigned> wordsInEachRep(2);
wordsInEachRep[0] = inOrd[0] == 0
? srcBlockedLayout.getSizePerThread()[0] / minVec
: srcBlockedLayout.getSizePerThread()[0];
wordsInEachRep[1] = inOrd[0] == 0
? srcBlockedLayout.getSizePerThread()[1]
: srcBlockedLayout.getSizePerThread()[1] / minVec;
Value outVecVal = i32_val(outVec);
Value minVecVal = i32_val(minVec);
auto numWordsEachRep = product<unsigned>(wordsInEachRep);
SmallVector<Value> wordVecs(numWordsEachRep);
for (unsigned i = 0; i < numElems; ++i) {
if (i % srcAccumSizeInThreads == 0) {
// start of a replication
for (unsigned w = 0; w < numWordsEachRep; ++w) {
wordVecs[w] = undef(wordTy);
}
}
unsigned linearIdxInNanoTile = i % srcAccumSizeInThreads;
auto multiDimIdxInNanoTile = getMultiDimIndex<unsigned>(
linearIdxInNanoTile, srcBlockedLayout.getSizePerThread(), inOrd);
unsigned pos = multiDimIdxInNanoTile[inOrd[0]] % minVec;
multiDimIdxInNanoTile[inOrd[0]] /= minVec;
auto wordVecIdx =
getLinearIndex<unsigned>(multiDimIdxInNanoTile, wordsInEachRep, inOrd);
wordVecs[wordVecIdx] =
insert_element(wordTy, wordVecs[wordVecIdx], inVals[i], i32_val(pos));
if (i % srcAccumSizeInThreads == srcAccumSizeInThreads - 1) {
// end of replication, store the vectors into shared memory
unsigned linearRepIdx = i / srcAccumSizeInThreads;
auto multiDimRepIdx =
getMultiDimIndex<unsigned>(linearRepIdx, reps, inOrd);
for (unsigned linearWordIdx = 0; linearWordIdx < numWordsEachRep;
++linearWordIdx) {
// step 1: recover the multidim_index from the index of
// input_elements
auto multiDimWordIdx =
getMultiDimIndex<unsigned>(linearWordIdx, wordsInEachRep, inOrd);
SmallVector<Value> multiDimIdx(2);
auto wordOffset0 = multiDimRepIdx[0] * srcShapePerCTA[0] +
multiDimWordIdx[0] * (inOrd[0] == 0 ? minVec : 1);
auto wordOffset1 = multiDimRepIdx[1] * srcShapePerCTA[1] +
multiDimWordIdx[1] * (inOrd[0] == 1 ? minVec : 1);
multiDimIdx[0] = add(srcIndices[0], i32_val(wordOffset0));
multiDimIdx[1] = add(srcIndices[1], i32_val(wordOffset1));
// step 2: do swizzling
Value remained = urem(multiDimIdx[outOrd[0]], outVecVal);
multiDimIdx[outOrd[0]] = udiv(multiDimIdx[outOrd[0]], outVecVal);
Value off_1 = mul(multiDimIdx[outOrd[1]], srcStrides[outOrd[1]]);
Value phaseId = udiv(multiDimIdx[outOrd[1]], i32_val(perPhase));
phaseId = urem(phaseId, i32_val(maxPhase));
Value off_0 = xor_(multiDimIdx[outOrd[0]], phaseId);
off_0 = mul(off_0, outVecVal);
remained = udiv(remained, minVecVal);
off_0 = add(off_0, mul(remained, minVecVal));
Value offset = add(off_1, off_0);
// step 3: store
Value smemAddr = gep(elemPtrTy, smemBase, offset);
smemAddr = bitcast(smemAddr, ptr_ty(wordTy, 3));
store(wordVecs[linearWordIdx], smemAddr);
}
}
}
}
struct ConvertLayoutOpConversion
: public ConvertTritonGPUOpToLLVMPattern<triton::gpu::ConvertLayoutOp> {
public:
using ConvertTritonGPUOpToLLVMPattern<
triton::gpu::ConvertLayoutOp>::ConvertTritonGPUOpToLLVMPattern;
LogicalResult
matchAndRewrite(triton::gpu::ConvertLayoutOp op, OpAdaptor adaptor,
ConversionPatternRewriter &rewriter) const override {
Value src = op.src();
Value dst = op.result();
auto srcTy = src.getType().cast<RankedTensorType>();
auto dstTy = dst.getType().cast<RankedTensorType>();
Attribute srcLayout = srcTy.getEncoding();
Attribute dstLayout = dstTy.getEncoding();
if (srcLayout.isa<BlockedEncodingAttr>() &&
dstLayout.isa<SharedEncodingAttr>()) {
return lowerBlockedToShared(op, adaptor, rewriter);
}
if (srcLayout.isa<SharedEncodingAttr>() &&
dstLayout.isa<DotOperandEncodingAttr>()) {
return lowerSharedToDotOperand(op, adaptor, rewriter);
}
if ((srcLayout.isa<BlockedEncodingAttr>() ||
srcLayout.isa<MmaEncodingAttr>() ||
srcLayout.isa<SliceEncodingAttr>()) &&
(dstLayout.isa<BlockedEncodingAttr>() ||
dstLayout.isa<MmaEncodingAttr>() ||
dstLayout.isa<SliceEncodingAttr>())) {
return lowerDistributedToDistributed(op, adaptor, rewriter);
}
if (srcLayout.isa<MmaEncodingAttr>() &&
dstLayout.isa<DotOperandEncodingAttr>()) {
return lowerMmaToDotOperand(op, adaptor, rewriter);
}
// TODO: to be implemented
llvm_unreachable("unsupported layout conversion");
return failure();
}
private:
SmallVector<Value> getMultiDimOffset(Attribute layout, Location loc,
ConversionPatternRewriter &rewriter,
unsigned elemId, ArrayRef<int64_t> shape,
ArrayRef<unsigned> multiDimCTAInRepId,
ArrayRef<unsigned> shapePerCTA) const {
unsigned rank = shape.size();
if (auto blockedLayout = layout.dyn_cast<BlockedEncodingAttr>()) {
auto multiDimOffsetFirstElem =
emitBaseIndexForBlockedLayout(loc, rewriter, blockedLayout, shape);
SmallVector<Value> multiDimOffset(rank);
SmallVector<unsigned> multiDimElemId = getMultiDimIndex<unsigned>(
elemId, getSizePerThread(layout), getOrder(layout));
for (unsigned d = 0; d < rank; ++d) {
multiDimOffset[d] = add(multiDimOffsetFirstElem[d],
idx_val(multiDimCTAInRepId[d] * shapePerCTA[d] +
multiDimElemId[d]));
}
return multiDimOffset;
}
if (auto sliceLayout = layout.dyn_cast<SliceEncodingAttr>()) {
unsigned dim = sliceLayout.getDim();
auto multiDimOffsetParent =
getMultiDimOffset(sliceLayout.getParent(), loc, rewriter, elemId,
sliceLayout.paddedShape(shape),
sliceLayout.paddedShape(multiDimCTAInRepId),
sliceLayout.paddedShape(shapePerCTA));
SmallVector<Value> multiDimOffset(rank);
for (unsigned d = 0; d < rank + 1; ++d) {
if (d == dim)
continue;
unsigned slicedD = d < dim ? d : (d - 1);
multiDimOffset[slicedD] = multiDimOffsetParent[d];
}
return multiDimOffset;
}
if (auto mmaLayout = layout.dyn_cast<MmaEncodingAttr>()) {
SmallVector<Value> mmaColIdx(4);
SmallVector<Value> mmaRowIdx(2);
Value threadId = getThreadId(rewriter, loc);
Value warpSize = idx_val(32);
Value laneId = urem(threadId, warpSize);
Value warpId = udiv(threadId, warpSize);
// TODO: fix the bug in MMAEncodingAttr document
SmallVector<Value> multiDimWarpId(2);
multiDimWarpId[0] = urem(warpId, idx_val(mmaLayout.getWarpsPerCTA()[0]));
multiDimWarpId[1] = udiv(warpId, idx_val(mmaLayout.getWarpsPerCTA()[0]));
Value _1 = idx_val(1);
Value _2 = idx_val(2);
Value _4 = idx_val(4);
Value _8 = idx_val(8);
Value _16 = idx_val(16);
if (mmaLayout.isAmpere()) {
multiDimWarpId[0] = urem(multiDimWarpId[0], idx_val(shape[0] / 16));
multiDimWarpId[1] = urem(multiDimWarpId[1], idx_val(shape[1] / 8));
Value mmaGrpId = udiv(laneId, _4);
Value mmaGrpIdP8 = add(mmaGrpId, _8);
Value mmaThreadIdInGrp = urem(laneId, _4);
Value mmaThreadIdInGrpM2 = mul(mmaThreadIdInGrp, _2);
Value mmaThreadIdInGrpM2P1 = add(mmaThreadIdInGrpM2, _1);
Value rowWarpOffset = mul(multiDimWarpId[0], _16);
mmaRowIdx[0] = add(mmaGrpId, rowWarpOffset);
mmaRowIdx[1] = add(mmaGrpIdP8, rowWarpOffset);
Value colWarpOffset = mul(multiDimWarpId[1], _8);
mmaColIdx[0] = add(mmaThreadIdInGrpM2, colWarpOffset);
mmaColIdx[1] = add(mmaThreadIdInGrpM2P1, colWarpOffset);
} else if (mmaLayout.isVolta()) {
multiDimWarpId[0] = urem(multiDimWarpId[0], idx_val(shape[0] / 16));
multiDimWarpId[1] = urem(multiDimWarpId[1], idx_val(shape[1] / 16));
Value laneIdDiv16 = udiv(laneId, _16);
Value laneIdRem16 = urem(laneId, _16);
Value laneIdRem2 = urem(laneId, _2);
Value laneIdRem16Div8 = udiv(laneIdRem16, _8);
Value laneIdRem16Div4 = udiv(laneIdRem16, _4);
Value laneIdRem16Div4Rem2 = urem(laneIdRem16Div4, _2);
Value laneIdRem4Div2 = udiv(urem(laneId, _4), _2);
Value rowWarpOffset = mul(multiDimWarpId[0], _16);
Value colWarpOffset = mul(multiDimWarpId[1], _16);
mmaRowIdx[0] =
add(add(mul(laneIdDiv16, _8), mul(laneIdRem16Div4Rem2, _4)),
laneIdRem2);
mmaRowIdx[0] = add(mmaRowIdx[0], rowWarpOffset);
mmaRowIdx[1] = add(mmaRowIdx[0], _2);
mmaColIdx[0] = add(mul(laneIdRem16Div8, _4), mul(laneIdRem4Div2, _2));
mmaColIdx[0] = add(mmaColIdx[0], colWarpOffset);
mmaColIdx[1] = add(mmaColIdx[0], _1);
mmaColIdx[2] = add(mmaColIdx[0], _8);
mmaColIdx[3] = add(mmaColIdx[0], idx_val(9));
} else {
llvm_unreachable("Unexpected MMALayout version");
}
assert(rank == 2);
SmallVector<Value> multiDimOffset(rank);
if (mmaLayout.isAmpere()) {
multiDimOffset[0] = elemId < 2 ? mmaRowIdx[0] : mmaRowIdx[1];
multiDimOffset[1] = elemId % 2 == 0 ? mmaColIdx[0] : mmaColIdx[1];
multiDimOffset[0] = add(
multiDimOffset[0], idx_val(multiDimCTAInRepId[0] * shapePerCTA[0]));
multiDimOffset[1] = add(
multiDimOffset[1], idx_val(multiDimCTAInRepId[1] * shapePerCTA[1]));
} else if (mmaLayout.isVolta()) {
// the order of elements in a thread:
// c0, c1, ... c4, c5
// c2, c3, ... c6, c7
if (elemId < 2) {
multiDimOffset[0] = mmaRowIdx[0];
multiDimOffset[1] = mmaColIdx[elemId % 2];
} else if (elemId >= 2 && elemId < 4) {
multiDimOffset[0] = mmaRowIdx[1];
multiDimOffset[1] = mmaColIdx[elemId % 2];
} else if (elemId >= 4 && elemId < 6) {
multiDimOffset[0] = mmaRowIdx[0];
multiDimOffset[1] = mmaColIdx[elemId % 2 + 2];
} else if (elemId >= 6) {
multiDimOffset[0] = mmaRowIdx[1];
multiDimOffset[1] = mmaColIdx[elemId % 2 + 2];
}
multiDimOffset[0] = add(
multiDimOffset[0], idx_val(multiDimCTAInRepId[0] * shapePerCTA[0]));
multiDimOffset[1] = add(
multiDimOffset[1], idx_val(multiDimCTAInRepId[1] * shapePerCTA[1]));
} else {
llvm_unreachable("Unexpected MMALayout version");
}
return multiDimOffset;
}
llvm_unreachable("unexpected layout in getMultiDimOffset");
}
// shared memory rd/st for blocked or mma layout with data padding
void processReplica(Location loc, ConversionPatternRewriter &rewriter,
bool stNotRd, RankedTensorType type,
ArrayRef<unsigned> numCTAsEachRep,
ArrayRef<unsigned> multiDimRepId, unsigned vec,
ArrayRef<unsigned> paddedRepShape,
ArrayRef<unsigned> outOrd, SmallVector<Value> &vals,
Value smemBase) const {
auto accumNumCTAsEachRep = product<unsigned>(numCTAsEachRep);
auto layout = type.getEncoding();
auto blockedLayout = layout.dyn_cast<BlockedEncodingAttr>();
auto sliceLayout = layout.dyn_cast<SliceEncodingAttr>();
auto mmaLayout = layout.dyn_cast<MmaEncodingAttr>();
auto rank = type.getRank();
auto sizePerThread = getSizePerThread(layout);
auto accumSizePerThread = product<unsigned>(sizePerThread);
SmallVector<unsigned> numCTAs(rank);
auto shapePerCTA = getShapePerCTA(layout);
auto order = getOrder(layout);
for (unsigned d = 0; d < rank; ++d) {
numCTAs[d] = ceil<unsigned>(type.getShape()[d], shapePerCTA[d]);
}
auto elemTy = type.getElementType();
bool isInt1 = elemTy.isInteger(1);
bool isPtr = elemTy.isa<triton::PointerType>();
auto llvmElemTyOrig = getTypeConverter()->convertType(elemTy);
if (isInt1)
elemTy = IntegerType::get(elemTy.getContext(), 8);
else if (isPtr)
elemTy = IntegerType::get(elemTy.getContext(), 64);
auto llvmElemTy = getTypeConverter()->convertType(elemTy);
for (unsigned ctaId = 0; ctaId < accumNumCTAsEachRep; ++ctaId) {
auto multiDimCTAInRepId =
getMultiDimIndex<unsigned>(ctaId, numCTAsEachRep, order);
SmallVector<unsigned> multiDimCTAId(rank);
for (const auto &it : llvm::enumerate(multiDimCTAInRepId)) {
auto d = it.index();
multiDimCTAId[d] = multiDimRepId[d] * numCTAsEachRep[d] + it.value();
}
auto linearCTAId =
getLinearIndex<unsigned>(multiDimCTAId, numCTAs, order);
// TODO: This is actually redundant index calculation, we should
// consider of caching the index calculation result in case
// of performance issue observed.
for (unsigned elemId = 0; elemId < accumSizePerThread; elemId += vec) {
SmallVector<Value> multiDimOffset =
getMultiDimOffset(layout, loc, rewriter, elemId, type.getShape(),
multiDimCTAInRepId, shapePerCTA);
Value offset =
linearize(rewriter, loc, multiDimOffset, paddedRepShape, outOrd);
auto elemPtrTy = ptr_ty(llvmElemTy, 3);
Value ptr = gep(elemPtrTy, smemBase, offset);
auto vecTy = vec_ty(llvmElemTy, vec);
ptr = bitcast(ptr, ptr_ty(vecTy, 3));
if (stNotRd) {
Value valVec = undef(vecTy);
for (unsigned v = 0; v < vec; ++v) {
auto currVal = vals[elemId + linearCTAId * accumSizePerThread + v];
if (isInt1)
currVal = zext(llvmElemTy, currVal);
else if (isPtr)
currVal = ptrtoint(llvmElemTy, currVal);
valVec = insert_element(vecTy, valVec, currVal, idx_val(v));
}
store(valVec, ptr);
} else {
Value valVec = load(ptr);
for (unsigned v = 0; v < vec; ++v) {
Value currVal = extract_element(llvmElemTy, valVec, idx_val(v));
if (isInt1)
currVal = icmp_ne(currVal,
rewriter.create<LLVM::ConstantOp>(
loc, i8_ty, rewriter.getI8IntegerAttr(0)));
else if (isPtr)
currVal = inttoptr(llvmElemTyOrig, currVal);
vals[elemId + linearCTAId * accumSizePerThread + v] = currVal;
}
}
}
}
}
// blocked/mma -> blocked/mma.
// Data padding in shared memory to avoid bank conflict.
LogicalResult
lowerDistributedToDistributed(triton::gpu::ConvertLayoutOp op,
OpAdaptor adaptor,
ConversionPatternRewriter &rewriter) const {
auto loc = op.getLoc();
Value src = op.src();
Value dst = op.result();
auto srcTy = src.getType().cast<RankedTensorType>();
auto dstTy = dst.getType().cast<RankedTensorType>();
Attribute srcLayout = srcTy.getEncoding();
Attribute dstLayout = dstTy.getEncoding();
auto llvmElemTy = getTypeConverter()->convertType(dstTy.getElementType());
Value smemBase = getSharedMemoryBase(loc, rewriter, op.getOperation());
auto elemPtrTy = ptr_ty(llvmElemTy, 3);
smemBase = bitcast(smemBase, elemPtrTy);
auto shape = dstTy.getShape();
unsigned rank = dstTy.getRank();
SmallVector<unsigned> numReplicates(rank);
SmallVector<unsigned> inNumCTAsEachRep(rank);
SmallVector<unsigned> outNumCTAsEachRep(rank);
SmallVector<unsigned> inNumCTAs(rank);
SmallVector<unsigned> outNumCTAs(rank);
auto srcShapePerCTA = getShapePerCTA(srcLayout);
auto dstShapePerCTA = getShapePerCTA(dstLayout);
for (unsigned d = 0; d < rank; ++d) {
unsigned inPerCTA = std::min<unsigned>(shape[d], srcShapePerCTA[d]);
unsigned outPerCTA = std::min<unsigned>(shape[d], dstShapePerCTA[d]);
unsigned maxPerCTA = std::max(inPerCTA, outPerCTA);
numReplicates[d] = ceil<unsigned>(shape[d], maxPerCTA);
inNumCTAsEachRep[d] = maxPerCTA / inPerCTA;
outNumCTAsEachRep[d] = maxPerCTA / outPerCTA;
assert(maxPerCTA % inPerCTA == 0 && maxPerCTA % outPerCTA == 0);
inNumCTAs[d] = ceil<unsigned>(shape[d], inPerCTA);
outNumCTAs[d] = ceil<unsigned>(shape[d], outPerCTA);
}
// Potentially we need to store for multiple CTAs in this replication
auto accumNumReplicates = product<unsigned>(numReplicates);
// unsigned elems = getElemsPerThread(srcTy);
auto vals = getElementsFromStruct(loc, adaptor.src(), rewriter);
unsigned inVec = 0;
unsigned outVec = 0;
auto paddedRepShape = getScratchConfigForCvtLayout(op, inVec, outVec);
unsigned outElems = getElemsPerThread(dstTy);
auto outOrd = getOrder(dstLayout);
SmallVector<Value> outVals(outElems);
for (unsigned repId = 0; repId < accumNumReplicates; ++repId) {
auto multiDimRepId =
getMultiDimIndex<unsigned>(repId, numReplicates, outOrd);
if (repId != 0)
barrier();
if (srcLayout.isa<BlockedEncodingAttr>() ||
srcLayout.isa<SliceEncodingAttr>() ||
srcLayout.isa<MmaEncodingAttr>()) {
processReplica(loc, rewriter, /*stNotRd*/ true, srcTy, inNumCTAsEachRep,
multiDimRepId, inVec, paddedRepShape, outOrd, vals,
smemBase);
} else {
assert(0 && "ConvertLayout with input layout not implemented");
return failure();
}
barrier();
if (dstLayout.isa<BlockedEncodingAttr>() ||
dstLayout.isa<SliceEncodingAttr>() ||
dstLayout.isa<MmaEncodingAttr>()) {
processReplica(loc, rewriter, /*stNotRd*/ false, dstTy,
outNumCTAsEachRep, multiDimRepId, outVec, paddedRepShape,
outOrd, outVals, smemBase);
} else {
assert(0 && "ConvertLayout with output layout not implemented");
return failure();
}
}
SmallVector<Type> types(outElems, llvmElemTy);
auto *ctx = llvmElemTy.getContext();
Type structTy = struct_ty(types);
Value result = getStructFromElements(loc, outVals, rewriter, structTy);
rewriter.replaceOp(op, result);
return success();
}
// blocked -> shared.
// Swizzling in shared memory to avoid bank conflict. Normally used for
// A/B operands of dots.
LogicalResult
lowerBlockedToShared(triton::gpu::ConvertLayoutOp op, OpAdaptor adaptor,
ConversionPatternRewriter &rewriter) const {
auto loc = op.getLoc();
Value src = op.src();
Value dst = op.result();
auto srcTy = src.getType().cast<RankedTensorType>();
auto srcShape = srcTy.getShape();
auto dstTy = dst.getType().cast<RankedTensorType>();
auto dstShape = dstTy.getShape();
assert(srcShape.size() == 2 &&
"Unexpected rank of ConvertLayout(blocked->shared)");
auto srcBlockedLayout = srcTy.getEncoding().cast<BlockedEncodingAttr>();
auto dstSharedLayout = dstTy.getEncoding().cast<SharedEncodingAttr>();
auto inOrd = srcBlockedLayout.getOrder();
auto outOrd = dstSharedLayout.getOrder();
Value smemBase = getSharedMemoryBase(loc, rewriter, dst);
auto elemTy = getTypeConverter()->convertType(srcTy.getElementType());
auto elemPtrTy = ptr_ty(getTypeConverter()->convertType(elemTy), 3);
smemBase = bitcast(smemBase, elemPtrTy);
auto srcStrides =
getStridesFromShapeAndOrder(srcShape, inOrd, loc, rewriter);
auto srcIndices = emitBaseIndexForBlockedLayout(loc, rewriter,
srcBlockedLayout, srcShape);
storeBlockedToShared(src, adaptor.src(), srcStrides, srcIndices, dst,
smemBase, elemTy, loc, rewriter);
auto smemObj =
SharedMemoryObject(smemBase, dstShape, outOrd, loc, rewriter);
auto retVal = getStructFromSharedMemoryObject(loc, smemObj, rewriter);
rewriter.replaceOp(op, retVal);
return success();
}
// shared -> mma_operand
LogicalResult
lowerSharedToDotOperand(triton::gpu::ConvertLayoutOp op, OpAdaptor adaptor,
ConversionPatternRewriter &rewriter) const {
auto loc = op.getLoc();
Value src = op.src();
Value dst = op.result();
auto dstTensorTy = dst.getType().cast<RankedTensorType>();
auto srcTensorTy = src.getType().cast<RankedTensorType>();
auto dotOperandLayout =
dstTensorTy.getEncoding().cast<DotOperandEncodingAttr>();
auto sharedLayout = srcTensorTy.getEncoding().cast<SharedEncodingAttr>();
bool isOuter{};
int K{};
if (dotOperandLayout.getOpIdx() == 0) // $a
K = dstTensorTy.getShape()[sharedLayout.getOrder()[0]];
else // $b
K = dstTensorTy.getShape()[sharedLayout.getOrder()[1]];
isOuter = K == 1;
Value res;
if (auto mmaLayout =
dotOperandLayout.getParent().dyn_cast_or_null<MmaEncodingAttr>()) {
res = lowerSharedToDotOperandMMA(op, adaptor, rewriter, mmaLayout,
dotOperandLayout, isOuter);
} else if (auto blockedLayout =
dotOperandLayout.getParent()
.dyn_cast_or_null<BlockedEncodingAttr>()) {
auto dotOpLayout =
dstTensorTy.getEncoding().cast<DotOperandEncodingAttr>();
DotOpFMAConversionHelper helper(blockedLayout);
auto thread = getThreadId(rewriter, loc);
if (dotOpLayout.getOpIdx() == 0) { // $a
res = helper.loadA(src, adaptor.src(), blockedLayout, thread, loc,
rewriter);
} else { // $b
res = helper.loadB(src, adaptor.src(), blockedLayout, thread, loc,
rewriter);
}
} else {
assert(false && "Unsupported dot operand layout found");
}
rewriter.replaceOp(op, res);
return success();
}
// mma -> dot_operand
LogicalResult
lowerMmaToDotOperand(triton::gpu::ConvertLayoutOp op, OpAdaptor adaptor,
ConversionPatternRewriter &rewriter) const {
auto loc = op.getLoc();
auto srcTy = op.src().getType().cast<RankedTensorType>();
auto dstTy = op.result().getType().cast<RankedTensorType>();
auto srcLayout = srcTy.getEncoding();
auto dstLayout = dstTy.getEncoding();
auto srcMmaLayout = srcLayout.cast<MmaEncodingAttr>();
auto dstDotLayout = dstLayout.cast<DotOperandEncodingAttr>();
if (isMmaToDotShortcut(srcMmaLayout, dstDotLayout)) {
// get source values
auto vals = getElementsFromStruct(loc, adaptor.src(), rewriter);
unsigned elems = getElemsPerThread(srcTy);
Type elemTy =
this->getTypeConverter()->convertType(srcTy.getElementType());
// for the destination type, we need to pack values together
// so they can be consumed by tensor core operations
unsigned vecSize =
std::max<unsigned>(32 / elemTy.getIntOrFloatBitWidth(), 1);
Type vecTy = vec_ty(elemTy, vecSize);
SmallVector<Type> types(elems / vecSize, vecTy);
SmallVector<Value> vecVals;
for (unsigned i = 0; i < elems; i += vecSize) {
Value packed = rewriter.create<LLVM::UndefOp>(loc, vecTy);
for (unsigned j = 0; j < vecSize; j++)
packed = insert_element(vecTy, packed, vals[i + j], i32_val(j));
vecVals.push_back(packed);
}
// This needs to be ordered the same way that
// ldmatrix.x4 would order it
// TODO: this needs to be refactor so we don't
// implicitly depends on how emitOffsetsForMMAV2
// is implemented
SmallVector<Value> reorderedVals;
for (unsigned i = 0; i < vecVals.size(); i += 4) {
reorderedVals.push_back(vecVals[i]);
reorderedVals.push_back(vecVals[i + 2]);
reorderedVals.push_back(vecVals[i + 1]);
reorderedVals.push_back(vecVals[i + 3]);
}
// return composeValuesToDotOperandLayoutStruct(ha, numRepM, numRepK);
Type structTy =
LLVM::LLVMStructType::getLiteral(this->getContext(), types);
Value view =
getStructFromElements(loc, reorderedVals, rewriter, structTy);
rewriter.replaceOp(op, view);
return success();
}
return failure();
}
// shared -> dot_operand if the result layout is mma
Value lowerSharedToDotOperandMMA(
triton::gpu::ConvertLayoutOp op, OpAdaptor adaptor,
ConversionPatternRewriter &rewriter, const MmaEncodingAttr &mmaLayout,
const DotOperandEncodingAttr &dotOperandLayout, bool isOuter) const {
auto loc = op.getLoc();
Value src = op.src();
Value dst = op.result();
bool isHMMA = supportMMA(dst, mmaLayout.getVersionMajor());
auto smemObj =
getSharedMemoryObjectFromStruct(loc, adaptor.src(), rewriter);
Value res;
if (!isOuter && mmaLayout.isAmpere() && isHMMA) { // tensor core v2
MMA16816ConversionHelper mmaHelper(src.getType(), mmaLayout,
getThreadId(rewriter, loc), rewriter,
getTypeConverter(), op.getLoc());
if (dotOperandLayout.getOpIdx() == 0) {
// operand $a
res = mmaHelper.loadA(src, smemObj);
} else if (dotOperandLayout.getOpIdx() == 1) {
// operand $b
res = mmaHelper.loadB(src, smemObj);
}
} else if (!isOuter && mmaLayout.isVolta() && isHMMA) { // tensor core v1
DotOpMmaV1ConversionHelper helper(mmaLayout);
bool isMMAv1Row =
dotOperandLayout.getIsMMAv1Row().cast<BoolAttr>().getValue();
auto srcSharedLayout = src.getType()
.cast<RankedTensorType>()
.getEncoding()
.cast<SharedEncodingAttr>();
// Can only convert [1, 0] to row or [0, 1] to col for now
if ((srcSharedLayout.getOrder()[0] == 1 && !isMMAv1Row) ||
(srcSharedLayout.getOrder()[0] == 0 && isMMAv1Row)) {
llvm::errs() << "Unsupported Shared -> DotOperand[MMAv1] conversion\n";
return Value();
}
if (dotOperandLayout.getOpIdx() == 0) { // operand $a
// TODO[Superjomn]: transA is not available here.
bool transA = false;
res = helper.loadA(src, transA, smemObj, getThreadId(rewriter, loc),
loc, rewriter);
} else if (dotOperandLayout.getOpIdx() == 1) { // operand $b
// TODO[Superjomn]: transB is not available here.
bool transB = false;
res = helper.loadB(src, transB, smemObj, getThreadId(rewriter, loc),
loc, rewriter);
}
} else {
assert(false && "Unsupported mma layout found");
}
return res;
}
};
void populateConvertLayoutOpToLLVMPatterns(
mlir::LLVMTypeConverter &typeConverter, RewritePatternSet &patterns,
int numWarps, AxisInfoAnalysis &axisInfoAnalysis,
const Allocation *allocation, Value smem, PatternBenefit benefit) {
patterns.add<ConvertLayoutOpConversion>(typeConverter, allocation, smem,
benefit);
}

View File

@@ -0,0 +1,24 @@
#ifndef TRITON_CONVERSION_TRITONGPU_TO_LLVM_CONVERT_LAYOUT_OP_H
#define TRITON_CONVERSION_TRITONGPU_TO_LLVM_CONVERT_LAYOUT_OP_H
#include "TritonGPUToLLVMBase.h"
using namespace mlir;
using namespace mlir::triton;
using ::mlir::triton::gpu::DotOperandEncodingAttr;
bool isMmaToDotShortcut(MmaEncodingAttr &mmaLayout,
DotOperandEncodingAttr &dotOperandLayout);
void storeBlockedToShared(Value src, Value llSrc, ArrayRef<Value> srcStrides,
ArrayRef<Value> srcIndices, Value dst, Value smemBase,
Type elemPtrTy, Location loc,
ConversionPatternRewriter &rewriter);
void populateConvertLayoutOpToLLVMPatterns(
mlir::LLVMTypeConverter &typeConverter, RewritePatternSet &patterns,
int numWarps, AxisInfoAnalysis &axisInfoAnalysis,
const Allocation *allocation, Value smem, PatternBenefit benefit);
#endif

File diff suppressed because it is too large Load Diff

View File

@@ -0,0 +1,311 @@
#include "DotOpToLLVM.h"
#include "DotOpHelpers.h"
#include "Utility.h"
using namespace mlir;
using namespace mlir::triton;
using ::mlir::LLVM::DotOpFMAConversionHelper;
using ::mlir::LLVM::DotOpMmaV1ConversionHelper;
using ::mlir::LLVM::getElementsFromStruct;
using ::mlir::LLVM::getStructFromElements;
using ::mlir::LLVM::MMA16816ConversionHelper;
using ::mlir::triton::gpu::DotOperandEncodingAttr;
using ::mlir::triton::gpu::MmaEncodingAttr;
struct DotOpConversion : public ConvertTritonGPUOpToLLVMPattern<triton::DotOp> {
using ConvertTritonGPUOpToLLVMPattern<
triton::DotOp>::ConvertTritonGPUOpToLLVMPattern;
LogicalResult
matchAndRewrite(triton::DotOp op, OpAdaptor adaptor,
ConversionPatternRewriter &rewriter) const override {
// D = A * B + C
Value A = op.a();
Value D = op.getResult();
// Here we assume the DotOp's operands always comes from shared memory.
auto AShape = A.getType().cast<RankedTensorType>().getShape();
size_t reduceAxis = 1;
unsigned K = AShape[reduceAxis];
bool isOuter = K == 1;
MmaEncodingAttr mmaLayout = D.getType()
.cast<RankedTensorType>()
.getEncoding()
.dyn_cast<MmaEncodingAttr>();
if (!isOuter && mmaLayout && supportMMA(op, mmaLayout.getVersionMajor())) {
if (mmaLayout.isVolta())
return convertMMA884(op, adaptor, rewriter);
if (mmaLayout.isAmpere())
return convertMMA16816(op, adaptor, rewriter);
llvm::report_fatal_error(
"Unsupported MMA kind found when converting DotOp to LLVM.");
}
if (D.getType()
.cast<RankedTensorType>()
.getEncoding()
.isa<BlockedEncodingAttr>())
return convertFMADot(op, adaptor, rewriter);
llvm::report_fatal_error(
"Unsupported DotOp found when converting TritonGPU to LLVM.");
}
private:
// Convert to mma.m16n8k16
LogicalResult convertMMA16816(triton::DotOp op, OpAdaptor adaptor,
ConversionPatternRewriter &rewriter) const {
auto loc = op.getLoc();
auto mmaLayout = op.getResult()
.getType()
.cast<RankedTensorType>()
.getEncoding()
.cast<MmaEncodingAttr>();
Value A = op.a();
Value B = op.b();
Value C = op.c();
MMA16816ConversionHelper mmaHelper(A.getType(), mmaLayout,
getThreadId(rewriter, loc), rewriter,
getTypeConverter(), loc);
auto ATensorTy = A.getType().cast<RankedTensorType>();
auto BTensorTy = B.getType().cast<RankedTensorType>();
assert(ATensorTy.getEncoding().isa<DotOperandEncodingAttr>() &&
BTensorTy.getEncoding().isa<DotOperandEncodingAttr>() &&
"Both $a and %b should be DotOperand layout.");
Value loadedA, loadedB, loadedC;
loadedA = adaptor.a();
loadedB = adaptor.b();
loadedC = mmaHelper.loadC(op.c(), adaptor.c());
return mmaHelper.convertDot(A, B, C, op.d(), loadedA, loadedB, loadedC, op,
adaptor);
}
/// Convert to mma.m8n8k4
LogicalResult convertMMA884(triton::DotOp op, OpAdaptor adaptor,
ConversionPatternRewriter &rewriter) const {
auto *ctx = op.getContext();
auto loc = op.getLoc();
Value A = op.a();
Value B = op.b();
Value D = op.getResult();
auto mmaLayout = D.getType()
.cast<RankedTensorType>()
.getEncoding()
.cast<MmaEncodingAttr>();
auto ALayout = A.getType()
.cast<RankedTensorType>()
.getEncoding()
.cast<DotOperandEncodingAttr>();
auto BLayout = B.getType()
.cast<RankedTensorType>()
.getEncoding()
.cast<DotOperandEncodingAttr>();
auto ATensorTy = A.getType().cast<RankedTensorType>();
auto BTensorTy = B.getType().cast<RankedTensorType>();
auto DTensorTy = D.getType().cast<RankedTensorType>();
auto AShape = ATensorTy.getShape();
auto BShape = BTensorTy.getShape();
auto DShape = DTensorTy.getShape();
auto wpt = mmaLayout.getWarpsPerCTA();
bool isARow = ALayout.getIsMMAv1Row().cast<BoolAttr>().getValue();
bool isBRow = BLayout.getIsMMAv1Row().cast<BoolAttr>().getValue();
DotOpMmaV1ConversionHelper helper(mmaLayout);
unsigned numM = helper.getNumM(AShape, isARow);
unsigned numN = helper.getNumN(BShape, isBRow);
unsigned NK = AShape[1];
auto has = helper.extractLoadedOperand(adaptor.a(), NK, rewriter);
auto hbs = helper.extractLoadedOperand(adaptor.b(), NK, rewriter);
// Initialize accumulators with external values, the acc holds the
// accumulator value that is shared between the MMA instructions inside a
// DotOp, we can call the order of the values the accumulator-internal
// order.
SmallVector<Value> acc = getElementsFromStruct(loc, adaptor.c(), rewriter);
size_t resSize = acc.size();
// The resVals holds the final result of the DotOp.
// NOTE The current order of resVals is different from acc, we call it the
// accumulator-external order. and
SmallVector<Value> resVals(resSize);
auto getIdx = [&](int m, int n) {
std::vector<size_t> idx{{
(m * 2 + 0) + (n * 4 + 0) * numM, // row0
(m * 2 + 0) + (n * 4 + 1) * numM,
(m * 2 + 1) + (n * 4 + 0) * numM, // row1
(m * 2 + 1) + (n * 4 + 1) * numM,
(m * 2 + 0) + (n * 4 + 2) * numM, // row2
(m * 2 + 0) + (n * 4 + 3) * numM,
(m * 2 + 1) + (n * 4 + 2) * numM, // row3
(m * 2 + 1) + (n * 4 + 3) * numM,
}};
return idx;
};
{ // convert the acc's value from accumuator-external order to
// accumulator-internal order.
SmallVector<Value> accInit(acc.size());
for (unsigned m = 0; m < numM / 2; ++m)
for (unsigned n = 0; n < numN / 2; ++n) {
auto idx = getIdx(m, n);
for (unsigned i = 0; i < 8; ++i)
accInit[idx[i]] = acc[(m * numN / 2 + n) * 8 + i];
}
acc = accInit;
}
auto callMMA = [&](unsigned m, unsigned n, unsigned k) {
auto ha = has.at({m, k});
auto hb = hbs.at({n, k});
PTXBuilder builder;
auto idx = getIdx(m, n);
auto *resOprs = builder.newListOperand(8, "=f");
auto *AOprs = builder.newListOperand({
{ha.first, "r"},
{ha.second, "r"},
});
auto *BOprs = builder.newListOperand({
{hb.first, "r"},
{hb.second, "r"},
});
auto *COprs = builder.newListOperand();
for (int i = 0; i < 8; ++i)
COprs->listAppend(builder.newOperand(acc[idx[i]], std::to_string(i)));
auto mma = builder.create("mma.sync.aligned.m8n8k4")
->o(isARow ? "row" : "col")
.o(isBRow ? "row" : "col")
.o("f32.f16.f16.f32");
mma(resOprs, AOprs, BOprs, COprs);
Value res =
builder.launch(rewriter, loc, helper.getMmaRetType(ATensorTy));
auto getIntAttr = [&](int v) {
return ArrayAttr::get(ctx, {IntegerAttr::get(i32_ty, v)});
};
for (unsigned i = 0; i < 8; i++) {
Value elem = extract_val(f32_ty, res, getIntAttr(i));
acc[idx[i]] = elem;
resVals[(m * numN / 2 + n) * 8 + i] = elem;
}
};
for (unsigned k = 0; k < NK; k += 4)
for (unsigned m = 0; m < numM / 2; ++m)
for (unsigned n = 0; n < numN / 2; ++n) {
callMMA(m, n, k);
}
Type structTy = LLVM::LLVMStructType::getLiteral(
ctx, SmallVector<Type>(resSize, type::f32Ty(ctx)));
Value res = getStructFromElements(loc, resVals, rewriter, structTy);
rewriter.replaceOp(op, res);
return success();
}
LogicalResult convertFMADot(triton::DotOp op, OpAdaptor adaptor,
ConversionPatternRewriter &rewriter) const {
auto *ctx = rewriter.getContext();
auto loc = op.getLoc();
auto threadId = getThreadId(rewriter, loc);
auto A = op.a();
auto B = op.b();
auto C = op.c();
auto D = op.getResult();
auto aTensorTy = A.getType().cast<RankedTensorType>();
auto bTensorTy = B.getType().cast<RankedTensorType>();
auto cTensorTy = C.getType().cast<RankedTensorType>();
auto dTensorTy = D.getType().cast<RankedTensorType>();
auto aShape = aTensorTy.getShape();
auto bShape = bTensorTy.getShape();
auto cShape = cTensorTy.getShape();
BlockedEncodingAttr dLayout =
dTensorTy.getEncoding().cast<BlockedEncodingAttr>();
auto order = dLayout.getOrder();
auto cc = getElementsFromStruct(loc, adaptor.c(), rewriter);
DotOpFMAConversionHelper helper(dLayout);
Value llA = adaptor.a();
Value llB = adaptor.b();
auto sizePerThread = getSizePerThread(dLayout);
auto shapePerCTA = getShapePerCTA(dLayout);
int K = aShape[1];
int M = aShape[0];
int N = bShape[1];
int mShapePerCTA =
order[0] == 1 ? shapePerCTA[order[1]] : shapePerCTA[order[0]];
int mSizePerThread =
order[0] == 1 ? sizePerThread[order[1]] : sizePerThread[order[0]];
int nShapePerCTA =
order[0] == 0 ? shapePerCTA[order[1]] : shapePerCTA[order[0]];
int nSizePerThread =
order[0] == 0 ? sizePerThread[order[1]] : sizePerThread[order[0]];
auto has = helper.getValueTableFromStruct(llA, K, M, mShapePerCTA,
mSizePerThread, rewriter, loc);
auto hbs = helper.getValueTableFromStruct(llB, K, N, nShapePerCTA,
nSizePerThread, rewriter, loc);
SmallVector<Value> ret = cc;
bool isCRow = order[0] == 1;
for (unsigned k = 0; k < K; k++) {
for (unsigned m = 0; m < M; m += mShapePerCTA)
for (unsigned n = 0; n < N; n += nShapePerCTA)
for (unsigned mm = 0; mm < mSizePerThread; ++mm)
for (unsigned nn = 0; nn < nSizePerThread; ++nn) {
int mIdx = m / mShapePerCTA * mSizePerThread + mm;
int nIdx = n / nShapePerCTA * nSizePerThread + nn;
int z = isCRow ? mIdx * N / nShapePerCTA * mSizePerThread + nIdx
: nIdx * M / mShapePerCTA * nSizePerThread + mIdx;
ret[z] = rewriter.create<LLVM::FMulAddOp>(
loc, has[{m + mm, k}], hbs[{n + nn, k}], ret[z]);
}
}
auto res = getStructFromElements(
loc, ret, rewriter,
struct_ty(SmallVector<Type>(ret.size(), ret[0].getType())));
rewriter.replaceOp(op, res);
return success();
}
};
void populateDotOpToLLVMPatterns(mlir::LLVMTypeConverter &typeConverter,
RewritePatternSet &patterns, int numWarps,
AxisInfoAnalysis &axisInfoAnalysis,
const Allocation *allocation, Value smem,
PatternBenefit benefit) {
patterns.add<DotOpConversion>(typeConverter, allocation, smem, benefit);
}

View File

@@ -0,0 +1,15 @@
#ifndef TRITON_CONVERSION_TRITONGPU_TO_LLVM_DOT_OP_H
#define TRITON_CONVERSION_TRITONGPU_TO_LLVM_DOT_OP_H
#include "TritonGPUToLLVMBase.h"
using namespace mlir;
using namespace mlir::triton;
void populateDotOpToLLVMPatterns(mlir::LLVMTypeConverter &typeConverter,
RewritePatternSet &patterns, int numWarps,
AxisInfoAnalysis &axisInfoAnalysis,
const Allocation *allocation, Value smem,
PatternBenefit benefit);
#endif

View File

@@ -0,0 +1,908 @@
#include "ElementwiseOpToLLVM.h"
using namespace mlir;
using namespace mlir::triton;
using ::mlir::LLVM::getElementsFromStruct;
using ::mlir::LLVM::getStructFromElements;
using ::mlir::triton::gpu::getElemsPerThread;
struct FpToFpOpConversion
: public ConvertTritonGPUOpToLLVMPattern<triton::FpToFpOp> {
using ConvertTritonGPUOpToLLVMPattern<
triton::FpToFpOp>::ConvertTritonGPUOpToLLVMPattern;
static SmallVector<Value>
convertFp8x4ToFp16x4(Location loc, ConversionPatternRewriter &rewriter,
const Value &v0, const Value &v1, const Value &v2,
const Value &v3) {
auto ctx = rewriter.getContext();
auto fp8x4VecTy = vec_ty(i8_ty, 4);
Value fp8x4Vec = undef(fp8x4VecTy);
fp8x4Vec = insert_element(fp8x4VecTy, fp8x4Vec, v0, i32_val(0));
fp8x4Vec = insert_element(fp8x4VecTy, fp8x4Vec, v1, i32_val(1));
fp8x4Vec = insert_element(fp8x4VecTy, fp8x4Vec, v2, i32_val(2));
fp8x4Vec = insert_element(fp8x4VecTy, fp8x4Vec, v3, i32_val(3));
fp8x4Vec = bitcast(fp8x4Vec, i32_ty);
PTXBuilder builder;
auto *ptxAsm = "{ \n"
".reg .b32 a<2>, b<2>; \n"
"prmt.b32 a0, 0, $2, 0x5040; \n"
"prmt.b32 a1, 0, $2, 0x7060; \n"
"lop3.b32 b0, a0, 0x7fff7fff, 0, 0xc0; \n"
"lop3.b32 b1, a1, 0x7fff7fff, 0, 0xc0; \n"
"shr.b32 b0, b0, 1; \n"
"shr.b32 b1, b1, 1; \n"
"lop3.b32 $0, b0, 0x80008000, a0, 0xf8; \n"
"lop3.b32 $1, b1, 0x80008000, a1, 0xf8; \n"
"}";
auto &call = *builder.create(ptxAsm);
auto *o0 = builder.newOperand("=r");
auto *o1 = builder.newOperand("=r");
auto *i = builder.newOperand(fp8x4Vec, "r");
call({o0, o1, i}, /*onlyAttachMLIRArgs=*/true);
auto fp16x2VecTy = vec_ty(f16_ty, 2);
auto fp16x2x2StructTy =
struct_ty(SmallVector<Type>{fp16x2VecTy, fp16x2VecTy});
auto fp16x2x2Struct =
builder.launch(rewriter, loc, fp16x2x2StructTy, false);
auto fp16x2Vec0 =
extract_val(fp16x2VecTy, fp16x2x2Struct, rewriter.getI32ArrayAttr({0}));
auto fp16x2Vec1 =
extract_val(fp16x2VecTy, fp16x2x2Struct, rewriter.getI32ArrayAttr({1}));
return {extract_element(f16_ty, fp16x2Vec0, i32_val(0)),
extract_element(f16_ty, fp16x2Vec0, i32_val(1)),
extract_element(f16_ty, fp16x2Vec1, i32_val(0)),
extract_element(f16_ty, fp16x2Vec1, i32_val(1))};
}
static SmallVector<Value>
convertFp16x4ToFp8x4(Location loc, ConversionPatternRewriter &rewriter,
const Value &v0, const Value &v1, const Value &v2,
const Value &v3) {
auto ctx = rewriter.getContext();
auto fp16x2VecTy = vec_ty(f16_ty, 2);
Value fp16x2Vec0 = undef(fp16x2VecTy);
Value fp16x2Vec1 = undef(fp16x2VecTy);
fp16x2Vec0 = insert_element(fp16x2VecTy, fp16x2Vec0, v0, i32_val(0));
fp16x2Vec0 = insert_element(fp16x2VecTy, fp16x2Vec0, v1, i32_val(1));
fp16x2Vec1 = insert_element(fp16x2VecTy, fp16x2Vec1, v2, i32_val(0));
fp16x2Vec1 = insert_element(fp16x2VecTy, fp16x2Vec1, v3, i32_val(1));
fp16x2Vec0 = bitcast(fp16x2Vec0, i32_ty);
fp16x2Vec1 = bitcast(fp16x2Vec1, i32_ty);
PTXBuilder builder;
auto *ptxAsm = "{ \n"
".reg .b32 a<2>, b<2>; \n"
"shl.b32 a0, $1, 1; \n"
"shl.b32 a1, $2, 1; \n"
"lop3.b32 a0, a0, 0x7fff7fff, 0, 0xc0; \n"
"lop3.b32 a1, a1, 0x7fff7fff, 0, 0xc0; \n"
"add.u32 a0, a0, 0x00800080; \n"
"add.u32 a1, a1, 0x00800080; \n"
"lop3.b32 b0, $1, 0x80008000, a0, 0xea; \n"
"lop3.b32 b1, $2, 0x80008000, a1, 0xea; \n"
"prmt.b32 $0, b0, b1, 0x7531; \n"
"}";
auto &call = *builder.create(ptxAsm);
auto *o = builder.newOperand("=r");
auto *i0 = builder.newOperand(fp16x2Vec0, "r");
auto *i1 = builder.newOperand(fp16x2Vec1, "r");
call({o, i0, i1}, /*onlyAttachMLIRArgs=*/true);
auto fp8x4VecTy = vec_ty(i8_ty, 4);
auto fp8x4Vec = builder.launch(rewriter, loc, fp8x4VecTy, false);
return {extract_element(i8_ty, fp8x4Vec, i32_val(0)),
extract_element(i8_ty, fp8x4Vec, i32_val(1)),
extract_element(i8_ty, fp8x4Vec, i32_val(2)),
extract_element(i8_ty, fp8x4Vec, i32_val(3))};
}
static SmallVector<Value>
convertFp8x4ToBf16x4(Location loc, ConversionPatternRewriter &rewriter,
const Value &v0, const Value &v1, const Value &v2,
const Value &v3) {
auto ctx = rewriter.getContext();
auto fp8x4VecTy = vec_ty(i8_ty, 4);
Value fp8x4Vec = undef(fp8x4VecTy);
fp8x4Vec = insert_element(fp8x4VecTy, fp8x4Vec, v0, i32_val(0));
fp8x4Vec = insert_element(fp8x4VecTy, fp8x4Vec, v1, i32_val(1));
fp8x4Vec = insert_element(fp8x4VecTy, fp8x4Vec, v2, i32_val(2));
fp8x4Vec = insert_element(fp8x4VecTy, fp8x4Vec, v3, i32_val(3));
fp8x4Vec = bitcast(fp8x4Vec, i32_ty);
PTXBuilder builder;
auto *ptxAsm = "{ \n"
".reg .b32 a<2>, sign<2>, nosign<2>, b<2>; \n"
"prmt.b32 a0, 0, $2, 0x5040; \n"
"prmt.b32 a1, 0, $2, 0x7060; \n"
"and.b32 sign0, a0, 0x80008000; \n"
"and.b32 sign1, a1, 0x80008000; \n"
"and.b32 nosign0, a0, 0x7fff7fff; \n"
"and.b32 nosign1, a1, 0x7fff7fff; \n"
"shr.b32 nosign0, nosign0, 4; \n"
"shr.b32 nosign1, nosign1, 4; \n"
"add.u32 nosign0, nosign0, 0x38003800; \n"
"add.u32 nosign1, nosign1, 0x38003800; \n"
"or.b32 $0, sign0, nosign0; \n"
"or.b32 $1, sign1, nosign1; \n"
"}";
auto &call = *builder.create(ptxAsm);
auto *o0 = builder.newOperand("=r");
auto *o1 = builder.newOperand("=r");
auto *i = builder.newOperand(fp8x4Vec, "r");
call({o0, o1, i}, /* onlyAttachMLIRArgs */ true);
auto bf16x2VecTy = vec_ty(i16_ty, 2);
auto bf16x2x2StructTy =
struct_ty(SmallVector<Type>{bf16x2VecTy, bf16x2VecTy});
auto bf16x2x2Struct =
builder.launch(rewriter, loc, bf16x2x2StructTy, false);
auto bf16x2Vec0 =
extract_val(bf16x2VecTy, bf16x2x2Struct, rewriter.getI32ArrayAttr({0}));
auto bf16x2Vec1 =
extract_val(bf16x2VecTy, bf16x2x2Struct, rewriter.getI32ArrayAttr({1}));
return {extract_element(i16_ty, bf16x2Vec0, i32_val(0)),
extract_element(i16_ty, bf16x2Vec0, i32_val(1)),
extract_element(i16_ty, bf16x2Vec1, i32_val(0)),
extract_element(i16_ty, bf16x2Vec1, i32_val(1))};
}
static SmallVector<Value>
convertBf16x4ToFp8x4(Location loc, ConversionPatternRewriter &rewriter,
const Value &v0, const Value &v1, const Value &v2,
const Value &v3) {
auto ctx = rewriter.getContext();
auto bf16x2VecTy = vec_ty(i16_ty, 2);
Value bf16x2Vec0 = undef(bf16x2VecTy);
Value bf16x2Vec1 = undef(bf16x2VecTy);
bf16x2Vec0 = insert_element(bf16x2VecTy, bf16x2Vec0, v0, i32_val(0));
bf16x2Vec0 = insert_element(bf16x2VecTy, bf16x2Vec0, v1, i32_val(1));
bf16x2Vec1 = insert_element(bf16x2VecTy, bf16x2Vec1, v2, i32_val(0));
bf16x2Vec1 = insert_element(bf16x2VecTy, bf16x2Vec1, v3, i32_val(1));
bf16x2Vec0 = bitcast(bf16x2Vec0, i32_ty);
bf16x2Vec1 = bitcast(bf16x2Vec1, i32_ty);
PTXBuilder builder;
auto *ptxAsm = "{ \n"
".reg .u32 sign, sign<2>, nosign, nosign<2>; \n"
".reg .u32 fp8_min, fp8_max, rn_, zero; \n"
"mov.u32 fp8_min, 0x38003800; \n"
"mov.u32 fp8_max, 0x3ff03ff0; \n"
"mov.u32 rn_, 0x80008; \n"
"mov.u32 zero, 0; \n"
"and.b32 sign0, $1, 0x80008000; \n"
"and.b32 sign1, $2, 0x80008000; \n"
"prmt.b32 sign, sign0, sign1, 0x7531; \n"
"and.b32 nosign0, $1, 0x7fff7fff; \n"
"and.b32 nosign1, $2, 0x7fff7fff; \n"
".reg .u32 nosign_0_<2>, nosign_1_<2>; \n"
"and.b32 nosign_0_0, nosign0, 0xffff0000; \n"
"max.u32 nosign_0_0, nosign_0_0, 0x38000000; \n"
"min.u32 nosign_0_0, nosign_0_0, 0x3ff00000; \n"
"and.b32 nosign_0_1, nosign0, 0x0000ffff; \n"
"max.u32 nosign_0_1, nosign_0_1, 0x3800; \n"
"min.u32 nosign_0_1, nosign_0_1, 0x3ff0; \n"
"or.b32 nosign0, nosign_0_0, nosign_0_1; \n"
"and.b32 nosign_1_0, nosign1, 0xffff0000; \n"
"max.u32 nosign_1_0, nosign_1_0, 0x38000000; \n"
"min.u32 nosign_1_0, nosign_1_0, 0x3ff00000; \n"
"and.b32 nosign_1_1, nosign1, 0x0000ffff; \n"
"max.u32 nosign_1_1, nosign_1_1, 0x3800; \n"
"min.u32 nosign_1_1, nosign_1_1, 0x3ff0; \n"
"or.b32 nosign1, nosign_1_0, nosign_1_1; \n"
"add.u32 nosign0, nosign0, rn_; \n"
"add.u32 nosign1, nosign1, rn_; \n"
"sub.u32 nosign0, nosign0, 0x38003800; \n"
"sub.u32 nosign1, nosign1, 0x38003800; \n"
"shr.u32 nosign0, nosign0, 4; \n"
"shr.u32 nosign1, nosign1, 4; \n"
"prmt.b32 nosign, nosign0, nosign1, 0x6420; \n"
"or.b32 $0, nosign, sign; \n"
"}";
auto &call = *builder.create(ptxAsm);
auto *o = builder.newOperand("=r");
auto *i0 = builder.newOperand(bf16x2Vec0, "r");
auto *i1 = builder.newOperand(bf16x2Vec1, "r");
call({o, i0, i1}, /*onlyAttachMLIRArgs=*/true);
auto fp8x4VecTy = vec_ty(i8_ty, 4);
auto fp8x4Vec = builder.launch(rewriter, loc, fp8x4VecTy, false);
return {extract_element(i8_ty, fp8x4Vec, i32_val(0)),
extract_element(i8_ty, fp8x4Vec, i32_val(1)),
extract_element(i8_ty, fp8x4Vec, i32_val(2)),
extract_element(i8_ty, fp8x4Vec, i32_val(3))};
}
static SmallVector<Value>
convertFp8x4ToFp32x4(Location loc, ConversionPatternRewriter &rewriter,
const Value &v0, const Value &v1, const Value &v2,
const Value &v3) {
auto fp16Values = convertFp8x4ToFp16x4(loc, rewriter, v0, v1, v2, v3);
return {rewriter.create<LLVM::FPExtOp>(loc, f32_ty, fp16Values[0]),
rewriter.create<LLVM::FPExtOp>(loc, f32_ty, fp16Values[1]),
rewriter.create<LLVM::FPExtOp>(loc, f32_ty, fp16Values[2]),
rewriter.create<LLVM::FPExtOp>(loc, f32_ty, fp16Values[3])};
}
static SmallVector<Value>
convertFp32x4ToFp8x4(Location loc, ConversionPatternRewriter &rewriter,
const Value &v0, const Value &v1, const Value &v2,
const Value &v3) {
auto c0 = rewriter.create<LLVM::FPTruncOp>(loc, f16_ty, v0);
auto c1 = rewriter.create<LLVM::FPTruncOp>(loc, f16_ty, v1);
auto c2 = rewriter.create<LLVM::FPTruncOp>(loc, f16_ty, v2);
auto c3 = rewriter.create<LLVM::FPTruncOp>(loc, f16_ty, v3);
return convertFp16x4ToFp8x4(loc, rewriter, c0, c1, c2, c3);
}
static SmallVector<Value>
convertFp8x4ToFp64x4(Location loc, ConversionPatternRewriter &rewriter,
const Value &v0, const Value &v1, const Value &v2,
const Value &v3) {
auto fp16Values = convertFp8x4ToFp16x4(loc, rewriter, v0, v1, v2, v3);
return {rewriter.create<LLVM::FPExtOp>(loc, f64_ty, fp16Values[0]),
rewriter.create<LLVM::FPExtOp>(loc, f64_ty, fp16Values[1]),
rewriter.create<LLVM::FPExtOp>(loc, f64_ty, fp16Values[2]),
rewriter.create<LLVM::FPExtOp>(loc, f64_ty, fp16Values[3])};
}
static SmallVector<Value>
convertFp64x4ToFp8x4(Location loc, ConversionPatternRewriter &rewriter,
const Value &v0, const Value &v1, const Value &v2,
const Value &v3) {
auto c0 = rewriter.create<LLVM::FPTruncOp>(loc, f16_ty, v0);
auto c1 = rewriter.create<LLVM::FPTruncOp>(loc, f16_ty, v1);
auto c2 = rewriter.create<LLVM::FPTruncOp>(loc, f16_ty, v2);
auto c3 = rewriter.create<LLVM::FPTruncOp>(loc, f16_ty, v3);
return convertFp16x4ToFp8x4(loc, rewriter, c0, c1, c2, c3);
}
static Value convertBf16ToFp32(Location loc,
ConversionPatternRewriter &rewriter,
const Value &v) {
#ifdef USE_ROCM
auto as_int16 = bitcast(v, i16_ty);
auto as_int32 = zext(i32_ty, as_int16);
auto shifted = shl(i32_ty, as_int32, i32_val(16));
return(bitcast(shifted, f32_ty));
#else
PTXBuilder builder;
auto &cvt = *builder.create("cvt.rn.f32.bf16");
auto res = builder.newOperand("=r");
auto operand = builder.newOperand(v, "h");
cvt(res, operand);
return builder.launch(rewriter, loc, f32_ty, false);
#endif
}
static Value convertFp32ToBf16(Location loc,
ConversionPatternRewriter &rewriter,
const Value &v) {
#ifdef USE_ROCM
auto as_int32 = bitcast(v, i32_ty);
auto shifted = lshr(i32_ty, as_int32, i32_val(16));
auto truncated = trunc(i16_ty, shifted);
return(bitcast(truncated, i16_ty));
#else
PTXBuilder builder;
auto &cvt = *builder.create("cvt.rn.bf16.f32");
auto res = builder.newOperand("=h");
auto operand = builder.newOperand(v, "r");
cvt(res, operand);
// TODO: This is a hack to get the right type. We should be able to invoke
// the type converter
return builder.launch(rewriter, loc, i16_ty, false);
#endif
}
LogicalResult
matchAndRewrite(triton::FpToFpOp op, OpAdaptor adaptor,
ConversionPatternRewriter &rewriter) const override {
auto srcTensorType = op.from().getType().cast<mlir::RankedTensorType>();
auto dstTensorType = op.result().getType().cast<mlir::RankedTensorType>();
auto srcEltType = srcTensorType.getElementType();
auto dstEltType = dstTensorType.getElementType();
auto loc = op->getLoc();
auto elems = getElemsPerThread(dstTensorType);
SmallVector<Value> resultVals;
// Select convertor
if (srcEltType.isa<triton::Float8Type>() ||
dstEltType.isa<triton::Float8Type>()) {
std::function<SmallVector<Value>(Location, ConversionPatternRewriter &,
const Value &, const Value &,
const Value &, const Value &)>
convertor;
if (srcEltType.isa<triton::Float8Type>() && dstEltType.isF16()) {
convertor = convertFp8x4ToFp16x4;
} else if (srcEltType.isF16() && dstEltType.isa<triton::Float8Type>()) {
convertor = convertFp16x4ToFp8x4;
} else if (srcEltType.isa<triton::Float8Type>() && dstEltType.isBF16()) {
convertor = convertFp8x4ToBf16x4;
} else if (srcEltType.isBF16() && dstEltType.isa<triton::Float8Type>()) {
convertor = convertBf16x4ToFp8x4;
} else if (srcEltType.isa<triton::Float8Type>() && dstEltType.isF32()) {
convertor = convertFp8x4ToFp32x4;
} else if (srcEltType.isF32() && dstEltType.isa<triton::Float8Type>()) {
convertor = convertFp32x4ToFp8x4;
} else if (srcEltType.isa<triton::Float8Type>() && dstEltType.isF64()) {
convertor = convertFp8x4ToFp64x4;
} else if (srcEltType.isF64() && dstEltType.isa<triton::Float8Type>()) {
convertor = convertFp64x4ToFp8x4;
} else {
assert(false && "unsupported fp8 casting");
}
// Vectorized casting
assert(elems % 4 == 0 &&
"FP8 casting only support tensors with 4-aligned sizes");
auto elements = getElementsFromStruct(loc, adaptor.from(), rewriter);
for (size_t i = 0; i < elems; i += 4) {
auto converted = convertor(loc, rewriter, elements[i], elements[i + 1],
elements[i + 2], elements[i + 3]);
resultVals.append(converted);
}
} else if (srcEltType.isBF16() && dstEltType.isF32()) {
resultVals.emplace_back(convertBf16ToFp32(loc, rewriter, adaptor.from()));
} else if (srcEltType.isF32() && dstEltType.isBF16()) {
resultVals.emplace_back(convertFp32ToBf16(loc, rewriter, adaptor.from()));
} else {
assert(false && "unsupported type casting");
}
assert(resultVals.size() == elems);
auto convertedDstTensorType =
this->getTypeConverter()->convertType(dstTensorType);
auto result = getStructFromElements(loc, resultVals, rewriter,
convertedDstTensorType);
rewriter.replaceOp(op, result);
return success();
}
};
template <typename OP>
Value EmitDualBF16ElementwiseOp(Location loc,
ConversionPatternRewriter &rewriter,
ValueRange operands) {
auto v0 = FpToFpOpConversion::convertBf16ToFp32(loc, rewriter, operands[0]);
auto v1 = FpToFpOpConversion::convertBf16ToFp32(loc, rewriter, operands[1]);
auto result = rewriter.create<OP>(loc, f32_ty, v0, v1);
return FpToFpOpConversion::convertFp32ToBf16(loc, rewriter, result);
}
template <typename SourceOp, typename ConcreteT>
class ElementwiseOpConversionBase
: public ConvertTritonGPUOpToLLVMPattern<SourceOp> {
public:
using OpAdaptor = typename SourceOp::Adaptor;
explicit ElementwiseOpConversionBase(LLVMTypeConverter &typeConverter,
PatternBenefit benefit = 1)
: ConvertTritonGPUOpToLLVMPattern<SourceOp>(typeConverter, benefit) {}
LogicalResult
matchAndRewrite(SourceOp op, OpAdaptor adaptor,
ConversionPatternRewriter &rewriter) const override {
auto resultTy = op.getType();
Location loc = op->getLoc();
unsigned elems = getElemsPerThread(resultTy);
auto resultElementTy = getElementTypeOrSelf(resultTy);
Type elemTy = this->getTypeConverter()->convertType(resultElementTy);
SmallVector<Type> types(elems, elemTy);
Type structTy = this->getTypeConverter()->convertType(resultTy);
auto *concreteThis = static_cast<const ConcreteT *>(this);
auto operands = getOperands(rewriter, adaptor, elems, loc);
SmallVector<Value> resultVals(elems);
for (unsigned i = 0; i < elems; ++i) {
resultVals[i] = concreteThis->createDestOp(op, adaptor, rewriter, elemTy,
operands[i], loc);
if (!bool(resultVals[i]))
return failure();
}
Value view = getStructFromElements(loc, resultVals, rewriter, structTy);
rewriter.replaceOp(op, view);
return success();
}
protected:
SmallVector<SmallVector<Value>>
getOperands(ConversionPatternRewriter &rewriter, OpAdaptor adaptor,
const unsigned elems, Location loc) const {
SmallVector<SmallVector<Value>> operands(elems);
for (auto operand : adaptor.getOperands()) {
auto sub_operands = getElementsFromStruct(loc, operand, rewriter);
for (size_t i = 0; i < elems; ++i) {
operands[i].push_back(sub_operands[i]);
}
}
return operands;
}
};
template <typename SourceOp, typename DestOp>
struct ElementwiseOpConversion
: public ElementwiseOpConversionBase<
SourceOp, ElementwiseOpConversion<SourceOp, DestOp>> {
using Base =
ElementwiseOpConversionBase<SourceOp,
ElementwiseOpConversion<SourceOp, DestOp>>;
using Base::Base;
using OpAdaptor = typename Base::OpAdaptor;
explicit ElementwiseOpConversion(LLVMTypeConverter &typeConverter,
PatternBenefit benefit = 1)
: ElementwiseOpConversionBase<SourceOp, ElementwiseOpConversion>(
typeConverter, benefit) {}
// An interface to support variant DestOp builder.
DestOp createDestOp(SourceOp op, OpAdaptor adaptor,
ConversionPatternRewriter &rewriter, Type elemTy,
ValueRange operands, Location loc) const {
return rewriter.create<DestOp>(loc, elemTy, operands,
adaptor.getAttributes().getValue());
}
};
struct CmpIOpConversion
: public ElementwiseOpConversionBase<triton::gpu::CmpIOp,
CmpIOpConversion> {
using Base =
ElementwiseOpConversionBase<triton::gpu::CmpIOp, CmpIOpConversion>;
using Base::Base;
using Adaptor = typename Base::OpAdaptor;
// An interface to support variant DestOp builder.
LLVM::ICmpOp createDestOp(triton::gpu::CmpIOp op, OpAdaptor adaptor,
ConversionPatternRewriter &rewriter, Type elemTy,
ValueRange operands, Location loc) const {
return rewriter.create<LLVM::ICmpOp>(
loc, elemTy, ArithCmpIPredicateToLLVM(op.predicate()), operands[0],
operands[1]);
}
static LLVM::ICmpPredicate
ArithCmpIPredicateToLLVM(arith::CmpIPredicate predicate) {
switch (predicate) {
#define __PRED_ENUM(item__) \
case arith::CmpIPredicate::item__: \
return LLVM::ICmpPredicate::item__
__PRED_ENUM(eq);
__PRED_ENUM(ne);
__PRED_ENUM(sgt);
__PRED_ENUM(sge);
__PRED_ENUM(slt);
__PRED_ENUM(sle);
__PRED_ENUM(ugt);
__PRED_ENUM(uge);
__PRED_ENUM(ult);
__PRED_ENUM(ule);
#undef __PRED_ENUM
}
return LLVM::ICmpPredicate::eq;
}
};
struct CmpFOpConversion
: public ElementwiseOpConversionBase<triton::gpu::CmpFOp,
CmpFOpConversion> {
using Base =
ElementwiseOpConversionBase<triton::gpu::CmpFOp, CmpFOpConversion>;
using Base::Base;
using Adaptor = typename Base::OpAdaptor;
// An interface to support variant DestOp builder.
static LLVM::FCmpOp createDestOp(triton::gpu::CmpFOp op, OpAdaptor adaptor,
ConversionPatternRewriter &rewriter,
Type elemTy, ValueRange operands,
Location loc) {
return rewriter.create<LLVM::FCmpOp>(
loc, elemTy, ArithCmpFPredicateToLLVM(op.predicate()), operands[0],
operands[1]);
}
static LLVM::FCmpPredicate
ArithCmpFPredicateToLLVM(arith::CmpFPredicate predicate) {
switch (predicate) {
#define __PRED_ENUM(item__, item1__) \
case arith::CmpFPredicate::item__: \
return LLVM::FCmpPredicate::item1__
__PRED_ENUM(OEQ, oeq);
__PRED_ENUM(ONE, one);
__PRED_ENUM(OGT, ogt);
__PRED_ENUM(OGE, oge);
__PRED_ENUM(OLT, olt);
__PRED_ENUM(OLE, ole);
__PRED_ENUM(ORD, ord);
__PRED_ENUM(UEQ, ueq);
__PRED_ENUM(UGT, ugt);
__PRED_ENUM(UGE, uge);
__PRED_ENUM(ULT, ult);
__PRED_ENUM(ULE, ule);
__PRED_ENUM(UNE, une);
__PRED_ENUM(UNO, uno);
__PRED_ENUM(AlwaysTrue, _true);
__PRED_ENUM(AlwaysFalse, _false);
#undef __PRED_ENUM
}
return LLVM::FCmpPredicate::_true;
}
};
struct ExtElemwiseOpConversion
: public ElementwiseOpConversionBase<triton::ExtElemwiseOp,
ExtElemwiseOpConversion> {
using Base = ElementwiseOpConversionBase<triton::ExtElemwiseOp,
ExtElemwiseOpConversion>;
using Base::Base;
using Adaptor = typename Base::OpAdaptor;
Value createDestOp(triton::ExtElemwiseOp op, OpAdaptor adaptor,
ConversionPatternRewriter &rewriter, Type elemTy,
ValueRange operands, Location loc) const {
StringRef funcName = op.symbol();
if (funcName.empty())
llvm::errs() << "ExtElemwiseOpConversion";
Type funcType = getFunctionType(elemTy, operands);
LLVM::LLVMFuncOp funcOp =
appendOrGetFuncOp(rewriter, op, funcName, funcType);
return rewriter.create<LLVM::CallOp>(loc, funcOp, operands).getResult(0);
}
private:
Type getFunctionType(Type resultType, ValueRange operands) const {
SmallVector<Type> operandTypes(operands.getTypes());
return LLVM::LLVMFunctionType::get(resultType, operandTypes);
}
LLVM::LLVMFuncOp appendOrGetFuncOp(ConversionPatternRewriter &rewriter,
triton::ExtElemwiseOp op,
StringRef funcName, Type funcType) const {
using LLVM::LLVMFuncOp;
auto funcAttr = StringAttr::get(op->getContext(), funcName);
Operation *funcOp = SymbolTable::lookupNearestSymbolFrom(op, funcAttr);
if (funcOp)
return cast<LLVMFuncOp>(*funcOp);
mlir::OpBuilder b(op->getParentOfType<LLVMFuncOp>());
auto ret = b.create<LLVMFuncOp>(op->getLoc(), funcName, funcType);
ret.getOperation()->setAttr(
"libname", StringAttr::get(op->getContext(), op.libname()));
ret.getOperation()->setAttr(
"libpath", StringAttr::get(op->getContext(), op.libpath()));
return ret;
}
};
struct FDivOpConversion
: ElementwiseOpConversionBase<mlir::arith::DivFOp, FDivOpConversion> {
using Base =
ElementwiseOpConversionBase<mlir::arith::DivFOp, FDivOpConversion>;
using Base::Base;
using Adaptor = typename Base::OpAdaptor;
Value createDestOp(mlir::arith::DivFOp op, OpAdaptor adaptor,
ConversionPatternRewriter &rewriter, Type elemTy,
ValueRange operands, Location loc) const {
#ifdef USE_ROCM
return rewriter.create<LLVM::FDivOp>(loc, elemTy, operands[0],
operands[1]);
#else
PTXBuilder ptxBuilder;
auto &fdiv = *ptxBuilder.create<PTXInstr>("div");
unsigned bitwidth = elemTy.getIntOrFloatBitWidth();
if (32 == bitwidth) {
fdiv.o("full").o("f32");
} else if (64 == bitwidth) {
fdiv.o("rn").o("f64");
} else {
assert(0 && bitwidth && "not supported");
}
auto res = ptxBuilder.newOperand(bitwidth == 32 ? "=r" : "=l");
auto lhs = ptxBuilder.newOperand(operands[0], bitwidth == 32 ? "r" : "l");
auto rhs = ptxBuilder.newOperand(operands[1], bitwidth == 32 ? "r" : "l");
fdiv(res, lhs, rhs);
Value ret = ptxBuilder.launch(rewriter, loc, elemTy, false);
return ret;
#endif
}
};
struct FMulOpConversion
: ElementwiseOpConversionBase<mlir::arith::MulFOp, FMulOpConversion> {
using Base =
ElementwiseOpConversionBase<mlir::arith::MulFOp, FMulOpConversion>;
using Base::Base;
using Adaptor = typename Base::OpAdaptor;
Value createDestOp(mlir::arith::MulFOp op, OpAdaptor adaptor,
ConversionPatternRewriter &rewriter, Type elemTy,
ValueRange operands, Location loc) const {
auto lhsElemTy = getElementType(op.getLhs());
auto rhsElemTy = getElementType(op.getRhs());
if (lhsElemTy.isBF16() && rhsElemTy.isBF16()) {
#ifdef USE_ROCM
return EmitDualBF16ElementwiseOp<LLVM::FMulOp>(loc, rewriter, operands);
#else
PTXBuilder builder;
auto ptxAsm = " { .reg .b16 c; \n"
" mov.b16 c, 0x8000U; \n" // 0.0
" fma.rn.bf16 $0, $1, $2, c; } \n";
auto &fMul = *builder.create<PTXInstr>(ptxAsm);
auto res = builder.newOperand("=h");
auto lhs = builder.newOperand(operands[0], "h");
auto rhs = builder.newOperand(operands[1], "h");
fMul({res, lhs, rhs}, /*onlyAttachMLIRArgs=*/true);
return builder.launch(rewriter, loc, i16_ty, false);
#endif
} else {
return rewriter.create<LLVM::FMulOp>(loc, elemTy, operands[0],
operands[1]);
}
}
};
struct FAddOpConversion
: ElementwiseOpConversionBase<mlir::arith::AddFOp, FAddOpConversion> {
using Base =
ElementwiseOpConversionBase<mlir::arith::AddFOp, FAddOpConversion>;
using Base::Base;
using Adaptor = typename Base::OpAdaptor;
Value createDestOp(mlir::arith::AddFOp op, OpAdaptor adaptor,
ConversionPatternRewriter &rewriter, Type elemTy,
ValueRange operands, Location loc) const {
auto lhsElemTy = getElementType(op.getLhs());
auto rhsElemTy = getElementType(op.getRhs());
if (lhsElemTy.isBF16() && rhsElemTy.isBF16()) {
#ifdef USE_ROCM
return EmitDualBF16ElementwiseOp<LLVM::FAddOp>(loc, rewriter, operands);
#else
PTXBuilder builder;
auto ptxAsm = "{ .reg .b16 c; \n"
" mov.b16 c, 0x3f80U; \n" // 1.0
" fma.rn.bf16 $0, $1, c, $2; } \n";
auto &fAdd = *builder.create<PTXInstr>(ptxAsm);
auto res = builder.newOperand("=h");
auto lhs = builder.newOperand(operands[0], "h");
auto rhs = builder.newOperand(operands[1], "h");
fAdd({res, lhs, rhs}, /*onlyAttachMLIRArgs=*/true);
return builder.launch(rewriter, loc, i16_ty, false);
#endif
} else {
return rewriter.create<LLVM::FAddOp>(loc, elemTy, operands[0],
operands[1]);
}
}
};
struct FSubOpConversion
: ElementwiseOpConversionBase<mlir::arith::SubFOp, FSubOpConversion> {
using Base =
ElementwiseOpConversionBase<mlir::arith::SubFOp, FSubOpConversion>;
using Base::Base;
using Adaptor = typename Base::OpAdaptor;
Value createDestOp(mlir::arith::SubFOp op, OpAdaptor adaptor,
ConversionPatternRewriter &rewriter, Type elemTy,
ValueRange operands, Location loc) const {
auto lhsElemTy = getElementType(op.getLhs());
auto rhsElemTy = getElementType(op.getRhs());
if (lhsElemTy.isBF16() && rhsElemTy.isBF16()) {
#ifdef USE_ROCM
return EmitDualBF16ElementwiseOp<LLVM::FSubOp>(loc, rewriter, operands);
#else
PTXBuilder builder;
auto ptxAsm = " { .reg .b16 c; \n"
" mov.b16 c, 0xbf80U; \n" // -1.0
" fma.rn.bf16 $0, $2, c, $1;} \n";
auto &fSub = *builder.create<PTXInstr>(ptxAsm);
auto res = builder.newOperand("=h");
auto lhs = builder.newOperand(operands[0], "h");
auto rhs = builder.newOperand(operands[1], "h");
fSub({res, lhs, rhs}, /*onlyAttachMLIRArgs=*/true);
return builder.launch(rewriter, loc, i16_ty, false);
#endif
} else {
return rewriter.create<LLVM::FSubOp>(loc, elemTy, operands[0],
operands[1]);
}
}
};
struct SIToFPOpConversion
: ElementwiseOpConversionBase<mlir::arith::SIToFPOp, SIToFPOpConversion> {
using Base =
ElementwiseOpConversionBase<mlir::arith::SIToFPOp, SIToFPOpConversion>;
using Base::Base;
using Adaptor = typename Base::OpAdaptor;
Value createDestOp(mlir::arith::SIToFPOp op, OpAdaptor adaptor,
ConversionPatternRewriter &rewriter, Type elemTy,
ValueRange operands, Location loc) const {
auto outElemTy = getElementType(op.getOut());
if (outElemTy.isBF16()) {
auto value = rewriter.create<LLVM::SIToFPOp>(loc, f32_ty, operands[0]);
return FpToFpOpConversion::convertFp32ToBf16(loc, rewriter, value);
} else {
return rewriter.create<LLVM::SIToFPOp>(loc, elemTy, operands[0]);
}
}
};
struct FPToSIOpConversion
: ElementwiseOpConversionBase<mlir::arith::FPToSIOp, FPToSIOpConversion> {
using Base =
ElementwiseOpConversionBase<mlir::arith::FPToSIOp, FPToSIOpConversion>;
using Base::Base;
using Adaptor = typename Base::OpAdaptor;
Value createDestOp(mlir::arith::FPToSIOp op, OpAdaptor adaptor,
ConversionPatternRewriter &rewriter, Type elemTy,
ValueRange operands, Location loc) const {
auto inElemTy = getElementType(op.getIn());
if (inElemTy.isBF16()) {
auto value =
FpToFpOpConversion::convertBf16ToFp32(loc, rewriter, operands[0]);
return rewriter.create<LLVM::FPToSIOp>(loc, elemTy, value);
} else {
return rewriter.create<LLVM::FPToSIOp>(loc, elemTy, operands[0]);
}
}
};
struct ExtFOpConversion
: ElementwiseOpConversionBase<mlir::arith::ExtFOp, ExtFOpConversion> {
using Base =
ElementwiseOpConversionBase<mlir::arith::ExtFOp, ExtFOpConversion>;
using Base::Base;
using Adaptor = typename Base::OpAdaptor;
Value createDestOp(mlir::arith::ExtFOp op, OpAdaptor adaptor,
ConversionPatternRewriter &rewriter, Type elemTy,
ValueRange operands, Location loc) const {
auto inElemTy = getElementType(op.getIn());
if (inElemTy.isBF16()) {
auto outElemTy = getElementType(op.getOut());
assert(outElemTy.isF32() && "unsupported conversion");
return FpToFpOpConversion::convertBf16ToFp32(loc, rewriter, operands[0]);
} else {
return rewriter.create<LLVM::FPExtOp>(loc, elemTy, operands[0]);
}
}
};
struct TruncFOpConversion
: ElementwiseOpConversionBase<mlir::arith::TruncFOp, TruncFOpConversion> {
using Base =
ElementwiseOpConversionBase<mlir::arith::TruncFOp, TruncFOpConversion>;
using Base::Base;
using Adaptor = typename Base::OpAdaptor;
Value createDestOp(mlir::arith::TruncFOp op, OpAdaptor adaptor,
ConversionPatternRewriter &rewriter, Type elemTy,
ValueRange operands, Location loc) const {
auto outElemTy = getElementType(op.getOut());
if (outElemTy.isBF16()) {
auto inElemTy = getElementType(op.getIn());
assert(inElemTy.isF32() && "unsupported conversion");
return FpToFpOpConversion::convertFp32ToBf16(loc, rewriter, operands[0]);
} else {
return rewriter.create<LLVM::FPTruncOp>(loc, elemTy, operands[0]);
}
}
};
struct ExpOpConversionApprox
: ElementwiseOpConversionBase<mlir::math::ExpOp, ExpOpConversionApprox> {
using Base =
ElementwiseOpConversionBase<mlir::math::ExpOp, ExpOpConversionApprox>;
using Base::Base;
using Adaptor = typename Base::OpAdaptor;
Value createDestOp(mlir::math::ExpOp op, OpAdaptor adaptor,
ConversionPatternRewriter &rewriter, Type elemTy,
ValueRange operands, Location loc) const {
// For FP64 input, call __nv_expf for higher-precision calculation
if (elemTy.getIntOrFloatBitWidth() == 64)
return {};
const double log2e = 1.4426950408889634;
Value prod = fmul(f32_ty, operands[0], f32_val(log2e));
PTXBuilder ptxBuilder;
auto &exp2 = ptxBuilder.create<PTXInstr>("ex2")->o("approx").o("f32");
auto output = ptxBuilder.newOperand("=f");
auto input = ptxBuilder.newOperand(prod, "f");
exp2(output, input);
return ptxBuilder.launch(rewriter, loc, f32_ty, false);
}
};
void populateElementwiseOpToLLVMPatterns(mlir::LLVMTypeConverter &typeConverter,
RewritePatternSet &patterns,
int numWarps,
AxisInfoAnalysis &axisInfoAnalysis,
const Allocation *allocation,
Value smem, PatternBenefit benefit) {
#define POPULATE_TERNARY_OP(SRC_OP, DST_OP) \
patterns.add<ElementwiseOpConversion<SRC_OP, DST_OP>>(typeConverter, benefit);
POPULATE_TERNARY_OP(triton::gpu::SelectOp, LLVM::SelectOp)
#undef POPULATE_TERNARY_OP
#define POPULATE_BINARY_OP(SRC_OP, DST_OP) \
patterns.add<ElementwiseOpConversion<SRC_OP, DST_OP>>(typeConverter, benefit);
POPULATE_BINARY_OP(arith::SubIOp, LLVM::SubOp) // -
POPULATE_BINARY_OP(arith::AddIOp, LLVM::AddOp) // +
POPULATE_BINARY_OP(arith::MulIOp, LLVM::MulOp) // *
POPULATE_BINARY_OP(arith::DivSIOp, LLVM::SDivOp)
POPULATE_BINARY_OP(arith::DivUIOp, LLVM::UDivOp)
POPULATE_BINARY_OP(arith::RemFOp, LLVM::FRemOp) // %
POPULATE_BINARY_OP(arith::RemSIOp, LLVM::SRemOp)
POPULATE_BINARY_OP(arith::RemUIOp, LLVM::URemOp)
POPULATE_BINARY_OP(arith::AndIOp, LLVM::AndOp) // &
POPULATE_BINARY_OP(arith::OrIOp, LLVM::OrOp) // |
POPULATE_BINARY_OP(arith::XOrIOp, LLVM::XOrOp) // ^
POPULATE_BINARY_OP(arith::ShLIOp, LLVM::ShlOp) // <<
POPULATE_BINARY_OP(arith::ShRSIOp, LLVM::AShrOp) // >>
POPULATE_BINARY_OP(arith::ShRUIOp, LLVM::LShrOp) // >>
#undef POPULATE_BINARY_OP
#define POPULATE_UNARY_OP(SRC_OP, DST_OP) \
patterns.add<ElementwiseOpConversion<SRC_OP, DST_OP>>(typeConverter, benefit);
POPULATE_UNARY_OP(arith::TruncIOp, LLVM::TruncOp)
POPULATE_UNARY_OP(arith::ExtSIOp, LLVM::SExtOp)
POPULATE_UNARY_OP(arith::ExtUIOp, LLVM::ZExtOp)
POPULATE_UNARY_OP(arith::FPToUIOp, LLVM::FPToUIOp)
POPULATE_UNARY_OP(arith::UIToFPOp, LLVM::UIToFPOp)
POPULATE_UNARY_OP(math::LogOp, math::LogOp)
POPULATE_UNARY_OP(math::CosOp, math::CosOp)
POPULATE_UNARY_OP(math::SinOp, math::SinOp)
POPULATE_UNARY_OP(math::SqrtOp, math::SqrtOp)
POPULATE_UNARY_OP(math::ExpOp, math::ExpOp)
POPULATE_UNARY_OP(triton::BitcastOp, LLVM::BitcastOp)
POPULATE_UNARY_OP(triton::IntToPtrOp, LLVM::IntToPtrOp)
POPULATE_UNARY_OP(triton::PtrToIntOp, LLVM::PtrToIntOp)
#undef POPULATE_UNARY_OP
patterns.add<CmpIOpConversion>(typeConverter, benefit);
patterns.add<CmpFOpConversion>(typeConverter, benefit);
patterns.add<FDivOpConversion>(typeConverter, benefit);
patterns.add<FSubOpConversion>(typeConverter, benefit);
patterns.add<FAddOpConversion>(typeConverter, benefit);
patterns.add<FMulOpConversion>(typeConverter, benefit);
patterns.add<ExtFOpConversion>(typeConverter, benefit);
patterns.add<TruncFOpConversion>(typeConverter, benefit);
patterns.add<FPToSIOpConversion>(typeConverter, benefit);
patterns.add<SIToFPOpConversion>(typeConverter, benefit);
patterns.add<FpToFpOpConversion>(typeConverter, benefit);
patterns.add<ExtElemwiseOpConversion>(typeConverter, benefit);
// ExpOpConversionApprox will try using ex2.approx if the input type is FP32.
// For FP64 input type, ExpOpConversionApprox will return failure and
// ElementwiseOpConversion<math::ExpOp, math::ExpOp> defined below will call
// __nv_expf for higher-precision calculation
#ifndef USE_ROCM
patterns.add<ExpOpConversionApprox>(typeConverter, benefit);
#endif
}

View File

@@ -0,0 +1,16 @@
#ifndef TRITON_CONVERSION_TRITONGPU_TO_ELEMENTWISE_OP_H
#define TRITON_CONVERSION_TRITONGPU_TO_ELEMENTWISE_OP_H
#include "TritonGPUToLLVMBase.h"
using namespace mlir;
using namespace mlir::triton;
void populateElementwiseOpToLLVMPatterns(mlir::LLVMTypeConverter &typeConverter,
RewritePatternSet &patterns,
int numWarps,
AxisInfoAnalysis &axisInfoAnalysis,
const Allocation *allocation,
Value smem, PatternBenefit benefit);
#endif

View File

@@ -1,4 +1,4 @@
#include "triton/Conversion/TritonGPUToLLVM/GcnAsmFormat.h"
#include "triton/Conversion/TritonGPUToLLVM/GCNAsmFormat.h"
#include "mlir/Dialect/LLVMIR/LLVMDialect.h"
#include "mlir/Transforms/DialectConversion.h"
#include "triton/Conversion/TritonGPUToLLVM/AsmFormat.h"
@@ -27,8 +27,7 @@ GCNBuilder::Operand *GCNBuilder::newOperand(StringRef constraint) {
return opr;
}
GCNBuilder::Modifier *GCNBuilder::newModifier(StringRef modifier,
StringRef arg) {
GCNBuilder::Modifier *GCNBuilder::newModifier(StringRef modifier, StringRef arg) {
assert(!modifier.empty());
auto *mod = newModifier();
mod->modifier = modifier;
@@ -52,8 +51,7 @@ std::string GCNBuilder::getConstraints() const {
auto args = getAllArgs();
llvm::SmallVector<std::string, 4> argReprs;
for (auto arg : args)
if (!arg->constraint.empty())
argReprs.push_back(arg->constraint);
argReprs.push_back(arg->constraint);
return strJoin(argReprs, ",");
}
@@ -127,17 +125,6 @@ GCNInstr::Operand *GCNBuilder::newAddrOperand(mlir::Value addr,
return opr;
}
GCNInstr::Operand *GCNBuilder::newEmptyOperand(std::string arg) {
auto *opr = newOperand();
opr->repr = [arg](int idx) -> std::string {
std::stringstream ss;
ss << arg;
return ss.str();
};
return opr;
}
std::string GCNBuilder::dump() const {
llvm::SmallVector<std::string> lines;
for (auto &exec : executions) {
@@ -147,15 +134,13 @@ std::string GCNBuilder::dump() const {
return strJoin(lines, "\n\t");
}
GCNInstrExecution &GCNInstrCommon::call(ArrayRef<Operand *> oprs,
ArrayRef<Modifier *> mods) {
GCNInstrExecution &GCNInstrCommon::call(ArrayRef<Operand *> oprs, ArrayRef<Modifier *> mods) {
builder->executions.emplace_back(
std::make_unique<GCNInstrExecution>(this, oprs, mods));
return *builder->executions.back();
}
GCNInstrExecution &GCNInstrCommon::operator()(ArrayRef<Operand *> oprs,
ArrayRef<Modifier *> mods) {
GCNInstrExecution &GCNInstrCommon::operator()(ArrayRef<Operand *> oprs, ArrayRef<Modifier *> mods) {
return call(oprs, mods);
}
@@ -173,12 +158,16 @@ std::string GCNInstrExecution::dump() const {
std::string argsRepr = strJoin(argReprs, ", ");
llvm::SmallVector<std::string, 4> modReprs;
for (auto *mod : mods) {
for(auto *mod : mods) {
modReprs.push_back(mod->dump());
}
std::string modsRepr = strJoin(modReprs, " ");
os << instrRepr << " " << argsRepr << " " << modsRepr << ";";
if (!modsRepr.empty()) {
os << instrRepr << " " << argsRepr << ", " << modsRepr;
} else {
os << instrRepr << " " << argsRepr;
}
os.flush();
return osStr;
}
@@ -196,4 +185,4 @@ GCNInstrExecution::getArgList() const {
}
} // namespace triton
} // namespace mlir
} // namespace mlir

View File

@@ -0,0 +1,911 @@
#include "mlir/IR/Matchers.h"
#include "mlir/IR/TypeUtilities.h"
#include "ConvertLayoutOpToLLVM.h"
#include "LoadStoreOpToLLVM.h"
using namespace mlir;
using namespace mlir::triton;
using ::mlir::LLVM::getElementsFromStruct;
using ::mlir::LLVM::getSharedMemoryObjectFromStruct;
using ::mlir::LLVM::getStructFromElements;
using ::mlir::triton::gpu::getElemsPerThread;
using ::mlir::triton::gpu::SharedEncodingAttr;
// Contains some helper functions for both Load and Store conversions.
struct LoadStoreConversionBase : public ConvertTritonGPUOpToLLVMPatternBase {
explicit LoadStoreConversionBase(AxisInfoAnalysis &axisAnalysisPass)
: axisAnalysisPass(axisAnalysisPass) {}
// Get corresponding LLVM element values of \param value.
static SmallVector<Value> getLLVMElems(Value value, Value llValue,
ConversionPatternRewriter &rewriter,
Location loc) {
if (!value)
return {};
if (!llValue.getType().isa<LLVM::LLVMStructType>())
return {llValue};
// Here, we assume that all inputs should have a blockedLayout
auto valueVals = getElementsFromStruct(loc, llValue, rewriter);
return valueVals;
}
unsigned getVectorSize(Value ptr) const {
return axisAnalysisPass.getPtrVectorSize(ptr);
}
unsigned getMaskAlignment(Value mask) const {
return axisAnalysisPass.getMaskAlignment(mask);
}
protected:
AxisInfoAnalysis &axisAnalysisPass;
};
struct LoadOpConversion
: public ConvertTritonGPUOpToLLVMPattern<triton::LoadOp>,
public LoadStoreConversionBase {
using ConvertTritonGPUOpToLLVMPattern<
triton::LoadOp>::ConvertTritonGPUOpToLLVMPattern;
LoadOpConversion(LLVMTypeConverter &converter,
AxisInfoAnalysis &axisAnalysisPass, PatternBenefit benefit)
: ConvertTritonGPUOpToLLVMPattern<triton::LoadOp>(converter, benefit),
LoadStoreConversionBase(axisAnalysisPass) {}
LogicalResult
matchAndRewrite(triton::LoadOp op, OpAdaptor adaptor,
ConversionPatternRewriter &rewriter) const override {
auto loc = op->getLoc();
// original values
Value ptr = op.ptr();
Value mask = op.mask();
Value other = op.other();
// adaptor values
Value llPtr = adaptor.ptr();
Value llMask = adaptor.mask();
Value llOther = adaptor.other();
// Determine the vectorization size
Type valueTy = op.getResult().getType();
Type valueElemTy =
typeConverter->convertType(getElementTypeOrSelf(valueTy));
unsigned vec = getVectorSize(ptr);
unsigned numElems = getElemsPerThread(ptr.getType());
if (llMask)
vec = std::min<size_t>(vec, getMaskAlignment(mask));
// Get the LLVM values for pointers
auto ptrElems = getLLVMElems(ptr, llPtr, rewriter, loc);
assert(ptrElems.size() == numElems);
// Get the LLVM values for mask
SmallVector<Value> maskElems;
if (llMask) {
maskElems = getLLVMElems(mask, llMask, rewriter, loc);
assert(maskElems.size() == numElems);
}
// Get the LLVM values for `other`
// TODO: (goostavz) handle when other is const but not splat, which
// should be rarely seen
bool otherIsSplatConstInt = false;
DenseElementsAttr constAttr;
int64_t splatVal = 0;
if (other && valueElemTy.isa<IntegerType>() &&
matchPattern(other, m_Constant(&constAttr)) && constAttr.isSplat()) {
otherIsSplatConstInt = true;
splatVal = constAttr.getSplatValue<APInt>().getSExtValue();
}
auto otherElems = getLLVMElems(other, llOther, rewriter, loc);
// vectorized iteration through all the pointer/mask/other elements
const int valueElemNbits =
std::max(8u, valueElemTy.getIntOrFloatBitWidth());
const int numVecs = numElems / vec;
SmallVector<Value> loadedVals;
for (size_t vecStart = 0; vecStart < numElems; vecStart += vec) {
// TODO: optimization when ptr is GEP with constant offset
size_t in_off = 0;
const size_t maxWordWidth = std::max<size_t>(32, valueElemNbits);
const size_t totalWidth = valueElemNbits * vec;
const size_t width = std::min(totalWidth, maxWordWidth);
const size_t nWords = std::max<size_t>(1, totalWidth / width);
const size_t wordNElems = width / valueElemNbits;
assert(wordNElems * nWords * numVecs == numElems);
#ifdef USE_ROCM
Value pred = mask ? maskElems[vecStart] : int_val(1, 1);
for (size_t wordIdx = 0; wordIdx < nWords; ++wordIdx) {
for (size_t wordElem = 0; wordElem < wordNElems; ++wordElem) {
size_t elemOffset = vecStart + wordIdx * wordNElems + wordElem;
// get values
Value trueVal = load(ptrElems[elemOffset]);
Value zeroVal = bitcast(i32_val(0), valueElemTy);
Value falseVal = other ? load(otherElems[elemOffset]) : zeroVal;
// select value based on mask
Value ret = select(pred, trueVal, falseVal);
loadedVals.push_back(ret);
}
}
#else
// TODO(Superjomn) Add cache policy fields to StoreOp.
// TODO(Superjomn) Deal with cache policy here.
const bool hasL2EvictPolicy = false;
PTXBuilder ptxBuilder;
Value pred = mask ? maskElems[vecStart] : int_val(1, 1);
const std::string readConstraint =
(width == 64) ? "l" : ((width == 32) ? "r" : "c");
const std::string writeConstraint =
(width == 64) ? "=l" : ((width == 32) ? "=r" : "=c");
// prepare asm operands
auto *dstsOpr = ptxBuilder.newListOperand();
for (size_t wordIdx = 0; wordIdx < nWords; ++wordIdx) {
auto *opr = ptxBuilder.newOperand(writeConstraint); // =r operations
dstsOpr->listAppend(opr);
}
auto *addrOpr =
ptxBuilder.newAddrOperand(ptrElems[vecStart], "l", in_off);
// Define the instruction opcode
auto &ld = ptxBuilder.create<>("ld")
->o("volatile", op.isVolatile())
.global()
.o("ca", op.cache() == triton::CacheModifier::CA)
.o("cg", op.cache() == triton::CacheModifier::CG)
.o("L1::evict_first",
op.evict() == triton::EvictionPolicy::EVICT_FIRST)
.o("L1::evict_last",
op.evict() == triton::EvictionPolicy::EVICT_LAST)
.o("L1::cache_hint", hasL2EvictPolicy)
.v(nWords)
.b(width);
PTXBuilder::Operand *evictOpr{};
// Here lack a mlir::Value to bind to this operation, so disabled.
// if (has_l2_evict_policy)
// evictOpr = ptxBuilder.newOperand(l2Evict, "l");
if (!evictOpr)
ld(dstsOpr, addrOpr).predicate(pred, "b");
else
ld(dstsOpr, addrOpr, evictOpr).predicate(pred, "b");
if (other) {
for (size_t ii = 0; ii < nWords; ++ii) {
// PTX doesn't support mov.u8, so we need to use mov.u16
auto movWidth = width < 16 ? 16 : width;
PTXInstr &mov =
ptxBuilder.create<>("mov")->o("u" + std::to_string(movWidth));
size_t size = width / valueElemNbits;
auto vecTy = LLVM::getFixedVectorType(valueElemTy, size);
Value v = undef(vecTy);
for (size_t s = 0; s < size; ++s) {
Value falseVal = otherElems[vecStart + ii * size + s];
Value sVal = createIndexAttrConstant(
rewriter, loc, this->getTypeConverter()->getIndexType(), s);
v = insert_element(vecTy, v, falseVal, sVal);
}
v = bitcast(v, IntegerType::get(getContext(), width));
PTXInstr::Operand *opr{};
if (otherIsSplatConstInt)
opr = ptxBuilder.newConstantOperand(splatVal);
else
opr = ptxBuilder.newOperand(v, readConstraint);
mov(dstsOpr->listGet(ii), opr).predicateNot(pred, "b");
}
}
// Create inline ASM signature
SmallVector<Type> retTys(nWords, IntegerType::get(getContext(), width));
Type retTy = retTys.size() > 1
? LLVM::LLVMStructType::getLiteral(getContext(), retTys)
: retTys[0];
// TODO: if (has_l2_evict_policy)
// auto asmDialectAttr =
// LLVM::AsmDialectAttr::get(rewriter.getContext(),
// LLVM::AsmDialect::AD_ATT);
Value ret = ptxBuilder.launch(rewriter, loc, retTy);
// Extract and store return values
SmallVector<Value> rets;
for (unsigned int ii = 0; ii < nWords; ++ii) {
Value curr;
if (retTy.isa<LLVM::LLVMStructType>()) {
curr = extract_val(IntegerType::get(getContext(), width), ret,
rewriter.getI64ArrayAttr(ii));
} else {
curr = ret;
}
curr = bitcast(curr, LLVM::getFixedVectorType(valueElemTy,
width / valueElemNbits));
rets.push_back(curr);
}
int tmp = width / valueElemNbits;
for (size_t ii = 0; ii < vec; ++ii) {
Value vecIdx = createIndexAttrConstant(
rewriter, loc, this->getTypeConverter()->getIndexType(), ii % tmp);
Value loaded = extract_element(valueElemTy, rets[ii / tmp], vecIdx);
loadedVals.push_back(loaded);
}
#endif
} // end vec
Type llvmResultStructTy = getTypeConverter()->convertType(valueTy);
Value resultStruct =
getStructFromElements(loc, loadedVals, rewriter, llvmResultStructTy);
rewriter.replaceOp(op, {resultStruct});
return success();
}
};
struct StoreOpConversion
: public ConvertTritonGPUOpToLLVMPattern<triton::StoreOp>,
public LoadStoreConversionBase {
using ConvertTritonGPUOpToLLVMPattern<
triton::StoreOp>::ConvertTritonGPUOpToLLVMPattern;
StoreOpConversion(LLVMTypeConverter &converter,
AxisInfoAnalysis &axisAnalysisPass, PatternBenefit benefit)
: ConvertTritonGPUOpToLLVMPattern<triton::StoreOp>(converter, benefit),
LoadStoreConversionBase(axisAnalysisPass) {}
LogicalResult
matchAndRewrite(triton::StoreOp op, OpAdaptor adaptor,
ConversionPatternRewriter &rewriter) const override {
Value ptr = op.ptr();
Value mask = op.mask();
Value value = op.value();
Value llPtr = adaptor.ptr();
Value llMask = adaptor.mask();
Value llValue = adaptor.value();
auto loc = op->getLoc();
MLIRContext *ctx = rewriter.getContext();
auto valueTy = value.getType();
Type valueElemTy =
typeConverter->convertType(getElementTypeOrSelf(valueTy));
unsigned vec = getVectorSize(ptr);
unsigned numElems = getElemsPerThread(ptr.getType());
auto ptrElems = getLLVMElems(ptr, llPtr, rewriter, loc);
auto valueElems = getLLVMElems(value, llValue, rewriter, loc);
assert(ptrElems.size() == valueElems.size());
// Determine the vectorization size
SmallVector<Value> maskElems;
if (llMask) {
maskElems = getLLVMElems(mask, llMask, rewriter, loc);
assert(valueElems.size() == maskElems.size());
unsigned maskAlign = getMaskAlignment(mask);
vec = std::min(vec, maskAlign);
}
const size_t dtsize =
std::max<int>(1, valueElemTy.getIntOrFloatBitWidth() / 8);
const size_t valueElemNbits = dtsize * 8;
const int numVecs = numElems / vec;
for (size_t vecStart = 0; vecStart < numElems; vecStart += vec) {
// TODO: optimization when ptr is AddPtr with constant offset
size_t in_off = 0;
const size_t maxWordWidth = std::max<size_t>(32, valueElemNbits);
const size_t totalWidth = valueElemNbits * vec;
const size_t width = std::min(totalWidth, maxWordWidth);
const size_t nWords = std::max<size_t>(1, totalWidth / width);
const size_t wordNElems = width / valueElemNbits;
assert(wordNElems * nWords * numVecs == numElems);
// TODO(Superjomn) Add cache policy fields to StoreOp.
// TODO(Superjomn) Deal with cache policy here.
Type valArgTy = IntegerType::get(ctx, width);
auto wordTy = vec_ty(valueElemTy, wordNElems);
SmallVector<std::pair<Value, std::string>> asmArgs;
for (size_t wordIdx = 0; wordIdx < nWords; ++wordIdx) {
// llWord is a width-len composition
Value llWord = undef(wordTy);
// Insert each value element to the composition
for (size_t elemIdx = 0; elemIdx < wordNElems; ++elemIdx) {
const size_t elemOffset = vecStart + wordIdx * wordNElems + elemIdx;
assert(elemOffset < valueElems.size());
Value elem = valueElems[elemOffset];
if (elem.getType().isInteger(1))
elem = rewriter.create<LLVM::SExtOp>(loc, type::i8Ty(ctx), elem);
elem = bitcast(elem, valueElemTy);
#ifdef USE_ROCM
Value maskVal = llMask ? maskElems[vecStart] : int_val(1, 1);
Value ret = select(maskVal, elem , bitcast(i32_val(0), valueElemTy));
store(ret, ptrElems[elemOffset]);
}
}
#else
Type u32Ty = typeConverter->convertType(type::u32Ty(ctx));
llWord = insert_element(wordTy, llWord, elem, i32_val(elemIdx));
}
llWord = bitcast(llWord, valArgTy);
std::string constraint =
(width == 64) ? "l" : ((width == 32) ? "r" : "c");
asmArgs.emplace_back(llWord, constraint);
}
// Prepare the PTX inline asm.
PTXBuilder ptxBuilder;
auto *asmArgList = ptxBuilder.newListOperand(asmArgs);
Value maskVal = llMask ? maskElems[vecStart] : int_val(1, 1);
auto *asmAddr =
ptxBuilder.newAddrOperand(ptrElems[vecStart], "l", in_off);
auto &ptxStoreInstr =
ptxBuilder.create<>("st")->global().v(nWords).b(width);
ptxStoreInstr(asmAddr, asmArgList).predicate(maskVal, "b");
Type boolTy = getTypeConverter()->convertType(rewriter.getIntegerType(1));
llvm::SmallVector<Type> argTys({boolTy, ptr.getType()});
argTys.insert(argTys.end(), nWords, valArgTy);
auto asmReturnTy = void_ty(ctx);
ptxBuilder.launch(rewriter, loc, asmReturnTy);
#endif
}
rewriter.eraseOp(op);
return success();
}
};
struct AtomicCASOpConversion
: public ConvertTritonGPUOpToLLVMPattern<triton::AtomicCASOp>,
public LoadStoreConversionBase {
using ConvertTritonGPUOpToLLVMPattern<
triton::AtomicCASOp>::ConvertTritonGPUOpToLLVMPattern;
AtomicCASOpConversion(LLVMTypeConverter &converter,
const Allocation *allocation, Value smem,
AxisInfoAnalysis &axisAnalysisPass,
PatternBenefit benefit)
: ConvertTritonGPUOpToLLVMPattern<triton::AtomicCASOp>(
converter, allocation, smem, benefit),
LoadStoreConversionBase(axisAnalysisPass) {}
LogicalResult
matchAndRewrite(triton::AtomicCASOp op, OpAdaptor adaptor,
ConversionPatternRewriter &rewriter) const override {
auto loc = op.getLoc();
MLIRContext *ctx = rewriter.getContext();
Value ptr = op.ptr();
Value llPtr = adaptor.ptr();
Value llCmp = adaptor.cmp();
Value llVal = adaptor.val();
auto ptrElements = getElementsFromStruct(loc, llPtr, rewriter);
auto cmpElements = getElementsFromStruct(loc, llCmp, rewriter);
auto valElements = getElementsFromStruct(loc, llVal, rewriter);
auto valueTy = op.getResult().getType().dyn_cast<RankedTensorType>();
Type valueElemTy =
valueTy ? getTypeConverter()->convertType(valueTy.getElementType())
: op.getResult().getType();
auto tid = tid_val();
Value pred = icmp_eq(tid, i32_val(0));
PTXBuilder ptxBuilderMemfence;
auto memfence = ptxBuilderMemfence.create<PTXInstr>("membar")->o("gl");
memfence();
auto ASMReturnTy = void_ty(ctx);
ptxBuilderMemfence.launch(rewriter, loc, ASMReturnTy);
Value atomPtr = getSharedMemoryBase(loc, rewriter, op.getOperation());
atomPtr = bitcast(atomPtr, ptr_ty(valueElemTy, 3));
Value casPtr = ptrElements[0];
Value casCmp = cmpElements[0];
Value casVal = valElements[0];
PTXBuilder ptxBuilderAtomicCAS;
auto *dstOpr = ptxBuilderAtomicCAS.newOperand("=r");
auto *ptrOpr = ptxBuilderAtomicCAS.newAddrOperand(casPtr, "l");
auto *cmpOpr = ptxBuilderAtomicCAS.newOperand(casCmp, "r");
auto *valOpr = ptxBuilderAtomicCAS.newOperand(casVal, "r");
auto &atom = *ptxBuilderAtomicCAS.create<PTXInstr>("atom");
atom.global().o("cas").o("b32");
atom(dstOpr, ptrOpr, cmpOpr, valOpr).predicate(pred);
auto old = ptxBuilderAtomicCAS.launch(rewriter, loc, valueElemTy);
barrier();
PTXBuilder ptxBuilderStore;
auto *dstOprStore = ptxBuilderStore.newAddrOperand(atomPtr, "l");
auto *valOprStore = ptxBuilderStore.newOperand(old, "r");
auto &st = *ptxBuilderStore.create<PTXInstr>("st");
st.shared().o("b32");
st(dstOprStore, valOprStore).predicate(pred);
ptxBuilderStore.launch(rewriter, loc, ASMReturnTy);
ptxBuilderMemfence.launch(rewriter, loc, ASMReturnTy);
barrier();
Value ret = load(atomPtr);
barrier();
rewriter.replaceOp(op, {ret});
return success();
}
};
struct AtomicRMWOpConversion
: public ConvertTritonGPUOpToLLVMPattern<triton::AtomicRMWOp>,
public LoadStoreConversionBase {
using ConvertTritonGPUOpToLLVMPattern<
triton::AtomicRMWOp>::ConvertTritonGPUOpToLLVMPattern;
AtomicRMWOpConversion(LLVMTypeConverter &converter,
const Allocation *allocation, Value smem,
AxisInfoAnalysis &axisAnalysisPass,
PatternBenefit benefit)
: ConvertTritonGPUOpToLLVMPattern<triton::AtomicRMWOp>(
converter, allocation, smem, benefit),
LoadStoreConversionBase(axisAnalysisPass) {}
LogicalResult
matchAndRewrite(triton::AtomicRMWOp op, OpAdaptor adaptor,
ConversionPatternRewriter &rewriter) const override {
auto loc = op.getLoc();
MLIRContext *ctx = rewriter.getContext();
auto atomicRmwAttr = op.atomic_rmw_op();
Value ptr = op.ptr();
Value val = op.val();
Value llPtr = adaptor.ptr();
Value llVal = adaptor.val();
Value llMask = adaptor.mask();
auto valElements = getElementsFromStruct(loc, llVal, rewriter);
auto ptrElements = getElementsFromStruct(loc, llPtr, rewriter);
auto maskElements = getElementsFromStruct(loc, llMask, rewriter);
auto valueTy = op.getResult().getType().dyn_cast<RankedTensorType>();
Type valueElemTy =
valueTy ? getTypeConverter()->convertType(valueTy.getElementType())
: op.getResult().getType();
const size_t valueElemNbits = valueElemTy.getIntOrFloatBitWidth();
auto elemsPerThread = getElemsPerThread(val.getType());
// vec = 1 for scalar
auto vec = getVectorSize(ptr);
Value mask = int_val(1, 1);
auto tid = tid_val();
// tensor
if (valueTy) {
auto valTy = val.getType().cast<RankedTensorType>();
vec = std::min<unsigned>(vec, valTy.getElementType().isF16() ? 2 : 1);
// mask
auto shape = valueTy.getShape();
auto numElements = product(shape);
mask = and_(mask, icmp_slt(mul(tid, i32_val(elemsPerThread)),
i32_val(numElements)));
}
auto vecTy = vec_ty(valueElemTy, vec);
SmallVector<Value> resultVals(elemsPerThread);
for (size_t i = 0; i < elemsPerThread; i += vec) {
Value rmwVal = undef(vecTy);
for (int ii = 0; ii < vec; ++ii) {
Value iiVal = createIndexAttrConstant(
rewriter, loc, getTypeConverter()->getIndexType(), ii);
rmwVal = insert_element(vecTy, rmwVal, valElements[i + ii], iiVal);
}
Value rmwPtr = ptrElements[i];
Value rmwMask = maskElements[i];
rmwMask = and_(rmwMask, mask);
std::string sTy;
PTXBuilder ptxBuilderAtomicRMW;
std::string tyId = valueElemNbits * vec == 64
? "l"
: (valueElemNbits * vec == 32 ? "r" : "h");
auto *dstOpr = ptxBuilderAtomicRMW.newOperand("=" + tyId);
auto *ptrOpr = ptxBuilderAtomicRMW.newAddrOperand(rmwPtr, "l");
auto *valOpr = ptxBuilderAtomicRMW.newOperand(rmwVal, tyId);
auto &atom = ptxBuilderAtomicRMW.create<>("atom")->global().o("gpu");
auto rmwOp = stringifyRMWOp(atomicRmwAttr).str();
auto sBits = std::to_string(valueElemNbits);
switch (atomicRmwAttr) {
case RMWOp::AND:
sTy = "b" + sBits;
break;
case RMWOp::OR:
sTy = "b" + sBits;
break;
case RMWOp::XOR:
sTy = "b" + sBits;
break;
case RMWOp::ADD:
sTy = "s" + sBits;
break;
case RMWOp::FADD:
rmwOp = "add";
rmwOp += (valueElemNbits == 16 ? ".noftz" : "");
sTy = "f" + sBits;
sTy += (vec == 2 && valueElemNbits == 16) ? "x2" : "";
break;
case RMWOp::MAX:
sTy = "s" + sBits;
break;
case RMWOp::MIN:
sTy = "s" + sBits;
break;
case RMWOp::UMAX:
rmwOp = "max";
sTy = "u" + sBits;
break;
case RMWOp::UMIN:
rmwOp = "min";
sTy = "u" + sBits;
break;
case RMWOp::XCHG:
sTy = "b" + sBits;
break;
default:
return failure();
}
atom.o(rmwOp).o(sTy);
if (valueTy) {
atom(dstOpr, ptrOpr, valOpr).predicate(rmwMask);
auto retType = vec == 1 ? valueElemTy : vecTy;
auto ret = ptxBuilderAtomicRMW.launch(rewriter, loc, retType);
for (int ii = 0; ii < vec; ++ii) {
resultVals[i + ii] =
vec == 1 ? ret : extract_element(valueElemTy, ret, idx_val(ii));
}
} else {
PTXBuilder ptxBuilderMemfence;
auto memfenc = ptxBuilderMemfence.create<PTXInstr>("membar")->o("gl");
memfenc();
auto ASMReturnTy = void_ty(ctx);
ptxBuilderMemfence.launch(rewriter, loc, ASMReturnTy);
rmwMask = and_(rmwMask, icmp_eq(tid, i32_val(0)));
atom(dstOpr, ptrOpr, valOpr).predicate(rmwMask);
auto old = ptxBuilderAtomicRMW.launch(rewriter, loc, valueElemTy);
Value atomPtr = getSharedMemoryBase(loc, rewriter, op.getOperation());
atomPtr = bitcast(atomPtr, ptr_ty(valueElemTy, 3));
store(old, atomPtr);
barrier();
Value ret = load(atomPtr);
barrier();
rewriter.replaceOp(op, {ret});
}
}
if (valueTy) {
Type structTy = getTypeConverter()->convertType(valueTy);
Value resultStruct =
getStructFromElements(loc, resultVals, rewriter, structTy);
rewriter.replaceOp(op, {resultStruct});
}
return success();
}
};
struct InsertSliceOpConversion
: public ConvertTritonGPUOpToLLVMPattern<tensor::InsertSliceOp> {
using ConvertTritonGPUOpToLLVMPattern<
tensor::InsertSliceOp>::ConvertTritonGPUOpToLLVMPattern;
LogicalResult
matchAndRewrite(tensor::InsertSliceOp op, OpAdaptor adaptor,
ConversionPatternRewriter &rewriter) const override {
// %dst = insert_slice %src into %dst[%offsets]
Location loc = op->getLoc();
Value dst = op.dest();
Value src = op.source();
Value res = op.result();
assert(allocation->getBufferId(res) == Allocation::InvalidBufferId &&
"Only support in-place insert_slice for now");
auto srcTy = src.getType().dyn_cast<RankedTensorType>();
auto srcLayout = srcTy.getEncoding().dyn_cast<BlockedEncodingAttr>();
auto srcShape = srcTy.getShape();
assert(srcLayout && "Unexpected srcLayout in InsertSliceOpConversion");
auto dstTy = dst.getType().dyn_cast<RankedTensorType>();
auto dstLayout = dstTy.getEncoding().dyn_cast<SharedEncodingAttr>();
auto llDst = adaptor.dest();
assert(dstLayout && "Unexpected dstLayout in InsertSliceOpConversion");
assert(op.hasUnitStride() &&
"Only unit stride supported by InsertSliceOpConversion");
// newBase = base + offset
// Triton support either static and dynamic offsets
auto smemObj = getSharedMemoryObjectFromStruct(loc, llDst, rewriter);
SmallVector<Value, 4> offsets;
SmallVector<Value, 4> srcStrides;
auto mixedOffsets = op.getMixedOffsets();
for (auto i = 0; i < mixedOffsets.size(); ++i) {
if (op.isDynamicOffset(i)) {
offsets.emplace_back(adaptor.offsets()[i]);
} else {
offsets.emplace_back(i32_val(op.getStaticOffset(i)));
}
// Like insert_slice_async, we only support slice from one dimension,
// which has a slice size of 1
if (op.getStaticSize(i) != 1) {
srcStrides.emplace_back(smemObj.strides[i]);
}
}
// Compute the offset based on the original strides of the shared memory
// object
auto offset = dot(rewriter, loc, offsets, smemObj.strides);
auto elemTy = getTypeConverter()->convertType(dstTy.getElementType());
auto elemPtrTy = ptr_ty(elemTy, 3);
auto smemBase = gep(elemPtrTy, smemObj.base, offset);
auto llSrc = adaptor.source();
auto srcIndices =
emitBaseIndexForBlockedLayout(loc, rewriter, srcLayout, srcShape);
storeBlockedToShared(src, llSrc, srcStrides, srcIndices, dst, smemBase,
elemTy, loc, rewriter);
// Barrier is not necessary.
// The membar pass knows that it writes to shared memory and will handle it
// properly.
rewriter.replaceOp(op, llDst);
return success();
}
};
struct InsertSliceAsyncOpConversion
: public ConvertTritonGPUOpToLLVMPattern<triton::gpu::InsertSliceAsyncOp>,
public LoadStoreConversionBase {
using ConvertTritonGPUOpToLLVMPattern<
triton::gpu::InsertSliceAsyncOp>::ConvertTritonGPUOpToLLVMPattern;
InsertSliceAsyncOpConversion(LLVMTypeConverter &converter,
const Allocation *allocation, Value smem,
AxisInfoAnalysis &axisAnalysisPass,
PatternBenefit benefit)
: ConvertTritonGPUOpToLLVMPattern<triton::gpu::InsertSliceAsyncOp>(
converter, allocation, smem, benefit),
LoadStoreConversionBase(axisAnalysisPass) {}
LogicalResult
matchAndRewrite(triton::gpu::InsertSliceAsyncOp op, OpAdaptor adaptor,
ConversionPatternRewriter &rewriter) const override {
// insert_slice_async %src, %dst, %index, %mask, %other
auto loc = op.getLoc();
Value src = op.src();
Value dst = op.dst();
Value res = op.result();
Value mask = op.mask();
Value other = op.other();
assert(allocation->getBufferId(res) == Allocation::InvalidBufferId &&
"Only support in-place insert_slice_async for now");
auto srcTy = src.getType().cast<RankedTensorType>();
auto resTy = dst.getType().cast<RankedTensorType>();
auto resElemTy = getTypeConverter()->convertType(resTy.getElementType());
auto srcBlockedLayout = srcTy.getEncoding().cast<BlockedEncodingAttr>();
auto resSharedLayout = resTy.getEncoding().cast<SharedEncodingAttr>();
auto srcShape = srcTy.getShape();
assert(srcShape.size() == 2 &&
"insert_slice_async: Unexpected rank of %src");
Value llDst = adaptor.dst();
Value llSrc = adaptor.src();
Value llMask = adaptor.mask();
Value llOther = adaptor.other();
Value llIndex = adaptor.index();
// %src
auto srcElems = getLLVMElems(src, llSrc, rewriter, loc);
// %dst
auto dstTy = dst.getType().cast<RankedTensorType>();
auto dstShape = dstTy.getShape();
auto smemObj = getSharedMemoryObjectFromStruct(loc, llDst, rewriter);
auto axis = op->getAttrOfType<IntegerAttr>("axis").getInt();
SmallVector<Value, 4> offsetVals;
SmallVector<Value, 4> srcStrides;
for (auto i = 0; i < dstShape.size(); ++i) {
if (i == axis) {
offsetVals.emplace_back(llIndex);
} else {
offsetVals.emplace_back(i32_val(0));
srcStrides.emplace_back(smemObj.strides[i]);
}
}
// Compute the offset based on the original dimensions of the shared
// memory object
auto dstOffset = dot(rewriter, loc, offsetVals, smemObj.strides);
auto dstPtrTy = ptr_ty(resElemTy, 3);
Value dstPtrBase = gep(dstPtrTy, smemObj.base, dstOffset);
// %mask
SmallVector<Value> maskElems;
if (llMask) {
maskElems = getLLVMElems(mask, llMask, rewriter, loc);
assert(srcElems.size() == maskElems.size());
}
// %other
SmallVector<Value> otherElems;
if (llOther) {
// FIXME(Keren): always assume other is 0 for now
// It's not necessary for now because the pipeline pass will skip
// generating insert_slice_async if the load op has any "other" tensor.
// assert(false && "insert_slice_async: Other value not supported yet");
otherElems = getLLVMElems(other, llOther, rewriter, loc);
assert(srcElems.size() == otherElems.size());
}
unsigned inVec = getVectorSize(src);
unsigned outVec = resSharedLayout.getVec();
unsigned minVec = std::min(outVec, inVec);
unsigned numElems = getElemsPerThread(srcTy);
unsigned perPhase = resSharedLayout.getPerPhase();
unsigned maxPhase = resSharedLayout.getMaxPhase();
auto sizePerThread = srcBlockedLayout.getSizePerThread();
auto threadsPerCTA = getThreadsPerCTA(srcBlockedLayout);
auto inOrder = srcBlockedLayout.getOrder();
// If perPhase * maxPhase > threadsPerCTA, we will have elements
// that share the same tile indices. The index calculation will
// be cached.
auto numSwizzleRows = std::max<unsigned>(
(perPhase * maxPhase) / threadsPerCTA[inOrder[1]], 1);
// A sharedLayout encoding has a "vec" parameter.
// On the column dimension, if inVec > outVec, it means we have to divide
// single vector read into multiple ones
auto numVecCols = std::max<unsigned>(inVec / outVec, 1);
auto srcIndices = emitIndices(loc, rewriter, srcBlockedLayout, srcShape);
// <<tileVecIdxRow, tileVecIdxCol>, TileOffset>
DenseMap<std::pair<unsigned, unsigned>, Value> tileOffsetMap;
for (unsigned elemIdx = 0; elemIdx < numElems; elemIdx += minVec) {
// minVec = 2, inVec = 4, outVec = 2
// baseOffsetCol = 0 baseOffsetCol = 0
// tileVecIdxCol = 0 tileVecIdxCol = 1
// -/\- -/\-
// [|x x| |x x| x x x x x]
// [|x x| |x x| x x x x x]
// baseOffsetRow [|x x| |x x| x x x x x]
// [|x x| |x x| x x x x x]
auto vecIdx = elemIdx / minVec;
auto vecIdxCol = vecIdx % (sizePerThread[inOrder[0]] / minVec);
auto vecIdxRow = vecIdx / (sizePerThread[inOrder[0]] / minVec);
auto baseOffsetCol =
vecIdxCol / numVecCols * numVecCols * threadsPerCTA[inOrder[0]];
auto baseOffsetRow = vecIdxRow / numSwizzleRows * numSwizzleRows *
threadsPerCTA[inOrder[1]];
auto tileVecIdxCol = vecIdxCol % numVecCols;
auto tileVecIdxRow = vecIdxRow % numSwizzleRows;
if (!tileOffsetMap.count({tileVecIdxRow, tileVecIdxCol})) {
// Swizzling
// Since the swizzling index is related to outVec, and we know minVec
// already, inVec doesn't matter
//
// (Numbers represent row indices)
// Example1:
// outVec = 2, inVec = 2, minVec = 2
// outVec = 2, inVec = 4, minVec = 2
// | [1 2] [3 4] [5 6] ... |
// | [3 4] [1 2] [7 8] ... |
// | [5 6] [7 8] [1 2] ... |
// Example2:
// outVec = 4, inVec = 2, minVec = 2
// | [1 2 3 4] [5 6 7 8] [9 10 11 12] ... |
// | [5 6 7 8] [1 2 3 4] [13 14 15 16] ... |
// | [9 10 11 12] [13 14 15 16] [1 2 3 4] ... |
auto srcIdx = srcIndices[tileVecIdxRow * sizePerThread[inOrder[0]]];
Value phase = urem(udiv(srcIdx[inOrder[1]], i32_val(perPhase)),
i32_val(maxPhase));
// srcShape and smemObj.shape maybe different if smemObj is a
// slice of the original shared memory object.
// So we need to use the original shape to compute the offset
Value rowOffset = mul(srcIdx[inOrder[1]], srcStrides[inOrder[1]]);
Value colOffset =
add(srcIdx[inOrder[0]], i32_val(tileVecIdxCol * minVec));
Value swizzleIdx = udiv(colOffset, i32_val(outVec));
Value swizzleColOffset =
add(mul(xor_(swizzleIdx, phase), i32_val(outVec)),
urem(colOffset, i32_val(outVec)));
Value tileOffset = add(rowOffset, swizzleColOffset);
tileOffsetMap[{tileVecIdxRow, tileVecIdxCol}] =
gep(dstPtrTy, dstPtrBase, tileOffset);
}
// 16 * 8 = 128bits
auto maxBitWidth =
std::max<unsigned>(128, resElemTy.getIntOrFloatBitWidth());
auto vecBitWidth = resElemTy.getIntOrFloatBitWidth() * minVec;
auto bitWidth = std::min<unsigned>(maxBitWidth, vecBitWidth);
auto numWords = vecBitWidth / bitWidth;
auto numWordElems = bitWidth / resElemTy.getIntOrFloatBitWidth();
// Tune CG and CA here.
auto byteWidth = bitWidth / 8;
CacheModifier srcCacheModifier =
byteWidth == 16 ? CacheModifier::CG : CacheModifier::CA;
assert(byteWidth == 16 || byteWidth == 8 || byteWidth == 4);
auto resByteWidth = resElemTy.getIntOrFloatBitWidth() / 8;
Value tileOffset = tileOffsetMap[{tileVecIdxRow, tileVecIdxCol}];
Value baseOffset =
add(mul(i32_val(baseOffsetRow), srcStrides[inOrder[1]]),
i32_val(baseOffsetCol));
Value basePtr = gep(dstPtrTy, tileOffset, baseOffset);
for (size_t wordIdx = 0; wordIdx < numWords; ++wordIdx) {
PTXBuilder ptxBuilder;
auto wordElemIdx = wordIdx * numWordElems;
auto &copyAsyncOp =
*ptxBuilder.create<PTXCpAsyncLoadInstr>(srcCacheModifier);
auto *dstOperand =
ptxBuilder.newAddrOperand(basePtr, "r", wordElemIdx * resByteWidth);
auto *srcOperand =
ptxBuilder.newAddrOperand(srcElems[elemIdx + wordElemIdx], "l");
auto *copySize = ptxBuilder.newConstantOperand(byteWidth);
auto *srcSize = copySize;
if (op.mask()) {
// We don't use predicate in this case, setting src-size to 0
// if there's any mask. cp.async will automatically fill the
// remaining slots with 0 if cp-size > src-size.
// XXX(Keren): Always assume other = 0 for now.
auto selectOp = select(maskElems[elemIdx + wordElemIdx],
i32_val(byteWidth), i32_val(0));
srcSize = ptxBuilder.newOperand(selectOp, "r");
}
copyAsyncOp(dstOperand, srcOperand, copySize, srcSize);
ptxBuilder.launch(rewriter, loc, void_ty(getContext()));
}
}
PTXBuilder ptxBuilder;
ptxBuilder.create<>("cp.async.commit_group")->operator()();
ptxBuilder.launch(rewriter, loc, void_ty(getContext()));
rewriter.replaceOp(op, llDst);
return success();
}
};
void populateLoadStoreOpToLLVMPatterns(mlir::LLVMTypeConverter &typeConverter,
RewritePatternSet &patterns,
int numWarps,
AxisInfoAnalysis &axisInfoAnalysis,
const Allocation *allocation, Value smem,
PatternBenefit benefit) {
patterns.add<LoadOpConversion>(typeConverter, axisInfoAnalysis, benefit);
patterns.add<StoreOpConversion>(typeConverter, axisInfoAnalysis, benefit);
patterns.add<AtomicCASOpConversion>(typeConverter, allocation, smem,
axisInfoAnalysis, benefit);
patterns.add<AtomicRMWOpConversion>(typeConverter, allocation, smem,
axisInfoAnalysis, benefit);
patterns.add<InsertSliceOpConversion>(typeConverter, allocation, smem,
benefit);
patterns.add<InsertSliceAsyncOpConversion>(typeConverter, allocation, smem,
axisInfoAnalysis, benefit);
}

View File

@@ -0,0 +1,16 @@
#ifndef TRITON_CONVERSION_TRITONGPU_TO_LLVM_LOAD_STORE_OP_H
#define TRITON_CONVERSION_TRITONGPU_TO_LLVM_LOAD_STORE_OP_H
#include "TritonGPUToLLVMBase.h"
using namespace mlir;
using namespace mlir::triton;
void populateLoadStoreOpToLLVMPatterns(mlir::LLVMTypeConverter &typeConverter,
RewritePatternSet &patterns,
int numWarps,
AxisInfoAnalysis &axisInfoAnalysis,
const Allocation *allocation, Value smem,
PatternBenefit benefit);
#endif

View File

@@ -1,9 +1,10 @@
#include "triton/Conversion/TritonGPUToLLVM/PtxAsmFormat.h"
#include "triton/Conversion/TritonGPUToLLVM/PTXAsmFormat.h"
#include "triton/Conversion/TritonGPUToLLVM/AsmFormat.h"
#include "mlir/Dialect/LLVMIR/LLVMDialect.h"
#include "mlir/Transforms/DialectConversion.h"
#include "llvm/Support/raw_ostream.h"
#include <sstream> // unify to llvm::raw_string_ostream ?
// TODO(Superjomn): unify to llvm::raw_string_ostream
#include <sstream>
namespace mlir {
namespace triton {
@@ -33,7 +34,7 @@ PTXBuilder::Operand *PTXBuilder::newConstantOperand(const std::string &v) {
return argArchive.back().get();
}
PTXBuilder::Operand *PTXBuilder::newConstantOperand(int v) {
PTXBuilder::Operand *PTXBuilder::newConstantOperand(int64_t v) {
std::stringstream ss;
ss << "0x" << std::hex << v;
return newConstantOperand(ss.str());
@@ -116,19 +117,36 @@ std::string PTXBuilder::dump() const {
return strJoin(lines, "\n\t");
}
PTXInstrExecution &PTXInstrCommon::call(ArrayRef<Operand *> oprs) {
PTXInstrExecution &PTXInstrCommon::call(ArrayRef<Operand *> oprs,
bool onlyAttachMLIRArgs) {
if (onlyAttachMLIRArgs) {
// Nearly impossible to make the $0,$1 in two PTX code snippets to point to
// the same MLIR values in onlyAttachMLIRArgs mode.
assert(builder->executions.empty() &&
"builder can only hold a single execution when onlyAttachMIIRArgs "
"is true.");
builder->reorderArgArchive(oprs);
}
builder->executions.emplace_back(
std::make_unique<PTXInstrExecution>(this, oprs));
std::make_unique<PTXInstrExecution>(this, oprs, onlyAttachMLIRArgs));
return *builder->executions.back();
}
PTXInstrExecution &PTXInstrCommon::operator()(ArrayRef<Operand *> oprs) {
return call(oprs);
PTXInstrExecution &PTXInstrCommon::operator()(ArrayRef<Operand *> oprs,
bool onlyAttachMLIRArgs) {
return call(oprs, onlyAttachMLIRArgs);
}
std::string PTXInstrExecution::dump() const {
std::string osStr;
llvm::raw_string_ostream os(osStr);
std::string instrRepr = strJoin(instr->instrParts, ".");
if (onlyAttachMLIRArgs)
return instrRepr;
if (pred) {
if (!pred->repr)
os << "@" << pred->dump() << " ";
@@ -136,8 +154,6 @@ std::string PTXInstrExecution::dump() const {
os << pred->repr(pred->idx) << " ";
}
std::string instrRepr = strJoin(instr->instrParts, ".");
llvm::SmallVector<std::string, 4> argReprs;
for (auto *arg : argsInOrder) {
argReprs.push_back(arg->dump());
@@ -162,5 +178,27 @@ PTXInstrExecution::getArgList() const {
return args;
}
PTXInstr &PTXInstr::global() {
o("global");
return *this;
}
PTXInstr &PTXInstr::shared() {
o("shared");
return *this;
}
PTXInstr &PTXInstr::v(int vecWidth, bool predicate) {
if (vecWidth > 1) {
o("v" + std::to_string(vecWidth), predicate);
}
return *this;
}
PTXInstr &PTXInstr::b(int width) {
o("b" + std::to_string(width));
return *this;
}
} // namespace triton
} // namespace mlir

View File

@@ -0,0 +1,488 @@
#include "ReduceOpToLLVM.h"
using namespace mlir;
using namespace mlir::triton;
using ::mlir::LLVM::getElementsFromStruct;
using ::mlir::LLVM::getStructFromElements;
using ::mlir::LLVM::shflSync;
using ::mlir::LLVM::storeShared;
using ::mlir::triton::gpu::getElemsPerThread;
using ::mlir::triton::gpu::getOrder;
struct ReduceOpConversion
: public ConvertTritonGPUOpToLLVMPattern<triton::ReduceOp> {
public:
using ConvertTritonGPUOpToLLVMPattern<
triton::ReduceOp>::ConvertTritonGPUOpToLLVMPattern;
LogicalResult
matchAndRewrite(triton::ReduceOp op, OpAdaptor adaptor,
ConversionPatternRewriter &rewriter) const override {
if (ReduceOpHelper(op).isFastReduction())
return matchAndRewriteFast(op, adaptor, rewriter);
return matchAndRewriteBasic(op, adaptor, rewriter);
}
private:
void accumulate(ConversionPatternRewriter &rewriter, Location loc,
RedOp redOp, Value &acc, Value cur, bool isFirst) const {
if (isFirst) {
acc = cur;
return;
}
switch (redOp) {
case RedOp::ADD:
acc = add(acc, cur);
break;
case RedOp::FADD:
acc = fadd(acc.getType(), acc, cur);
break;
case RedOp::MIN:
acc = smin(acc, cur);
break;
case RedOp::MAX:
acc = smax(acc, cur);
break;
case RedOp::UMIN:
acc = umin(acc, cur);
break;
case RedOp::UMAX:
acc = umax(acc, cur);
break;
case RedOp::FMIN:
acc = fmin(acc, cur);
break;
case RedOp::FMAX:
acc = fmax(acc, cur);
break;
case RedOp::XOR:
acc = xor_(acc, cur);
break;
case RedOp::ARGMIN:
case RedOp::ARGMAX:
case RedOp::ARGUMIN:
case RedOp::ARGUMAX:
case RedOp::ARGFMIN:
case RedOp::ARGFMAX:
llvm::report_fatal_error(
"This accumulate implementation is not for argmin / argmax");
default:
llvm::report_fatal_error("Unsupported reduce op");
}
}
void accumulateWithIndex(ConversionPatternRewriter &rewriter, Location loc,
RedOp redOp, Value &acc, Value &accIndex, Value cur,
Value curIndex, bool isFirst) const {
if (isFirst) {
acc = cur;
accIndex = curIndex;
return;
}
switch (redOp) {
case RedOp::ARGMIN:
accIndex = select(
icmp_slt(acc, cur), accIndex,
select(icmp_sgt(acc, cur), curIndex, smin(accIndex, curIndex)));
acc = smin(acc, cur);
break;
case RedOp::ARGMAX:
accIndex = select(
icmp_sgt(acc, cur), accIndex,
select(icmp_slt(acc, cur), curIndex, smin(accIndex, curIndex)));
acc = smax(acc, cur);
break;
case RedOp::ARGUMIN:
accIndex = select(
icmp_ult(acc, cur), accIndex,
select(icmp_ugt(acc, cur), curIndex, smin(accIndex, curIndex)));
acc = umin(acc, cur);
break;
case RedOp::ARGUMAX:
accIndex = select(
icmp_ugt(acc, cur), accIndex,
select(icmp_ult(acc, cur), curIndex, smin(accIndex, curIndex)));
acc = umax(acc, cur);
break;
case RedOp::ARGFMIN:
accIndex = select(
fcmp_olt(acc, cur), accIndex,
select(fcmp_ogt(acc, cur), curIndex, smin(accIndex, curIndex)));
acc = fmin(acc, cur);
break;
case RedOp::ARGFMAX:
accIndex = select(
fcmp_ogt(acc, cur), accIndex,
select(fcmp_olt(acc, cur), curIndex, smin(accIndex, curIndex)));
acc = fmax(acc, cur);
break;
case RedOp::ADD:
case RedOp::FADD:
case RedOp::MIN:
case RedOp::MAX:
case RedOp::UMIN:
case RedOp::UMAX:
case RedOp::FMIN:
case RedOp::FMAX:
case RedOp::XOR:
llvm::report_fatal_error(
"This accumulate implementation is only for argmin / argmax");
default:
llvm::report_fatal_error("Unsupported reduce op");
}
}
// Use shared memory for reduction within warps and across warps
LogicalResult
matchAndRewriteBasic(triton::ReduceOp op, OpAdaptor adaptor,
ConversionPatternRewriter &rewriter) const {
Location loc = op->getLoc();
unsigned axis = op.axis();
bool withIndex = triton::ReduceOp::withIndex(op.redOp());
auto srcTy = op.operand().getType().cast<RankedTensorType>();
auto srcLayout = srcTy.getEncoding().cast<BlockedEncodingAttr>();
auto srcOrd = srcLayout.getOrder();
auto srcShape = srcTy.getShape();
auto llvmElemTy = getTypeConverter()->convertType(srcTy.getElementType());
auto llvmIndexTy = getTypeConverter()->getIndexType();
auto elemPtrTy = LLVM::LLVMPointerType::get(llvmElemTy, 3);
auto indexPtrTy = LLVM::LLVMPointerType::get(llvmIndexTy, 3);
Value smemBase = getSharedMemoryBase(loc, rewriter, op.getOperation());
smemBase = bitcast(smemBase, elemPtrTy);
ReduceOpHelper helper(op);
auto smemShape = helper.getScratchConfigBasic();
unsigned elems = product<unsigned>(smemShape);
Value indexSmemBase = gep(elemPtrTy, smemBase, i32_val(elems));
indexSmemBase = bitcast(indexSmemBase, indexPtrTy);
unsigned srcElems = getElemsPerThread(srcTy);
auto srcIndices = emitIndices(loc, rewriter, srcLayout, srcShape);
auto srcValues = getElementsFromStruct(loc, adaptor.operand(), rewriter);
SmallVector<SmallVector<unsigned>> offset =
emitOffsetForBlockedLayout(srcLayout, srcShape);
std::map<SmallVector<unsigned>, Value> accs;
std::map<SmallVector<unsigned>, Value> accIndices;
std::map<SmallVector<unsigned>, SmallVector<Value>> indices;
// reduce within threads
for (unsigned i = 0; i < srcElems; ++i) {
SmallVector<unsigned> key = offset[i];
key[axis] = 0;
bool isFirst = accs.find(key) == accs.end();
if (!withIndex) {
accumulate(rewriter, loc, op.redOp(), accs[key], srcValues[i], isFirst);
} else {
Value curIndex = srcIndices[i][axis];
accumulateWithIndex(rewriter, loc, op.redOp(), accs[key],
accIndices[key], srcValues[i], curIndex, isFirst);
}
if (isFirst)
indices[key] = srcIndices[i];
}
// cached int32 constants
std::map<int, Value> ints;
ints[0] = i32_val(0);
for (int N = smemShape[axis] / 2; N > 0; N >>= 1)
ints[N] = i32_val(N);
Value sizePerThread = i32_val(srcLayout.getSizePerThread()[axis]);
// reduce across threads
for (auto it : accs) {
const SmallVector<unsigned> &key = it.first;
Value acc = it.second;
Value accIndex;
if (withIndex)
accIndex = accIndices[key];
SmallVector<Value> writeIdx = indices[key];
writeIdx[axis] = udiv(writeIdx[axis], sizePerThread);
Value writeOffset = linearize(rewriter, loc, writeIdx, smemShape, srcOrd);
Value writePtr = gep(elemPtrTy, smemBase, writeOffset);
Value indexWritePtr = gep(indexPtrTy, indexSmemBase, writeOffset);
store(acc, writePtr);
if (withIndex)
store(accIndex, indexWritePtr);
SmallVector<Value> readIdx(writeIdx.size(), ints[0]);
for (int N = smemShape[axis] / 2; N > 0; N >>= 1) {
readIdx[axis] = ints[N];
Value readMask = icmp_slt(writeIdx[axis], ints[N]);
Value readOffset = select(
readMask, linearize(rewriter, loc, readIdx, smemShape, srcOrd),
ints[0]);
Value readPtr = gep(elemPtrTy, writePtr, readOffset);
barrier();
if (!withIndex) {
Value cur = load(readPtr);
accumulate(rewriter, loc, op.redOp(), acc, cur, false);
barrier();
store(acc, writePtr);
} else {
Value cur = load(readPtr);
Value indexReadPtr = gep(indexPtrTy, indexWritePtr, readOffset);
Value curIndex = load(indexReadPtr);
accumulateWithIndex(rewriter, loc, op.redOp(), acc, accIndex, cur,
curIndex, false);
barrier();
store(acc, writePtr);
store(accIndex, indexWritePtr);
}
}
}
barrier();
// set output values
if (auto resultTy = op.getType().dyn_cast<RankedTensorType>()) {
// nd-tensor where n >= 1
auto resultLayout = resultTy.getEncoding();
auto resultShape = resultTy.getShape();
unsigned resultElems = getElemsPerThread(resultTy);
auto resultIndices =
emitIndices(loc, rewriter, resultLayout, resultShape);
assert(resultIndices.size() == resultElems);
SmallVector<Value> resultVals(resultElems);
for (unsigned i = 0; i < resultElems; ++i) {
SmallVector<Value> readIdx = resultIndices[i];
readIdx.insert(readIdx.begin() + axis, ints[0]);
Value readOffset = linearize(rewriter, loc, readIdx, smemShape, srcOrd);
Value readPtr = gep(elemPtrTy, smemBase, readOffset);
Value indexReadPtr = gep(indexPtrTy, indexSmemBase, readOffset);
resultVals[i] = withIndex ? load(indexReadPtr) : load(readPtr);
}
SmallVector<Type> resultTypes(resultElems,
withIndex ? llvmIndexTy : llvmElemTy);
Type structTy =
LLVM::LLVMStructType::getLiteral(this->getContext(), resultTypes);
Value ret = getStructFromElements(loc, resultVals, rewriter, structTy);
rewriter.replaceOp(op, ret);
} else {
// 0d-tensor -> scalar
Value resultVal = withIndex ? load(indexSmemBase) : load(smemBase);
rewriter.replaceOp(op, resultVal);
}
return success();
}
// Use warp shuffle for reduction within warps and shared memory for data
// exchange across warps
LogicalResult matchAndRewriteFast(triton::ReduceOp op, OpAdaptor adaptor,
ConversionPatternRewriter &rewriter) const {
Location loc = op->getLoc();
unsigned axis = adaptor.axis();
bool withIndex = triton::ReduceOp::withIndex(op.redOp());
auto srcTy = op.operand().getType().cast<RankedTensorType>();
auto srcLayout = srcTy.getEncoding();
auto srcShape = srcTy.getShape();
auto srcRank = srcTy.getRank();
auto order = getOrder(srcLayout);
auto threadsPerWarp = triton::gpu::getThreadsPerWarp(srcLayout);
auto warpsPerCTA = triton::gpu::getWarpsPerCTA(srcLayout);
auto llvmElemTy = getTypeConverter()->convertType(srcTy.getElementType());
auto llvmIndexTy = getTypeConverter()->getIndexType();
auto elemPtrTy = LLVM::LLVMPointerType::get(llvmElemTy, 3);
auto indexPtrTy = LLVM::LLVMPointerType::get(llvmIndexTy, 3);
Value smemBase = getSharedMemoryBase(loc, rewriter, op.getOperation());
smemBase = bitcast(smemBase, elemPtrTy);
ReduceOpHelper helper(op);
auto smemShapes = helper.getScratchConfigsFast();
unsigned elems = product<unsigned>(smemShapes[0]);
unsigned maxElems = std::max(elems, product<unsigned>(smemShapes[1]));
Value indexSmemBase = gep(elemPtrTy, smemBase, i32_val(maxElems));
indexSmemBase = bitcast(indexSmemBase, indexPtrTy);
unsigned sizeIntraWarps = helper.getIntraWarpSize();
unsigned sizeInterWarps = helper.getInterWarpSize();
unsigned srcElems = getElemsPerThread(srcTy);
auto srcIndices = emitIndices(loc, rewriter, srcLayout, srcShape);
auto srcValues = getElementsFromStruct(loc, adaptor.operand(), rewriter);
SmallVector<SmallVector<unsigned>> offset =
emitOffsetForLayout(srcLayout, srcShape);
std::map<SmallVector<unsigned>, Value> accs;
std::map<SmallVector<unsigned>, Value> accIndices;
std::map<SmallVector<unsigned>, SmallVector<Value>> indices;
// reduce within threads
for (unsigned i = 0; i < srcElems; ++i) {
SmallVector<unsigned> key = offset[i];
key[axis] = 0;
bool isFirst = accs.find(key) == accs.end();
if (!withIndex) {
accumulate(rewriter, loc, op.redOp(), accs[key], srcValues[i], isFirst);
} else {
Value curIndex = srcIndices[i][axis];
accumulateWithIndex(rewriter, loc, op.redOp(), accs[key],
accIndices[key], srcValues[i], curIndex, isFirst);
}
if (isFirst)
indices[key] = srcIndices[i];
}
Value threadId = getThreadId(rewriter, loc);
Value warpSize = i32_val(32);
Value warpId = udiv(threadId, warpSize);
Value laneId = urem(threadId, warpSize);
SmallVector<Value> multiDimLaneId =
delinearize(rewriter, loc, laneId, threadsPerWarp, order);
SmallVector<Value> multiDimWarpId =
delinearize(rewriter, loc, warpId, warpsPerCTA, order);
Value laneIdAxis = multiDimLaneId[axis];
Value warpIdAxis = multiDimWarpId[axis];
Value zero = i32_val(0);
Value laneZero = icmp_eq(laneIdAxis, zero);
Value warpZero = icmp_eq(warpIdAxis, zero);
for (auto it : accs) {
const SmallVector<unsigned> &key = it.first;
Value acc = it.second;
Value accIndex;
if (withIndex)
accIndex = accIndices[key];
// Reduce within warps
for (unsigned N = sizeIntraWarps / 2; N > 0; N >>= 1) {
Value shfl = shflSync(loc, rewriter, acc, N);
if (!withIndex) {
accumulate(rewriter, loc, op.redOp(), acc, shfl, false);
} else {
Value shflIndex = shflSync(loc, rewriter, accIndex, N);
accumulateWithIndex(rewriter, loc, op.redOp(), acc, accIndex, shfl,
shflIndex, false);
}
}
SmallVector<Value> writeIdx = indices[key];
writeIdx[axis] = (sizeInterWarps == 1) ? zero : warpIdAxis;
Value writeOffset =
linearize(rewriter, loc, writeIdx, smemShapes[0], order);
Value writePtr = gep(elemPtrTy, smemBase, writeOffset);
storeShared(rewriter, loc, writePtr, acc, laneZero);
if (withIndex) {
Value indexWritePtr = gep(indexPtrTy, indexSmemBase, writeOffset);
storeShared(rewriter, loc, indexWritePtr, accIndex, laneZero);
}
}
barrier();
// The second round of shuffle reduction
// now the problem size: sizeInterWarps, s1, s2, .. , sn
// where sizeInterWarps is 2^m
//
// Each thread needs to process:
// elemsPerThread = sizeInterWarps * s1 * s2 .. Sn / numThreads
unsigned numThreads =
product<unsigned>(triton::gpu::getWarpsPerCTA(srcLayout)) * 32;
unsigned elemsPerThread = std::max<unsigned>(elems / numThreads, 1);
Value readOffset = threadId;
for (unsigned round = 0; round < elemsPerThread; ++round) {
Value readPtr = gep(elemPtrTy, smemBase, readOffset);
// FIXME(Qingyi): need predicate icmp_slt(threadId,
// i32_val(sizeInerWarps))
Value acc = load(readPtr);
Value accIndex;
if (withIndex) {
Value readIndexPtr = gep(indexPtrTy, indexSmemBase, readOffset);
accIndex = load(readIndexPtr);
}
for (unsigned N = sizeInterWarps / 2; N > 0; N >>= 1) {
Value shfl = shflSync(loc, rewriter, acc, N);
if (!withIndex) {
accumulate(rewriter, loc, op.redOp(), acc, shfl, false);
} else {
Value shflIndex = shflSync(loc, rewriter, accIndex, N);
accumulateWithIndex(rewriter, loc, op.redOp(), acc, accIndex, shfl,
shflIndex, false);
}
}
// only the first thread in each sizeInterWarps is writing
Value writeOffset = readOffset;
Value writePtr = gep(elemPtrTy, smemBase, writeOffset);
Value threadIsNeeded = icmp_slt(threadId, i32_val(elems));
Value laneIdModSizeInterWarps = urem(laneId, i32_val(sizeInterWarps));
Value laneIdModSizeInterWarpsIsZero =
icmp_eq(laneIdModSizeInterWarps, zero);
Value pred = and_(threadIsNeeded, laneIdModSizeInterWarpsIsZero);
storeShared(rewriter, loc, writePtr, acc, pred);
if (withIndex) {
Value writeIndexPtr = gep(indexPtrTy, indexSmemBase, writeOffset);
storeShared(rewriter, loc, writeIndexPtr, accIndex, pred);
}
if (round != elemsPerThread - 1) {
readOffset = add(readOffset, i32_val(numThreads));
}
}
// We could avoid this barrier in some of the layouts, however this is not
// the general case.
// TODO: optimize the barrier incase the layouts are accepted.
barrier();
// set output values
if (auto resultTy = op.getType().dyn_cast<RankedTensorType>()) {
// nd-tensor where n >= 1
auto resultLayout = resultTy.getEncoding().cast<SliceEncodingAttr>();
auto resultShape = resultTy.getShape();
unsigned resultElems = getElemsPerThread(resultTy);
auto resultIndices =
emitIndices(loc, rewriter, resultLayout, resultShape);
assert(resultIndices.size() == resultElems);
SmallVector<Value> resultVals(resultElems);
for (size_t i = 0; i < resultElems; ++i) {
SmallVector<Value> readIdx = resultIndices[i];
readIdx.insert(readIdx.begin() + axis, i32_val(0));
Value readOffset =
linearize(rewriter, loc, readIdx, smemShapes[0], order);
Value readPtr = gep(elemPtrTy, smemBase, readOffset);
Value indexReadPtr = gep(indexPtrTy, indexSmemBase, readOffset);
resultVals[i] = withIndex ? load(indexReadPtr) : load(readPtr);
}
SmallVector<Type> resultTypes(resultElems,
withIndex ? llvmIndexTy : llvmElemTy);
Type structTy =
LLVM::LLVMStructType::getLiteral(this->getContext(), resultTypes);
Value ret = getStructFromElements(loc, resultVals, rewriter, structTy);
rewriter.replaceOp(op, ret);
} else {
// 0d-tensor -> scalar
Value resultVal = withIndex ? load(indexSmemBase) : load(smemBase);
rewriter.replaceOp(op, resultVal);
}
return success();
}
};
void populateReduceOpToLLVMPatterns(mlir::LLVMTypeConverter &typeConverter,
RewritePatternSet &patterns, int numWarps,
AxisInfoAnalysis &axisInfoAnalysis,
const Allocation *allocation, Value smem,
PatternBenefit benefit) {
patterns.add<ReduceOpConversion>(typeConverter, allocation, smem, benefit);
}

View File

@@ -0,0 +1,15 @@
#ifndef TRITON_CONVERSION_TRITONGPU_TO_LLVM_REDUCE_OP_H
#define TRITON_CONVERSION_TRITONGPU_TO_LLVM_REDUCE_OP_H
#include "TritonGPUToLLVMBase.h"
using namespace mlir;
using namespace mlir::triton;
void populateReduceOpToLLVMPatterns(mlir::LLVMTypeConverter &typeConverter,
RewritePatternSet &patterns, int numWarps,
AxisInfoAnalysis &axisInfoAnalysis,
const Allocation *allocation, Value smem,
PatternBenefit benefit);
#endif

File diff suppressed because it is too large Load Diff

View File

@@ -0,0 +1,15 @@
#ifndef TRITON_CONVERSION_TRITONGPU_TO_LLVM_H
#define TRITON_CONVERSION_TRITONGPU_TO_LLVM_H
#include "TritonGPUToLLVMBase.h"
using namespace mlir;
using namespace mlir::triton;
void populateTritonGPUToLLVMPatterns(mlir::LLVMTypeConverter &typeConverter,
RewritePatternSet &patterns, int numWarps,
AxisInfoAnalysis &axisInfoAnalysis,
const Allocation *allocation, Value smem,
PatternBenefit benefit);
#endif

View File

@@ -0,0 +1,550 @@
#ifndef TRITON_CONVERSION_TRITONGPU_TO_LLVM_BASE_H
#define TRITON_CONVERSION_TRITONGPU_TO_LLVM_BASE_H
// TODO: refactor so that it doesn't fail if Allocation.h
// is included after utility.h (due to conflict in `store` macro
// and <atomic>
#include "triton/Analysis/Allocation.h"
//
#include "Utility.h"
#include "mlir/IR/TypeUtilities.h"
#include "triton/Analysis/AxisInfo.h"
using namespace mlir;
using namespace mlir::triton;
using ::mlir::LLVM::SharedMemoryObject;
using ::mlir::triton::gpu::BlockedEncodingAttr;
using ::mlir::triton::gpu::MmaEncodingAttr;
using ::mlir::triton::gpu::SliceEncodingAttr;
// FuncOpConversion/FuncOpConversionBase is borrowed from
// https://github.com/llvm/llvm-project/blob/fae656b2dd80246c3c6f01e9c77c49560368752c/mlir/lib/Conversion/FuncToLLVM/FuncToLLVM.cpp#L276
// since it is not exposed on header files in mlir v14
// TODO(Superjomn): remove the code when MLIR v15.0 is included.
// All the rights are reserved by the LLVM community.
struct FuncOpConversionBase : public ConvertOpToLLVMPattern<FuncOp> {
private:
/// Only retain those attributes that are not constructed by
/// `LLVMFuncOp::build`. If `filterArgAttrs` is set, also filter out argument
/// attributes.
static void filterFuncAttributes(ArrayRef<NamedAttribute> attrs,
bool filterArgAttrs,
SmallVectorImpl<NamedAttribute> &result) {
for (const auto &attr : attrs) {
if (attr.getName() == SymbolTable::getSymbolAttrName() ||
attr.getName() == FunctionOpInterface::getTypeAttrName() ||
attr.getName() == "std.varargs" ||
(filterArgAttrs &&
attr.getName() == FunctionOpInterface::getArgDictAttrName()))
continue;
result.push_back(attr);
}
}
/// Helper function for wrapping all attributes into a single DictionaryAttr
static auto wrapAsStructAttrs(OpBuilder &b, ArrayAttr attrs) {
return DictionaryAttr::get(b.getContext(),
b.getNamedAttr("llvm.struct_attrs", attrs));
}
protected:
using ConvertOpToLLVMPattern<FuncOp>::ConvertOpToLLVMPattern;
// Convert input FuncOp to LLVMFuncOp by using the LLVMTypeConverter provided
// to this legalization pattern.
LLVM::LLVMFuncOp
convertFuncOpToLLVMFuncOp(FuncOp funcOp,
ConversionPatternRewriter &rewriter) const {
// Convert the original function arguments. They are converted using the
// LLVMTypeConverter provided to this legalization pattern.
auto varargsAttr = funcOp->getAttrOfType<BoolAttr>("func.varargs");
TypeConverter::SignatureConversion result(funcOp.getNumArguments());
auto llvmType = getTypeConverter()->convertFunctionSignature(
funcOp.getType(), varargsAttr && varargsAttr.getValue(), result);
if (!llvmType)
return nullptr;
// Propagate argument/result attributes to all converted arguments/result
// obtained after converting a given original argument/result.
SmallVector<NamedAttribute, 4> attributes;
filterFuncAttributes(funcOp->getAttrs(), /*filterArgAttrs=*/true,
attributes);
if (ArrayAttr resAttrDicts = funcOp.getAllResultAttrs()) {
assert(!resAttrDicts.empty() && "expected array to be non-empty");
auto newResAttrDicts =
(funcOp.getNumResults() == 1)
? resAttrDicts
: rewriter.getArrayAttr(
{wrapAsStructAttrs(rewriter, resAttrDicts)});
attributes.push_back(rewriter.getNamedAttr(
FunctionOpInterface::getResultDictAttrName(), newResAttrDicts));
}
if (ArrayAttr argAttrDicts = funcOp.getAllArgAttrs()) {
SmallVector<Attribute, 4> newArgAttrs(
llvmType.cast<LLVM::LLVMFunctionType>().getNumParams());
for (unsigned i = 0, e = funcOp.getNumArguments(); i < e; ++i) {
auto mapping = result.getInputMapping(i);
assert(mapping && "unexpected deletion of function argument");
for (size_t j = 0; j < mapping->size; ++j)
newArgAttrs[mapping->inputNo + j] = argAttrDicts[i];
}
attributes.push_back(
rewriter.getNamedAttr(FunctionOpInterface::getArgDictAttrName(),
rewriter.getArrayAttr(newArgAttrs)));
}
for (const auto &pair : llvm::enumerate(attributes)) {
if (pair.value().getName() == "llvm.linkage") {
attributes.erase(attributes.begin() + pair.index());
break;
}
}
// Create an LLVM function, use external linkage by default until MLIR
// functions have linkage.
LLVM::Linkage linkage = LLVM::Linkage::External;
if (funcOp->hasAttr("llvm.linkage")) {
auto attr =
funcOp->getAttr("llvm.linkage").dyn_cast<mlir::LLVM::LinkageAttr>();
if (!attr) {
funcOp->emitError()
<< "Contains llvm.linkage attribute not of type LLVM::LinkageAttr";
return nullptr;
}
linkage = attr.getLinkage();
}
auto newFuncOp = rewriter.create<LLVM::LLVMFuncOp>(
funcOp.getLoc(), funcOp.getName(), llvmType, linkage,
/*dsoLocal*/ false, attributes);
rewriter.inlineRegionBefore(funcOp.getBody(), newFuncOp.getBody(),
newFuncOp.end());
if (failed(rewriter.convertRegionTypes(&newFuncOp.getBody(), *typeConverter,
&result)))
return nullptr;
return newFuncOp;
}
};
struct ConvertTritonGPUOpToLLVMPatternBase {
static Value
getStructFromSharedMemoryObject(Location loc,
const SharedMemoryObject &smemObj,
ConversionPatternRewriter &rewriter) {
auto elems = smemObj.getElems();
auto types = smemObj.getTypes();
auto structTy =
LLVM::LLVMStructType::getLiteral(rewriter.getContext(), types);
return getStructFromElements(loc, elems, rewriter, structTy);
}
};
template <typename SourceOp>
class ConvertTritonGPUOpToLLVMPattern
: public ConvertOpToLLVMPattern<SourceOp>,
public ConvertTritonGPUOpToLLVMPatternBase {
public:
using OpAdaptor = typename SourceOp::Adaptor;
explicit ConvertTritonGPUOpToLLVMPattern(LLVMTypeConverter &typeConverter,
PatternBenefit benefit = 1)
: ConvertOpToLLVMPattern<SourceOp>(typeConverter, benefit) {}
explicit ConvertTritonGPUOpToLLVMPattern(LLVMTypeConverter &typeConverter,
const Allocation *allocation,
Value smem,
PatternBenefit benefit = 1)
: ConvertOpToLLVMPattern<SourceOp>(typeConverter, benefit),
allocation(allocation), smem(smem) {}
Value getThreadId(ConversionPatternRewriter &rewriter, Location loc) const {
auto llvmIndexTy = this->getTypeConverter()->getIndexType();
auto cast = rewriter.create<UnrealizedConversionCastOp>(
loc, TypeRange{llvmIndexTy},
ValueRange{rewriter.create<::mlir::gpu::ThreadIdOp>(
loc, rewriter.getIndexType(), ::mlir::gpu::Dimension::x)});
Value threadId = cast.getResult(0);
return threadId;
}
// -----------------------------------------------------------------------
// Utilities
// -----------------------------------------------------------------------
// Convert an \param index to a multi-dim coordinate given \param shape and
// \param order.
SmallVector<Value> delinearize(ConversionPatternRewriter &rewriter,
Location loc, Value linear,
ArrayRef<unsigned> shape,
ArrayRef<unsigned> order) const {
unsigned rank = shape.size();
assert(rank == order.size());
auto reordered = reorder(shape, order);
auto reorderedMultiDim = delinearize(rewriter, loc, linear, reordered);
SmallVector<Value> multiDim(rank);
for (unsigned i = 0; i < rank; ++i) {
multiDim[order[i]] = reorderedMultiDim[i];
}
return multiDim;
}
SmallVector<Value> delinearize(ConversionPatternRewriter &rewriter,
Location loc, Value linear,
ArrayRef<unsigned> shape) const {
unsigned rank = shape.size();
assert(rank > 0);
SmallVector<Value> multiDim(rank);
if (rank == 1) {
multiDim[0] = linear;
} else {
Value remained = linear;
for (auto &&en : llvm::enumerate(shape.drop_back())) {
Value dimSize = idx_val(en.value());
multiDim[en.index()] = urem(remained, dimSize);
remained = udiv(remained, dimSize);
}
multiDim[rank - 1] = remained;
}
return multiDim;
}
Value linearize(ConversionPatternRewriter &rewriter, Location loc,
ArrayRef<Value> multiDim, ArrayRef<unsigned> shape,
ArrayRef<unsigned> order) const {
return linearize(rewriter, loc, reorder<Value>(multiDim, order),
reorder<unsigned>(shape, order));
}
Value linearize(ConversionPatternRewriter &rewriter, Location loc,
ArrayRef<Value> multiDim, ArrayRef<unsigned> shape) const {
auto rank = multiDim.size();
Value linear = idx_val(0);
if (rank > 0) {
linear = multiDim.back();
for (auto [dim, dimShape] :
llvm::reverse(llvm::zip(multiDim.drop_back(), shape.drop_back()))) {
Value dimSize = idx_val(dimShape);
linear = add(mul(linear, dimSize), dim);
}
}
return linear;
}
Value dot(ConversionPatternRewriter &rewriter, Location loc,
ArrayRef<Value> offsets, ArrayRef<Value> strides) const {
assert(offsets.size() == strides.size());
Value ret = idx_val(0);
for (auto [offset, stride] : llvm::zip(offsets, strides)) {
ret = add(ret, mul(offset, stride));
}
return ret;
}
// -----------------------------------------------------------------------
// Blocked layout indices
// -----------------------------------------------------------------------
// Get an index-base for each dimension for a \param blocked_layout.
SmallVector<Value>
emitBaseIndexForBlockedLayout(Location loc,
ConversionPatternRewriter &rewriter,
const BlockedEncodingAttr &blocked_layout,
ArrayRef<int64_t> shape) const {
Value threadId = getThreadId(rewriter, loc);
Value warpSize = idx_val(32);
Value laneId = urem(threadId, warpSize);
Value warpId = udiv(threadId, warpSize);
auto sizePerThread = blocked_layout.getSizePerThread();
auto threadsPerWarp = blocked_layout.getThreadsPerWarp();
auto warpsPerCTA = blocked_layout.getWarpsPerCTA();
auto order = blocked_layout.getOrder();
unsigned rank = shape.size();
// delinearize threadId to get the base index
SmallVector<Value> multiDimWarpId =
delinearize(rewriter, loc, warpId, warpsPerCTA, order);
SmallVector<Value> multiDimThreadId =
delinearize(rewriter, loc, laneId, threadsPerWarp, order);
SmallVector<Value> multiDimBase(rank);
for (unsigned k = 0; k < rank; ++k) {
// Wrap around multiDimWarpId/multiDimThreadId incase
// shape[k] > shapePerCTA[k]
auto maxWarps =
ceil<unsigned>(shape[k], sizePerThread[k] * threadsPerWarp[k]);
auto maxThreads = ceil<unsigned>(shape[k], sizePerThread[k]);
multiDimWarpId[k] = urem(multiDimWarpId[k], idx_val(maxWarps));
multiDimThreadId[k] = urem(multiDimThreadId[k], idx_val(maxThreads));
// multiDimBase[k] = (multiDimThreadId[k] +
// multiDimWarpId[k] * threadsPerWarp[k]) *
// sizePerThread[k];
Value threadsPerWarpK = idx_val(threadsPerWarp[k]);
Value sizePerThreadK = idx_val(sizePerThread[k]);
multiDimBase[k] =
mul(sizePerThreadK, add(multiDimThreadId[k],
mul(multiDimWarpId[k], threadsPerWarpK)));
}
return multiDimBase;
}
SmallVector<SmallVector<unsigned>>
emitOffsetForBlockedLayout(const BlockedEncodingAttr &blockedLayout,
ArrayRef<int64_t> shape) const {
auto sizePerThread = blockedLayout.getSizePerThread();
auto threadsPerWarp = blockedLayout.getThreadsPerWarp();
auto warpsPerCTA = blockedLayout.getWarpsPerCTA();
auto order = blockedLayout.getOrder();
unsigned rank = shape.size();
SmallVector<unsigned> shapePerCTA = getShapePerCTA(blockedLayout);
SmallVector<unsigned> tilesPerDim(rank);
for (unsigned k = 0; k < rank; ++k)
tilesPerDim[k] = ceil<unsigned>(shape[k], shapePerCTA[k]);
SmallVector<SmallVector<unsigned>> offset(rank);
for (unsigned k = 0; k < rank; ++k) {
// 1 block in minimum if shape[k] is less than shapePerCTA[k]
for (unsigned blockOffset = 0; blockOffset < tilesPerDim[k];
++blockOffset)
for (unsigned warpOffset = 0; warpOffset < warpsPerCTA[k]; ++warpOffset)
for (unsigned threadOffset = 0; threadOffset < threadsPerWarp[k];
++threadOffset)
for (unsigned elemOffset = 0; elemOffset < sizePerThread[k];
++elemOffset)
offset[k].push_back(blockOffset * sizePerThread[k] *
threadsPerWarp[k] * warpsPerCTA[k] +
warpOffset * sizePerThread[k] *
threadsPerWarp[k] +
threadOffset * sizePerThread[k] + elemOffset);
}
unsigned elemsPerThread = blockedLayout.getElemsPerThread(shape);
unsigned totalSizePerThread = product<unsigned>(sizePerThread);
SmallVector<SmallVector<unsigned>> reorderedOffset(elemsPerThread);
for (unsigned n = 0; n < elemsPerThread; ++n) {
unsigned linearNanoTileId = n / totalSizePerThread;
unsigned linearNanoTileElemId = n % totalSizePerThread;
SmallVector<unsigned> multiDimNanoTileId =
getMultiDimIndex<unsigned>(linearNanoTileId, tilesPerDim, order);
SmallVector<unsigned> multiDimNanoTileElemId = getMultiDimIndex<unsigned>(
linearNanoTileElemId, sizePerThread, order);
for (unsigned k = 0; k < rank; ++k) {
unsigned reorderedMultiDimId =
multiDimNanoTileId[k] *
(sizePerThread[k] * threadsPerWarp[k] * warpsPerCTA[k]) +
multiDimNanoTileElemId[k];
reorderedOffset[n].push_back(offset[k][reorderedMultiDimId]);
}
}
return reorderedOffset;
}
// -----------------------------------------------------------------------
// Mma layout indices
// -----------------------------------------------------------------------
SmallVector<Value>
emitBaseIndexForMmaLayoutV1(Location loc, ConversionPatternRewriter &rewriter,
const MmaEncodingAttr &mmaLayout,
ArrayRef<int64_t> shape) const {
llvm_unreachable("emitIndicesForMmaLayoutV1 not implemented");
}
SmallVector<SmallVector<unsigned>>
emitOffsetForMmaLayoutV1(const MmaEncodingAttr &mmaLayout,
ArrayRef<int64_t> shape) const {
SmallVector<SmallVector<unsigned>> ret;
for (unsigned i = 0; i < shape[0]; i += getShapePerCTA(mmaLayout)[0]) {
for (unsigned j = 0; j < shape[1]; j += getShapePerCTA(mmaLayout)[1]) {
ret.push_back({i, j});
ret.push_back({i, j + 1});
ret.push_back({i + 2, j});
ret.push_back({i + 2, j + 1});
ret.push_back({i, j + 8});
ret.push_back({i, j + 9});
ret.push_back({i + 2, j + 8});
ret.push_back({i + 2, j + 9});
}
}
return ret;
}
SmallVector<Value>
emitBaseIndexForMmaLayoutV2(Location loc, ConversionPatternRewriter &rewriter,
const MmaEncodingAttr &mmaLayout,
ArrayRef<int64_t> shape) const {
auto _warpsPerCTA = mmaLayout.getWarpsPerCTA();
assert(_warpsPerCTA.size() == 2);
SmallVector<Value> warpsPerCTA = {idx_val(_warpsPerCTA[0]),
idx_val(_warpsPerCTA[1])};
Value threadId = getThreadId(rewriter, loc);
Value warpSize = idx_val(32);
Value laneId = urem(threadId, warpSize);
Value warpId = udiv(threadId, warpSize);
Value warpId0 = urem(warpId, warpsPerCTA[0]);
Value warpId1 = urem(udiv(warpId, warpsPerCTA[0]), warpsPerCTA[1]);
Value offWarp0 = mul(warpId0, idx_val(16));
Value offWarp1 = mul(warpId1, idx_val(8));
SmallVector<Value> multiDimBase(2);
multiDimBase[0] = add(udiv(laneId, idx_val(4)), offWarp0);
multiDimBase[1] = add(mul(idx_val(2), urem(laneId, idx_val(4))), offWarp1);
return multiDimBase;
}
SmallVector<SmallVector<unsigned>>
emitOffsetForMmaLayoutV2(const MmaEncodingAttr &mmaLayout,
ArrayRef<int64_t> shape) const {
SmallVector<SmallVector<unsigned>> ret;
for (unsigned i = 0; i < shape[0]; i += getShapePerCTA(mmaLayout)[0]) {
for (unsigned j = 0; j < shape[1]; j += getShapePerCTA(mmaLayout)[1]) {
ret.push_back({i, j});
ret.push_back({i, j + 1});
ret.push_back({i + 8, j});
ret.push_back({i + 8, j + 1});
}
}
return ret;
}
// -----------------------------------------------------------------------
// Get offsets / indices for any layout
// -----------------------------------------------------------------------
SmallVector<Value> emitBaseIndexForLayout(Location loc,
ConversionPatternRewriter &rewriter,
const Attribute &layout,
ArrayRef<int64_t> shape) const {
if (auto blockedLayout = layout.dyn_cast<BlockedEncodingAttr>())
return emitBaseIndexForBlockedLayout(loc, rewriter, blockedLayout, shape);
if (auto mmaLayout = layout.dyn_cast<MmaEncodingAttr>()) {
if (mmaLayout.isVolta())
return emitBaseIndexForMmaLayoutV1(loc, rewriter, mmaLayout, shape);
if (mmaLayout.isAmpere())
return emitBaseIndexForMmaLayoutV2(loc, rewriter, mmaLayout, shape);
}
llvm_unreachable("unsupported emitBaseIndexForLayout");
}
SmallVector<SmallVector<unsigned>>
emitOffsetForLayout(const Attribute &layout, ArrayRef<int64_t> shape) const {
if (auto blockedLayout = layout.dyn_cast<BlockedEncodingAttr>())
return emitOffsetForBlockedLayout(blockedLayout, shape);
if (auto mmaLayout = layout.dyn_cast<MmaEncodingAttr>()) {
if (mmaLayout.isVolta())
return emitOffsetForMmaLayoutV1(mmaLayout, shape);
if (mmaLayout.isAmpere())
return emitOffsetForMmaLayoutV2(mmaLayout, shape);
}
llvm_unreachable("unsupported emitOffsetForLayout");
}
// Emit indices calculation within each ConversionPattern, and returns a
// [elemsPerThread X rank] index matrix.
// TODO: [phil] redundant indices computation do not appear to hurt
// performance much, but they could still significantly slow down
// computations.
SmallVector<SmallVector<Value>> emitIndicesForDistributedLayout(
Location loc, ConversionPatternRewriter &rewriter,
const Attribute &layout, ArrayRef<int64_t> shape) const {
// step 1, delinearize threadId to get the base index
auto multiDimBase = emitBaseIndexForLayout(loc, rewriter, layout, shape);
// step 2, get offset of each element
auto offset = emitOffsetForLayout(layout, shape);
// step 3, add offset to base, and reorder the sequence of indices to
// guarantee that elems in the same sizePerThread are adjacent in order
unsigned rank = shape.size();
unsigned elemsPerThread = offset.size();
SmallVector<SmallVector<Value>> multiDimIdx(elemsPerThread,
SmallVector<Value>(rank));
for (unsigned n = 0; n < elemsPerThread; ++n)
for (unsigned k = 0; k < rank; ++k)
multiDimIdx[n][k] = add(multiDimBase[k], idx_val(offset[n][k]));
return multiDimIdx;
}
struct SmallVectorKeyInfo {
static unsigned getHashValue(const SmallVector<unsigned> &key) {
return llvm::hash_combine_range(key.begin(), key.end());
}
static bool isEqual(const SmallVector<unsigned> &lhs,
const SmallVector<unsigned> &rhs) {
return lhs == rhs;
}
static SmallVector<unsigned> getEmptyKey() {
return SmallVector<unsigned>();
}
static SmallVector<unsigned> getTombstoneKey() {
return {std::numeric_limits<unsigned>::max()};
}
};
SmallVector<SmallVector<Value>>
emitIndicesForSliceLayout(Location loc, ConversionPatternRewriter &rewriter,
const SliceEncodingAttr &sliceLayout,
ArrayRef<int64_t> shape) const {
auto parent = sliceLayout.getParent();
unsigned dim = sliceLayout.getDim();
size_t rank = shape.size();
auto parentIndices =
emitIndices(loc, rewriter, parent, sliceLayout.paddedShape(shape));
unsigned numIndices = parentIndices.size();
SmallVector<SmallVector<Value>> resultIndices;
for (unsigned i = 0; i < numIndices; ++i) {
SmallVector<Value> indices = parentIndices[i];
indices.erase(indices.begin() + dim);
resultIndices.push_back(indices);
}
return resultIndices;
}
// -----------------------------------------------------------------------
// Emit indices
// -----------------------------------------------------------------------
SmallVector<SmallVector<Value>> emitIndices(Location loc,
ConversionPatternRewriter &b,
const Attribute &layout,
ArrayRef<int64_t> shape) const {
if (auto blocked = layout.dyn_cast<BlockedEncodingAttr>()) {
return emitIndicesForDistributedLayout(loc, b, blocked, shape);
} else if (auto mma = layout.dyn_cast<MmaEncodingAttr>()) {
return emitIndicesForDistributedLayout(loc, b, mma, shape);
} else if (auto slice = layout.dyn_cast<SliceEncodingAttr>()) {
return emitIndicesForSliceLayout(loc, b, slice, shape);
} else {
assert(0 && "emitIndices for layouts other than blocked & slice not "
"implemented yet");
return {};
}
}
// -----------------------------------------------------------------------
// Shared memory utilities
// -----------------------------------------------------------------------
template <typename T>
Value getSharedMemoryBase(Location loc, ConversionPatternRewriter &rewriter,
T value) const {
auto ptrTy = LLVM::LLVMPointerType::get(
this->getTypeConverter()->convertType(rewriter.getI8Type()), 3);
auto bufferId = allocation->getBufferId(value);
assert(bufferId != Allocation::InvalidBufferId && "BufferId not found");
size_t offset = allocation->getOffset(bufferId);
Value offVal = idx_val(offset);
Value base = gep(ptrTy, smem, offVal);
return base;
}
protected:
const Allocation *allocation;
Value smem;
};
#endif

View File

@@ -0,0 +1,421 @@
#include "triton/Conversion/TritonGPUToLLVM/TritonGPUToLLVMPass.h"
#include "mlir/Conversion/ArithmeticToLLVM/ArithmeticToLLVM.h"
#include "mlir/Conversion/GPUToNVVM/GPUToNVVMPass.h"
#include "mlir/Conversion/GPUToROCDL/GPUToROCDLPass.h"
#include "mlir/Conversion/MathToLLVM/MathToLLVM.h"
#include "mlir/Conversion/SCFToStandard/SCFToStandard.h"
#include "mlir/Conversion/StandardToLLVM/ConvertStandardToLLVM.h"
#include "mlir/Dialect/LLVMIR/LLVMDialect.h"
#include "mlir/Dialect/LLVMIR/NVVMDialect.h"
#include "mlir/Dialect/LLVMIR/ROCDLDialect.h"
#include "mlir/Pass/Pass.h"
#include "triton/Analysis/Allocation.h"
#include "triton/Analysis/AxisInfo.h"
#include "triton/Analysis/Membar.h"
#include "triton/Dialect/Triton/IR/Dialect.h"
#include "triton/Dialect/TritonGPU/IR/Dialect.h"
#include "ConvertLayoutOpToLLVM.h"
#include "DotOpToLLVM.h"
#include "ElementwiseOpToLLVM.h"
#include "LoadStoreOpToLLVM.h"
#include "ReduceOpToLLVM.h"
#include "TritonGPUToLLVM.h"
#include "TypeConverter.h"
#include "ViewOpToLLVM.h"
using namespace mlir;
using namespace mlir::triton;
#define GEN_PASS_CLASSES
#include "triton/Conversion/Passes.h.inc"
namespace mlir {
class TritonLLVMConversionTarget : public ConversionTarget {
public:
explicit TritonLLVMConversionTarget(MLIRContext &ctx)
: ConversionTarget(ctx) {
addLegalDialect<LLVM::LLVMDialect>();
#ifdef USE_ROCM
addLegalDialect<ROCDL::ROCDLDialect>();
#else
addLegalDialect<NVVM::NVVMDialect>();
#endif
addIllegalDialect<triton::TritonDialect>();
addIllegalDialect<triton::gpu::TritonGPUDialect>();
addIllegalDialect<mlir::gpu::GPUDialect>();
addIllegalDialect<mlir::StandardOpsDialect>();
addLegalOp<mlir::UnrealizedConversionCastOp>();
}
};
class TritonLLVMFunctionConversionTarget : public ConversionTarget {
public:
explicit TritonLLVMFunctionConversionTarget(MLIRContext &ctx)
: ConversionTarget(ctx) {
addLegalDialect<LLVM::LLVMDialect>();
#ifdef USE_ROCM
addLegalDialect<ROCDL::ROCDLDialect>();
#else
addLegalDialect<NVVM::NVVMDialect>();
#endif
addIllegalOp<mlir::FuncOp>();
addLegalOp<mlir::UnrealizedConversionCastOp>();
}
};
} // namespace mlir
namespace {
/// FuncOp legalization pattern that converts MemRef arguments to pointers to
/// MemRef descriptors (LLVM struct data types) containing all the MemRef type
/// information.
struct FuncOpConversion : public FuncOpConversionBase {
FuncOpConversion(LLVMTypeConverter &converter, int numWarps,
PatternBenefit benefit)
: FuncOpConversionBase(converter, benefit), numWarps(numWarps) {}
LogicalResult
matchAndRewrite(FuncOp funcOp, OpAdaptor adaptor,
ConversionPatternRewriter &rewriter) const override {
auto newFuncOp = convertFuncOpToLLVMFuncOp(funcOp, rewriter);
if (!newFuncOp)
return failure();
auto ctx = funcOp->getContext();
// Set an attribute to indicate this function is a kernel entry.
newFuncOp->setAttr("nvvm.kernel",
rewriter.getIntegerAttr(type::u1Ty(ctx), 1));
#ifndef USE_ROCM
// Set an attribute for maxntidx, it could be used in latter LLVM codegen
// for `nvvm.annotation` metadata.
newFuncOp->setAttr("nvvm.maxntid",
rewriter.getIntegerAttr(i32_ty, 32 * numWarps));
#endif
rewriter.eraseOp(funcOp);
return success();
}
private:
int numWarps{0};
};
class ConvertTritonGPUToLLVM
: public ConvertTritonGPUToLLVMBase<ConvertTritonGPUToLLVM> {
public:
explicit ConvertTritonGPUToLLVM(int computeCapability)
: computeCapability(computeCapability) {}
void runOnOperation() override {
MLIRContext *context = &getContext();
ModuleOp mod = getOperation();
mlir::LowerToLLVMOptions option(context);
option.overrideIndexBitwidth(32);
TritonGPUToLLVMTypeConverter typeConverter(context, option);
TritonLLVMFunctionConversionTarget funcTarget(*context);
TritonLLVMConversionTarget target(*context);
int numWarps = triton::gpu::TritonGPUDialect::getNumWarps(mod);
// Step 1: Decompose unoptimized layout conversions to use shared memory
// Step 2: Decompose insert_slice_async to use load + insert_slice for
// pre-Ampere architectures or unsupported vectorized load sizes
// Step 3: Allocate shared memories and insert barriers
// Step 4: Convert SCF to CFG
// Step 5: Convert FuncOp to LLVMFuncOp via partial conversion
// Step 6: Get axis and shared memory info
// Step 7: Convert the rest of ops via partial conversion
//
// The reason for putting step 3 before step 4 is that the membar
// analysis currently only supports SCF but not CFG. The reason for a
// separation between 5/7 is that, step 6 is out of the scope of Dialect
// Conversion, thus we need to make sure the smem is not revised during the
// conversion of step 7.
// Step 1
decomposeMmaToDotOperand(mod, numWarps);
decomposeBlockedToDotOperand(mod);
// Step 2
decomposeInsertSliceAsyncOp(mod);
// Step 3
Allocation allocation(mod);
MembarAnalysis membarPass(&allocation);
membarPass.run();
// Step 4
RewritePatternSet scf_patterns(context);
mlir::populateLoopToStdConversionPatterns(scf_patterns);
mlir::ConversionTarget scf_target(*context);
scf_target.addIllegalOp<scf::ForOp, scf::IfOp, scf::ParallelOp,
scf::WhileOp, scf::ExecuteRegionOp>();
scf_target.markUnknownOpDynamicallyLegal([](Operation *) { return true; });
if (failed(
applyPartialConversion(mod, scf_target, std::move(scf_patterns))))
return signalPassFailure();
// Step 5
RewritePatternSet func_patterns(context);
func_patterns.add<FuncOpConversion>(typeConverter, numWarps, /*benefit=*/1);
if (failed(
applyPartialConversion(mod, funcTarget, std::move(func_patterns))))
return signalPassFailure();
// Step 6 - get axis and shared memory info
AxisInfoAnalysis axisInfoAnalysis(mod.getContext());
axisInfoAnalysis.run(mod);
initSharedMemory(allocation.getSharedMemorySize(), typeConverter);
mod->setAttr("triton_gpu.shared",
mlir::IntegerAttr::get(mlir::IntegerType::get(context, 32),
allocation.getSharedMemorySize()));
// Step 7 - rewrite rest of ops
// We set a higher benefit here to ensure triton's patterns runs before
// arith patterns for some encoding not supported by the community
// patterns.
RewritePatternSet patterns(context);
// Normal conversions
populateTritonGPUToLLVMPatterns(typeConverter, patterns, numWarps,
axisInfoAnalysis, &allocation, smem,
/*benefit=*/10);
// ConvertLayoutOp
populateConvertLayoutOpToLLVMPatterns(typeConverter, patterns, numWarps,
axisInfoAnalysis, &allocation, smem,
/*benefit=*/10);
// DotOp
populateDotOpToLLVMPatterns(typeConverter, patterns, numWarps,
axisInfoAnalysis, &allocation, smem,
/*benefit=*/10);
// ElementwiseOp
populateElementwiseOpToLLVMPatterns(typeConverter, patterns, numWarps,
axisInfoAnalysis, &allocation, smem,
/*benefit=*/10);
// LoadStoreOp
populateLoadStoreOpToLLVMPatterns(typeConverter, patterns, numWarps,
axisInfoAnalysis, &allocation, smem,
/*benefit=*/10);
// ReduceOp
populateReduceOpToLLVMPatterns(typeConverter, patterns, numWarps,
axisInfoAnalysis, &allocation, smem,
/*benefit=*/10);
// ViewOp
populateViewOpToLLVMPatterns(typeConverter, patterns, numWarps,
axisInfoAnalysis, &allocation, smem,
/*benefit=*/10);
// Add arith/math's patterns to help convert scalar expression to LLVM.
mlir::arith::populateArithmeticToLLVMConversionPatterns(typeConverter,
patterns);
mlir::populateMathToLLVMConversionPatterns(typeConverter, patterns);
mlir::populateStdToLLVMConversionPatterns(typeConverter, patterns);
#ifdef USE_ROCM
mlir::populateGpuToROCDLConversionPatterns(typeConverter, patterns, mlir::gpu::amd::HIP);
#else
mlir::populateGpuToNVVMConversionPatterns(typeConverter, patterns);
#endif
if (failed(applyPartialConversion(mod, target, std::move(patterns))))
return signalPassFailure();
}
private:
Value smem;
int computeCapability{};
void initSharedMemory(size_t size,
TritonGPUToLLVMTypeConverter &typeConverter) {
ModuleOp mod = getOperation();
OpBuilder b(mod.getBodyRegion());
auto loc = mod.getLoc();
auto elemTy = typeConverter.convertType(b.getIntegerType(8));
// Set array size 0 and external linkage indicates that we use dynamic
// shared allocation to allow a larger shared memory size for each kernel.
auto arrayTy = LLVM::LLVMArrayType::get(elemTy, 0);
auto global = b.create<LLVM::GlobalOp>(
loc, arrayTy, /*isConstant=*/false, LLVM::Linkage::External,
"global_smem", /*value=*/Attribute(), /*alignment=*/0,
mlir::gpu::GPUDialect::getWorkgroupAddressSpace());
SmallVector<LLVM::LLVMFuncOp> funcs;
mod.walk([&](LLVM::LLVMFuncOp func) { funcs.push_back(func); });
assert(funcs.size() == 1 &&
"Inliner pass is expected before TritonGPUToLLVM");
b.setInsertionPointToStart(&funcs[0].getBody().front());
smem = b.create<LLVM::AddressOfOp>(loc, global);
auto ptrTy =
LLVM::LLVMPointerType::get(typeConverter.convertType(b.getI8Type()), 3);
smem = b.create<LLVM::BitcastOp>(loc, ptrTy, smem);
}
void decomposeMmaToDotOperand(ModuleOp mod, int numWarps) const {
// Replace `mma -> dot_op` with `mma -> blocked -> dot_op`
// unless certain conditions are met
mod.walk([&](triton::gpu::ConvertLayoutOp cvtOp) -> void {
OpBuilder builder(cvtOp);
auto srcType = cvtOp.getOperand().getType().cast<RankedTensorType>();
auto dstType = cvtOp.getType().cast<RankedTensorType>();
auto srcMma =
srcType.getEncoding().dyn_cast<triton::gpu::MmaEncodingAttr>();
auto dstDotOp =
dstType.getEncoding().dyn_cast<triton::gpu::DotOperandEncodingAttr>();
if (srcMma && dstDotOp && !isMmaToDotShortcut(srcMma, dstDotOp)) {
auto tmpType = RankedTensorType::get(
dstType.getShape(), dstType.getElementType(),
triton::gpu::BlockedEncodingAttr::get(
mod.getContext(), srcType.getShape(), getSizePerThread(srcMma),
getOrder(srcMma), numWarps));
auto tmp = builder.create<triton::gpu::ConvertLayoutOp>(
cvtOp.getLoc(), tmpType, cvtOp.getOperand());
auto newConvert = builder.create<triton::gpu::ConvertLayoutOp>(
cvtOp.getLoc(), dstType, tmp);
cvtOp.replaceAllUsesWith(newConvert.getResult());
cvtOp.erase();
}
});
}
void decomposeBlockedToDotOperand(ModuleOp mod) const {
// Replace `blocked -> dot_op` with `blocked -> shared -> dot_op`
// because the codegen doesn't handle `blocked -> dot_op` directly
mod.walk([&](triton::gpu::ConvertLayoutOp cvtOp) -> void {
OpBuilder builder(cvtOp);
auto srcType = cvtOp.getOperand().getType().cast<RankedTensorType>();
auto dstType = cvtOp.getType().cast<RankedTensorType>();
auto srcBlocked =
srcType.getEncoding().dyn_cast<triton::gpu::BlockedEncodingAttr>();
auto dstDotOp =
dstType.getEncoding().dyn_cast<triton::gpu::DotOperandEncodingAttr>();
if (srcBlocked && dstDotOp) {
auto tmpType = RankedTensorType::get(
dstType.getShape(), dstType.getElementType(),
triton::gpu::SharedEncodingAttr::get(
mod.getContext(), dstDotOp, srcType.getShape(),
getOrder(srcBlocked), srcType.getElementType()));
auto tmp = builder.create<triton::gpu::ConvertLayoutOp>(
cvtOp.getLoc(), tmpType, cvtOp.getOperand());
auto newConvert = builder.create<triton::gpu::ConvertLayoutOp>(
cvtOp.getLoc(), dstType, tmp);
cvtOp.replaceAllUsesWith(newConvert.getResult());
cvtOp.erase();
}
});
}
void decomposeInsertSliceAsyncOp(ModuleOp mod) const {
AxisInfoAnalysis axisInfoAnalysis(mod.getContext());
axisInfoAnalysis.run(mod);
// TODO(Keren): This is a hacky knob that may cause performance regression
// when decomposition has been performed. We should remove this knob once we
// have thorough analysis on async wait. Currently, we decompose
// `insert_slice_async` into `load` and `insert_slice` without knowing which
// `async_wait` is responsible for the `insert_slice_async`. To guarantee
// correctness, we blindly set the `async_wait` to wait for all async ops.
//
// There are two options to improve this:
// 1. We can perform a dataflow analysis to find the `async_wait` that is
// responsible for the `insert_slice_async` in the backend.
// 2. We can modify the pipeline to perform the decomposition before the
// `async_wait` is inserted. However, it is also risky because we don't know
// the correct vectorized shape yet in the pipeline pass. Making the
// pipeline pass aware of the vectorization could introduce additional
// dependencies on the AxisInfoAnalysis and the Coalesce analysis.
bool decomposed = false;
// insert_slice_async %src, %dst, %idx, %mask, %other
// =>
// %tmp = load %src, %mask, %other
// %res = insert_slice %tmp into %dst[%idx]
mod.walk([&](triton::gpu::InsertSliceAsyncOp insertSliceAsyncOp) -> void {
OpBuilder builder(insertSliceAsyncOp);
// Get the vectorized load size
auto src = insertSliceAsyncOp.src();
auto dst = insertSliceAsyncOp.dst();
auto srcTy = src.getType().cast<RankedTensorType>();
auto dstTy = dst.getType().cast<RankedTensorType>();
auto srcBlocked =
srcTy.getEncoding().dyn_cast<triton::gpu::BlockedEncodingAttr>();
auto resSharedLayout =
dstTy.getEncoding().dyn_cast<triton::gpu::SharedEncodingAttr>();
auto resElemTy = dstTy.getElementType();
unsigned inVec = axisInfoAnalysis.getPtrVectorSize(src);
unsigned outVec = resSharedLayout.getVec();
unsigned minVec = std::min(outVec, inVec);
auto maxBitWidth =
std::max<unsigned>(128, resElemTy.getIntOrFloatBitWidth());
auto vecBitWidth = resElemTy.getIntOrFloatBitWidth() * minVec;
auto bitWidth = std::min<unsigned>(maxBitWidth, vecBitWidth);
auto byteWidth = bitWidth / 8;
// If the load byte width is not eligible or the current compute
// capability does not support async copy, then we do decompose
if (triton::gpu::InsertSliceAsyncOp::getEligibleLoadByteWidth(
computeCapability)
.contains(byteWidth))
return;
// load
auto tmpTy =
RankedTensorType::get(srcTy.getShape(), resElemTy, srcBlocked);
auto loadOp = builder.create<triton::LoadOp>(
insertSliceAsyncOp.getLoc(), tmpTy, insertSliceAsyncOp.src(),
insertSliceAsyncOp.mask(), insertSliceAsyncOp.other(),
insertSliceAsyncOp.cache(), insertSliceAsyncOp.evict(),
insertSliceAsyncOp.isVolatile());
// insert_slice
auto axis = insertSliceAsyncOp.axis();
auto intAttr = [&](int64_t v) { return builder.getI64IntegerAttr(v); };
auto offsets = SmallVector<OpFoldResult>(dstTy.getRank(), intAttr(0));
auto sizes = SmallVector<OpFoldResult>(dstTy.getRank(), intAttr(1));
auto strides = SmallVector<OpFoldResult>(dstTy.getRank(), intAttr(1));
offsets[axis] = insertSliceAsyncOp.index();
for (size_t i = 0; i < dstTy.getRank(); i++) {
if (i != axis)
sizes[i] = intAttr(dstTy.getShape()[i]);
}
auto insertSliceOp = builder.create<tensor::InsertSliceOp>(
insertSliceAsyncOp.getLoc(), loadOp, insertSliceAsyncOp.dst(),
offsets, sizes, strides);
// Replace
insertSliceAsyncOp.replaceAllUsesWith(insertSliceOp.getResult());
insertSliceAsyncOp.erase();
decomposed = true;
});
mod.walk([&](triton::gpu::AsyncWaitOp asyncWaitOp) -> void {
if (!triton::gpu::AsyncWaitOp::isSupported(computeCapability)) {
// async wait is supported in Ampere and later
asyncWaitOp.erase();
} else if (decomposed) {
// Wait for all previous async ops
OpBuilder builder(asyncWaitOp);
auto newAsyncWaitOp =
builder.create<triton::gpu::AsyncWaitOp>(asyncWaitOp.getLoc(), 0);
asyncWaitOp.erase();
}
});
}
};
} // anonymous namespace
namespace mlir {
namespace triton {
std::unique_ptr<OperationPass<ModuleOp>>
createConvertTritonGPUToLLVMPass(int computeCapability) {
return std::make_unique<::ConvertTritonGPUToLLVM>(computeCapability);
}
} // namespace triton
} // namespace mlir

View File

@@ -0,0 +1,150 @@
#ifndef TRITON_CONVERSION_TRITONGPU_TO_LLVM_TYPECONVERTER_H
#define TRITON_CONVERSION_TRITONGPU_TO_LLVM_TYPECONVERTER_H
#include "mlir/Dialect/LLVMIR/LLVMDialect.h"
#include "triton/Conversion/MLIRTypes.h"
#include "DotOpHelpers.h"
#include "Utility.h"
using namespace mlir;
using namespace mlir::triton;
using ::mlir::LLVM::DotOpFMAConversionHelper;
using ::mlir::LLVM::DotOpMmaV1ConversionHelper;
using ::mlir::LLVM::MMA16816ConversionHelper;
using ::mlir::triton::gpu::BlockedEncodingAttr;
using ::mlir::triton::gpu::DotOperandEncodingAttr;
using ::mlir::triton::gpu::getElemsPerThread;
using ::mlir::triton::gpu::MmaEncodingAttr;
using ::mlir::triton::gpu::SharedEncodingAttr;
using ::mlir::triton::gpu::SliceEncodingAttr;
class TritonGPUToLLVMTypeConverter : public LLVMTypeConverter {
public:
using TypeConverter::convertType;
TritonGPUToLLVMTypeConverter(MLIRContext *ctx, LowerToLLVMOptions &option,
const DataLayoutAnalysis *analysis = nullptr)
: LLVMTypeConverter(ctx, option, analysis) {
addConversion([&](triton::PointerType type) -> llvm::Optional<Type> {
return convertTritonPointerType(type);
});
addConversion([&](RankedTensorType type) -> llvm::Optional<Type> {
return convertTritonTensorType(type);
});
// Internally store float8 as int8
addConversion([&](triton::Float8Type type) -> llvm::Optional<Type> {
return IntegerType::get(type.getContext(), 8);
});
// Internally store bfloat16 as int16
addConversion([&](BFloat16Type type) -> llvm::Optional<Type> {
return IntegerType::get(type.getContext(), 16);
});
}
Type convertTritonPointerType(triton::PointerType type) {
// Recursively translate pointee type
return LLVM::LLVMPointerType::get(convertType(type.getPointeeType()),
type.getAddressSpace());
}
llvm::Optional<Type> convertTritonTensorType(RankedTensorType type) {
auto ctx = type.getContext();
Attribute layout = type.getEncoding();
SmallVector<int64_t> shape(type.getShape().begin(), type.getShape().end());
if (layout &&
(layout.isa<BlockedEncodingAttr>() || layout.isa<SliceEncodingAttr>() ||
layout.isa<MmaEncodingAttr>())) {
unsigned numElementsPerThread = getElemsPerThread(type);
SmallVector<Type, 4> types(numElementsPerThread,
convertType(type.getElementType()));
return LLVM::LLVMStructType::getLiteral(ctx, types);
} else if (auto shared_layout =
layout.dyn_cast_or_null<SharedEncodingAttr>()) {
SmallVector<Type, 4> types;
// base ptr
auto ptrType =
LLVM::LLVMPointerType::get(convertType(type.getElementType()), 3);
types.push_back(ptrType);
// shape dims
auto rank = type.getRank();
// offsets + strides
for (auto i = 0; i < rank * 2; i++) {
types.push_back(IntegerType::get(ctx, 32));
}
return LLVM::LLVMStructType::getLiteral(ctx, types);
} else if (auto dotOpLayout =
layout.dyn_cast_or_null<DotOperandEncodingAttr>()) {
if (dotOpLayout.getParent()
.isa<BlockedEncodingAttr>()) { // for parent is blocked layout
int numElemsPerThread =
DotOpFMAConversionHelper::getNumElemsPerThread(shape, dotOpLayout);
return LLVM::LLVMStructType::getLiteral(
ctx, SmallVector<Type>(numElemsPerThread, type::f32Ty(ctx)));
} else { // for parent is MMA layout
auto mmaLayout = dotOpLayout.getParent().cast<MmaEncodingAttr>();
auto wpt = mmaLayout.getWarpsPerCTA();
Type elemTy = convertType(type.getElementType());
if (mmaLayout.isAmpere()) {
const llvm::DenseMap<int, Type> targetTyMap = {
{32, elemTy},
{16, vec_ty(elemTy, 2)},
{8, vec_ty(elemTy, 4)},
};
Type targetTy;
if (targetTyMap.count(elemTy.getIntOrFloatBitWidth())) {
targetTy = targetTyMap.lookup(elemTy.getIntOrFloatBitWidth());
} else {
assert(false && "Unsupported element type");
}
if (dotOpLayout.getOpIdx() == 0) { // $a
auto elems =
MMA16816ConversionHelper::getANumElemsPerThread(type, wpt[0]);
return LLVM::LLVMStructType::getLiteral(
ctx, SmallVector<Type>(elems, targetTy));
}
if (dotOpLayout.getOpIdx() == 1) { // $b
auto elems =
MMA16816ConversionHelper::getBNumElemsPerThread(type, wpt[1]);
return struct_ty(SmallVector<Type>(elems, targetTy));
}
}
if (mmaLayout.isVolta()) {
DotOpMmaV1ConversionHelper helper(mmaLayout);
// TODO[Superjomn]: Both transA and transB are not available here.
bool trans = false;
// TODO[Superjomn]: The order of A and B are not available here.
SmallVector<unsigned> order({1, 0});
if (trans) {
std::swap(shape[0], shape[1]);
std::swap(order[0], order[1]);
}
if (dotOpLayout.getOpIdx() == 0) { // $a
int elems = helper.numElemsPerThreadA(shape, order);
Type x2Ty = vec_ty(elemTy, 2);
return struct_ty(SmallVector<Type>(elems, x2Ty));
}
if (dotOpLayout.getOpIdx() == 1) { // $b
int elems = helper.numElemsPerThreadB(shape, order);
Type x2Ty = vec_ty(elemTy, 2);
return struct_ty(SmallVector<Type>(elems, x2Ty));
}
}
}
llvm::errs() << "Unexpected dot operand layout detected in "
"TritonToLLVMTypeConverter";
return llvm::None;
}
return llvm::None;
}
};
#endif

View File

@@ -0,0 +1,392 @@
#ifndef TRITON_CONVERSION_TRITONGPU_TO_LLVM_UTILITY_H
#define TRITON_CONVERSION_TRITONGPU_TO_LLVM_UTILITY_H
#include "mlir/Conversion/LLVMCommon/Pattern.h"
#include "mlir/Dialect/LLVMIR/LLVMDialect.h"
#include "triton/Analysis/Utility.h"
#include "triton/Conversion/MLIRTypes.h"
#include "triton/Conversion/TritonGPUToLLVM/PTXAsmFormat.h"
#include "triton/Conversion/TritonGPUToLLVM/GCNAsmFormat.h"
// Shortcuts for some commonly used LLVM ops to keep code simple and intuitive
// Operators
#define inttoptr(...) rewriter.create<LLVM::IntToPtrOp>(loc, __VA_ARGS__)
#define ptrtoint(...) rewriter.create<LLVM::PtrToIntOp>(loc, __VA_ARGS__)
#define zext(...) rewriter.create<LLVM::ZExtOp>(loc, __VA_ARGS__)
#define trunc(...) rewriter.create<LLVM::TruncOp>(loc, __VA_ARGS__)
#define udiv(...) rewriter.create<LLVM::UDivOp>(loc, __VA_ARGS__)
#define urem(...) rewriter.create<LLVM::URemOp>(loc, __VA_ARGS__)
#define add(...) rewriter.create<LLVM::AddOp>(loc, __VA_ARGS__)
#define sub(...) rewriter.create<LLVM::SubOp>(loc, __VA_ARGS__)
#define fadd(...) rewriter.create<LLVM::FAddOp>(loc, __VA_ARGS__)
#define mul(...) rewriter.create<LLVM::MulOp>(loc, __VA_ARGS__)
#define fmul(...) rewriter.create<LLVM::FMulOp>(loc, __VA_ARGS__)
#define shl(...) rewriter.create<LLVM::ShlOp>(loc, __VA_ARGS__)
#define lshr(...) rewriter.create<LLVM::LShrOp>(loc, __VA_ARGS__)
#define ashr(...) rewriter.create<LLVM::AShrOp>(loc, __VA_ARGS__)
#define smax(...) rewriter.create<LLVM::SMaxOp>(loc, __VA_ARGS__)
#define umax(...) rewriter.create<LLVM::UMaxOp>(loc, __VA_ARGS__)
#define fmax(...) rewriter.create<LLVM::MaxNumOp>(loc, __VA_ARGS__)
#define smin(...) rewriter.create<LLVM::SMinOp>(loc, __VA_ARGS__)
#define umin(...) rewriter.create<LLVM::UMinOp>(loc, __VA_ARGS__)
#define fmin(...) rewriter.create<LLVM::MinNumOp>(loc, __VA_ARGS__)
#define and_(...) rewriter.create<LLVM::AndOp>(loc, __VA_ARGS__)
#define xor_(...) rewriter.create<LLVM::XOrOp>(loc, __VA_ARGS__)
#define bitcast(val__, type__) \
rewriter.create<LLVM::BitcastOp>(loc, type__, val__)
#define gep(...) rewriter.create<LLVM::GEPOp>(loc, __VA_ARGS__)
#define ptr_ty(...) LLVM::LLVMPointerType::get(__VA_ARGS__)
#define insert_val(...) rewriter.create<LLVM::InsertValueOp>(loc, __VA_ARGS__)
#define extract_val(...) rewriter.create<LLVM::ExtractValueOp>(loc, __VA_ARGS__)
#define insert_element(...) \
rewriter.create<LLVM::InsertElementOp>(loc, __VA_ARGS__)
#define extract_element(...) \
rewriter.create<LLVM::ExtractElementOp>(loc, __VA_ARGS__)
#define load(...) rewriter.create<LLVM::LoadOp>(loc, __VA_ARGS__)
#define store(val, ptr) rewriter.create<LLVM::StoreOp>(loc, val, ptr)
#define fcmp_ogt(lhs, rhs) \
rewriter.create<LLVM::FCmpOp>(loc, rewriter.getI1Type(), \
LLVM::FCmpPredicate::ogt, lhs, rhs)
#define fcmp_olt(lhs, rhs) \
rewriter.create<LLVM::FCmpOp>(loc, rewriter.getI1Type(), \
LLVM::FCmpPredicate::olt, lhs, rhs)
#define icmp_eq(...) \
rewriter.create<LLVM::ICmpOp>(loc, LLVM::ICmpPredicate::eq, __VA_ARGS__)
#define icmp_ne(...) \
rewriter.create<LLVM::ICmpOp>(loc, LLVM::ICmpPredicate::ne, __VA_ARGS__)
#define icmp_slt(...) \
rewriter.create<LLVM::ICmpOp>(loc, LLVM::ICmpPredicate::slt, __VA_ARGS__)
#define icmp_sle(...) \
rewriter.create<LLVM::ICmpOp>(loc, LLVM::ICmpPredicate::sle, __VA_ARGS__)
#define icmp_sgt(...) \
rewriter.create<LLVM::ICmpOp>(loc, LLVM::ICmpPredicate::sgt, __VA_ARGS__)
#define icmp_sge(...) \
rewriter.create<LLVM::ICmpOp>(loc, LLVM::ICmpPredicate::sge, __VA_ARGS__)
#define icmp_ult(...) \
rewriter.create<LLVM::ICmpOp>(loc, LLVM::ICmpPredicate::ult, __VA_ARGS__)
#define icmp_ule(...) \
rewriter.create<LLVM::ICmpOp>(loc, LLVM::ICmpPredicate::ule, __VA_ARGS__)
#define icmp_ugt(...) \
rewriter.create<LLVM::ICmpOp>(loc, LLVM::ICmpPredicate::ugt, __VA_ARGS__)
#define icmp_uge(...) \
rewriter.create<LLVM::ICmpOp>(loc, LLVM::ICmpPredicate::uge, __VA_ARGS__)
#define select(...) rewriter.create<LLVM::SelectOp>(loc, __VA_ARGS__)
#define address_of(...) rewriter.create<LLVM::AddressOfOp>(loc, __VA_ARGS__)
#define barrier() rewriter.create<mlir::gpu::BarrierOp>(loc)
#define undef(...) rewriter.create<LLVM::UndefOp>(loc, __VA_ARGS__)
// Types
#define i32_ty rewriter.getIntegerType(32)
#define i16_ty rewriter.getIntegerType(16)
#define ui32_ty rewriter.getIntegerType(32, false)
#define f16_ty rewriter.getF16Type()
#define bf16_ty rewriter.getBF16Type()
#define i8_ty rewriter.getIntegerType(8)
#define f32_ty rewriter.getF32Type()
#define f64_ty rewriter.getF64Type()
#define vec_ty(type, num) VectorType::get(num, type)
#define f32_val(...) LLVM::createConstantF32(loc, rewriter, __VA_ARGS__)
#define f64_val(...) LLVM::createConstantF64(loc, rewriter, __VA_ARGS__)
#define void_ty(ctx) LLVM::LLVMVoidType::get(ctx)
#define struct_ty(...) LLVM::LLVMStructType::getLiteral(ctx, __VA_ARGS__)
#define array_ty(elemTy, count) LLVM::LLVMArrayType::get(elemTy, count)
// Constants
#define i32_val(...) LLVM::createConstantI32(loc, rewriter, __VA_ARGS__)
#define int_val(width, val) \
LLVM::createLLVMIntegerConstant(rewriter, loc, width, val)
#define idx_val(...) \
LLVM::createIndexConstant(rewriter, loc, this->getTypeConverter(), \
__VA_ARGS__)
#define tid_val() getThreadId(rewriter, loc)
namespace mlir {
namespace triton {
// Delinearize supposing order is [0, 1, .. , n]
template <typename T>
llvm::SmallVector<T> getMultiDimIndexImpl(T linearIndex,
llvm::ArrayRef<T> shape) {
// shape: {a, b, c, d} -> accMul: {1, a, a*b, a*b*c}
size_t rank = shape.size();
T accMul = product(shape.drop_back());
T linearRemain = linearIndex;
llvm::SmallVector<T> multiDimIndex(rank);
for (int i = rank - 1; i >= 0; --i) {
multiDimIndex[i] = linearRemain / accMul;
linearRemain = linearRemain % accMul;
if (i != 0) {
accMul = accMul / shape[i - 1];
}
}
return multiDimIndex;
}
template <typename T>
llvm::SmallVector<T> getMultiDimIndex(T linearIndex, llvm::ArrayRef<T> shape,
llvm::ArrayRef<unsigned> order) {
size_t rank = shape.size();
assert(rank == order.size());
auto reordered = reorder(shape, order);
auto reorderedMultiDim = getMultiDimIndexImpl<T>(linearIndex, reordered);
llvm::SmallVector<T> multiDim(rank);
for (unsigned i = 0; i < rank; ++i) {
multiDim[order[i]] = reorderedMultiDim[i];
}
return multiDim;
}
// Linearize supposing order is [0, 1, .. , n]
template <typename T>
static T getLinearIndexImpl(llvm::ArrayRef<T> multiDimIndex,
llvm::ArrayRef<T> shape) {
assert(multiDimIndex.size() == shape.size());
// shape: {a, b, c, d} -> accMul: {1, a, a*b, a*b*c}
size_t rank = shape.size();
T accMul = product(shape.drop_back());
T linearIndex = 0;
for (int i = rank - 1; i >= 0; --i) {
linearIndex += multiDimIndex[i] * accMul;
if (i != 0) {
accMul = accMul / shape[i - 1];
}
}
return linearIndex;
}
template <typename T>
static T getLinearIndex(llvm::ArrayRef<T> multiDimIndex,
llvm::ArrayRef<T> shape,
llvm::ArrayRef<unsigned> order) {
assert(shape.size() == order.size());
return getLinearIndexImpl<T>(reorder(multiDimIndex, order),
reorder(shape, order));
}
} // namespace triton
namespace LLVM {
using namespace mlir::triton;
static Value getStructFromElements(Location loc, ValueRange resultVals,
ConversionPatternRewriter &rewriter,
Type structType) {
if (!structType.isa<LLVM::LLVMStructType>()) {
return *resultVals.begin();
}
Value llvmStruct = rewriter.create<LLVM::UndefOp>(loc, structType);
for (const auto &v : llvm::enumerate(resultVals)) {
assert(v.value() && "can not insert null values");
llvmStruct = insert_val(structType, llvmStruct, v.value(),
rewriter.getI64ArrayAttr(v.index()));
}
return llvmStruct;
}
static SmallVector<Value>
getElementsFromStruct(Location loc, Value llvmStruct,
ConversionPatternRewriter &rewriter) {
if (llvmStruct.getType().isIntOrIndexOrFloat() ||
llvmStruct.getType().isa<triton::PointerType>() ||
llvmStruct.getType().isa<LLVM::LLVMPointerType>())
return {llvmStruct};
ArrayRef<Type> types =
llvmStruct.getType().cast<LLVM::LLVMStructType>().getBody();
SmallVector<Value> results(types.size());
for (unsigned i = 0; i < types.size(); ++i) {
Type type = types[i];
results[i] = extract_val(type, llvmStruct, rewriter.getI64ArrayAttr(i));
}
return results;
}
// Create a 32-bit integer constant.
static Value createConstantI32(Location loc, PatternRewriter &rewriter,
int32_t v) {
auto i32ty = rewriter.getIntegerType(32);
return rewriter.create<LLVM::ConstantOp>(loc, i32ty,
IntegerAttr::get(i32ty, v));
}
static Value createConstantF32(Location loc, PatternRewriter &rewriter,
float v) {
auto type = type::f32Ty(rewriter.getContext());
return rewriter.create<LLVM::ConstantOp>(loc, type,
rewriter.getF32FloatAttr(v));
}
static Value createConstantF64(Location loc, PatternRewriter &rewriter,
float v) {
auto type = type::f64Ty(rewriter.getContext());
return rewriter.create<LLVM::ConstantOp>(loc, type,
rewriter.getF64FloatAttr(v));
}
// Create an index type constant.
static Value createIndexConstant(OpBuilder &builder, Location loc,
TypeConverter *converter, int64_t value) {
Type ty = converter->convertType(builder.getIndexType());
return builder.create<LLVM::ConstantOp>(loc, ty,
builder.getIntegerAttr(ty, value));
}
// Create an integer constant of \param width bits.
static Value createLLVMIntegerConstant(OpBuilder &builder, Location loc,
short width, int64_t value) {
Type ty = builder.getIntegerType(width);
return builder.create<LLVM::ConstantOp>(loc, ty,
builder.getIntegerAttr(ty, value));
}
/// Helper function to get strides from a given shape and its order
static SmallVector<Value>
getStridesFromShapeAndOrder(ArrayRef<int64_t> shape, ArrayRef<unsigned> order,
Location loc, ConversionPatternRewriter &rewriter) {
auto rank = shape.size();
SmallVector<Value> strides(rank);
int64_t stride = 1;
for (auto idx : order) {
strides[idx] = i32_val(stride);
stride *= shape[idx];
}
return strides;
}
struct SharedMemoryObject {
Value base; // i32 ptr. The start address of the shared memory object.
// We need to store strides as Values but not integers because the
// extract_slice instruction can take a slice at arbitrary offsets.
// Take $a[16:32, 16:32] as an example, though we know the stride of $a[0] is
// 32, we need to let the instruction that uses $a to be aware of that.
// Otherwise, when we use $a, we only know that the shape of $a is 16x16. If
// we store strides into an attribute array of integers, the information
// cannot pass through block argument assignment because attributes are
// associated with operations but not Values.
// TODO(Keren): We may need to figure out a way to store strides as integers
// if we want to support more optimizations.
SmallVector<Value>
strides; // i32 int. The strides of the shared memory object.
SmallVector<Value> offsets; // i32 int. The offsets of the shared memory
// objects from the originally allocated object.
SharedMemoryObject(Value base, ArrayRef<Value> strides,
ArrayRef<Value> offsets)
: base(base), strides(strides.begin(), strides.end()),
offsets(offsets.begin(), offsets.end()) {}
SharedMemoryObject(Value base, ArrayRef<int64_t> shape,
ArrayRef<unsigned> order, Location loc,
ConversionPatternRewriter &rewriter)
: base(base) {
strides = getStridesFromShapeAndOrder(shape, order, loc, rewriter);
for (auto idx : order) {
offsets.emplace_back(i32_val(0));
}
}
SmallVector<Value> getElems() const {
SmallVector<Value> elems;
elems.push_back(base);
elems.append(strides.begin(), strides.end());
elems.append(offsets.begin(), offsets.end());
return elems;
}
SmallVector<Type> getTypes() const {
SmallVector<Type> types;
types.push_back(base.getType());
types.append(strides.size(), IntegerType::get(base.getContext(), 32));
types.append(offsets.size(), IntegerType::get(base.getContext(), 32));
return types;
}
Value getCSwizzleOffset(int order) const {
assert(order >= 0 && order < strides.size());
return offsets[order];
}
Value getBaseBeforeSwizzle(int order, Location loc,
ConversionPatternRewriter &rewriter) const {
Value cSwizzleOffset = getCSwizzleOffset(order);
Value offset = sub(i32_val(0), cSwizzleOffset);
Type type = base.getType();
return gep(type, base, offset);
}
};
static SharedMemoryObject
getSharedMemoryObjectFromStruct(Location loc, Value llvmStruct,
ConversionPatternRewriter &rewriter) {
auto elems = getElementsFromStruct(loc, llvmStruct, rewriter);
auto rank = (elems.size() - 1) / 2;
return {/*base=*/elems[0],
/*strides=*/{elems.begin() + 1, elems.begin() + 1 + rank},
/*offsets=*/{elems.begin() + 1 + rank, elems.end()}};
}
static Value storeShared(ConversionPatternRewriter &rewriter, Location loc,
Value ptr, Value val, Value pred) {
#if USE_ROCM
store(val, ptr);
return val;
#else
MLIRContext *ctx = rewriter.getContext();
unsigned bits = val.getType().getIntOrFloatBitWidth();
const char *c = bits == 64 ? "l" : (bits == 16 ? "h" : "r");
PTXBuilder builder;
auto *ptrOpr = builder.newAddrOperand(ptr, "r");
auto *valOpr = builder.newOperand(val, c);
auto &st = builder.create<>("st")->shared().b(bits);
st(ptrOpr, valOpr).predicate(pred, "b");
return builder.launch(rewriter, loc, void_ty(ctx));
#endif
}
static Value shflSync(Location loc, ConversionPatternRewriter &rewriter,
Value val, int i) {
unsigned bits = val.getType().getIntOrFloatBitWidth();
if (bits == 64) {
Type vecTy = vec_ty(f32_ty, 2);
Value vec = bitcast(val, vecTy);
Value val0 = extract_element(f32_ty, vec, i32_val(0));
Value val1 = extract_element(f32_ty, vec, i32_val(1));
val0 = shflSync(loc, rewriter, val0, i);
val1 = shflSync(loc, rewriter, val1, i);
vec = undef(vecTy);
vec = insert_element(vecTy, vec, val0, i32_val(0));
vec = insert_element(vecTy, vec, val1, i32_val(1));
return bitcast(vec, val.getType());
}
#ifdef USE_ROCM
// This map facilates the butterfly shuffle pattern for a stride less than 16. The pattern stride is the key of the map.
DenseMap<short, unsigned int> masks{{16, 0x401F}, {8, 0x201F}, {4, 0x101F}, {2, 0x081F}, {1, 0x041F}};
GCNBuilder builder;
auto shfl = builder.create("ds_swizzle_b32");
auto dOpr = builder.newOperand("=v");
auto aOpr = builder.newOperand(val, "v");
auto maskOpr = builder.newConstantOperand("offset:" + std::to_string(masks[i]));
(*shfl)(dOpr, aOpr, maskOpr);
auto swait = builder.create("s_waitcnt lgkmcnt(0)");
(*swait)();
#else
PTXBuilder builder;
auto &shfl = builder.create("shfl.sync")->o("bfly").o("b32");
auto *dOpr = builder.newOperand("=r");
auto *aOpr = builder.newOperand(val, "r");
auto *bOpr = builder.newConstantOperand(i);
auto *cOpr = builder.newConstantOperand("0x1f");
auto *maskOpr = builder.newConstantOperand("0xffffffff");
shfl(dOpr, aOpr, bOpr, cOpr, maskOpr);
#endif
return builder.launch(rewriter, loc, val.getType(), false);
}
} // namespace LLVM
} // namespace mlir
#endif

View File

@@ -0,0 +1,229 @@
#include "ViewOpToLLVM.h"
#include "DotOpHelpers.h"
using namespace mlir;
using namespace mlir::triton;
using ::mlir::LLVM::DotOpMmaV1ConversionHelper;
using ::mlir::LLVM::DotOpMmaV2ConversionHelper;
using ::mlir::LLVM::getElementsFromStruct;
using ::mlir::LLVM::getSharedMemoryObjectFromStruct;
using ::mlir::LLVM::getStructFromElements;
using ::mlir::triton::gpu::getElemsPerThread;
struct SplatOpConversion
: public ConvertTritonGPUOpToLLVMPattern<triton::SplatOp> {
using ConvertTritonGPUOpToLLVMPattern<
triton::SplatOp>::ConvertTritonGPUOpToLLVMPattern;
// Convert SplatOp or arith::ConstantOp with SplatElementsAttr to a
// LLVM::StructType value.
//
// @elemType: the element type in operand.
// @resType: the return type of the Splat-like op.
// @constVal: a LLVM::ConstantOp or other scalar value.
static Value convertSplatLikeOp(Type elemType, Type resType, Value constVal,
TypeConverter *typeConverter,
ConversionPatternRewriter &rewriter,
Location loc) {
auto tensorTy = resType.cast<RankedTensorType>();
if (tensorTy.getEncoding().isa<BlockedEncodingAttr>() ||
tensorTy.getEncoding().isa<SliceEncodingAttr>()) {
auto srcType = typeConverter->convertType(elemType);
auto llSrc = bitcast(constVal, srcType);
size_t elemsPerThread = getElemsPerThread(tensorTy);
llvm::SmallVector<Value> elems(elemsPerThread, llSrc);
llvm::SmallVector<Type> elemTypes(elems.size(), srcType);
auto structTy =
LLVM::LLVMStructType::getLiteral(rewriter.getContext(), elemTypes);
return getStructFromElements(loc, elems, rewriter, structTy);
} else if (auto mmaLayout =
tensorTy.getEncoding().dyn_cast<MmaEncodingAttr>()) {
return convertSplatLikeOpWithMmaLayout(
mmaLayout, resType, elemType, constVal, typeConverter, rewriter, loc);
} else
assert(false && "Unsupported layout found in ConvertSplatLikeOp");
return {};
}
static Value convertSplatLikeOpWithMmaLayout(
const MmaEncodingAttr &layout, Type resType, Type elemType,
Value constVal, TypeConverter *typeConverter,
ConversionPatternRewriter &rewriter, Location loc) {
auto tensorTy = resType.cast<RankedTensorType>();
auto shape = tensorTy.getShape();
if (layout.isAmpere()) {
auto [repM, repN] = DotOpMmaV2ConversionHelper::getRepMN(tensorTy);
size_t fcSize = 4 * repM * repN;
auto structTy = LLVM::LLVMStructType::getLiteral(
rewriter.getContext(), SmallVector<Type>(fcSize, elemType));
return getStructFromElements(loc, SmallVector<Value>(fcSize, constVal),
rewriter, structTy);
}
if (layout.isVolta()) {
DotOpMmaV1ConversionHelper helper(layout);
int repM = helper.getRepM(shape[0]);
int repN = helper.getRepN(shape[1]);
// According to mma layout of v1, each thread process 8 elements.
int elems = 8 * repM * repN;
auto structTy = LLVM::LLVMStructType::getLiteral(
rewriter.getContext(), SmallVector<Type>(elems, elemType));
return getStructFromElements(loc, SmallVector<Value>(elems, constVal),
rewriter, structTy);
}
assert(false && "Unsupported mma layout found");
return {};
}
LogicalResult matchAndRewrite(triton::SplatOp op, OpAdaptor adaptor,
ConversionPatternRewriter &rewriter) const {
auto loc = op->getLoc();
auto src = adaptor.src();
auto llStruct = convertSplatLikeOp(src.getType(), op.getType(), src,
getTypeConverter(), rewriter, loc);
rewriter.replaceOp(op, {llStruct});
return success();
}
};
// This pattern helps to convert arith::ConstantOp(with SplatElementsAttr),
// the logic is the same as triton::SplatOp, so the underlying implementation
// is reused.
struct ArithConstantSplatOpConversion
: public ConvertTritonGPUOpToLLVMPattern<arith::ConstantOp> {
using ConvertTritonGPUOpToLLVMPattern<
arith::ConstantOp>::ConvertTritonGPUOpToLLVMPattern;
LogicalResult
matchAndRewrite(arith::ConstantOp op, OpAdaptor adaptor,
ConversionPatternRewriter &rewriter) const override {
auto value = op.getValue();
if (!value.dyn_cast<SplatElementsAttr>())
return failure();
auto loc = op->getLoc();
LLVM::ConstantOp arithConstantOp;
auto values = op.getValue().dyn_cast<SplatElementsAttr>();
auto elemType = values.getElementType();
Attribute val;
if (elemType.isBF16() || type::isFloat(elemType)) {
val = values.getValues<FloatAttr>()[0];
} else if (type::isInt(elemType)) {
val = values.getValues<IntegerAttr>()[0];
} else {
llvm::errs() << "ArithConstantSplatOpConversion get unsupported type: "
<< value.getType() << "\n";
return failure();
}
auto constOp = rewriter.create<LLVM::ConstantOp>(loc, elemType, val);
auto llStruct = SplatOpConversion::convertSplatLikeOp(
elemType, op.getType(), constOp, getTypeConverter(), rewriter, loc);
rewriter.replaceOp(op, llStruct);
return success();
}
};
struct CatOpConversion : public ConvertTritonGPUOpToLLVMPattern<CatOp> {
using OpAdaptor = typename CatOp::Adaptor;
explicit CatOpConversion(LLVMTypeConverter &typeConverter,
PatternBenefit benefit = 1)
: ConvertTritonGPUOpToLLVMPattern<CatOp>(typeConverter, benefit) {}
LogicalResult
matchAndRewrite(CatOp op, OpAdaptor adaptor,
ConversionPatternRewriter &rewriter) const override {
Location loc = op->getLoc();
auto resultTy = op.getType().template cast<RankedTensorType>();
unsigned elems = getElemsPerThread(resultTy);
Type elemTy =
this->getTypeConverter()->convertType(resultTy.getElementType());
SmallVector<Type> types(elems, elemTy);
// unpack input values
auto lhsVals = getElementsFromStruct(loc, adaptor.lhs(), rewriter);
auto rhsVals = getElementsFromStruct(loc, adaptor.rhs(), rewriter);
// concatenate (and potentially reorder) values
SmallVector<Value> retVals;
for (Value v : lhsVals)
retVals.push_back(v);
for (Value v : rhsVals)
retVals.push_back(v);
// pack and replace
Type structTy = LLVM::LLVMStructType::getLiteral(this->getContext(), types);
Value ret = getStructFromElements(loc, retVals, rewriter, structTy);
rewriter.replaceOp(op, ret);
return success();
}
};
template <typename SourceOp>
struct ViewLikeOpConversion : public ConvertTritonGPUOpToLLVMPattern<SourceOp> {
using OpAdaptor = typename SourceOp::Adaptor;
explicit ViewLikeOpConversion(LLVMTypeConverter &typeConverter,
PatternBenefit benefit = 1)
: ConvertTritonGPUOpToLLVMPattern<SourceOp>(typeConverter, benefit) {}
LogicalResult
matchAndRewrite(SourceOp op, OpAdaptor adaptor,
ConversionPatternRewriter &rewriter) const override {
// We cannot directly run `rewriter.replaceOp(op, adaptor.src())`
// due to MLIR's restrictions
Location loc = op->getLoc();
auto resultTy = op.getType().template cast<RankedTensorType>();
unsigned elems = getElemsPerThread(resultTy);
Type elemTy =
this->getTypeConverter()->convertType(resultTy.getElementType());
SmallVector<Type> types(elems, elemTy);
Type structTy = LLVM::LLVMStructType::getLiteral(this->getContext(), types);
auto vals = getElementsFromStruct(loc, adaptor.src(), rewriter);
Value view = getStructFromElements(loc, vals, rewriter, structTy);
rewriter.replaceOp(op, view);
return success();
}
};
struct TransOpConversion
: public ConvertTritonGPUOpToLLVMPattern<triton::TransOp> {
using ConvertTritonGPUOpToLLVMPattern<
triton::TransOp>::ConvertTritonGPUOpToLLVMPattern;
LogicalResult
matchAndRewrite(triton::TransOp op, OpAdaptor adaptor,
ConversionPatternRewriter &rewriter) const override {
Location loc = op->getLoc();
auto srcSmemObj =
getSharedMemoryObjectFromStruct(loc, adaptor.src(), rewriter);
SmallVector<Value> dstStrides = {srcSmemObj.strides[1],
srcSmemObj.strides[0]};
SmallVector<Value> dstOffsets = {srcSmemObj.offsets[1],
srcSmemObj.offsets[0]};
auto dstSmemObj =
SharedMemoryObject(srcSmemObj.base, dstStrides, dstOffsets);
auto retVal = getStructFromSharedMemoryObject(loc, dstSmemObj, rewriter);
rewriter.replaceOp(op, retVal);
return success();
}
};
void populateViewOpToLLVMPatterns(mlir::LLVMTypeConverter &typeConverter,
RewritePatternSet &patterns, int numWarps,
AxisInfoAnalysis &axisInfoAnalysis,
const Allocation *allocation, Value smem,
PatternBenefit benefit) {
patterns.add<ViewLikeOpConversion<triton::ViewOp>>(typeConverter, benefit);
patterns.add<ViewLikeOpConversion<triton::ExpandDimsOp>>(typeConverter,
benefit);
patterns.add<SplatOpConversion>(typeConverter, benefit);
patterns.add<ArithConstantSplatOpConversion>(typeConverter, benefit);
patterns.add<CatOpConversion>(typeConverter, benefit);
patterns.add<TransOpConversion>(typeConverter, benefit);
}

View File

@@ -0,0 +1,15 @@
#ifndef TRITON_CONVERSION_TRITONGPU_TO_LLVM_VIEW_OP_H
#define TRITON_CONVERSION_TRITONGPU_TO_LLVM_VIEW_OP_H
#include "TritonGPUToLLVMBase.h"
using namespace mlir;
using namespace mlir::triton;
void populateViewOpToLLVMPatterns(mlir::LLVMTypeConverter &typeConverter,
RewritePatternSet &patterns, int numWarps,
AxisInfoAnalysis &axisInfoAnalysis,
const Allocation *allocation, Value smem,
PatternBenefit benefit);
#endif

View File

@@ -1,5 +1,5 @@
add_mlir_conversion_library(TritonToTritonGPU
TritonToTritonGPU.cpp
TritonToTritonGPUPass.cpp
ADDITIONAL_HEADER_DIRS
${PROJECT_SOURCE_DIR}/include/triton/Conversion/TritonToTritonGPU

View File

@@ -1,16 +1,25 @@
#include "triton/Conversion/TritonToTritonGPU/TritonToTritonGPU.h"
#include "../PassDetail.h"
#include "triton/Conversion/TritonToTritonGPU/TritonToTritonGPUPass.h"
#include "mlir/Dialect/Arithmetic/IR/Arithmetic.h"
#include "mlir/Dialect/GPU/GPUDialect.h"
#include "mlir/Dialect/LLVMIR/LLVMDialect.h"
#include "mlir/Dialect/LLVMIR/NVVMDialect.h"
#include "mlir/Dialect/LLVMIR/ROCDLDialect.h"
#include "mlir/Dialect/StandardOps/IR/Ops.h"
#include "mlir/Pass/Pass.h"
#include "mlir/Transforms/DialectConversion.h"
#include "triton/Dialect/Triton/IR/Dialect.h"
#include "triton/Dialect/TritonGPU/IR/Dialect.h"
#include "triton/Dialect/TritonGPU/Transforms/TritonGPUConversion.h"
#include "llvm/ADT/APSInt.h"
#include <numeric>
using namespace mlir;
using namespace mlir::triton;
#define GEN_PASS_CLASSES
#include "triton/Conversion/Passes.h.inc"
namespace {
template <class Op> class GenericOpPattern : public OpConversionPattern<Op> {
@@ -114,6 +123,7 @@ void populateArithmeticPatternsAndLegality(
GenericOpPattern<arith::TruncIOp>, GenericOpPattern<arith::TruncFOp>,
GenericOpPattern<arith::ExtUIOp>, GenericOpPattern<arith::ExtSIOp>,
GenericOpPattern<arith::ExtFOp>, GenericOpPattern<arith::SIToFPOp>,
GenericOpPattern<arith::FPToSIOp>, GenericOpPattern<arith::FPToUIOp>,
GenericOpPattern<arith::UIToFPOp>>(typeConverter, context);
}
@@ -220,7 +230,21 @@ struct TritonDotPattern : public OpConversionPattern<triton::DotOp> {
LogicalResult
matchAndRewrite(triton::DotOp op, OpAdaptor adaptor,
ConversionPatternRewriter &rewriter) const override {
Type retType = getTypeConverter()->convertType(op.getType());
RankedTensorType origType = op.getType().cast<RankedTensorType>();
auto origShape = origType.getShape();
auto typeConverter = getTypeConverter<TritonGPUTypeConverter>();
int numWarps = typeConverter->getNumWarps();
SmallVector<unsigned> retSizePerThread = {1, 1};
if (origShape[0] * origShape[1] / (numWarps * 32) >= 4)
retSizePerThread = {2, 2};
if (origShape[0] * origShape[1] / (numWarps * 32) >= 16)
retSizePerThread = {4, 4};
SmallVector<unsigned> retOrder = {1, 0};
Attribute dEncoding = triton::gpu::BlockedEncodingAttr::get(
getContext(), origShape, retSizePerThread, retOrder, numWarps);
RankedTensorType retType =
RankedTensorType::get(origShape, origType.getElementType(), dEncoding);
// a & b must be of smem layout
auto aType = adaptor.a().getType().cast<RankedTensorType>();
auto bType = adaptor.b().getType().cast<RankedTensorType>();
@@ -230,24 +254,86 @@ struct TritonDotPattern : public OpConversionPattern<triton::DotOp> {
return failure();
Value a = adaptor.a();
Value b = adaptor.b();
SmallVector<unsigned, 2> order{1, 0};
if (!aEncoding.isa<triton::gpu::SharedEncodingAttr>()) {
Value c = adaptor.c();
if (!aEncoding.isa<triton::gpu::DotOperandEncodingAttr>()) {
Attribute encoding =
triton::gpu::SharedEncodingAttr::get(getContext(), 1, 1, 1, order);
triton::gpu::DotOperandEncodingAttr::get(getContext(), 0, dEncoding);
auto dstType = RankedTensorType::get(aType.getShape(),
aType.getElementType(), encoding);
a = rewriter.create<triton::gpu::ConvertLayoutOp>(a.getLoc(), dstType, a);
}
if (!bEncoding.isa<triton::gpu::SharedEncodingAttr>()) {
if (!bEncoding.isa<triton::gpu::DotOperandEncodingAttr>()) {
Attribute encoding =
triton::gpu::SharedEncodingAttr::get(getContext(), 1, 1, 1, order);
triton::gpu::DotOperandEncodingAttr::get(getContext(), 1, dEncoding);
auto dstType = RankedTensorType::get(bType.getShape(),
bType.getElementType(), encoding);
b = rewriter.create<triton::gpu::ConvertLayoutOp>(b.getLoc(), dstType, b);
}
rewriter.replaceOpWithNewOp<triton::DotOp>(
op, retType, a, b, adaptor.c(), adaptor.allowTF32(), adaptor.transA(),
adaptor.transB());
c = rewriter.create<triton::gpu::ConvertLayoutOp>(c.getLoc(), retType, c);
rewriter.replaceOpWithNewOp<triton::DotOp>(op, retType, a, b, c,
adaptor.allowTF32());
return success();
}
};
struct TritonCatPattern : public OpConversionPattern<triton::CatOp> {
using OpConversionPattern<triton::CatOp>::OpConversionPattern;
LogicalResult
matchAndRewrite(triton::CatOp op, OpAdaptor adaptor,
ConversionPatternRewriter &rewriter) const override {
// For now, this behaves like generic, but this will evolve when
// we add support for `can_reorder=False`
Type retType = this->getTypeConverter()->convertType(op.getType());
rewriter.replaceOpWithNewOp<triton::CatOp>(op, retType,
adaptor.getOperands());
return success();
}
};
struct TritonTransPattern : public OpConversionPattern<triton::TransOp> {
using OpConversionPattern<triton::TransOp>::OpConversionPattern;
LogicalResult
matchAndRewrite(triton::TransOp op, OpAdaptor adaptor,
ConversionPatternRewriter &rewriter) const override {
Value src = adaptor.src();
auto srcType = src.getType().cast<RankedTensorType>();
Attribute srcEncoding = srcType.getEncoding();
if (!srcEncoding)
return failure();
if (!srcEncoding.isa<triton::gpu::SharedEncodingAttr>()) {
// TODO: end-to-end correctness is broken if
// the input is blocked and the output is shared
// with different order. Maybe a backend issue in BlockedToShared?
SmallVector<unsigned> order = {1, 0};
if (auto srcBlockedEncoding =
srcEncoding.dyn_cast<triton::gpu::BlockedEncodingAttr>())
llvm::copy(srcBlockedEncoding.getOrder(), order.begin());
srcEncoding =
triton::gpu::SharedEncodingAttr::get(getContext(), 1, 1, 1, order);
srcType = RankedTensorType::get(srcType.getShape(),
srcType.getElementType(), srcEncoding);
src = rewriter.create<triton::gpu::ConvertLayoutOp>(src.getLoc(), srcType,
src);
}
auto srcSharedEncoding =
srcEncoding.cast<triton::gpu::SharedEncodingAttr>();
SmallVector<unsigned> retOrder(srcSharedEncoding.getOrder().begin(),
srcSharedEncoding.getOrder().end());
SmallVector<int64_t> retShapes(srcType.getShape().begin(),
srcType.getShape().end());
std::reverse(retOrder.begin(), retOrder.end());
std::reverse(retShapes.begin(), retShapes.end());
auto retEncoding =
triton::gpu::SharedEncodingAttr::get(getContext(), 1, 1, 1, retOrder);
auto retType =
RankedTensorType::get(retShapes, srcType.getElementType(), retEncoding);
rewriter.replaceOpWithNewOp<triton::TransOp>(op, retType, src);
return success();
}
};
@@ -278,6 +364,20 @@ struct TritonStorePattern : public OpConversionPattern<triton::StoreOp> {
}
};
struct TritonAtomicCASPattern
: public OpConversionPattern<triton::AtomicCASOp> {
using OpConversionPattern<triton::AtomicCASOp>::OpConversionPattern;
LogicalResult
matchAndRewrite(triton::AtomicCASOp op, OpAdaptor adaptor,
ConversionPatternRewriter &rewriter) const override {
rewriter.replaceOpWithNewOp<triton::AtomicCASOp>(
op, typeConverter->convertType(op.getType()), adaptor.ptr(),
adaptor.cmp(), adaptor.val());
return success();
}
};
struct TritonAtomicRMWPattern
: public OpConversionPattern<triton::AtomicRMWOp> {
using OpConversionPattern<triton::AtomicRMWOp>::OpConversionPattern;
@@ -371,13 +471,15 @@ void populateTritonPatterns(TritonGPUTypeConverter &typeConverter,
patterns.add< // TODO: view should have custom pattern that views the layout
TritonGenericPattern<triton::ViewOp>,
TritonGenericPattern<triton::BitcastOp>,
TritonGenericPattern<triton::FpToFpOp>,
TritonGenericPattern<triton::IntToPtrOp>,
TritonGenericPattern<triton::PtrToIntOp>,
TritonGenericPattern<triton::SplatOp>, TritonBroadcastPattern,
TritonGenericPattern<triton::AddPtrOp>, TritonReducePattern,
TritonExpandDimsPattern, TritonMakeRangePattern, TritonDotPattern,
TritonLoadPattern, TritonStorePattern, TritonExtElemwisePattern,
TritonPrintfPattern>(typeConverter, context);
TritonGenericPattern<triton::AddPtrOp>, TritonCatPattern,
TritonReducePattern, TritonTransPattern, TritonExpandDimsPattern,
TritonMakeRangePattern, TritonDotPattern, TritonLoadPattern,
TritonStorePattern, TritonExtElemwisePattern, TritonPrintfPattern,
TritonAtomicRMWPattern>(typeConverter, context);
}
//
@@ -441,10 +543,55 @@ struct SCFYieldPattern : public OpConversionPattern<scf::YieldOp> {
}
};
// This is borrowed from ConvertFIfOpTypes in
// SCF/Transforms/StructuralTypeConversions.cpp
class SCFIfPattern : public OpConversionPattern<scf::IfOp> {
public:
using OpConversionPattern<scf::IfOp>::OpConversionPattern;
LogicalResult
matchAndRewrite(scf::IfOp op, OpAdaptor adaptor,
ConversionPatternRewriter &rewriter) const override {
// TODO: Generalize this to any type conversion, not just 1:1.
//
// We need to implement something more sophisticated here that tracks which
// types convert to which other types and does the appropriate
// materialization logic.
// For example, it's possible that one result type converts to 0 types and
// another to 2 types, so newResultTypes would at least be the right size to
// not crash in the llvm::zip call below, but then we would set the the
// wrong type on the SSA values! These edge cases are also why we cannot
// safely use the TypeConverter::convertTypes helper here.
SmallVector<Type> newResultTypes;
for (auto type : op.getResultTypes()) {
Type newType = typeConverter->convertType(type);
if (!newType)
return rewriter.notifyMatchFailure(op, "not a 1:1 type conversion");
newResultTypes.push_back(newType);
}
// See comments in the ForOp pattern for why we clone without regions and
// then inline.
scf::IfOp newOp =
cast<scf::IfOp>(rewriter.cloneWithoutRegions(*op.getOperation()));
rewriter.inlineRegionBefore(op.getThenRegion(), newOp.getThenRegion(),
newOp.getThenRegion().end());
rewriter.inlineRegionBefore(op.getElseRegion(), newOp.getElseRegion(),
newOp.getElseRegion().end());
// Update the operands and types.
newOp->setOperands(adaptor.getOperands());
for (auto t : llvm::zip(newOp.getResults(), newResultTypes))
std::get<0>(t).setType(std::get<1>(t));
rewriter.replaceOp(op, newOp.getResults());
return success();
}
};
void populateSCFPatterns(TritonGPUTypeConverter &typeConverter,
RewritePatternSet &patterns) {
MLIRContext *context = patterns.getContext();
patterns.add<SCFYieldPattern, SCFForPattern>(typeConverter, context);
patterns.add<SCFYieldPattern, SCFForPattern, SCFIfPattern>(typeConverter,
context);
}
class ConvertTritonToTritonGPU

View File

@@ -124,6 +124,29 @@ void printStoreOp(OpAsmPrinter &printer, StoreOp storeOp) {
namespace mlir {
namespace triton {
//-- FpToFpOp --
bool FpToFpOp::areCastCompatible(::mlir::TypeRange inputs,
::mlir::TypeRange outputs) {
if (inputs.size() != 1 || outputs.size() != 1)
return false;
auto srcEltType = inputs.front();
auto dstEltType = outputs.front();
auto srcTensorType = srcEltType.dyn_cast<mlir::RankedTensorType>();
auto dstTensorType = dstEltType.dyn_cast<mlir::RankedTensorType>();
if (srcTensorType && dstTensorType) {
srcEltType = srcTensorType.getElementType();
dstEltType = dstTensorType.getElementType();
}
// Check whether fp8 <=> fp16, bf16, f32, f64
// Make `srcEltType` always the fp8 side
if (dstEltType.dyn_cast<mlir::triton::Float8Type>())
std::swap(srcEltType, dstEltType);
if (!srcEltType.dyn_cast<mlir::triton::Float8Type>())
return false;
return dstEltType.isF16() || dstEltType.isBF16() || dstEltType.isF32() ||
dstEltType.isF64();
}
//-- StoreOp --
void StoreOp::build(::mlir::OpBuilder &builder, ::mlir::OperationState &state,
::mlir::Value ptr, ::mlir::Value value) {
@@ -191,6 +214,20 @@ mlir::LogicalResult mlir::triton::DotOp::inferReturnTypes(
// type is the same as the accumulator
auto accTy = operands[2].getType().cast<RankedTensorType>();
inferredReturnTypes.push_back(accTy);
// verify encodings
auto aEnc = operands[0].getType().cast<RankedTensorType>().getEncoding();
auto bEnc = operands[1].getType().cast<RankedTensorType>().getEncoding();
auto retEnc = accTy.getEncoding();
if (aEnc) {
assert(bEnc);
Dialect &dialect = aEnc.getDialect();
auto interface = dyn_cast<DialectInferLayoutInterface>(&dialect);
if (interface->inferDotOpEncoding(aEnc, 0, retEnc, location).failed())
return mlir::failure();
if (interface->inferDotOpEncoding(bEnc, 1, retEnc, location).failed())
return mlir::failure();
}
return mlir::success();
}
@@ -203,12 +240,17 @@ mlir::LogicalResult mlir::triton::ReduceOp::inferReturnTypes(
Value arg = operands[0];
auto argTy = arg.getType().cast<RankedTensorType>();
auto argEltTy = argTy.getElementType();
auto i32Ty = IntegerType::get(argEltTy.getContext(), 32);
auto redOp =
attributes.get("redOp").cast<mlir::triton::RedOpAttr>().getValue();
bool withIndex = mlir::triton::ReduceOp::withIndex(redOp);
auto retEltTy = withIndex ? i32Ty : argEltTy;
auto retShape = argTy.getShape().vec();
int axis = attributes.get("axis").cast<IntegerAttr>().getInt();
retShape.erase(retShape.begin() + axis);
if (retShape.empty()) {
// 0d-tensor -> scalar
inferredReturnTypes.push_back(argEltTy);
inferredReturnTypes.push_back(retEltTy);
} else {
// nd-tensor where n >= 1
// infer encoding
@@ -227,11 +269,20 @@ mlir::LogicalResult mlir::triton::ReduceOp::inferReturnTypes(
}
// create type
inferredReturnTypes.push_back(
RankedTensorType::get(retShape, argEltTy, retEncoding));
RankedTensorType::get(retShape, retEltTy, retEncoding));
}
return mlir::success();
}
bool mlir::triton::ReduceOp::withIndex(mlir::triton::RedOp redOp) {
return redOp == mlir::triton::RedOp::ARGMIN ||
redOp == mlir::triton::RedOp::ARGMAX ||
redOp == mlir::triton::RedOp::ARGUMIN ||
redOp == mlir::triton::RedOp::ARGUMAX ||
redOp == mlir::triton::RedOp::ARGFMIN ||
redOp == mlir::triton::RedOp::ARGFMAX;
}
//-- SplatOp --
OpFoldResult SplatOp::fold(ArrayRef<Attribute> operands) {
auto constOperand = src().getDefiningOp<arith::ConstantOp>();
@@ -244,7 +295,7 @@ OpFoldResult SplatOp::fold(ArrayRef<Attribute> operands) {
//-- ExpandDimsOp --
mlir::LogicalResult mlir::triton::ExpandDimsOp::inferReturnTypes(
MLIRContext *context, Optional<Location> location, ValueRange operands,
MLIRContext *context, Optional<Location> loc, ValueRange operands,
DictionaryAttr attributes, RegionRange regions,
SmallVectorImpl<Type> &inferredReturnTypes) {
// infer shape
@@ -260,11 +311,9 @@ mlir::LogicalResult mlir::triton::ExpandDimsOp::inferReturnTypes(
Dialect &dialect = argEncoding.getDialect();
auto inferLayoutInterface = dyn_cast<DialectInferLayoutInterface>(&dialect);
if (inferLayoutInterface
->inferExpandDimsOpEncoding(argEncoding, axis, retEncoding)
.failed()) {
llvm::report_fatal_error("failed to infer layout for ExpandDimsOp");
return mlir::failure();
}
->inferExpandDimsOpEncoding(argEncoding, axis, retEncoding, loc)
.failed())
return emitOptionalError(loc, "failed to infer layout for ExpandDimsOp");
}
// create type
auto argEltTy = argTy.getElementType();

View File

@@ -19,7 +19,7 @@ mlir::OpTrait::impl::verifySameOperandsAndResultEncoding(Operation *op) {
for (auto resultType : op->getResultTypes())
if (failed(verifySameEncoding(resultType, type)))
return op->emitOpError()
<< "requires the same shape for all operands and results";
<< "requires the same encoding for all operands and results";
return verifySameOperandsEncoding(op);
}
@@ -48,7 +48,8 @@ mlir::LogicalResult mlir::OpTrait::impl::verifyTensorSize(Operation *op) {
<< " has more than that";
if ((numElements & (numElements - 1)) != 0)
return op->emitError("Number of elements must be power-of-two, but ")
<< *op << " doesn't follow the rule";
<< *op << " doesn't follow the rule (" << numElements << ")"
<< " elements";
}
}
for (auto opType : op->getResultTypes()) {
@@ -62,7 +63,8 @@ mlir::LogicalResult mlir::OpTrait::impl::verifyTensorSize(Operation *op) {
<< " has more than that";
if ((numElements & (numElements - 1)) != 0)
return op->emitError("Number of elements must be power-of-two, but ")
<< *op << " doesn't follow the rule";
<< *op << " doesn't follow the rule (" << numElements << ")"
<< " elements";
}
}
return success();

View File

@@ -196,7 +196,7 @@ public:
patterns.add<CombineDotAddFRevPattern>(context);
// %}
patterns.add<CombineSelectMaskedLoadPattern>(context);
patterns.add<CombineAddPtrPattern>(context);
// patterns.add<CombineAddPtrPattern>(context);
patterns.add<CombineBroadcastConstantPattern>(context);
if (applyPatternsAndFoldGreedily(m, std::move(patterns)).failed())

View File

@@ -12,30 +12,31 @@ include "triton/Dialect/Triton/IR/TritonOps.td"
// AddIOp(d, DotOp(a, b, c)) and c==0 => DotOp(a, b, d)
// AddFOp(d, DotOp(a, b, c)) and c==0 => DotOp(a, b, d)
def CombineDotAddIPattern : Pat<
(Arith_AddIOp $d, (TT_DotOp:$res $a, $b, $c, $allowTF32, $transA, $transB)),
(TT_DotOp $a, $b, $d, $allowTF32, $transA, $transB),
(Arith_AddIOp $d, (TT_DotOp:$res $a, $b, $c, $allowTF32)),
(TT_DotOp $a, $b, $d, $allowTF32),
[(Constraint<CPred<"isZero($0)">> $c)]>;
def CombineDotAddFPattern : Pat<
(Arith_AddFOp $d, (TT_DotOp:$res $a, $b, $c, $allowTF32, $transA, $transB)),
(TT_DotOp $a, $b, $d, $allowTF32, $transA, $transB),
(Arith_AddFOp $d, (TT_DotOp:$res $a, $b, $c, $allowTF32)),
(TT_DotOp $a, $b, $d, $allowTF32),
[(Constraint<CPred<"isZero($0)">> $c)]>;
def CombineDotAddIRevPattern : Pat<
(Arith_AddIOp (TT_DotOp:$res $a, $b, $c, $allowTF32, $transA, $transB), $d),
(TT_DotOp $a, $b, $d, $allowTF32, $transA, $transB),
(Arith_AddIOp (TT_DotOp:$res $a, $b, $c, $allowTF32), $d),
(TT_DotOp $a, $b, $d, $allowTF32),
[(Constraint<CPred<"isZero($0)">> $c)]>;
def CombineDotAddFRevPattern : Pat<
(Arith_AddFOp (TT_DotOp:$res $a, $b, $c, $allowTF32, $transA, $transB), $d),
(TT_DotOp $a, $b, $d, $allowTF32, $transA, $transB),
(Arith_AddFOp (TT_DotOp:$res $a, $b, $c, $allowTF32), $d),
(TT_DotOp $a, $b, $d, $allowTF32),
[(Constraint<CPred<"isZero($0)">> $c)]>;
// TODO: this fails for addptr(addptr(ptr, i32), i64)
// Commented out until fixed
// addptr(addptr(%ptr, %idx0), %idx1) => addptr(%ptr, AddI(%idx0, %idx1))
// Note: leave (sub %c0, %c0) canceling to ArithmeticDialect
// (ref: ArithmeticCanonicalization.td)
def CombineAddPtrPattern : Pat<
(TT_AddPtrOp (TT_AddPtrOp $ptr, $idx0), $idx1),
(TT_AddPtrOp $ptr, (Arith_AddIOp $idx0, $idx1))>;
// def CombineAddPtrPattern : Pat<
// (TT_AddPtrOp (TT_AddPtrOp $ptr, $idx0), $idx1),
// (TT_AddPtrOp $ptr, (Arith_AddIOp $idx0, $idx1))>;
// broadcast(cst) => cst
def getConstantValue : NativeCodeCall<"getConstantValue($_builder, $0, $1)">;

View File

@@ -42,13 +42,11 @@ static Type getPointeeType(Type type) {
namespace gpu {
// TODO: Inheritation of layout attributes
unsigned getElemsPerThread(Type type) {
if (type.isIntOrIndexOrFloat() || type.isa<triton::PointerType>())
return 1;
auto tensorType = type.cast<RankedTensorType>();
auto layout = tensorType.getEncoding();
auto shape = tensorType.getShape();
// TODO: Inheritance of layout attributes
// so that all distributed layouts implement
// these utilities
unsigned getElemsPerThread(Attribute layout, ArrayRef<int64_t> shape) {
if (auto blockedLayout = layout.dyn_cast<BlockedEncodingAttr>()) {
return blockedLayout.getElemsPerThread(shape);
} else if (auto sliceLayout = layout.dyn_cast<SliceEncodingAttr>()) {
@@ -57,28 +55,102 @@ unsigned getElemsPerThread(Type type) {
return mmaLayout.getElemsPerThread(shape);
} else if (auto sharedLayout = layout.dyn_cast<SharedEncodingAttr>()) {
return sharedLayout.getElemsPerThread(shape);
} else if (auto dotLayout = layout.dyn_cast<DotOperandEncodingAttr>()) {
return dotLayout.getElemsPerThread(shape);
} else {
assert(0 && "getElemsPerThread not implemented");
return 0;
}
}
SmallVector<unsigned> getSizePerThread(Attribute layout) {
unsigned getElemsPerThread(Type type) {
if (type.isIntOrIndexOrFloat() || type.isa<triton::Float8Type>() ||
type.isa<triton::PointerType>())
return 1;
auto tensorType = type.cast<RankedTensorType>();
return getElemsPerThread(tensorType.getEncoding(), tensorType.getShape());
}
SmallVector<unsigned> getThreadsPerWarp(const Attribute &layout) {
if (auto blockedLayout = layout.dyn_cast<BlockedEncodingAttr>()) {
return SmallVector<unsigned>(blockedLayout.getThreadsPerWarp().begin(),
blockedLayout.getThreadsPerWarp().end());
}
if (auto mmaLayout = layout.dyn_cast<MmaEncodingAttr>()) {
if (mmaLayout.isVolta())
return {4, 8};
if (mmaLayout.isAmpere())
return {8, 4};
}
assert(0 && "getThreadsPerWarp not implemented");
return {};
}
SmallVector<unsigned> getWarpsPerCTA(const Attribute &layout) {
if (auto blockedLayout = layout.dyn_cast<BlockedEncodingAttr>()) {
return SmallVector<unsigned>(blockedLayout.getWarpsPerCTA().begin(),
blockedLayout.getWarpsPerCTA().end());
}
if (auto mmaLayout = layout.dyn_cast<MmaEncodingAttr>()) {
return SmallVector<unsigned>(mmaLayout.getWarpsPerCTA().begin(),
mmaLayout.getWarpsPerCTA().end());
}
assert(0 && "getWarpsPerCTA not implemented");
return {};
}
SmallVector<unsigned> getSizePerThread(const Attribute &layout) {
if (auto blockedLayout = layout.dyn_cast<BlockedEncodingAttr>()) {
return SmallVector<unsigned>(blockedLayout.getSizePerThread().begin(),
blockedLayout.getSizePerThread().end());
} else if (auto sliceLayout = layout.dyn_cast<SliceEncodingAttr>()) {
return getSizePerThread(sliceLayout.getParent());
} else if (auto mmaLayout = layout.dyn_cast<MmaEncodingAttr>()) {
assert(mmaLayout.getVersion() == 2 &&
"mmaLayout version = 1 is not implemented yet");
return SmallVector<unsigned>{2, 2};
if (mmaLayout.isAmpere()) {
return {2, 2};
} else if (mmaLayout.isVolta()) {
// Note: here the definition of sizePerThread is obscure, which doesn't
// mean vecSize=4 can be supported in the last dimension.
return {2, 4};
} else {
llvm_unreachable("Unexpected mma version");
}
} else if (auto dotLayout = layout.dyn_cast<DotOperandEncodingAttr>()) {
auto parentLayout = dotLayout.getParent();
assert(parentLayout && "DotOperandEncodingAttr must have a parent");
if (auto parentMmaLayout = parentLayout.dyn_cast<MmaEncodingAttr>()) {
assert(parentMmaLayout.isAmpere() &&
"mmaLayout version = 1 is not implemented yet");
auto parentShapePerCTA = getShapePerCTA(parentLayout);
auto opIdx = dotLayout.getOpIdx();
if (opIdx == 0) {
return {2, 4};
} else if (opIdx == 1) {
return {4, 1};
} else {
assert(0 && "DotOperandEncodingAttr opIdx must be 0 or 1");
return {};
}
} else {
assert(0 && "DotOperandEncodingAttr non-MmaEncodingAttr parent not "
"supported yet");
return {};
}
} else {
assert(0 && "getSizePerThread not implemented");
return {};
}
}
SmallVector<unsigned> getContigPerThread(Attribute layout) {
if (auto mmaLayout = layout.dyn_cast<MmaEncodingAttr>()) {
assert(mmaLayout.isVolta() || mmaLayout.isAmpere());
return {1, 2};
} else {
return getSizePerThread(layout);
}
}
SmallVector<unsigned> getThreadsPerCTA(const Attribute &layout) {
SmallVector<unsigned> threads;
if (auto blockedLayout = layout.dyn_cast<BlockedEncodingAttr>()) {
@@ -104,23 +176,48 @@ SmallVector<unsigned> getShapePerCTA(const Attribute &layout) {
} else if (auto sliceLayout = layout.dyn_cast<SliceEncodingAttr>()) {
unsigned dim = sliceLayout.getDim();
auto parent = sliceLayout.getParent();
if (auto blockedParent = parent.dyn_cast<BlockedEncodingAttr>()) {
for (unsigned d = 0, n = blockedParent.getOrder().size(); d < n; ++d) {
if (d == dim)
continue;
shape.push_back(blockedParent.getSizePerThread()[d] *
blockedParent.getThreadsPerWarp()[d] *
blockedParent.getWarpsPerCTA()[d]);
}
} else {
assert(0 && "SliceEncodingAttr with parent other than "
"BlockedEncodingAttr not implemented");
for (unsigned d = 0, n = getOrder(parent).size(); d < n; ++d) {
if (d == dim)
continue;
shape.push_back(getShapePerCTA(parent)[d]);
}
} else if (auto mmaLayout = layout.dyn_cast<MmaEncodingAttr>()) {
assert(mmaLayout.getVersion() == 2 &&
"mmaLayout version = 1 is not implemented yet");
return {16 * mmaLayout.getWarpsPerCTA()[0],
8 * mmaLayout.getWarpsPerCTA()[1]};
if (mmaLayout.isAmpere())
return {16 * mmaLayout.getWarpsPerCTA()[0],
8 * mmaLayout.getWarpsPerCTA()[1]};
if (mmaLayout.isVolta())
return {16 * mmaLayout.getWarpsPerCTA()[0],
16 * mmaLayout.getWarpsPerCTA()[1]};
assert(0 && "Unexpected MMA layout version found");
} else if (auto dotLayout = layout.dyn_cast<DotOperandEncodingAttr>()) {
auto parentLayout = dotLayout.getParent();
assert(parentLayout && "DotOperandEncodingAttr must have a parent");
if (auto parentMmaLayout = parentLayout.dyn_cast<MmaEncodingAttr>()) {
assert(parentMmaLayout.isAmpere() &&
"mmaLayout version = 1 is not implemented yet");
auto parentShapePerCTA = getShapePerCTA(parentLayout);
auto opIdx = dotLayout.getOpIdx();
if (opIdx == 0) {
return {parentShapePerCTA[0], 16};
} else if (opIdx == 1) {
return {16, parentShapePerCTA[1]};
} else {
assert(0 && "DotOperandEncodingAttr opIdx must be 0 or 1");
}
} else {
assert(0 && "DotOperandEncodingAttr non-MmaEncodingAttr parent not "
"supported yet");
}
} else if (auto mmaLayout = layout.dyn_cast<MmaEncodingAttr>()) {
if (mmaLayout.isAmpere()) {
return {16 * mmaLayout.getWarpsPerCTA()[0],
8 * mmaLayout.getWarpsPerCTA()[1]};
} else if (mmaLayout.isVolta()) {
return {16 * mmaLayout.getWarpsPerCTA()[0],
16 * mmaLayout.getWarpsPerCTA()[1]};
} else {
llvm_unreachable("Unexpected mma version");
}
} else {
assert(0 && "Unimplemented usage of getShapePerCTA");
}
@@ -132,7 +229,9 @@ SmallVector<unsigned> getOrder(const Attribute &layout) {
return SmallVector<unsigned>(blockedLayout.getOrder().begin(),
blockedLayout.getOrder().end());
} else if (auto mmaLayout = layout.dyn_cast<MmaEncodingAttr>()) {
return SmallVector<unsigned>{1, 0};
return {1, 0};
} else if (auto dotLayout = layout.dyn_cast<DotOperandEncodingAttr>()) {
return {1, 0};
} else if (auto sliceLayout = layout.dyn_cast<SliceEncodingAttr>()) {
SmallVector<unsigned> parentOrder = getOrder(sliceLayout.getParent());
unsigned dim = sliceLayout.getDim();
@@ -240,11 +339,11 @@ unsigned BlockedEncodingAttr::getElemsPerThread(ArrayRef<int64_t> shape) const {
return product<unsigned>(elemsPerThread);
}
SmallVector<int64_t>
SliceEncodingAttr::paddedShape(ArrayRef<int64_t> shape) const {
template <class T>
SmallVector<T> SliceEncodingAttr::paddedShape(ArrayRef<T> shape) const {
size_t rank = shape.size();
unsigned dim = getDim();
SmallVector<int64_t> retShape(rank + 1);
SmallVector<T> retShape(rank + 1);
for (unsigned d = 0; d < rank + 1; ++d) {
if (d < dim)
retShape[d] = shape[d];
@@ -255,37 +354,35 @@ SliceEncodingAttr::paddedShape(ArrayRef<int64_t> shape) const {
}
return retShape;
}
template SmallVector<unsigned>
SliceEncodingAttr::paddedShape<unsigned>(ArrayRef<unsigned> shape) const;
template SmallVector<int64_t>
SliceEncodingAttr::paddedShape<int64_t>(ArrayRef<int64_t> shape) const;
unsigned SliceEncodingAttr::getElemsPerThread(ArrayRef<int64_t> shape) const {
size_t rank = shape.size();
auto parent = getParent();
if (auto blockedParent = parent.dyn_cast<BlockedEncodingAttr>()) {
assert(rank == blockedParent.getSizePerThread().size() - 1 &&
"unexpected rank in SliceEncodingAttr::getElemsPerThread");
return blockedParent.getElemsPerThread(paddedShape(shape));
} else {
assert(0 && "getElemsPerThread not implemented");
return 0;
}
return ::getElemsPerThread(parent, paddedShape(shape));
}
unsigned MmaEncodingAttr::getElemsPerThread(ArrayRef<int64_t> shape) const {
size_t rank = shape.size();
assert(rank == 2 && "Unexpected rank of mma layout");
assert((getVersion() == 1 || getVersion() == 2) &&
"Only version 1 and 2 is supported");
assert((isVolta() || isAmpere()) && "Only version 1 and 2 is supported");
int res = 0;
if (getVersion() == 1) {
if (isVolta()) {
unsigned mmasRow = ceil<unsigned>(shape[0], 16 * getWarpsPerCTA()[0]);
unsigned mmasCol = ceil<unsigned>(shape[1], 16 * getWarpsPerCTA()[1]);
// Each warp-level mma884 will perform a m16xn16xk4 mma, thus get a m16xn16
// matrix as result.
res = mmasRow * mmasCol * (16 * 16 / 32);
} else if (getVersion() == 2) {
} else if (isAmpere()) {
unsigned elemsCol = ceil<unsigned>(shape[0], 16 * getWarpsPerCTA()[0]) * 2;
unsigned elemsRow = ceil<unsigned>(shape[1], 8 * getWarpsPerCTA()[1]) * 2;
res = elemsCol * elemsRow;
} else {
llvm_unreachable("Unexpected mma version");
}
return res;
@@ -297,6 +394,15 @@ unsigned SharedEncodingAttr::getElemsPerThread(ArrayRef<int64_t> shape) const {
return 0;
}
unsigned
DotOperandEncodingAttr::getElemsPerThread(ArrayRef<int64_t> shape) const {
if (auto blockedLayout = getParent().dyn_cast<BlockedEncodingAttr>()) {
return blockedLayout.getElemsPerThread(shape);
}
assert(0 && "DotOperandEncodingAttr::getElemsPerThread not implemented");
return 0;
}
//===----------------------------------------------------------------------===//
// Blocked Encoding
//===----------------------------------------------------------------------===//
@@ -369,12 +475,17 @@ Attribute MmaEncodingAttr::parse(AsmParser &parser, Type type) {
if (parser.parseGreater().failed())
return {};
unsigned version = 0;
unsigned versionMajor = 0;
unsigned versionMinor = 0;
SmallVector<unsigned, 2> warpsPerCTA;
for (const NamedAttribute &attr : dict) {
if (attr.getName() == "version") {
if (parseUInt(parser, attr, version, "version").failed())
if (attr.getName() == "versionMajor") {
if (parseUInt(parser, attr, versionMajor, "versionMajor").failed())
return {};
}
if (attr.getName() == "versionMinor") {
if (parseUInt(parser, attr, versionMinor, "versionMinor").failed())
return {};
}
if (attr.getName() == "warpsPerCTA") {
@@ -383,13 +494,14 @@ Attribute MmaEncodingAttr::parse(AsmParser &parser, Type type) {
}
}
return parser.getChecked<MmaEncodingAttr>(parser.getContext(), version,
warpsPerCTA);
return parser.getChecked<MmaEncodingAttr>(parser.getContext(), versionMajor,
versionMinor, warpsPerCTA);
}
void MmaEncodingAttr::print(AsmPrinter &printer) const {
printer << "<{"
<< "version = " << getVersion() << ", "
<< "versionMajor = " << getVersionMajor() << ", "
<< "versionMinor = " << getVersionMinor() << ", "
<< "warpsPerCTA = [" << getWarpsPerCTA() << "]"
<< "}>";
}
@@ -468,6 +580,58 @@ void SharedEncodingAttr::print(AsmPrinter &printer) const {
<< "}>";
}
//===----------------------------------------------------------------------===//
// Mma encoding
//===----------------------------------------------------------------------===//
bool MmaEncodingAttr::isVolta() const { return getVersionMajor() == 1; }
bool MmaEncodingAttr::isAmpere() const { return getVersionMajor() == 2; }
// Get [isARow, isBRow, isAVec4, isBVec4] from versionMinor
std::tuple<bool, bool, bool, bool>
MmaEncodingAttr::decodeVoltaLayoutStates() const {
unsigned versionMinor = getVersionMinor();
bool isARow = versionMinor & (1 << 0);
bool isBRow = versionMinor & (1 << 1);
bool isAVec4 = versionMinor & (1 << 2);
bool isBVec4 = versionMinor & (1 << 3);
return std::make_tuple(isARow, isBRow, isAVec4, isBVec4);
}
//===----------------------------------------------------------------------===//
// DotOperand Encoding
//===----------------------------------------------------------------------===//
Attribute DotOperandEncodingAttr::parse(AsmParser &parser, Type type) {
if (parser.parseLess().failed())
return {};
NamedAttrList attrs;
if (parser.parseOptionalAttrDict(attrs).failed())
return {};
if (parser.parseGreater().failed())
return {};
unsigned opIdx = attrs.get("opIdx").cast<IntegerAttr>().getInt();
Attribute parent = attrs.get("parent");
Attribute isMMAv1Row;
if (parent.isa<MmaEncodingAttr>() &&
parent.cast<MmaEncodingAttr>().isVolta()) {
isMMAv1Row = attrs.get("isMMAv1Row");
if (!isMMAv1Row)
llvm::report_fatal_error("isMMAv1Row attribute is missing");
}
return parser.getChecked<DotOperandEncodingAttr>(parser.getContext(), opIdx,
parent, isMMAv1Row);
}
void DotOperandEncodingAttr::print(mlir::AsmPrinter &printer) const {
printer << "<{"
<< "opIdx = " << getOpIdx() << ", "
<< "parent = " << getParent();
if (getIsMMAv1Row())
printer << ", isMMAv1Row = " << getIsMMAv1Row();
printer << "}>";
}
//===----------------------------------------------------------------------===//
// InsertSliceAsyncOp
//===----------------------------------------------------------------------===//
@@ -527,30 +691,6 @@ void printInsertSliceAsyncOp(OpAsmPrinter &printer,
printer.printStrippedAttrOrType(insertSliceAsyncOp.result().getType());
}
//===----------------------------------------------------------------------===//
// DotOperand Encoding
//===----------------------------------------------------------------------===//
Attribute DotOperandEncodingAttr::parse(AsmParser &parser, Type type) {
if (parser.parseLess().failed())
return {};
NamedAttrList attrs;
if (parser.parseOptionalAttrDict(attrs).failed())
return {};
if (parser.parseGreater().failed())
return {};
unsigned opIdx = attrs.get("opIdx").cast<IntegerAttr>().getInt();
Attribute parent = attrs.get("parent");
return parser.getChecked<DotOperandEncodingAttr>(parser.getContext(), opIdx,
parent);
}
void DotOperandEncodingAttr::print(mlir::AsmPrinter &printer) const {
printer << "<{"
<< "opIdx = " << getOpIdx() << ", "
<< "parent = " << getParent() << "}>";
}
//===----------------------------------------------------------------------===//
// ASM Interface (i.e.: alias)
//===----------------------------------------------------------------------===//
@@ -591,21 +731,32 @@ struct TritonGPUInferLayoutInterface
LogicalResult
inferExpandDimsOpEncoding(Attribute operandEncoding, unsigned axis,
Attribute &resultEncoding) const override {
Attribute &resultEncoding,
Optional<Location> location) const override {
auto sliceEncoding = operandEncoding.dyn_cast<SliceEncodingAttr>();
if (!sliceEncoding) {
llvm::report_fatal_error(
"ExpandDimsOp operand encoding must be SliceEncodingAttr");
return failure();
}
if (sliceEncoding.getDim() != axis) {
llvm::report_fatal_error(
"Incompatible slice dimension for ExpandDimsOp operand");
return failure();
}
if (!sliceEncoding)
return emitOptionalError(
location, "ExpandDimsOp operand encoding must be SliceEncodingAttr");
if (sliceEncoding.getDim() != axis)
return emitOptionalError(
location, "Incompatible slice dimension for ExpandDimsOp operand");
resultEncoding = sliceEncoding.getParent();
return success();
}
LogicalResult inferDotOpEncoding(Attribute operandEncoding, unsigned opIdx,
Attribute retEncoding,
Optional<Location> location) const override {
if (auto dotOpEnc = operandEncoding.dyn_cast<DotOperandEncodingAttr>()) {
if (opIdx != dotOpEnc.getOpIdx())
return emitOptionalError(location, "Wrong opIdx");
if (retEncoding != dotOpEnc.getParent())
return emitOptionalError(location, "Incompatible parent encoding");
} else
return emitOptionalError(
location, "Dot's a/b's encoding should be of DotOperandEncodingAttr");
return success();
}
};
void TritonGPUDialect::initialize() {

View File

@@ -7,7 +7,7 @@ add_mlir_dialect_library(TritonGPUTransforms
CanonicalizeLoops.cpp
Combine.cpp
Pipeline.cpp
Swizzle.cpp
Prefetch.cpp
TritonGPUConversion.cpp
DEPENDS

View File

@@ -24,7 +24,7 @@ struct CanonicalizePass
// The following piece of code is a workaround to
// very crudely remove dead code, by making an iteration
// argument yield itself if it is not used to create
// side-effects anywhere.
// side effects anywhere.
getOperation()->walk([&](scf::ForOp forOp) -> void {
for (size_t i = 0; i < forOp.getNumResults(); ++i) {
// condition 1: no other iter arguments depend on it

View File

@@ -32,7 +32,10 @@ struct CoalescePass : public TritonGPUCoalesceBase<CoalescePass> {
// Thread tile size depends on memory alignment
SmallVector<unsigned, 4> sizePerThread(rank, 1);
PointerType ptrType = origType.getElementType().cast<PointerType>();
unsigned numBits = ptrType.getPointeeType().getIntOrFloatBitWidth();
auto pointeeType = ptrType.getPointeeType();
unsigned numBits = pointeeType.isa<triton::Float8Type>()
? 8
: pointeeType.getIntOrFloatBitWidth();
unsigned maxMultiple = info.getDivisibility(order[0]);
unsigned maxContig = info.getContiguity(order[0]);
unsigned alignment = std::min(maxMultiple, maxContig);

File diff suppressed because it is too large Load Diff

View File

@@ -1,3 +1,4 @@
#include "mlir/Dialect/Tensor/IR/Tensor.h"
#include "mlir/IR/BlockAndValueMapping.h"
#include "triton/Dialect/TritonGPU/IR/Dialect.h"
#include "triton/Dialect/TritonGPU/Transforms/Passes.h"
@@ -11,6 +12,7 @@
//===----------------------------------------------------------------------===//
using namespace mlir;
namespace ttg = triton::gpu;
#define GEN_PASS_CLASSES
#include "triton/Dialect/TritonGPU/Transforms/Passes.h.inc"
@@ -23,20 +25,25 @@ static Type getI1SameShape(Value v) {
tensorType.getEncoding());
}
#define int_attr(num) builder.getI64IntegerAttr(num)
namespace {
class LoopPipeliner {
/// cache forOp we are working on
/// Cache forOp we are working on
scf::ForOp forOp;
/// cache YieldOp for this forOp
/// Cache YieldOp for this forOp
scf::YieldOp yieldOp;
/// loads to be pipelined
/// Loads to be pipelined
SetVector<Value> loads;
/// the value that each load will be mapped to (after layout conversion)
/// The value that each load will be mapped to (after layout conversion)
DenseMap<Value, Value> loadsMapping;
/// load => buffer
DenseMap<Value, Value> loadsBuffer;
/// load => buffer type (with shared layout after swizzling)
DenseMap<Value, RankedTensorType> loadsBufferType;
/// load => buffer at stage N
DenseMap<Value, SmallVector<Value>> loadStageBuffer;
/// load => after extract
@@ -46,7 +53,7 @@ class LoopPipeliner {
///
Value loopIterIdx;
/// comments on numStages:
/// Comments on numStages:
/// [0, numStages-1) are in the prologue
/// numStages-1 is appended after the loop body
int numStages;
@@ -56,6 +63,7 @@ class LoopPipeliner {
/// Block arguments that loads depend on
DenseSet<BlockArgument> depArgs;
/// Operations (inside the loop body) that loads depend on
DenseSet<Operation *> depOps;
@@ -66,9 +74,8 @@ class LoopPipeliner {
Value lookupOrDefault(Value origin, int stage);
/// returns a empty buffer of size <numStages, ...>
triton::gpu::AllocTensorOp allocateEmptyBuffer(Operation *op,
OpBuilder &builder);
/// Returns a empty buffer of size <numStages, ...>
ttg::AllocTensorOp allocateEmptyBuffer(Operation *op, OpBuilder &builder);
public:
LoopPipeliner(scf::ForOp forOp, int numStages)
@@ -80,7 +87,7 @@ public:
/// Collect loads to pipeline. Return success if we can pipeline this loop
LogicalResult initialize();
/// emit pipelined loads (before loop body)
/// Emit pipelined loads (before loop body)
void emitPrologue();
/// emit pipelined loads (after loop body)
@@ -106,7 +113,7 @@ Value LoopPipeliner::lookupOrDefault(Value origin, int stage) {
}
void LoopPipeliner::collectDeps(Value v, int stages, DenseSet<Value> &deps) {
// Loop-invarant value. skip
// Loop-invariant value, skip
if (v.getParentRegion() != &forOp.getLoopBody())
return;
@@ -116,33 +123,30 @@ void LoopPipeliner::collectDeps(Value v, int stages, DenseSet<Value> &deps) {
return;
if (auto arg = v.dyn_cast<BlockArgument>()) {
deps.insert(v);
// Note: we have iv as the first arg, so the op idx is arg.getArgNumber()-1
collectDeps(yieldOp->getOperand(arg.getArgNumber() - 1), stages - 1, deps);
if (arg.getArgNumber() > 0) {
// Skip the first arg (loop induction variable)
// Otherwise the op idx is arg.getArgNumber()-1
deps.insert(v);
collectDeps(yieldOp->getOperand(arg.getArgNumber() - 1), stages - 1,
deps);
}
} else { // value
// v might be in deps, but we still need to visit v.
// This is because v might depends on value in previous iterations
// This is because v might depend on value in previous iterations
deps.insert(v);
for (Value op : v.getDefiningOp()->getOperands())
collectDeps(op, stages, deps);
}
}
triton::gpu::AllocTensorOp
LoopPipeliner::allocateEmptyBuffer(Operation *op, OpBuilder &builder) {
// allocate a buffer for each pipelined tensor
ttg::AllocTensorOp LoopPipeliner::allocateEmptyBuffer(Operation *op,
OpBuilder &builder) {
// Allocate a buffer for each pipelined tensor
// shape: e.g. (numStages==4), <32x64xbf16> -> <4x32x64xbf16>
Value convertLayout = loadsMapping[op->getResult(0)];
if (auto tensorType = convertLayout.getType().dyn_cast<RankedTensorType>()) {
SmallVector<int64_t> shape(tensorType.getShape().begin(),
tensorType.getShape().end());
shape.insert(shape.begin(), numStages);
Type elementType = tensorType.getElementType();
// The encoding of the buffer is similar to the original tensor
Attribute encoding = tensorType.getEncoding();
auto bufferType = RankedTensorType::get(shape, elementType, encoding);
return builder.create<triton::gpu::AllocTensorOp>(convertLayout.getLoc(),
bufferType);
return builder.create<ttg::AllocTensorOp>(
convertLayout.getLoc(), loadsBufferType[op->getResult(0)]);
}
llvm_unreachable("Async copy's return should be of RankedTensorType");
}
@@ -178,40 +182,49 @@ LogicalResult LoopPipeliner::initialize() {
// other load in the prologue, which is against the point of the pipeline
// pass)
for (triton::LoadOp loadOp : allLoads) {
bool isCandiate = true;
bool isCandidate = true;
for (triton::LoadOp other : allLoads) {
if (loadDeps[loadOp].contains(other)) {
isCandiate = false;
isCandidate = false;
break;
}
}
// For now, we only pipeline loads that have one covert_layout (to smem) use
// We only pipeline loads that have one covert_layout (to dot_op) use
// TODO: lift this constraint in the future
if (isCandiate && loadOp.getResult().hasOneUse()) {
isCandiate = false;
if (isCandidate && loadOp.getResult().hasOneUse()) {
isCandidate = false;
Operation *use = *loadOp.getResult().getUsers().begin();
if (auto convertLayout =
llvm::dyn_cast<triton::gpu::ConvertLayoutOp>(use)) {
if (auto convertLayout = llvm::dyn_cast<ttg::ConvertLayoutOp>(use)) {
if (auto tensorType = convertLayout.getResult()
.getType()
.dyn_cast<RankedTensorType>()) {
if (tensorType.getEncoding().isa<triton::gpu::SharedEncodingAttr>()) {
isCandiate = true;
if (auto dotOpEnc = tensorType.getEncoding()
.dyn_cast<ttg::DotOperandEncodingAttr>()) {
isCandidate = true;
loadsMapping[loadOp] = convertLayout;
auto ty = loadOp.getType().cast<RankedTensorType>();
SmallVector<int64_t> bufferShape(ty.getShape().begin(),
ty.getShape().end());
bufferShape.insert(bufferShape.begin(), numStages);
auto sharedEnc = ttg::SharedEncodingAttr::get(
ty.getContext(), dotOpEnc, ty.getShape(),
triton::gpu::getOrder(ty.getEncoding()), ty.getElementType());
loadsBufferType[loadOp] = RankedTensorType::get(
bufferShape, ty.getElementType(), sharedEnc);
}
}
}
} else
isCandiate = false;
isCandidate = false;
if (isCandiate)
if (isCandidate)
loads.insert(loadOp);
}
// we have some loads to pipeline
// We have some loads to pipeline
if (!loads.empty()) {
// update depArgs & depOps
// Update depArgs & depOps
for (Value loadOp : loads) {
for (Value dep : loadDeps[loadOp]) {
// TODO: we should record the stage that the value is depended on
@@ -242,16 +255,16 @@ void LoopPipeliner::emitPrologue() {
Value iv = forOp.getLowerBound();
pipelineIterIdx = builder.create<arith::ConstantIntOp>(iv.getLoc(), 0, 32);
for (int stage = 0; stage < numStages - 1; ++stage) {
// special handling for induction variable as the increment is implicit
// Special handling for induction variable as the increment is implicit
if (stage != 0)
iv = builder.create<arith::AddIOp>(iv.getLoc(), iv, forOp.getStep());
setValueMapping(forOp.getInductionVar(), iv, stage);
// special handling for loop condition as there is no condition in ForOp
// Special handling for loop condition as there is no condition in ForOp
Value loopCond = builder.create<arith::CmpIOp>(
iv.getLoc(), arith::CmpIPredicate::slt, iv, forOp.getUpperBound());
// rematerialize peeled values
// Rematerialize peeled values
SmallVector<Operation *> orderedDeps;
for (Operation &op : forOp.getLoopBody().front()) {
if (depOps.contains(&op))
@@ -305,11 +318,11 @@ void LoopPipeliner::emitPrologue() {
}
}
// update mapping of results
// Update mapping of results
for (unsigned dstIdx : llvm::seq(unsigned(0), op->getNumResults())) {
Value originalResult = op->getResult(dstIdx);
// copy_async will update the value of its only use
// TODO: load should no be used in the preheader?
// TODO: load should not be used in the preheader?
if (loads.contains(originalResult)) {
break;
// originalResult = loadsMapping[originalResult];
@@ -330,23 +343,25 @@ void LoopPipeliner::emitPrologue() {
builder.create<arith::ConstantIntOp>(iv.getLoc(), 1, 32));
} // for (int stage = 0; stage < numStages - 1; ++stage)
auto intAttr = [&](int64_t v) { return builder.getI64IntegerAttr(v); };
// async.wait & extract_slice
builder.create<triton::gpu::AsyncWaitOp>(loads[0].getLoc(),
loads.size() * (numStages - 2));
builder.create<ttg::AsyncWaitOp>(loads[0].getLoc(),
loads.size() * (numStages - 2));
loopIterIdx = builder.create<arith::ConstantIntOp>(iv.getLoc(), 0, 32);
for (Value loadOp : loads) {
auto sliceType = loadsMapping[loadOp].getType().cast<RankedTensorType>();
sliceType =
RankedTensorType::get(sliceType.getShape(), sliceType.getElementType(),
loadsBufferType[loadOp].getEncoding());
Value extractSlice = builder.create<tensor::ExtractSliceOp>(
loadOp.getLoc(), sliceType, loadStageBuffer[loadOp][numStages - 1],
SmallVector<OpFoldResult>{intAttr(0), intAttr(0), intAttr(0)},
SmallVector<OpFoldResult>{intAttr(1), intAttr(sliceType.getShape()[0]),
intAttr(sliceType.getShape()[1])},
SmallVector<OpFoldResult>{intAttr(1), intAttr(1), intAttr(1)});
SmallVector<OpFoldResult>{int_attr(0), int_attr(0), int_attr(0)},
SmallVector<OpFoldResult>{int_attr(1),
int_attr(sliceType.getShape()[0]),
int_attr(sliceType.getShape()[1])},
SmallVector<OpFoldResult>{int_attr(1), int_attr(1), int_attr(1)});
loadsExtract[loadOp] = extractSlice;
}
// bump up loopIterIdx, this is used for getting the correct slice for the
// Bump up loopIterIdx, this is used for getting the correct slice for the
// *next* iteration
loopIterIdx = builder.create<arith::AddIOp>(
loopIterIdx.getLoc(), loopIterIdx,
@@ -355,9 +370,6 @@ void LoopPipeliner::emitPrologue() {
void LoopPipeliner::emitEpilogue() {
// If there's any outstanding async copies, we need to wait for them.
// TODO(Keren): We may want to completely avoid the async copies in the last
// few iterations by setting is_masked attribute to true. We don't want to use
// the mask operand because it's a tensor but not a scalar.
OpBuilder builder(forOp);
OpBuilder::InsertionGuard g(builder);
builder.setInsertionPointAfter(forOp);
@@ -367,12 +379,12 @@ void LoopPipeliner::emitEpilogue() {
scf::ForOp LoopPipeliner::createNewForOp() {
OpBuilder builder(forOp);
// order of new args:
// (original args),
// (insertSliceAsync buffer at stage numStages - 1) for each load
// (extracted tensor) for each load
// (depArgs at stage numStages-1)
// (iv at stage numStages-1)
// Order of new args:
// (original args)
// (insertSliceAsync buffer at stage numStages - 1) for each load
// (extracted tensor) for each load
// (depArgs at stage numStages - 1)
// (iv at stage numStages - 2)
// (pipeline iteration index)
// (loop iteration index)
SmallVector<Value> newLoopArgs;
@@ -413,6 +425,7 @@ scf::ForOp LoopPipeliner::createNewForOp() {
BlockAndValueMapping mapping;
for (const auto &arg : llvm::enumerate(forOp.getRegionIterArgs()))
mapping.map(arg.value(), newForOp.getRegionIterArgs()[arg.index()]);
mapping.map(forOp.getInductionVar(), newForOp.getInductionVar());
// 2.1 clone the loop body, replace original args with args of the new ForOp
// Insert async wait if necessary.
@@ -454,15 +467,16 @@ scf::ForOp LoopPipeliner::createNewForOp() {
newForOp.getRegionIterArgs()[argIdx + depArgsBeginIdx]);
++argIdx;
}
// special handling for iv & loop condition
// Special handling for iv & loop condition
Value nextIV = builder.create<arith::AddIOp>(
newForOp.getInductionVar().getLoc(),
newForOp.getRegionIterArgs()[nextIVIdx], newForOp.getStep());
Value nextLoopCond =
builder.create<arith::CmpIOp>(nextIV.getLoc(), arith::CmpIPredicate::slt,
nextIV, newForOp.getUpperBound());
nextMapping.map(forOp.getInductionVar(), nextIV);
// slice index
// Slice index
SmallVector<Value> nextBuffers;
SmallVector<Value> extractSlices;
@@ -477,11 +491,9 @@ scf::ForOp LoopPipeliner::createNewForOp() {
extractSliceIndex = builder.create<arith::IndexCastOp>(
extractSliceIndex.getLoc(), builder.getIndexType(), extractSliceIndex);
auto intAttr = [&](int64_t v) { return builder.getI64IntegerAttr(v); };
for (Operation *op : orderedDeps) {
Operation *nextOp = nullptr;
// update loading mask
// Update loading mask
if (loads.contains(op->getResult(0))) {
auto loadOp = llvm::cast<triton::LoadOp>(op);
Value mask = loadOp.mask();
@@ -491,7 +503,7 @@ scf::ForOp LoopPipeliner::createNewForOp() {
mask.getLoc(), mask.getType(), nextLoopCond);
newMask = builder.create<arith::AndIOp>(
mask.getLoc(), splatCond, nextMapping.lookupOrDefault(mask));
// if mask is defined outside the loop, don't update the map more than
// If mask is defined outside the loop, don't update the map more than
// once
if (!(forOp.isDefinedOutsideOfLoop(mask) && nextMapping.contains(mask)))
nextMapping.map(mask, newMask);
@@ -508,20 +520,24 @@ scf::ForOp LoopPipeliner::createNewForOp() {
loadOp.evict(), loadOp.isVolatile(), /*axis*/ 0);
nextBuffers.push_back(insertAsyncOp);
auto sliceType = loadsMapping[loadOp].getType().cast<RankedTensorType>();
sliceType = RankedTensorType::get(sliceType.getShape(),
sliceType.getElementType(),
loadsBufferType[loadOp].getEncoding());
nextOp = builder.create<tensor::ExtractSliceOp>(
op->getLoc(), sliceType, insertAsyncOp,
SmallVector<OpFoldResult>{extractSliceIndex, intAttr(0), intAttr(0)},
SmallVector<OpFoldResult>{intAttr(1),
intAttr(sliceType.getShape()[0]),
intAttr(sliceType.getShape()[1])},
SmallVector<OpFoldResult>{intAttr(1), intAttr(1), intAttr(1)});
SmallVector<OpFoldResult>{extractSliceIndex, int_attr(0),
int_attr(0)},
SmallVector<OpFoldResult>{int_attr(1),
int_attr(sliceType.getShape()[0]),
int_attr(sliceType.getShape()[1])},
SmallVector<OpFoldResult>{int_attr(1), int_attr(1), int_attr(1)});
extractSlices.push_back(nextOp->getResult(0));
} else
nextOp = builder.clone(*op, nextMapping);
// update mapping of results
// Update mapping of results
for (unsigned dstIdx : llvm::seq(unsigned(0), op->getNumResults())) {
nextMapping.map(op->getResult(dstIdx), nextOp->getResult(dstIdx));
// if this is a loop-carried value, update the mapping for yield
// If this is a loop-carried value, update the mapping for yield
auto originYield = cast<scf::YieldOp>(forOp.getBody()->getTerminator());
for (OpOperand &operand : originYield->getOpOperands()) {
if (operand.get() == op->getResult(dstIdx)) {
@@ -534,15 +550,44 @@ scf::ForOp LoopPipeliner::createNewForOp() {
}
}
{
OpBuilder::InsertionGuard guard(builder);
for (Operation &op : *newForOp.getBody()) {
if (auto dotOp = llvm::dyn_cast<triton::DotOp>(&op)) {
builder.setInsertionPoint(&op);
auto dotType = dotOp.getType().cast<RankedTensorType>();
Value a = dotOp.a();
Value b = dotOp.b();
auto layoutCast = [&](Value dotOperand, int opIdx) -> Value {
auto tensorType = dotOperand.getType().cast<RankedTensorType>();
if (!tensorType.getEncoding().isa<ttg::DotOperandEncodingAttr>()) {
auto newEncoding = ttg::DotOperandEncodingAttr::get(
tensorType.getContext(), opIdx, dotType.getEncoding());
auto newType =
RankedTensorType::get(tensorType.getShape(),
tensorType.getElementType(), newEncoding);
return builder.create<ttg::ConvertLayoutOp>(dotOperand.getLoc(),
newType, dotOperand);
}
return dotOperand;
};
a = layoutCast(a, 0);
b = layoutCast(b, 1);
dotOp->setOperand(0, a);
dotOp->setOperand(1, b);
}
}
}
// async.wait & extract_slice
Operation *asyncWait = builder.create<triton::gpu::AsyncWaitOp>(
Operation *asyncWait = builder.create<ttg::AsyncWaitOp>(
loads[0].getLoc(), loads.size() * (numStages - 2));
for (auto it = extractSlices.rbegin(); it != extractSlices.rend(); ++it) {
// move extract_slice after asyncWait
it->getDefiningOp()->moveAfter(asyncWait);
}
// bump iteration count
// Bump iteration count
pipelineIterIdx = builder.create<arith::AddIOp>(
nextIV.getLoc(), pipelineIterIdx,
builder.create<arith::ConstantIntOp>(nextIV.getLoc(), 1, 32));
@@ -559,9 +604,11 @@ scf::ForOp LoopPipeliner::createNewForOp() {
for (Value nextSlice : extractSlices)
yieldValues.push_back(nextSlice);
for (size_t i = depArgsBeginIdx; i < nextIVIdx; ++i)
yieldValues.push_back(
depArgsMapping.lookup(newForOp.getRegionIterArgs()[i]));
for (size_t i = depArgsBeginIdx; i < nextIVIdx; ++i) {
auto arg = newForOp.getRegionIterArgs()[i];
assert(depArgsMapping.count(arg) && "Missing loop-carried value");
yieldValues.push_back(depArgsMapping[arg]);
}
yieldValues.push_back(nextIV);
yieldValues.push_back(pipelineIterIdx);
yieldValues.push_back(loopIterIdx);

View File

@@ -0,0 +1,313 @@
//===----------------------------------------------------------------------===//
//
// This pass tries to prefetch operands (a and b) of tt.dot.
// Those ConvertLayoutOps will be lowered to shared memory loads.
//
// For example:
// %a: tensor<128x32xf16, #enc>
// scf.for %iv = ... iter_args(%a_arg = %a, ...) {
// %d = tt.dot %a_arg, %b, %c
// ...
// scf.yield %a_next, ...
// }
//
// will be translated to
//
// %a: tensor<128x32xf16, #enc>
// %a_tmp = tensor.extract_slice %a[0, 0] [128, 16]
// %a_prefetch = triton_gpu.convert_layout %a_tmp
// scf.for %iv = ... iter_args(%a_buf = %a, ..., %a_prefetch_arg = %a_prefetch)
// {
// %x = tt.dot %a_arg, %b, %c
// %a_tmp_rem = tensor.extract_slice %a_buf[0, 16] [128, 16]
// %a_prefetch_next = triton_gpu.convert_layout %a_tmp_rem
// ...
// scf.yield %next_a, ..., %a_prefetch_next
// }
//===----------------------------------------------------------------------===//
#include "mlir/IR/BlockAndValueMapping.h"
#include "triton/Analysis/Utility.h"
#include "triton/Dialect/TritonGPU/IR/Dialect.h"
#include "triton/Dialect/TritonGPU/Transforms/Passes.h"
using namespace mlir;
#define GEN_PASS_CLASSES
#include "triton/Dialect/TritonGPU/Transforms/Passes.h.inc"
namespace {
class Prefetcher {
/// cache the ForOp we are working on
scf::ForOp forOp;
/// cache the YieldOp of this ForOp
scf::YieldOp yieldOp;
///
// TODO: add a hook to infer prefetchWidth
unsigned prefetchWidth = 16;
/// dots to be prefetched
SetVector<Value> dots;
/// dot => dot operand
DenseMap<Value, Value> dot2aLoopArg;
DenseMap<Value, Value> dot2aHeaderDef;
DenseMap<Value, Value> dot2bLoopArg;
DenseMap<Value, Value> dot2bHeaderDef;
DenseMap<Value, Value> dot2aYield;
DenseMap<Value, Value> dot2bYield;
/// operand => defining
DenseMap<Value, Value> operand2headPrefetch;
LogicalResult isForOpOperand(Value v);
Value generatePrefetch(Value v, unsigned opIdx, bool isPrologue,
Attribute dotEncoding, OpBuilder &builder,
llvm::Optional<int64_t> offsetK = llvm::None,
llvm::Optional<int64_t> shapeK = llvm::None);
public:
Prefetcher() = delete;
Prefetcher(scf::ForOp forOp) : forOp(forOp) {
yieldOp = cast<scf::YieldOp>(forOp.getBody()->getTerminator());
}
LogicalResult initialize();
void emitPrologue();
scf::ForOp createNewForOp();
};
Value Prefetcher::generatePrefetch(Value v, unsigned opIdx, bool isPrologue,
Attribute dotEncoding, OpBuilder &builder,
llvm::Optional<int64_t> offsetK,
llvm::Optional<int64_t> shapeK) {
// opIdx: 0 => a, 1 => b
auto type = v.getType().cast<RankedTensorType>();
SmallVector<int64_t> shape{type.getShape().begin(), type.getShape().end()};
SmallVector<int64_t> offset{0, 0};
Type elementType = type.getElementType();
auto intAttr = [&](int64_t val) { return builder.getI64IntegerAttr(val); };
// k => (prefetchWidth, k - prefetchWidth)
int64_t kIdx = opIdx == 0 ? 1 : 0;
offset[kIdx] = isPrologue ? 0 : prefetchWidth;
shape[kIdx] = isPrologue ? prefetchWidth : (shape[kIdx] - prefetchWidth);
if (shapeK)
shape[kIdx] = *shapeK;
if (offsetK)
offset[kIdx] = *offsetK;
Value newSmem = builder.create<tensor::ExtractSliceOp>(
v.getLoc(),
// TODO: encoding?
RankedTensorType::get(shape, elementType, type.getEncoding()), v,
SmallVector<OpFoldResult>{intAttr(offset[0]), intAttr(offset[1])},
SmallVector<OpFoldResult>{intAttr(shape[0]), intAttr(shape[1])},
SmallVector<OpFoldResult>{intAttr(1), intAttr(1)});
auto dotOperandEnc = triton::gpu::DotOperandEncodingAttr::get(
builder.getContext(), opIdx, dotEncoding);
Value prefetchSlice = builder.create<triton::gpu::ConvertLayoutOp>(
v.getLoc(), RankedTensorType::get(shape, elementType, dotOperandEnc),
newSmem);
return prefetchSlice;
}
LogicalResult Prefetcher::initialize() {
Block *loop = forOp.getBody();
SmallVector<triton::DotOp> dotsInFor;
for (Operation &op : *loop)
if (auto dotOp = dyn_cast<triton::DotOp>(op))
dotsInFor.push_back(dotOp);
if (dotsInFor.empty())
return failure();
// TODO: segfault (original for still has uses)
// when used in flash attention that has 2 dots in the loop
if (dotsInFor.size() > 1)
return failure();
// returns source of cvt
auto getPrefetchSrc = [](Value v) -> Value {
if (auto cvt = v.getDefiningOp<triton::gpu::ConvertLayoutOp>())
if (isSharedEncoding(cvt.getOperand()))
return cvt.src();
return Value();
};
auto getIncomingOp = [this](Value v) -> Value {
if (auto arg = v.dyn_cast<BlockArgument>())
if (arg.getOwner()->getParentOp() == forOp.getOperation())
return forOp.getOpOperandForRegionIterArg(arg).get();
return Value();
};
auto getYieldOp = [this](Value v) -> Value {
auto arg = v.cast<BlockArgument>();
unsigned yieldIdx = arg.getArgNumber() - forOp.getNumInductionVars();
return yieldOp.getOperand(yieldIdx);
};
for (triton::DotOp dot : dotsInFor) {
auto kSize = dot.a().getType().cast<RankedTensorType>().getShape()[1];
// Skip prefetching if kSize is less than prefetchWidth
if (kSize < prefetchWidth)
continue;
Value aSmem = getPrefetchSrc(dot.a());
Value bSmem = getPrefetchSrc(dot.b());
if (aSmem && bSmem) {
Value aHeaderDef = getIncomingOp(aSmem);
Value bHeaderDef = getIncomingOp(bSmem);
// Only prefetch loop arg
if (aHeaderDef && bHeaderDef) {
dots.insert(dot);
dot2aHeaderDef[dot] = aHeaderDef;
dot2bHeaderDef[dot] = bHeaderDef;
dot2aLoopArg[dot] = aSmem;
dot2bLoopArg[dot] = bSmem;
dot2aYield[dot] = getYieldOp(aSmem);
dot2bYield[dot] = getYieldOp(bSmem);
}
}
}
return success();
}
void Prefetcher::emitPrologue() {
OpBuilder builder(forOp);
for (Value dot : dots) {
Attribute dotEncoding =
dot.getType().cast<RankedTensorType>().getEncoding();
Value aPrefetched =
generatePrefetch(dot2aHeaderDef[dot], 0, true, dotEncoding, builder);
operand2headPrefetch[dot.getDefiningOp<triton::DotOp>().a()] = aPrefetched;
Value bPrefetched =
generatePrefetch(dot2bHeaderDef[dot], 1, true, dotEncoding, builder);
operand2headPrefetch[dot.getDefiningOp<triton::DotOp>().b()] = bPrefetched;
}
}
scf::ForOp Prefetcher::createNewForOp() {
OpBuilder builder(forOp);
SmallVector<Value> loopArgs;
for (auto v : forOp.getIterOperands())
loopArgs.push_back(v);
for (Value dot : dots) {
loopArgs.push_back(
operand2headPrefetch[dot.getDefiningOp<triton::DotOp>().a()]);
loopArgs.push_back(
operand2headPrefetch[dot.getDefiningOp<triton::DotOp>().b()]);
}
auto newForOp = builder.create<scf::ForOp>(
forOp.getLoc(), forOp.getLowerBound(), forOp.getUpperBound(),
forOp.getStep(), loopArgs);
auto largestPow2 = [](int64_t n) -> int64_t {
while ((n & (n - 1)) != 0)
n = n & (n - 1);
return n;
};
builder.setInsertionPointToStart(newForOp.getBody());
BlockAndValueMapping mapping;
for (const auto &arg : llvm::enumerate(forOp.getRegionIterArgs()))
mapping.map(arg.value(), newForOp.getRegionIterArgs()[arg.index()]);
mapping.map(forOp.getInductionVar(), newForOp.getInductionVar());
for (Operation &op : forOp.getBody()->without_terminator()) {
Operation *newOp = builder.clone(op, mapping);
auto dot = dyn_cast<triton::DotOp>(&op);
if (dots.contains(dot)) {
Attribute dotEncoding =
dot.getType().cast<RankedTensorType>().getEncoding();
// prefetched dot
Operation *firstDot = builder.clone(*dot, mapping);
if (Value a = operand2headPrefetch.lookup(dot.a()))
firstDot->setOperand(
0, newForOp.getRegionIterArgForOpOperand(*a.use_begin()));
if (Value b = operand2headPrefetch.lookup(dot.b()))
firstDot->setOperand(
1, newForOp.getRegionIterArgForOpOperand(*b.use_begin()));
// remaining part
int64_t kOff = prefetchWidth;
int64_t kRem = dot.a().getType().cast<RankedTensorType>().getShape()[1] -
prefetchWidth;
Operation *prevDot = firstDot;
while (kRem != 0) {
int64_t kShape = largestPow2(kRem);
Value aRem =
generatePrefetch(mapping.lookup(dot2aLoopArg[dot]), 0, false,
dotEncoding, builder, kOff, kShape);
Value bRem =
generatePrefetch(mapping.lookup(dot2bLoopArg[dot]), 1, false,
dotEncoding, builder, kOff, kShape);
newOp = builder.clone(*dot, mapping);
newOp->setOperand(0, aRem);
newOp->setOperand(1, bRem);
newOp->setOperand(2, prevDot->getResult(0));
prevDot = newOp;
kOff += kShape;
kRem -= kShape;
}
}
// update mapping of results
for (unsigned dstIdx : llvm::seq(unsigned(0), op.getNumResults()))
mapping.map(op.getResult(dstIdx), newOp->getResult(dstIdx));
}
// prefetch next iteration
SmallVector<Value> yieldValues;
for (Value v : forOp.getBody()->getTerminator()->getOperands())
yieldValues.push_back(mapping.lookup(v));
for (Value dot : dots) {
Attribute dotEncoding =
dot.getType().cast<RankedTensorType>().getEncoding();
yieldValues.push_back(generatePrefetch(mapping.lookup(dot2aYield[dot]), 0,
true, dotEncoding, builder));
yieldValues.push_back(generatePrefetch(mapping.lookup(dot2bYield[dot]), 1,
true, dotEncoding, builder));
}
// Update ops of yield
builder.create<scf::YieldOp>(yieldOp.getLoc(), yieldValues);
return newForOp;
}
struct PrefetchPass : public TritonGPUPrefetchBase<PrefetchPass> {
void runOnOperation() override {
getOperation()->walk([&](scf::ForOp forOp) {
Prefetcher prefetcher(forOp);
if (prefetcher.initialize().failed())
return;
prefetcher.emitPrologue();
scf::ForOp newForOp = prefetcher.createNewForOp();
// replace the original loop
for (unsigned i = 0; i < forOp->getNumResults(); ++i)
forOp->getResult(i).replaceAllUsesWith(newForOp->getResult(i));
forOp->erase();
});
}
};
} // anonymous namespace
std::unique_ptr<Pass> mlir::createTritonGPUPrefetchPass() {
return std::make_unique<PrefetchPass>();
}

View File

@@ -1,102 +0,0 @@
#include "mlir/Analysis/SliceAnalysis.h"
#include "triton/Dialect/TritonGPU/IR/Dialect.h"
#include "triton/Dialect/TritonGPU/Transforms/Passes.h"
using namespace mlir;
using namespace mlir::triton;
#define GEN_PASS_CLASSES
#include "triton/Dialect/TritonGPU/Transforms/Passes.h.inc"
namespace {
struct SwizzlePass : public TritonGPUSwizzleBase<SwizzlePass> {
SwizzlePass() = default;
struct SwizzleInfo {
int vec;
int perPhase;
int maxPhase;
};
SwizzleInfo getSwizzleMMA(int opIdx, triton::gpu::MmaEncodingAttr retEncoding,
RankedTensorType ty) {
SwizzleInfo noSwizzling = {1, 1, 1};
int version = retEncoding.getVersion();
auto tyEncoding = ty.getEncoding().cast<triton::gpu::SharedEncodingAttr>();
auto order = tyEncoding.getOrder();
// number of rows per phase
int perPhase = 128 / (ty.getShape()[order[0]] *
(ty.getElementType().getIntOrFloatBitWidth() / 8));
perPhase = std::max<int>(perPhase, 1);
// index of the inner dimension in `order`
size_t inner = (opIdx == 0) ? 0 : 1;
if (version == 1) {
int maxPhase = (order[inner] == 1 ? 8 : 4) / perPhase;
// TODO: handle rep (see
// https://github.com/openai/triton/blob/master/lib/codegen/analysis/layout.cc#L209)
int vec = 1;
return SwizzleInfo{vec, perPhase, maxPhase};
} else if (version == 2) {
auto eltTy = ty.getElementType();
std::vector<size_t> mat_shape = {8, 8,
2 * 64 / eltTy.getIntOrFloatBitWidth()};
// for now, disable swizzle when using transposed int8 tensor cores
bool is_int8_mma = ty.getElementType().isInteger(8);
if (is_int8_mma && order[0] == inner)
return noSwizzling;
// compute swizzling for A operand
if (opIdx == 0) {
int vec = order[0] == 1 ? mat_shape[2] : mat_shape[0]; // k : m
int mmaStride = order[0] == 1 ? mat_shape[0] : mat_shape[2];
int maxPhase = mmaStride / perPhase;
return SwizzleInfo{vec, perPhase, maxPhase};
}
// compute swizzling for B operand
else if (opIdx == 1) {
int vec = order[0] == 1 ? mat_shape[1] : mat_shape[2]; // n : k
int mmaStride = order[0] == 1 ? mat_shape[2] : mat_shape[1];
int maxPhase = mmaStride / perPhase;
return SwizzleInfo{vec, perPhase, maxPhase};
} else {
llvm_unreachable("invalid operand index");
}
} else
llvm_unreachable("unsupported swizzling for provided MMA version");
}
void runOnOperation() override {
Operation *op = getOperation();
op->walk([&](triton::DotOp dotOp) -> void {
OpBuilder builder(dotOp);
auto _retEncoding =
dotOp.getResult().getType().cast<RankedTensorType>().getEncoding();
auto retEncoding = _retEncoding.dyn_cast<triton::gpu::MmaEncodingAttr>();
if (!retEncoding)
return;
for (int opIdx : {0, 1}) {
Value op = dotOp.getOperand(opIdx);
auto ty = op.getType().template cast<RankedTensorType>();
// compute new swizzled encoding
SwizzleInfo swizzle = getSwizzleMMA(opIdx, retEncoding, ty);
auto newEncoding = triton::gpu::SharedEncodingAttr::get(
&getContext(), swizzle.vec, swizzle.perPhase, swizzle.maxPhase,
ty.getEncoding()
.cast<triton::gpu::SharedEncodingAttr>()
.getOrder());
// create conversion
auto newType = RankedTensorType::get(ty.getShape(), ty.getElementType(),
newEncoding);
Operation *newOp = builder.create<triton::gpu::ConvertLayoutOp>(
op.getLoc(), newType, op);
// bind new op to dot operand
dotOp->replaceUsesOfWith(op, newOp->getResult(0));
}
});
}
};
} // anonymous namespace
std::unique_ptr<Pass> mlir::createTritonGPUSwizzlePass() {
return std::make_unique<SwizzlePass>();
}

View File

@@ -35,7 +35,7 @@ TritonGPUTypeConverter::TritonGPUTypeConverter(MLIRContext *context,
});
//
// materailizations
// Materializations
//
// This will be called when (newArgType != origArgType)
// This will create newArg, and map(origArg, newArg)
@@ -95,8 +95,8 @@ TritonGPUConversionTarget::TritonGPUConversionTarget(
dotOp.a().getType().cast<RankedTensorType>().getEncoding();
Attribute bEncoding =
dotOp.b().getType().cast<RankedTensorType>().getEncoding();
if (aEncoding && aEncoding.isa<triton::gpu::SharedEncodingAttr>() &&
bEncoding && bEncoding.isa<triton::gpu::SharedEncodingAttr>())
if (aEncoding && aEncoding.isa<triton::gpu::DotOperandEncodingAttr>() &&
bEncoding && bEncoding.isa<triton::gpu::DotOperandEncodingAttr>())
return true;
return false;
});

View File

@@ -1,126 +0,0 @@
#include "triton/Target/AMDGCN/AMDGCNTranslation.h"
#include "triton/tools/sys/getenv.hpp"
#include "llvm/IR/LegacyPassManager.h"
#include "llvm/IR/Module.h"
#include "llvm/IR/Verifier.h"
#include "llvm/MC/TargetRegistry.h"
#include "llvm/Support/FileSystem.h"
#include "llvm/Support/Program.h"
#include "llvm/Support/TargetSelect.h"
#include "llvm/Target/TargetMachine.h"
namespace {
void init_llvm() {
LLVMInitializeAMDGPUTarget();
LLVMInitializeAMDGPUTargetInfo();
LLVMInitializeAMDGPUTargetMC();
LLVMInitializeAMDGPUAsmPrinter();
LLVMInitializeAMDGPUAsmParser();
}
}
namespace triton {
std::tuple<std::string, std::string>
translateLLVMIRToAMDGCN(llvm::Module &module, std::string cc) {
// create
llvm::SmallVector<char, 0> buffer;
std::string triple = "amdgcn-amd-amdhsa";
std::string proc = cc;
std::string layout = "";
std::string features = "+sramecc,-xnack";
std::string kernel_name =
std::string(std::tmpnam(nullptr)) + "_" + module.getModuleIdentifier();
init_llvm();
// verify and store llvm
llvm::legacy::PassManager pm;
pm.add(llvm::createVerifierPass());
pm.run(module);
// create machine
module.setTargetTriple(triple);
std::string error;
auto target =
llvm::TargetRegistry::lookupTarget(module.getTargetTriple(), error);
llvm::TargetOptions opt;
opt.AllowFPOpFusion = llvm::FPOpFusion::Fast;
opt.UnsafeFPMath = false;
opt.NoInfsFPMath = false;
opt.NoNaNsFPMath = true;
llvm::TargetMachine *machine = target->createTargetMachine(
module.getTargetTriple(), proc, features, opt, llvm::Reloc::PIC_,
llvm::None, llvm::CodeGenOpt::Aggressive);
// set data layout
if (layout.empty())
module.setDataLayout(machine->createDataLayout());
else
module.setDataLayout(layout);
// emit machine code
for (llvm::Function &f : module.functions())
f.addFnAttr(llvm::Attribute::AlwaysInline);
llvm::legacy::PassManager pass;
llvm::raw_svector_ostream stream(buffer);
// emit
machine->addPassesToEmitFile(pass, stream, nullptr,
llvm::CodeGenFileType::CGFT_AssemblyFile);
pass.run(module);
std::string amdgcn(buffer.begin(), buffer.end());
if (tools::getBoolEnv("AMDGCN_ENABLE_DUMP")) {
llvm::outs() << amdgcn << "\n";
}
// create dump files
std::error_code ec;
// Save GCN ISA binary.
std::string isabin_path = kernel_name + std::string(".o");
std::unique_ptr<llvm::raw_fd_ostream> isabin_fs(
new llvm::raw_fd_ostream(isabin_path, ec, llvm::sys::fs::OF_Text));
if (ec) {
llvm::errs() << "Error: '" << isabin_path << "' was not created. \n";
llvm::errs() << "\tError reported: " << ec.message() << "\n";
return std::make_tuple("", "");
}
// emit
machine->addPassesToEmitFile(pass, *isabin_fs, nullptr,
llvm::CGFT_ObjectFile);
pass.run(module);
// generate HASCO file
std::string hsaco_path = kernel_name + std::string(".hsaco");
std::string error_message;
int lld_result =
llvm::sys::ExecuteAndWait("/opt/rocm/llvm/bin/ld.lld",
{"/opt/rocm/llvm/bin/ld.lld", "-flavor", "gnu",
"-shared", "-o", hsaco_path, isabin_path},
llvm::None, {}, 0, 0, &error_message);
if (lld_result) {
llvm::errs() << "Error: ld.lld execute fail. Error code: " << lld_result
<< "\n";
llvm::errs() << "\tError reported: " << error_message << "\n";
return std::make_tuple("", "");
}
const std::string hsaco_dump_path = tools::getenv("AMDGCN_HSACO_DUMP_PATH");
if (!hsaco_dump_path.empty()) {
if (std::error_code copy_result =
llvm::sys::fs::copy_file(hsaco_path, hsaco_dump_path)) {
llvm::errs() << "Error: cannot copy to hsaco dump file from '"
<< hsaco_path << "' to '" << hsaco_dump_path << "'\n";
llvm::errs() << "\tError reported: " << copy_result.message() << "\n";
return std::make_tuple("", "");
}
}
return std::make_tuple(amdgcn, hsaco_path);
}
} // namespace triton

View File

@@ -1,12 +0,0 @@
add_mlir_translation_library(TritonAMDGCN
AMDGCNTranslation.cpp
LINK_COMPONENTS
Core
LINK_LIBS PUBLIC
MLIRIR
MLIRLLVMIR
MLIRSupport
MLIRTargetLLVMIRExport
)

View File

@@ -1,3 +1,3 @@
add_subdirectory(AMDGCN)
add_subdirectory(LLVMIR)
add_subdirectory(PTX)
add_subdirectory(HSACO)

View File

@@ -0,0 +1,9 @@
add_mlir_translation_library(TritonHSACO
HSACOTranslation.cpp
LINK_COMPONENTS
Core
LINK_LIBS PUBLIC
TritonLLVMIR
)

View File

@@ -0,0 +1,170 @@
#include "triton/Target/HSACO/HSACOTranslation.h"
#include "mlir/Dialect/LLVMIR/LLVMDialect.h"
#include "mlir/ExecutionEngine/ExecutionEngine.h"
#include "mlir/ExecutionEngine/OptUtils.h"
#include "mlir/IR/BuiltinOps.h"
#include "mlir/IR/Dialect.h"
#include "mlir/Pass/Pass.h"
#include "mlir/Pass/PassManager.h"
#include "mlir/Support/LogicalResult.h"
#include "mlir/Target/LLVMIR/Dialect/LLVMIR/LLVMToLLVMIRTranslation.h"
#include "mlir/Target/LLVMIR/Export.h"
#include "mlir/Target/LLVMIR/LLVMTranslationInterface.h"
#include "triton/Target/LLVMIR/LLVMIRTranslation.h"
#include "triton/Tools/Sys/GetEnv.hpp"
#include "llvm/ExecutionEngine/ExecutionEngine.h"
#include "llvm/ExecutionEngine/SectionMemoryManager.h"
#include "llvm/IR/IRBuilder.h"
#include "llvm/IR/IRPrintingPasses.h"
#include "llvm/IR/LegacyPassManager.h"
#include "llvm/IR/Module.h"
#include "llvm/IR/Verifier.h"
#include "llvm/MC/TargetRegistry.h"
#include "llvm/Support/CodeGen.h"
#include "llvm/Support/CommandLine.h"
#include "llvm/Support/SourceMgr.h"
#include "llvm/Support/TargetSelect.h"
#include "llvm/Support/raw_ostream.h"
#include "llvm/Target/TargetMachine.h"
#include "llvm/Target/TargetOptions.h"
#include "llvm/Transforms/Scalar.h"
#include "llvm/Transforms/Utils/Cloning.h"
#include <regex>
#include <cstdio>
#include <iostream>
#include <memory>
namespace {
void init_llvm() {
LLVMInitializeAMDGPUTarget();
LLVMInitializeAMDGPUTargetInfo();
LLVMInitializeAMDGPUTargetMC();
LLVMInitializeAMDGPUAsmParser();
LLVMInitializeAMDGPUAsmPrinter();
}
std::unique_ptr<llvm::TargetMachine> initialize_module(llvm::Module* module,
const std::string& triple,
const std::string& proc,
const std::string& features) {
// verify and store llvm
llvm::legacy::PassManager pm;
pm.add(llvm::createVerifierPass());
pm.run(*module);
module->setTargetTriple(triple);
std::string error;
auto target =
llvm::TargetRegistry::lookupTarget(module->getTargetTriple(), error);
llvm::TargetOptions opt;
opt.AllowFPOpFusion = llvm::FPOpFusion::Fast;
opt.UnsafeFPMath = false;
opt.NoInfsFPMath = false;
opt.NoNaNsFPMath = true;
llvm::TargetMachine *machine = target->createTargetMachine(
module->getTargetTriple(), proc, features, opt, llvm::Reloc::PIC_,
llvm::None, llvm::CodeGenOpt::Aggressive);
module->setDataLayout(machine->createDataLayout());
for (llvm::Function &f : module->functions())
f.addFnAttr(llvm::Attribute::AlwaysInline);
return std::unique_ptr<llvm::TargetMachine>(machine);
}
std::string generate_amdgcn_assembly(llvm::Module* module,
const std::string& triple,
const std::string& proc,
const std::string& features) {
auto machine = initialize_module(module, triple, proc, features);
llvm::SmallVector<char, 0> buffer;
llvm::legacy::PassManager pass;
llvm::raw_svector_ostream stream(buffer);
// emit
machine->addPassesToEmitFile(pass, stream, nullptr,
llvm::CodeGenFileType::CGFT_AssemblyFile);
pass.run(*module);
std::string amdgcn(buffer.begin(), buffer.end());
if (::triton::tools::getBoolEnv("AMDGCN_ENABLE_DUMP")) {
std::cout << "// -----// AMDGCN Dump //----- //\n" << amdgcn << std::endl;
}
return amdgcn;
}
std::string generate_hsaco(llvm::Module* module,
const std::string& triple,
const std::string& proc,
const std::string& features) {
auto machine = initialize_module(module, triple, proc, features);
std::string kernel_name = std::string(std::tmpnam(nullptr)) + "_" + module->getModuleIdentifier();
// create dump files
std::error_code ec;
// Save GCN ISA binary.
std::string isabin_path = kernel_name + std::string(".o");
std::unique_ptr<llvm::raw_fd_ostream> isabin_fs(
new llvm::raw_fd_ostream(isabin_path, ec, llvm::sys::fs::OF_Text));
if (ec)
{
std::cout << isabin_path << " was not created. error code: " << ec << std::endl;
}
// emit
llvm::legacy::PassManager pass;
machine->addPassesToEmitFile(pass, *isabin_fs, nullptr, llvm::CGFT_ObjectFile);
pass.run(*module);
// generate HASCO file
std::string hsaco_path = kernel_name + std::string(".hsaco");
std::string error_message;
int lld_result =
llvm::sys::ExecuteAndWait("/opt/rocm/llvm/bin/ld.lld",
{"/opt/rocm/llvm/bin/ld.lld", "-flavor", "gnu", "-shared", "-o", hsaco_path, isabin_path},
llvm::None, {}, 0, 0, &error_message);
if (lld_result)
{
std::cout << "ld.lld execute fail: " << std::endl;
std::cout << error_message << std::endl;
std::cout << lld_result << std::endl;
}
return hsaco_path;
}
std::tuple<std::string, std::string> llir_to_amdgcn_and_hsaco(llvm::Module* module,
std::string cc) {
// create
std::string triple = "amdgcn-amd-amdhsa";
std::string proc = cc;
std::string features = "+sramecc,-xnack";
init_llvm();
// verify and store llvm
auto module_obj = llvm::CloneModule(*module);
auto amdgcn = generate_amdgcn_assembly(module, triple, proc, features);
auto hsaco_path = generate_hsaco(module_obj.get(), triple, proc, features);
return std::make_tuple(amdgcn, hsaco_path);
}
}
namespace triton {
std::tuple<std::string, std::string> translateLLVMIRToHSACO(llvm::Module& module,
std::string cc) {
auto hsacoCode = llir_to_amdgcn_and_hsaco(&module, cc);
return hsacoCode;
}
} // namespace triton

View File

@@ -1,27 +1,28 @@
#include "triton/Target/LLVMIR/LLVMIRTranslation.h"
#include "mlir/Conversion/Passes.h"
#include "mlir/Dialect/LLVMIR/LLVMDialect.h"
#include "mlir/ExecutionEngine/ExecutionEngine.h"
#include "mlir/ExecutionEngine/OptUtils.h"
#include "mlir/IR/BuiltinOps.h"
#include "mlir/IR/Dialect.h"
#include "mlir/Pass/Pass.h"
#include "mlir/Pass/PassManager.h"
#include "mlir/Support/LogicalResult.h"
#include "mlir/Target/LLVMIR/Dialect/LLVMIR/LLVMToLLVMIRTranslation.h"
#include "mlir/Target/LLVMIR/Dialect/NVVM/NVVMToLLVMIRTranslation.h"
#include "mlir/Target/LLVMIR/Dialect/ROCDL/ROCDLToLLVMIRTranslation.h"
#include "mlir/Target/LLVMIR/Export.h"
#include "mlir/Target/LLVMIR/LLVMTranslationInterface.h"
#include "mlir/Transforms/Passes.h"
#include "triton/Conversion/TritonGPUToLLVM/TritonGPUToLLVM.h"
#include "triton/tools/sys/getenv.hpp"
#include "triton/Conversion/TritonGPUToLLVM/TritonGPUToLLVMPass.h"
#include "triton/Tools/Sys/GetEnv.hpp"
#include "llvm/IR/CallingConv.h"
#include "llvm/IR/Constants.h"
#include "llvm/IRReader/IRReader.h"
#include "llvm/Linker/Linker.h"
#include "llvm/Support/SourceMgr.h"
#include <iostream>
namespace mlir {
namespace triton {
@@ -38,7 +39,6 @@ void amendLLVMFunc(llvm::Function *func, const NVVMMetadata &metadata) {
auto *module = func->getParent();
auto &ctx = func->getContext();
#ifndef USE_ROCM
if (metadata.maxntidx > 0) {
auto i32_ty = llvm::IntegerType::get(ctx, 32);
auto warps =
@@ -51,7 +51,6 @@ void amendLLVMFunc(llvm::Function *func, const NVVMMetadata &metadata) {
module->getOrInsertNamedMetadata("nvvm.annotations")
->addOperand(llvm::MDNode::get(ctx, md_args));
}
#endif
if (metadata.is_kernel) {
#ifndef USE_ROCM
@@ -76,14 +75,14 @@ void extractNVVMMetadata(mlir::ModuleOp module,
bool hasMetadata{};
// maxntid
if (op->hasAttr(NVVMMetadataField::MaxNTid)) {
auto attr = op->getAttr(NVVMMetadataField::MaxNTid);
if (op->hasAttr("nvvm.maxntid")) {
auto attr = op->getAttr("nvvm.maxntid");
meta.maxntidx = attr.dyn_cast<IntegerAttr>().getInt();
hasMetadata = true;
}
// kernel
if (op->hasAttr(NVVMMetadataField::Kernel)) {
if (op->hasAttr("nvvm.kernel")) {
meta.is_kernel = true;
hasMetadata = true;
}
@@ -98,13 +97,8 @@ translateLLVMToLLVMIR(llvm::LLVMContext *llvmContext, mlir::ModuleOp module) {
auto context = module->getContext();
DialectRegistry registry;
mlir::registerLLVMDialectTranslation(registry);
#ifdef USE_ROCM
mlir::registerROCDLDialectTranslation(registry);
#else
mlir::registerNVVMDialectTranslation(registry);
#endif
context->appendDialectRegistry(registry);
llvm::DenseMap<llvm::StringRef, NVVMMetadata> nvvmMetadata;
@@ -116,9 +110,6 @@ translateLLVMToLLVMIR(llvm::LLVMContext *llvmContext, mlir::ModuleOp module) {
return nullptr;
}
// Initialize LLVM targets.
mlir::ExecutionEngine::setupTargetTriple(llvmModule.get());
auto optPipeline = mlir::makeOptimizingTransformer(
/*optLevel=*/3, /*sizeLevel=*/0,
/*targetMachine=*/nullptr);
@@ -139,7 +130,7 @@ translateLLVMToLLVMIR(llvm::LLVMContext *llvmContext, mlir::ModuleOp module) {
std::unique_ptr<llvm::Module>
translateTritonGPUToLLVMIR(llvm::LLVMContext *llvmContext,
mlir::ModuleOp module) {
mlir::ModuleOp module, int computeCapability) {
mlir::PassManager pm(module->getContext());
applyPassManagerCLOptions(pm);
auto printingFlags = mlir::OpPrintingFlags();
@@ -154,8 +145,8 @@ translateTritonGPUToLLVMIR(llvm::LLVMContext *llvmContext,
/*printAfterOnlyOnChange=*/true,
/*printAfterOnlyOnFailure*/ false, llvm::dbgs(), printingFlags);
pm.addPass(createConvertTritonGPUToLLVMPass());
// Conanicalize to eliminate the remaining UnrealizedConversionCastOp
pm.addPass(createConvertTritonGPUToLLVMPass(computeCapability));
// Canonicalize to eliminate the remaining UnrealizedConversionCastOp
pm.addPass(mlir::createCanonicalizerPass());
pm.addPass(mlir::createCSEPass()); // Simplify the IR to improve readability.
pm.addPass(mlir::createSymbolDCEPass());
@@ -208,6 +199,14 @@ translateTritonGPUToLLVMIR(llvm::LLVMContext *llvmContext,
return nullptr;
}
if (::triton::tools::getBoolEnv("LLVM_IR_ENABLE_DUMP")) {
std::string mod_string;
std::unique_ptr<llvm::raw_string_ostream> ir_ss(
new llvm::raw_string_ostream(mod_string));
llvmir->print(*ir_ss, nullptr);
std::cout << "// -----// LLVM IR Dump //----- //\n" << mod_string << std::endl;
}
return llvmir;
}
@@ -228,7 +227,6 @@ void addExternalLibs(mlir::ModuleOp &module,
DictionaryAttr dict = DictionaryAttr::get(module->getContext(), attrs);
module.getOperation()->setAttr("triton_gpu.externs", dict);
return;
}
bool linkExternLib(llvm::Module &module, llvm::StringRef path) {

View File

@@ -1,69 +1,39 @@
#include "triton/Target/PTX/PTXTranslation.h"
#include "mlir/Dialect/LLVMIR/LLVMDialect.h"
#include "mlir/ExecutionEngine/ExecutionEngine.h"
#include "mlir/ExecutionEngine/OptUtils.h"
#include "mlir/IR/BuiltinOps.h"
#include "mlir/IR/Dialect.h"
#include "mlir/Pass/Pass.h"
#include "mlir/Pass/PassManager.h"
#include "mlir/Support/LogicalResult.h"
#include "mlir/Target/LLVMIR/Dialect/LLVMIR/LLVMToLLVMIRTranslation.h"
#include "mlir/Target/LLVMIR/Export.h"
#include "mlir/Target/LLVMIR/LLVMTranslationInterface.h"
#include "triton/Target/LLVMIR/LLVMIRTranslation.h"
#include "llvm/ExecutionEngine/ExecutionEngine.h"
#include "llvm/ExecutionEngine/SectionMemoryManager.h"
#include "llvm/IR/IRBuilder.h"
#include "llvm/IR/IRPrintingPasses.h"
#include "llvm/IR/LegacyPassManager.h"
#include "llvm/IR/Module.h"
#include "llvm/IR/Verifier.h"
#include "llvm/MC/TargetRegistry.h"
#include "llvm/Support/CodeGen.h"
#include "llvm/Support/CommandLine.h"
#include "llvm/Support/SourceMgr.h"
#include "llvm/Support/TargetSelect.h"
#include "llvm/Support/raw_ostream.h"
#include "llvm/Target/TargetMachine.h"
#include "llvm/Target/TargetOptions.h"
#include "llvm/Transforms/Scalar.h"
#include "llvm/Transforms/Utils/Cloning.h"
#include <filesystem>
#include <regex>
namespace triton {
extern "C" {
int set_curterm(char *nterm) { return 0; }
int del_curterm(char *nterm) { return 0; }
int tigetnum(char *capname) { return 0; }
int setupterm(char *term, int fildes, int *errret) { return 0; }
}
static void init_llvm() {
static void initLLVM() {
LLVMInitializeNVPTXTargetInfo();
LLVMInitializeNVPTXTarget();
LLVMInitializeNVPTXTargetMC();
LLVMInitializeNVPTXAsmPrinter();
}
static bool find_and_replace(std::string &str, const std::string &begin,
const std::string &end,
const std::string &target) {
size_t start_replace = str.find(begin);
if (start_replace == std::string::npos)
static bool findAndReplace(std::string &str, const std::string &begin,
const std::string &end, const std::string &target) {
size_t startReplace = str.find(begin);
if (startReplace == std::string::npos)
return false;
size_t end_replace = str.find(end, start_replace);
if (end_replace == std::string::npos)
size_t endReplace = str.find(end, startReplace);
if (endReplace == std::string::npos)
return false;
str.replace(start_replace, end_replace + 1 - start_replace, target);
str.replace(startReplace, endReplace + 1 - startReplace, target);
return true;
}
static std::string llir_to_ptx(llvm::Module *module, int capability, int ptx) {
static void linkExternal(llvm::Module &module) {
bool hasExternal = false;
for (auto &func : *module) {
for (auto &func : module) {
if (func.hasExternalLinkage()) {
hasExternal = true;
break;
@@ -80,16 +50,14 @@ static std::string llir_to_ptx(llvm::Module *module, int capability, int ptx) {
.parent_path() /
"python" / "triton" / "language" /
"libdevice.10.bc";
if (mlir::triton::linkExternLib(*module, libdevice.string()))
if (mlir::triton::linkExternLib(module, libdevice.string()))
llvm::errs() << "link failed for: " << libdevice.string();
}
// please check https://llvm.org/docs/NVPTXUsage.html#reflection-parameters
// this will enable fast math path in libdevice
// for example, when enable nvvm-reflect-ftz, sqrt.approx.f32 will change to
// sqrt.approx.ftz.f32
{
auto &ctx = module->getContext();
// please check https://llvm.org/docs/NVPTXUsage.html#reflection-parameters
// this will enable fast math path in libdevice
// for example, when enable nvvm-reflect-ftz, sqrt.approx.f32 will change to
// sqrt.approx.ftz.f32
auto &ctx = module.getContext();
llvm::Type *I32 = llvm::Type::getInt32Ty(ctx);
llvm::Metadata *mdFour =
llvm::ConstantAsMetadata::get(llvm::ConstantInt::getSigned(I32, 4));
@@ -97,81 +65,80 @@ static std::string llir_to_ptx(llvm::Module *module, int capability, int ptx) {
llvm::Metadata *mdOne =
llvm::ConstantAsMetadata::get(llvm::ConstantInt::getSigned(I32, 1));
llvm::MDNode *reflect = llvm::MDNode::get(ctx, {mdFour, mdName, mdOne});
module->addModuleFlag(reflect);
module.addModuleFlag(reflect);
}
}
std::string translateLLVMIRToPTX(llvm::Module &module, int cc, int version) {
linkExternal(module);
// LLVM version in use may not officially support target hardware
int max_nvvm_cc = 75;
// int max_nvvm_ptx = 74;
int maxNNVMCC = 75;
// options
auto options = llvm::cl::getRegisteredOptions();
auto *short_ptr =
auto *shortPtr =
static_cast<llvm::cl::opt<bool> *>(options["nvptx-short-ptr"]);
assert(short_ptr);
short_ptr->setValue(true);
assert(shortPtr);
shortPtr->setValue(true);
// compute capability
std::string sm = "sm_" + std::to_string(capability);
std::string sm = "sm_" + std::to_string(cc);
// max PTX version
int ptx_major = ptx / 10;
int ptx_minor = ptx % 10;
int ptxMajor = version / 10;
int ptxMinor = version % 10;
// create
llvm::SmallVector<char, 0> buffer;
std::string triple = "nvptx64-nvidia-cuda";
std::string proc = "sm_" + std::to_string(std::min(capability, max_nvvm_cc));
std::string proc = "sm_" + std::to_string(std::min(cc, maxNNVMCC));
std::string layout = "";
std::string features = "";
// std::string features = "+ptx" + std::to_string(std::min(ptx,
// max_nvvm_ptx));
init_llvm();
initLLVM();
// verify and store llvm
llvm::legacy::PassManager pm;
pm.add(llvm::createVerifierPass());
pm.run(*module);
pm.run(module);
// module->print(llvm::outs(), nullptr);
// create machine
module->setTargetTriple(triple);
module.setTargetTriple(triple);
std::string error;
auto target =
llvm::TargetRegistry::lookupTarget(module->getTargetTriple(), error);
llvm::TargetRegistry::lookupTarget(module.getTargetTriple(), error);
llvm::TargetOptions opt;
opt.AllowFPOpFusion = llvm::FPOpFusion::Fast;
opt.UnsafeFPMath = false;
opt.NoInfsFPMath = false;
opt.NoNaNsFPMath = true;
llvm::TargetMachine *machine = target->createTargetMachine(
module->getTargetTriple(), proc, features, opt, llvm::Reloc::PIC_,
module.getTargetTriple(), proc, features, opt, llvm::Reloc::PIC_,
llvm::None, llvm::CodeGenOpt::Aggressive);
// set data layout
if (layout.empty())
module->setDataLayout(machine->createDataLayout());
module.setDataLayout(machine->createDataLayout());
else
module->setDataLayout(layout);
module.setDataLayout(layout);
// emit machine code
for (llvm::Function &f : module->functions())
for (llvm::Function &f : module.functions())
f.addFnAttr(llvm::Attribute::AlwaysInline);
llvm::legacy::PassManager pass;
llvm::raw_svector_ostream stream(buffer);
// emit
machine->addPassesToEmitFile(pass, stream, nullptr,
llvm::CodeGenFileType::CGFT_AssemblyFile);
pass.run(*module);
pass.run(module);
// post-process
std::string result(buffer.begin(), buffer.end());
find_and_replace(result, ".version", "\n",
".version " + std::to_string(ptx_major) + "." +
std::to_string(ptx_minor) + "\n");
find_and_replace(result, ".target", "\n", ".target " + sm + "\n");
while (find_and_replace(result, "\t// begin inline asm", "\n", ""))
findAndReplace(result, ".version", "\n",
".version " + std::to_string(ptxMajor) + "." +
std::to_string(ptxMinor) + "\n");
findAndReplace(result, ".target", "\n", ".target " + sm + "\n");
while (findAndReplace(result, "\t// begin inline asm", "\n", ""))
;
while (find_and_replace(result, "\t// end inline asm", "\n", ""))
while (findAndReplace(result, "\t// end inline asm", "\n", ""))
;
return result;
}
std::string translateLLVMIRToPTX(llvm::Module &module, int cc, int version) {
auto ptxCode = llir_to_ptx(&module, cc, version);
return ptxCode;
}
} // namespace triton

View File

@@ -6,7 +6,6 @@ import shutil
import subprocess
import sys
import tarfile
import tempfile
import urllib.request
from distutils.version import LooseVersion
from typing import NamedTuple
@@ -25,8 +24,11 @@ def get_build_type():
return "Debug"
elif check_env_flag("REL_WITH_DEB_INFO"):
return "RelWithDebInfo"
elif check_env_flag("TRITON_REL_BUILD_WITH_ASSERTS"):
return "TritonRelBuildWithAsserts"
else:
return "Release"
# TODO: change to release when stable enough
return "TritonRelBuildWithAsserts"
# --- third party packages -----
@@ -124,19 +126,14 @@ class CMakeBuild(build_ext):
self.build_extension(ext)
def build_extension(self, ext):
self.debug = True
lit_dir = shutil.which('lit')
triton_cache_path = os.path.join(os.environ["HOME"], ".triton")
# lit is used by the test suite
thirdparty_cmake_args = get_thirdparty_packages(triton_cache_path)
extdir = os.path.abspath(os.path.dirname(self.get_ext_fullpath(ext.path)))
# create build directories
build_suffix = 'debug' if self.debug else 'release'
llvm_build_dir = os.path.join(tempfile.gettempdir(), "llvm-" + build_suffix)
if not os.path.exists(self.build_temp):
os.makedirs(self.build_temp)
if not os.path.exists(llvm_build_dir):
os.makedirs(llvm_build_dir)
# python directories
python_include_dir = distutils.sysconfig.get_python_inc()
cmake_args = [
@@ -145,13 +142,13 @@ class CMakeBuild(build_ext):
"-DTRITON_BUILD_TUTORIALS=OFF",
"-DTRITON_BUILD_PYTHON_MODULE=ON",
# '-DPYTHON_EXECUTABLE=' + sys.executable,
# '-DCMAKE_VERBOSE_MAKEFILE:BOOL=ON',
'-DCMAKE_VERBOSE_MAKEFILE:BOOL=ON',
"-DPYTHON_INCLUDE_DIRS=" + python_include_dir,
"-DLLVM_EXTERNAL_LIT=" + lit_dir
] + thirdparty_cmake_args
# configuration
cfg = "Debug" if self.debug else "Release"
cfg = get_build_type()
build_args = ["--config", cfg]
if platform.system() == "Windows":
@@ -183,7 +180,11 @@ setup(
"torch",
"lit",
],
package_data={"triton/ops": ["*.c"], "triton/ops/blocksparse": ["*.c"]},
package_data={
"triton/ops": ["*.c"],
"triton/ops/blocksparse": ["*.c"],
"triton/language": ["*.bc"]
},
include_package_data=True,
ext_modules=[CMakeExtension("triton", "triton/_C/")],
cmdclass={"build_ext": CMakeBuild},

View File

@@ -11,9 +11,10 @@
#include "mlir/Parser.h"
#include "mlir/Support/FileUtilities.h"
#include "mlir/Dialect/LLVMIR/LLVMDialect.h"
#include "triton/Analysis/Allocation.h"
#include "triton/Conversion/TritonGPUToLLVM/TritonGPUToLLVM.h"
#include "triton/Conversion/TritonToTritonGPU/TritonToTritonGPU.h"
#include "triton/Conversion/TritonGPUToLLVM/TritonGPUToLLVMPass.h"
#include "triton/Conversion/TritonToTritonGPU/TritonToTritonGPUPass.h"
#include "triton/Dialect/Triton/IR/Dialect.h"
#include "triton/Dialect/Triton/IR/Types.h"
#include "triton/Dialect/Triton/Transforms/Passes.h"
@@ -21,7 +22,8 @@
#include "triton/Target/AMDGCN/AMDGCNTranslation.h"
#include "triton/Target/LLVMIR/LLVMIRTranslation.h"
#include "triton/Target/PTX/PTXTranslation.h"
#include "triton/tools/sys/getenv.hpp"
#include "triton/Target/HSACO/HSACOTranslation.h"
#include "triton/Tools/Sys/GetEnv.hpp"
#include "llvm/IR/LegacyPassManager.h"
#include "llvm/IR/Module.h"
@@ -106,7 +108,7 @@ void init_triton_ir(py::module &&m) {
.value("AND", mlir::triton::RMWOp::AND)
.value("OR", mlir::triton::RMWOp::OR)
.value("XOR", mlir::triton::RMWOp::XOR)
// .value("XCHG", mlir::triton::RMWOp::Xchg)
.value("XCHG", mlir::triton::RMWOp::XCHG)
.value("MAX", mlir::triton::RMWOp::MAX)
.value("MIN", mlir::triton::RMWOp::MIN)
.value("UMIN", mlir::triton::RMWOp::UMIN)
@@ -116,6 +118,11 @@ void init_triton_ir(py::module &&m) {
.def(py::init<>())
.def("load_triton", [](mlir::MLIRContext &self) {
self.getOrLoadDialect<mlir::triton::TritonDialect>();
// we load LLVM because the frontend uses LLVM.undef for
// some placeholders
self.getOrLoadDialect<mlir::triton::TritonDialect>();
self.getOrLoadDialect<mlir::LLVM::LLVMDialect>();
self.getOrLoadDialect<mlir::gpu::GPUDialect>();
});
// .def(py::init([](){
// mlir::MLIRContext context;
@@ -164,7 +171,19 @@ void init_triton_ir(py::module &&m) {
py::class_<mlir::Type>(m, "type")
.def("is_integer", &mlir::Type::isInteger)
.def("is_fp16", &mlir::Type::isF16);
.def("is_fp16", &mlir::Type::isF16)
.def("__str__", [](mlir::Type &self) {
std::string str;
llvm::raw_string_ostream os(str);
self.print(os);
return os.str();
});
py::class_<mlir::FunctionType>(m, "function_type")
.def("param_types", [](mlir::FunctionType &self) {
return std::vector<mlir::Type>(self.getInputs().begin(),
self.getInputs().end());
});
py::class_<mlir::Value>(m, "value")
.def("set_attr",
@@ -173,15 +192,16 @@ void init_triton_ir(py::module &&m) {
if (mlir::Operation *definingOp = self.getDefiningOp())
definingOp->setAttr(name, attr);
else {
/* issue an warning */
/* issue a warning */
}
})
.def("get_context", &mlir::Value::getContext)
.def("replace_all_uses_with",
[](mlir::Value &self, mlir::Value &newValue) {
self.replaceAllUsesWith(newValue);
});
py::class_<mlir::BlockArgument, mlir::Value>(m, "block_arguement");
py::class_<mlir::BlockArgument, mlir::Value>(m, "block_argument");
py::class_<mlir::Region>(m, "region")
.def("get_parent_region", &mlir::Region::getParentRegion, ret::reference)
@@ -289,7 +309,7 @@ void init_triton_ir(py::module &&m) {
py::class_<mlir::scf::WhileOp, mlir::OpState>(m, "WhileOp")
.def("get_before", &mlir::scf::WhileOp::getBefore, ret::reference)
.def("get_after", &mlir::scf::WhileOp::getAfter, ret::reference);
py::class_<mlir::scf::ConditionOp, mlir::OpState>(m, "CondtionOp");
py::class_<mlir::scf::ConditionOp, mlir::OpState>(m, "ConditionOp");
// dynamic_attr is used to transfer ownership of the MLIR context to the
// module
@@ -315,12 +335,30 @@ void init_triton_ir(py::module &&m) {
.def("get_function",
[](mlir::ModuleOp &self, std::string &funcName) -> mlir::FuncOp {
return self.lookupSymbol<mlir::FuncOp>(funcName);
});
})
.def("get_single_function", [](mlir::ModuleOp &self) -> mlir::FuncOp {
llvm::SmallVector<mlir::FuncOp> funcs;
self.walk([&](mlir::FuncOp func) { funcs.push_back(func); });
if (funcs.size() != 1)
throw std::runtime_error("Expected a single function");
return funcs[0];
});
m.def("make_attr",
[](const std::vector<int> &values, mlir::MLIRContext &context) {
return mlir::DenseIntElementsAttr::get(
mlir::RankedTensorType::get(
{static_cast<int64_t>(values.size())},
mlir::IntegerType::get(&context, 32)),
values)
.cast<mlir::Attribute>();
});
m.def(
"parse_mlir_module",
[](const std::string &inputFilename, mlir::MLIRContext &context) {
// initialize registry
// note: we initialize llvm for undef
mlir::DialectRegistry registry;
registry.insert<mlir::triton::TritonDialect,
mlir::triton::gpu::TritonGPUDialect,
@@ -364,6 +402,7 @@ void init_triton_ir(py::module &&m) {
self.setArgAttr(arg_no, name, mlir::IntegerAttr::get(attrTy, val));
},
ret::reference)
.def_property_readonly("type", &mlir::FuncOp::getType)
.def("reset_type", &mlir::FuncOp::setType);
py::class_<mlir::OpBuilder::InsertPoint>(m, "InsertPoint");
@@ -424,7 +463,7 @@ void init_triton_ir(py::module &&m) {
.def("get_bool_attr", &mlir::OpBuilder::getBoolAttr)
.def("get_int32_attr", &mlir::OpBuilder::getI32IntegerAttr)
// Use arith.ConstantOp to create constants
// // Constants
// Constants
.def("get_int1",
[](mlir::OpBuilder &self, bool v) -> mlir::Value {
auto loc = self.getUnknownLoc();
@@ -468,6 +507,15 @@ void init_triton_ir(py::module &&m) {
else
throw std::runtime_error("Not implemented");
})
.def("get_one_value",
[](mlir::OpBuilder &self, mlir::Type type) -> mlir::Value {
auto loc = self.getUnknownLoc();
uint64_t val = 0x1;
if (auto intTy = type.dyn_cast<mlir::IntegerType>())
return self.create<mlir::arith::ConstantIntOp>(loc, val, intTy);
else
throw std::runtime_error("Not implemented");
})
// Types
.def("get_void_ty",
@@ -494,10 +542,6 @@ void init_triton_ir(py::module &&m) {
[](mlir::OpBuilder &self) -> mlir::Type {
return self.getType<mlir::triton::Float8Type>();
})
.def("get_bf8_ty",
[](mlir::OpBuilder &self) -> mlir::Type {
return self.getType<mlir::triton::BFloat8Type>();
})
.def(
"get_half_ty",
[](mlir::OpBuilder &self) -> mlir::Type { return self.getF16Type(); })
@@ -536,7 +580,7 @@ void init_triton_ir(py::module &&m) {
return llvm::dyn_cast<mlir::FuncOp>(funcOperation);
auto loc = self.getUnknownLoc();
if (auto funcTy = funcType.dyn_cast<mlir::FunctionType>()) {
mlir::ArrayRef<mlir::NamedAttribute> attrs = {
llvm::SmallVector<mlir::NamedAttribute> attrs = {
mlir::NamedAttribute(self.getStringAttr("sym_visibility"),
self.getStringAttr(visibility))};
return self.create<mlir::FuncOp>(loc, funcName, funcTy, attrs);
@@ -593,14 +637,14 @@ void init_triton_ir(py::module &&m) {
auto loc = self.getUnknownLoc();
return self.create<mlir::scf::WhileOp>(loc, retTypes, initArgs);
})
.def("create_condtion_op",
.def("create_condition_op",
[](mlir::OpBuilder &self, mlir::Value &cond,
std::vector<mlir::Value> &args) -> mlir::scf::ConditionOp {
auto loc = self.getUnknownLoc();
return self.create<mlir::scf::ConditionOp>(loc, cond, args);
})
// miscellious
// miscellaneous
.def("create_make_range",
[](mlir::OpBuilder &self, int start, int end) -> mlir::Value {
auto loc = self.getUnknownLoc();
@@ -617,14 +661,20 @@ void init_triton_ir(py::module &&m) {
})
// Cast instructions
// Conversions for custom FP types (FP8)
.def("create_fp_to_fp",
[](mlir::OpBuilder &self, mlir::Value &src,
mlir::Type &dstType) -> mlir::Value {
auto loc = self.getUnknownLoc();
return self.create<mlir::triton::FpToFpOp>(loc, dstType, src);
})
// Conversions for standard LLVM builtin types
.def("create_bitcast",
[](mlir::OpBuilder &self, mlir::Value &src,
mlir::Type &dstType) -> mlir::Value {
auto loc = self.getUnknownLoc();
return self.create<mlir::triton::BitcastOp>(loc, dstType, src);
})
// .def("create_cast", &ir::builder::create_cast)
// .def("create_ptr_to_int", &ir::builder::create_ptr_to_int)
.def("create_si_to_fp",
[](mlir::OpBuilder &self, mlir::Value &src,
mlir::Type &dstType) -> mlir::Value {
@@ -698,7 +748,6 @@ void init_triton_ir(py::module &&m) {
return self.create<mlir::arith::IndexCastOp>(loc, input,
self.getI32Type());
})
.def("create_fmul",
[](mlir::OpBuilder &self, mlir::Value &lhs,
mlir::Value &rhs) -> mlir::Value {
@@ -1042,15 +1091,15 @@ void init_triton_ir(py::module &&m) {
auto loc = self.getUnknownLoc();
return self.create<mlir::arith::OrIOp>(loc, lhs, rhs);
})
// // Input/Output
// Input/Output
.def("create_load",
[](mlir::OpBuilder &self, mlir::Value &ptrs,
mlir::triton::CacheModifier cacheModifer,
mlir::triton::CacheModifier cacheModifier,
mlir::triton::EvictionPolicy evictionPolicy,
bool isVolatile) -> mlir::Value {
auto loc = self.getUnknownLoc();
return self.create<mlir::triton::LoadOp>(
loc, ptrs, cacheModifer, evictionPolicy, isVolatile);
loc, ptrs, cacheModifier, evictionPolicy, isVolatile);
})
.def("create_store",
[](mlir::OpBuilder &self, mlir::Value &ptrs,
@@ -1114,6 +1163,16 @@ void init_triton_ir(py::module &&m) {
mlir::RankedTensorType::get(shape, lhsType.getElementType()),
lhs, rhs);
})
.def("create_trans",
[](mlir::OpBuilder &self, mlir::Value &arg) -> mlir::Value {
auto loc = self.getUnknownLoc();
auto argType = arg.getType().dyn_cast<mlir::RankedTensorType>();
auto argEltType = argType.getElementType();
std::vector<int64_t> retShape = argType.getShape();
std::reverse(retShape.begin(), retShape.end());
return self.create<mlir::triton::TransOp>(
loc, mlir::RankedTensorType::get(retShape, argEltType), arg);
})
.def("create_broadcast",
[](mlir::OpBuilder &self, mlir::Value &arg,
std::vector<int64_t> &shape) -> mlir::Value {
@@ -1141,9 +1200,19 @@ void init_triton_ir(py::module &&m) {
[](mlir::OpBuilder &self, mlir::Value &ptr, mlir::Value &cmp,
mlir::Value &val) -> mlir::Value {
auto loc = self.getUnknownLoc();
auto ptrType = mlir::getElementTypeOrSelf(ptr)
.cast<mlir::triton::PointerType>();
mlir::Type dstType = ptrType.getPointeeType();
mlir::Type dstType;
if (auto srcTensorType =
ptr.getType().dyn_cast<mlir::RankedTensorType>()) {
mlir::Type dstElemType = srcTensorType.getElementType()
.cast<mlir::triton::PointerType>()
.getPointeeType();
dstType = mlir::RankedTensorType::get(srcTensorType.getShape(),
dstElemType);
} else {
auto ptrType = mlir::getElementTypeOrSelf(ptr)
.cast<mlir::triton::PointerType>();
dstType = ptrType.getPointeeType();
}
return self.create<mlir::triton::AtomicCASOp>(loc, dstType, ptr,
cmp, val);
})
@@ -1152,9 +1221,19 @@ void init_triton_ir(py::module &&m) {
mlir::Value &ptr, mlir::Value &val,
mlir::Value &mask) -> mlir::Value {
auto loc = self.getUnknownLoc();
auto ptrType = mlir::getElementTypeOrSelf(ptr)
.cast<mlir::triton::PointerType>();
mlir::Type dstType = ptrType.getPointeeType();
mlir::Type dstType;
if (auto srcTensorType =
ptr.getType().dyn_cast<mlir::RankedTensorType>()) {
mlir::Type dstElemType = srcTensorType.getElementType()
.cast<mlir::triton::PointerType>()
.getPointeeType();
dstType = mlir::RankedTensorType::get(srcTensorType.getShape(),
dstElemType);
} else {
auto ptrType = mlir::getElementTypeOrSelf(ptr)
.cast<mlir::triton::PointerType>();
dstType = ptrType.getPointeeType();
}
return self.create<mlir::triton::AtomicRMWOp>(loc, dstType, rmwOp,
ptr, val, mask);
})
@@ -1183,11 +1262,10 @@ void init_triton_ir(py::module &&m) {
})
.def("create_dot",
[](mlir::OpBuilder &self, mlir::Value &a, mlir::Value &b,
mlir::Value &c, bool allowTF32, bool transA,
bool transB) -> mlir::Value {
mlir::Value &c, bool allowTF32) -> mlir::Value {
auto loc = self.getUnknownLoc();
return self.create<mlir::triton::DotOp>(loc, c.getType(), a, b, c,
allowTF32, transA, transB);
allowTF32);
})
.def("create_exp",
[](mlir::OpBuilder &self, mlir::Value &val) -> mlir::Value {
@@ -1222,10 +1300,11 @@ void init_triton_ir(py::module &&m) {
operand.getType().dyn_cast<mlir::RankedTensorType>();
std::vector<int64_t> shape = inputTensorType.getShape();
shape.erase(shape.begin() + axis);
mlir::Type resType = inputTensorType.getElementType();
bool withIndex = mlir::triton::ReduceOp::withIndex(redOp);
mlir::Type resType = withIndex ? self.getI32Type()
: inputTensorType.getElementType();
if (!shape.empty()) {
resType = mlir::RankedTensorType::get(
shape, inputTensorType.getElementType());
resType = mlir::RankedTensorType::get(shape, resType);
}
return self.create<mlir::triton::ReduceOp>(loc, resType, redOp,
operand, axis);
@@ -1258,7 +1337,18 @@ void init_triton_ir(py::module &&m) {
mlir::StringAttr::get(self.getContext(),
llvm::StringRef(prefix)),
values);
});
})
// Undef
.def("create_undef",
[](mlir::OpBuilder &self, mlir::Type &type) -> mlir::Value {
auto loc = self.getUnknownLoc();
return self.create<::mlir::LLVM::UndefOp>(loc, type);
})
// Force GPU barrier
.def("create_barrier", [](mlir::OpBuilder &self) {
auto loc = self.getUnknownLoc();
self.create<mlir::gpu::BarrierOp>(loc);
});
py::class_<mlir::PassManager>(m, "pass_manager")
.def(py::init<mlir::MLIRContext *>())
@@ -1322,13 +1412,14 @@ void init_triton_ir(py::module &&m) {
[](mlir::PassManager &self, int numStages) {
self.addPass(mlir::createTritonGPUPipelinePass(numStages));
})
.def("add_triton_gpu_combine_pass",
.def("add_tritongpu_prefetch_pass",
[](mlir::PassManager &self) {
self.addPass(mlir::createTritonGPUCombineOpsPass());
self.addPass(mlir::createTritonGPUPrefetchPass());
})
.def("add_triton_gpu_swizzle_pass",
[](mlir::PassManager &self) {
self.addPass(mlir::createTritonGPUSwizzlePass());
.def("add_triton_gpu_combine_pass",
[](mlir::PassManager &self, int computeCapability) {
self.addPass(
mlir::createTritonGPUCombineOpsPass(computeCapability));
})
.def("add_triton_gpu_to_llvm",
[](mlir::PassManager &self) {
@@ -1342,17 +1433,17 @@ void init_triton_ir(py::module &&m) {
void init_triton_translation(py::module &m) {
using ret = py::return_value_policy;
m.def("get_shared_memory_size", [](mlir::ModuleOp module) {
auto shared = module->getAttrOfType<mlir::IntegerAttr>("triton_gpu.shared");
m.def("get_shared_memory_size", [](mlir::ModuleOp mod) {
auto shared = mod->getAttrOfType<mlir::IntegerAttr>("triton_gpu.shared");
return shared.getInt();
});
m.def(
"translate_triton_gpu_to_llvmir",
[](mlir::ModuleOp op) {
[](mlir::ModuleOp op, int computeCapability) {
llvm::LLVMContext llvmContext;
auto llvmModule =
::mlir::triton::translateTritonGPUToLLVMIR(&llvmContext, op);
auto llvmModule = ::mlir::triton::translateTritonGPUToLLVMIR(
&llvmContext, op, computeCapability);
if (!llvmModule)
llvm::report_fatal_error("Failed to translate TritonGPU to LLVM IR.");
@@ -1374,6 +1465,12 @@ void init_triton_translation(py::module &m) {
llvm::SMDiagnostic error;
std::unique_ptr<llvm::Module> module =
llvm::parseIR(buffer->getMemBufferRef(), error, context);
if (!module) {
llvm::report_fatal_error(
"failed to parse IR: " + error.getMessage() +
"lineno: " + std::to_string(error.getLineNo()));
}
// translate module to PTX
auto ptxCode =
triton::translateLLVMIRToPTX(*module, capability, version);
@@ -1426,7 +1523,7 @@ void init_triton_translation(py::module &m) {
});
m.def(
"translate_llvmir_to_amdgcn",
"translate_llvmir_to_hsaco",
[](const std::string llvmIR, std::string cc) -> std::tuple<std::string, std::string> {
// create LLVM module from C++
llvm::LLVMContext context;
@@ -1437,7 +1534,7 @@ void init_triton_translation(py::module &m) {
llvm::parseIR(buffer->getMemBufferRef(), error, context);
// translate module to HSACO
auto hsacoCode =
triton::translateLLVMIRToAMDGCN(*module, cc);
triton::translateLLVMIRToHSACO(*module, cc);
return hsacoCode;
},
ret::take_ownership);

View File

@@ -0,0 +1,164 @@
import subprocess
import sys
import pytest
import torch
import triton
import triton.language as tl
from triton.testing import get_dram_gbps, get_max_tensorcore_tflops
DEVICE_NAME = 'v100'
#######################
# Utilities
#######################
def nvsmi(attrs):
attrs = ','.join(attrs)
cmd = ['nvidia-smi', '-i', '0', '--query-gpu=' + attrs, '--format=csv,noheader,nounits']
out = subprocess.check_output(cmd)
ret = out.decode(sys.stdout.encoding).split(',')
ret = [int(x) for x in ret]
return ret
#######################
# Matrix Multiplication
#######################
sm_clocks = {'v100': 1350, 'a100': 1350}
mem_clocks = {'v100': 877, 'a100': 1215}
matmul_data = {
'v100': {
# square
(256, 256, 256): {'float16': 0.027},
(512, 512, 512): {'float16': 0.158},
(1024, 1024, 1024): {'float16': 0.466},
(2048, 2048, 2048): {'float16': 0.695},
(4096, 4096, 4096): {'float16': 0.831},
(8192, 8192, 8192): {'float16': 0.849},
# tall-skinny
(16, 1024, 1024): {'float16': 0.0128},
(16, 4096, 4096): {'float16': 0.0883},
(16, 8192, 8192): {'float16': 0.101},
(64, 1024, 1024): {'float16': 0.073},
(64, 4096, 4096): {'float16': 0.270},
(64, 8192, 8192): {'float16': 0.459},
(1024, 64, 1024): {'float16': 0.0692},
(4096, 64, 4096): {'float16': 0.264},
(8192, 64, 8192): {'float16': 0.452},
},
'a100': {
(256, 256, 256): {'float16': 0.010, 'float32': 0.0214, 'int8': 0.006},
(512, 512, 512): {'float16': 0.061, 'float32': 0.109, 'int8': 0.030},
(1024, 1024, 1024): {'float16': 0.287, 'float32': 0.331, 'int8': 0.169},
(2048, 2048, 2048): {'float16': 0.604, 'float32': 0.599, 'int8': 0.385},
(4096, 4096, 4096): {'float16': 0.842, 'float32': 0.862, 'int8': 0.711},
(8192, 8192, 8192): {'float16': 0.896, 'float32': 0.932, 'int8': 0.860},
# tall-skinny
(16, 1024, 1024): {'float16': 0.0077, 'float32': 0.0127, 'int8': 0.005},
(16, 4096, 4096): {'float16': 0.0363, 'float32': 0.0457, 'int8': 0.0259},
(16, 8192, 8192): {'float16': 0.0564, 'float32': 0.0648, 'int8': 0.0431},
(64, 1024, 1024): {'float16': 0.0271, 'float32': 0.0509, 'int8': 0.0169},
(64, 4096, 4096): {'float16': 0.141, 'float32': 0.162, 'int8': 0.097},
(64, 8192, 8192): {'float16': 0.244, 'float32': 0.257, 'int8': 0.174},
(1024, 64, 1024): {'float16': 0.0263, 'float32': 0.0458, 'int8': 0.017},
(4096, 64, 4096): {'float16': 0.135, 'float32': 0.177, 'int8': 0.102},
(8192, 64, 8192): {'float16': 0.216, 'float32': 0.230, 'int8': 0.177},
}
# # deep reductions
# (64 , 64 , 16384) : {'a100': 0.},
# (64 , 64 , 65536) : {'a100': 0.},
# (256 , 256 , 8192 ) : {'a100': 0.},
# (256 , 256 , 32768) : {'a100': 0.},
}
@pytest.mark.parametrize('M, N, K, dtype_str',
[(M, N, K, dtype_str)
for M, N, K in matmul_data[DEVICE_NAME].keys()
for dtype_str in ['float16']])
def test_matmul(M, N, K, dtype_str):
if dtype_str in ['float32', 'int8'] and DEVICE_NAME != 'a100':
pytest.skip('Only test float32 & int8 on a100')
dtype = {'float16': torch.float16, 'float32': torch.float32, 'int8': torch.int8}[dtype_str]
torch.manual_seed(0)
ref_gpu_util = matmul_data[DEVICE_NAME][(M, N, K)][dtype_str]
cur_sm_clock = nvsmi(['clocks.current.sm'])[0]
ref_sm_clock = sm_clocks[DEVICE_NAME]
max_gpu_perf = get_max_tensorcore_tflops(dtype, clock_rate=cur_sm_clock * 1e3)
assert abs(cur_sm_clock - ref_sm_clock) < 10, f'GPU SMs must run at {ref_sm_clock} MHz'
if dtype == torch.int8:
a = torch.randint(-128, 127, (M, K), dtype=dtype, device='cuda')
b = torch.randint(-128, 127, (N, K), dtype=dtype, device='cuda')
b = b.t() # only test row-col layout
else:
a = torch.randn((M, K), dtype=dtype, device='cuda')
b = torch.randn((K, N), dtype=dtype, device='cuda')
fn = lambda: triton.ops.matmul(a, b)
ms = triton.testing.do_bench(fn, percentiles=None, warmup=25, rep=1000)
cur_gpu_perf = 2. * M * N * K / ms * 1e-9
cur_gpu_util = cur_gpu_perf / max_gpu_perf
triton.testing.assert_almost_equal(cur_gpu_util, ref_gpu_util, decimal=2)
#######################
# Element-Wise
#######################
@triton.jit
def _add(x_ptr, y_ptr, output_ptr, n_elements,
BLOCK_SIZE: tl.constexpr):
pid = tl.program_id(axis=0)
block_start = pid * BLOCK_SIZE
offsets = block_start + tl.arange(0, BLOCK_SIZE)
mask = offsets < n_elements
x = tl.load(x_ptr + offsets, mask=mask)
y = tl.load(y_ptr + offsets, mask=mask)
output = x + y
tl.store(output_ptr + offsets, output, mask=mask)
elementwise_data = {
'v100': {
1024 * 16: 0.0219,
1024 * 64: 0.0791,
1024 * 256: 0.243,
1024 * 1024: 0.530,
1024 * 4096: 0.796,
1024 * 16384: 0.905,
1024 * 65536: 0.939,
},
'a100': {
1024 * 16: 0.008,
1024 * 64: 0.034,
1024 * 256: 0.114,
1024 * 1024: 0.315,
1024 * 4096: 0.580,
1024 * 16384: 0.782,
1024 * 65536: 0.850,
}
}
@pytest.mark.parametrize('N', elementwise_data[DEVICE_NAME].keys())
def test_elementwise(N):
torch.manual_seed(0)
ref_gpu_util = elementwise_data[DEVICE_NAME][N]
cur_mem_clock = nvsmi(['clocks.current.memory'])[0]
ref_mem_clock = mem_clocks[DEVICE_NAME]
max_gpu_perf = get_dram_gbps()
assert abs(cur_mem_clock - ref_mem_clock) < 10, f'GPU memory must run at {ref_mem_clock} MHz'
z = torch.empty((N, ), dtype=torch.float16, device='cuda')
x = torch.randn_like(z)
y = torch.randn_like(z)
grid = lambda args: (triton.cdiv(N, args['BLOCK_SIZE']), )
fn = lambda: _add[grid](x, y, z, N, BLOCK_SIZE=1024)
ms = triton.testing.do_bench(fn, percentiles=None, warmup=25, rep=250)
cur_gpu_perf = 3. * N * z.element_size() / ms * 1e-6
cur_gpu_util = cur_gpu_perf / max_gpu_perf
triton.testing.assert_almost_equal(cur_gpu_util, ref_gpu_util, decimal=2)

View File

@@ -14,13 +14,13 @@ def empty_kernel(X, stride_xm, BLOCK: tl.constexpr):
def test_empty_kernel_cubin_compile():
device = torch.cuda.current_device()
kernel = triton.compile(empty_kernel,
"*fp32,i32,i32",
device=device,
signature="*fp32,i32,i32",
constants={"BLOCK": 256})
assert len(kernel.asm["cubin"]) > 0
if torch.version.hip is not None:
assert len(kernel.asm["hsaco_path"]) > 0
else:
assert len(kernel.asm["cubin"]) > 0
def test_empty_kernel_launch():

File diff suppressed because it is too large Load Diff

View File

@@ -32,8 +32,10 @@ torch_ops = {
"where": "where",
}
libdevice = '/usr/local/cuda/nvvm/libdevice/libdevice.10.bc'
if torch.version.hip is not None:
e_libs = None
else:
e_libs = {"libdevice": '/usr/local/cuda/nvvm/libdevice/libdevice.10.bc'}
def get_tensor(shape, data_type, b_positive=False):
x = None
@@ -89,7 +91,7 @@ def kernel(X, Y, BLOCK: tl.constexpr):
x = get_tensor(shape, input0_type, expr == 'log' or expr == 'sqrt')
# triton result
y = torch.zeros(shape, dtype=torch_type[output_type], device="cuda")
kernel[(1,)](x, y, BLOCK=shape[0], extern_libs={"libdevice": libdevice})
kernel[(1,)](x, y, BLOCK=shape[0], extern_libs=e_libs)
# reference result
y_ref = getattr(torch, torch_ops[expr])(x)
# compare
@@ -133,7 +135,7 @@ def kernel(X0, X1, Y, BLOCK: tl.constexpr):
# triton result
y = torch.zeros(shape, dtype=torch_type[output_type], device="cuda")
kernel[(1,)](x0, x1, y, BLOCK=shape[0], extern_libs={"libdevice": libdevice})
kernel[(1,)](x0, x1, y, BLOCK=shape[0], extern_libs=e_libs)
# reference result
if expr == "cdiv":
@@ -181,7 +183,7 @@ def kernel(X0, X1, X2, Y, BLOCK: tl.constexpr):
# triton result
y = torch.zeros(shape, dtype=torch_type[output_type], device="cuda")
kernel[(1,)](x0, x1, x2, y, BLOCK=shape[0], extern_libs={"libdevice": libdevice})
kernel[(1,)](x0, x1, x2, y, BLOCK=shape[0], extern_libs=e_libs)
# reference result
y_ref = getattr(torch, torch_ops[expr])(x0, x1, x2)

View File

@@ -6,7 +6,6 @@ from torch.testing import assert_close
import triton
import triton.language as tl
@pytest.mark.parametrize('num_warps, block_size, iter_size', [
[4, 256, 1],
[4, 1024, 256],
@@ -173,6 +172,9 @@ def kernel(X, Y, BLOCK: tl.constexpr):
# triton result
y = torch.zeros(shape, dtype=x.dtype, device="cuda")
kernel[(1,)](x, y, BLOCK=shape[0], extern_libs={"libdevice": lib_path})
if torch.version.hip is not None:
kernel[(1,)](x, y, BLOCK=shape[0], extern_libs=None)
else:
kernel[(1,)](x, y, BLOCK=shape[0], extern_libs={"libdevice": lib_path})
# compare
assert_close(y, y_ref)

View File

@@ -1,12 +1,13 @@
import os
import subprocess
import sys
dir_path = os.path.dirname(os.path.realpath(__file__))
printf_path = os.path.join(dir_path, "printf_helper.py")
def test_printf():
proc = subprocess.Popen(["python", printf_path], stdout=subprocess.PIPE, shell=False)
proc = subprocess.Popen([sys.executable, printf_path], stdout=subprocess.PIPE, shell=False)
(outs, err) = proc.communicate()
outs = outs.split()
new_lines = set()

View File

@@ -0,0 +1,198 @@
import numpy as np
import pytest
import scipy.stats
import torch
import triton
import triton.language as tl
#####################################
# Reference Philox Implementation
#####################################
class PhiloxConfig:
def __init__(self, PHILOX_ROUND_A, PHILOX_ROUND_B, PHILOX_KEY_A, PHILOX_KEY_B, DTYPE):
self.PHILOX_ROUND_A = np.array(PHILOX_ROUND_A, dtype=DTYPE)
self.PHILOX_ROUND_B = np.array(PHILOX_ROUND_B, dtype=DTYPE)
self.PHILOX_KEY_A = np.array(PHILOX_KEY_A, dtype=DTYPE)
self.PHILOX_KEY_B = np.array(PHILOX_KEY_B, dtype=DTYPE)
self.DTYPE = DTYPE
# This is better for GPU
PHILOX_32 = PhiloxConfig(
PHILOX_KEY_A=0x9E3779B9,
PHILOX_KEY_B=0xBB67AE85,
PHILOX_ROUND_A=0xD2511F53,
PHILOX_ROUND_B=0xCD9E8D57,
DTYPE=np.uint32,
)
# This is what numpy implements
PHILOX_64 = PhiloxConfig(
PHILOX_KEY_A=0x9E3779B97F4A7C15,
PHILOX_KEY_B=0xBB67AE8584CAA73B,
PHILOX_ROUND_A=0xD2E7470EE14C6C93,
PHILOX_ROUND_B=0xCA5A826395121157,
DTYPE=np.uint64,
)
class CustomPhilox4x:
def __init__(self, seed, config):
self._config = config
seed = self._into_pieces(seed)
self._key = np.array(seed[:2], dtype=self._dtype)
self._counter = np.array((0, 0) + seed[2:], dtype=self._dtype)
@property
def _dtype(self):
return self._config.DTYPE
def _into_pieces(self, n, pad=4):
res = []
while len(res) < pad:
res.append(np.array(n, dtype=self._dtype))
n >>= (np.dtype(self._dtype).itemsize * 8)
assert n == 0
return tuple(res)
def _multiply_low_high(self, a, b):
low = a * b
high = int(a) * int(b)
high = np.array(high >> (np.dtype(self._dtype).itemsize * 8), dtype=self._dtype)
return low, high
def _single_round(self, counter, key):
lo0, hi0 = self._multiply_low_high(self._config.PHILOX_ROUND_A, counter[0])
lo1, hi1 = self._multiply_low_high(self._config.PHILOX_ROUND_B, counter[2])
ret0 = hi1 ^ counter[1] ^ key[0]
ret1 = lo1
ret2 = hi0 ^ counter[3] ^ key[1]
ret3 = lo0
return np.array([ret0, ret1, ret2, ret3], dtype=self._dtype)
def _raise_key(self, key):
pk = [self._config.PHILOX_KEY_A, self._config.PHILOX_KEY_B]
return key + np.array(pk, dtype=self._dtype)
def random_raw(self):
counter = self._counter
key = self._key
for _ in range(10):
counter = self._single_round(counter, key)
key = self._raise_key(key)
self.advance(1)
return counter
def advance(self, n_steps):
self._counter[0] += n_steps
assert self._counter[0] < 2**32, "FIXME: doesn't work for large offsets"
class CustomPhilox(CustomPhilox4x):
def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)
self.buffer = []
def random_raw(self):
if len(self.buffer) == 0:
self.buffer = list(super().random_raw())[::-1]
return int(self.buffer.pop())
#####################################
# Unit Tests
#####################################
BLOCK = 1024
# test generation of random uint32
@pytest.mark.parametrize('size, seed',
[(size, seed) for size in ['10', '4,53', '10000']
for seed in [0, 42, 124, 54, 0xffffffff, 0xdeadbeefcafeb0ba]]
)
def test_randint(size, seed, device='cuda'):
size = list(map(int, size.split(',')))
@triton.jit
def kernel(X, N, seed):
offset = tl.program_id(0) * BLOCK + tl.arange(0, BLOCK)
rand = tl.randint(seed, offset)
tl.store(X + offset, rand, mask=offset < N)
# triton result
x = torch.empty(size, dtype=torch.int32, device=device)
N = x.numel()
grid = (triton.cdiv(N, BLOCK),)
kernel[grid](x, N, seed)
out_tri = x.cpu().numpy().astype(np.uint32).flatten().tolist()
# reference result
gen = CustomPhilox4x(seed, config=PHILOX_32)
out_ref = [gen.random_raw()[0] for _ in out_tri]
assert out_tri == out_ref
# test uniform PRNG
@pytest.mark.parametrize('size, seed',
[(size, seed) for size in [1000000]
for seed in [0, 42, 124, 54]]
)
def test_rand(size, seed, device='cuda'):
@triton.jit
def kernel(X, N, seed):
offset = tl.program_id(0) * BLOCK + tl.arange(0, BLOCK)
rand = tl.rand(seed, offset)
tl.store(X + offset, rand, mask=offset < N)
# triton result
x = torch.empty(size, dtype=torch.float32, device=device)
N = x.numel()
grid = (triton.cdiv(N, BLOCK),)
kernel[grid](x, N, seed)
assert all((x >= 0) & (x <= 1))
assert scipy.stats.kstest(x.tolist(), 'uniform', args=(0, 1)).statistic < 0.01
# test normal PRNG
@pytest.mark.parametrize('size, seed',
[(size, seed) for size in [1000000]
for seed in [0, 42, 124, 54]]
)
def test_randn(size, seed, device='cuda'):
@triton.jit
def kernel(X, N, seed):
offset = tl.program_id(0) * BLOCK + tl.arange(0, BLOCK)
rand = tl.randn(seed, offset)
tl.store(X + offset, rand, mask=offset < N)
# triton result
x = torch.empty(size, dtype=torch.float32, device=device)
N = x.numel()
grid = (triton.cdiv(N, BLOCK),)
kernel[grid](x, N, seed)
assert abs(x.mean()) < 1e-2
assert abs(x.std() - 1) < 1e-2
# tl.rand() should never produce >=1.0
def test_rand_limits():
@triton.jit
def kernel(input, output, n: tl.constexpr):
idx = tl.arange(0, n)
x = tl.load(input + idx)
y = tl.random.uint32_to_uniform_float(x)
tl.store(output + idx, y)
min_max_int32 = torch.tensor([
torch.iinfo(torch.int32).min,
torch.iinfo(torch.int32).max,
], dtype=torch.int32, device='cuda')
output = torch.empty(2, dtype=torch.float32, device='cuda')
kernel[(1,)](min_max_int32, output, 2)
assert output[0] == output[1]
assert 1.0 - torch.finfo(torch.float32).eps <= output[0].item() < 1.0

View File

@@ -0,0 +1,192 @@
import pytest
import torch
import triton
@pytest.mark.parametrize("MODE", ["sdd", "dds", "dsd"])
@pytest.mark.parametrize("TRANS_A", [False, True])
@pytest.mark.parametrize("TRANS_B", [False, True])
@pytest.mark.parametrize("BLOCK", [16, 32, 64])
# TODO: float32 fails
@pytest.mark.parametrize("DTYPE", [torch.float16])
def test_matmul(MODE, TRANS_A, TRANS_B, BLOCK, DTYPE, Z=3, H=2, M=512, N=384, K=256):
seed = 0
torch.manual_seed(seed)
is_sdd = MODE == "sdd"
is_dsd = MODE == "dsd"
is_dds = MODE == "dds"
do_sparsify = lambda x: triton.testing.sparsify_tensor(x, layout, BLOCK)
do_mask = lambda x: triton.testing.mask_tensor(x, layout, BLOCK)
# create inputs
# create op
a_shape = (Z, H, K, M) if TRANS_A else (Z, H, M, K)
b_shape = (Z, H, N, K) if TRANS_B else (Z, H, K, N)
c_shape = (Z, H, M, N)
shape = {
"sdd": (M, N),
"dsd": (a_shape[2], a_shape[3]),
"dds": (b_shape[2], b_shape[3]),
}[MODE]
layout = torch.randint(2, (H, shape[0] // BLOCK, shape[1] // BLOCK))
layout[1, 2, :] = 0
layout[1, :, 1] = 0
# create data
a_ref, a_tri = triton.testing.make_pair(a_shape, alpha=.1, dtype=DTYPE)
b_ref, b_tri = triton.testing.make_pair(b_shape, alpha=.1, dtype=DTYPE)
dc_ref, dc_tri = triton.testing.make_pair(c_shape, dtype=DTYPE)
# compute [torch]
dc_ref = do_mask(dc_ref) if is_sdd else dc_ref
a_ref = do_mask(a_ref) if is_dsd else a_ref
b_ref = do_mask(b_ref) if is_dds else b_ref
a_ref.retain_grad()
b_ref.retain_grad()
c_ref = torch.matmul(a_ref.transpose(2, 3) if TRANS_A else a_ref,
b_ref.transpose(2, 3) if TRANS_B else b_ref)
c_ref.backward(dc_ref)
c_ref = do_sparsify(c_ref) if is_sdd else c_ref
da_ref = do_sparsify(a_ref.grad) if is_dsd else a_ref.grad
db_ref = do_sparsify(b_ref.grad) if is_dds else b_ref.grad
# triton result
dc_tri = do_sparsify(dc_tri) if is_sdd else dc_tri
a_tri = do_sparsify(a_tri) if is_dsd else a_tri
b_tri = do_sparsify(b_tri) if is_dds else b_tri
a_tri.retain_grad()
b_tri.retain_grad()
op = triton.ops.blocksparse.matmul(layout, BLOCK, MODE, trans_a=TRANS_A, trans_b=TRANS_B, device="cuda")
c_tri = triton.testing.catch_oor(lambda: op(a_tri, b_tri), pytest)
triton.testing.catch_oor(lambda: c_tri.backward(dc_tri), pytest)
da_tri = a_tri.grad
db_tri = b_tri.grad
# compare
triton.testing.assert_almost_equal(c_ref, c_tri)
triton.testing.assert_almost_equal(da_ref, da_tri)
triton.testing.assert_almost_equal(db_ref, db_tri)
configs = [
(16, 256),
(32, 576),
(64, 1871),
(128, 2511),
]
@pytest.mark.parametrize("is_dense", [False, True])
@pytest.mark.parametrize("BLOCK, WIDTH", configs)
def test_softmax(BLOCK, WIDTH, is_dense, Z=2, H=2, is_causal=True, scale=0.4):
# set seed
torch.random.manual_seed(0)
Z, H, M, N = 2, 3, WIDTH, WIDTH
# initialize layout
# make sure each row has at least one non-zero element
layout = torch.randint(2, (H, M // BLOCK, N // BLOCK))
if is_dense:
layout[:] = 1
else:
layout[1, 2, :] = 0
layout[1, :, 1] = 0
# initialize data
a_shape = (Z, H, M, N)
a_ref, a_tri = triton.testing.make_pair(a_shape)
dout_ref, dout_tri = triton.testing.make_pair(a_shape)
# compute [torch]
a_ref = triton.testing.mask_tensor(a_ref, layout, BLOCK, value=float("-inf"))
a_ref.retain_grad()
at_mask = torch.ones((M, N), device="cuda")
if is_causal:
at_mask = torch.tril(at_mask)
M = at_mask[None, None, :, :] + torch.zeros_like(a_ref)
a_ref[M == 0] = float("-inf")
out_ref = torch.softmax(a_ref * scale, -1)
out_ref.backward(dout_ref)
out_ref = triton.testing.sparsify_tensor(out_ref, layout, BLOCK)
da_ref = triton.testing.sparsify_tensor(a_ref.grad, layout, BLOCK)
# compute [triton]
a_tri = triton.testing.sparsify_tensor(a_tri, layout, BLOCK)
a_tri.retain_grad()
dout_tri = triton.testing.sparsify_tensor(dout_tri, layout, BLOCK)
op = triton.ops.blocksparse.softmax(layout, BLOCK, device="cuda", is_dense=is_dense)
out_tri = op(a_tri, scale=scale, is_causal=is_causal)
out_tri.backward(dout_tri)
da_tri = a_tri.grad
# compare
triton.testing.assert_almost_equal(out_tri, out_ref)
triton.testing.assert_almost_equal(da_tri, da_ref)
@pytest.mark.parametrize("block", [16, 32, 64])
@pytest.mark.parametrize("dtype", [torch.float16, torch.float32])
def test_attention_fwd_bwd(
block,
dtype,
input_scale=1.0,
scale=1 / 8.0,
n_ctx=256,
batch_size=2,
n_heads=2,
):
capability = torch.cuda.get_device_capability()
if capability[0] < 7:
pytest.skip("Only test tl.dot() on devices with sm >= 70")
# inputs
qkv_shape = (batch_size, n_heads, n_ctx, 64)
qkvs = [
torch.nn.Parameter(input_scale * torch.randn(qkv_shape), requires_grad=True).to(dtype).cuda() for _ in range(3)
]
# Triton:
n_blocks = n_ctx // block
layout = torch.tril(torch.ones([n_heads, n_blocks, n_blocks], dtype=torch.long))
query, key, value = [x.clone() for x in qkvs]
query.retain_grad()
key.retain_grad()
value.retain_grad()
attn_out = triton_attention(layout, block, query=query, key=key, value=value, scale=scale)
# ad hoc loss
loss = (attn_out ** 2).mean()
loss.backward()
grads = [query.grad, key.grad, value.grad]
# Torch version:
torch_q, torch_k, torch_v = [x.clone() for x in qkvs]
attn_mask = torch.ones([n_ctx, n_ctx], device="cuda", dtype=dtype)
attn_mask = torch.tril(attn_mask, diagonal=0)
attn_mask = 1e6 * (-1 + (attn_mask.reshape((1, 1, n_ctx, n_ctx)).cuda()))
torch_q.retain_grad()
torch_k.retain_grad()
torch_v.retain_grad()
scores = scale * torch.einsum("bhsd,bhtd->bhst", torch_q, torch_k)
scores = scores + attn_mask
probs = torch.softmax(scores, dim=-1)
torch_attn_out = torch.einsum("bhst,bhtd->bhsd", probs, torch_v)
# ad hoc loss
torch_loss = (torch_attn_out ** 2).mean()
torch_loss.backward()
torch_grads = [torch_q.grad, torch_k.grad, torch_v.grad]
# comparison
# print(f"Triton loss {loss} and torch loss {torch_loss}. Also checking grads...")
triton.testing.assert_almost_equal(loss, torch_loss)
for g1, g2 in zip(grads, torch_grads):
triton.testing.assert_almost_equal(g1, g2)
@pytest.mark.parametrize("block", [16, 32, 64])
def triton_attention(
layout,
block: int,
query: torch.Tensor,
key: torch.Tensor,
value: torch.Tensor,
scale: float,
):
sparse_dot_sdd_nt = triton.ops.blocksparse.matmul(layout, block, "sdd", trans_a=False, trans_b=True, device=value.device)
sparse_dot_dsd_nn = triton.ops.blocksparse.matmul(layout, block, "dsd", trans_a=False, trans_b=False, device=value.device)
sparse_softmax = triton.ops.blocksparse.softmax(layout, block, device=value.device)
w = sparse_dot_sdd_nt(query, key)
w = sparse_softmax(w, scale=scale, is_causal=True)
a = sparse_dot_dsd_nn(w, value)
return a

Some files were not shown because too many files have changed in this diff Show More