Initial code merge of Hopper support (#2036)

The initial code merge of Nvidia Hopper features support. Please be
aware that the code merge is not finished yet and the trouble-shooting
is still ongoing. The new hardware features (GMMA, TMA, STMATRIX etc.)
and automatic warp-specialization are experimental for now and turned
off by default. It is recommended for a trial when version 3.0 is
released.

The work is contributed by:
ben-zhang-609, bealwang, donproc, qliu93, jsh20, allatit23, LyricZhao,
ivanyinwz, goostavz & yangjunpro
from Nvidia, in cooperation with:
ptillet, Jokeren, ThomasRaoux & zahimoud
from OpenAI.

Co-authored-by: Goostav Zhu <gzhu@nvidia.com>
This commit is contained in:
goostavz
2023-08-07 09:53:04 +08:00
committed by GitHub
parent 5df904233c
commit f1512bded1
220 changed files with 28448 additions and 2295 deletions

View File

@@ -50,6 +50,8 @@ jobs:
if: ${{(matrix.runner[0] == 'self-hosted') && (matrix.runner[1] == 'V100' || matrix.runner[1] == 'A100' || matrix.runner[1] == 'H100')}}
run: |
echo "BACKEND=CUDA" >> "${GITHUB_ENV}"
echo "ENABLE_TMA=0" >> "${GITHUB_ENV}"
echo "ENABLE_MMA_V3=0" >> "${GITHUB_ENV}"
- name: Clear cache
run: |
@@ -79,8 +81,32 @@ jobs:
fi
lit -v "${LIT_TEST_DIR}"
- name: Run python tests on CUDA
if: ${{ env.BACKEND == 'CUDA'}}
- name: Enable MMAV3 and TMA
if: ${{(matrix.runner[0] == 'self-hosted') && (matrix.runner[1] == 'H100')}}
run: |
echo "ENABLE_TMA=1" >> "${GITHUB_ENV}"
echo "ENABLE_MMA_V3=1" >> "${GITHUB_ENV}"
- name: Run python tests on CUDA with ENABLE_TMA=1 and ENABLE_MMA_V3=1
if: ${{ env.BACKEND == 'CUDA' && env.ENABLE_TMA == '1' && env.ENABLE_MMA_V3 == '1'}}
run: |
cd python/test/unit
python3 -m pytest -n 8 --ignore=runtime
# run runtime tests serially to avoid race condition with cache handling.
python3 -m pytest runtime/
- name: Disable MMAV3 and TMA
if: ${{(matrix.runner[0] == 'self-hosted') && (matrix.runner[1] == 'H100')}}
run: |
echo "ENABLE_TMA=0" >> "${GITHUB_ENV}"
echo "ENABLE_MMA_V3=0" >> "${GITHUB_ENV}"
- name: Clear cache
run: |
rm -rf ~/.triton
- name: Run python tests on CUDA with ENABLE_TMA=0 and ENABLE_MMA_V3=0
if: ${{ env.BACKEND == 'CUDA' && env.ENABLE_TMA == '0' && env.ENABLE_MMA_V3 == '0'}}
run: |
cd python/test/unit
python3 -m pytest -n 8 --ignore=runtime

View File

@@ -209,6 +209,7 @@ if(TRITON_BUILD_PYTHON_MODULE)
TritonAnalysis
TritonTransforms
TritonGPUTransforms
TritonNvidiaGPUTransforms
TritonLLVMIR
TritonPTX
TritonHSACO

View File

@@ -9,6 +9,7 @@ target_link_libraries(triton-opt PRIVATE
TritonAnalysis
TritonTransforms
TritonGPUTransforms
TritonNvidiaGPUTransforms
${dialect_libs}
${conversion_libs}
# tests
@@ -29,6 +30,7 @@ target_link_libraries(triton-reduce PRIVATE
TritonAnalysis
TritonTransforms
TritonGPUTransforms
TritonNvidiaGPUTransforms
${dialect_libs}
${conversion_libs}
# tests
@@ -48,6 +50,7 @@ llvm_update_compile_flags(triton-translate)
TritonAnalysis
TritonTransforms
TritonGPUTransforms
TritonNvidiaGPUTransforms
TritonLLVMIR
TritonPTX
TritonHSACO

View File

@@ -1,9 +1,11 @@
#pragma once
#include "triton/Dialect/Triton/IR/Dialect.h"
#include "triton/Dialect/TritonGPU/IR/Dialect.h"
#include "triton/Dialect/TritonNvidiaGPU/IR/Dialect.h"
#include "triton/Dialect/Triton/Transforms/Passes.h"
#include "triton/Dialect/TritonGPU/Transforms/Passes.h"
#include "triton/Dialect/TritonNvidiaGPU/Transforms/Passes.h"
#include "triton/Conversion/TritonGPUToLLVM/Passes.h"
#include "triton/Conversion/TritonToTritonGPU/Passes.h"
@@ -23,6 +25,7 @@ inline void registerTritonDialects(mlir::DialectRegistry &registry) {
mlir::registerAllPasses();
mlir::registerTritonPasses();
mlir::registerTritonGPUPasses();
mlir::registerTritonNvidiaGPUPasses();
mlir::test::registerTestAliasPass();
mlir::test::registerTestAlignmentPass();
mlir::test::registerTestAllocationPass();
@@ -32,6 +35,7 @@ inline void registerTritonDialects(mlir::DialectRegistry &registry) {
// TODO: register Triton & TritonGPU passes
registry.insert<mlir::triton::TritonDialect, mlir::cf::ControlFlowDialect,
mlir::triton::nvidia_gpu::TritonNvidiaGPUDialect,
mlir::triton::gpu::TritonGPUDialect, mlir::math::MathDialect,
mlir::arith::ArithDialect, mlir::scf::SCFDialect,
mlir::gpu::GPUDialect>();

View File

@@ -14,6 +14,7 @@
#include "triton/Conversion/TritonToTritonGPU/TritonToTritonGPUPass.h"
#include "triton/Dialect/Triton/IR/Dialect.h"
#include "triton/Dialect/TritonGPU/IR/Dialect.h"
#include "triton/Dialect/TritonNvidiaGPU/IR/Dialect.h"
#include "triton/Target/HSACO/HSACOTranslation.h"
#include "triton/Target/LLVMIR/LLVMIRTranslation.h"
#include "triton/Target/PTX/PTXTranslation.h"
@@ -38,6 +39,7 @@ OwningOpRef<ModuleOp> loadMLIRModule(llvm::StringRef inputFilename,
mlir::DialectRegistry registry;
registry
.insert<TritonDialect, triton::gpu::TritonGPUDialect,
triton::nvidia_gpu::TritonNvidiaGPUDialect,
mlir::math::MathDialect, arith::ArithDialect, scf::SCFDialect>();
context.appendDialectRegistry(registry);
@@ -121,8 +123,10 @@ LogicalResult tritonTranslateMain(int argc, char **argv,
}
llvm::LLVMContext llvmContext;
auto llvmir = translateTritonGPUToLLVMIR(&llvmContext, *module,
SMArch.getValue(), false /*isRocm*/);
mlir::triton::gpu::TMAMetadataTy tmaInfos;
auto llvmir = translateTritonGPUToLLVMIR(
&llvmContext, *module, SMArch.getValue(), tmaInfos, false /*isRocm*/);
if (!llvmir) {
llvm::errs() << "Translate to LLVM IR failed";
}

View File

@@ -192,3 +192,4 @@ Iterators
:nosignatures:
static_range
multiple_of

View File

@@ -9,6 +9,7 @@
#include "triton/Dialect/Triton/IR/Dialect.h"
#include "triton/Dialect/TritonGPU/IR/Dialect.h"
#include "triton/Dialect/TritonNvidiaGPU/IR/Dialect.h"
#include <atomic>
#include <limits>
@@ -147,17 +148,17 @@ private:
BufferKind kind;
BufferId id;
size_t size;
size_t alignment;
size_t offset;
bool operator==(const BufferT &other) const { return id == other.id; }
bool operator<(const BufferT &other) const { return id < other.id; }
BufferT() : BufferT(BufferKind::Explicit) {}
BufferT(BufferKind kind)
: kind(kind), id(InvalidBufferId), size(0), offset(0) {}
BufferT(BufferKind kind, size_t size) : BufferT(kind, size, 0) {}
BufferT(BufferKind kind, size_t size, size_t offset)
: kind(kind), id(nextId++), size(size), offset(offset) {}
BufferT() : BufferT(BufferKind::Explicit, 0) {}
BufferT(BufferKind kind, size_t size, size_t alignment = 4,
size_t offset = 0)
: kind(kind), id(nextId++), size(size), alignment(alignment),
offset(offset) {}
};
/// Op -> Scratch Buffer

View File

@@ -21,7 +21,7 @@ namespace mlir {
/// This lattice value represents known information on the axes of a lattice.
class AxisInfo {
public:
typedef SmallVector<int64_t, 4> DimVectorT;
typedef SmallVector<int64_t> DimVectorT;
public:
/// Default constructor

View File

@@ -3,6 +3,7 @@
#include "mlir/Analysis/DataFlowFramework.h"
#include "mlir/Analysis/SliceAnalysis.h"
#include "triton/Dialect/Triton/IR/Dialect.h"
#include "triton/Dialect/TritonGPU/IR/Dialect.h"
#include <algorithm>
#include <numeric>
@@ -121,7 +122,11 @@ bool isSingleValue(Value value);
bool isMmaToDotShortcut(RankedTensorType &srcTy, RankedTensorType &dstTy);
Type getElementType(Value value);
bool isMmaToMmaShortcut(RankedTensorType &srcTy, RankedTensorType &dstTy);
// TODO: Move utility functions that belong to ConvertLayoutOp to class
// ConvertLayoutOpHelper in the future
bool shouldUseDistSmem(Attribute srcLayout, Attribute dstLayout);
template <typename T_OUT, typename T_IN>
inline SmallVector<T_OUT> convertType(ArrayRef<T_IN> in) {
@@ -324,6 +329,10 @@ protected:
FuncDataMapT funcMap;
SmallVector<FunctionOpInterface> roots;
};
// Create a basic DataFlowSolver with constant and dead code analysis included.
std::unique_ptr<DataFlowSolver> createDataFlowSolver();
triton::MakeTensorPtrOp getMakeTensorPtrOp(Value v);
} // namespace mlir

View File

@@ -19,6 +19,7 @@ def ConvertTritonGPUToLLVM : Pass<"convert-triton-gpu-to-llvm", "mlir::ModuleOp"
"mlir::tensor::TensorDialect",
"mlir::triton::TritonDialect",
"mlir::triton::gpu::TritonGPUDialect",
"mlir::triton::nvidia_gpu::TritonNvidiaGPUDialect",
"mlir::ROCDL::ROCDLDialect",
"mlir::NVVM::NVVMDialect"];
@@ -26,6 +27,9 @@ def ConvertTritonGPUToLLVM : Pass<"convert-triton-gpu-to-llvm", "mlir::ModuleOp"
Option<"computeCapability", "compute-capability",
"int32_t", /*default*/"80",
"device compute capability">,
Option<"TmaMetadata", "tma-metadata",
"mlir::triton::gpu::TMAMetadataTy*", /*default*/"nullptr",
"tma metadata to the runtime">,
Option<"isROCM", "is-rocm",
"bool", /*default*/"false",
"compile for ROCM-compatible LLVM">,

View File

@@ -3,6 +3,8 @@
#include "mlir/Conversion/LLVMCommon/TypeConverter.h"
#include "mlir/Transforms/DialectConversion.h"
#include "triton/Target/PTX/TmaMetadata.h"
#include <memory>
namespace mlir {
@@ -12,9 +14,10 @@ template <typename T> class OperationPass;
namespace triton {
std::unique_ptr<OperationPass<ModuleOp>>
createConvertTritonGPUToLLVMPass(int computeCapability = 80,
bool isROCM = false);
std::unique_ptr<OperationPass<ModuleOp>> createConvertTritonGPUToLLVMPass(
int computeCapability = 80,
mlir::triton::gpu::TMAMetadataTy *tmaMetadata = nullptr,
bool isROCM = false);
} // namespace triton

View File

@@ -2,6 +2,7 @@
#define TRITON_CONVERSION_PASSES_H
#include "triton/Conversion/TritonToTritonGPU/TritonToTritonGPUPass.h"
#include "triton/Target/PTX/TmaMetadata.h"
namespace mlir {
namespace triton {

View File

@@ -25,6 +25,12 @@ def ConvertTritonToTritonGPU: Pass<"convert-triton-to-tritongpu", "mlir::ModuleO
Option<"threadsPerWarp", "threads-per-warp",
"int32_t", /*default*/"32",
"number of threads per warp">,
Option<"numCTAs", "num-ctas",
"int32_t", /*default*/"1",
"number of ctas in a cga">,
Option<"computeCapability", "compute-capability",
"int32_t", /*default*/"80",
"compute capability">
];
}

View File

@@ -11,6 +11,9 @@ template <typename T> class OperationPass;
namespace triton {
constexpr static char AttrNumWarpsName[] = "triton_gpu.num-warps";
constexpr static char AttrNumCTAsName[] = "triton_gpu.num-ctas";
constexpr static char AttrComputeCapabilityName[] =
"triton_gpu.compute-capability";
constexpr static char AttrNumThreadsPerWarp[] = "triton_gpu.threads-per-warp";
@@ -19,7 +22,8 @@ std::unique_ptr<OperationPass<ModuleOp>> createConvertTritonToTritonGPUPass();
// Create the pass with numWarps set explicitly.
std::unique_ptr<OperationPass<ModuleOp>>
createConvertTritonToTritonGPUPass(int numWarps, int threadsPerWarp = 32);
createConvertTritonToTritonGPUPass(int numWarps, int threadsPerWarp = 32,
int numCTAs = 1, int computeCapability = 80);
} // namespace triton
} // namespace mlir

View File

@@ -1,2 +1,4 @@
add_subdirectory(Triton)
add_subdirectory(TritonGPU)
add_subdirectory(TritonNvidiaGPU)
add_subdirectory(NVGPU)

View File

@@ -0,0 +1,2 @@
add_subdirectory(IR)
#add_subdirectory(Transforms)

View File

@@ -0,0 +1,14 @@
set(LLVM_TARGET_DEFINITIONS NVGPUOps.td)
mlir_tablegen(Dialect.h.inc -gen-dialect-decls -dialect=nvgpu)
mlir_tablegen(Dialect.cpp.inc -gen-dialect-defs -dialect=nvgpu)
mlir_tablegen(OpsConversions.inc -gen-llvmir-conversions)
mlir_tablegen(Ops.h.inc -gen-op-decls)
mlir_tablegen(Ops.cpp.inc -gen-op-defs)
mlir_tablegen(OpsEnums.h.inc -gen-enum-decls)
mlir_tablegen(OpsEnums.cpp.inc -gen-enum-defs)
add_public_tablegen_target(NVGPUTableGen)
set(LLVM_TARGET_DEFINITIONS NVGPUAttrDefs.td)
mlir_tablegen(NVGPUAttrDefs.h.inc -gen-attrdef-decls)
mlir_tablegen(NVGPUAttrDefs.cpp.inc -gen-attrdef-defs)
add_public_tablegen_target(NVGPUAttrDefsIncGen)

View File

@@ -0,0 +1,47 @@
/*
* Copyright (c) 2023 NVIDIA Corporation & Affiliates. All rights reserved.
*
* Permission is hereby granted, free of charge, to any person obtaining
* a copy of this software and associated documentation files
* (the "Software"), to deal in the Software without restriction,
* including without limitation the rights to use, copy, modify, merge,
* publish, distribute, sublicense, and/or sell copies of the Software,
* and to permit persons to whom the Software is furnished to do so,
* subject to the following conditions:
*
* The above copyright notice and this permission notice shall be
* included in all copies or substantial portions of the Software.
*
* THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND,
* EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF
* MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT.
* IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY
* CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION OF CONTRACT,
* TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION WITH THE
* SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE.
*/
#ifndef TRITON_DIALECT_NVGPU_IR_DIALECT_H_
#define TRITON_DIALECT_NVGPU_IR_DIALECT_H_
#include "mlir/Dialect/GPU/IR/GPUDialect.h"
#include "mlir/Dialect/LLVMIR/LLVMDialect.h"
#include "mlir/Dialect/Tensor/IR/Tensor.h"
#include "mlir/IR/BuiltinOps.h"
#include "mlir/IR/Dialect.h"
#include "triton/Dialect/NVGPU/IR/Dialect.h.inc"
#include "triton/Dialect/NVGPU/IR/OpsEnums.h.inc"
#define GET_ATTRDEF_CLASSES
#include "triton/Dialect/NVGPU/IR/NVGPUAttrDefs.h.inc"
#define GET_OP_CLASSES
#include "triton/Dialect/NVGPU/IR/Ops.h.inc"
namespace mlir {
namespace triton {
namespace nvgpu {} // namespace nvgpu
} // namespace triton
} // namespace mlir
#endif // TRITON_DIALECT_TRITONGPU_IR_DIALECT_H_

View File

@@ -0,0 +1,33 @@
// Copyright (c) 2023 NVIDIA Corporation & Affiliates. All rights reserved.
//
// Permission is hereby granted, free of charge, to any person obtaining
// a copy of this software and associated documentation files
// (the "Software"), to deal in the Software without restriction,
// including without limitation the rights to use, copy, modify, merge,
// publish, distribute, sublicense, and/or sell copies of the Software,
// and to permit persons to whom the Software is furnished to do so,
// subject to the following conditions:
//
// The above copyright notice and this permission notice shall be
// included in all copies or substantial portions of the Software.
//
// THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND,
// EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF
// MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT.
// IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY
// CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION OF CONTRACT,
// TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION WITH THE
// SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE.
#ifndef NVGPU_ATTRDEFS
#define NVGPU_ATTRDEFS
include "triton/Dialect/NVGPU/IR/NVGPUDialect.td"
include "mlir/IR/AttrTypeBase.td"
class NVGPU_Attr<string name, list<Trait> traits = [],
string baseCppClass = "::mlir::Attribute">
: AttrDef<NVGPU_Dialect, name, traits, baseCppClass> {
}
#endif

View File

@@ -0,0 +1,40 @@
// Copyright (c) 2023 NVIDIA Corporation & Affiliates. All rights reserved.
//
// Permission is hereby granted, free of charge, to any person obtaining
// a copy of this software and associated documentation files
// (the "Software"), to deal in the Software without restriction,
// including without limitation the rights to use, copy, modify, merge,
// publish, distribute, sublicense, and/or sell copies of the Software,
// and to permit persons to whom the Software is furnished to do so,
// subject to the following conditions:
//
// The above copyright notice and this permission notice shall be
// included in all copies or substantial portions of the Software.
//
// THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND,
// EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF
// MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT.
// IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY
// CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION OF CONTRACT,
// TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION WITH THE
// SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE.
#ifndef NVGPU_DIALECT
#define NVGPU_DIALECT
include "mlir/IR/OpBase.td"
def NVGPU_Dialect : Dialect {
let name = "nvgpu";
let cppNamespace = "::mlir::triton::nvgpu";
let description = [{
NVGPU Dialect.
}];
let dependentDialects = [
"mlir::LLVM::LLVMDialect"
];
}
#endif

View File

@@ -0,0 +1,380 @@
// Copyright (c) 2023 NVIDIA Corporation & Affiliates. All rights reserved.
//
// Permission is hereby granted, free of charge, to any person obtaining
// a copy of this software and associated documentation files
// (the "Software"), to deal in the Software without restriction,
// including without limitation the rights to use, copy, modify, merge,
// publish, distribute, sublicense, and/or sell copies of the Software,
// and to permit persons to whom the Software is furnished to do so,
// subject to the following conditions:
//
// The above copyright notice and this permission notice shall be
// included in all copies or substantial portions of the Software.
//
// THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND,
// EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF
// MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT.
// IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY
// CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION OF CONTRACT,
// TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION WITH THE
// SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE.
#ifndef NVGPU_OPS
#define NVGPU_OPS
include "triton/Dialect/NVGPU/IR/NVGPUDialect.td"
include "triton/Dialect/NVGPU/IR/NVGPUAttrDefs.td"
include "mlir/IR/OpBase.td"
include "mlir/IR/EnumAttr.td"
include "mlir/Dialect/LLVMIR/LLVMOpBase.td"
def I8Ptr_global : LLVM_IntPtrBase<8, 1>;
def I8Ptr_shared : LLVM_IntPtrBase<8, 3>;
def I64Ptr_shared : LLVM_IntPtrBase<64, 3>;
class NVGPU_Op<string mnemonic, list<Trait> traits = []> :
LLVM_OpBase<NVGPU_Dialect, mnemonic, traits>;
def NVGPU_WGMMAFenceOp : NVGPU_Op<"wgmma_fence", []> {
string llvmBuilder = [{
createWGMMAFence(builder);
}];
let assemblyFormat = "attr-dict";
}
def NVGPU_WGMMACommitGroupOp : NVGPU_Op<"wgmma_commit_group", []> {
let assemblyFormat = "attr-dict";
string llvmBuilder = [{
createWGMMACommitGroup(builder);
}];
}
def NVGPU_WGMMAWaitOp : NVGPU_Op<"wgmma_wait_group", []> {
let arguments = (ins I32Attr:$pendings);
let assemblyFormat = "attr-dict";
string llvmBuilder = [{
createWGMMAWaitGroup(builder, $pendings);
}];
}
def NVGPU_MBarrierInitOp : NVGPU_Op<"mbarrier_init", [MemoryEffects<[MemWrite]>]> {
let arguments = (ins I64Ptr_shared:$mbarrier, I1:$pred, I32Attr:$count);
let assemblyFormat = "$mbarrier `,` $pred attr-dict `:` type($mbarrier)";
string llvmBuilder = [{
auto *i32Ty = builder.getInt32Ty();
auto *arriveCnt = builder.getInt32($count);
createExternalCall(builder, "__nv_mbarrier_init", {builder.CreatePtrToInt($mbarrier, i32Ty),
arriveCnt,
builder.CreateIntCast($pred, i32Ty, false)});
}];
}
def MBarrier_ArriveTypeAttr : I32EnumAttr<"MBarriveType",
"mbarrier arrive type, either 'normal', 'expect_tx', 'cp_async'",
[
I32EnumAttrCase<"normal", 0>,
I32EnumAttrCase<"cp_async", 1>,
I32EnumAttrCase<"expect_tx", 2>,
I32EnumAttrCase<"remote", 3>,
]>{
let cppNamespace = "::mlir::triton::nvgpu";
}
def NVGPU_MBarrierArriveOp : NVGPU_Op<"mbarrier_arrive", []> {
let arguments = (ins I64Ptr_shared:$mbarrier, I1:$pred, Optional<I32>:$ctaId, MBarrier_ArriveTypeAttr:$arriveType, DefaultValuedAttr<I32Attr, "0">:$txCount);
let assemblyFormat = "$mbarrier `,` $pred (`,` $ctaId^)? attr-dict `:` type($mbarrier)";
string llvmBuilder = [{
auto *i32Ty = builder.getInt32Ty();
createMBarrierArrive(builder, $arriveType, builder.CreatePtrToInt($mbarrier, i32Ty),
builder.CreateIntCast($pred, i32Ty, false), $ctaId,
$txCount);
}];
}
def NVGPU_MBarrierWaitOp : NVGPU_Op<"mbarrier_wait", []> {
let arguments = (ins I64Ptr_shared:$mbarrier, I1:$phase);
let assemblyFormat = "$mbarrier `,` $phase attr-dict `:` type(operands)";
string llvmBuilder = [{
auto *i32Ty = builder.getInt32Ty();
createExternalCall(builder, "__nv_mbarrier_wait", {builder.CreatePtrToInt($mbarrier, i32Ty),
builder.CreateIntCast($phase, i32Ty, false)});
}];
}
def NVGPU_NamedBarrierArriveOp : NVGPU_Op<"bar_arrive", []> {
let arguments = (ins I32:$bar, I32:$numThreads);
let assemblyFormat = "$bar `,` $numThreads attr-dict `:` type(operands)";
string llvmBuilder = [{
createExternalCall(builder, "__nv_bar_arrive", {$bar, $numThreads});
}];
}
def NVGPU_NamedBarrierWaitOp : NVGPU_Op<"bar_wait", []> {
let arguments = (ins I32:$bar, I32:$numThreads);
let assemblyFormat = "$bar `,` $numThreads attr-dict `:` type(operands)";
string llvmBuilder = [{
createExternalCall(builder, "__nv_bar_wait", {$bar, $numThreads});
}];
}
def WGMMADesc_ModeAttr : I32EnumAttr<"WGMMADescMode",
"wgmma desc mode, either 'none', 'swizzle128', 'swizzle64', or 'swizzle32'",
[
I32EnumAttrCase<"none", 0>,
I32EnumAttrCase<"swizzle128", 1>,
I32EnumAttrCase<"swizzle64", 2>,
I32EnumAttrCase<"swizzle32", 3>
]>{
let cppNamespace = "::mlir::triton::nvgpu";
}
def NVGPU_WGMMADescCreateOp : NVGPU_Op<"wgmma_desc_create", []> {
let arguments = (ins LLVM_AnyPointer:$buffer, I32:$height, WGMMADesc_ModeAttr:$mode);
let results = (outs I64:$res);
let assemblyFormat = "$buffer `,` $height attr-dict `:` functional-type(operands, results)";
string llvmBuilder = [{
$res = createWGMMADesc(builder, builder.CreatePtrToInt($buffer, builder.getInt32Ty()), $mode, $height);
}];
}
def NVGPU_TMALoadTiledOp : NVGPU_Op<"tma_load_tiled", [AttrSizedOperandSegments]> {
let arguments = (ins I8Ptr_shared:$dst, I64Ptr_shared:$mbarrier, I8Ptr_global:$tmaDesc, I64:$l2Desc,
I1:$pred, Variadic<I32>:$coords, Optional<I16>:$mcastMask);
let assemblyFormat = "operands attr-dict `:` type(operands)";
string llvmBuilder = [{
auto *i32Ty = builder.getInt32Ty();
auto *i64Ty = builder.getInt64Ty();
createTMALoadTiled(builder,
builder.CreatePtrToInt($dst, i32Ty),
builder.CreatePtrToInt($mbarrier, i32Ty),
builder.CreatePtrToInt($tmaDesc, i64Ty),
$l2Desc, $mcastMask, builder.CreateIntCast($pred, builder.getInt32Ty(), false), $coords);
}];
}
def NVGPU_TMALoadIm2colOp : NVGPU_Op<"tma_load_im2col", []> {
let arguments = (ins I8Ptr_shared:$dst, I64Ptr_shared:$mbarrier, I8Ptr_global:$tmaDesc, I64:$l2Desc, LLVM_AnyStruct:$im2colOffsets, I1:$pred, Variadic<I32>:$coords, I16Attr:$mcastMask);
let assemblyFormat = "operands attr-dict `:` type(operands)";
string llvmBuilder = [{
auto *i32Ty = builder.getInt32Ty();
auto *i64Ty = builder.getInt64Ty();
createTMALoadIm2col(builder,
builder.CreatePtrToInt($dst, i32Ty),
builder.CreatePtrToInt($mbarrier, i32Ty),
builder.CreatePtrToInt($tmaDesc, i64Ty),
$l2Desc, $mcastMask, $im2colOffsets, builder.CreateIntCast($pred, builder.getInt32Ty(), false), $coords);
}];
}
def WGMMA_LayoutAttr : I32EnumAttr<"WGMMALayout",
"wgmma layout, either 'row' or 'col'",
[
I32EnumAttrCase<"row", 0>,
I32EnumAttrCase<"col", 1>
]>{
let cppNamespace = "::mlir::triton::nvgpu";
}
def WGMMA_EltTypeAttr : I32EnumAttr<"WGMMAEltType",
"wgmma operand type, either 's8', 's32', 'e4m3', 'e5m2', 'f16', 'bf16', 'tf32', or 'f32'",
[
I32EnumAttrCase<"s8", 0>,
I32EnumAttrCase<"s32", 1>,
I32EnumAttrCase<"e4m3", 2>,
I32EnumAttrCase<"e5m2", 3>,
I32EnumAttrCase<"f16", 4>,
I32EnumAttrCase<"bf16", 5>,
I32EnumAttrCase<"tf32", 6>,
I32EnumAttrCase<"f32", 7>
]>{
let cppNamespace = "::mlir::triton::nvgpu";
}
def WGMMA_OperandType : AnyTypeOf<[LLVM_AnyStruct, I64], "wgmma operand A/B type">;
def NVGPU_WGMMAOp : NVGPU_Op<"wgmma", []> {
let arguments = (ins WGMMA_OperandType:$opA, WGMMA_OperandType:$opB, LLVM_AnyStruct:$opC,
I32Attr:$m, I32Attr:$n, I32Attr:$k,
WGMMA_EltTypeAttr:$eltTypeC, WGMMA_EltTypeAttr:$eltTypeA, WGMMA_EltTypeAttr:$eltTypeB,
WGMMA_LayoutAttr:$layoutA, WGMMA_LayoutAttr:$layoutB);
let results = (outs LLVM_AnyStruct:$res);
let assemblyFormat = "$opA `,` $opB `,` $opC attr-dict `:` functional-type(operands, $res)";
string llvmBuilder = [{
$res = createWGMMA(builder, $m, $n, $k, $eltTypeC, $eltTypeA, $eltTypeB, $layoutA, $layoutB, $opA, $opB, $opC);
}];
}
def NVGPU_CGABarrierSyncOp : NVGPU_Op<"cga_barrier_sync", []> {
let assemblyFormat = "attr-dict";
string llvmBuilder = [{
createExternalCall(builder, "__nv_cga_barrier_sync");
}];
}
def NVGPU_CGABarrierArriveOp : NVGPU_Op<"cga_barrier_arrive", []> {
let assemblyFormat = "attr-dict";
string llvmBuilder = [{
createExternalCall(builder, "__nv_cga_barrier_arrive");
}];
}
def NVGPU_CGABarrierWaitOp : NVGPU_Op<"cga_barrier_wait", []> {
let assemblyFormat = "attr-dict";
string llvmBuilder = [{
createExternalCall(builder, "__nv_cga_barrier_wait");
}];
}
def NVGPU_LoadDSmemOp : NVGPU_Op<"load_dsmem", [MemoryEffects<[MemRead]>]> {
let arguments = (ins LLVM_AnyPointer:$addr, I32:$ctaId, I32Attr:$bitwidth, I32Attr:$vec);
let builders = [
OpBuilder<(ins "Type":$resultTy, "Value":$addr, "Value":$ctaId)>,
OpBuilder<(ins "Value":$addr, "Value":$ctaId, "unsigned":$bitwidth, "unsigned":$vec)>,
OpBuilder<(ins "Value":$addr, "Value":$ctaId, "unsigned":$bitwidth)>
];
let results = (outs LLVM_LoadableType:$result);
let assemblyFormat = "operands attr-dict `:` functional-type(operands, results)";
string llvmBuilder = [{
$result = createLoadSharedCluster(builder, $addr, $ctaId, $bitwidth, $vec);
}];
}
def NVGPU_StoreDSmemOp : NVGPU_Op<"store_dsmem", [MemoryEffects<[MemWrite]>]> {
let arguments = (ins LLVM_AnyPointer:$addr, I32:$ctaId,
Variadic<LLVM_LoadableType>:$values, I1:$pred);
let builders = [
OpBuilder<(ins "Value":$addr, "Value":$ctaId, "Value":$value, "Value":$pred)>,
];
let assemblyFormat = "operands attr-dict `:` type(operands)";
string llvmBuilder = [{
createStoreSharedCluster(builder, $addr, $ctaId, $values, $pred, op.getBitwidth(), op.getVec());
}];
let extraClassDeclaration = [{
unsigned getBitwidth();
unsigned getVec();
}];
}
def NVGPU_FenceAsyncSharedOp : NVGPU_Op<"fence_async_shared", []> {
let arguments = (ins BoolAttr:$bCluster);
string llvmBuilder = [{
if ($bCluster)
createExternalCall(builder, "__nv_fence_async_shared_cluster", {});
else
createExternalCall(builder, "__nv_fence_async_shared_cta", {});
}];
let assemblyFormat = "attr-dict";
}
def NVGPU_FenceMBarrierInitOp : NVGPU_Op<"fence_mbarrier_init", []> {
string llvmBuilder = [{
createExternalCall(builder, "__nv_fence_mbarrier_init", {});
}];
let assemblyFormat = "attr-dict";
}
def NVGPU_ClusterArriveOp : NVGPU_Op<"cluster_arrive", []> {
let arguments = (ins I1Attr:$relaxed);
string llvmBuilder = [{
if ($relaxed)
createExternalCall(builder, "__nv_cluster_arrive_relaxed", {});
else
createExternalCall(builder, "__nv_cluster_arrive", {});
}];
let assemblyFormat = "attr-dict";
}
def NVGPU_ClusterWaitOp : NVGPU_Op<"cluster_wait", []> {
string llvmBuilder = [{
createExternalCall(builder, "__nv_cluster_wait", {});
}];
let assemblyFormat = "attr-dict";
}
def NVGPU_TMAStoreTiledOp : NVGPU_Op<"tma_store_tiled", [MemoryEffects<[MemWrite]>]> {
let arguments = (ins I8Ptr_global:$tmaDesc, I8Ptr_shared:$src, I1:$pred, Variadic<I32>:$coords);
let assemblyFormat = "operands attr-dict `:` type(operands)";
string llvmBuilder = [{
auto *i32Ty = builder.getInt32Ty();
auto *i64Ty = builder.getInt64Ty();
createTMAStoreTiled(builder,
builder.CreatePtrToInt($tmaDesc, i64Ty),
builder.CreatePtrToInt($src, i32Ty),
builder.CreateIntCast($pred, i32Ty, false), $coords);
}];
}
def NVGPU_StoreMatrixOp : NVGPU_Op<"stmatrix", [MemoryEffects<[MemWrite]>]> {
let arguments = (ins I8Ptr_shared:$addr, Variadic<I32>:$datas);
let assemblyFormat = "operands attr-dict `:` type(operands)";
string llvmBuilder = [{
auto *i32Ty = builder.getInt32Ty();
createStoreMatrix(builder,
builder.CreatePtrToInt($addr, i32Ty),
$datas);
}];
}
def NVGPU_OffsetOfStmatrixV4Op : NVGPU_Op<"offset_of_stmatrix_v4", []> {
let arguments = (ins I32:$threadId, I32:$rowOfWarp, I32:$elemIdx, I32Attr:$leadingDimOffset, I32Attr:$rowStride, I1Attr:$swizzleEnabled);
let results = (outs I32:$offset);
let assemblyFormat = "operands attr-dict `:` type(operands) `->` type($offset)";
string llvmBuilder = [{
$offset = createOffsetOfStmatrixV4(builder, $threadId, $rowOfWarp, $elemIdx, $leadingDimOffset, $rowStride, $swizzleEnabled);
}];
}
def NVGPU_OffsetOfSts64Op : NVGPU_Op<"offset_of_sts64", []> {
let arguments = (ins I32:$threadId, I32:$rowOfWarp, I32:$elemIdx, I32Attr:$leadingDimOffset, I32Attr:$rowStride, I1Attr:$swizzleEnabled);
let results = (outs I32:$offset);
let assemblyFormat = "operands attr-dict `:` type(operands) `->` type($offset)";
string llvmBuilder = [{
$offset = createOffsetOfSts64(builder, $threadId, $rowOfWarp, $elemIdx, $leadingDimOffset, $rowStride, $swizzleEnabled);
}];
}
def NVGPU_Sts64Op : NVGPU_Op<"sts64", [MemoryEffects<[MemWrite]>]> {
let arguments = (ins I32:$offset, AnyTypeOf<[F32, I32]>:$d0, AnyTypeOf<[F32, I32]>:$d1);
let assemblyFormat = "operands attr-dict `:` type(operands)";
string llvmBuilder = [{
createSts64(builder, $offset, $d0, $d1);
}];
}
def NVGPU_CvtPackOp : NVGPU_Op<"cvt_pack", []> {
let arguments = (ins AnyTypeOf<[F16, I16]>:$d0, AnyTypeOf<[F16, I16]>:$d1);
let results = (outs I32:$result);
let assemblyFormat = "operands attr-dict `:` type(operands) `->` type($result)";
string llvmBuilder = [{
$result = createCvtPack(builder, $d0, $d1);
}];
}
def NVGPU_ClusterCTAIdOp : NVGPU_Op<"cluster_id", [Pure]> {
let results = (outs I32:$result);
let assemblyFormat = "attr-dict";
string llvmBuilder = [{
$result = createClusterId(builder);
}];
}
def NVGPU_RegAllocOp : NVGPU_Op<"reg_alloc", []> {
let arguments = (ins I32Attr: $regCount);
let assemblyFormat = "operands attr-dict `:` type(operands)";
string llvmBuilder = [{
createRegAlloc(builder, $regCount);
}];
}
def NVGPU_RegDeallocOp : NVGPU_Op<"reg_dealloc", []> {
let arguments = (ins I32Attr: $regCount);
let assemblyFormat = "operands attr-dict `:` type(operands)";
string llvmBuilder = [{
createRegDealloc(builder, $regCount);
}];
}
#endif

View File

@@ -0,0 +1,41 @@
/*
* Copyright (c) 2023 NVIDIA Corporation & Affiliates. All rights reserved.
*
* Permission is hereby granted, free of charge, to any person obtaining
* a copy of this software and associated documentation files
* (the "Software"), to deal in the Software without restriction,
* including without limitation the rights to use, copy, modify, merge,
* publish, distribute, sublicense, and/or sell copies of the Software,
* and to permit persons to whom the Software is furnished to do so,
* subject to the following conditions:
*
* The above copyright notice and this permission notice shall be
* included in all copies or substantial portions of the Software.
*
* THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND,
* EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF
* MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT.
* IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY
* CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION OF CONTRACT,
* TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION WITH THE
* SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE.
*/
#ifndef TRITON_DIALECT_NVGPU_NVGPUTOLLVMIRTRANSLATION_H
#define TRITON_DIALECT_NVGPU_NVGPUTOLLVMIRTRANSLATION_H
namespace mlir {
class DialectRegistry;
class MLIRContext;
/// Register the nvgpu dialect and the translation from it to the LLVM IR in the
/// given registry;
void registerNVGPUDialectTranslation(DialectRegistry &registry);
/// Register the nvgpu dialect and the translation from it in the registry
/// associated with the given context.
void registerNVGPUDialectTranslation(MLIRContext &context);
} // namespace mlir
#endif // TRITON_DIALECT_NVGPU_NVGPUTOLLVMIRTRANSLATION_H

View File

@@ -9,6 +9,8 @@
#include "mlir/Dialect/Tensor/IR/Tensor.h"
#include "mlir/IR/BuiltinOps.h"
#include "mlir/IR/Dialect.h"
#include "mlir/IR/FunctionInterfaces.h"
#include "mlir/Interfaces/CallInterfaces.h"
#include "mlir/Interfaces/ControlFlowInterfaces.h"
#include "triton/Dialect/Triton/IR/Dialect.h.inc"
#include "triton/Dialect/Triton/IR/OpsEnums.h.inc"

View File

@@ -9,6 +9,8 @@ include "mlir/IR/OpBase.td"
include "mlir/IR/FunctionInterfaces.td" // FunctionOpInterface
include "mlir/IR/SymbolInterfaces.td" // SymbolUserOpInterface
include "mlir/IR/OpAsmInterface.td" // OpAsmOpInterface
include "mlir/Interfaces/CallInterfaces.td" // CallOpInterface
include "mlir/Interfaces/CastInterfaces.td" // CastOpInterface
include "mlir/Interfaces/SideEffectInterfaces.td" // Pure
include "mlir/Interfaces/ControlFlowInterfaces.td" // BranchOpInterface
include "mlir/Interfaces/InferTypeOpInterface.td" // SameOperandsAndResultType
@@ -563,6 +565,7 @@ def TT_MakeTensorPtrOp : TT_Op<"make_tensor_ptr",
let results = (outs TT_TensorPtr:$result);
// TODO(Keren): define a custom assembly format for this op because the result type cannot be printed correctly
// Add additional `[]` to increase readability and split variadic lists
let assemblyFormat = "$base `,` `[` $shape `]` `,` `[` $strides `]` `,` `[` $offsets `]` attr-dict `:` type($result)";

View File

@@ -74,7 +74,7 @@ def TT_PtrType : TritonTypeDef<"Pointer", "ptr"> {
// Scalar Pointer Type: `ptr<>`
def TT_Ptr : TT_PtrOf<[AnyType]>;
// Tensor of Pointer Type
// Tensor of Pointer Type: `tensor<ptr<>>`
def TT_PtrTensor : TensorOf<[TT_Ptr]>;
// Tensor of Pointer Type or Pointer type: `tensor<ptr<>>` or `ptr<>`

View File

@@ -14,6 +14,8 @@ namespace triton {
bool isTensorPointerType(Type type);
bool isTensorOrTensorPointerType(Type type);
unsigned getPointeeBitWidth(Type type);
Type getPointeeType(Type type);

View File

@@ -9,7 +9,6 @@ namespace triton {
std::unique_ptr<Pass> createCombineOpsPass();
std::unique_ptr<Pass> createReorderBroadcastPass();
std::unique_ptr<Pass>
createRewriteTensorPointerPass(int computeCapability = 80);

View File

@@ -3,9 +3,13 @@ mlir_tablegen(Dialect.h.inc -gen-dialect-decls -dialect=triton_gpu)
mlir_tablegen(Dialect.cpp.inc -gen-dialect-defs -dialect=triton_gpu)
mlir_tablegen(Ops.h.inc -gen-op-decls)
mlir_tablegen(Ops.cpp.inc -gen-op-defs)
mlir_tablegen(Types.h.inc -gen-typedef-decls -typedefs-dialect=triton_gpu)
mlir_tablegen(Types.cpp.inc -gen-typedef-defs -typedefs-dialect=triton_gpu)
add_public_tablegen_target(TritonGPUTableGen)
set(LLVM_TARGET_DEFINITIONS TritonGPUAttrDefs.td)
mlir_tablegen(TritonGPUAttrDefs.h.inc -gen-attrdef-decls)
mlir_tablegen(TritonGPUAttrDefs.cpp.inc -gen-attrdef-defs)
mlir_tablegen(OpsEnums.h.inc -gen-enum-decls)
mlir_tablegen(OpsEnums.cpp.inc -gen-enum-defs)
add_public_tablegen_target(TritonGPUAttrDefsIncGen)

View File

@@ -7,6 +7,7 @@
#include "mlir/IR/Dialect.h"
// TritonGPU depends on Triton
#include "triton/Dialect/NVGPU/IR/Dialect.h"
#include "triton/Dialect/Triton/IR/Dialect.h"
#include "triton/Dialect/TritonGPU/IR/Attributes.h"
#include "triton/Dialect/TritonGPU/IR/Dialect.h.inc"
@@ -71,17 +72,41 @@ getWarpsPerCTAWithUniqueData(Attribute layout, ArrayRef<int64_t> tensorShape);
SmallVector<unsigned> getThreadsPerCTA(Attribute layout);
SmallVector<unsigned>
getShapePerCTA(Attribute layout,
ArrayRef<int64_t> tensorShape = ArrayRef<int64_t>());
SmallVector<unsigned> getOrder(Attribute layout);
CTALayoutAttr getCTALayout(Attribute layout);
SmallVector<unsigned> getCTAsPerCGA(Attribute layout);
SmallVector<unsigned> getCTASplitNum(Attribute layout);
SmallVector<unsigned> getCTAOrder(Attribute layout);
/* The difference between ShapePerCTATile and ShapePerCTA:
* (1) ShapePerCTATile is defined by SizePerThread * ThreadsPerWarp *
* WarpsPerCTA in each dimension and is independent from the tensor shape.
* (2) ShapePerCTA is defined by shape / CTASplitNum in each dimension.
* (3) In the implementation of emitIndices, ShapePerCTATile will
* be replicated or wraped to fit ShapePerCTA.
*/
SmallVector<unsigned>
getShapePerCTATile(Attribute layout,
ArrayRef<int64_t> tensorShape = ArrayRef<int64_t>());
SmallVector<int64_t> getShapePerCTA(ArrayRef<unsigned> CTASplitNum,
ArrayRef<int64_t> shape);
SmallVector<int64_t> getShapePerCTA(Attribute layout, ArrayRef<int64_t> shape);
SmallVector<int64_t> getShapePerCTA(Type type);
unsigned getNumWarpsPerCTA(Attribute layout);
unsigned getNumCTAs(Attribute layout);
bool isaDistributedLayout(Attribute layout);
bool isSharedEncoding(Value value);
bool isExpensiveCat(CatOp cat, Attribute &targetEncoding);
bool isExpensiveCat(CatOp cat, Attribute targetEncoding);
} // namespace gpu
} // namespace triton

View File

@@ -41,6 +41,19 @@ Right now, Triton implements two classes of layouts: shared, and distributed.
}];
}
//===----------------------------------------------------------------------===//
// CTA Layout
//===----------------------------------------------------------------------===//
def CTALayoutAttr : TritonGPU_Attr<"CTALayout"> {
let parameters = (
ins
ArrayRefParameter<"unsigned">:$CTAsPerCGA,
ArrayRefParameter<"unsigned">:$CTASplitNum,
ArrayRefParameter<"unsigned">:$CTAOrder
);
}
//===----------------------------------------------------------------------===//
// Shared Layout Encoding
//===----------------------------------------------------------------------===//
@@ -64,26 +77,49 @@ are stored contiguously
_ _ _ _ /\_ _ _ _
A_{2, 2} A_{2, 3} A_{2, 0} A_{2, 1} ... [phase 1] \ per phase = 2
A_{3, 2} A_{3, 3} A_{3, 0} A_{3, 1} ... [phase 1] /
For MMAv3 eg Hopper GMMA, hasLeadingOffset should be true. In this case,
when the matrix is stored in shared memory, there will be an offset not
only in the stride dimension, but also in the leading dimension. For example,
a matrix of size 16x128 and data type I8 is stored in the shared memory with
64B-swizzle mode. The offset of the element with index (0, 64) will be 16*64,
compared to 1*64 when the hasLeadingOffset is false.
}];
// swizzle info: vec, perPhase, maxPhase
// order: the fastest-changing axis first
let parameters = (
ins
// swizzle info
"unsigned":$vec, "unsigned":$perPhase, "unsigned":$maxPhase,
ArrayRefParameter<"unsigned", "order of axes by the rate of changing">:$order
"unsigned":$vec,
"unsigned":$perPhase,
"unsigned":$maxPhase,
ArrayRefParameter<"unsigned">:$order,
"CTALayoutAttr":$CTALayout,
"bool":$hasLeadingOffset
);
let builders = [
AttrBuilder<(ins "unsigned":$vec,
"unsigned":$perPhase,
"unsigned":$maxPhase,
"ArrayRef<unsigned>":$order,
"CTALayoutAttr":$CTALayout), [{
bool hasLeadingOffset = false; // default value
return $_get(context, vec, perPhase, maxPhase, order, CTALayout, hasLeadingOffset);
}]>,
AttrBuilder<(ins "DotOperandEncodingAttr":$dotOpEnc,
"ArrayRef<int64_t>":$shape,
"ArrayRef<unsigned>":$order,
"CTALayoutAttr":$CTALayout,
"unsigned":$typeWidthInBit), [{
auto mmaEnc = dotOpEnc.getParent().dyn_cast<MmaEncodingAttr>();
if(!mmaEnc)
return $_get(context, 1, 1, 1, order);
return get(context, 1, 1, 1, order, CTALayout);
int opIdx = dotOpEnc.getOpIdx();
auto shapePerCTA = getShapePerCTA(CTALayout.getCTASplitNum(), shape);
// number of rows per phase
@@ -92,34 +128,34 @@ A_{3, 2} A_{3, 3} A_{3, 0} A_{3, 1} ... [phase 1] /
// ---- begin Volta ----
if (mmaEnc.isVolta()) {
int perPhase = 128 / (shape[order[0]] * (typeWidthInBit / 8));
int perPhase = 128 / (shapePerCTA[order[0]] * (typeWidthInBit / 8));
perPhase = std::max<int>(perPhase, 1);
bool is_row = order[0] != 0;
bool is_vec4 = opIdx == 0 ? !is_row && (shape[order[0]] <= 16) :
is_row && (shape[order[0]] <= 16);
bool is_vec4 = opIdx == 0 ? !is_row && (shapePerCTA[order[0]] <= 16) :
is_row && (shapePerCTA[order[0]] <= 16);
int pack_size = opIdx == 0 ? ((is_row || is_vec4) ? 1 : 2) :
((is_row && !is_vec4) ? 2 : 1);
int rep = 2 * pack_size;
int maxPhase = (order[inner] == 1 ? 8 : 4) / perPhase;
int vec = 2 * rep;
return $_get(context, vec, perPhase, maxPhase, order);
return get(context, vec, perPhase, maxPhase, order, CTALayout);
}
// ---- begin Ampere ----
if (mmaEnc.isAmpere()) {
int perPhase = 128 / (shape[order[0]] * 4 / dotOpEnc.getMMAv2kWidth());
int perPhase = 128 / (shapePerCTA[order[0]] * 4 / dotOpEnc.getMMAv2kWidth());
perPhase = std::max<int>(perPhase, 1);
std::vector<size_t> matShape = {8, 8, 4 * dotOpEnc.getMMAv2kWidth()};
// for now, disable swizzle when using transposed int8 tensor cores
if ((32 / typeWidthInBit != dotOpEnc.getMMAv2kWidth()) && order[0] == inner)
return $_get(context, 1, 1, 1, order);
return get(context, 1, 1, 1, order, CTALayout);
// --- handle A operand ---
if (opIdx == 0) { // compute swizzling for A operand
int vec = (order[0] == 1) ? matShape[2] : matShape[0]; // k : m
int mmaStride = (order[0] == 1) ? matShape[0] : matShape[2];
int maxPhase = mmaStride / perPhase;
return $_get(context, vec, perPhase, maxPhase, order);
return get(context, vec, perPhase, maxPhase, order, CTALayout);
}
// --- handle B operand ---
@@ -127,12 +163,19 @@ A_{3, 2} A_{3, 3} A_{3, 0} A_{3, 1} ... [phase 1] /
int vec = (order[0] == 1) ? matShape[1] : matShape[2]; // n : k
int mmaStride = (order[0] == 1) ? matShape[2] : matShape[1];
int maxPhase = mmaStride / perPhase;
return $_get(context, vec, perPhase, maxPhase, order);
return get(context, vec, perPhase, maxPhase, order, CTALayout);
}
llvm_unreachable("invalid operand index");
}
// ---- begin version 3 ----
if (mmaEnc.isHopper()) {
llvm_unreachable("SharedEncodingAttr builder when the MMAEncodingAttr"
" is Hopper has not been implemented yet");
return $_get(context, 1, 1, 1, order, CTALayout, true);
}
// ---- not implemented ----
llvm_unreachable("unsupported swizzling for provided MMA version");
}]>,
@@ -140,9 +183,38 @@ A_{3, 2} A_{3, 3} A_{3, 0} A_{3, 1} ... [phase 1] /
AttrBuilder<(ins "DotOperandEncodingAttr":$dotOpEnc,
"ArrayRef<int64_t>":$shape,
"ArrayRef<unsigned>":$order,
"CTALayoutAttr":$CTALayout,
"Type":$eltTy), [{
unsigned bitwidth = eltTy.getIntOrFloatBitWidth();
return get(context, dotOpEnc, shape, order, bitwidth);
return get(context, dotOpEnc, shape, order, CTALayout, bitwidth);
}]>,
AttrBuilder<(ins "ArrayRef<int64_t>":$shape,
"ArrayRef<unsigned>":$order,
"CTALayoutAttr":$CTALayout,
"Type":$eltTy), [{
auto shapePerCTA = getShapePerCTA(CTALayout.getCTASplitNum(), shape);
int32_t eleBitWidth = eltTy.getIntOrFloatBitWidth();
int32_t vec = 128 / eleBitWidth, perPhase = 1, maxPhase = 1;
// get proper shared memory swizzling mode from the contiguous dimension
// size of the origin blocked layout.
auto contigDimSizeInByte = shapePerCTA[order[0]] * eleBitWidth / 8;
if (contigDimSizeInByte >= 128 && contigDimSizeInByte % 128 == 0) {
perPhase = 1;
maxPhase = 8;
} else if (contigDimSizeInByte >= 64 && contigDimSizeInByte % 64 == 0) {
perPhase = 2;
maxPhase = 4;
} else if (contigDimSizeInByte >= 32 && contigDimSizeInByte % 32 == 0) {
perPhase = 4;
maxPhase = 2;
} else {
llvm_unreachable("unsupported shared memory layout for MMAv3");
}
return $_get(context, vec, perPhase, maxPhase, order, CTALayout, true);
}]>
];
@@ -194,7 +266,7 @@ used to promote memory coalescing in LoadInst and StoreInst.
It is characterized by three tuples -- thread tile size, warp tile size, and block tile size -- which
specify the amount of elements owned by each CUDA thread, warp and CTA respectively.
For example, a row-major coalesced layout may partition a 16x16 tensor over 2 warps (i.e. 64 threads) as follows.
Example 1, a row-major coalesced layout may partition a 16x16 tensor over 2 warps (i.e. 64 threads) as follows:
[ 0 0 1 1 2 2 3 3 ; 32 32 33 33 34 34 35 35 ]
[ 0 0 1 1 2 2 3 3 ; 32 32 33 33 34 34 35 35 ]
@@ -210,82 +282,143 @@ for
sizePerThread = {2, 2}
threadsPerWarp = {8, 4}
warpsPerCTA = {1, 2}
CTAsPerCGA = {1, 1}
}>
Example 2, a row-major coalesced layout may partition a 32x32 tensor over 2 warps (i.e. 64 threads) as follows:
[ 0 0 1 1 2 2 3 3 ; 32 32 33 33 34 34 35 35 0 0 1 1 2 2 3 3 ; 32 32 33 33 34 34 35 35 ]
[ 0 0 1 1 2 2 3 3 ; 32 32 33 33 34 34 35 35 0 0 1 1 2 2 3 3 ; 32 32 33 33 34 34 35 35 ]
[ 4 4 5 5 6 6 7 7 ; 36 36 37 37 38 38 39 39 4 4 5 5 6 6 7 7 ; 36 36 37 37 38 38 39 39 ]
[ 4 4 5 5 6 6 7 7 ; 36 36 37 37 38 38 39 39 4 4 5 5 6 6 7 7 ; 36 36 37 37 38 38 39 39 ]
... ...
[ 28 28 29 29 30 30 31 31 ; 60 60 61 61 62 62 63 63 28 28 29 29 30 30 31 31 ; 60 60 61 61 62 62 63 63 ]
[ 28 28 29 29 30 30 31 31 ; 60 60 61 61 62 62 63 63 28 28 29 29 30 30 31 31 ; 60 60 61 61 62 62 63 63 ]
[ 0 0 1 1 2 2 3 3 ; 32 32 33 33 34 34 35 35 0 0 1 1 2 2 3 3 ; 32 32 33 33 34 34 35 35 ]
[ 0 0 1 1 2 2 3 3 ; 32 32 33 33 34 34 35 35 0 0 1 1 2 2 3 3 ; 32 32 33 33 34 34 35 35 ]
[ 4 4 5 5 6 6 7 7 ; 36 36 37 37 38 38 39 39 4 4 5 5 6 6 7 7 ; 36 36 37 37 38 38 39 39 ]
[ 4 4 5 5 6 6 7 7 ; 36 36 37 37 38 38 39 39 4 4 5 5 6 6 7 7 ; 36 36 37 37 38 38 39 39 ]
... ...
[ 28 28 29 29 30 30 31 31 ; 60 60 61 61 62 62 63 63 28 28 29 29 30 30 31 31 ; 60 60 61 61 62 62 63 63 ]
[ 28 28 29 29 30 30 31 31 ; 60 60 61 61 62 62 63 63 28 28 29 29 30 30 31 31 ; 60 60 61 61 62 62 63 63 ]
for
#triton_gpu.blocked_layout<{
sizePerThread = {2, 2}
threadsPerWarp = {8, 4}
warpsPerCTA = {1, 2}
CTAsPerCGA = {1, 1}
}>
Example 3, A row-major coalesced layout may partition a 32x32 tensor over 2 warps (i.e. 64 threads) and
4 CTAs (taking 2x2 for example) as follows:
CTA [0,0] CTA [0,1]
[ 0 0 1 1 2 2 3 3 ; 32 32 33 33 34 34 35 35 ] [ 0 0 1 1 2 2 3 3 ; 32 32 33 33 34 34 35 35 ]
[ 0 0 1 1 2 2 3 3 ; 32 32 33 33 34 34 35 35 ] [ 0 0 1 1 2 2 3 3 ; 32 32 33 33 34 34 35 35 ]
[ 4 4 5 5 6 6 7 7 ; 36 36 37 37 38 38 39 39 ] [ 4 4 5 5 6 6 7 7 ; 36 36 37 37 38 38 39 39 ]
[ 4 4 5 5 6 6 7 7 ; 36 36 37 37 38 38 39 39 ] [ 4 4 5 5 6 6 7 7 ; 36 36 37 37 38 38 39 39 ]
... ...
[ 28 28 29 29 30 30 31 31 ; 60 60 61 61 62 62 63 63 ] [ 28 28 29 29 30 30 31 31 ; 60 60 61 61 62 62 63 63 ]
[ 28 28 29 29 30 30 31 31 ; 60 60 61 61 62 62 63 63 ] [ 28 28 29 29 30 30 31 31 ; 60 60 61 61 62 62 63 63 ]
CTA [1,0] CTA [1,1]
[ 0 0 1 1 2 2 3 3 ; 32 32 33 33 34 34 35 35 ] [ 0 0 1 1 2 2 3 3 ; 32 32 33 33 34 34 35 35 ]
[ 0 0 1 1 2 2 3 3 ; 32 32 33 33 34 34 35 35 ] [ 0 0 1 1 2 2 3 3 ; 32 32 33 33 34 34 35 35 ]
[ 4 4 5 5 6 6 7 7 ; 36 36 37 37 38 38 39 39 ] [ 4 4 5 5 6 6 7 7 ; 36 36 37 37 38 38 39 39 ]
[ 4 4 5 5 6 6 7 7 ; 36 36 37 37 38 38 39 39 ] [ 4 4 5 5 6 6 7 7 ; 36 36 37 37 38 38 39 39 ]
... ...
[ 28 28 29 29 30 30 31 31 ; 60 60 61 61 62 62 63 63 ] [ 28 28 29 29 30 30 31 31 ; 60 60 61 61 62 62 63 63 ]
[ 28 28 29 29 30 30 31 31 ; 60 60 61 61 62 62 63 63 ] [ 28 28 29 29 30 30 31 31 ; 60 60 61 61 62 62 63 63 ]
for
#triton_gpu.blocked_layout<{
sizePerThread = {2, 2}
threadsPerWarp = {8, 4}
warpsPerCTA = {1, 2}
CTAsPerCGA = {2, 2}
}>
}];
let builders = [
// Custom builder initializes sizePerWarp and sizePerCTA automatically
// TODO: compiles on MacOS but not linux?
// AttrBuilder<(ins "ArrayRef<unsigned>":$sizePerThread,
// "ArrayRef<unsigned>":$threadsPerWarp,
// "ArrayRef<unsigned>":$warpsPerCTA,
// "ArrayRef<unsigned>":$order), [{
// int rank = threadsPerWarp.size();
// SmallVector<unsigned, 4> sizePerWarp(rank);
// SmallVector<unsigned, 4> sizePerCTA(rank);
// for (unsigned i = 0; i < rank; i++) {
// sizePerWarp.push_back(sizePerThread[i] * threadsPerWarp[i]);
// sizePerCTA.push_back(sizePerWarp[i] * warpsPerCTA[i]);
// }
// return $_get(context, sizePerThread, threadsPerWarp, warpsPerCTA, order, sizePerWarp, sizePerCTA);
// }]>,
// Custom builder initializes sizePerWarp and sizePerCTA automatically
// Default builder takes sizePerThread, order and numWarps, and tries to
// pack numWarps*32 threads in the provided order for use in a type
// of the given shape.
AttrBuilder<(ins "ArrayRef<int64_t>":$shape,
"ArrayRef<unsigned>":$sizePerThread,
"ArrayRef<unsigned>":$order,
"unsigned":$numWarps,
"unsigned":$threadsPerWarp), [{
int rank = sizePerThread.size();
unsigned remainingLanes = threadsPerWarp;
unsigned remainingThreads = numWarps*threadsPerWarp;
unsigned remainingWarps = numWarps;
unsigned prevLanes = 1;
unsigned prevWarps = 1;
SmallVector<unsigned, 4> rankedThreadsPerWarp(rank);
SmallVector<unsigned, 4> warpsPerCTA(rank);
for (int _dim = 0; _dim < rank - 1; ++_dim) {
int i = order[_dim];
unsigned threadsPerCTA = std::clamp<unsigned>(remainingThreads, 1, shape[i] / sizePerThread[i]);
rankedThreadsPerWarp[i] = std::clamp<unsigned>(threadsPerCTA, 1, remainingLanes);
warpsPerCTA[i] = std::clamp<unsigned>(threadsPerCTA / rankedThreadsPerWarp[i], 1, remainingWarps);
remainingWarps /= warpsPerCTA[i];
remainingLanes /= rankedThreadsPerWarp[i];
remainingThreads /= threadsPerCTA;
prevLanes *= rankedThreadsPerWarp[i];
prevWarps *= warpsPerCTA[i];
}
// Expand the last dimension to fill the remaining lanes and warps
rankedThreadsPerWarp[order[rank-1]] = threadsPerWarp / prevLanes;
warpsPerCTA[order[rank-1]] = numWarps / prevWarps;
return $_get(context, sizePerThread, rankedThreadsPerWarp, warpsPerCTA, order);
}]>
];
let extraClassDeclaration = extraBaseClassDeclaration # [{
SliceEncodingAttr squeeze(int axis);
}];
let parameters = (
ins
ArrayRefParameter<"unsigned">:$sizePerThread,
ArrayRefParameter<"unsigned">:$threadsPerWarp,
ArrayRefParameter<"unsigned">:$warpsPerCTA,
// fastest-changing axis first
ArrayRefParameter<
"unsigned",
"order of axes by the rate of changing"
>:$order
// These attributes can be inferred from the rest
// ArrayRefParameter<"unsigned">:$sizePerWarp,
// ArrayRefParameter<"unsigned">:$sizePerCTA
ArrayRefParameter<"unsigned">:$order, // the fastest-changing axis first
"CTALayoutAttr":$CTALayout
);
let builders = [
AttrBuilder<(ins "ArrayRef<int64_t>":$shape,
"ArrayRef<unsigned>":$sizePerThread,
"ArrayRef<unsigned>":$order,
"unsigned":$numWarps,
"unsigned":$numThreadsPerWarp,
"CTALayoutAttr":$CTALayout), [{
unsigned rank = sizePerThread.size();
SmallVector<unsigned, 4> threadsPerWarp(rank);
SmallVector<unsigned, 4> warpsPerCTA(rank);
SmallVector<int64_t> shapePerCTA = getShapePerCTA(CTALayout.getCTASplitNum(), shape);
unsigned remainingLanes = numThreadsPerWarp;
unsigned remainingThreads = numWarps * numThreadsPerWarp;
unsigned remainingWarps = numWarps;
unsigned prevLanes = 1;
unsigned prevWarps = 1;
// starting from the contiguous dimension
for (unsigned d = 0; d < rank - 1; ++d) {
unsigned i = order[d];
unsigned threadsPerCTA = std::clamp<unsigned>(remainingThreads, 1, shapePerCTA[i] / sizePerThread[i]);
threadsPerWarp[i] = std::clamp<unsigned>(threadsPerCTA, 1, remainingLanes);
warpsPerCTA[i] = std::clamp<unsigned>(threadsPerCTA / threadsPerWarp[i], 1, remainingWarps);
remainingWarps /= warpsPerCTA[i];
remainingLanes /= threadsPerWarp[i];
remainingThreads /= threadsPerCTA;
prevLanes *= threadsPerWarp[i];
prevWarps *= warpsPerCTA[i];
}
// Expand the last dimension to fill the remaining lanes and warps
threadsPerWarp[order[rank - 1]] = numThreadsPerWarp / prevLanes;
warpsPerCTA[order[rank - 1]] = numWarps / prevWarps;
return $_get(context, sizePerThread, threadsPerWarp, warpsPerCTA, order, CTALayout);
}]>,
AttrBuilder<(ins "ArrayRef<int64_t>":$shape,
"ArrayRef<unsigned>":$sizePerThread,
"ArrayRef<unsigned>":$order,
"unsigned":$numWarps,
"unsigned":$numThreadsPerWarp,
"unsigned":$numCTAs), [{
unsigned rank = sizePerThread.size();
SmallVector<unsigned, 4> CTAsPerCGA(rank);
SmallVector<unsigned, 4> CTASplitNum(rank);
ArrayRef<unsigned> CTAOrder = order;
unsigned remainingCTAs = numCTAs;
// starting from the most strided dimension
for (int d = rank - 1; d >= 0; --d) {
unsigned i = order[d];
CTAsPerCGA[i] = std::clamp<unsigned>(remainingCTAs, 1, shape[i] / sizePerThread[i]);
CTASplitNum[i] = CTAsPerCGA[i];
remainingCTAs /= CTAsPerCGA[i];
}
CTAsPerCGA[rank - 1] *= remainingCTAs; // wrap at CTA level
CTALayoutAttr CTALayout = CTALayoutAttr::get(context, CTAsPerCGA, CTASplitNum, CTAOrder);
return get(context, shape, sizePerThread, order, numWarps, numThreadsPerWarp, CTALayout);
}]>
];
let extraClassDeclaration = extraBaseClassDeclaration # [{
SliceEncodingAttr squeeze(int axis);
}];
let hasCustomAssemblyFormat = 1;
}
@@ -381,13 +514,17 @@ For example, the matrix L corresponding to blockTileSize=[32,16] is:
ins
"unsigned":$versionMajor,
"unsigned":$versionMinor,
ArrayRefParameter<"unsigned">:$warpsPerCTA
ArrayRefParameter<"unsigned">:$warpsPerCTA,
"CTALayoutAttr":$CTALayout,
ArrayRefParameter<"unsigned">:$instrShape
);
let builders = [
// Specially for MMAV1(Volta)
AttrBuilder<(ins "int":$versionMajor,
"int":$numWarps,
"CTALayoutAttr":$CTALayout,
"ArrayRef<unsigned>":$instrShape,
"ArrayRef<int64_t>":$shapeC,
"bool":$isARow,
"bool":$isBRow,
@@ -401,7 +538,6 @@ For example, the matrix L corresponding to blockTileSize=[32,16] is:
(isAVec4 * (1<<2)) |\
(isBVec4 * (1<<3));
// TODO: Share code with
// DotOpMmaV1ConversionHelper::AParam/BParam, since same code to compute the
// rep,spw and fpw.
@@ -426,12 +562,14 @@ For example, the matrix L corresponding to blockTileSize=[32,16] is:
wpt[1] = std::clamp<int>(wpt[1] * 2, 1, shapeC[1] / spw[1]);
} while (wpt_nm1 != wpt);
return $_get(context, versionMajor, versionMinor, wpt);
return $_get(context, versionMajor, versionMinor, wpt, CTALayout, instrShape);
}]>,
AttrBuilder<(ins "int":$versionMajor,
"int":$numWarps,
"CTALayoutAttr":$CTALayout,
"ArrayRef<unsigned>":$instrShape,
"ArrayRef<int64_t>":$shapeA,
"ArrayRef<int64_t>":$shapeB,
"ArrayRef<int64_t>":$shapeC,
@@ -441,15 +579,20 @@ For example, the matrix L corresponding to blockTileSize=[32,16] is:
assert(versionMajor == 1 && "This builder is specially for versionMajor==1");
bool isAVec4 = !isARow && (shapeA[isARow] <= 16);
bool isBVec4 = isBRow && (shapeB[isBRow] <= 16);
return get(context, versionMajor, numWarps, shapeC, isARow, isBRow, isAVec4, isBVec4, id);
return get(context, versionMajor, numWarps, CTALayout, instrShape, shapeC, isARow, isBRow, isAVec4, isBVec4, id);
}]>
];
let extraClassDeclaration = extraBaseClassDeclaration # [{
bool isVolta() const;
bool isAmpere() const;
bool isHopper() const;
unsigned getElemsPerThreadOfOperand(int opIdx, ArrayRef<int64_t> shape) const;
// Get [isARow, isBRow, isAVec4, isBVec4, id] from versionMinor
std::tuple<bool, bool, bool, bool, int> decodeVoltaLayoutStates() const;
// Number of bits in versionMinor to hold the ID of the MMA encoding instance.
// Here 5 bits can hold 32 IDs in a single module.
static constexpr int numBitsToHoldMmaV1ID{5};
@@ -544,6 +687,4 @@ section 9.7.13.4.1 for more details.
}];
}
#endif

View File

@@ -16,6 +16,7 @@ def TritonGPU_Dialect : Dialect {
let dependentDialects = [
"triton::TritonDialect",
"mlir::triton::nvgpu::NVGPUDialect",
"mlir::gpu::GPUDialect",
"tensor::TensorDialect",
];
@@ -23,14 +24,27 @@ def TritonGPU_Dialect : Dialect {
let extraClassDeclaration = [{
static std::string getNumWarpsAttrName() { return "triton_gpu.num-warps"; }
static int getNumWarps(ModuleOp mod) {
Attribute numWarps = mod->getDiscardableAttr("triton_gpu.num-warps");
if(!numWarps)
if(!mod->hasAttr("triton_gpu.num-warps"))
llvm::report_fatal_error(
"TritonGPU module should contain a triton_gpu.num-warps attribute");
return numWarps.cast<IntegerAttr>().getInt();
return mod->getAttr("triton_gpu.num-warps").cast<IntegerAttr>().getInt();
}
static int getNumCTAs(ModuleOp mod) {
if(!mod->hasAttr("triton_gpu.num-ctas"))
llvm::report_fatal_error(
"TritonGPU module should contain a triton_gpu.num-ctas attribute");
return mod->getAttr("triton_gpu.num-ctas").cast<IntegerAttr>().getInt();
}
static int getComputeCapability(ModuleOp mod) {
if(!mod->hasAttr("triton_gpu.compute-capability"))
llvm::report_fatal_error(
"TritonGPU module should contain a triton_gpu.compute-capability attribute");
return mod->getAttrOfType<IntegerAttr>("triton_gpu.compute-capability").getInt();
}
void registerTypes();
static std::string getThreadsPerWarpAttrName() { return "triton_gpu.threads-per-warp"; }
static int getThreadsPerWarp(ModuleOp mod) {
Attribute threadsPerWarp = mod->getDiscardableAttr("triton_gpu.threads-per-warp");
if(!threadsPerWarp) {
@@ -38,7 +52,6 @@ def TritonGPU_Dialect : Dialect {
}
return threadsPerWarp.cast<IntegerAttr>().getInt();
}
}];
let useDefaultAttributePrinterParser = 1;

View File

@@ -2,6 +2,7 @@
#define TRITONGPU_OPS
include "triton/Dialect/TritonGPU/IR/TritonGPUDialect.td"
include "triton/Dialect/TritonGPU/IR/TritonGPUTypes.td"
include "triton/Dialect/TritonGPU/IR/TritonGPUAttrDefs.td"
include "mlir/Dialect/Arith/IR/ArithBase.td"
include "triton/Dialect/Triton/IR/TritonTypes.td"
@@ -46,6 +47,20 @@ def TTG_AsyncWaitOp : TTG_Op<"async_wait"> {
}];
}
def TTG_AsyncBulkWaitOp : TTG_Op<"async_bulk_wait"> {
let summary = "async bulk wait";
let arguments = (ins I32Attr:$num);
let assemblyFormat = "attr-dict";
let extraClassDeclaration = [{
static bool isSupported(int computeCapability) {
return computeCapability >= 90;
}
}];
}
def TTG_AsyncCommitGroupOp : TTG_Op<"async_commit_group"> {
let summary = "async commit group";
@@ -58,6 +73,18 @@ def TTG_AsyncCommitGroupOp : TTG_Op<"async_commit_group"> {
}];
}
def TTG_AsyncBulkCommitGroupOp : TTG_Op<"async_bulk_commit_group"> {
let summary = "async bulk commit group";
let assemblyFormat = "attr-dict";
let extraClassDeclaration = [{
static bool isSupported(int computeCapability) {
return computeCapability >= 90;
}
}];
}
// Port Arith_CmpIOp & Arith_CmpFOp & Std_SelectOp to TritonGPU.
// This is needed because these ops don't
@@ -106,6 +133,98 @@ def TTG_SelectOp : TTG_Op<"select", [Pure, Elementwise,
let results = (outs TT_Type:$result);
}
// TODO[goostavz]: extract a base class for InsertSlice & InsertSliceAsync once the op definition is verified
def TTG_InsertSliceOp : TTG_Op<"insert_slice",
[AttrSizedOperandSegments,
ResultsAreSharedEncoding,
MemoryEffects<[MemRead, MemWrite]>,
TypesMatchWith<"infer mask type from src type",
"src", "mask", "getI1SameShape($_self)",
"($_op.getOperands().size() <= 3) || std::equal_to<>()">,
TypesMatchWith<"infer other type from src type",
"src", "other", "getPointeeType($_self)",
"($_op.getOperands().size() <= 4) || std::equal_to<>()">]> {
let summary = "insert slice";
let description = [{
This operation inserts a tensor `$src` into another tensor `$dst` as specified by the operations
`$index` argument and `$axis` attribute.
It returns a copy of `$dst` with the proper slice updated with the value of `$src`.
When converting from `tt.load` to `triton_gpu.insert_slice`, the `$evict`, `$cache`, and `$isVolatile` fields
might be ignored on certain hardware. For example, on NVIDIA GPUs, the cache policy is determined by the backend,
and `$evict` and `$isVolatile` are ignored because they apply to L1 cache only.
The insert_slice operation supports the following arguments:
* src: the tensor that is inserted.
* dst: the tensor into which the `$src` tensor is inserted.
* index: the index of the `$src` tensor at the given `$axis` from which the `$dst` tensor is inserted into
* mask: optional tensor-rank number of boolean masks which specify which
elements of the `$src` tensor are inserted into the `$dst` tensor.
* other: optional tensor-rank number of other tensors which specify what
values are inserted into the `$dst` tensor if the corresponding
element of the `$mask` tensor is false.
ttgpu.load_tile_async depracate
triton_gpu.insert_slice might be further lowered into triton_gpu_async for different hardware implementations
like tt.load, ttgpu.insert_slice/insert_slice_async has two modes up to the type of src
mode 1: ptr/src is a tensor of pointers
mode 2: ptr/src is a tensor pointer
Some typical lowering paths are:
in case the load is pipelined by the pipeline pass( load is inside kBlock loop, which means "pipeline pass):
Load from global + store to shared : tt.load(mode 1) -(tt->ttgpu+Coalesce)-> tt.load(mode 1) -(Pipeline)-> ttgpu.insert_slice(mode 1)
Non-bulk cp.async : tt.load(mode 1) -(tt->ttgpu+Coalesce)-> tt.load(mode 1) -(Pipeline)-> ttgpu.insert_slice(mode 1) -(MaterializeLoad)> ttgpu.insert_slice_async(mode 1) + ttgpu.await-> llvm
TMA load : tt.load(mode 2) -(tt->ttgpu+Coalesce)-> tt.load(mode 2) -(Pipeline)-> ttgpu.insert_slice(mode 2) -(MaterializeLoad)> ttgpu.insert_slice_async_v2(mode 2) + ttgpu.await-> llvm
otherwise:
Load from global + store to shared : tt.load(mode 1) -(tt->ttgpu+Coalesce)-> tt.load(mode 1)
Non-bulk cp.async : tt.load(mode 1) -(tt->ttgpu+Coalesce)-> tt.load(mode 1) -> ... -(MaterializeLoad)-> ttgpu.insert_slice_async(mode 1) + ttgpu.await -> llvm
TMA load : tt.load(mode 2) -(tt->ttgpu+Coalesce)-> tt.load(mode 2) -> ... -(MaterializeLoad)-> ttgpu.insert_slice_async(mode 2) + ttgpu.await -> llvm
Example:
```
%1 = triton_gpu.alloc_tensor : tensor<2x32xf32>
%2 = triton_gpu.insert_slice %0, %1, %index { axis = 0 } : tensor<32x!tt.ptr<f32>, #AL> -> tensor<2x32xf32, #A>
```
}];
let arguments = (ins TT_PtrLike:$src, TT_Tensor:$dst, I32:$index,
Optional<I1Tensor>:$mask, Optional<TT_Type>:$other,
TT_CacheModifierAttr:$cache, TT_EvictionPolicyAttr:$evict,
BoolAttr:$isVolatile, I32Attr:$axis);
let builders = [
OpBuilder<(ins "Value":$src, "Value":$dst, "Value":$index,
"triton::CacheModifier":$cache,
"triton::EvictionPolicy":$evict, "bool":$isVolatile, "int":$axis)>,
OpBuilder<(ins "Value":$src, "Value":$dst, "Value":$index, "Value":$mask,
"triton::CacheModifier":$cache,
"triton::EvictionPolicy":$evict, "bool":$isVolatile, "int":$axis)>,
OpBuilder<(ins "Value":$src, "Value":$dst, "Value":$index,
"Value":$mask, "Value":$other,
"triton::CacheModifier":$cache,
"triton::EvictionPolicy":$evict, "bool":$isVolatile, "int":$axis)>,
];
let results = (outs TT_Tensor:$result);
let extraClassDeclaration = [{
static DenseSet<unsigned> getEligibleLoadByteWidth(int computeCapability) {
DenseSet<unsigned> validLoadBytes;
if (computeCapability >= 80) {
validLoadBytes = {4, 8, 16};
}
return validLoadBytes;
}
}];
let hasCustomAssemblyFormat = 1;
}
def TTG_ExtractSliceOp : TTG_Op<"extract_slice",
@@ -173,7 +292,8 @@ def TTG_ExtractSliceOp : TTG_Op<"extract_slice",
def TTG_InsertSliceAsyncOp : TTG_Op<"insert_slice_async",
[AttrSizedOperandSegments,
ResultsAreSharedEncoding,
MemoryEffects<[MemRead]>,
// TODO: Check if MemWrite will degrade performance of non-warp-specialized kernel
MemoryEffects<[MemRead, MemWrite]>,
TypesMatchWith<"infer mask type from src type",
"src", "mask", "getI1SameShape($_self)",
"($_op.getOperands().size() <= 3) || std::equal_to<>()">,
@@ -219,7 +339,7 @@ def TTG_InsertSliceAsyncOp : TTG_Op<"insert_slice_async",
```
}];
let arguments = (ins TT_PtrTensor:$src, TT_Tensor:$dst, I32:$index,
let arguments = (ins TT_PtrLike:$src, TT_Tensor:$dst, I32:$index,
Optional<I1Tensor>:$mask, Optional<TT_Type>:$other,
TT_CacheModifierAttr:$cache, TT_EvictionPolicyAttr:$evict,
BoolAttr:$isVolatile, I32Attr:$axis);

View File

@@ -0,0 +1,26 @@
#ifndef TRITONGPU_TYPES
#define TRITONGPU_TYPES
include "triton/Dialect/TritonGPU/IR/TritonGPUDialect.td"
include "mlir/IR/AttrTypeBase.td"
class TTG_TypeDef<string name, string _mnemonic, list<Trait> traits = []>
: TypeDef<TritonGPU_Dialect, name, traits> {
let mnemonic = _mnemonic;
}
def TTG_TokenType : TTG_TypeDef<"Token", "token"> {
let parameters = (ins "int32_t":$type);
let builders = [
TypeBuilder<(ins "unsigned":$type), [{
return $_get($_ctxt, type);
}]>
];
let hasCustomAssemblyFormat = 1;
let skipDefaultBuilders = 1;
}
#endif

View File

@@ -0,0 +1,10 @@
#ifndef TRITONGPU_IR_TYPES_H_
#define TRITONGPU_IR_TYPES_H_
#include "mlir/IR/TypeSupport.h"
#include "mlir/IR/Types.h"
#define GET_TYPEDEF_CLASSES
#include "triton/Dialect/TritonGPU/IR/Types.h.inc"
#endif // TRITON_IR_TYPES_H_

View File

@@ -2,9 +2,14 @@
#define TRITON_DIALECT_TRITONGPU_TRANSFORMS_PASSES_H_
#include "mlir/Pass/Pass.h"
#include "triton/Dialect/TritonNvidiaGPU/IR/Dialect.h"
namespace mlir {
std::unique_ptr<Pass> createTritonGPUPipelinePass(int numStages = 2);
std::unique_ptr<Pass> createTritonGPUPipelinePass(int numStages = 3,
int numWarps = 4,
int numCTAs = 1,
int computeCapability = 80);
std::unique_ptr<Pass>
createTritonGPUAccelerateMatmulPass(int computeCapability = 80);
@@ -25,6 +30,8 @@ std::unique_ptr<Pass> createTritonGPUVerifier();
std::unique_ptr<Pass> createTritonGPUOptimizeDotOperandsPass();
std::unique_ptr<Pass> createTritonGPUOptimizeEpiloguePass();
/// Generate the code for registering passes.
#define GEN_PASS_REGISTRATION
#include "triton/Dialect/TritonGPU/Transforms/Passes.h.inc"

View File

@@ -14,13 +14,23 @@ def TritonGPUPipeline : Pass<"tritongpu-pipeline", "mlir::ModuleOp"> {
let constructor = "mlir::createTritonGPUPipelinePass()";
let dependentDialects = ["mlir::triton::gpu::TritonGPUDialect",
"mlir::triton::nvidia_gpu::TritonNvidiaGPUDialect",
"mlir::scf::SCFDialect",
"mlir::arith::ArithDialect"];
let options = [
Option<"numStages", "num-stages",
"int32_t", /*default*/"2",
"number of pipeline stages">
"int32_t", /*default*/"3",
"number of pipeline stages">,
Option<"numWarps", "num-warps",
"int32_t", /*default*/"4",
"number of warps per block">,
Option<"numCTAs", "num-ctas",
"int32_t", /*default*/"1",
"number of CTAs per CGA">,
Option<"computeCapability", "compute-capability",
"int32_t", /*default*/"80",
"device compute capability">
];
}
@@ -50,6 +60,7 @@ def TritonGPUAccelerateMatmul : Pass<"tritongpu-accelerate-matmul", "mlir::Modul
let constructor = "mlir::createTritonGPUAccelerateMatmulPass()";
let dependentDialects = ["mlir::triton::gpu::TritonGPUDialect",
"mlir::triton::nvidia_gpu::TritonNvidiaGPUDialect",
"mlir::triton::TritonDialect"];
let options = [
@@ -70,6 +81,7 @@ def TritonGPUOptimizeDotOperands : Pass<"tritongpu-optimize-dot-operands", "mlir
let constructor = "mlir::createTritonGPUOptimizeDotOperandsPass()";
let dependentDialects = ["mlir::triton::gpu::TritonGPUDialect",
"mlir::triton::nvidia_gpu::TritonNvidiaGPUDialect",
"mlir::triton::TritonDialect"];
}
@@ -96,6 +108,20 @@ def TritonGPURemoveLayoutConversions : Pass<"tritongpu-remove-layout-conversions
let dependentDialects = ["mlir::triton::gpu::TritonGPUDialect",
"mlir::triton::TritonDialect"];
}
def TritonGPUOptimizeEpilogue : Pass<"tritongpu-optimize-epilogue", "mlir::ModuleOp"> {
let summary = "Optimize epilogue: (1) Store accumulators directly without going thorough SMEM in epilogue.";
let description = [{
}];
let constructor = "mlir::createTritonGPUOptimizeEpiloguePass()";
let dependentDialects = ["mlir::triton::gpu::TritonGPUDialect",
"mlir::triton::TritonDialect"];
}
def TritonGPUReorderInstructions: Pass<"tritongpu-reorder-instructions", "mlir::ModuleOp"> {

View File

@@ -13,15 +13,17 @@ namespace mlir {
class TritonGPUTypeConverter : public TypeConverter {
public:
TritonGPUTypeConverter(MLIRContext *context, int numWarps,
int threadsPerWarp);
TritonGPUTypeConverter(MLIRContext *context, int numWarps, int threadsPerWarp,
int numCTAs);
int getNumWarps() const { return numWarps; }
int getThreadsPerWarp() const { return threadsPerWarp; }
int getNumCTAs() const { return numCTAs; }
private:
MLIRContext *context;
int numWarps;
int threadsPerWarp;
int numCTAs;
};
class TritonGPUConversionTarget : public ConversionTarget {

View File

@@ -10,8 +10,99 @@
namespace mlir {
namespace triton {
class LoadOp;
class StoreOp;
class FuncOp;
namespace gpu {
class SharedEncodingAttr;
}
} // namespace triton
LogicalResult fixupLoops(ModuleOp mod);
SmallVector<unsigned, 3> mmaVersionToInstrShape(int version,
const ArrayRef<int64_t> &shape,
RankedTensorType type);
/// Returns true if the Load is for TMA
bool isLoadFromTensorPtr(triton::LoadOp op);
/// Returns true if the store is for TMA
bool isStoreToTensorPtr(triton::StoreOp op);
/// Return the first consumer of v
Operation *getFirstUser(Value v);
/// Return the proper SharedEncodingAttr according to shape/order
triton::gpu::SharedEncodingAttr getSharedEncoding(RankedTensorType tensorTy);
/* Dump Triton IR in graphviz dot format.
*
* You can override `onValue` and `onOperation` in a subclass to mark
* specific Values and Operations. The below subclass
* GraphLayoutMarker is an example.
*
* Default NodeInfo for Value nodes:
* {{"shape": "box"},
* {"style", "filled"},
* {"fillcolor", "white"},
* {"label", shapeStr}}
*
* Default NodeInfo for Operation nodes:
* {{"shape": "ellipse"},
* {"style", "filled"},
* {"fillcolor", "white"},
* {"label", operationName}}
*
* If the key "label" is not set by `onValue` or `onOperation`, default labels
* will be generated. For Value node, the default label is the shape string and
* for Operation node, it is the operation name.
*
* Reference:
* https://graphviz.org/doc/info/shapes.html
* https://graphviz.org/doc/info/colors.html
*
* Usage:
* C++: GraphDumper().dumpToFile(func, "func.dot");
* Shell: dot -Tjpg func.dot -o func.jpg
*/
class GraphDumper {
public:
using NodeInfo = std::map<std::string, std::string>;
// Override this function to mark specific Values
virtual NodeInfo onValue(Value value) const;
// Override this function to mark specific Operations
virtual NodeInfo onOperation(Operation *op) const;
std::string dump(triton::FuncOp func) const;
void dumpToFile(triton::FuncOp func, const std::string &filename) const;
protected:
std::string getShapeStr(const Type &type) const;
std::string getUniqueId(Value value) const;
std::string getUniqueId(Operation *op) const;
std::string emitNode(const std::string &id, const NodeInfo style) const;
std::string emitEdge(const std::string &srcId,
const std::string &destId) const;
std::string emitValueNode(Value value) const;
std::string emitOperationNode(Operation *op) const;
};
/* A subclass of GraphDumper that marks different layout kinds in different
* colors.*/
class GraphLayoutMarker : public GraphDumper {
public:
NodeInfo onValue(Value value) const override;
protected:
std::string getColor(const Type &type) const;
};
// TODO: Interface
LogicalResult invertEncoding(Attribute targetEncoding, Operation *op,
Attribute &ret);
@@ -38,6 +129,23 @@ void rematerializeConversionChain(
LogicalResult canMoveOutOfLoop(BlockArgument arg,
SmallVector<Operation *> &cvts);
// Convert an \param index to a multi-dim coordinate given \param shape and
// \param order.
SmallVector<Value> delinearize(OpBuilder &b, Location loc, Value linear,
ArrayRef<unsigned> shape,
ArrayRef<unsigned> order);
SmallVector<Value> delinearize(OpBuilder &b, Location loc, unsigned linear,
ArrayRef<unsigned> shape);
SmallVector<Value> delinearize(OpBuilder &b, Location loc, Value linear,
ArrayRef<unsigned> shape);
Value linearize(OpBuilder &b, Location loc, ArrayRef<Value> multiDim,
ArrayRef<unsigned> shape, ArrayRef<unsigned> order);
Value linearize(OpBuilder &b, Location loc, ArrayRef<Value> multiDim,
ArrayRef<unsigned> shape);
} // namespace mlir
#endif // TRITON_DIALECT_TRITONGPU_TRANSFORMS_UTILITY_H_

View File

@@ -0,0 +1,2 @@
add_subdirectory(IR)
add_subdirectory(Transforms)

View File

@@ -0,0 +1,15 @@
set(LLVM_TARGET_DEFINITIONS TritonNvidiaGPUOps.td)
mlir_tablegen(Dialect.h.inc -gen-dialect-decls -dialect=triton_nvidia_gpu)
mlir_tablegen(Dialect.cpp.inc -gen-dialect-defs -dialect=triton_nvidia_gpu)
mlir_tablegen(Ops.h.inc -gen-op-decls)
mlir_tablegen(Ops.cpp.inc -gen-op-defs)
mlir_tablegen(Types.h.inc -gen-typedef-decls -typedefs-dialect=triton_nvidia_gpu)
mlir_tablegen(Types.cpp.inc -gen-typedef-defs -typedefs-dialect=triton_nvidia_gpu)
add_public_tablegen_target(TritonNvidiaGPUTableGen)
set(LLVM_TARGET_DEFINITIONS TritonNvidiaGPUAttrDefs.td)
mlir_tablegen(TritonNvidiaGPUAttrDefs.h.inc -gen-attrdef-decls)
mlir_tablegen(TritonNvidiaGPUAttrDefs.cpp.inc -gen-attrdef-defs)
mlir_tablegen(OpsEnums.h.inc -gen-enum-decls)
mlir_tablegen(OpsEnums.cpp.inc -gen-enum-defs)
add_public_tablegen_target(TritonNvidiaGPUAttrDefsIncGen)

View File

@@ -0,0 +1,46 @@
/*
* Copyright (c) 2023 NVIDIA Corporation & Affiliates. All rights reserved.
*
* Permission is hereby granted, free of charge, to any person obtaining
* a copy of this software and associated documentation files
* (the "Software"), to deal in the Software without restriction,
* including without limitation the rights to use, copy, modify, merge,
* publish, distribute, sublicense, and/or sell copies of the Software,
* and to permit persons to whom the Software is furnished to do so,
* subject to the following conditions:
*
* The above copyright notice and this permission notice shall be
* included in all copies or substantial portions of the Software.
*
* THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND,
* EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF
* MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT.
* IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY
* CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION OF CONTRACT,
* TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION WITH THE
* SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE.
*/
#ifndef TRITON_DIALECT_TRITONNVIDIAGPU_IR_DIALECT_H_
#define TRITON_DIALECT_TRITONNVIDIAGPU_IR_DIALECT_H_
#include "mlir/Dialect/GPU/IR/GPUDialect.h"
#include "mlir/Dialect/Tensor/IR/Tensor.h"
#include "mlir/IR/BuiltinOps.h"
#include "mlir/IR/Dialect.h"
// TritonNvidiaGPU depends on Triton
#include "triton/Dialect/NVGPU/IR/Dialect.h"
#include "triton/Dialect/Triton/IR/Dialect.h"
#include "triton/Dialect/TritonGPU/IR/Traits.h"
#include "triton/Dialect/TritonNvidiaGPU/IR/Dialect.h.inc"
#include "triton/Dialect/TritonNvidiaGPU/IR/Traits.h"
#include "triton/Dialect/TritonNvidiaGPU/IR/Types.h"
#define GET_ATTRDEF_CLASSES
#include "triton/Dialect/TritonNvidiaGPU/IR/TritonNvidiaGPUAttrDefs.h.inc"
#define GET_OP_CLASSES
#include "triton/Dialect/TritonNvidiaGPU/IR/Ops.h.inc"
#endif // TRITON_DIALECT_TRITONNVIDIAGPU_IR_DIALECT_H_

View File

@@ -0,0 +1,53 @@
/*
* Copyright (c) 2023 NVIDIA Corporation & Affiliates. All rights reserved.
*
* Permission is hereby granted, free of charge, to any person obtaining
* a copy of this software and associated documentation files
* (the "Software"), to deal in the Software without restriction,
* including without limitation the rights to use, copy, modify, merge,
* publish, distribute, sublicense, and/or sell copies of the Software,
* and to permit persons to whom the Software is furnished to do so,
* subject to the following conditions:
*
* The above copyright notice and this permission notice shall be
* included in all copies or substantial portions of the Software.
*
* THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND,
* EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF
* MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT.
* IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY
* CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION OF CONTRACT,
* TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION WITH THE
* SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE.
*/
#ifndef TRITON_NVIDIA_GPU_IR_TRAITS_H_
#define TRITON_NVIDIA_GPU_IR_TRAITS_H_
#include "mlir/IR/OpDefinition.h"
#include "mlir/IR/BuiltinTypes.h"
#include "mlir/Support/LogicalResult.h"
namespace mlir {
namespace OpTrait {
// These functions are out-of-line implementations of the methods in the
// corresponding trait classes. This avoids them being template
// instantiated/duplicated.
namespace impl {
LogicalResult verifySource1IsSharedEncoding(Operation *op);
} // namespace impl
template <typename ConcreteType>
class Source1IsSharedEncoding
: public TraitBase<ConcreteType, Source1IsSharedEncoding> {
public:
static LogicalResult verifyTrait(Operation *op) {
return impl::verifySource1IsSharedEncoding(op);
}
};
} // namespace OpTrait
} // namespace mlir
#endif

View File

@@ -0,0 +1,29 @@
// Copyright (c) 2023 NVIDIA Corporation & Affiliates. All rights reserved.
//
// Permission is hereby granted, free of charge, to any person obtaining
// a copy of this software and associated documentation files
// (the "Software"), to deal in the Software without restriction,
// including without limitation the rights to use, copy, modify, merge,
// publish, distribute, sublicense, and/or sell copies of the Software,
// and to permit persons to whom the Software is furnished to do so,
// subject to the following conditions:
//
// The above copyright notice and this permission notice shall be
// included in all copies or substantial portions of the Software.
//
// THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND,
// EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF
// MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT.
// IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY
// CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION OF CONTRACT,
// TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION WITH THE
// SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE.
#ifndef TRITONNVIDIAGPU_ATTRDEFS
#define TRITONNVIDIAGPU_ATTRDEFS
include "mlir/IR/AttrTypeBase.td"
include "triton/Dialect/TritonNvidiaGPU/IR/TritonNvidiaGPUDialect.td"
include "triton/Dialect/Triton/IR/TritonInterfaces.td"
#endif

View File

@@ -0,0 +1,82 @@
// Copyright (c) 2023 NVIDIA Corporation & Affiliates. All rights reserved.
//
// Permission is hereby granted, free of charge, to any person obtaining
// a copy of this software and associated documentation files
// (the "Software"), to deal in the Software without restriction,
// including without limitation the rights to use, copy, modify, merge,
// publish, distribute, sublicense, and/or sell copies of the Software,
// and to permit persons to whom the Software is furnished to do so,
// subject to the following conditions:
//
// The above copyright notice and this permission notice shall be
// included in all copies or substantial portions of the Software.
//
// THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND,
// EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF
// MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT.
// IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY
// CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION OF CONTRACT,
// TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION WITH THE
// SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE.
#ifndef TRITONNVIDIAGPU_DIALECT
#define TRITONNVIDIAGPU_DIALECT
include "mlir/IR/OpBase.td"
def TritonNvidiaGPU_Dialect : Dialect {
let name = "triton_nvidia_gpu";
let cppNamespace = "::mlir::triton::nvidia_gpu";
let hasOperationAttrVerify = 1;
let description = [{
Triton Nvidia GPU Dialect.
}];
let dependentDialects = [
"triton::TritonDialect",
"triton::gpu::TritonGPUDialect",
"mlir::triton::nvgpu::NVGPUDialect",
"mlir::gpu::GPUDialect",
"tensor::TensorDialect",
];
let extraClassDeclaration = [{
static std::string getNumWarpsAttrName() { return "triton_gpu.num-warps"; }
static int getNumWarps(ModuleOp mod) {
if(!mod->hasAttr("triton_gpu.num-warps"))
llvm::report_fatal_error(
"TritonGPU module should contain a triton_gpu.num-warps attribute");
return mod->getAttr("triton_gpu.num-warps").cast<IntegerAttr>().getInt();
}
static int getNumCTAs(ModuleOp mod) {
if(!mod->hasAttr("triton_gpu.num-ctas"))
llvm::report_fatal_error(
"TritonGPU module should contain a triton_gpu.num-ctas attribute");
return mod->getAttr("triton_gpu.num-ctas").cast<IntegerAttr>().getInt();
}
static int getComputeCapability(ModuleOp mod) {
if(!mod->hasAttr("triton_gpu.compute-capability"))
llvm::report_fatal_error(
"TritonGPU module should contain a triton_gpu.compute-capability attribute");
return mod->getAttrOfType<IntegerAttr>("triton_gpu.compute-capability").getInt();
}
void registerTypes();
// Warp specialization related:
static std::string getWSSupportedAttrName() { return "triton_gpu.enable-warp-specialization"; }
static int getWSSupportedAttr(ModuleOp mod) {
auto name = getWSSupportedAttrName();
if (!mod->hasAttr(name)) return 0;
return mod->getAttrOfType<IntegerAttr>(name).getInt();
}
}];
let useDefaultTypePrinterParser = 1;
}
include "triton/Dialect/TritonNvidiaGPU/IR/TritonNvidiaGPUTypes.td"
#endif

View File

@@ -0,0 +1,386 @@
// Copyright (c) 2023 NVIDIA Corporation & Affiliates. All rights reserved.
//
// Permission is hereby granted, free of charge, to any person obtaining
// a copy of this software and associated documentation files
// (the "Software"), to deal in the Software without restriction,
// including without limitation the rights to use, copy, modify, merge,
// publish, distribute, sublicense, and/or sell copies of the Software,
// and to permit persons to whom the Software is furnished to do so,
// subject to the following conditions:
//
// The above copyright notice and this permission notice shall be
// included in all copies or substantial portions of the Software.
//
// THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND,
// EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF
// MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT.
// IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY
// CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION OF CONTRACT,
// TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION WITH THE
// SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE.
#ifndef TRITONNVIDIAGPU_OPS
#define TRITONNVIDIAGPU_OPS
include "triton/Dialect/TritonNvidiaGPU/IR/TritonNvidiaGPUDialect.td"
include "triton/Dialect/TritonNvidiaGPU/IR/TritonNvidiaGPUTypes.td"
include "triton/Dialect/TritonNvidiaGPU/IR/TritonNvidiaGPUAttrDefs.td"
include "mlir/Dialect/Arith/IR/ArithBase.td"
include "triton/Dialect/Triton/IR/TritonTypes.td"
include "triton/Dialect/Triton/IR/TritonAttrDefs.td"
include "mlir/IR/OpBase.td"
include "mlir/Interfaces/SideEffectInterfaces.td" // Pure
include "mlir/Interfaces/InferTypeOpInterface.td" // SameOperandsAndResultType
include "mlir/Interfaces/DestinationStyleOpInterface.td"
include "mlir/Interfaces/ViewLikeInterface.td"
def Source1IsSharedEncoding: NativeOpTrait<"Source1IsSharedEncoding">;
def ResultsAreSharedEncoding: NativeOpTrait<"ResultsAreSharedEncoding">;
class TTNG_Op<string mnemonic, list<Trait> traits = []> :
Op<TritonNvidiaGPU_Dialect, mnemonic, traits>;
// --------------------------------------------------------------------------------------------------
// MBarrier related Ops:
// 1, These mbarrier commands are currently not needed, and not taken into consideration:
// (1), mbarrier.expect_tx
// (2), mbarrier.arrive_drop
// (3), mbarrier.complete_tx
// (4), mbarrier.inval
//
// 2, The mbarriers is supported to be created in vector, and accessed in seperate via tensor.extract.
// The mbarriers created in vector will have counters initialized in the same configuration. A
// typical example to demonstrate this:
//
// %1 = triton_nvidia_gpu.alloc_mbarrier { count = 1 } : tensor<4x!tt.ptr<i64>>
// scf.for %iv = %lb to %ub step %step iter_args() -> () {
// %buffer_id = arith.remi %iv, %c4 : i32
// %2 = triton_nvidia_gpu.extract_mbarrier %1[%buffer_id] : tensor<4xi64>, i32 -> !tt.ptr<i64>
// triton_nvidia_gpu.mbarrier_arrive %2 {expectTx = 2048} : !tt.ptr<i64> -> ()
// }
// ...
// scf.for %iv = %lb to %ub step %step iter_args() -> () {
// %buffer_id = arith.remi %iv, %c4 : i32
// %2 = triton_nvidia_gpu.extract_mbarrier %1[%buffer_id] : tensor<4xi64>, i32 -> !tt.ptr<i64>
// triton_nvidia_gpu.mbarrier_wait %2, %c0 : !tt.ptr<i64>, i1 -> ()
// }
def TTNG_AllocMBarrierOp : TTNG_Op<"alloc_mbarrier", [MemoryEffects<[MemAlloc]>]> {
let summary = "allocate a vector of mbarriers";
let description = [{
Allocate and initialize a vector of mbarriers. The size of the vector is implied in the returned type.
Each mbarrier is initialized as:
1, the current phase initialized to 0.
2, the expected arrival count initialized to 'count'.
3, the pending arrival count initialized to 'count'.
4, the tx-count initialized to 0.
Example:
case a. when created in vector:
%1 = triton_nvidia_gpu.alloc_mbarrier { count = 1 } : tensor<4xi64>
case b. when created in scalar:
%1 = triton_nvidia_gpu.alloc_mbarrier { count = 1 } : !tt.ptr<i64>
}];
let assemblyFormat = [{attr-dict `:` type($result)}];
let arguments = (ins I32Attr:$count);
let results = (outs AnyTypeOf<[TT_Ptr, I64Tensor]>:$result);
}
def TTNG_ExtractMBarrierOp : TTNG_Op<"extract_mbarrier", [Pure]> {
let summary = "extract a mbarrier from a vector of mbarriers";
let description = [{
Extract a mbarrier from a vector of mbarriers
Example:
%1 = triton_nvidia_gpu.extract_mbarrier %mbarriers[%idx] : tensor<4xi64>, index -> !tt.ptr<i64>
}];
let assemblyFormat = "$tensor `[` $index `]` attr-dict `:` type($tensor) `,` type($index) `->` type($result)";
let arguments = (ins I64Tensor:$tensor, I32:$index);
let results = (outs TT_Ptr:$result);
}
def TTNG_MBarrierWaitOp : TTNG_Op<"mbarrier_wait", [MemoryEffects<[MemRead, MemWrite]>]> {
let summary = "mbarrier wait";
let description = [{
This operation defining the waiting action for a mbarrier.
The subsequent operations should not execute until this operation completes waiting.
Example:
triton_nvidia_gpu.mbarrier_wait %0, %1 : !tt.ptr<i64>
}];
let arguments = (ins TT_Ptr:$mbarrier, I1: $phase);
let assemblyFormat = "$mbarrier `,` $phase attr-dict `:` type($mbarrier)";
}
def TTNG_MBarrierArriveOp : TTNG_Op<"mbarrier_arrive", [AttrSizedOperandSegments,
MemoryEffects<[MemWrite]>]> {
let summary = "mbarrier arrive";
let description = [{
This operation defining the arriving action for a mbarrier.
txCount:
An optional attribute that set tx-count. This Op will be lowered into
mbarrier.arrive.expect_tx if the optional attribute exist.
trackAsyncOp:
If true, this op will be lowered into cp.async.mbarrier.arrive.noinc.
pred:
Only perform arrive action when pred is true.
remoteCtaId:
if set, perform an remote arrive action.
Example:
triton_nvidia_gpu.mbarrier_arrive %0 {trackAsyncOp = false} : !tt.ptr<i64>
}];
let arguments = (ins TT_Ptr:$mbarrier,
Optional<I1>:$pred,
Optional<I32>:$remoteCtaId,
I1Attr: $trackAsyncOp,
DefaultValuedAttr<I32Attr, "0">: $txCount
);
let assemblyFormat = "operands attr-dict `:` type(operands)";
}
def TTNG_FenceAsyncSharedOp : TTNG_Op<"fence_async_shared"> {
let arguments = (ins BoolAttr:$bCluster);
let summary = "fence proxy async";
let assemblyFormat = "attr-dict";
let extraClassDeclaration = [{
static bool isSupported(int computeCapability) {
return computeCapability >= 90;
}
}];
}
// TODO[goostavz]: ThreadId & ClusterCTAId should not be exposed to
// ttgpu level. Remove them when async dialect is ready.
def TTNG_GetThreadIdOp : TTNG_Op<"get_thread_id", [Pure]> {
let description = [{
Returns the one dimensional threadId.
}];
let results = (outs I32:$result);
let assemblyFormat = "attr-dict `:` type($result)";
}
def TTNG_GetClusterCTAIdOp : TTNG_Op<"get_cluster_cta_id", [Pure]> {
let description = [{
Returns the one dimensional cluster_cta_id.
}];
let results = (outs I32:$result);
let assemblyFormat = "attr-dict `:` type($result)";
}
def TTNG_NamedBarrierArriveOp : TTNG_Op<"bar_arrive", []> {
let summary = "named barrier arrive";
let arguments = (ins I32:$bar, I32: $numThreads);
let assemblyFormat = "$bar `,` $numThreads attr-dict `:` type(operands)";
}
def TTNG_NamedBarrierWaitOp : TTNG_Op<"bar_wait", []> {
let summary = "named barrier wait";
let arguments = (ins I32:$bar, I32: $numThreads);
let assemblyFormat = "$bar `,` $numThreads attr-dict `:` type(operands)";
}
def TTNG_InsertSliceAsyncV2Op : TTNG_Op<"insert_slice_async_v2",
[AttrSizedOperandSegments,
ResultsAreSharedEncoding,
// TODO: Check if MemWrite will degrade performance of non-warp-specialized kernel
MemoryEffects<[MemRead, MemWrite]>]> {
let arguments = (ins AnyTypeOf<[TT_Ptr, TT_PtrTensor]>:$src, TT_Tensor:$dst,
I32:$index, TT_Ptr:$mbar,
Optional<AnyTypeOf<[I1Tensor, I1]>>:$mask, Optional<TT_Type>:$other,
TT_CacheModifierAttr:$cache, TT_EvictionPolicyAttr:$evict,
BoolAttr:$isVolatile, I32Attr:$axis);
let results = (outs TT_Tensor:$result);
let assemblyFormat = "operands attr-dict `:` type(operands) `->` type($result)";
}
// TODO: the abstraction of barriers in ttgpu level is pending, will revisit later
// def TTNG_AwaitOp : TTNG_Op<"await", []> {
// let arguments = (ins TTNG_TokenType:$token);
// let assemblyFormat = "$token attr-dict `:` type($token)";
// }
def TTNG_ClusterArriveOp : TTNG_Op<"cluster_arrive", []> {
let arguments = (ins I1Attr:$relaxed);
let assemblyFormat = "attr-dict";
}
def TTNG_ClusterWaitOp : TTNG_Op<"cluster_wait", []> {
let assemblyFormat = "attr-dict";
}
//
// DotAsync Op
//
def TTNG_DotAsyncOp : TTNG_Op<"dot_async", [Pure,
DeclareOpInterfaceMethods<InferTypeOpInterface>,
TypesMatchWith<"result's type matches accumulator's type",
"d", "c", "$_self">]> {
let summary = "dot async";
let description = [{
$d = matrix_multiply($a, $b) + $c
}];
let arguments = (ins TT_FpIntTensor:$a, TT_FpIntTensor:$b, TT_FpIntTensor:$c, BoolAttr:$allowTF32);
let results = (outs TT_FpIntTensor:$d);
let assemblyFormat = "$a`,` $b`,` $c attr-dict `:` type($a) `*` type($b) `->` type($d)";
}
def TTNG_DotWaitOp : TTNG_Op<"dot_wait", []> {
let summary = "dot wait";
let description = [{
This operation defining the waiting action for a async dot, MMAv3 .e.g.
The subsequent operations should not execute until this operation completes waiting.
}];
let arguments = (ins I32Attr:$pendings);
let assemblyFormat = "attr-dict";
}
def TTNG_StoreAsyncOp : TTNG_Op<"store_async",
[Source1IsSharedEncoding,
MemoryEffects<[MemWrite]>]> {
let summary = "store asynchronous by a tensor pointer";
let arguments = (ins TT_TensorPtr:$dst, TT_Tensor:$src,
DefaultValuedAttr<TT_CacheModifierAttr, "triton::CacheModifier::NONE">:$cache);
let assemblyFormat = "operands attr-dict `:` type(operands)";
}
def TTNG_GetAgentIdOp : TTNG_Op<"get_agent_id", [Pure]> {
let results = (outs I32:$result);
let builders = [OpBuilder<(ins)>];
let assemblyFormat = "attr-dict `:` type($result)";
}
//
// Token
//
def TTNG_CreateTokenOp : TTNG_Op<"create_token"> {
let results = (outs TensorOf<[TTNG_TokenType]>:$result);
let arguments = (ins I32Attr:$num);
let builders = [OpBuilder<(ins "uint32_t":$num)>];
let assemblyFormat = "attr-dict `:` type($result)";
}
def TTNG_ProducerAcquireOp : TTNG_Op<"producer_acquire"> {
let arguments = (ins TensorOf<[TTNG_TokenType]>:$token, I32:$idx);
let assemblyFormat = "$token `,` $idx attr-dict `:` type(operands)";
}
def TTNG_ProducerCommitOp : TTNG_Op<"producer_commit"> {
let arguments = (ins TensorOf<[TTNG_TokenType]>:$token, I32:$idx);
let assemblyFormat = "$token `,` $idx attr-dict `:` type(operands)";
}
def TTNG_ConsumerWaitOp : TTNG_Op<"consumer_wait"> {
let arguments = (ins TensorOf<[TTNG_TokenType]>:$token, I32:$idx);
let assemblyFormat = "$token `,` $idx attr-dict `:` type(operands)";
}
def TTNG_ConsumerReleaseOp : TTNG_Op<"consumer_release"> {
let arguments = (ins TensorOf<[TTNG_TokenType]>:$token, I32:$idx);
let assemblyFormat = "$token `,` $idx attr-dict `:` type(operands)";
}
//
// Mutex
//
def TTNG_GetMutexRoleIdOp : TTNG_Op<"get_mutex_role_id"> {
let results = (outs I32:$result);
let arguments = (ins I32Attr:$num);
let builders = [OpBuilder<(ins "uint32_t":$num)>];
let assemblyFormat = "attr-dict `:` type($result)";
}
def TTNG_CreateMutexOp : TTNG_Op<"create_mutex"> {
let results = (outs TTNG_MutexType:$result);
let builders = [OpBuilder<(ins)>];
let assemblyFormat = "attr-dict `:` type($result)";
}
def TTNG_LockOp : TTNG_Op<"lock"> {
let arguments = (ins TTNG_MutexType:$mutex);
let assemblyFormat = "$mutex attr-dict `:` type(operands)";
}
def TTNG_UnlockOp : TTNG_Op<"unlock"> {
let arguments = (ins TTNG_MutexType:$mutex);
let assemblyFormat = "$mutex attr-dict `:` type(operands)";
}
def TTNG_RegAllocOp : TTNG_Op<"reg_alloc", []> {
let summary = "register allocation";
let arguments = (ins I32Attr: $regCount);
let assemblyFormat = "$regCount attr-dict `:` type(operands)";
}
def TTNG_RegDeallocOp : TTNG_Op<"reg_dealloc", []> {
let summary = "register deallocation";
let arguments = (ins I32Attr: $regCount);
let assemblyFormat = "$regCount attr-dict `:` type(operands)";
}
#endif

View File

@@ -0,0 +1,37 @@
// Copyright (c) 2023 NVIDIA Corporation & Affiliates. All rights reserved.
//
// Permission is hereby granted, free of charge, to any person obtaining
// a copy of this software and associated documentation files
// (the "Software"), to deal in the Software without restriction,
// including without limitation the rights to use, copy, modify, merge,
// publish, distribute, sublicense, and/or sell copies of the Software,
// and to permit persons to whom the Software is furnished to do so,
// subject to the following conditions:
//
// The above copyright notice and this permission notice shall be
// included in all copies or substantial portions of the Software.
//
// THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND,
// EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF
// MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT.
// IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY
// CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION OF CONTRACT,
// TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION WITH THE
// SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE.
#ifndef TRITONNVIDIAGPU_TYPES
#define TRITONNVIDIAGPU_TYPES
include "triton/Dialect/TritonNvidiaGPU/IR/TritonNvidiaGPUDialect.td"
include "mlir/IR/AttrTypeBase.td"
class TTNG_TypeDef<string name, string _mnemonic>
: TypeDef<TritonNvidiaGPU_Dialect, name> {
let mnemonic = _mnemonic;
}
def TTNG_TokenType : TTNG_TypeDef<"Token", "token">;
def TTNG_MutexType : TTNG_TypeDef<"Mutex", "mutex">;
#endif

View File

@@ -0,0 +1,33 @@
/*
* Copyright (c) 2023 NVIDIA Corporation & Affiliates. All rights reserved.
*
* Permission is hereby granted, free of charge, to any person obtaining
* a copy of this software and associated documentation files
* (the "Software"), to deal in the Software without restriction,
* including without limitation the rights to use, copy, modify, merge,
* publish, distribute, sublicense, and/or sell copies of the Software,
* and to permit persons to whom the Software is furnished to do so,
* subject to the following conditions:
*
* The above copyright notice and this permission notice shall be
* included in all copies or substantial portions of the Software.
*
* THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND,
* EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF
* MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT.
* IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY
* CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION OF CONTRACT,
* TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION WITH THE
* SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE.
*/
#ifndef TRITONNVIDIAGPU_IR_TYPES_H_
#define TRITONNVIDIAGPU_IR_TYPES_H_
#include "mlir/IR/TypeSupport.h"
#include "mlir/IR/Types.h"
#define GET_TYPEDEF_CLASSES
#include "triton/Dialect/TritonNvidiaGPU/IR/Types.h.inc"
#endif // TRITON_IR_TYPES_H_

View File

@@ -0,0 +1,3 @@
set(LLVM_TARGET_DEFINITIONS Passes.td)
mlir_tablegen(Passes.h.inc -gen-pass-decls -name TritonNvidiaGPU)
add_public_tablegen_target(TritonNvidiaGPUTransformsIncGen)

View File

@@ -0,0 +1,81 @@
/*
* Copyright (c) 2023 NVIDIA Corporation & Affiliates. All rights reserved.
*
* Permission is hereby granted, free of charge, to any person obtaining
* a copy of this software and associated documentation files
* (the "Software"), to deal in the Software without restriction,
* including without limitation the rights to use, copy, modify, merge,
* publish, distribute, sublicense, and/or sell copies of the Software,
* and to permit persons to whom the Software is furnished to do so,
* subject to the following conditions:
*
* The above copyright notice and this permission notice shall be
* included in all copies or substantial portions of the Software.
*
* THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND,
* EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF
* MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT.
* IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY
* CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION OF CONTRACT,
* TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION WITH THE
* SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE.
*/
#ifndef TRITON_DIALECT_TRITONNVIDIAGPU_TRANSFORMS_PASSES_H_
#define TRITON_DIALECT_TRITONNVIDIAGPU_TRANSFORMS_PASSES_H_
#include "mlir/Pass/Pass.h"
namespace mlir {
namespace triton {
namespace nvidia_gpu {
// Used by Triton runtime
struct ClusterInfo {
ClusterInfo() : clusterDimX(1), clusterDimY(1), clusterDimZ(1) {}
int clusterDimX;
int clusterDimY;
int clusterDimZ;
};
} // namespace nvidia_gpu
} // namespace triton
} // namespace mlir
namespace mlir {
std::unique_ptr<Pass>
createTritonNvidiaGPUMaterializeLoadStorePass(int numWarps = 4,
int computeCapability = 80);
std::unique_ptr<Pass> createTritonNvidiaGPUPlanCTAPass(
mlir::triton::nvidia_gpu::ClusterInfo *clusterInfo = nullptr);
std::unique_ptr<Pass>
createTritonNvidiaGPUWSFeasibilityCheckingPass(int computeCapability = 90);
std::unique_ptr<Pass>
createTritonNvidiaGPUWSDecomposingPass(int computeCapability = 90);
std::unique_ptr<Pass>
createTritonNvidiaGPUWSPipelinePass(int numStages = 3, int numWarps = 4,
int computeCapability = 90);
std::unique_ptr<Pass>
createTritonNvidiaGPUWSMutexPass(int computeCapability = 90);
std::unique_ptr<Pass>
createTritonNvidiaGPUWSMaterializationPass(int computeCapability = 90);
std::unique_ptr<Pass>
createTritonNvidiaGPUFenceInsertionPass(int computeCapability = 90);
std::unique_ptr<Pass>
createTritonGPURewriteTensorPointerPass(int computeCapability = 80);
/// Generate the code for registering passes.
#define GEN_PASS_REGISTRATION
#include "triton/Dialect/TritonNvidiaGPU/Transforms/Passes.h.inc"
} // namespace mlir
#endif // TRITON_DIALECT_TRITONNVIDIAGPU_TRANSFORMS_PASSES_H_

View File

@@ -0,0 +1,228 @@
// Copyright (c) 2023 NVIDIA Corporation & Affiliates. All rights reserved.
//
// Permission is hereby granted, free of charge, to any person obtaining
// a copy of this software and associated documentation files
// (the "Software"), to deal in the Software without restriction,
// including without limitation the rights to use, copy, modify, merge,
// publish, distribute, sublicense, and/or sell copies of the Software,
// and to permit persons to whom the Software is furnished to do so,
// subject to the following conditions:
//
// The above copyright notice and this permission notice shall be
// included in all copies or substantial portions of the Software.
//
// THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND,
// EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF
// MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT.
// IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY
// CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION OF CONTRACT,
// TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION WITH THE
// SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE.
#ifndef TRITONNVIDIAGPU_PASSES
#define TRITONNVIDIAGPU_PASSES
include "mlir/Pass/PassBase.td"
def MaterializeLoadStore : Pass<"triton-nvidia-gpu-materialize-load-store", "mlir::ModuleOp"> {
let summary = "materialize load & store";
let description = [{
This pass works after pipeline pass, converting the remaining tt.LoadOp taking
ptr<tensor> as input into ttg.InsertSliceAsyncOp and emit proper barriers
}];
let constructor = "mlir::createTritonNvidiaGPUMaterializeLoadStorePass()";
let dependentDialects = ["mlir::triton::gpu::TritonGPUDialect",
"mlir::triton::nvidia_gpu::TritonNvidiaGPUDialect",
"mlir::scf::SCFDialect",
"mlir::arith::ArithDialect"];
let options = [
Option<"numWarps", "num-warps",
"int32_t", /*default*/"4",
"number of warps per block">,
Option<"computeCapability", "compute-capability",
"int32_t", /*default*/"80",
"device compute capability">
];
}
def TritonGPUPlanCTAPass : Pass<"triton-nvidia-gpu-plan-cta", "mlir::ModuleOp"> {
let summary = "plan CTA";
let description = [{
Plan CTAs in CGA
}];
let constructor = "mlir::createTritonNvidiaGPUPlanCTAPass()";
let dependentDialects = [
"mlir::triton::gpu::TritonGPUDialect",
"mlir::triton::nvidia_gpu::TritonNvidiaGPUDialect"
];
}
def TritonGPUWSFeasibilityChecking : Pass<"triton-nvidia-gpu-ws-feasibility-checking", "mlir::ModuleOp"> {
let summary = "Attach attr named TritonNvidiaGPUDialect::getWSSupportedAttrName() if auto WS supported";
let description = [{
Since not every legal triton kernels can be auto WS, this pass does some (conservative) check
and attaches an attribute named TritonNvidiaGPUDialect::getWSSupportedAttrName() on
the input module op if the kernel is supported.
}];
let constructor = "mlir::createTritonNvidiaGPUWSFeasibilityCheckingPass()";
let dependentDialects = [
"mlir::triton::gpu::TritonGPUDialect",
"mlir::triton::nvidia_gpu::TritonNvidiaGPUDialect"
];
let options = [
Option<"computeCapability", "compute-capability",
"int32_t", /*default*/"90",
"device compute capability">
];
}
def TritonGPUWSDecomposing : Pass<"triton-nvidia-gpu-ws-decomposing", "mlir::ModuleOp"> {
let summary = "Clustering on the ops according to their performance hotspots";
let description = [{
Based on compute capability and heuristics,
this pass will identify some operations to be executed in different agents,
by marking them with async 'label'. E.g.,
input:
%1 = tt,load %0 ...
%4 = tt.dot %1, %2, %3 ...
output:
%1 = tt,load %0 {async_agent = 0} ...
%4 = tt.dot %1, %2, %3 {async_agent = 1} : ...
}];
let constructor = "mlir::createTritonNvidiaGPUWSDecomposingPass()";
let dependentDialects = [
"mlir::triton::gpu::TritonGPUDialect",
"mlir::triton::nvidia_gpu::TritonNvidiaGPUDialect"
];
let options = [
Option<"computeCapability", "compute-capability",
"int32_t", /*default*/"80",
"device compute capability">
];
}
def TritonGPUWSPipeline : Pass<"triton-nvidia-gpu-ws-pipeline", "mlir::ModuleOp"> {
let summary = "Warp specialization pipeline";
let description = [{
}];
let constructor = "mlir::createTritonNvidiaGPUWSPipelinePass()";
let dependentDialects = ["mlir::triton::gpu::TritonGPUDialect",
"mlir::triton::nvidia_gpu::TritonNvidiaGPUDialect",
"mlir::scf::SCFDialect",
"mlir::arith::ArithDialect"];
let options = [
Option<"numStages", "num-stages",
"int32_t", /*default*/"3",
"number of pipeline stages">,
Option<"numWarps", "num-warps",
"int32_t", /*default*/"12",
"number of warps per block">,
Option<"computeCapability", "compute-capability",
"int32_t", /*default*/"90",
"device compute capability">
];
}
def TritonGPUWSMutex : Pass<"triton-nvidia-gpu-ws-mutex", "mlir::ModuleOp"> {
let summary = "Warp specialization mutex syncronization";
let description = [{
create mutex syncronization for persistent kernel. (as "2 Math WG" persistent kernel in cutlass)
For example, the agent containing dot and store will be divided into two sub-agent,
which execute dot and store alternately. i.e.:
sub-agent-0: dot | store | dot | ... | store
sub-agent-1: | dot | store | ... | dot | store
}];
let constructor = "mlir::createTritonNvidiaGPUWSMutexPass()";
let dependentDialects = ["mlir::triton::gpu::TritonGPUDialect",
"mlir::triton::nvidia_gpu::TritonNvidiaGPUDialect",
"mlir::scf::SCFDialect",
"mlir::arith::ArithDialect"];
let options = [
Option<"computeCapability", "compute-capability",
"int32_t", /*default*/"80",
"device compute capability">
];
}
def TritonGPUWSMaterialization : Pass<"triton-nvidia-gpu-ws-materialization", "mlir::ModuleOp"> {
let summary = "Warp specialization materialization";
let description = [{
}];
let constructor = "mlir::createTritonNvidiaGPUWSMaterializationPass()";
let dependentDialects = ["mlir::triton::gpu::TritonGPUDialect",
"mlir::triton::nvidia_gpu::TritonNvidiaGPUDialect"];
let options = [
Option<"computeCapability", "compute-capability",
"int32_t", /*default*/"90",
"device compute capability">
];
}
def TritonGPUFenceInsertion : Pass<"triton-nvidia-gpu-fence-insertion", "mlir::ModuleOp"> {
let summary = "Insert fences across generic and async proxy";
let description = [{
}];
let constructor = "mlir::createTritonNvidiaGPUFenceInsertionPass()";
let dependentDialects = [
"mlir::triton::gpu::TritonGPUDialect",
"mlir::triton::nvidia_gpu::TritonNvidiaGPUDialect"
];
let options = [
Option<"computeCapability", "compute-capability",
"int32_t", /*default*/"90",
"device compute capability">
];
}
def TritonGPURewriteTensorPointer : Pass</*cli-arg*/"tritongpu-rewrite-tensor-pointer", /*Op*/"mlir::ModuleOp"> {
let summary = "Rewrite load/stores with tensor pointers into legacy load/stores";
let description = [{
This pass rewrites all load/store semantics initiated by a `tt.make_tensor_ptr` and `tt.advance` into legacy
semantics. After this pass, `tt.make_tensor_ptr` and `tt.advance` will disappear, and it generates logics to compute
the pointer/mask/other for each load/store.
}];
let constructor = "mlir::createTritonGPURewriteTensorPointerPass()";
let dependentDialects = ["mlir::triton::gpu::TritonGPUDialect",
"mlir::triton::TritonDialect"];
let options = [
Option<"computeCapability", "compute-capability",
"int32_t", /*default*/"80",
"device compute capability">
];
}
#endif

View File

@@ -0,0 +1,95 @@
/*
* Copyright (c) 2023 NVIDIA Corporation & Affiliates. All rights reserved.
*
* Permission is hereby granted, free of charge, to any person obtaining
* a copy of this software and associated documentation files
* (the "Software"), to deal in the Software without restriction,
* including without limitation the rights to use, copy, modify, merge,
* publish, distribute, sublicense, and/or sell copies of the Software,
* and to permit persons to whom the Software is furnished to do so,
* subject to the following conditions:
*
* The above copyright notice and this permission notice shall be
* included in all copies or substantial portions of the Software.
*
* THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND,
* EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF
* MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT.
* IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY
* CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION OF CONTRACT,
* TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION WITH THE
* SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE.
*/
#ifndef TRITON_DIALECT_TRITONNVIDIAGPU_TRANSFORMS_UTILITY_H_
#define TRITON_DIALECT_TRITONNVIDIAGPU_TRANSFORMS_UTILITY_H_
#include "mlir/IR/Matchers.h"
#include "mlir/IR/PatternMatch.h"
#include "llvm/ADT/MapVector.h"
namespace mlir {
// 0 is reserved for default sync.
// TODO: comprehensive mechanism to globally manage namedbarrier.
static int const nameBarrierIdBegin = 1;
static int nameBarrierIdEnd = 16;
/// Helper functions for async agent
typedef int AgentId;
SmallVector<AgentId> getAgentIds(Operation *op);
bool hasAgentId(Operation *op, AgentId agentId);
void setAgentIds(Operation *op, ArrayRef<AgentId> agentIds);
SmallVector<AgentId> collectAgentIds(Operation *op);
void addAgentIds(Operation *op, ArrayRef<int> agents);
SmallVector<int> getMutexBarIds(Operation *op);
SmallVector<int> getMutexNumThreads(Operation *op);
class OpBuilderWithAgentIds : public OpBuilder {
public:
OpBuilderWithAgentIds(MLIRContext *context) : OpBuilder(context) {}
void setAgentIdsFromArray(ArrayRef<AgentId> newAgentIds) {
agentIds = SmallVector<AgentId>(newAgentIds.begin(), newAgentIds.end());
}
void setAgentIdsFromOp(Operation *op) {
setAgentIdsFromArray(getAgentIds(op));
}
void setAgentIdsFromValueUsers(Value value) {
SetVector<AgentId> agentIdSet;
for (Operation *user : value.getUsers())
for (AgentId agentId : getAgentIds(user))
agentIdSet.insert(agentId);
setAgentIdsFromArray(agentIdSet.getArrayRef());
}
template <typename OpTy, typename... Args>
OpTy createWithAgentIds(Args &&...args) {
OpTy op = create<OpTy>(std::forward<Args>(args)...);
if (!agentIds.empty())
setAgentIds(op, agentIds);
return op;
}
private:
SmallVector<AgentId> agentIds;
};
/// Constant agent ids
constexpr AgentId kLoadAgentId = 0;
constexpr AgentId kDotAgentId = 1;
bool isWSCandidateLoad(Operation *op);
bool isWSSupported(ModuleOp m, int computeCapability);
LogicalResult getDependentValues(Value val, DenseSet<Value> &depSet,
const DenseSet<Value> &stopSet = {});
LogicalResult getDependentValues(Operation *op, DenseSet<Value> &depSet,
const DenseSet<Value> &stopSet = {});
DenseSet<Operation *> getDependentOps(DenseSet<Value> &depSet);
} // namespace mlir
#endif // TRITON_DIALECT_TRITONNVIDIAGPU_TRANSFORMS_UTILITY_H_

View File

@@ -0,0 +1,19 @@
#ifndef TRITON_TARGET_AMDGCNTRANSLATION_H
#define TRITON_TARGET_AMDGCNTRANSLATION_H
#include <string>
#include <tuple>
namespace llvm {
class Module;
} // namespace llvm
namespace triton {
// Translate LLVM IR to AMDGCN code.
std::tuple<std::string, std::string>
translateLLVMIRToAMDGCN(llvm::Module &module, std::string cc);
} // namespace triton
#endif

View File

@@ -1,5 +1,6 @@
#ifndef TRITON_TARGET_LLVM_IR_LLVM_IR_TRANSLATION_H
#define TRITON_TARGET_LLVM_IR_LLVM_IR_TRANSLATION_H
#include "triton/Target/PTX/TmaMetadata.h"
#include "llvm/ADT/StringRef.h"
#include <memory>
#include <string>
@@ -26,6 +27,7 @@ void addExternalLibs(mlir::ModuleOp &module,
std::unique_ptr<llvm::Module>
translateTritonGPUToLLVMIR(llvm::LLVMContext *llvmContext,
mlir::ModuleOp module, int computeCapability,
mlir::triton::gpu::TMAMetadataTy &tmaInfos,
bool isROCM);
// Translate mlir LLVM dialect to LLVMIR, return null if failed.
@@ -33,6 +35,9 @@ std::unique_ptr<llvm::Module>
translateLLVMToLLVMIR(llvm::LLVMContext *llvmContext, mlir::ModuleOp module,
bool isROCM);
bool linkExternLib(llvm::Module &module, llvm::StringRef name,
llvm::StringRef path, bool isROCM);
} // namespace triton
} // namespace mlir

View File

@@ -0,0 +1,107 @@
/*
* Copyright (c) 2023 NVIDIA Corporation & Affiliates. All rights reserved.
*
* Permission is hereby granted, free of charge, to any person obtaining
* a copy of this software and associated documentation files
* (the "Software"), to deal in the Software without restriction,
* including without limitation the rights to use, copy, modify, merge,
* publish, distribute, sublicense, and/or sell copies of the Software,
* and to permit persons to whom the Software is furnished to do so,
* subject to the following conditions:
*
* The above copyright notice and this permission notice shall be
* included in all copies or substantial portions of the Software.
*
* THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND,
* EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF
* MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT.
* IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY
* CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION OF CONTRACT,
* TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION WITH THE
* SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE.
*/
#ifndef TRITON_TARGET_PTX_TMAMETADATA_H
#define TRITON_TARGET_PTX_TMAMETADATA_H
#include "python/triton/third_party/cuda/include/cuda.h"
#include "llvm/ADT/StringRef.h"
#include "llvm/Support/Debug.h"
#include "llvm/Support/Format.h"
#include "llvm/Support/FormatVariadic.h"
#include <map>
#include <utility>
#include <vector>
namespace mlir {
namespace triton {
namespace gpu {
struct TMAInfo {
// --------------------------------------------
// informations to be filled into CUtensorMaps
int tensorDataType;
uint32_t tensorRank;
// the argument indices for the runtime to get globalAddresses
size_t globalAddressArgIdx;
// the argument indices for the runtime to get globalDims, -1 stands for this
// dim is padded
std::vector<int32_t> globalDimsArgIdx;
// the argument indices for the runtime to get globalStrides, -1 stands for
// this dim is padded the runtime need to map the value to internal format
std::vector<int32_t> globalStridesArgIdx;
std::vector<uint32_t> boxDims;
std::vector<uint32_t> elementStrides;
int interleave;
int swizzle;
int l2Promotion;
int oobFill;
// --------------------------------------------
// the argument indices for the runtime to send the address of tma_desc to the
// binary
int TMADescArgIdx;
template <typename T>
void dump_vec(const std::vector<T> &vec, llvm::StringRef info) const {
llvm::errs() << info << ": ";
for (const T &e : vec)
llvm::errs() << e << ",";
llvm::errs() << "\n";
}
void dump() const {
llvm::errs() << "TMA Info: ----------"
<< "\n";
llvm::errs() << "-- tensorDataType: " << tensorDataType
<< ", tensorRank: " << tensorRank << "\n";
llvm::errs() << "-- globalAddressArgIdx: " << globalAddressArgIdx << "\n";
llvm::errs() << "-- TMADescArgIdx: " << TMADescArgIdx << "\n";
dump_vec<int32_t>(globalDimsArgIdx, "-- globalDimsArgIdx");
dump_vec<int32_t>(globalStridesArgIdx, "-- globalStridesArgIdx");
dump_vec<uint32_t>(boxDims, "-- boxDims");
dump_vec<uint32_t>(elementStrides, "-- elementStrides");
llvm::errs() << "-- interleave: " << interleave << "\n";
llvm::errs() << "-- swizzle: " << swizzle << "\n";
llvm::errs() << "-- l2Promotion: " << l2Promotion << "\n";
llvm::errs() << "-- oobFill: " << oobFill << "\n";
};
};
using TMAMetadataTy = std::vector<TMAInfo>;
} // namespace gpu
} // namespace triton
} // namespace mlir
#endif // TRITON_TARGET_PTX_TMAMETADATA_H

View File

@@ -25,9 +25,13 @@
#include <algorithm>
#include <cstdlib>
#include <string>
namespace triton {
const std::set<std::string> ENV_VARS = {
"ENABLE_MMA_V3", "TRITON_DISABLE_LINE_INFO", "DISABLE_FAST_REDUCTION",
"ENABLE_TMA", "MLIR_ENABLE_DUMP", "LLVM_IR_ENABLE_DUMP",
"AMDGCN_ENABLE_DUMP"};
namespace tools {
inline std::string getenv(const char *name) {
@@ -39,6 +43,9 @@ inline std::string getenv(const char *name) {
}
inline bool getBoolEnv(const std::string &env) {
std::string msg = "Environment variable " + env + " is not recognized";
assert(triton::ENV_VARS.find(env.c_str()) != triton::ENV_VARS.end() &&
msg.c_str());
const char *s = std::getenv(env.c_str());
std::string str(s ? s : "");
std::transform(str.begin(), str.end(), str.begin(),

View File

@@ -2,6 +2,7 @@
#include "mlir/Dialect/Tensor/IR/Tensor.h"
#include "triton/Analysis/Utility.h"
#include "triton/Dialect/TritonGPU/IR/Dialect.h"
#include "triton/Dialect/TritonNvidiaGPU/IR/Dialect.h"
namespace mlir {
@@ -27,17 +28,21 @@ void SharedMemoryAliasAnalysis::visitOperation(
// These ops may allocate a new shared memory buffer.
auto result = op->getResult(0);
// XXX(Keren): the following ops are always aliasing for now
if (isa<triton::gpu::ExtractSliceOp, triton::TransOp>(op)) {
if (isa<triton::gpu::ExtractSliceOp, triton::TransOp,
triton::nvidia_gpu::ExtractMBarrierOp>(op)) {
// extract_slice %src
// trans %src
aliasInfo = AliasInfo(operands[0]->getValue());
pessimistic = false;
} else if (isa<tensor::InsertSliceOp, triton::gpu::InsertSliceAsyncOp>(
op)) {
} else if (isa<tensor::InsertSliceOp, triton::gpu::InsertSliceAsyncOp,
triton::nvidia_gpu::InsertSliceAsyncV2Op>(op)) {
// insert_slice_async %src, %dst, %index
// insert_slice %src into %dst[%offsets]
aliasInfo = AliasInfo(operands[1]->getValue());
pessimistic = false;
} else if (isa<triton::nvidia_gpu::StoreAsyncOp>(op)) {
aliasInfo = AliasInfo(operands[0]->getValue());
pessimistic = false;
} else if (triton::gpu::isSharedEncoding(result)) {
aliasInfo.insert(result);
pessimistic = false;

View File

@@ -16,6 +16,7 @@ using ::mlir::triton::gpu::DotOperandEncodingAttr;
using ::mlir::triton::gpu::getContigPerThread;
using ::mlir::triton::gpu::getOrder;
using ::mlir::triton::gpu::getShapePerCTA;
using ::mlir::triton::gpu::getShapePerCTATile;
using ::mlir::triton::gpu::getSizePerThread;
using ::mlir::triton::gpu::MmaEncodingAttr;
using ::mlir::triton::gpu::SharedEncodingAttr;
@@ -57,11 +58,23 @@ getScratchConfigForCvtLayout(triton::gpu::ConvertLayoutOp op, unsigned &inVec,
Attribute srcLayout = srcTy.getEncoding();
Attribute dstLayout = dstTy.getEncoding();
// MmaToDotShortcut doesn't use shared mem
if (srcLayout.isa<MmaEncodingAttr>() &&
dstLayout.isa<DotOperandEncodingAttr>())
if (isMmaToDotShortcut(srcTy, dstTy))
return {};
if (shouldUseDistSmem(srcLayout, dstLayout)) {
// TODO: padding to avoid bank conflicts
return convertType<unsigned, int64_t>(getShapePerCTA(srcTy));
}
// MmaToDotShortcut and MmaToMmaShortcut doesn't use shared mem
if (auto srcMmaLayout = srcLayout.dyn_cast<MmaEncodingAttr>()) {
if (dstLayout.isa<DotOperandEncodingAttr>()) {
if (isMmaToDotShortcut(srcTy, dstTy)) {
return {};
}
} else if (auto dstMmaLayout = dstLayout.dyn_cast<MmaEncodingAttr>()) {
if (isMmaToMmaShortcut(srcTy, dstTy)) {
return {};
}
}
}
assert(srcLayout && dstLayout &&
"Unexpected layout in getScratchConfigForCvtLayout()");
@@ -73,18 +86,18 @@ getScratchConfigForCvtLayout(triton::gpu::ConvertLayoutOp op, unsigned &inVec,
inVec = outOrd[0] == 0 ? 1 : inOrd[0] == 0 ? 1 : srcContigPerThread;
outVec = outOrd[0] == 0 ? 1 : dstContigPerThread;
auto srcShape = srcTy.getShape();
auto dstShape = dstTy.getShape();
auto srcShapePerCTA = getShapePerCTA(srcLayout, srcShape);
auto dstShapePerCTA = getShapePerCTA(dstLayout, dstShape);
auto srcShapePerCTA = getShapePerCTA(srcTy);
auto dstShapePerCTA = getShapePerCTA(dstTy);
auto srcShapePerCTATile = getShapePerCTATile(srcLayout, srcTy.getShape());
auto dstShapePerCTATile = getShapePerCTATile(dstLayout, dstTy.getShape());
unsigned rank = dstTy.getRank();
SmallVector<unsigned> paddedRepShape(rank);
unsigned pad = std::max(inVec, outVec);
for (unsigned d = 0; d < rank; ++d) {
paddedRepShape[d] =
std::max(std::min<unsigned>(srcTy.getShape()[d], srcShapePerCTA[d]),
std::min<unsigned>(dstTy.getShape()[d], dstShapePerCTA[d]));
std::max(std::min<unsigned>(srcShapePerCTA[d], srcShapePerCTATile[d]),
std::min<unsigned>(dstShapePerCTA[d], dstShapePerCTATile[d]));
}
if (rank == 1)
return paddedRepShape;
@@ -146,20 +159,45 @@ private:
// For example: %a = scf.if -> yield
// %a must be allocated elsewhere by other operations.
// FIXME(Keren): extract and insert are always alias for now
if (!maybeSharedAllocationOp(op) || maybeAliasOp(op)) {
if (!maybeSharedAllocationOp(op) || maybeAliasOp(op))
return;
}
// XXX(Keren): Why this hard-coded alignment?
size_t kAlignment = 8;
for (Value result : op->getResults()) {
if (triton::gpu::isSharedEncoding(result)) {
// Bytes could be a different value once we support padding or other
// allocation policies.
auto tensorType = result.getType().dyn_cast<RankedTensorType>();
auto bytes = tensorType.getNumElements() *
auto shapePerCTA = triton::gpu::getShapePerCTA(tensorType);
auto bytes = product<int64_t>(shapePerCTA) *
tensorType.getElementTypeBitWidth() / 8;
allocation->addBuffer<BufferT::BufferKind::Explicit>(result, bytes);
// XXX(Keren): magic numbers 256 and 1024
// benzh@maybe alignment should be passed in.
// Software swizzling calculates phase based on offset, while hardware
// swizzling do that based on physical address. Thus only by setting the
// alignment to 1024 can ensure the correctness. 
if (bytes > 256)
kAlignment = 1024;
allocation->addBuffer<BufferT::BufferKind::Explicit>(result, bytes,
kAlignment);
}
}
if (isa<triton::nvidia_gpu::AllocMBarrierOp>(op)) {
Value result = op->getResult(0);
if (!result.getType().isa<RankedTensorType>())
// In case AllocMBarrierOp is allocating scalar mbarriers
allocation->addBuffer<BufferT::BufferKind::Explicit>(result, 8,
kAlignment);
}
}
template <BufferT::BufferKind T>
void maybeAddScratchBuffer(Operation *op, unsigned bytes,
unsigned alignment) {
if (bytes > 0)
allocation->addBuffer<T>(op, bytes, alignment);
}
template <BufferT::BufferKind T>
@@ -170,14 +208,17 @@ private:
/// Initializes temporary shared memory for a given operation.
void getScratchValueSize(Operation *op) {
const size_t scratchAlignment = 128;
if (auto reduceOp = dyn_cast<triton::ReduceOp>(op)) {
ReduceOpHelper helper(reduceOp);
unsigned bytes = helper.getScratchSizeInBytes();
maybeAddScratchBuffer<BufferT::BufferKind::Scratch>(op, bytes);
maybeAddScratchBuffer<BufferT::BufferKind::Scratch>(op, bytes,
scratchAlignment);
} else if (auto scanOp = dyn_cast<triton::ScanOp>(op)) {
ScanLoweringHelper helper(scanOp);
unsigned bytes = helper.getScratchSizeInBytes();
maybeAddScratchBuffer<BufferT::BufferKind::Scratch>(op, bytes);
maybeAddScratchBuffer<BufferT::BufferKind::Scratch>(op, bytes,
scratchAlignment);
} else if (auto cvtLayout = dyn_cast<triton::gpu::ConvertLayoutOp>(op)) {
auto srcTy = cvtLayout.getSrc().getType().cast<RankedTensorType>();
auto dstTy = cvtLayout.getResult().getType().cast<RankedTensorType>();
@@ -201,7 +242,8 @@ private:
srcTy.getElementType().isa<triton::PointerType>()
? elems * kPtrBitWidth / 8
: elems * std::max<int>(8, srcTy.getElementTypeBitWidth()) / 8;
maybeAddScratchBuffer<BufferT::BufferKind::Scratch>(op, bytes);
maybeAddScratchBuffer<BufferT::BufferKind::Scratch>(op, bytes,
scratchAlignment);
} else if (auto atomicRMWOp = dyn_cast<triton::AtomicRMWOp>(op)) {
auto value = op->getOperand(0);
// only scalar requires scratch memory
@@ -218,7 +260,8 @@ private:
elemTy.isa<triton::PointerType>()
? elems * kPtrBitWidth / 8
: elems * std::max<int>(8, elemTy.getIntOrFloatBitWidth()) / 8;
maybeAddScratchBuffer<BufferT::BufferKind::Scratch>(op, bytes);
maybeAddScratchBuffer<BufferT::BufferKind::Scratch>(op, bytes,
scratchAlignment);
}
} else if (auto atomicCASOp = dyn_cast<triton::AtomicCASOp>(op)) {
auto value = op->getOperand(0);
@@ -230,13 +273,15 @@ private:
auto bytes = elemTy.isa<triton::PointerType>()
? elems * kPtrBitWidth / 8
: elems * elemTy.getIntOrFloatBitWidth() / 8;
maybeAddScratchBuffer<BufferT::BufferKind::Scratch>(op, bytes);
maybeAddScratchBuffer<BufferT::BufferKind::Scratch>(op, bytes,
scratchAlignment);
} else if (auto callOp = dyn_cast<CallOpInterface>(op)) {
auto callable = callOp.resolveCallable();
auto funcOp = dyn_cast<FunctionOpInterface>(callable);
auto *funcAlloc = &(*funcAllocMap)[funcOp];
auto bytes = funcAlloc->getSharedMemorySize();
maybeAddScratchBuffer<BufferT::BufferKind::Virtual>(op, bytes);
maybeAddScratchBuffer<BufferT::BufferKind::Virtual>(op, bytes,
scratchAlignment);
}
}
@@ -356,6 +401,12 @@ private:
// Analyze liveness of explicit buffers
Liveness liveness(operation);
auto getValueLivenessRange = [&](Value value) {
// Shared memory allocated by mbarrier cannot be reused
if (value.getDefiningOp() &&
isa<triton::nvidia_gpu::AllocMBarrierOp>(value.getDefiningOp()))
return Interval(std::numeric_limits<size_t>::min(),
std::numeric_limits<size_t>::max());
auto liveOperations = liveness.resolveLiveness(value);
auto minId = std::numeric_limits<size_t>::max();
auto maxId = std::numeric_limits<size_t>::min();
@@ -437,17 +488,22 @@ private:
auto xRange = bufferRange[buffer];
bool res = xRange.intersects(range);
for (auto val : tripleMap)
res = res && !val.second.intersects(xRange);
res = res &&
!val.second.intersects(xRange); // only one buffer intersect
return res;
});
if (bufferIt != xBuffers.end()) {
auto buffer = *bufferIt;
auto xSize = buffer->size;
auto xRange = bufferRange.lookup(buffer);
bufferStart[buffer] = size;
tripleMap.insert(
{size + xSize, Interval{std::max(range.start(), xRange.start()),
std::min(range.end(), xRange.end())}});
// TODO(Keren): A buffer's size shouldn't be determined here, have to
// clean it up
size_t alignment = buffer->alignment;
size_t alignSize = ((size + alignment - 1) / alignment) * alignment;
bufferStart[buffer] = alignSize;
tripleMap.insert({alignSize + xSize,
Interval{std::max(range.start(), xRange.start()),
std::min(range.end(), xRange.end())}});
// We could either insert (range.start, xRange.start) or (range.start,
// xRange.end), both are correct and determine the potential buffer
// offset, and the graph coloring algorithm will solve the interference,

View File

@@ -14,4 +14,5 @@ add_mlir_library(TritonAnalysis
MLIRLLVMDialect
TritonIR
TritonGPUIR
TritonNvidiaGPUIR
)

View File

@@ -2,7 +2,11 @@
#include "triton/Analysis/Alias.h"
#include "triton/Dialect/TritonGPU/IR/Dialect.h"
#include "../lib/Conversion/TritonGPUToLLVM/Utility.h"
#include "mlir/Dialect/Func/IR/FuncOps.h"
#include "triton/Conversion/TritonGPUToLLVM/PTXAsmFormat.h"
#include "triton/Dialect/TritonNvidiaGPU/Transforms/Utility.h"
#include "mlir/Dialect/GPU/IR/GPUDialect.h"
#include "mlir/Dialect/Tensor/IR/Tensor.h"
#include <deque>
@@ -103,7 +107,11 @@ void MembarAnalysis::update(Operation *op, BlockInfo *blockInfo,
return;
}
if (isa<gpu::BarrierOp>(op)) {
// TODO(Keren): Don't expose LLVM Dialect ops here
if (isa<gpu::BarrierOp>(op) ||
(isa<LLVM::InlineAsmOp>(op) &&
(dyn_cast<LLVM::InlineAsmOp>(op).getAsmString().find("bar.sync") !=
std::string::npos))) {
// If the current op is a barrier, we sync previous reads and writes
blockInfo->sync();
return;
@@ -169,12 +177,23 @@ void MembarAnalysis::update(Operation *op, BlockInfo *blockInfo,
if (blockInfo->isIntersected(curBlockInfo)) {
OpBuilder::InsertionGuard g(*builder);
builder->setInsertionPoint(op);
builder->create<gpu::BarrierOp>(op->getLoc());
blockInfo->sync();
// TODO(Keren): Don't expose LLVM Dialect ops here
// TODO[shuhaoj]: Change hard code style of numThreads. Hide async_agent
// attr. Better way to determine barId (number of agents are limited).
if (op->hasAttr("async_agent")) {
int agentId = getAgentIds(op).front(), roleId = 0;
if (op->hasAttr("agent.mutex_role"))
roleId = op->getAttrOfType<IntegerAttr>("agent.mutex_role").getInt();
int barId = agentId + roleId + nameBarrierIdBegin;
assert(barId < nameBarrierIdEnd);
barSync(*builder, op, barId, 128);
} else {
builder->create<gpu::BarrierOp>(op->getLoc());
blockInfo->sync();
}
}
// Update the region info, even if barrier is inserted, we have to maintain
// the current op's read/write buffers.
blockInfo->join(curBlockInfo);
}
} // namespace mlir

View File

@@ -1,10 +1,14 @@
#include "triton/Analysis/Utility.h"
#include "mlir/Analysis/DataFlow/ConstantPropagationAnalysis.h"
#include "mlir/Analysis/DataFlow/DeadCodeAnalysis.h"
#include "mlir/Conversion/LLVMCommon/Pattern.h"
#include "mlir/Dialect/ControlFlow/IR/ControlFlowOps.h"
#include "mlir/Dialect/LLVMIR/LLVMDialect.h"
#include "mlir/IR/Dialect.h"
#include "mlir/IR/Matchers.h"
#include "triton/Dialect/Triton/IR/Dialect.h"
#include "triton/Dialect/TritonGPU/IR/Dialect.h"
#include "triton/Dialect/TritonNvidiaGPU/IR/Dialect.h"
#include "triton/Tools/Sys/GetEnv.hpp"
#include <deque>
@@ -37,6 +41,51 @@ bool ReduceOpHelper::isFastReduction() {
getParentOrder(getSrcLayout())[0];
}
// Cases where distributed shared memory is not required in ConvertLayout:
// (1) numCTAs == 1
// (2) numCTAs > 1 but srcCTALayout == dstCTALayout
// TODO: Case with SliceLayout as srcLayout and numCTAs > 1 is to be implemented
// in the future
bool shouldUseDistSmem(Attribute srcLayout, Attribute dstLayout) {
unsigned numCTAs = triton::gpu::getNumCTAs(srcLayout);
assert(numCTAs == triton::gpu::getNumCTAs(dstLayout) &&
"Invalid layout conversion: the numbers of CTAs of src and dst "
"layouts are different");
// Case (1): Never use dsmem when numCTAs == 1
if (numCTAs == 1)
return false;
// Case where CTAsPerCGA of srcLayout in the sliced dim is not 1 is not
// implemented yet
if (auto sliceLayout = srcLayout.dyn_cast<triton::gpu::SliceEncodingAttr>()) {
auto dim = sliceLayout.getDim();
auto CTAsPerCGA = triton::gpu::getCTAsPerCGA(sliceLayout.getParent());
if (CTAsPerCGA[dim] != 1)
assert(0 && "Layout conversion to be implemented");
}
// Case where CTAsPerCGA of dstLayout in the sliced dim is not 1 is supported
if (auto sliceLayout = dstLayout.dyn_cast<triton::gpu::SliceEncodingAttr>()) {
auto dim = sliceLayout.getDim();
auto CTAsPerCGA = triton::gpu::getCTAsPerCGA(sliceLayout.getParent());
if (CTAsPerCGA[dim] != 1)
return true;
}
// The above two branches make sure that it is legal to call getCTALayout of
// srcLayout and dstLayout
// Case (2): Do not use dsmem when srcCTALayout == dstCTALayout
auto srcCTALayout = triton::gpu::getCTALayout(srcLayout);
auto dstCTALayout = triton::gpu::getCTALayout(dstLayout);
if (srcCTALayout == dstCTALayout)
return false;
// Dsmem access is required when srcCTALayout != dstCTALayout
return true;
}
unsigned ReduceOpHelper::getInterWarpSize() {
auto srcReduceDimSize = static_cast<unsigned>(srcShape[axis]);
unsigned sizeIntraWarps = getIntraWarpSize();
@@ -136,7 +185,7 @@ bool ReduceOpHelper::isSupportedLayout() {
return true;
}
if (auto mmaLayout = srcLayout.dyn_cast<triton::gpu::MmaEncodingAttr>()) {
if (mmaLayout.isAmpere()) {
if (mmaLayout.isAmpere() || mmaLayout.isHopper()) {
return true;
}
}
@@ -282,6 +331,8 @@ bool maybeSharedAllocationOp(Operation *op) {
return dialect &&
(dialect->getTypeID() ==
mlir::TypeID::get<triton::gpu::TritonGPUDialect>() ||
dialect->getTypeID() ==
mlir::TypeID::get<triton::nvidia_gpu::TritonNvidiaGPUDialect>() ||
dialect->getTypeID() == mlir::TypeID::get<triton::TritonDialect>() ||
dialect->getTypeID() == mlir::TypeID::get<arith::ArithDialect>() ||
dialect->getTypeID() == mlir::TypeID::get<tensor::TensorDialect>());
@@ -290,6 +341,8 @@ bool maybeSharedAllocationOp(Operation *op) {
bool maybeAliasOp(Operation *op) {
return isa<triton::gpu::ExtractSliceOp>(op) || isa<triton::TransOp>(op) ||
isa<triton::gpu::InsertSliceAsyncOp>(op) ||
isa<triton::nvidia_gpu::InsertSliceAsyncV2Op>(op) ||
isa<triton::nvidia_gpu::StoreAsyncOp>(op) ||
isa<tensor::InsertSliceOp>(op);
}
@@ -299,6 +352,21 @@ bool supportMMA(triton::DotOp op, int version) {
// https://docs.nvidia.com/cuda/parallel-thread-execution/index.html#warp-level-matrix-fragment-mma-884-f16
auto aElemTy = op.getA().getType().cast<RankedTensorType>().getElementType();
auto bElemTy = op.getB().getType().cast<RankedTensorType>().getElementType();
if (version == 3) {
if (!::triton::tools::getBoolEnv("ENABLE_MMA_V3"))
return false;
auto retType = op.getResult().getType().cast<RankedTensorType>();
auto retShapePerCTA = triton::gpu::getShapePerCTA(retType);
auto mod = op->getParentOfType<mlir::ModuleOp>();
int numWarps = triton::gpu::TritonGPUDialect::getNumWarps(mod);
if (!(numWarps % 4 == 0 && retShapePerCTA[0] % 64 == 0 &&
retShapePerCTA[1] % 8 == 0 &&
(aElemTy.isFloat8E5M2() || aElemTy.isFloat8E4M3FN() ||
aElemTy.isInteger(8) || aElemTy.isF16() || aElemTy.isBF16() ||
aElemTy.isF32()))) {
return false;
}
}
if (aElemTy.isF32() && bElemTy.isF32()) {
return op.getAllowTF32() && version >= 2;
}
@@ -306,24 +374,21 @@ bool supportMMA(triton::DotOp op, int version) {
}
bool supportMMA(Value value, int version) {
// Tell whether a DotOp support HMMA by the operand type(either $a or $b).
// Tell whether a DotOp support MMA by the operand type(either $a or $b).
// We cannot get both the operand types(in TypeConverter), here we assume the
// types of both the operands are identical here.
assert((version == 1 || version == 2) &&
assert((version == 1 || version == 2 || version == 3) &&
"Unexpected MMA layout version found");
auto elemTy = value.getType().cast<RankedTensorType>().getElementType();
return elemTy.isF16() || elemTy.isBF16() ||
// FP8 is not natively supported on all mma versions but it can always be
// promoted to fp16 therefore we can always support it.
bool isFP8 = elemTy.isFloat8E5M2() || elemTy.isFloat8E4M3FN() ||
elemTy.isFloat8E5M2FNUZ() || elemTy.isFloat8E4M3FNUZ();
return isFP8 || elemTy.isF16() || elemTy.isBF16() ||
(elemTy.isF32() && version >= 2) ||
(elemTy.isInteger(8) && version >= 2);
}
Type getElementType(Value value) {
auto type = value.getType();
if (auto tensorType = type.dyn_cast<RankedTensorType>())
return tensorType.getElementType();
return type;
}
bool isMmaToDotShortcut(RankedTensorType &srcTy, RankedTensorType &dstTy) {
// dot_op<opIdx=0, parent=#mma> = #mma
// when #mma = MmaEncoding<version=2, warpsPerCTA=[..., 1]>
@@ -338,6 +403,17 @@ bool isMmaToDotShortcut(RankedTensorType &srcTy, RankedTensorType &dstTy) {
!srcTy.getElementType().isF32();
}
bool isMmaToMmaShortcut(RankedTensorType &srcTy, RankedTensorType &dstTy) {
auto src = srcTy.getEncoding().cast<triton::gpu::MmaEncodingAttr>();
auto dst = dstTy.getEncoding().cast<triton::gpu::MmaEncodingAttr>();
auto srcElemsPerThread = triton::gpu::getTotalElemsPerThread(srcTy);
auto dstElemsPerThread = triton::gpu::getTotalElemsPerThread(dstTy);
// when #mma = MmaEncoding<version=3, warpsPerCTA=[..., 1]>
return src.getVersionMajor() == 3 && src.getWarpsPerCTA()[1] == 1 &&
dst.getVersionMajor() == 3 && dst.getWarpsPerCTA()[1] == 1 &&
srcElemsPerThread == dstElemsPerThread;
}
bool isSingleValue(Value value) {
// Don't consider load as expensive if it is loading a scalar.
if (auto tensorTy = value.getType().dyn_cast<RankedTensorType>())
@@ -557,4 +633,81 @@ std::unique_ptr<DataFlowSolver> createDataFlowSolver() {
return solver;
}
static triton::MakeTensorPtrOp getMakeTensorPtrOpImpl(Operation *op, Value v) {
if (auto makeTensorPtrOp = dyn_cast<triton::MakeTensorPtrOp>(op)) {
return makeTensorPtrOp;
}
if (auto advanceOp = dyn_cast<triton::AdvanceOp>(op)) {
return getMakeTensorPtrOp(advanceOp.getPtr());
}
if (auto branch = dyn_cast<RegionBranchOpInterface>(op)) {
auto idx = v.cast<OpResult>().getResultNumber();
llvm::SmallVector<scf::YieldOp> yieldOps;
op->walk([&](Operation *op) {
if (auto yieldOp = dyn_cast<scf::YieldOp>(op))
yieldOps.push_back(yieldOp);
});
// benzh@ if multi yields, all yields operand should come from same arg.
Value newValue = yieldOps[0].getOperands()[idx];
return getMakeTensorPtrOp(newValue);
}
llvm_unreachable("Unable to getMakeTensorPtr()");
}
triton::MakeTensorPtrOp getMakeTensorPtrOp(Value v) {
using BranchOps = llvm::SetVector<std::pair<Operation *, int>>;
llvm::DenseMap<Block *, BranchOps> blockToCFOps;
auto moduleOp =
v.getParentBlock()->getParentOp()->getParentOfType<ModuleOp>();
moduleOp.walk([&](Operation *op) {
if (auto br = dyn_cast<cf::BranchOp>(op)) {
Block *block = br.getDest();
blockToCFOps[block].insert({op, -1});
}
if (auto condBr = dyn_cast<cf::CondBranchOp>(op)) {
Block *blockT = condBr.getTrueDest();
Block *blockF = condBr.getFalseDest();
blockToCFOps[blockT].insert({condBr, 1});
blockToCFOps[blockF].insert({condBr, 0});
}
});
if (Operation *definingOp = v.getDefiningOp()) {
return getMakeTensorPtrOpImpl(definingOp, v);
} else if (BlockArgument arg = v.cast<BlockArgument>()) {
unsigned argNum = arg.getArgNumber();
Operation *argOwner = arg.getOwner()->getParentOp();
if (auto forOp = dyn_cast<scf::ForOp>(argOwner)) {
return getMakeTensorPtrOp(
forOp.getOperand(argNum + forOp.getNumControlOperands() - 1));
} else if (auto funcOp = dyn_cast<mlir::triton::FuncOp>(argOwner)) {
Block *block = arg.getOwner();
Operation *op;
int tOrF;
std::tie(op, tOrF) = blockToCFOps[block][0];
if (auto br = dyn_cast<cf::BranchOp>(op)) {
return getMakeTensorPtrOp(br.getDestOperands()[argNum]);
}
if (auto condBr = dyn_cast<cf::CondBranchOp>(op)) {
if (tOrF) {
return getMakeTensorPtrOp(condBr.getTrueDestOperands()[argNum]);
} else {
return getMakeTensorPtrOp(condBr.getFalseDestOperands()[argNum]);
}
}
} else {
return getMakeTensorPtrOp(argOwner->getOperand(argNum));
}
}
llvm_unreachable("Unable to getMakeTensorPtr()");
}
} // namespace mlir

View File

@@ -3,3 +3,4 @@ add_subdirectory(Analysis)
add_subdirectory(Conversion)
add_subdirectory(Dialect)
add_subdirectory(Target)
add_subdirectory(Hopper)

View File

@@ -0,0 +1,217 @@
/*
* Copyright (c) 2023 NVIDIA Corporation & Affiliates. All rights reserved.
*
* Permission is hereby granted, free of charge, to any person obtaining
* a copy of this software and associated documentation files
* (the "Software"), to deal in the Software without restriction,
* including without limitation the rights to use, copy, modify, merge,
* publish, distribute, sublicense, and/or sell copies of the Software,
* and to permit persons to whom the Software is furnished to do so,
* subject to the following conditions:
*
* The above copyright notice and this permission notice shall be
* included in all copies or substantial portions of the Software.
*
* THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND,
* EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF
* MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT.
* IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY
* CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION OF CONTRACT,
* TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION WITH THE
* SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE.
*/
#include "BarrierOpToLLVM.h"
using namespace mlir;
using namespace mlir::triton;
// --------------------------------------------------------------------------
// -- MBarrier related Ops lowering, to be moved to a seperate file ---------
// --------------------------------------------------------------------------
struct AllocMBarrierOpConversion : public ConvertTritonGPUOpToLLVMPattern<
triton::nvidia_gpu::AllocMBarrierOp> {
using ConvertTritonGPUOpToLLVMPattern<
triton::nvidia_gpu::AllocMBarrierOp>::ConvertTritonGPUOpToLLVMPattern;
LogicalResult
matchAndRewrite(triton::nvidia_gpu::AllocMBarrierOp op, OpAdaptor adaptor,
ConversionPatternRewriter &rewriter) const override {
Location loc = op->getLoc();
Value smemBase = getSharedMemoryBase(loc, rewriter, op.getResult());
auto resultTy = op.getType();
auto resultTensorTy = resultTy.dyn_cast<RankedTensorType>();
Type elemPtrTy;
if (resultTensorTy) {
auto llvmElemTy =
getTypeConverter()->convertType(resultTensorTy.getElementType());
elemPtrTy = ptr_ty(llvmElemTy, 3);
} else {
elemPtrTy = getTypeConverter()->convertType(resultTy);
}
smemBase = bitcast(smemBase, elemPtrTy);
auto threadId = getThreadId(rewriter, loc);
auto pred = icmp_eq(threadId, i32_val(0));
int numMBarriers = 1;
if (resultTensorTy) {
assert(resultTensorTy.getRank() == 1 &&
"unexpected rank for AllocMBarrierOp");
numMBarriers = resultTensorTy.getShape()[0];
}
for (int i = 0; i < numMBarriers; ++i) {
Value smem = smemBase;
if (i > 0) {
smem = gep(elemPtrTy, smem, i32_val(i));
}
rewriter.create<triton::nvgpu::MBarrierInitOp>(loc, smem, pred,
op.getCount());
}
if (resultTensorTy) {
auto smemObj = SharedMemoryObject(smemBase, resultTensorTy.getShape(),
{0}, loc, rewriter);
auto retVal = getStructFromSharedMemoryObject(loc, smemObj, rewriter);
rewriter.replaceOp(op, retVal);
} else {
rewriter.replaceOp(op, smemBase);
}
return success();
}
};
struct MBarrierArriveOpConversion : public ConvertTritonGPUOpToLLVMPattern<
triton::nvidia_gpu::MBarrierArriveOp> {
using ConvertTritonGPUOpToLLVMPattern<
triton::nvidia_gpu::MBarrierArriveOp>::ConvertTritonGPUOpToLLVMPattern;
LogicalResult
matchAndRewrite(triton::nvidia_gpu::MBarrierArriveOp op, OpAdaptor adaptor,
ConversionPatternRewriter &rewriter) const override {
Location loc = op->getLoc();
auto mbarrier = adaptor.getMbarrier();
bool trackAsyncOp = op.getTrackAsyncOp();
triton::nvgpu::MBarriveType type = triton::nvgpu::MBarriveType::normal;
uint32_t txCount = op.getTxCount();
auto remoteCtaId = adaptor.getRemoteCtaId();
if (trackAsyncOp) {
type = triton::nvgpu::MBarriveType::cp_async;
} else if (remoteCtaId) {
assert(txCount == 0 &&
"remote arrive of transaction mbarrier is not implemented yet");
type = triton::nvgpu::MBarriveType::remote;
} else if (txCount > 0) {
type = triton::nvgpu::MBarriveType::expect_tx;
}
Value pred = adaptor.getPred();
if (pred == nullptr) {
pred = int_val(/*width*/ 1, 1);
}
rewriter.replaceOpWithNewOp<triton::nvgpu::MBarrierArriveOp>(
op, mbarrier, pred, remoteCtaId, type, txCount);
return success();
}
};
struct MBarrierWaitOpConversion : public ConvertTritonGPUOpToLLVMPattern<
triton::nvidia_gpu::MBarrierWaitOp> {
using ConvertTritonGPUOpToLLVMPattern<
triton::nvidia_gpu::MBarrierWaitOp>::ConvertTritonGPUOpToLLVMPattern;
LogicalResult
matchAndRewrite(triton::nvidia_gpu::MBarrierWaitOp op, OpAdaptor adaptor,
ConversionPatternRewriter &rewriter) const override {
Location loc = op->getLoc();
rewriter.replaceOpWithNewOp<triton::nvgpu::MBarrierWaitOp>(
op, adaptor.getMbarrier(), adaptor.getPhase());
return success();
}
};
struct ExtractMBarrierOpConversion
: public ConvertTritonGPUOpToLLVMPattern<
triton::nvidia_gpu::ExtractMBarrierOp> {
using ConvertTritonGPUOpToLLVMPattern<
triton::nvidia_gpu::ExtractMBarrierOp>::ConvertTritonGPUOpToLLVMPattern;
LogicalResult
matchAndRewrite(triton::nvidia_gpu::ExtractMBarrierOp op, OpAdaptor adaptor,
ConversionPatternRewriter &rewriter) const override {
Location loc = op->getLoc();
auto elemTy =
op.getTensor().getType().cast<RankedTensorType>().getElementType();
auto tensorStruct = adaptor.getTensor();
auto index = adaptor.getIndex();
auto ptrTy =
LLVM::LLVMPointerType::get(getTypeConverter()->convertType(elemTy), 3);
auto basePtr =
extract_val(ptrTy, tensorStruct, rewriter.getDenseI64ArrayAttr(0));
Value result = gep(ptrTy, basePtr, index);
rewriter.replaceOp(op, result);
return success();
}
};
struct NamedBarrierArriveOpConversion
: public ConvertTritonGPUOpToLLVMPattern<
triton::nvidia_gpu::NamedBarrierArriveOp> {
using ConvertTritonGPUOpToLLVMPattern<
triton::nvidia_gpu::NamedBarrierArriveOp>::
ConvertTritonGPUOpToLLVMPattern;
LogicalResult
matchAndRewrite(triton::nvidia_gpu::NamedBarrierArriveOp op,
OpAdaptor adaptor,
ConversionPatternRewriter &rewriter) const override {
Location loc = op->getLoc();
rewriter.replaceOpWithNewOp<triton::nvgpu::NamedBarrierArriveOp>(
op, adaptor.getBar(), adaptor.getNumThreads());
return success();
}
};
struct NamedBarrierWaitOpConversion
: public ConvertTritonGPUOpToLLVMPattern<
triton::nvidia_gpu::NamedBarrierWaitOp> {
using ConvertTritonGPUOpToLLVMPattern<
triton::nvidia_gpu::NamedBarrierWaitOp>::ConvertTritonGPUOpToLLVMPattern;
LogicalResult
matchAndRewrite(triton::nvidia_gpu::NamedBarrierWaitOp op, OpAdaptor adaptor,
ConversionPatternRewriter &rewriter) const override {
Location loc = op->getLoc();
rewriter.replaceOpWithNewOp<triton::nvgpu::NamedBarrierWaitOp>(
op, adaptor.getBar(), adaptor.getNumThreads());
return success();
}
};
struct FenceAsyncSharedOpConversion
: public ConvertTritonGPUOpToLLVMPattern<
triton::nvidia_gpu::FenceAsyncSharedOp> {
using ConvertTritonGPUOpToLLVMPattern<
triton::nvidia_gpu::FenceAsyncSharedOp>::ConvertTritonGPUOpToLLVMPattern;
LogicalResult
matchAndRewrite(triton::nvidia_gpu::FenceAsyncSharedOp op, OpAdaptor adaptor,
ConversionPatternRewriter &rewriter) const override {
Location loc = op->getLoc();
rewriter.replaceOpWithNewOp<triton::nvgpu::FenceAsyncSharedOp>(
op, adaptor.getBCluster());
return success();
}
};
void populateBarrierOpToLLVMPatterns(
TritonGPUToLLVMTypeConverter &typeConverter, RewritePatternSet &patterns,
int numWarps, ModuleAxisInfoAnalysis &axisInfoAnalysis,
ModuleAllocation &allocation, PatternBenefit benefit) {
patterns.add<AllocMBarrierOpConversion>(typeConverter, allocation, benefit);
patterns.add<MBarrierArriveOpConversion>(typeConverter, allocation, benefit);
patterns.add<MBarrierWaitOpConversion>(typeConverter, allocation, benefit);
patterns.add<ExtractMBarrierOpConversion>(typeConverter, allocation, benefit);
patterns.add<NamedBarrierArriveOpConversion>(typeConverter, allocation,
benefit);
patterns.add<NamedBarrierWaitOpConversion>(typeConverter, allocation,
benefit);
patterns.add<FenceAsyncSharedOpConversion>(typeConverter, allocation,
benefit);
}

View File

@@ -0,0 +1,37 @@
/*
* Copyright (c) 2023 NVIDIA Corporation & Affiliates. All rights reserved.
*
* Permission is hereby granted, free of charge, to any person obtaining
* a copy of this software and associated documentation files
* (the "Software"), to deal in the Software without restriction,
* including without limitation the rights to use, copy, modify, merge,
* publish, distribute, sublicense, and/or sell copies of the Software,
* and to permit persons to whom the Software is furnished to do so,
* subject to the following conditions:
*
* The above copyright notice and this permission notice shall be
* included in all copies or substantial portions of the Software.
*
* THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND,
* EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF
* MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT.
* IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY
* CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION OF CONTRACT,
* TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION WITH THE
* SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE.
*/
#ifndef TRITON_CONVERSION_TRITONGPU_TO_LLVM_BARRIER_OP_H
#define TRITON_CONVERSION_TRITONGPU_TO_LLVM_BARRIER_OP_H
#include "TritonGPUToLLVMBase.h"
using namespace mlir;
using namespace mlir::triton;
void populateBarrierOpToLLVMPatterns(
TritonGPUToLLVMTypeConverter &typeConverter, RewritePatternSet &patterns,
int numWarps, ModuleAxisInfoAnalysis &axisInfoAnalysis,
ModuleAllocation &allocation, PatternBenefit benefit);
#endif

View File

@@ -1,4 +1,16 @@
add_mlir_conversion_library(TritonGPUToLLVM
ConvertLayoutOpToLLVM/SharedToDotOperandFMA.cpp
ConvertLayoutOpToLLVM/SharedToDotOperandMMAv1.cpp
ConvertLayoutOpToLLVM/SharedToDotOperandMMAv2.cpp
ConvertLayoutOpToLLVM.cpp
DotOpToLLVM/FMA.cpp
DotOpToLLVM/MMAv1.cpp
DotOpToLLVM/MMAv2.cpp
DotOpToLLVM/WGMMA.cpp
DotOpToLLVM.cpp
ElementwiseOpToLLVM.cpp
LoadStoreOpToLLVM.cpp
BarrierOpToLLVM.cpp
TritonGPUToLLVM.cpp
GCNAsmFormat.cpp
PTXAsmFormat.cpp
@@ -15,12 +27,16 @@ add_mlir_conversion_library(TritonGPUToLLVM
LoadStoreOpToLLVM.cpp
TritonGPUToLLVM.cpp
TritonGPUToLLVMPass.cpp
GCNAsmFormat.cpp
PTXAsmFormat.cpp
ReduceOpToLLVM.cpp
ScanOpToLLVM.cpp
Utility.cpp
TypeConverter.cpp
Utility.cpp
ViewOpToLLVM.cpp
TensorPtrOpsToLLVM.cpp
ClusterOpsToLLVM.cpp
RegReallocOpToLLVM.cpp
ADDITIONAL_HEADER_DIRS
${PROJECT_SOURCE_DIR}/include/triton/Conversion/TritonGPUToLLVM
@@ -43,4 +59,6 @@ add_mlir_conversion_library(TritonGPUToLLVM
TritonIR
TritonGPUIR
TritonGPUTransforms
TritonNvidiaGPUTransforms
NVGPUIR
)

View File

@@ -0,0 +1,62 @@
/*
* Copyright (c) 2023 NVIDIA Corporation & Affiliates. All rights reserved.
*
* Permission is hereby granted, free of charge, to any person obtaining
* a copy of this software and associated documentation files
* (the "Software"), to deal in the Software without restriction,
* including without limitation the rights to use, copy, modify, merge,
* publish, distribute, sublicense, and/or sell copies of the Software,
* and to permit persons to whom the Software is furnished to do so,
* subject to the following conditions:
*
* The above copyright notice and this permission notice shall be
* included in all copies or substantial portions of the Software.
*
* THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND,
* EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF
* MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT.
* IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY
* CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION OF CONTRACT,
* TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION WITH THE
* SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE.
*/
#include "ClusterOpsToLLVM.h"
using namespace mlir;
using namespace mlir::triton;
struct ClusterArriveOpConversion : public ConvertTritonGPUOpToLLVMPattern<
triton::nvidia_gpu::ClusterArriveOp> {
using ConvertTritonGPUOpToLLVMPattern<
triton::nvidia_gpu::ClusterArriveOp>::ConvertTritonGPUOpToLLVMPattern;
LogicalResult
matchAndRewrite(triton::nvidia_gpu::ClusterArriveOp op, OpAdaptor adaptor,
ConversionPatternRewriter &rewriter) const override {
rewriter.replaceOpWithNewOp<triton::nvgpu::ClusterArriveOp>(
op, op.getRelaxed());
return success();
}
};
struct ClusterWaitOpConversion : public ConvertTritonGPUOpToLLVMPattern<
triton::nvidia_gpu::ClusterWaitOp> {
using ConvertTritonGPUOpToLLVMPattern<
triton::nvidia_gpu::ClusterWaitOp>::ConvertTritonGPUOpToLLVMPattern;
LogicalResult
matchAndRewrite(triton::nvidia_gpu::ClusterWaitOp op, OpAdaptor adaptor,
ConversionPatternRewriter &rewriter) const override {
rewriter.replaceOpWithNewOp<triton::nvgpu::ClusterWaitOp>(op);
return success();
}
};
void populateClusterOpsToLLVMPatterns(
TritonGPUToLLVMTypeConverter &typeConverter, RewritePatternSet &patterns,
int numWarps, ModuleAxisInfoAnalysis &axisInfoAnalysis,
ModuleAllocation &allocation, PatternBenefit benefit) {
patterns.add<ClusterArriveOpConversion>(typeConverter, benefit);
patterns.add<ClusterWaitOpConversion>(typeConverter, benefit);
return;
}

View File

@@ -0,0 +1,37 @@
/*
* Copyright (c) 2023 NVIDIA Corporation & Affiliates. All rights reserved.
*
* Permission is hereby granted, free of charge, to any person obtaining
* a copy of this software and associated documentation files
* (the "Software"), to deal in the Software without restriction,
* including without limitation the rights to use, copy, modify, merge,
* publish, distribute, sublicense, and/or sell copies of the Software,
* and to permit persons to whom the Software is furnished to do so,
* subject to the following conditions:
*
* The above copyright notice and this permission notice shall be
* included in all copies or substantial portions of the Software.
*
* THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND,
* EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF
* MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT.
* IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY
* CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION OF CONTRACT,
* TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION WITH THE
* SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE.
*/
#ifndef TRITON_CONVERSION_TRITONGPU_TO_LLVM_CLUSTER_OPS_H
#define TRITON_CONVERSION_TRITONGPU_TO_LLVM_CLUSTER_OPS_H
#include "TritonGPUToLLVMBase.h"
using namespace mlir;
using namespace mlir::triton;
void populateClusterOpsToLLVMPatterns(
TritonGPUToLLVMTypeConverter &typeConverter, RewritePatternSet &patterns,
int numWarps, ModuleAxisInfoAnalysis &axisInfoAnalysis,
ModuleAllocation &allocation, PatternBenefit benefit);
#endif

View File

@@ -1,14 +1,18 @@
#include "ConvertLayoutOpToLLVM.h"
#include "Utility.h"
#include "triton/Dialect/TritonNvidiaGPU/Transforms/Utility.h"
using ::mlir::LLVM::delinearize;
using ::mlir::LLVM::getSharedMemoryObjectFromStruct;
using ::mlir::LLVM::getStridesFromShapeAndOrder;
using ::mlir::LLVM::linearize;
using ::mlir::LLVM::getSharedMemoryObjectFromStruct;
using ::mlir::LLVM::getStridesFromShapeAndOrder;
using ::mlir::triton::gpu::DotOperandEncodingAttr;
using ::mlir::triton::gpu::getContigPerThread;
using ::mlir::triton::gpu::getOrder;
using ::mlir::triton::gpu::getShapePerCTA;
using ::mlir::triton::gpu::getShapePerCTATile;
using ::mlir::triton::gpu::getSizePerThread;
using ::mlir::triton::gpu::getTotalElemsPerThread;
using ::mlir::triton::gpu::isaDistributedLayout;
@@ -72,6 +76,13 @@ public:
dstLayout.isa<DotOperandEncodingAttr>()) {
return lowerSharedToDotOperand(op, adaptor, rewriter);
}
// forwarding on mma->mma shortcut, lower distributed->distributed otherwise
if (srcLayout.isa<MmaEncodingAttr>() && dstLayout.isa<MmaEncodingAttr>()) {
if (isMmaToMmaShortcut(srcTy, dstTy)) {
rewriter.replaceOp(op, op.getSrc());
return success();
}
}
if (isaDistributedLayout(srcLayout) && isaDistributedLayout(dstLayout)) {
return lowerDistributedToDistributed(op, adaptor, rewriter);
}
@@ -89,23 +100,25 @@ public:
}
private:
SmallVector<Value> getMultiDimOffset(Attribute layout, Location loc,
ConversionPatternRewriter &rewriter,
unsigned elemId, RankedTensorType type,
ArrayRef<unsigned> multiDimCTAInRepId,
ArrayRef<unsigned> shapePerCTA) const {
SmallVector<Value>
getMultiDimOffset(Attribute layout, Location loc,
ConversionPatternRewriter &rewriter, unsigned elemId,
RankedTensorType type,
ArrayRef<unsigned> multiDimCTAInRepId,
ArrayRef<unsigned> shapePerCTATile) const {
auto shape = type.getShape();
unsigned rank = shape.size();
if (auto blockedLayout = layout.dyn_cast<BlockedEncodingAttr>()) {
auto multiDimOffsetFirstElem =
emitBaseIndexForLayout(loc, rewriter, blockedLayout, type);
emitBaseIndexForLayout(loc, rewriter, blockedLayout, type, false);
SmallVector<Value> multiDimOffset(rank);
SmallVector<unsigned> multiDimElemId = getMultiDimIndex<unsigned>(
elemId, getSizePerThread(layout), getOrder(layout));
for (unsigned d = 0; d < rank; ++d) {
multiDimOffset[d] = add(multiDimOffsetFirstElem[d],
i32_val(multiDimCTAInRepId[d] * shapePerCTA[d] +
multiDimElemId[d]));
multiDimOffset[d] =
add(multiDimOffsetFirstElem[d],
i32_val(multiDimCTAInRepId[d] * shapePerCTATile[d] +
multiDimElemId[d]));
}
return multiDimOffset;
}
@@ -127,7 +140,7 @@ private:
auto multiDimOffsetParent = getMultiDimOffset(
parentEncoding, loc, rewriter, idxs[elemId], parentTy,
sliceLayout.paddedShape(multiDimCTAInRepId),
sliceLayout.paddedShape(shapePerCTA));
sliceLayout.paddedShape(shapePerCTATile));
SmallVector<Value> multiDimOffset(rank);
for (unsigned d = 0; d < rank + 1; ++d) {
if (d == dim)
@@ -138,6 +151,8 @@ private:
return multiDimOffset;
}
if (auto mmaLayout = layout.dyn_cast<MmaEncodingAttr>()) {
auto shapePerCTA = getShapePerCTA(mmaLayout, shape);
auto instrShape = mmaLayout.getInstrShape();
SmallVector<Value> mmaColIdx(4);
SmallVector<Value> mmaRowIdx(2);
Value threadId = getThreadId(rewriter, loc);
@@ -145,27 +160,35 @@ private:
Value laneId = urem(threadId, warpSize);
Value warpId = udiv(threadId, warpSize);
// TODO: fix the bug in MMAEncodingAttr document
SmallVector<Value> multiDimWarpId(2);
auto warpsPerCTA = mmaLayout.getWarpsPerCTA();
auto order = triton::gpu::getOrder(mmaLayout);
SmallVector<Value> multiDimWarpId =
delinearize(rewriter, loc, warpId, warpsPerCTA, order);
if (mmaLayout.isHopper()) {
multiDimWarpId[0] = urem(warpId, i32_val(warpsPerCTA[0]));
multiDimWarpId[1] = udiv(warpId, i32_val(warpsPerCTA[0]));
} else {
auto order = triton::gpu::getOrder(mmaLayout);
multiDimWarpId = delinearize(rewriter, loc, warpId, warpsPerCTA, order);
}
Value _1 = i32_val(1);
Value _2 = i32_val(2);
Value _4 = i32_val(4);
Value _8 = i32_val(8);
Value _16 = i32_val(16);
if (mmaLayout.isAmpere()) {
multiDimWarpId[0] = urem(multiDimWarpId[0], i32_val(shape[0] / 16));
multiDimWarpId[1] = urem(multiDimWarpId[1], i32_val(shape[1] / 8));
if (mmaLayout.isAmpere() || mmaLayout.isHopper()) {
multiDimWarpId[0] =
urem(multiDimWarpId[0], i32_val(shapePerCTA[0] / instrShape[0]));
multiDimWarpId[1] =
urem(multiDimWarpId[1], i32_val(shapePerCTA[1] / instrShape[1]));
Value mmaGrpId = udiv(laneId, _4);
Value mmaGrpIdP8 = add(mmaGrpId, _8);
Value mmaThreadIdInGrp = urem(laneId, _4);
Value mmaThreadIdInGrpM2 = mul(mmaThreadIdInGrp, _2);
Value mmaThreadIdInGrpM2P1 = add(mmaThreadIdInGrpM2, _1);
Value rowWarpOffset = mul(multiDimWarpId[0], _16);
Value rowWarpOffset = mul(multiDimWarpId[0], i32_val(instrShape[0]));
mmaRowIdx[0] = add(mmaGrpId, rowWarpOffset);
mmaRowIdx[1] = add(mmaGrpIdP8, rowWarpOffset);
Value colWarpOffset = mul(multiDimWarpId[1], _8);
Value colWarpOffset = mul(multiDimWarpId[1], i32_val(instrShape[1]));
mmaColIdx[0] = add(mmaThreadIdInGrpM2, colWarpOffset);
mmaColIdx[1] = add(mmaThreadIdInGrpM2P1, colWarpOffset);
} else if (mmaLayout.isVolta()) {
@@ -176,13 +199,27 @@ private:
assert(rank == 2);
SmallVector<Value> multiDimOffset(rank);
if (mmaLayout.isAmpere()) {
if (mmaLayout.isHopper()) {
unsigned elemIdRem4 = elemId % 4;
unsigned nGrpId = elemId / 4;
multiDimOffset[0] = elemIdRem4 < 2 ? mmaRowIdx[0] : mmaRowIdx[1];
multiDimOffset[1] = elemIdRem4 % 2 == 0 ? mmaColIdx[0] : mmaColIdx[1];
multiDimOffset[1] = add(multiDimOffset[1], i32_val(8 * nGrpId));
multiDimOffset[0] =
add(multiDimOffset[0],
i32_val(multiDimCTAInRepId[0] * shapePerCTATile[0]));
multiDimOffset[1] =
add(multiDimOffset[1],
i32_val(multiDimCTAInRepId[1] * shapePerCTATile[1]));
} else if (mmaLayout.isAmpere()) {
multiDimOffset[0] = elemId < 2 ? mmaRowIdx[0] : mmaRowIdx[1];
multiDimOffset[1] = elemId % 2 == 0 ? mmaColIdx[0] : mmaColIdx[1];
multiDimOffset[0] = add(
multiDimOffset[0], i32_val(multiDimCTAInRepId[0] * shapePerCTA[0]));
multiDimOffset[1] = add(
multiDimOffset[1], i32_val(multiDimCTAInRepId[1] * shapePerCTA[1]));
multiDimOffset[0] =
add(multiDimOffset[0],
i32_val(multiDimCTAInRepId[0] * shapePerCTATile[0]));
multiDimOffset[1] =
add(multiDimOffset[1],
i32_val(multiDimCTAInRepId[1] * shapePerCTATile[1]));
} else if (mmaLayout.isVolta()) {
auto [isARow, isBRow, isAVec4, isBVec4, _] =
mmaLayout.decodeVoltaLayoutStates();
@@ -211,11 +248,12 @@ private:
auto rank = type.getRank();
auto sizePerThread = getSizePerThread(layout);
auto accumSizePerThread = product<unsigned>(sizePerThread);
SmallVector<unsigned> numCTAs(rank);
SmallVector<unsigned> numCTATiles(rank);
auto shapePerCTATile = getShapePerCTATile(layout);
auto shapePerCTA = getShapePerCTA(layout, type.getShape());
auto order = getOrder(layout);
for (unsigned d = 0; d < rank; ++d) {
numCTAs[d] = ceil<unsigned>(type.getShape()[d], shapePerCTA[d]);
numCTATiles[d] = ceil<unsigned>(shapePerCTA[d], shapePerCTATile[d]);
}
auto elemTy = type.getElementType();
bool isInt1 = elemTy.isInteger(1);
@@ -238,17 +276,16 @@ private:
}
auto linearCTAId =
getLinearIndex<unsigned>(multiDimCTAId, numCTAs, order);
getLinearIndex<unsigned>(multiDimCTAId, numCTATiles, order);
// TODO: This is actually redundant index calculation, we should
// consider of caching the index calculation result in case
// of performance issue observed.
for (unsigned elemId = 0; elemId < accumSizePerThread; elemId += vec) {
SmallVector<Value> multiDimOffset =
getMultiDimOffset(layout, loc, rewriter, elemId, type,
multiDimCTAInRepId, shapePerCTA);
multiDimCTAInRepId, shapePerCTATile);
Value offset =
linearize(rewriter, loc, multiDimOffset, paddedRepShape, outOrd);
auto elemPtrTy = ptr_ty(llvmElemTy, 3);
Value ptr = gep(elemPtrTy, smemBase, offset);
auto vecTy = vec_ty(llvmElemTy, vec);
@@ -305,7 +342,8 @@ private:
SmallVector<unsigned> numCTAs(rank, 1);
SmallVector<unsigned> numCTAsEachRep(rank, 1);
SmallVector<unsigned> shapePerCTA = getShapePerCTA(layout, shape);
SmallVector<unsigned> shapePerCTATile = getShapePerCTATile(layout, shape);
SmallVector<int64_t> shapePerCTA = getShapePerCTA(layout, shape);
auto elemTy = type.getElementType();
int ctaId = 0;
@@ -335,7 +373,7 @@ private:
// duplicate in Volta.
SmallVector<Value> multiDimOffset =
getMultiDimOffset(layout, loc, rewriter, elemId, type,
multiDimCTAInRepId, shapePerCTA);
multiDimCTAInRepId, shapePerCTATile);
coord2val[elemId] = std::make_pair(multiDimOffset, vals[elemId]);
}
@@ -343,7 +381,7 @@ private:
// do transpose
auto aEncoding =
DotOperandEncodingAttr::get(mma.getContext(), 0, mma, 0);
int numM = aEncoding.getMMAv1NumOuter(shape);
int numM = aEncoding.getMMAv1NumOuter(shapePerCTA);
int numN = accumSizePerThread / numM;
for (int r = 0; r < numM; r++) {
@@ -382,6 +420,91 @@ private:
}
}
LogicalResult
lowerDistToDistWithDistSmem(triton::gpu::ConvertLayoutOp op,
OpAdaptor adaptor,
ConversionPatternRewriter &rewriter) const {
auto loc = op.getLoc();
Value src = op.getSrc();
Value dst = op.getResult();
auto srcTy = src.getType().cast<RankedTensorType>();
auto dstTy = dst.getType().cast<RankedTensorType>();
auto srcLayout = srcTy.getEncoding();
auto dstLayout = dstTy.getEncoding();
auto srcShapePerCTA = getShapePerCTA(srcTy);
auto srcCTAsPerCGA = triton::gpu::getCTAsPerCGA(srcLayout);
auto srcCTAOrder = triton::gpu::getCTAOrder(srcLayout);
unsigned rank = srcShapePerCTA.size();
auto llvmElemTy = getTypeConverter()->convertType(dstTy.getElementType());
auto elemPtrTy = ptr_ty(llvmElemTy, 3);
Value smemBase = getSharedMemoryBase(loc, rewriter, op.getOperation());
smemBase = bitcast(smemBase, elemPtrTy);
auto smemShape = convertType<unsigned, int64_t>(srcShapePerCTA);
// Store to local shared memory
{
auto inVals = getTypeConverter()->unpackLLElements(loc, adaptor.getSrc(),
rewriter, srcTy);
auto inIndices =
emitIndices(loc, rewriter, srcLayout, srcTy, /*withCTAOffset*/ false);
assert(inIndices.size() == inVals.size() &&
"Unexpected number of indices emitted");
for (unsigned i = 0; i < inIndices.size(); ++i) {
Value offset = linearize(rewriter, loc, inIndices[i], smemShape);
Value ptr = gep(elemPtrTy, smemBase, offset);
store(inVals[i], ptr);
}
}
// Cluster barrier
rewriter.create<triton::nvidia_gpu::ClusterArriveOp>(loc, false);
rewriter.create<triton::nvidia_gpu::ClusterWaitOp>(loc);
// Load from remote shared memory
{
SmallVector<Value> srcShapePerCTACache;
for (unsigned i = 0; i < rank; ++i)
srcShapePerCTACache.push_back(i32_val(srcShapePerCTA[i]));
SmallVector<Value> outVals;
auto outIndices =
emitIndices(loc, rewriter, dstLayout, dstTy, /*withCTAOffset*/ true);
for (unsigned i = 0; i < outIndices.size(); ++i) {
auto coord = outIndices[i];
assert(coord.size() == rank && "Unexpected rank of index emitted");
SmallVector<Value> multiDimCTAId, localCoord;
for (unsigned d = 0; d < rank; ++d) {
multiDimCTAId.push_back(udiv(coord[d], srcShapePerCTACache[d]));
localCoord.push_back(urem(coord[d], srcShapePerCTACache[d]));
}
Value remoteCTAId =
linearize(rewriter, loc, multiDimCTAId, srcCTAsPerCGA, srcCTAOrder);
Value localOffset = linearize(rewriter, loc, localCoord, smemShape);
Value ptr = gep(elemPtrTy, smemBase, localOffset);
outVals.push_back(load_dsmem(ptr, remoteCTAId));
}
Value result =
getTypeConverter()->packLLElements(loc, outVals, rewriter, dstTy);
rewriter.replaceOp(op, result);
}
// Cluster barrier
rewriter.create<triton::nvidia_gpu::ClusterArriveOp>(loc, false);
rewriter.create<triton::nvidia_gpu::ClusterWaitOp>(loc);
return success();
}
// blocked/mma -> blocked/mma.
// Data padding in shared memory to avoid bank conflict.
LogicalResult
@@ -395,6 +518,10 @@ private:
auto dstTy = dst.getType().cast<RankedTensorType>();
Attribute srcLayout = srcTy.getEncoding();
Attribute dstLayout = dstTy.getEncoding();
if (shouldUseDistSmem(srcLayout, dstLayout))
return lowerDistToDistWithDistSmem(op, adaptor, rewriter);
auto llvmElemTy = getTypeConverter()->convertType(dstTy.getElementType());
Value smemBase = getSharedMemoryBase(loc, rewriter, op.getOperation());
auto elemPtrTy = ptr_ty(llvmElemTy, 3);
@@ -406,8 +533,9 @@ private:
SmallVector<unsigned> outNumCTAsEachRep(rank);
SmallVector<unsigned> inNumCTAs(rank);
SmallVector<unsigned> outNumCTAs(rank);
auto srcShapePerCTA = getShapePerCTA(srcLayout, srcTy.getShape());
auto dstShapePerCTA = getShapePerCTA(dstLayout, shape);
auto srcShapePerCTATile = getShapePerCTATile(srcLayout, srcTy.getShape());
auto dstShapePerCTATile = getShapePerCTATile(dstLayout, shape);
auto shapePerCTA = getShapePerCTA(srcLayout, shape);
// For Volta, all the coords for a CTA are calculated.
bool isSrcMmaV1{}, isDstMmaV1{};
@@ -427,15 +555,17 @@ private:
}
for (unsigned d = 0; d < rank; ++d) {
unsigned inPerCTA = std::min<unsigned>(shape[d], srcShapePerCTA[d]);
unsigned outPerCTA = std::min<unsigned>(shape[d], dstShapePerCTA[d]);
unsigned inPerCTA =
std::min<unsigned>(shapePerCTA[d], srcShapePerCTATile[d]);
unsigned outPerCTA =
std::min<unsigned>(shapePerCTA[d], dstShapePerCTATile[d]);
unsigned maxPerCTA = std::max(inPerCTA, outPerCTA);
numReplicates[d] = ceil<unsigned>(shape[d], maxPerCTA);
numReplicates[d] = ceil<unsigned>(shapePerCTA[d], maxPerCTA);
inNumCTAsEachRep[d] = maxPerCTA / inPerCTA;
outNumCTAsEachRep[d] = maxPerCTA / outPerCTA;
assert(maxPerCTA % inPerCTA == 0 && maxPerCTA % outPerCTA == 0);
inNumCTAs[d] = ceil<unsigned>(shape[d], inPerCTA);
outNumCTAs[d] = ceil<unsigned>(shape[d], outPerCTA);
inNumCTAs[d] = ceil<unsigned>(shapePerCTA[d], inPerCTA);
outNumCTAs[d] = ceil<unsigned>(shapePerCTA[d], outPerCTA);
}
// Potentially we need to store for multiple CTAs in this replication
auto accumNumReplicates = product<unsigned>(numReplicates);
@@ -456,8 +586,26 @@ private:
for (unsigned repId = 0; repId < accumNumReplicates; ++repId) {
auto multiDimRepId =
getMultiDimIndex<unsigned>(repId, numReplicates, outOrd);
if (repId != 0)
barrier();
if (repId != 0) {
// TODO[shuhaoj]: change hard code style of numThreads. Hide async
// attr. Better way to determine barId (number of agents are limited).
if (op->hasAttr("async_agent")) {
int agentId = getAgentIds(op).front(), roleId = 0;
if (op->hasAttr("agent.mutex_role"))
roleId =
op->getAttrOfType<IntegerAttr>("agent.mutex_role").getInt();
int barId = agentId + roleId + nameBarrierIdBegin;
assert(barId < nameBarrierIdEnd);
auto bar = rewriter.create<LLVM::ConstantOp>(
loc, i32_ty, rewriter.getI32IntegerAttr(barId));
auto kNumThreads = rewriter.create<LLVM::ConstantOp>(
loc, i32_ty, rewriter.getI32IntegerAttr(128));
rewriter.create<triton::nvgpu::NamedBarrierWaitOp>(loc, bar,
kNumThreads);
} else {
barrier();
}
}
if (srcLayout.isa<BlockedEncodingAttr>() ||
srcLayout.isa<SliceEncodingAttr>() ||
srcLayout.isa<MmaEncodingAttr>()) {
@@ -474,7 +622,23 @@ private:
return failure();
}
barrier();
// TODO[shuhaoj]: change hard code style of numThreads. Hide async_agent
// attr. Better way to determine barId (number of agents are limited).
if (op->hasAttr("async_agent")) {
int agentId = getAgentIds(op).front(), roleId = 0;
if (op->hasAttr("agent.mutex_role"))
roleId = op->getAttrOfType<IntegerAttr>("agent.mutex_role").getInt();
int barId = agentId + roleId + nameBarrierIdBegin;
assert(barId < nameBarrierIdEnd);
auto bar = rewriter.create<LLVM::ConstantOp>(
loc, i32_ty, rewriter.getI32IntegerAttr(barId));
auto kNumThreads = rewriter.create<LLVM::ConstantOp>(
loc, i32_ty, rewriter.getI32IntegerAttr(128));
rewriter.create<triton::nvgpu::NamedBarrierWaitOp>(loc, bar,
kNumThreads);
} else {
barrier();
}
if (dstLayout.isa<BlockedEncodingAttr>() ||
dstLayout.isa<SliceEncodingAttr>() ||
dstLayout.isa<MmaEncodingAttr>()) {
@@ -545,7 +709,7 @@ private:
auto srcTy = src.getType().cast<RankedTensorType>();
auto srcShape = srcTy.getShape();
auto dstTy = dst.getType().cast<RankedTensorType>();
auto dstShape = dstTy.getShape();
auto dstShapePerCTA = triton::gpu::getShapePerCTA(dstTy);
assert(srcShape.size() == 2 &&
"Unexpected rank of ConvertLayout(blocked->shared)");
auto srcLayout = srcTy.getEncoding();
@@ -557,13 +721,93 @@ private:
auto elemPtrTy = ptr_ty(getTypeConverter()->convertType(elemTy), 3);
smemBase = bitcast(smemBase, elemPtrTy);
auto dstStrides =
getStridesFromShapeAndOrder(dstShape, outOrd, loc, rewriter);
auto srcIndices = emitIndices(loc, rewriter, srcLayout, srcTy);
storeDistributedToShared(src, adaptor.getSrc(), dstStrides, srcIndices, dst,
smemBase, elemTy, loc, rewriter);
int32_t elemSize = elemTy.getIntOrFloatBitWidth();
auto mmaLayout = srcLayout.dyn_cast<MmaEncodingAttr>();
unsigned numElems = triton::gpu::getTotalElemsPerThread(srcTy);
if (mmaLayout && mmaLayout.isHopper() && elemSize == 16 &&
inOrd == outOrd && numElems >= 16) {
auto inVals = getTypeConverter()->unpackLLElements(loc, adaptor.getSrc(),
rewriter, srcTy);
auto srcShapePerCTA = getShapePerCTA(mmaLayout, srcShape);
auto instrShape = mmaLayout.getInstrShape();
auto warpsPerCTA = mmaLayout.getWarpsPerCTA();
uint32_t repM =
ceil<unsigned>(srcShapePerCTA[0], instrShape[0] * warpsPerCTA[0]);
uint32_t numElemsPerRep = numElems / repM;
// rowStride in bytes
uint32_t rowStrideInBytes = dstShapePerCTA[outOrd[0]] * 2;
uint32_t swizzlingByteWidth = rowStrideInBytes;
if (swizzlingByteWidth > 128)
swizzlingByteWidth = 128;
unsigned numElemsPerSwizzlingRow = swizzlingByteWidth * 8 / elemSize;
unsigned leadingDimOffset =
numElemsPerSwizzlingRow * srcShapePerCTA[outOrd[1]];
auto ptrI8SharedTy = LLVM::LLVMPointerType::get(
typeConverter->convertType(rewriter.getI8Type()), 3);
uint32_t rowsPerRep = getShapePerCTATile(mmaLayout)[0];
Value threadId = getThreadId(rewriter, loc);
Value warpId = udiv(threadId, i32_val(32));
Value warpId0 = urem(urem(warpId, i32_val(warpsPerCTA[0])),
i32_val(srcShape[0] / instrShape[0]));
for (int rep = 0; rep < repM; ++rep) {
Value rowOfWarp = add(mul(warpId0, i32_val(instrShape[0])),
i32_val(rep * rowsPerRep));
uint32_t elemIdxOffset = rep * numElemsPerRep;
for (unsigned idx = 0; idx < numElemsPerRep; idx += 8) {
uint32_t elemIdx = elemIdxOffset + idx;
Value offset = rewriter.create<triton::nvgpu::OffsetOfStmatrixV4Op>(
loc, i32_ty, threadId, rowOfWarp, i32_val(idx), leadingDimOffset,
numElemsPerSwizzlingRow, true);
Value addr = gep(elemPtrTy, smemBase, offset);
Value data0 = rewriter.create<triton::nvgpu::CvtPackOp>(
loc, i32_ty, inVals[elemIdx + 1], inVals[elemIdx + 0]);
Value data1 = rewriter.create<triton::nvgpu::CvtPackOp>(
loc, i32_ty, inVals[elemIdx + 3], inVals[elemIdx + 2]);
Value data2 = rewriter.create<triton::nvgpu::CvtPackOp>(
loc, i32_ty, inVals[elemIdx + 5], inVals[elemIdx + 4]);
Value data3 = rewriter.create<triton::nvgpu::CvtPackOp>(
loc, i32_ty, inVals[elemIdx + 7], inVals[elemIdx + 6]);
rewriter.create<triton::nvgpu::StoreMatrixOp>(
loc, bitcast(addr, ptrI8SharedTy),
ValueRange{data0, data1, data2, data3});
}
}
// TODO[shuhaoj]: change hard code style of numThreads. Hide async_agent
// attr. Better way to determine barId (number of agents are limited).
if (op->hasAttr("async_agent")) {
int agentId = getAgentIds(op).front(), roleId = 0;
if (op->hasAttr("agent.mutex_role"))
roleId = op->getAttrOfType<IntegerAttr>("agent.mutex_role").getInt();
int barId = agentId + roleId + nameBarrierIdBegin;
assert(barId < nameBarrierIdEnd);
auto bar = rewriter.create<LLVM::ConstantOp>(
loc, i32_ty, rewriter.getI32IntegerAttr(barId));
auto kNumThreads = rewriter.create<LLVM::ConstantOp>(
loc, i32_ty, rewriter.getI32IntegerAttr(128));
rewriter.create<triton::nvgpu::NamedBarrierWaitOp>(loc, bar,
kNumThreads);
} else {
barrier();
}
} else {
auto dstStrides =
getStridesFromShapeAndOrder(dstShapePerCTA, outOrd, loc, rewriter);
auto srcIndices = emitIndices(loc, rewriter, srcLayout, srcTy, false);
storeDistributedToShared(src, adaptor.getSrc(), dstStrides, srcIndices,
dst, smemBase, elemTy, loc, rewriter);
}
auto smemObj =
SharedMemoryObject(smemBase, dstShape, outOrd, loc, rewriter);
SharedMemoryObject(smemBase, dstShapePerCTA, outOrd, loc, rewriter);
auto retVal = getStructFromSharedMemoryObject(loc, smemObj, rewriter);
rewriter.replaceOp(op, retVal);
return success();
@@ -688,19 +932,16 @@ private:
auto loc = op.getLoc();
Value src = op.getSrc();
Value dst = op.getResult();
bool isMMA = supportMMA(dst, mmaLayout.getVersionMajor());
auto smemObj =
getSharedMemoryObjectFromStruct(loc, adaptor.getSrc(), rewriter);
Value res;
if (!isOuter && mmaLayout.isAmpere()) { // tensor core v2
res = SharedToDotOperandMMAv2::convertLayout(
dotOperandLayout.getOpIdx(), rewriter, loc, src, dotOperandLayout,
smemObj, getTypeConverter(), tid_val());
} else if (!isOuter && mmaLayout.isVolta() &&
supportMMA(dst, mmaLayout.getVersionMajor())) { // tensor core v1
smemObj, getTypeConverter(), getThreadId(rewriter, loc));
} else if (!isOuter && mmaLayout.isVolta() && isMMA) { // tensor core v1
bool isMMAv1Row = dotOperandLayout.getMMAv1IsRow();
auto srcSharedLayout = src.getType()
.cast<RankedTensorType>()
@@ -722,10 +963,11 @@ private:
}
return res;
}
}; // namespace triton::gpu::ConvertLayoutOp
}; // namespace triton::gpu::ConvertLayoutOp>
void populateConvertLayoutOpToLLVMPatterns(
TritonGPUToLLVMTypeConverter &typeConverter, RewritePatternSet &patterns,
int numWarps, ModuleAxisInfoAnalysis &axisInfoAnalysis,
ModuleAllocation &allocation,
ConvertTritonGPUOpToLLVMPatternBase::IndexCacheInfo &indexCacheInfo,
PatternBenefit benefit) {

View File

@@ -10,6 +10,7 @@ using ::mlir::triton::gpu::DotOperandEncodingAttr;
void populateConvertLayoutOpToLLVMPatterns(
TritonGPUToLLVMTypeConverter &typeConverter, RewritePatternSet &patterns,
int numWarps, ModuleAxisInfoAnalysis &axisInfoAnalysis,
ModuleAllocation &allocation,
ConvertTritonGPUOpToLLVMPatternBase::IndexCacheInfo &indexCacheInfo,
PatternBenefit benefit);

View File

@@ -2,8 +2,10 @@
#include "../Utility.h"
using ValueTable = std::map<std::pair<int, int>, Value>;
using ::mlir::LLVM::delinearize;
using ::mlir::LLVM::getSharedMemoryObjectFromStruct;
using ::mlir::LLVM::getStridesFromShapeAndOrder;
using ::mlir::LLVM::linearize;
using ::mlir::triton::gpu::DotOperandEncodingAttr;
using ::mlir::triton::gpu::getContigPerThread;
using ::mlir::triton::gpu::getOrder;
@@ -14,31 +16,32 @@ using ::mlir::triton::gpu::isaDistributedLayout;
using ::mlir::triton::gpu::SharedEncodingAttr;
SmallVector<Value>
getThreadIds(Value threadId, ArrayRef<unsigned int> shapePerCTA,
getThreadIds(Value threadId, ArrayRef<unsigned int> shapePerCTATile,
ArrayRef<unsigned int> sizePerThread, ArrayRef<unsigned int> order,
ConversionPatternRewriter &rewriter, Location loc) {
int dim = order.size();
SmallVector<Value> threadIds(dim);
for (unsigned k = 0; k < dim - 1; k++) {
Value dimK = i32_val(shapePerCTA[order[k]] / sizePerThread[order[k]]);
Value dimK = i32_val(shapePerCTATile[order[k]] / sizePerThread[order[k]]);
Value rem = urem(threadId, dimK);
threadId = udiv(threadId, dimK);
threadIds[order[k]] = rem;
}
Value dimK = i32_val(shapePerCTA[order[dim - 1]]);
Value dimK = i32_val(shapePerCTATile[order[dim - 1]]);
threadIds[order[dim - 1]] = urem(threadId, dimK);
return threadIds;
}
int getShapePerCTAForMN(BlockedEncodingAttr layout, bool isM) {
// Get shapePerCTATile for M or N axis.
int getShapePerCTATileForMN(BlockedEncodingAttr layout, bool isM) {
auto order = layout.getOrder();
auto shapePerCTA = getShapePerCTA(layout);
auto shapePerCTATile = getShapePerCTATile(layout);
int mShapePerCTA =
order[0] == 1 ? shapePerCTA[order[1]] : shapePerCTA[order[0]];
int nShapePerCTA =
order[0] == 0 ? shapePerCTA[order[1]] : shapePerCTA[order[0]];
return isM ? mShapePerCTA : nShapePerCTA;
int mShapePerCTATile =
order[0] == 1 ? shapePerCTATile[order[1]] : shapePerCTATile[order[0]];
int nShapePerCTATile =
order[0] == 0 ? shapePerCTATile[order[1]] : shapePerCTATile[order[0]];
return isM ? mShapePerCTATile : nShapePerCTATile;
}
// Get sizePerThread for M or N axis.
@@ -91,7 +94,7 @@ Value loadAFMA(Value A, Value llA, BlockedEncodingAttr dLayout, Value thread,
ConversionPatternRewriter &rewriter) {
auto aTensorTy = A.getType().cast<RankedTensorType>();
auto aLayout = aTensorTy.getEncoding().cast<SharedEncodingAttr>();
auto aShape = aTensorTy.getShape();
auto aShapePerCTA = getShapePerCTA(aTensorTy);
auto aOrder = aLayout.getOrder();
auto order = dLayout.getOrder();
@@ -104,10 +107,10 @@ Value loadAFMA(Value A, Value llA, BlockedEncodingAttr dLayout, Value thread,
Value strideA0 = isARow ? strideAK : strideAM;
Value strideA1 = isARow ? strideAM : strideAK;
int aNumPtr = 8;
int K = aShape[1];
int M = aShape[0];
int K = aShapePerCTA[1];
int M = aShapePerCTA[0];
auto shapePerCTA = getShapePerCTA(dLayout);
auto shapePerCTATile = getShapePerCTATile(dLayout);
auto sizePerThread = getSizePerThread(dLayout);
Value _0 = i32_val(0);
@@ -115,8 +118,8 @@ Value loadAFMA(Value A, Value llA, BlockedEncodingAttr dLayout, Value thread,
Value mContig = i32_val(sizePerThread[order[1]]);
// threadId in blocked layout
auto threadIds =
getThreadIds(thread, shapePerCTA, sizePerThread, order, rewriter, loc);
auto threadIds = getThreadIds(thread, shapePerCTATile, sizePerThread, order,
rewriter, loc);
Value threadIdM = threadIds[0];
Value offA0 = isARow ? _0 : mul(threadIdM, mContig);
@@ -134,11 +137,11 @@ Value loadAFMA(Value A, Value llA, BlockedEncodingAttr dLayout, Value thread,
SmallVector<Value> vas;
int mShapePerCTA = getShapePerCTAForMN(dLayout, true /*isM*/);
int mShapePerCTATile = getShapePerCTATileForMN(dLayout, true /*isM*/);
int mSizePerThread = getSizePerThreadForMN(dLayout, true /*isM*/);
for (unsigned k = 0; k < K; ++k)
for (unsigned m = 0; m < M; m += mShapePerCTA)
for (unsigned m = 0; m < M; m += mShapePerCTATile)
for (unsigned mm = 0; mm < mSizePerThread; ++mm) {
Value offset =
add(mul(i32_val(m + mm), strideAM), mul(i32_val(k), strideAK));
@@ -155,7 +158,7 @@ Value loadBFMA(Value B, Value llB, BlockedEncodingAttr dLayout, Value thread,
ConversionPatternRewriter &rewriter) {
auto bTensorTy = B.getType().cast<RankedTensorType>();
auto bLayout = bTensorTy.getEncoding().cast<SharedEncodingAttr>();
auto bShape = bTensorTy.getShape();
auto bShapePerCTA = getShapePerCTA(bTensorTy);
auto bOrder = bLayout.getOrder();
auto order = dLayout.getOrder();
@@ -168,10 +171,10 @@ Value loadBFMA(Value B, Value llB, BlockedEncodingAttr dLayout, Value thread,
Value strideB0 = isBRow ? strideBN : strideBK;
Value strideB1 = isBRow ? strideBK : strideBN;
int bNumPtr = 8;
int K = bShape[0];
int N = bShape[1];
int K = bShapePerCTA[0];
int N = bShapePerCTA[1];
auto shapePerCTA = getShapePerCTA(dLayout);
auto shapePerCTATile = getShapePerCTATile(dLayout);
auto sizePerThread = getSizePerThread(dLayout);
Value _0 = i32_val(0);
@@ -179,8 +182,8 @@ Value loadBFMA(Value B, Value llB, BlockedEncodingAttr dLayout, Value thread,
Value nContig = i32_val(sizePerThread[order[0]]);
// threadId in blocked layout
auto threadIds =
getThreadIds(thread, shapePerCTA, sizePerThread, order, rewriter, loc);
auto threadIds = getThreadIds(thread, shapePerCTATile, sizePerThread, order,
rewriter, loc);
Value threadIdN = threadIds[1];
Value offB0 = isBRow ? mul(threadIdN, nContig) : _0;
@@ -198,11 +201,11 @@ Value loadBFMA(Value B, Value llB, BlockedEncodingAttr dLayout, Value thread,
SmallVector<Value> vbs;
int nShapePerCTA = getShapePerCTAForMN(dLayout, false /*isM*/);
int nShapePerCTATile = getShapePerCTATileForMN(dLayout, false /*isM*/);
int nSizePerThread = getSizePerThreadForMN(dLayout, false /*isM*/);
for (unsigned k = 0; k < K; ++k)
for (unsigned n = 0; n < N; n += nShapePerCTA)
for (unsigned n = 0; n < N; n += nShapePerCTATile)
for (unsigned nn = 0; nn < nSizePerThread; ++nn) {
Value offset =
add(mul(i32_val(n + nn), strideBN), mul(i32_val(k), strideBK));

View File

@@ -504,7 +504,7 @@ std::function<void(int, int)> getLoadMatrixFn(
.cast<RankedTensorType>()
.getElementType()
.isa<mlir::Float8E4M3B11FNUZType>()) {
bool noTrans = (isA ^ order[0] == 0);
bool noTrans = (isA ^ (order[0] == 0));
assert(noTrans && "float8e4b15 must have row-col layout");
}

View File

@@ -4,7 +4,9 @@
using namespace mlir;
using namespace mlir::triton;
using ::mlir::LLVM::getSharedMemoryObjectFromStruct;
using ::mlir::triton::gpu::DotOperandEncodingAttr;
using ::mlir::triton::gpu::getShapePerCTA;
using ::mlir::triton::gpu::MmaEncodingAttr;
LogicalResult convertFMADot(triton::DotOp op, triton::DotOp::Adaptor adaptor,
@@ -19,6 +21,16 @@ LogicalResult convertMMA16816(triton::DotOp op, triton::DotOp::Adaptor adaptor,
TritonGPUToLLVMTypeConverter *typeConverter,
ConversionPatternRewriter &rewriter);
LogicalResult convertWGMMA(triton::DotOp op, triton::DotOp::Adaptor adaptor,
TritonGPUToLLVMTypeConverter *typeConverter,
ConversionPatternRewriter &rewriter, Value thread);
LogicalResult convertAsyncWGMMA(triton::nvidia_gpu::DotAsyncOp op,
triton::nvidia_gpu::DotAsyncOp::Adaptor adaptor,
TritonGPUToLLVMTypeConverter *typeConverter,
ConversionPatternRewriter &rewriter,
Value thread);
struct DotOpConversion : public ConvertTritonGPUOpToLLVMPattern<triton::DotOp> {
using ConvertTritonGPUOpToLLVMPattern<
triton::DotOp>::ConvertTritonGPUOpToLLVMPattern;
@@ -26,14 +38,15 @@ struct DotOpConversion : public ConvertTritonGPUOpToLLVMPattern<triton::DotOp> {
LogicalResult
matchAndRewrite(triton::DotOp op, OpAdaptor adaptor,
ConversionPatternRewriter &rewriter) const override {
Location loc = op->getLoc();
// D = A * B + C
Value A = op.getA();
Value D = op.getResult();
// Here we assume the DotOp's operands always comes from shared memory.
auto AShape = A.getType().cast<RankedTensorType>().getShape();
auto AShapePerCTA = getShapePerCTA(A.getType());
size_t reduceAxis = 1;
unsigned K = AShape[reduceAxis];
unsigned K = AShapePerCTA[reduceAxis];
bool isOuter = K == 1;
MmaEncodingAttr mmaLayout = D.getType()
@@ -45,6 +58,9 @@ struct DotOpConversion : public ConvertTritonGPUOpToLLVMPattern<triton::DotOp> {
return convertMMA884(op, adaptor, getTypeConverter(), rewriter);
if (mmaLayout.isAmpere())
return convertMMA16816(op, adaptor, getTypeConverter(), rewriter);
if (mmaLayout.isHopper())
return convertWGMMA(op, adaptor, getTypeConverter(), rewriter,
getThreadId(rewriter, loc));
llvm::report_fatal_error(
"Unsupported MMA kind found when converting DotOp to LLVM.");
@@ -61,9 +77,68 @@ struct DotOpConversion : public ConvertTritonGPUOpToLLVMPattern<triton::DotOp> {
}
};
struct DotAsyncOpConversion
: public ConvertTritonGPUOpToLLVMPattern<triton::nvidia_gpu::DotAsyncOp> {
using ConvertTritonGPUOpToLLVMPattern<
triton::nvidia_gpu::DotAsyncOp>::ConvertTritonGPUOpToLLVMPattern;
LogicalResult
matchAndRewrite(triton::nvidia_gpu::DotAsyncOp op, OpAdaptor adaptor,
ConversionPatternRewriter &rewriter) const override {
auto loc = op.getLoc();
// D = A * B + C
Value A = op.getA();
Value D = op.getResult();
// Here we assume the DotOp's operands always comes from shared memory.
auto AShapePerCTA = getShapePerCTA(A.getType());
size_t reduceAxis = 1;
unsigned K = AShapePerCTA[reduceAxis];
bool isOuter = K == 1;
MmaEncodingAttr mmaLayout = D.getType()
.cast<RankedTensorType>()
.getEncoding()
.dyn_cast<MmaEncodingAttr>();
if (!isOuter && mmaLayout &&
supportMMA(op.getOperand(0), mmaLayout.getVersionMajor())) {
if (mmaLayout.isHopper()) {
return convertAsyncWGMMA(op, adaptor, getTypeConverter(), rewriter,
getThreadId(rewriter, loc));
}
llvm::report_fatal_error(
"Unsupported MMA kind found when converting DotAsyncOp to LLVM.");
}
llvm::report_fatal_error(
"Unsupported DotAsyncOp found when converting TritonGPU to LLVM.");
}
};
struct DotWaitOpConversion
: public ConvertTritonGPUOpToLLVMPattern<triton::nvidia_gpu::DotWaitOp> {
using ConvertTritonGPUOpToLLVMPattern<
triton::nvidia_gpu::DotWaitOp>::ConvertTritonGPUOpToLLVMPattern;
LogicalResult
matchAndRewrite(triton::nvidia_gpu::DotWaitOp op, OpAdaptor adaptor,
ConversionPatternRewriter &rewriter) const override {
auto pendings = op.getPendings();
rewriter.create<triton::nvgpu::WGMMAWaitOp>(op.getLoc(), pendings);
// Safe to remove the op since it doesn't have any return value.
rewriter.eraseOp(op);
return success();
}
};
void populateDotOpToLLVMPatterns(TritonGPUToLLVMTypeConverter &typeConverter,
RewritePatternSet &patterns,
RewritePatternSet &patterns, int numWarps,
ModuleAxisInfoAnalysis &axisInfoAnalysis,
ModuleAllocation &allocation,
PatternBenefit benefit) {
patterns.add<DotOpConversion>(typeConverter, allocation, benefit);
patterns.add<DotAsyncOpConversion>(typeConverter, allocation, benefit);
patterns.add<DotWaitOpConversion>(typeConverter, allocation, benefit);
}

View File

@@ -7,7 +7,8 @@ using namespace mlir;
using namespace mlir::triton;
void populateDotOpToLLVMPatterns(TritonGPUToLLVMTypeConverter &typeConverter,
RewritePatternSet &patterns,
RewritePatternSet &patterns, int numWarps,
ModuleAxisInfoAnalysis &axisInfoAnalysis,
ModuleAllocation &allocation,
PatternBenefit benefit);

View File

@@ -5,19 +5,20 @@ using namespace mlir;
using namespace mlir::triton;
using ::mlir::triton::gpu::DotOperandEncodingAttr;
using ::mlir::triton::gpu::getShapePerCTA;
using ::mlir::triton::gpu::MmaEncodingAttr;
using ValueTableFMA = std::map<std::pair<int, int>, Value>;
static ValueTableFMA getValueTableFromStructFMA(
Value val, int K, int n0, int shapePerCTA, int sizePerThread,
Value val, int K, int n0, int shapePerCTATile, int sizePerThread,
ConversionPatternRewriter &rewriter, Location loc,
TritonGPUToLLVMTypeConverter *typeConverter, Type type) {
ValueTableFMA res;
auto elems = typeConverter->unpackLLElements(loc, val, rewriter, type);
int index = 0;
for (unsigned k = 0; k < K; ++k) {
for (unsigned m = 0; m < n0; m += shapePerCTA)
for (unsigned m = 0; m < n0; m += shapePerCTATile)
for (unsigned mm = 0; mm < sizePerThread; ++mm) {
res[{m + mm, k}] = elems[index++];
}
@@ -40,8 +41,8 @@ LogicalResult convertFMADot(triton::DotOp op, triton::DotOp::Adaptor adaptor,
auto bTensorTy = B.getType().cast<RankedTensorType>();
auto dTensorTy = D.getType().cast<RankedTensorType>();
auto aShape = aTensorTy.getShape();
auto bShape = bTensorTy.getShape();
auto aShapePerCTA = getShapePerCTA(aTensorTy);
auto bShapePerCTA = getShapePerCTA(bTensorTy);
BlockedEncodingAttr dLayout =
dTensorTy.getEncoding().cast<BlockedEncodingAttr>();
@@ -53,41 +54,42 @@ LogicalResult convertFMADot(triton::DotOp op, triton::DotOp::Adaptor adaptor,
Value llB = adaptor.getB();
auto sizePerThread = getSizePerThread(dLayout);
auto shapePerCTA = getShapePerCTA(dLayout);
auto shapePerCTATile = getShapePerCTATile(dLayout);
int K = aShape[1];
int M = aShape[0];
int N = bShape[1];
int K = aShapePerCTA[1];
int M = aShapePerCTA[0];
int N = bShapePerCTA[1];
int mShapePerCTA =
order[0] == 1 ? shapePerCTA[order[1]] : shapePerCTA[order[0]];
int mShapePerCTATile =
order[0] == 1 ? shapePerCTATile[order[1]] : shapePerCTATile[order[0]];
int mSizePerThread =
order[0] == 1 ? sizePerThread[order[1]] : sizePerThread[order[0]];
int nShapePerCTA =
order[0] == 0 ? shapePerCTA[order[1]] : shapePerCTA[order[0]];
int nShapePerCTATile =
order[0] == 0 ? shapePerCTATile[order[1]] : shapePerCTATile[order[0]];
int nSizePerThread =
order[0] == 0 ? sizePerThread[order[1]] : sizePerThread[order[0]];
auto has =
getValueTableFromStructFMA(llA, K, M, mShapePerCTA, mSizePerThread,
getValueTableFromStructFMA(llA, K, M, mShapePerCTATile, mSizePerThread,
rewriter, loc, typeConverter, aTensorTy);
auto hbs =
getValueTableFromStructFMA(llB, K, N, nShapePerCTA, nSizePerThread,
getValueTableFromStructFMA(llB, K, N, nShapePerCTATile, nSizePerThread,
rewriter, loc, typeConverter, bTensorTy);
SmallVector<Value> ret = cc;
bool isCRow = order[0] == 1;
for (unsigned k = 0; k < K; k++) {
for (unsigned m = 0; m < M; m += mShapePerCTA)
for (unsigned n = 0; n < N; n += nShapePerCTA)
for (unsigned m = 0; m < M; m += mShapePerCTATile)
for (unsigned n = 0; n < N; n += nShapePerCTATile)
for (unsigned mm = 0; mm < mSizePerThread; ++mm)
for (unsigned nn = 0; nn < nSizePerThread; ++nn) {
int mIdx = m / mShapePerCTA * mSizePerThread + mm;
int nIdx = n / nShapePerCTA * nSizePerThread + nn;
int mIdx = m / mShapePerCTATile * mSizePerThread + mm;
int nIdx = n / nShapePerCTATile * nSizePerThread + nn;
int z = isCRow ? mIdx * N / nShapePerCTA * mSizePerThread + nIdx
: nIdx * M / mShapePerCTA * nSizePerThread + mIdx;
int z = isCRow
? mIdx * N / nShapePerCTATile * mSizePerThread + nIdx
: nIdx * M / mShapePerCTATile * nSizePerThread + mIdx;
ret[z] = rewriter.create<LLVM::FMulAddOp>(loc, has[{m + mm, k}],
hbs[{n + nn, k}], ret[z]);
}

View File

@@ -170,16 +170,17 @@ LogicalResult convertDot(TritonGPUToLLVMTypeConverter *typeConverter,
auto bTensorTy = b.getType().cast<RankedTensorType>();
auto dTensorTy = d.getType().cast<RankedTensorType>();
SmallVector<int64_t> aShape(aTensorTy.getShape().begin(),
aTensorTy.getShape().end());
auto dShape = dTensorTy.getShape();
auto aShapePerCTA = triton::gpu::getShapePerCTA(aTensorTy);
auto bShapePerCTA = triton::gpu::getShapePerCTA(bTensorTy);
auto dShapePerCTA = triton::gpu::getShapePerCTA(dTensorTy);
int bitwidth = aTensorTy.getElementType().getIntOrFloatBitWidth();
auto repA =
aTensorTy.getEncoding().cast<DotOperandEncodingAttr>().getMMAv2Rep(
aTensorTy.getShape(), bitwidth);
aShapePerCTA, bitwidth);
auto repB =
bTensorTy.getEncoding().cast<DotOperandEncodingAttr>().getMMAv2Rep(
bTensorTy.getShape(), bitwidth);
bShapePerCTA, bitwidth);
assert(repA[1] == repB[0]);
int repM = repA[0], repN = repB[1], repK = repA[1];

View File

@@ -0,0 +1,391 @@
/*
* Copyright (c) 2023 NVIDIA Corporation & Affiliates. All rights reserved.
*
* Permission is hereby granted, free of charge, to any person obtaining
* a copy of this software and associated documentation files
* (the "Software"), to deal in the Software without restriction,
* including without limitation the rights to use, copy, modify, merge,
* publish, distribute, sublicense, and/or sell copies of the Software,
* and to permit persons to whom the Software is furnished to do so,
* subject to the following conditions:
*
* The above copyright notice and this permission notice shall be
* included in all copies or substantial portions of the Software.
*
* THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND,
* EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF
* MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT.
* IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY
* CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION OF CONTRACT,
* TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION WITH THE
* SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE.
*/
#include "DotOpToLLVM.h"
#include "Utility.h"
using namespace mlir;
using namespace mlir::triton;
using ::mlir::LLVM::getSharedMemoryObjectFromStruct;
using ::mlir::triton::gpu::getShapePerCTA;
using ::mlir::triton::gpu::getShapePerCTATile;
using ::mlir::triton::gpu::MmaEncodingAttr;
using ::mlir::triton::gpu::SharedEncodingAttr;
triton::nvgpu::WGMMAEltType getMmaRetType(Value d) {
auto dTy = d.getType().cast<RankedTensorType>().getElementType();
if (dTy.isF32()) {
return triton::nvgpu::WGMMAEltType::f32;
} else if (dTy.isF16()) {
return triton::nvgpu::WGMMAEltType::f16;
} else if (dTy.isInteger(32)) {
return triton::nvgpu::WGMMAEltType::s32;
} else {
llvm::report_fatal_error("Unsupported mma result type found");
}
}
triton::nvgpu::WGMMAEltType getMmaOperandType(Value a, bool allowTF32) {
auto aTy = a.getType().cast<RankedTensorType>().getElementType();
if (aTy.isF16()) {
return triton::nvgpu::WGMMAEltType::f16;
} else if (aTy.isBF16()) {
return triton::nvgpu::WGMMAEltType::bf16;
} else if (aTy.isF32() && allowTF32) {
return triton::nvgpu::WGMMAEltType::tf32;
} else if (aTy.isInteger(8)) {
return triton::nvgpu::WGMMAEltType::s8;
} else if (aTy.isFloat8E5M2()) {
return triton::nvgpu::WGMMAEltType::e5m2;
} else if (aTy.isFloat8E4M3FN()) {
return triton::nvgpu::WGMMAEltType::e4m3;
} else {
llvm::report_fatal_error("Unsupported mma operand type found");
}
}
mlir::triton::nvgpu::WGMMADescMode
getModeFromLayout(const SharedEncodingAttr &layout, uint32_t widthInByte) {
int perPhase = layout.getPerPhase();
int maxPhase = layout.getMaxPhase();
uint32_t swizzlingByteWidth = 0;
mlir::triton::nvgpu::WGMMADescMode mode;
if (perPhase == 4 && maxPhase == 2) {
mode = mlir::triton::nvgpu::WGMMADescMode::swizzle32;
swizzlingByteWidth = 32;
} else if (perPhase == 2 && maxPhase == 4) {
mode = mlir::triton::nvgpu::WGMMADescMode::swizzle64;
swizzlingByteWidth = 64;
} else if (perPhase == 1 && maxPhase == 8) {
mode = mlir::triton::nvgpu::WGMMADescMode::swizzle128;
swizzlingByteWidth = 128;
} else {
llvm::report_fatal_error("Unsupported shared layout.");
}
// TODO[biaow]: remove it once we support swizzling size larger than matrix
// width, which requires padding the matrix width to the swizzling size when
// allocating shared memory.
assert(swizzlingByteWidth <= widthInByte &&
"swizzling size larger than matrix width is not supported.");
return mode;
}
class DotOpMmaV3SmemLoader {
public:
DotOpMmaV3SmemLoader(Value tensor, const SharedMemoryObject &smemObj,
SmallVector<int64_t> shape, Value warpId,
unsigned int dimWpt, bool trans,
SmallVector<unsigned int> instrShape,
ConversionPatternRewriter &rewriter, Location loc)
: base(smemObj.base), shape(shape), warpId(warpId), dimWpt(dimWpt),
trans(trans), instrShape(instrShape), rewriter(rewriter), loc(loc) {
auto tensorTy = tensor.getType().cast<RankedTensorType>();
auto sharedLayout = tensorTy.getEncoding().cast<SharedEncodingAttr>();
ord = sharedLayout.getOrder();
const int perPhase = sharedLayout.getPerPhase();
const int maxPhase = sharedLayout.getMaxPhase();
elemBytes = tensorTy.getElementTypeBitWidth() / 8;
elemsPerSwizzlingRow = 128 / perPhase / elemBytes;
elemsPerSwizzlingRowVal = i32_val(elemsPerSwizzlingRow);
uint32_t widthInByte = shape[ord[0]] * elemBytes;
mode = getModeFromLayout(sharedLayout, widthInByte);
baseDesc = rewriter.create<triton::nvgpu::WGMMADescCreateOp>(
loc, i64_ty, base, i32_val(shape[ord[1]]), mode);
}
Value smemLoad(int a, int b) {
Value k = i32_val(b * instrShape[1]);
Value m = add(i32_val(a * dimWpt * instrShape[0]),
mul(warpId, i32_val(instrShape[0])));
if (trans) {
std::swap(k, m);
}
Value leading_offset = mul(udiv(k, elemsPerSwizzlingRowVal),
i32_val(shape[ord[1]] * elemsPerSwizzlingRow));
Value stride_offset = mul(m, elemsPerSwizzlingRowVal);
Value offset = add(add(leading_offset, stride_offset),
urem(k, elemsPerSwizzlingRowVal));
Value off1 = mul(i32_val(elemBytes), offset);
Value off_ = zext(i64_ty, udiv(off1, i32_val(16)));
return add(baseDesc, off_);
}
private:
Value base;
SmallVector<int64_t> shape;
Value warpId;
int dimWpt;
bool trans;
Value elemsPerSwizzlingRowVal;
mlir::triton::nvgpu::WGMMADescMode mode;
SmallVector<unsigned int> instrShape;
ArrayRef<unsigned> ord;
ConversionPatternRewriter &rewriter;
Location loc;
int elemsPerSwizzlingRow;
int elemBytes;
Value baseDesc;
};
DotOpMmaV3SmemLoader loadA(TritonGPUToLLVMTypeConverter *typeConverter,
ConversionPatternRewriter &rewriter, Location loc,
const MmaEncodingAttr &mmaEncoding, Value tensor,
const SharedMemoryObject &smemObj, Value thread) {
auto aTensorTy = tensor.getType().cast<RankedTensorType>();
auto aSharedLayout = aTensorTy.getEncoding().dyn_cast<SharedEncodingAttr>();
assert(aSharedLayout && "only support load dot operand from shared.");
auto instrShape = mmaEncoding.getInstrShape();
auto wpt = mmaEncoding.getWarpsPerCTA();
auto aOrd = aSharedLayout.getOrder();
bool transA = aOrd[0] == 0;
auto shapePerCTA = getShapePerCTA(aTensorTy);
int numRepM = ceil<unsigned>(shapePerCTA[0], instrShape[0] * wpt[0]);
int numRepK = ceil<unsigned>(shapePerCTA[1], instrShape[2]);
Value warp = udiv(thread, i32_val(32));
Value warpM = urem(warp, i32_val(wpt[0]));
Value warpId = urem(warpM, i32_val(shapePerCTA[0] / instrShape[0]));
return {tensor,
smemObj,
shapePerCTA,
warpId,
wpt[0],
transA,
{instrShape[0], instrShape[2]},
rewriter,
loc};
}
DotOpMmaV3SmemLoader loadB(TritonGPUToLLVMTypeConverter *typeConverter,
ConversionPatternRewriter &rewriter, Location loc,
MmaEncodingAttr &mmaEncoding, Value tensor,
const SharedMemoryObject &smemObj, Value thread) {
auto bTensorTy = tensor.getType().cast<RankedTensorType>();
auto bSharedLayout = bTensorTy.getEncoding().cast<SharedEncodingAttr>();
assert(bSharedLayout && "only support load B from shared.");
auto instrShape = mmaEncoding.getInstrShape();
auto wpt = mmaEncoding.getWarpsPerCTA();
auto bOrd = bSharedLayout.getOrder();
bool transB = bOrd[0] == 1;
auto shapePerCTA = triton::gpu::getShapePerCTA(bTensorTy);
int numRepK = ceil<unsigned>(shapePerCTA[0], instrShape[2]);
int numRepN = ceil<unsigned>(shapePerCTA[1], instrShape[1] * wpt[1]);
Value warp = udiv(thread, i32_val(32));
Value warpMN = udiv(warp, i32_val(wpt[0]));
Value warpN = urem(warpMN, i32_val(wpt[1]));
Value warpId = urem(warpN, i32_val(shapePerCTA[1] / instrShape[1]));
return {tensor,
smemObj,
shapePerCTA,
warpId,
wpt[1],
transB,
{instrShape[1], instrShape[2]},
rewriter,
loc};
}
LogicalResult convertDot(TritonGPUToLLVMTypeConverter *typeConverter,
ConversionPatternRewriter &rewriter, Location loc,
Operation *op, Value a, Value b, Value c, Value d,
Value loadedA, Value loadedB, Value loadedC,
bool allowTF32, const SharedMemoryObject &smemObjA,
const SharedMemoryObject &smemObjB, bool sync,
Value thread) {
auto aTensorTy = a.getType().cast<RankedTensorType>();
auto bTensorTy = b.getType().cast<RankedTensorType>();
auto dTensorTy = d.getType().cast<RankedTensorType>();
auto aSharedLayout = aTensorTy.getEncoding().cast<SharedEncodingAttr>();
auto bSharedLayout = bTensorTy.getEncoding().cast<SharedEncodingAttr>();
auto mmaEncoding = dTensorTy.getEncoding().cast<MmaEncodingAttr>();
auto aOrd = aSharedLayout.getOrder();
auto bOrd = bSharedLayout.getOrder();
bool transA = aOrd[0] == 0;
bool transB = bOrd[0] == 1;
auto dShapePerCTA = getShapePerCTA(dTensorTy);
auto instrShape = mmaEncoding.getInstrShape();
auto accSize = 2 * (instrShape[1] / 4);
Type resElemTy = dTensorTy.getElementType();
llvm::SmallVector<Type> elemTypes(accSize, resElemTy);
auto accTy =
LLVM::LLVMStructType::getLiteral(rewriter.getContext(), elemTypes);
int M = 4 * instrShape[0];
int N = instrShape[1];
int K = instrShape[2];
auto shapePerCTATile = getShapePerCTATile(mmaEncoding);
int numRepM = ceil<unsigned>(dShapePerCTA[0], shapePerCTATile[0]);
int numRepN = ceil<unsigned>(dShapePerCTA[1], shapePerCTATile[1]);
int numRepK = ceil<unsigned>(aTensorTy.getShape()[1], instrShape[2]);
DotOpMmaV3SmemLoader aLoader =
loadA(typeConverter, rewriter, loc, mmaEncoding, a, smemObjA, thread);
DotOpMmaV3SmemLoader bLoader =
loadB(typeConverter, rewriter, loc, mmaEncoding, b, smemObjB, thread);
auto fc = typeConverter->unpackLLElements(loc, loadedC, rewriter, dTensorTy);
triton::nvgpu::WGMMAEltType eltTypeC = getMmaRetType(d);
assert(eltTypeC != triton::nvgpu::WGMMAEltType::f16 &&
"TODO support f16 return type. This requires packing C into "
"vector<2xf16> type.");
triton::nvgpu::WGMMAEltType eltTypeA = getMmaOperandType(a, allowTF32);
triton::nvgpu::WGMMAEltType eltTypeB = eltTypeA;
triton::nvgpu::WGMMALayout layoutA = transA ? triton::nvgpu::WGMMALayout::col
: triton::nvgpu::WGMMALayout::row;
triton::nvgpu::WGMMALayout layoutB = transB ? triton::nvgpu::WGMMALayout::row
: triton::nvgpu::WGMMALayout::col;
auto func = op->getParentOfType<LLVM::LLVMFuncOp>();
int numTMADescs =
func->getAttr(kAttrNumTMALoadDescsName).cast<IntegerAttr>().getInt();
if (numTMADescs == 0)
rewriter.create<triton::nvgpu::FenceAsyncSharedOp>(loc, 0);
rewriter.create<triton::nvgpu::WGMMAFenceOp>(loc);
llvm::SmallVector<Value> mmaOut(accSize);
for (int m = 0; m < numRepM; ++m) {
for (int n = 0; n < numRepN; ++n) {
// reuse the same mmaOut
for (int i = 0; i < accSize; ++i) {
mmaOut[i] = fc[(m * numRepN + n) * accSize + i];
}
Value d = typeConverter->packLLElements(loc, mmaOut, rewriter, accTy);
for (int k = 0; k < numRepK; ++k) {
auto a = aLoader.smemLoad(m, k);
auto b = bLoader.smemLoad(n, k);
ValueRange operands{a, b, d};
d = rewriter.create<triton::nvgpu::WGMMAOp>(loc, accTy, a, b, d, M, N,
K, eltTypeC, eltTypeA,
eltTypeB, layoutA, layoutB);
}
auto acc = typeConverter->unpackLLElements(loc, d, rewriter, accTy);
for (int i = 0; i < acc.size(); ++i) {
fc[(m * numRepN + n) * accSize + i] = acc[i];
}
}
}
rewriter.create<triton::nvgpu::WGMMACommitGroupOp>(loc);
if (sync)
rewriter.create<triton::nvgpu::WGMMAWaitOp>(loc, 0);
for (auto &elem : fc) {
elem = bitcast(elem, resElemTy);
}
// replace with new packed result
Type structTy = LLVM::LLVMStructType::getLiteral(
mmaEncoding.getContext(), SmallVector<Type>(fc.size(), resElemTy));
auto res = typeConverter->packLLElements(loc, fc, rewriter, structTy);
rewriter.replaceOp(op, res);
return success();
}
// Loading $c to registers, returns a Value.
Value loadC(Value tensor, Value llTensor) {
auto tensorTy = tensor.getType().cast<RankedTensorType>();
auto mmaEncoding = tensorTy.getEncoding().dyn_cast<MmaEncodingAttr>();
assert(mmaEncoding && "Currently, we only support $c with a mma layout.");
auto instrShape = mmaEncoding.getInstrShape();
auto wpt = mmaEncoding.getWarpsPerCTA();
auto shapePerCTA = getShapePerCTA(tensorTy);
auto shapePerCTATile = getShapePerCTATile(mmaEncoding);
int numRepM = ceil<unsigned>(shapePerCTA[0], shapePerCTATile[0]);
int numRepN = ceil<unsigned>(shapePerCTA[1], shapePerCTATile[1]);
size_t fcSize = 2 * (instrShape[1] / 4) * numRepM * numRepN;
auto structTy = llTensor.getType().cast<LLVM::LLVMStructType>();
assert(structTy.getBody().size() == fcSize &&
"DotOp's $c operand should pass the same number of values as $d in "
"mma layout.");
return llTensor;
}
LogicalResult convertWGMMA(triton::DotOp op, triton::DotOp::Adaptor adaptor,
TritonGPUToLLVMTypeConverter *typeConverter,
ConversionPatternRewriter &rewriter, Value thread) {
auto loc = op.getLoc();
Value A = op.getA();
Value B = op.getB();
Value C = op.getC();
auto ATensorTy = A.getType().cast<RankedTensorType>();
auto BTensorTy = B.getType().cast<RankedTensorType>();
assert(ATensorTy.getEncoding().isa<SharedEncodingAttr>() &&
BTensorTy.getEncoding().isa<SharedEncodingAttr>() &&
"Both $a and %b should be Shared layout.");
Value llA, llB, llC;
llA = adaptor.getA();
llB = adaptor.getB();
llC = loadC(C, adaptor.getC());
auto smemObjA = getSharedMemoryObjectFromStruct(loc, llA, rewriter);
auto smemObjB = getSharedMemoryObjectFromStruct(loc, llB, rewriter);
return convertDot(typeConverter, rewriter, loc, op.getOperation(), A, B, C,
op.getD(), llA, llB, llC, op.getAllowTF32(), smemObjA,
smemObjB, true, thread);
}
LogicalResult convertAsyncWGMMA(triton::nvidia_gpu::DotAsyncOp op,
triton::nvidia_gpu::DotAsyncOp::Adaptor adaptor,
TritonGPUToLLVMTypeConverter *typeConverter,
ConversionPatternRewriter &rewriter,
Value thread) {
auto loc = op.getLoc();
Value A = op.getA();
Value B = op.getB();
Value C = op.getC();
auto ATensorTy = A.getType().cast<RankedTensorType>();
auto BTensorTy = B.getType().cast<RankedTensorType>();
assert(ATensorTy.getEncoding().isa<SharedEncodingAttr>() &&
BTensorTy.getEncoding().isa<SharedEncodingAttr>() &&
"Both $a and %b should be Shared layout.");
Value llA, llB, llC;
llA = adaptor.getA();
llB = adaptor.getB();
llC = loadC(C, adaptor.getC());
auto smemObjA = getSharedMemoryObjectFromStruct(loc, llA, rewriter);
auto smemObjB = getSharedMemoryObjectFromStruct(loc, llB, rewriter);
return convertDot(typeConverter, rewriter, loc, op.getOperation(), A, B, C,
op.getD(), llA, llB, llC, op.getAllowTF32(), smemObjA,
smemObjB, false, thread);
}

View File

@@ -353,6 +353,13 @@ static SmallVector<Value> reorderValues(const SmallVector<Value> &values,
llvm_unreachable("unimplemented code path");
}
inline Type getElementType(Value value) {
auto type = value.getType();
if (auto tensorType = type.dyn_cast<RankedTensorType>())
return tensorType.getElementType();
return type;
}
inline SmallVector<Value> unpackI32(const SmallVector<Value> &inValues,
Type srcTy,
ConversionPatternRewriter &rewriter,
@@ -1149,7 +1156,8 @@ struct IndexCastOpLowering
void populateElementwiseOpToLLVMPatterns(
TritonGPUToLLVMTypeConverter &typeConverter, RewritePatternSet &patterns,
PatternBenefit benefit) {
int numWarps, ModuleAxisInfoAnalysis &axisInfoAnalysis,
ModuleAllocation &allocation, PatternBenefit benefit) {
#define POPULATE_TERNARY_OP(SRC_OP, DST_OP) \
patterns.add<ElementwiseOpConversion<SRC_OP, DST_OP>>(typeConverter, benefit);
POPULATE_TERNARY_OP(triton::gpu::SelectOp, LLVM::SelectOp)
@@ -1222,3 +1230,161 @@ void populateElementwiseOpToLLVMPatterns(
// __nv_expf for higher-precision calculation
patterns.add<ExpOpConversionApprox>(typeConverter, benefit);
}
struct FPExtOpConversion
: ElementwiseOpConversionBase<LLVM::FPExtOp, FPExtOpConversion> {
using Base = ElementwiseOpConversionBase<LLVM::FPExtOp, FPExtOpConversion>;
using Base::Base;
using Adaptor = typename Base::OpAdaptor;
static bool isLegalOp(LLVM::FPExtOp op) {
auto retTy = op.getResult().getType();
auto srcTy = op.getOperand().getType();
if (retTy.isF32() && srcTy.isF16()) {
return false;
}
return true;
}
SmallVector<Value> createDestOps(LLVM::FPExtOp op, OpAdaptor adaptor,
ConversionPatternRewriter &rewriter,
Type elemTy, MultipleOperandsRange operands,
Location loc) const {
return {
FpToFpOpConversion::convertFp16ToFp32(loc, rewriter, operands[0][0])};
}
};
struct FPTruncOpConversion
: ElementwiseOpConversionBase<LLVM::FPTruncOp, FPTruncOpConversion> {
using Base =
ElementwiseOpConversionBase<LLVM::FPTruncOp, FPTruncOpConversion>;
using Base::Base;
using Adaptor = typename Base::OpAdaptor;
static bool isLegalOp(LLVM::FPTruncOp op) {
auto retTy = op.getResult().getType();
auto srcTy = op.getOperand().getType();
if (retTy.isF16() && srcTy.isF32()) {
return false;
}
return true;
}
SmallVector<Value> createDestOps(LLVM::FPTruncOp op, OpAdaptor adaptor,
ConversionPatternRewriter &rewriter,
Type elemTy, MultipleOperandsRange operands,
Location loc) const {
return {
FpToFpOpConversion::convertFp32ToFp16(loc, rewriter, operands[0][0])};
}
};
struct TruncOpConversion
: ElementwiseOpConversionBase<LLVM::TruncOp, TruncOpConversion> {
using Base = ElementwiseOpConversionBase<LLVM::TruncOp, TruncOpConversion>;
using Base::Base;
using Adaptor = typename Base::OpAdaptor;
static bool isLegalOp(LLVM::TruncOp op) {
auto retTy = op.getResult().getType();
auto srcTy = op.getOperand().getType();
if (retTy.isInteger(16) && srcTy.isInteger(32)) {
return false;
}
return true;
}
SmallVector<Value> createDestOps(LLVM::TruncOp op, OpAdaptor adaptor,
ConversionPatternRewriter &rewriter,
Type elemTy, MultipleOperandsRange operands,
Location loc) const {
PTXBuilder builder;
auto &cvt = *builder.create("cvt.u16.u32");
auto res = builder.newOperand("=h");
auto operand = builder.newOperand(operands[0][0], "r");
cvt(res, operand);
return {builder.launch(rewriter, loc, i16_ty, false)};
}
};
struct SExtOpConversion
: ElementwiseOpConversionBase<LLVM::SExtOp, SExtOpConversion> {
using Base = ElementwiseOpConversionBase<LLVM::SExtOp, SExtOpConversion>;
using Base::Base;
using Adaptor = typename Base::OpAdaptor;
static bool isLegalOp(LLVM::SExtOp op) {
auto retTy = op.getResult().getType();
auto srcTy = op.getOperand().getType();
if (retTy.isInteger(32) && srcTy.isInteger(16)) {
return false;
}
return true;
}
SmallVector<Value> createDestOps(LLVM::SExtOp op, OpAdaptor adaptor,
ConversionPatternRewriter &rewriter,
Type elemTy, MultipleOperandsRange operands,
Location loc) const {
PTXBuilder builder;
auto &cvt = *builder.create("cvt.s32.s16");
auto res = builder.newOperand("=r");
auto operand = builder.newOperand(operands[0][0], "h");
cvt(res, operand);
return {builder.launch(rewriter, loc, i32_ty, false)};
}
};
struct ZExtOpConversion
: ElementwiseOpConversionBase<LLVM::ZExtOp, ZExtOpConversion> {
using Base = ElementwiseOpConversionBase<LLVM::ZExtOp, ZExtOpConversion>;
using Base::Base;
using Adaptor = typename Base::OpAdaptor;
static bool isLegalOp(LLVM::ZExtOp op) {
auto retTy = op.getResult().getType();
auto srcTy = op.getOperand().getType();
if (retTy.isInteger(32) && srcTy.isInteger(16)) {
return false;
}
return true;
}
SmallVector<Value> createDestOps(LLVM::ZExtOp op, OpAdaptor adaptor,
ConversionPatternRewriter &rewriter,
Type elemTy, MultipleOperandsRange operands,
Location loc) const {
PTXBuilder builder;
auto &cvt = *builder.create("cvt.u32.u16");
auto res = builder.newOperand("=r");
auto operand = builder.newOperand(operands[0][0], "h");
cvt(res, operand);
return {builder.launch(rewriter, loc, i32_ty, false)};
}
};
bool isLegalElementwiseOp(Operation *op) {
if (isa<LLVM::FPExtOp>(op)) {
return FPExtOpConversion::isLegalOp(cast<LLVM::FPExtOp>(op));
} else if (isa<LLVM::FPTruncOp>(op)) {
return FPTruncOpConversion::isLegalOp(cast<LLVM::FPTruncOp>(op));
} else if (isa<LLVM::TruncOp>(op)) {
return TruncOpConversion::isLegalOp(cast<LLVM::TruncOp>(op));
} else if (isa<LLVM::SExtOp>(op)) {
return SExtOpConversion::isLegalOp(cast<LLVM::SExtOp>(op));
} else if (isa<LLVM::ZExtOp>(op)) {
return ZExtOpConversion::isLegalOp(cast<LLVM::ZExtOp>(op));
}
return true;
}
void populateElementwiseOpToPTXPatterns(
TritonGPUToLLVMTypeConverter &typeConverter, RewritePatternSet &patterns,
PatternBenefit benefit) {
patterns.add<FPExtOpConversion>(typeConverter, benefit);
patterns.add<FPTruncOpConversion>(typeConverter, benefit);
patterns.add<TruncOpConversion>(typeConverter, benefit);
patterns.add<SExtOpConversion>(typeConverter, benefit);
patterns.add<ZExtOpConversion>(typeConverter, benefit);
}

View File

@@ -8,8 +8,13 @@ using namespace mlir::triton;
void populateElementwiseOpToLLVMPatterns(
TritonGPUToLLVMTypeConverter &typeConverter, RewritePatternSet &patterns,
PatternBenefit benefit);
int numWarps, ModuleAxisInfoAnalysis &axisInfoAnalysis,
ModuleAllocation &allocation, PatternBenefit benefit);
bool isLegalElementwiseOp(Operation *op);
void populateElementwiseOpToPTXPatterns(
TritonGPUToLLVMTypeConverter &typeConverter, RewritePatternSet &patterns,
PatternBenefit benefit);
#endif

View File

@@ -4,10 +4,15 @@
#include "ConvertLayoutOpToLLVM.h"
#include "LoadStoreOpToLLVM.h"
#include <numeric>
using namespace mlir;
using namespace mlir::triton;
using ::mlir::LLVM::delinearize;
using ::mlir::LLVM::getSharedMemoryObjectFromStruct;
using ::mlir::LLVM::linearize;
using ::mlir::triton::gpu::getShapePerCTA;
using ::mlir::triton::gpu::getTotalElemsPerThread;
using ::mlir::triton::gpu::SharedEncodingAttr;
@@ -64,6 +69,9 @@ struct LoadOpConversion
Value other = op.getOther();
// adaptor values
assert(!isTensorPointerType(ptr.getType()) &&
"Cannot convert load with a tensor pointer into LLVM; "
"this case should be transformed to normal load before lowering");
Value llPtr = adaptor.getPtr();
Value llMask = adaptor.getMask();
Value llOther = adaptor.getOther();
@@ -378,6 +386,251 @@ struct StoreOpConversion
return success();
}
};
// TODO: refactor to save common logic with insertsliceasyncv2
struct StoreAsyncOpConversion
: public ConvertTritonGPUOpToLLVMPattern<triton::nvidia_gpu::StoreAsyncOp> {
using ConvertTritonGPUOpToLLVMPattern<
triton::nvidia_gpu::StoreAsyncOp>::ConvertTritonGPUOpToLLVMPattern;
StoreAsyncOpConversion(TritonGPUToLLVMTypeConverter &converter,
ModuleAllocation &allocation,
mlir::triton::gpu::TMAMetadataTy *tmaMetadata,
const TensorPtrMapT *tensorPtrMap,
PatternBenefit benefit)
: ConvertTritonGPUOpToLLVMPattern<triton::nvidia_gpu::StoreAsyncOp>(
converter, allocation, tmaMetadata, benefit),
tensorPtrMap(tensorPtrMap) {}
LogicalResult
matchAndRewrite(triton::nvidia_gpu::StoreAsyncOp op, OpAdaptor adaptor,
ConversionPatternRewriter &rewriter) const override {
auto loc = op.getLoc();
MLIRContext *ctx = rewriter.getContext();
auto dst = op.getDst();
auto src = op.getSrc();
auto srcTy = src.getType().cast<RankedTensorType>();
auto elemTy = srcTy.getElementType();
auto rank = srcTy.getRank();
assert(rank > 0 && rank <= 5);
auto moduleOp = op->getParentOfType<ModuleOp>();
assert(moduleOp && "Parent ModuleOp not found for StoreAsyncOp");
auto llFuncOp = op->getParentOfType<LLVM::LLVMFuncOp>();
assert(llFuncOp && "LLVMFuncOp not found for StoreAsyncOp");
int numTMADescs = getNumTMADescs(llFuncOp);
assert(numTMADescs > 0);
auto sharedLayout = srcTy.getEncoding().dyn_cast<SharedEncodingAttr>();
assert(sharedLayout && "expected shared encoding");
mlir::triton::gpu::TMAInfo tmaInfo;
tmaInfo.tensorDataType = getCUtensorMapDataType(elemTy);
tmaInfo.tensorRank = rank;
assert(tmaMetadata);
auto inOrder = sharedLayout.getOrder();
unsigned TMADescIdx = tmaMetadata->size();
unsigned numFuncArgs = llFuncOp.getBody().front().getNumArguments();
auto makeTensorPtr = tensorPtrMap->lookup(op.getOperation());
auto dstOrder = makeTensorPtr.getOrder();
unsigned globalAddressArgIdx = getArgIdx(makeTensorPtr.getBase());
tmaInfo.globalAddressArgIdx = globalAddressArgIdx;
tmaInfo.TMADescArgIdx = numFuncArgs - numTMADescs + TMADescIdx;
auto getDimOfOrder = [](ArrayRef<int32_t> order, int32_t i) {
auto it = std::find(order.begin(), order.end(), i);
assert(it != order.end());
return std::distance(order.begin(), it);
};
std::vector<int32_t> globalDimsArgIdx;
std::vector<int32_t> globalStridesArgIdx;
// constant values are mapped to (-1 - value)
for (int i = 0; i < rank; ++i) {
int32_t argIdx = -1;
auto dim = getDimOfOrder(dstOrder, i);
argIdx = getArgIdx(makeTensorPtr.getShape()[dim]);
globalDimsArgIdx.emplace_back(argIdx);
// handle constant stride
argIdx = getArgIdx(makeTensorPtr.getStrides()[dim]);
globalStridesArgIdx.emplace_back(argIdx);
}
tmaInfo.globalDimsArgIdx = globalDimsArgIdx;
tmaInfo.globalStridesArgIdx = globalStridesArgIdx;
std::vector<uint32_t> boxDims;
auto CTAsPerCGA = sharedLayout.getCTALayout().getCTAsPerCGA();
auto CTAOrder = sharedLayout.getCTALayout().getCTAOrder();
auto CTASplitNum = sharedLayout.getCTALayout().getCTASplitNum();
auto tensorShape = makeTensorPtr.getResult()
.getType()
.cast<triton::PointerType>()
.getPointeeType()
.cast<RankedTensorType>()
.getShape();
auto shapePerCTA = getShapePerCTA(CTASplitNum, tensorShape);
// magic 128 bytes
uint32_t bytesPerElem = elemTy.getIntOrFloatBitWidth() / 8;
uint32_t numBox{1};
for (int i = 0; i < rank; ++i) {
auto dim = getDimOfOrder(dstOrder, i);
auto tNumElems = shapePerCTA[dim];
if (i == 0 && tNumElems * bytesPerElem > 128) {
tNumElems = 128 / bytesPerElem;
numBox = (shapePerCTA[dim] + tNumElems - 1) / tNumElems;
}
boxDims.emplace_back(tNumElems);
}
std::vector<uint32_t> elementStrides(rank, 1);
tmaInfo.boxDims = boxDims;
tmaInfo.elementStrides = elementStrides;
CUtensorMapSwizzle swizzle = CUtensorMapSwizzle::CU_TENSOR_MAP_SWIZZLE_NONE;
assert(
((elemTy.getIntOrFloatBitWidth() == 16 && sharedLayout.getVec() == 8) or
(elemTy.getIntOrFloatBitWidth() == 32 &&
sharedLayout.getVec() == 4)) &&
"Unexpected shared layout for StoreAsyncOp");
if (sharedLayout.getPerPhase() == 4 && sharedLayout.getMaxPhase() == 2)
swizzle = CUtensorMapSwizzle::CU_TENSOR_MAP_SWIZZLE_32B;
else if (sharedLayout.getPerPhase() == 2 && sharedLayout.getMaxPhase() == 4)
swizzle = CUtensorMapSwizzle::CU_TENSOR_MAP_SWIZZLE_64B;
else if (sharedLayout.getPerPhase() == 1 && sharedLayout.getMaxPhase() == 8)
swizzle = CUtensorMapSwizzle::CU_TENSOR_MAP_SWIZZLE_128B;
else
llvm::report_fatal_error("Unsupported shared layout for StoreAsyncOp");
tmaInfo.swizzle = swizzle;
tmaInfo.interleave = CUtensorMapInterleave::CU_TENSOR_MAP_INTERLEAVE_NONE;
tmaInfo.l2Promotion =
CUtensorMapL2promotion::CU_TENSOR_MAP_L2_PROMOTION_L2_128B;
tmaInfo.oobFill =
CUtensorMapFloatOOBfill::CU_TENSOR_MAP_FLOAT_OOB_FILL_NONE;
tmaMetadata->emplace_back(tmaInfo);
Value llDst = adaptor.getDst();
Value llSrc = adaptor.getSrc();
auto srcShape = srcTy.getShape();
auto smemObj = getSharedMemoryObjectFromStruct(loc, llSrc, rewriter);
SmallVector<Value> offsetVals;
for (auto i = 0; i < srcShape.size(); ++i) {
offsetVals.emplace_back(i32_val(0));
}
Value tmaDesc =
llFuncOp.getBody().front().getArgument(tmaInfo.TMADescArgIdx);
auto ptrI8SharedTy = LLVM::LLVMPointerType::get(
typeConverter->convertType(rewriter.getI8Type()), 3);
auto threadId = getThreadId(rewriter, loc);
Value pred = icmp_eq(threadId, i32_val(0));
auto llCoord = getTypeConverter()->unpackLLElements(loc, llDst, rewriter,
dst.getType());
uint32_t boxStride = std::accumulate(boxDims.begin(), boxDims.end(), 1,
std::multiplies<uint32_t>());
Value clusterCTAId = getClusterCTAId(rewriter, loc);
SmallVector<Value> multiDimClusterCTAId =
delinearize(rewriter, loc, clusterCTAId, CTAsPerCGA, CTAOrder);
rewriter.create<triton::nvgpu::FenceAsyncSharedOp>(loc, 0);
for (uint32_t b = 0; b < numBox; ++b) {
SmallVector<Value> coord;
// raw coord
for (int i = 0; i < rank; ++i) {
auto dim = getDimOfOrder(dstOrder, i);
coord.push_back(llCoord[dim]);
}
// coord with box and cta offset
for (int i = 0; i < rank; ++i) {
auto dim = getDimOfOrder(dstOrder, i);
if (i == 0) {
coord[i] = add(coord[i], i32_val(b * boxDims[i]));
auto CTAOffset =
mul(multiDimClusterCTAId[dim], i32_val(numBox * boxDims[i]));
coord[i] = add(coord[i], CTAOffset);
} else {
coord[i] = add(coord[i],
mul(multiDimClusterCTAId[dim], i32_val(boxDims[i])));
}
}
Value srcOffset = i32_val(b * boxStride);
auto srcPtrTy = ptr_ty(getTypeConverter()->convertType(elemTy), 3);
Value srcPtrBase = gep(srcPtrTy, smemObj.base, srcOffset);
auto addr = bitcast(srcPtrBase, ptrI8SharedTy);
rewriter.create<triton::nvgpu::TMAStoreTiledOp>(loc, tmaDesc, addr, pred,
coord);
}
rewriter.eraseOp(op);
return success();
}
private:
CUtensorMapDataType getCUtensorMapDataType(Type ty) const {
if (ty.isF16()) {
return CUtensorMapDataType::CU_TENSOR_MAP_DATA_TYPE_FLOAT16;
} else if (ty.isF32()) {
return CUtensorMapDataType::CU_TENSOR_MAP_DATA_TYPE_FLOAT32;
} else {
llvm::report_fatal_error("Unsupported elemTy for StoreAsyncOp");
return CUtensorMapDataType::CU_TENSOR_MAP_DATA_TYPE_FLOAT16;
}
}
unsigned getArgIdx(Value v) const {
if (auto op = v.getDefiningOp<mlir::arith::ConstantOp>()) {
return -1 -
op.getValue().dyn_cast<IntegerAttr>().getValue().getZExtValue();
}
if (v.getDefiningOp() &&
isa<mlir::UnrealizedConversionCastOp>(v.getDefiningOp())) {
return getArgIdx(v.getDefiningOp()->getOperand(0));
} else if (v.getParentBlock()->isEntryBlock() && v.isa<BlockArgument>()) {
// in entryblock and is BlockArgument; Because argument of func are
// arugments of entryblock bb0 in MLIR
return v.cast<BlockArgument>().getArgNumber();
} else if (v.getParentBlock()->isEntryBlock() &&
(!v.isa<BlockArgument>())) {
// in entryblock but not BlockArgument
return getArgIdx(v.getDefiningOp()->getOperand(0));
} else if (!v.getParentBlock()->isEntryBlock()) {
// in non-entryblock
return getArgIdx(v.getDefiningOp()->getOperand(0));
} else {
llvm::report_fatal_error(
"Operand of `MakeTensorPtrOp` is not the function's argument");
return 0;
}
}
int getNumTMADescs(LLVM::LLVMFuncOp func) const {
if (!func->hasAttr(kAttrNumTMALoadDescsName)) {
llvm::report_fatal_error("TritonGPU module should contain a "
"triton_gpu.num-tma-load attribute");
return -1;
}
if (!func->hasAttr(kAttrNumTMAStoreDescsName)) {
llvm::report_fatal_error("TritonGPU module should contain a "
"triton_gpu.num-tma-store attribute");
return -1;
}
return func->getAttr(kAttrNumTMAStoreDescsName)
.cast<IntegerAttr>()
.getInt() +
func->getAttr(kAttrNumTMALoadDescsName).cast<IntegerAttr>().getInt();
}
const TensorPtrMapT *tensorPtrMap;
};
struct AtomicCASOpConversion
: public ConvertTritonGPUOpToLLVMPattern<triton::AtomicCASOp>,
@@ -854,11 +1107,386 @@ struct InsertSliceAsyncOpConversion
}
};
struct InsertSliceAsyncV2OpConversion
: public ConvertTritonGPUOpToLLVMPattern<
triton::nvidia_gpu::InsertSliceAsyncV2Op> {
using ConvertTritonGPUOpToLLVMPattern<
triton::nvidia_gpu::InsertSliceAsyncV2Op>::
ConvertTritonGPUOpToLLVMPattern;
InsertSliceAsyncV2OpConversion(TritonGPUToLLVMTypeConverter &converter,
ModuleAllocation &allocation,
mlir::triton::gpu::TMAMetadataTy *tmaMetadata,
const TensorPtrMapT *tensorPtrMap,
PatternBenefit benefit)
: ConvertTritonGPUOpToLLVMPattern<
triton::nvidia_gpu::InsertSliceAsyncV2Op>(converter, allocation,
tmaMetadata, benefit),
tensorPtrMap(tensorPtrMap) {}
LogicalResult
matchAndRewrite(triton::nvidia_gpu::InsertSliceAsyncV2Op op,
OpAdaptor adaptor,
ConversionPatternRewriter &rewriter) const override {
Location loc = op->getLoc();
auto resultTy = op.getResult().getType().cast<RankedTensorType>();
auto elemTy = resultTy.getElementType();
auto rank = resultTy.getRank() - 1;
// TODO: support any valid rank in (3, 4, 5)
assert(rank > 0 && rank <= 5);
SmallVector<unsigned> shape;
auto axis = op->getAttrOfType<IntegerAttr>("axis").getInt();
auto moduleOp = op->getParentOfType<ModuleOp>();
assert(moduleOp && "Parent ModuleOp not found for InsertSliceAsyncV2Op");
auto llFuncOp = op->getParentOfType<LLVM::LLVMFuncOp>();
assert(llFuncOp && "LLVMFuncOp not found for InsertSliceAsyncV2Op");
int numTMADescs = getNumTMADescs(llFuncOp);
assert(numTMADescs > 0);
auto sharedLayout = resultTy.getEncoding().dyn_cast<SharedEncodingAttr>();
assert(sharedLayout && "unexpected layout of InsertSliceAsyncV2Op");
auto CTAsPerCGA = sharedLayout.getCTALayout().getCTAsPerCGA();
auto CTAOrder = sharedLayout.getCTALayout().getCTAOrder();
auto CTASplitNum = sharedLayout.getCTALayout().getCTASplitNum();
mlir::triton::gpu::TMAInfo tmaInfo;
tmaInfo.tensorDataType = getCUtensorMapDataType(elemTy);
tmaInfo.tensorRank = rank;
assert(tmaMetadata);
unsigned TMADescIdx = tmaMetadata->size();
unsigned numFuncArgs = llFuncOp.getBody().front().getNumArguments();
auto makeTensorPtr = tensorPtrMap->lookup(op.getOperation());
auto inOrder = makeTensorPtr.getOrder();
unsigned globalAddressArgIdx = getArgIdx(makeTensorPtr.getBase());
tmaInfo.globalAddressArgIdx = globalAddressArgIdx;
tmaInfo.TMADescArgIdx = numFuncArgs - numTMADescs + TMADescIdx;
auto getDimOfOrder = [](ArrayRef<int32_t> order, int32_t i) {
auto it = std::find(order.begin(), order.end(), i);
assert(it != order.end());
return std::distance(order.begin(), it);
};
std::vector<int32_t> globalDimsArgIdx;
std::vector<int32_t> globalStridesArgIdx;
// constant values are mapped to (-1 - value)
for (int i = 0; i < rank; ++i) {
int32_t argIdx = -1;
auto dim = getDimOfOrder(inOrder, i);
argIdx = getArgIdx(makeTensorPtr.getShape()[dim]);
globalDimsArgIdx.emplace_back(argIdx);
// handle constant stride
argIdx = getArgIdx(makeTensorPtr.getStrides()[dim]);
globalStridesArgIdx.emplace_back(argIdx);
}
tmaInfo.globalDimsArgIdx = globalDimsArgIdx;
tmaInfo.globalStridesArgIdx = globalStridesArgIdx;
std::vector<uint32_t> boxDims;
auto tensorShape = makeTensorPtr.getResult()
.getType()
.cast<triton::PointerType>()
.getPointeeType()
.cast<RankedTensorType>()
.getShape();
SmallVector<unsigned> numMcast(rank);
unsigned accNumMcast = 1;
for (unsigned i = 0; i < rank; ++i) {
numMcast[i] = CTAsPerCGA[i] / CTASplitNum[i];
accNumMcast *= numMcast[i];
}
auto shapePerCTA = getShapePerCTA(CTASplitNum, tensorShape);
for (size_t i = 0; i < rank; ++i) {
auto dim = getDimOfOrder(inOrder, i);
// in case of TMA multicast, we should always slice along higher order
// dimensions
if (i == rank - 1) {
assert(shapePerCTA[dim] >= accNumMcast &&
"cases when the size of the highest order is smaller "
"than numMcasts is not implemented");
boxDims.emplace_back(shapePerCTA[dim] / accNumMcast);
} else {
boxDims.emplace_back(shapePerCTA[dim]);
}
}
std::vector<uint32_t> elementStrides(rank, 1);
tmaInfo.elementStrides = elementStrides;
CUtensorMapSwizzle swizzle = CUtensorMapSwizzle::CU_TENSOR_MAP_SWIZZLE_NONE;
if (sharedLayout.getPerPhase() == 4 && sharedLayout.getMaxPhase() == 2)
swizzle = CUtensorMapSwizzle::CU_TENSOR_MAP_SWIZZLE_32B;
else if (sharedLayout.getPerPhase() == 2 && sharedLayout.getMaxPhase() == 4)
swizzle = CUtensorMapSwizzle::CU_TENSOR_MAP_SWIZZLE_64B;
else if (sharedLayout.getPerPhase() == 1 && sharedLayout.getMaxPhase() == 8)
swizzle = CUtensorMapSwizzle::CU_TENSOR_MAP_SWIZZLE_128B;
else
llvm::report_fatal_error(
"Unsupported shared layout for InsertSliceAsyncV2Op");
tmaInfo.swizzle = swizzle;
tmaInfo.interleave = CUtensorMapInterleave::CU_TENSOR_MAP_INTERLEAVE_NONE;
tmaInfo.l2Promotion =
CUtensorMapL2promotion::CU_TENSOR_MAP_L2_PROMOTION_L2_128B;
tmaInfo.oobFill =
CUtensorMapFloatOOBfill::CU_TENSOR_MAP_FLOAT_OOB_FILL_NONE;
uint32_t numBoxes = 1;
uint32_t elemSizeOfBytes = elemTy.getIntOrFloatBitWidth() / 8;
if (swizzle == CUtensorMapSwizzle::CU_TENSOR_MAP_SWIZZLE_128B) {
while (elemSizeOfBytes * boxDims[0] > 128) {
boxDims[0] = boxDims[0] / 2;
numBoxes *= 2;
}
}
tmaInfo.boxDims = boxDims;
tmaMetadata->emplace_back(tmaInfo);
uint32_t elemsPerBox =
std::accumulate(boxDims.begin(), boxDims.end(), 1, std::multiplies{});
Value clusterCTAId = getClusterCTAId(rewriter, loc);
SmallVector<Value> multiDimClusterCTAId =
delinearize(rewriter, loc, clusterCTAId, CTAsPerCGA, CTAOrder);
Value llDst = adaptor.getDst();
Value llIndex = adaptor.getIndex();
Value src = op.getSrc();
Value dst = op.getDst();
auto dstTy = dst.getType().cast<RankedTensorType>();
auto dstShape = dstTy.getShape();
auto smemObj = getSharedMemoryObjectFromStruct(loc, llDst, rewriter);
// the offset of coord considering multicast slicing
SmallVector<Value> mcastOffsetVals;
// The index of slice is this CTAId is responsible for
SmallVector<Value> multiDimSliceIdx(rank);
for (auto i = 0; i < rank; ++i)
multiDimSliceIdx[i] =
udiv(multiDimClusterCTAId[i], i32_val(CTASplitNum[i]));
Value sliceIdx =
linearize(rewriter, loc, multiDimSliceIdx, numMcast, CTAOrder);
Value sliceCoord;
for (auto i = 0; i < rank; ++i) {
if (inOrder[i] == rank - 1) {
// TODO[goostavz]: Cases when the size of the highest order is smaller
// than numMcasts is not implemented.
sliceCoord = mul(sliceIdx, i32_val(shapePerCTA[i] / accNumMcast));
mcastOffsetVals.emplace_back(
mul(sliceIdx, i32_val(shapePerCTA[i] / accNumMcast)));
} else {
mcastOffsetVals.emplace_back(i32_val(0));
}
}
uint32_t elemsPerSlice = std::accumulate(
shapePerCTA.begin(), shapePerCTA.end(), 1, std::multiplies{});
Value dstOffsetCommon = mul(llIndex, i32_val(elemsPerSlice));
// [benzh] sliceCoord should be higher dimension's multiplier accumulate.
// currently only support rank == 2.
dstOffsetCommon =
add(dstOffsetCommon, mul(sliceCoord, i32_val(boxDims[0])));
auto dstPtrTy = ptr_ty(getTypeConverter()->convertType(elemTy), 3);
Value tmaDesc =
llFuncOp.getBody().front().getArgument(tmaInfo.TMADescArgIdx);
// TODO: sink this logic into Triton::NVGPU dialect and support more
// cache-policy modes
Value l2Desc = int_val(64, 0x1000000000000000ll);
auto ptrI8SharedTy = LLVM::LLVMPointerType::get(
typeConverter->convertType(rewriter.getI8Type()), 3);
SmallVector<Value> coordCommon;
auto llCoord = getTypeConverter()->unpackLLElements(
loc, adaptor.getSrc(), rewriter, src.getType());
for (int i = 0; i < rank; ++i) {
auto dim = getDimOfOrder(inOrder, i);
Value coordDim = bitcast(llCoord[dim], i32_ty);
if (CTASplitNum[dim] != 1) {
// Add offset for each CTA
// boxDims[i] * (multiDimClusterCTAId[i] % CTASplitNum[i]);
auto CTAOffset =
mul(i32_val(shapePerCTA[dim]),
urem(multiDimClusterCTAId[dim], i32_val(CTASplitNum[dim])));
coordDim = add(coordDim, CTAOffset);
}
if (i == rank - 1)
// Add offset in case of multicast slicing
coordCommon.push_back(add(coordDim, mcastOffsetVals[dim]));
else
coordCommon.push_back(coordDim);
}
auto threadId = getThreadId(rewriter, loc);
Value pred = icmp_eq(threadId, i32_val(0));
auto mask = adaptor.getMask();
if (mask) {
// TODO(thomas): What is the right implementation for this case?
assert(mask.getType().isInteger(1) &&
"need to implement cases with tensor mask");
pred = rewriter.create<arith::AndIOp>(loc, pred, mask);
}
Value mcastMask = getMCastMask(sharedLayout, rewriter, loc, clusterCTAId);
for (size_t i = 0; i < numBoxes; ++i) {
Value dstOffset =
add(dstOffsetCommon, i32_val(i * elemsPerBox * accNumMcast));
Value dstPtrBase = gep(dstPtrTy, smemObj.base, dstOffset);
SmallVector<Value> coord = coordCommon;
coord[0] = add(coordCommon[0], i32_val(i * boxDims[0]));
rewriter.create<triton::nvgpu::TMALoadTiledOp>(
loc, bitcast(dstPtrBase, ptrI8SharedTy), adaptor.getMbar(), tmaDesc,
l2Desc, pred, coord, mcastMask);
}
rewriter.replaceOp(op, llDst);
return success();
}
private:
Value getMCastMask(const SharedEncodingAttr &sharedLayout,
ConversionPatternRewriter &rewriter, Location loc,
Value clusterCTAId) const {
auto CTAsPerCGA = sharedLayout.getCTALayout().getCTAsPerCGA();
auto CTAOrder = sharedLayout.getCTALayout().getCTAOrder();
auto CTASplitNum = sharedLayout.getCTALayout().getCTASplitNum();
// Short path when no multicast is needed
if (CTAsPerCGA == CTASplitNum)
return nullptr;
// Short path when bcastMask is a constant
bool isConstMcastMask = true;
for (unsigned s : CTASplitNum) {
if (s > 1) {
isConstMcastMask = false;
break;
}
}
if (isConstMcastMask) {
unsigned numCTAs = std::accumulate(CTAsPerCGA.begin(), CTAsPerCGA.end(),
1, std::multiplies{});
return int_val(/*width*/ 16, (1u << numCTAs) - 1);
}
SmallVector<Value> multiDimCTAId =
delinearize(rewriter, loc, clusterCTAId, CTAsPerCGA, CTAOrder);
auto rank = CTAOrder.size();
SmallVector<SmallVector<Value>> multiDimMask(rank);
unsigned accNumMcast = 1;
SmallVector<unsigned> numMcast(rank);
for (unsigned i = 0; i < rank; ++i) {
// For the ith dimension, CTAsPerCGA[i]/CTASplitNum[i] vals is to be
// broadcasted, which for this CTAId is:
// multiDimCTAId[i] % CTASplitNum[i] + (0 ..
// (CTAsPerCGA[i]/CTASplitNum[i] - 1)) * CTASplitNum[i]
// TODO: will there be cases if CTAsPerCGA[i]/CTASplitNum[i] < 1?
Value rem = urem(multiDimCTAId[i], i32_val(CTASplitNum[i]));
numMcast[i] = CTAsPerCGA[i] / CTASplitNum[i];
accNumMcast *= numMcast[i];
for (unsigned j = 0; j < numMcast[i]; ++j) {
if (j == 0) {
multiDimMask[i].push_back(rem);
} else {
multiDimMask[i].push_back(add(rem, i32_val(j * CTASplitNum[i])));
}
}
}
Value bcastMask = int_val(/*width*/ 16, 0);
Value _1_i16 = int_val(/*width*/ 16, 1);
for (unsigned i = 0; i < accNumMcast; ++i) {
SmallVector<unsigned> multiDimIdx =
getMultiDimIndex<unsigned>(i, numMcast, CTAOrder);
SmallVector<Value> multiDimMaskedCTAId(rank);
for (unsigned dim = 0; dim < rank; ++dim) {
multiDimMaskedCTAId[dim] = multiDimMask[dim][multiDimIdx[dim]];
}
Value bcastCTAId =
linearize(rewriter, loc, multiDimMaskedCTAId, CTAsPerCGA, CTAOrder);
// bcastMask |= 1u << bcastCTAId;
bcastMask = or_(bcastMask, shl(_1_i16, trunc(i16_ty, bcastCTAId)));
}
return bcastMask;
}
CUtensorMapDataType getCUtensorMapDataType(Type ty) const {
if (ty.isF16()) {
return CUtensorMapDataType::CU_TENSOR_MAP_DATA_TYPE_FLOAT16;
} else if (ty.isF32()) {
return CUtensorMapDataType::CU_TENSOR_MAP_DATA_TYPE_FLOAT32;
} else {
llvm::report_fatal_error("Unsupported elemTy for InsertSliceAsyncV2Op");
return CUtensorMapDataType::CU_TENSOR_MAP_DATA_TYPE_FLOAT16;
}
}
unsigned getArgIdx(Value v) const {
if (auto op = v.getDefiningOp<mlir::arith::ConstantOp>()) {
return -1 -
op.getValue().dyn_cast<IntegerAttr>().getValue().getZExtValue();
}
if (v.getDefiningOp() &&
isa<mlir::UnrealizedConversionCastOp>(v.getDefiningOp())) {
return getArgIdx(v.getDefiningOp()->getOperand(0));
} else if (v.getParentBlock()->isEntryBlock() && v.isa<BlockArgument>()) {
// in entryblock and is BlockArgument; Because argument of func are
// arugments of entryblock bb0 in MLIR
return v.cast<BlockArgument>().getArgNumber();
} else if (v.getParentBlock()->isEntryBlock() &&
(!v.isa<BlockArgument>())) {
// in entryblock but not BlockArgument
return getArgIdx(v.getDefiningOp()->getOperand(0));
} else if (!v.getParentBlock()->isEntryBlock()) {
// in non-entryblock
return getArgIdx(v.getDefiningOp()->getOperand(0));
} else {
llvm::report_fatal_error(
"Operand of `MakeTensorPtrOp` is not the function's argument");
return 0;
}
}
int getNumTMADescs(LLVM::LLVMFuncOp func) const {
if (!func->hasAttr(kAttrNumTMALoadDescsName)) {
llvm::report_fatal_error("TritonGPU module should contain a "
"triton_gpu.num-tma-load attribute");
return -1;
}
if (!func->hasAttr(kAttrNumTMAStoreDescsName)) {
llvm::report_fatal_error("TritonGPU module should contain a "
"triton_gpu.num-tma-store attribute");
return -1;
}
return func->getAttr(kAttrNumTMAStoreDescsName)
.cast<IntegerAttr>()
.getInt() +
func->getAttr(kAttrNumTMALoadDescsName).cast<IntegerAttr>().getInt();
}
const TensorPtrMapT *tensorPtrMap;
};
void populateLoadStoreOpToLLVMPatterns(
TritonGPUToLLVMTypeConverter &typeConverter, RewritePatternSet &patterns,
ModuleAxisInfoAnalysis &axisInfoAnalysis, ModuleAllocation &allocation,
int numWarps, ModuleAxisInfoAnalysis &axisInfoAnalysis,
ModuleAllocation &allocation,
ConvertTritonGPUOpToLLVMPatternBase::IndexCacheInfo &indexCacheInfo,
PatternBenefit benefit) {
mlir::triton::gpu::TMAMetadataTy *tmaMetadata,
const TensorPtrMapT *tensorPtrMap, PatternBenefit benefit) {
patterns.add<LoadOpConversion>(typeConverter, axisInfoAnalysis, benefit);
patterns.add<StoreOpConversion>(typeConverter, axisInfoAnalysis, benefit);
patterns.add<AtomicCASOpConversion>(typeConverter, allocation,
@@ -869,4 +1497,8 @@ void populateLoadStoreOpToLLVMPatterns(
indexCacheInfo, benefit);
patterns.add<InsertSliceAsyncOpConversion>(
typeConverter, allocation, indexCacheInfo, axisInfoAnalysis, benefit);
patterns.add<InsertSliceAsyncV2OpConversion>(
typeConverter, allocation, tmaMetadata, tensorPtrMap, benefit);
patterns.add<StoreAsyncOpConversion>(typeConverter, allocation, tmaMetadata,
tensorPtrMap, benefit);
}

View File

@@ -8,8 +8,10 @@ using namespace mlir::triton;
void populateLoadStoreOpToLLVMPatterns(
TritonGPUToLLVMTypeConverter &typeConverter, RewritePatternSet &patterns,
ModuleAxisInfoAnalysis &axisInfoAnalysis, ModuleAllocation &allocation,
int numWarps, ModuleAxisInfoAnalysis &axisInfoAnalysis,
ModuleAllocation &allocation,
ConvertTritonGPUOpToLLVMPatternBase::IndexCacheInfo &indexCacheInfo,
PatternBenefit benefit);
mlir::triton::gpu::TMAMetadataTy *tmaMetadata,
const TensorPtrMapT *tensorPtrMap, PatternBenefit benefit);
#endif

View File

@@ -1,5 +1,7 @@
#include "ReduceOpToLLVM.h"
#include "Utility.h"
#include "mlir/Dialect/LLVMIR/NVVMDialect.h"
#include "triton/Dialect/TritonNvidiaGPU/Transforms/Utility.h"
using namespace mlir;
using namespace mlir::triton;
@@ -14,8 +16,13 @@ using ::mlir::triton::gpu::getTotalElemsPerThread;
struct ReduceOpConversion
: public ConvertTritonGPUOpToLLVMPattern<triton::ReduceOp> {
public:
using ConvertTritonGPUOpToLLVMPattern<
triton::ReduceOp>::ConvertTritonGPUOpToLLVMPattern;
ReduceOpConversion(
TritonGPUToLLVMTypeConverter &typeConverter, ModuleAllocation &allocation,
ConvertTritonGPUOpToLLVMPatternBase::IndexCacheInfo &indexCacheInfo,
int computeCapability, PatternBenefit benefit)
: ConvertTritonGPUOpToLLVMPattern<triton::ReduceOp>(
typeConverter, allocation, indexCacheInfo, benefit),
computeCapability(computeCapability) {}
LogicalResult
matchAndRewrite(triton::ReduceOp op, OpAdaptor adaptor,
@@ -26,14 +33,12 @@ public:
}
private:
int computeCapability;
void accumulate(ConversionPatternRewriter &rewriter, Region &combineOp,
llvm::SmallVectorImpl<Value> &acc, ValueRange cur,
bool isFirst) const {
SmallVector<Value> &acc, ValueRange cur, bool isFirst) const {
if (isFirst) {
acc.resize(cur.size());
for (unsigned i = 0; i < cur.size(); ++i) {
acc[i] = cur[i];
}
acc = SmallVector<Value>(cur.begin(), cur.end());
return;
}
@@ -114,7 +119,7 @@ private:
// writeIdx[originalAxis] = index[originalAxis] / axisSizePerThread
writeIdx[originalAxis] = udiv(index[originalAxis], axisSizePerThread);
} else if (auto mmaLayout = layout.dyn_cast<MmaEncodingAttr>()) {
if (!mmaLayout.isAmpere()) {
if (!mmaLayout.isAmpere() && !mmaLayout.isHopper()) {
llvm::report_fatal_error("Unsupported layout");
}
if (originalAxis == 0) {
@@ -157,7 +162,6 @@ private:
elemPtrTys[i] = LLVM::LLVMPointerType::get(llvmElemTy, 3);
}
auto llvmIndexTy = getTypeConverter()->getIndexType();
auto indexPtrTy = LLVM::LLVMPointerType::get(llvmIndexTy, 3);
auto smemShape = helper.getScratchConfigBasic();
unsigned elems = product<unsigned>(smemShape);
@@ -171,33 +175,10 @@ private:
elemPtrTys[i]);
}
unsigned srcElems = getTotalElemsPerThread(srcTys[0]);
// Emits indices of the original tensor that each thread
// would own
auto srcIndices = emitIndices(loc, rewriter, srcLayout, srcTys[0]);
auto srcValues = unpackInputs(loc, op, adaptor, rewriter);
// Emits offsets (the offset from the base index)
// of the original tensor that each thread would own
// NOTE: Assumes offsets don't actually depend on type
SmallVector<SmallVector<unsigned>> offset =
emitOffsetForLayout(srcLayout, srcTys[0]);
// Keep track of accumulations and their indices
std::map<SmallVector<unsigned>, SmallVector<Value>> accs;
std::map<SmallVector<unsigned>, SmallVector<Value>> indices;
Region *combineOp = &op.getCombineOp();
// reduce within threads
for (unsigned i = 0; i < srcElems; ++i) {
SmallVector<unsigned> key = offset[i];
key[axis] = 0;
bool isFirst = accs.find(key) == accs.end();
accumulate(rewriter, *combineOp, accs[key], srcValues[i], isFirst);
if (isFirst)
indices[key] = srcIndices[i];
}
reduceWithinThreads(helper, srcValues, accs, indices, rewriter);
// cached int32 constants
std::map<int, Value> ints;
@@ -249,15 +230,17 @@ private:
readPtrs[i] = gep(elemPtrTys[i], writePtrs[i], readOffset);
}
barrier();
sync(rewriter, loc, op);
// Combine accumulator value from another thread
SmallVector<Value> cur(op.getNumOperands());
for (unsigned i = 0; i < op.getNumOperands(); ++i) {
cur[i] = load(readPtrs[i]);
}
accumulate(rewriter, *combineOp, acc, cur, false);
accumulate(rewriter, op.getCombineOp(), acc, cur, false);
sync(rewriter, loc, op);
barrier();
// Publish our new accumulator value to shared memory
for (unsigned i = 0; i < op.getNumOperands(); ++i) {
store(acc[i], writePtrs[i]);
@@ -265,7 +248,7 @@ private:
}
}
barrier();
sync(rewriter, loc, op);
// set output values
SmallVector<Value> results(op.getNumOperands());
@@ -302,78 +285,186 @@ private:
return success();
}
// Use warp shuffle for reduction within warps and shared memory for data
// exchange across warps
LogicalResult matchAndRewriteFast(triton::ReduceOp op, OpAdaptor adaptor,
ConversionPatternRewriter &rewriter) const {
ReduceOpHelper helper(op);
Location loc = op->getLoc();
unsigned axis = adaptor.getAxis();
auto srcTys = op.getInputTypes();
auto srcLayout = helper.getSrcLayout();
if (!helper.isSupportedLayout()) {
assert(false && "Unexpected srcLayout in ReduceOpConversion");
void sync(ConversionPatternRewriter &rewriter, Location loc,
triton::ReduceOp op) const {
// TODO[shuhaoj]: change hard code style of numThreads. Hide async_agent
// attr.
if (op->hasAttr("async_agent")) {
barSync(rewriter, op, getAgentIds(op).front(), 128);
} else {
barrier();
}
auto srcOrd = triton::gpu::getOrder(srcLayout);
auto srcShape = helper.getSrcShape();
}
SmallVector<Type> elemPtrTys(srcTys.size());
for (unsigned i = 0; i < op.getNumOperands(); ++i) {
auto ty = srcTys[i].getElementType();
auto llvmElemTy = getTypeConverter()->convertType(ty);
elemPtrTys[i] = LLVM::LLVMPointerType::get(llvmElemTy, 3);
// Check if the reduction can use a redux op and return the kind.
std::optional<NVVM::ReduxKind> matchReduxKind(triton::ReduceOp op) const {
if (computeCapability < 80)
return std::nullopt;
if (op.getNumOperands() != 1 || op.getNumResults() != 1)
return std::nullopt;
Block *block = &(*op.getCombineOp().begin());
Operation *yield = block->getTerminator();
Operation *reduceOp = yield->getOperand(0).getDefiningOp();
if (!reduceOp || reduceOp->getNumOperands() != 2 ||
reduceOp->getNumResults() != 1 ||
!reduceOp->getResultTypes()[0].isInteger(32))
return std::nullopt;
if (reduceOp->getOperand(0) != block->getArgument(0) ||
reduceOp->getOperand(1) != block->getArgument(1))
return std::nullopt;
if (isa<arith::AddIOp>(reduceOp))
return NVVM::ReduxKind::ADD;
if (isa<arith::AndIOp>(reduceOp))
return NVVM::ReduxKind::AND;
if (isa<arith::OrIOp>(reduceOp))
return NVVM::ReduxKind::OR;
if (isa<arith::XOrIOp>(reduceOp))
return NVVM::ReduxKind::XOR;
if (auto externalCall =
dyn_cast<triton::PureExternElementwiseOp>(reduceOp)) {
if (externalCall.getSymbol() == "__nv_min")
return NVVM::ReduxKind::MIN;
if (externalCall.getSymbol() == "__nv_umin")
return NVVM::ReduxKind::UMIN;
if (externalCall.getSymbol() == "__nv_max")
return NVVM::ReduxKind::MAX;
if (externalCall.getSymbol() == "__nv_umax")
return NVVM::ReduxKind::UMAX;
}
auto llvmIndexTy = getTypeConverter()->getIndexType();
auto indexPtrTy = LLVM::LLVMPointerType::get(llvmIndexTy, 3);
auto smemShapes = helper.getScratchConfigsFast();
unsigned elems = product<unsigned>(smemShapes[0]);
unsigned maxElems = std::max(elems, product<unsigned>(smemShapes[1]));
unsigned sizeIntraWarps = helper.getIntraWarpSizeWithUniqueData();
unsigned sizeInterWarps = helper.getInterWarpSizeWithUniqueData();
SmallVector<Value> smemBases(op.getNumOperands());
bool isWarpSync = helper.isWarpSynchronous();
if (!isWarpSync) {
smemBases[0] = bitcast(
getSharedMemoryBase(loc, rewriter, op.getOperation()), elemPtrTys[0]);
for (unsigned i = 1; i < op.getNumOperands(); ++i) {
smemBases[i] =
bitcast(gep(elemPtrTys[i - 1], smemBases[i - 1], i32_val(maxElems)),
elemPtrTys[i]);
}
}
unsigned srcElems = getTotalElemsPerThread(srcTys[0]);
auto srcIndices = emitIndices(loc, rewriter, srcLayout, srcTys[0]);
auto srcValues = unpackInputs(loc, op, adaptor, rewriter);
std::map<SmallVector<unsigned>, SmallVector<Value>> accs;
std::map<SmallVector<unsigned>, SmallVector<Value>> indices;
return std::nullopt;
}
// Reduce along op axis for elements that are in the same thread. The
// accumulated value is stored in accs.
void reduceWithinThreads(
ReduceOpHelper &helper, SmallVector<SmallVector<Value>> &srcValues,
std::map<SmallVector<unsigned>, SmallVector<Value>> &accs,
std::map<SmallVector<unsigned>, SmallVector<Value>> &indices,
ConversionPatternRewriter &rewriter) const {
triton::ReduceOp op = helper.getOperation();
RankedTensorType operandType = op.getInputTypes()[0];
// Assumes offsets don't actually depend on type
SmallVector<SmallVector<unsigned>> offset =
emitOffsetForLayout(srcLayout, srcTys[0]);
emitOffsetForLayout(helper.getSrcLayout(), operandType);
unsigned srcElems = getTotalElemsPerThread(operandType);
auto *combineOp = &op.getCombineOp();
auto srcIndices =
emitIndices(op.getLoc(), rewriter, helper.getSrcLayout(), operandType);
// reduce within threads
for (unsigned i = 0; i < srcElems; ++i) {
SmallVector<unsigned> key = offset[i];
key[axis] = 0;
key[op.getAxis()] = 0;
bool isFirst = accs.find(key) == accs.end();
accumulate(rewriter, *combineOp, accs[key], srcValues[i], isFirst);
if (isFirst)
indices[key] = srcIndices[i];
}
}
// Apply warp reduction across the given number of contiguous lanes using op
// region and the accumulator values as source.
void warpReduce(ConversionPatternRewriter &rewriter, Location loc,
SmallVector<Value> &acc, triton::ReduceOp op,
unsigned numLaneToReduce) const {
if (auto kind = matchReduxKind(op)) {
// Based on benchmarking on A100 redux op gives a speed up only when doing
// a single reduction (not partioned) and when the mask is static.
// Therefore we currently only enable it to reduce across all the lanes.
if (numLaneToReduce == 32) {
assert(acc.size() == 1);
Value mask = i32_val(0xFFFFFFFF);
// Even though we currently don't use redux for partitioned reduction
// the code below supports it in case we want to tweak the heuristic.
if (numLaneToReduce < 32) {
// For partitioned reduction we need to caluclate the mask so that
// each group of numLaneToReduce threads has the correct mask.
unsigned bitmask = (1 << numLaneToReduce) - 1;
Value threadId = getThreadId(rewriter, loc);
Value laneId = urem(threadId, i32_val(32));
mask = shl(i32_val(bitmask),
and_(laneId, i32_val(~(numLaneToReduce - 1))));
}
acc[0] = rewriter.create<NVVM::ReduxOp>(loc, acc[0].getType(), acc[0],
*kind, mask);
return;
}
}
for (unsigned N = numLaneToReduce / 2; N > 0; N >>= 1) {
SmallVector<Value> shfl(acc.size());
for (unsigned i = 0; i < acc.size(); ++i) {
shfl[i] = shflSync(loc, rewriter, acc[i], N);
}
accumulate(rewriter, op.getCombineOp(), acc, shfl, false);
}
}
// Reduce across threads within each warp.
void
reduceWithinWarps(ReduceOpHelper &helper,
std::map<SmallVector<unsigned>, SmallVector<Value>> &accs,
ConversionPatternRewriter &rewriter) const {
triton::ReduceOp op = helper.getOperation();
unsigned sizeIntraWarps = helper.getIntraWarpSizeWithUniqueData();
for (auto it : accs) {
const SmallVector<unsigned> &key = it.first;
SmallVector<Value> &acc = accs[key];
warpReduce(rewriter, op.getLoc(), acc, op, sizeIntraWarps);
}
}
// Pack the accumualtor values and replace the reduce op with the result.
void packResults(ReduceOpHelper &helper,
std::map<SmallVector<unsigned>, SmallVector<Value>> &accs,
ConversionPatternRewriter &rewriter) const {
triton::ReduceOp op = helper.getOperation();
Location loc = op.getLoc();
unsigned axis = op.getAxis();
SmallVector<Value> results(op.getNumOperands());
for (unsigned i = 0; i < op.getNumOperands(); ++i) {
if (auto resultTy =
op.getResult()[i].getType().dyn_cast<RankedTensorType>()) {
auto resultLayout = resultTy.getEncoding().cast<SliceEncodingAttr>();
unsigned resultElems = getTotalElemsPerThread(resultTy);
SmallVector<SmallVector<unsigned>> resultOffset =
emitOffsetForLayout(resultLayout, resultTy);
SmallVector<Value> resultVals;
for (int j = 0; j < resultElems; j++) {
auto key = resultOffset[j];
key.insert(key.begin() + axis, 0);
resultVals.push_back(accs[key][i]);
}
results[i] = getTypeConverter()->packLLElements(loc, resultVals,
rewriter, resultTy);
} else
results[i] = accs.begin()->second[i];
}
rewriter.replaceOp(op, results);
}
// Return the type of the shared memory pointer for operand i.
Type getElementPtrType(triton::ReduceOp op, int i) const {
auto ty = op.getInputTypes()[i].getElementType();
auto llvmElemTy = getTypeConverter()->convertType(ty);
return LLVM::LLVMPointerType::get(llvmElemTy, 3);
}
void storeWarpReduceToSharedMemory(
ReduceOpHelper &helper,
std::map<SmallVector<unsigned>, SmallVector<Value>> &accs,
std::map<SmallVector<unsigned>, SmallVector<Value>> &indices,
SmallVector<Value> &smemBases,
ConversionPatternRewriter &rewriter) const {
triton::ReduceOp op = helper.getOperation();
Location loc = op.getLoc();
Value threadId = getThreadId(rewriter, loc);
Value warpSize = i32_val(32);
Value warpId = udiv(threadId, warpSize);
Value laneId = urem(threadId, warpSize);
auto srcLayout = helper.getSrcLayout();
auto srcShape = helper.getSrcShape();
unsigned axis = op.getAxis();
auto smemShapes = helper.getScratchConfigsFast();
auto threadsPerWarp =
triton::gpu::getThreadsPerWarpWithUniqueData(srcLayout, srcShape);
@@ -391,67 +482,38 @@ private:
Value zero = i32_val(0);
Value laneZero = icmp_eq(laneIdAxis, zero);
std::map<SmallVector<unsigned>, SmallVector<Value>> finalAccs;
for (auto it : accs) {
const SmallVector<unsigned> &key = it.first;
SmallVector<Value> acc = it.second;
// Reduce within warps
for (unsigned N = sizeIntraWarps / 2; N > 0; N >>= 1) {
SmallVector<Value> shfl(op.getNumOperands());
for (unsigned i = 0; i < op.getNumOperands(); ++i) {
shfl[i] = shflSync(loc, rewriter, acc[i], N);
}
accumulate(rewriter, *combineOp, acc, shfl, false);
}
if (isWarpSync) {
finalAccs[key] = acc;
continue;
}
SmallVector<Value> &acc = it.second;
SmallVector<Value> writeIdx = indices[key];
writeIdx[axis] = warpIdAxis;
Value writeOffset =
linearize(rewriter, loc, writeIdx, smemShapes[0], order);
for (unsigned i = 0; i < op.getNumOperands(); ++i) {
Value writePtr = gep(elemPtrTys[i], smemBases[i], writeOffset);
auto elemPtrTy = getElementPtrType(op, i);
Value writePtr = gep(elemPtrTy, smemBases[i], writeOffset);
storeShared(rewriter, loc, writePtr, acc[i], laneZero);
}
}
}
if (isWarpSync) {
SmallVector<Value> results(op.getNumOperands());
for (unsigned i = 0; i < op.getNumOperands(); ++i) {
if (auto resultTy =
op.getResult()[i].getType().dyn_cast<RankedTensorType>()) {
auto resultLayout = resultTy.getEncoding().cast<SliceEncodingAttr>();
unsigned resultElems = getTotalElemsPerThread(resultTy);
SmallVector<SmallVector<unsigned>> resultOffset =
emitOffsetForLayout(resultLayout, resultTy);
SmallVector<Value> resultVals;
for (int j = 0; j < resultElems; j++) {
auto key = resultOffset[j];
key.insert(key.begin() + axis, 0);
resultVals.push_back(finalAccs[key][i]);
}
results[i] = getTypeConverter()->packLLElements(loc, resultVals,
rewriter, resultTy);
} else
results[i] = finalAccs.begin()->second[i];
}
rewriter.replaceOp(op, results);
return success();
}
// Load the reduction of each warp and accumulate them to a final value and
// store back to shared memory.
void accumulatePartialReductions(ReduceOpHelper &helper,
SmallVector<Value> &smemBases,
ConversionPatternRewriter &rewriter) const {
triton::ReduceOp op = helper.getOperation();
auto srcLayout = helper.getSrcLayout();
auto smemShapes = helper.getScratchConfigsFast();
unsigned elems = product<unsigned>(smemShapes[0]);
unsigned sizeInterWarps = helper.getInterWarpSizeWithUniqueData();
Location loc = op.getLoc();
barrier();
// The second round of shuffle reduction
// now the problem size: sizeInterWarps, s1, s2, .. , sn
// where sizeInterWarps is 2^m
//
// Each thread needs to process:
// elemsPerThread = sizeInterWarps * s1 * s2 .. Sn / numThreads
Value threadId = getThreadId(rewriter, loc);
Value warpSize = i32_val(32);
Value laneId = urem(threadId, warpSize);
Value zero = i32_val(0);
auto mod = op.getOperation()->getParentOfType<ModuleOp>();
unsigned numThreads =
@@ -464,23 +526,18 @@ private:
// i32_val(sizeInerWarps))
SmallVector<Value> acc(op.getNumOperands());
for (unsigned i = 0; i < op.getNumOperands(); ++i) {
Value readPtr = gep(elemPtrTys[i], smemBases[i], readOffset);
auto elemPtrTy = getElementPtrType(op, i);
Value readPtr = gep(elemPtrTy, smemBases[i], readOffset);
acc[i] = load(readPtr);
}
for (unsigned N = sizeInterWarps / 2; N > 0; N >>= 1) {
SmallVector<Value> shfl(op.getNumOperands());
for (unsigned i = 0; i < op.getNumOperands(); ++i) {
shfl[i] = shflSync(loc, rewriter, acc[i], N);
}
accumulate(rewriter, *combineOp, acc, shfl, false);
}
warpReduce(rewriter, loc, acc, op, sizeInterWarps);
// only the first thread in each sizeInterWarps is writing
Value writeOffset = readOffset;
SmallVector<Value> writePtrs(op.getNumOperands());
for (unsigned i = 0; i < op.getNumOperands(); ++i) {
writePtrs[i] = gep(elemPtrTys[i], smemBases[i], writeOffset);
auto elemPtrTy = getElementPtrType(op, i);
writePtrs[i] = gep(elemPtrTy, smemBases[i], writeOffset);
}
Value threadIsNeeded = icmp_slt(threadId, i32_val(elems));
Value laneIdModSizeInterWarps = urem(laneId, i32_val(sizeInterWarps));
@@ -496,10 +553,17 @@ private:
readOffset = add(readOffset, i32_val(numThreads));
}
}
}
barrier();
// set output values
// Load the final reduction from shared memory and replace the reduce result
// with it.
void loadReductionAndPackResult(ReduceOpHelper &helper,
SmallVector<Value> &smemBases,
ConversionPatternRewriter &rewriter) const {
triton::ReduceOp op = helper.getOperation();
Location loc = op.getLoc();
auto smemShapes = helper.getScratchConfigsFast();
auto order = getOrder(helper.getSrcLayout());
SmallVector<Value> results(op.getNumOperands());
for (unsigned i = 0; i < op.getNumOperands(); ++i) {
if (auto resultTy =
@@ -513,10 +577,11 @@ private:
SmallVector<Value> resultVals(resultElems);
for (size_t j = 0; j < resultElems; ++j) {
SmallVector<Value> readIdx = resultIndices[j];
readIdx.insert(readIdx.begin() + axis, i32_val(0));
readIdx.insert(readIdx.begin() + op.getAxis(), i32_val(0));
Value readOffset =
linearize(rewriter, loc, readIdx, smemShapes[0], order);
Value readPtr = gep(elemPtrTys[i], smemBases[i], readOffset);
Value readPtr =
gep(getElementPtrType(op, i), smemBases[i], readOffset);
resultVals[j] = load(readPtr);
}
@@ -528,6 +593,65 @@ private:
}
}
rewriter.replaceOp(op, results);
}
// Use warp shuffle for reduction within warps and shared memory for data
// exchange across warps
LogicalResult matchAndRewriteFast(triton::ReduceOp op, OpAdaptor adaptor,
ConversionPatternRewriter &rewriter) const {
ReduceOpHelper helper(op);
assert(helper.isSupportedLayout() &&
"Unexpected srcLayout in ReduceOpConversion");
Location loc = op->getLoc();
auto srcValues = unpackInputs(loc, op, adaptor, rewriter);
std::map<SmallVector<unsigned>, SmallVector<Value>> accs;
std::map<SmallVector<unsigned>, SmallVector<Value>> indices;
// First reduce all the values along axis within each thread.
reduceWithinThreads(helper, srcValues, accs, indices, rewriter);
// Then reduce across threads within a warp.
reduceWithinWarps(helper, accs, rewriter);
if (helper.isWarpSynchronous()) {
// If all the values to be reduced are within the same warp there is
// nothing left to do.
packResults(helper, accs, rewriter);
return success();
}
// Compute a shared memory base per operand.
auto smemShapes = helper.getScratchConfigsFast();
unsigned elems = product<unsigned>(smemShapes[0]);
unsigned maxElems = std::max(elems, product<unsigned>(smemShapes[1]));
SmallVector<Value> smemBases(op.getNumOperands());
smemBases[0] =
bitcast(getSharedMemoryBase(loc, rewriter, op.getOperation()),
getElementPtrType(op, 0));
for (unsigned i = 1; i < op.getNumOperands(); ++i) {
smemBases[i] = bitcast(gep(getElementPtrType(op, i - 1), smemBases[i - 1],
i32_val(maxElems)),
getElementPtrType(op, i));
}
storeWarpReduceToSharedMemory(helper, accs, indices, smemBases, rewriter);
sync(rewriter, loc, op);
// The second round of shuffle reduction
// now the problem size: sizeInterWarps, s1, s2, .. , sn
// where sizeInterWarps is 2^m
//
// Each thread needs to process:
// elemsPerThread = sizeInterWarps * s1 * s2 .. Sn / numThreads
accumulatePartialReductions(helper, smemBases, rewriter);
// We could avoid this barrier in some of the layouts, however this is not
// the general case.
// TODO: optimize the barrier in case the layouts are accepted.
sync(rewriter, loc, op);
// set output values
loadReductionAndPackResult(helper, smemBases, rewriter);
return success();
}
@@ -535,9 +659,10 @@ private:
void populateReduceOpToLLVMPatterns(
TritonGPUToLLVMTypeConverter &typeConverter, RewritePatternSet &patterns,
int numWarps, ModuleAxisInfoAnalysis &axisInfoAnalysis,
ModuleAllocation &allocation,
ConvertTritonGPUOpToLLVMPatternBase::IndexCacheInfo &indexCacheInfo,
PatternBenefit benefit) {
int computeCapability, PatternBenefit benefit) {
patterns.add<ReduceOpConversion>(typeConverter, allocation, indexCacheInfo,
benefit);
computeCapability, benefit);
}

View File

@@ -8,8 +8,9 @@ using namespace mlir::triton;
void populateReduceOpToLLVMPatterns(
TritonGPUToLLVMTypeConverter &typeConverter, RewritePatternSet &patterns,
int numWarps, ModuleAxisInfoAnalysis &axisInfoAnalysis,
ModuleAllocation &allocation,
ConvertTritonGPUOpToLLVMPatternBase::IndexCacheInfo &indexCacheInfo,
PatternBenefit benefit);
int computeCapability, PatternBenefit benefit);
#endif

View File

@@ -0,0 +1,43 @@
#include "RegReallocOpToLLVM.h"
using namespace mlir;
using namespace mlir::triton;
struct RegAllocOpConversion
: public ConvertTritonGPUOpToLLVMPattern<triton::nvidia_gpu::RegAllocOp> {
using ConvertTritonGPUOpToLLVMPattern<
triton::nvidia_gpu::RegAllocOp>::ConvertTritonGPUOpToLLVMPattern;
LogicalResult
matchAndRewrite(triton::nvidia_gpu::RegAllocOp op, OpAdaptor adaptor,
ConversionPatternRewriter &rewriter) const override {
Location loc = op->getLoc();
rewriter.replaceOpWithNewOp<triton::nvgpu::RegAllocOp>(
op, adaptor.getRegCount());
return success();
}
};
struct RegDeallocOpConversion
: public ConvertTritonGPUOpToLLVMPattern<triton::nvidia_gpu::RegDeallocOp> {
using ConvertTritonGPUOpToLLVMPattern<
triton::nvidia_gpu::RegDeallocOp>::ConvertTritonGPUOpToLLVMPattern;
LogicalResult
matchAndRewrite(triton::nvidia_gpu::RegDeallocOp op, OpAdaptor adaptor,
ConversionPatternRewriter &rewriter) const override {
Location loc = op->getLoc();
rewriter.replaceOpWithNewOp<triton::nvgpu::RegDeallocOp>(
op, adaptor.getRegCount());
return success();
}
};
void populateRegReallocOpToLLVMPatterns(
TritonGPUToLLVMTypeConverter &typeConverter, RewritePatternSet &patterns,
int numWarps, ModuleAxisInfoAnalysis &axisInfoAnalysis,
const ModuleAllocation &allocation, PatternBenefit benefit) {
patterns.add<RegAllocOpConversion>(typeConverter, benefit);
patterns.add<RegDeallocOpConversion>(typeConverter, benefit);
return;
}

View File

@@ -0,0 +1,14 @@
#ifndef TRITON_CONVERSION_TRITONGPU_TO_LLVM_REGREALLOC_OP_H
#define TRITON_CONVERSION_TRITONGPU_TO_LLVM_REGREALLOC_OP_H
#include "TritonGPUToLLVMBase.h"
using namespace mlir;
using namespace mlir::triton;
void populateRegReallocOpToLLVMPatterns(
TritonGPUToLLVMTypeConverter &typeConverter, RewritePatternSet &patterns,
int numWarps, ModuleAxisInfoAnalysis &axisInfoAnalysis,
const ModuleAllocation &allocation, PatternBenefit benefit);
#endif

View File

@@ -316,6 +316,7 @@ ScanOpConversion::emitFastScan(triton::ScanOp op, triton::ScanOpAdaptor adaptor,
void populateScanOpToLLVMPatterns(
TritonGPUToLLVMTypeConverter &typeConverter, RewritePatternSet &patterns,
int numWarps, ModuleAxisInfoAnalysis &axisInfoAnalysis,
ModuleAllocation &allocation,
ConvertTritonGPUOpToLLVMPatternBase::IndexCacheInfo &indexCacheInfo,
PatternBenefit benefit) {

View File

@@ -8,6 +8,7 @@ using namespace mlir::triton;
void populateScanOpToLLVMPatterns(
TritonGPUToLLVMTypeConverter &typeConverter, RewritePatternSet &patterns,
int numWarps, ModuleAxisInfoAnalysis &axisInfoAnalysis,
ModuleAllocation &allocation,
ConvertTritonGPUOpToLLVMPatternBase::IndexCacheInfo &indexCacheInfo,
PatternBenefit benefit);

View File

@@ -0,0 +1,104 @@
/*
* Copyright (c) 2023 NVIDIA Corporation & Affiliates. All rights reserved.
*
* Permission is hereby granted, free of charge, to any person obtaining
* a copy of this software and associated documentation files
* (the "Software"), to deal in the Software without restriction,
* including without limitation the rights to use, copy, modify, merge,
* publish, distribute, sublicense, and/or sell copies of the Software,
* and to permit persons to whom the Software is furnished to do so,
* subject to the following conditions:
*
* The above copyright notice and this permission notice shall be
* included in all copies or substantial portions of the Software.
*
* THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND,
* EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF
* MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT.
* IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY
* CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION OF CONTRACT,
* TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION WITH THE
* SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE.
*/
#include "TensorPtrOpsToLLVM.h"
using namespace mlir;
using namespace mlir::triton;
struct MakeTensorPtrOpConversion
: public ConvertTritonGPUOpToLLVMPattern<triton::MakeTensorPtrOp> {
using ConvertTritonGPUOpToLLVMPattern<
triton::MakeTensorPtrOp>::ConvertTritonGPUOpToLLVMPattern;
LogicalResult
matchAndRewrite(triton::MakeTensorPtrOp op, OpAdaptor adaptor,
ConversionPatternRewriter &rewriter) const override {
// struct { offset0, offset1, shape0, shape1, stride0,
// stride1, base_ptr};
auto offsets = adaptor.getOffsets();
auto shapes = adaptor.getShape();
auto strides = adaptor.getStrides();
auto base = adaptor.getBase();
auto result = op.getResult();
SmallVector<Value> elems;
for (auto offset : offsets)
elems.push_back(offset);
for (auto shape : shapes)
elems.push_back(shape);
for (auto stride : strides)
elems.push_back(stride);
elems.push_back(base);
auto newValue = getTypeConverter()->packLLElements(
op.getLoc(), elems, rewriter, result.getType());
rewriter.replaceOp(op, newValue);
return success();
}
};
struct AdvanceOpConversion
: public ConvertTritonGPUOpToLLVMPattern<triton::AdvanceOp> {
using ConvertTritonGPUOpToLLVMPattern<
triton::AdvanceOp>::ConvertTritonGPUOpToLLVMPattern;
LogicalResult
matchAndRewrite(triton::AdvanceOp op, OpAdaptor adaptor,
ConversionPatternRewriter &rewriter) const override {
// struct { offset0, offset1, shape0, shape1, stride0,
// stride1, base_ptr};
auto loc = op.getLoc();
auto ptrType = op.getPtr().getType();
auto tensorPtr = adaptor.getPtr();
auto offsets = adaptor.getOffsets();
auto elems =
getTypeConverter()->unpackLLElements(loc, tensorPtr, rewriter, ptrType);
SmallVector<Value, 2> newOffsets;
for (auto [offset, oldOffset] : llvm::zip_first(offsets, elems)) {
newOffsets.push_back((add(offset, oldOffset)));
}
for (size_t i = 0; i < newOffsets.size(); ++i) {
elems[i] = newOffsets[i];
}
auto newValue = getTypeConverter()->packLLElements(op.getLoc(), elems,
rewriter, ptrType);
rewriter.replaceOp(op, newValue);
return success();
}
};
void populateTensorPtrOpsToLLVMPatterns(
TritonGPUToLLVMTypeConverter &typeConverter, RewritePatternSet &patterns,
int numWarps, ModuleAxisInfoAnalysis &axisInfoAnalysis,
ModuleAllocation &allocation, PatternBenefit benefit) {
patterns.add<MakeTensorPtrOpConversion>(typeConverter, benefit);
patterns.add<AdvanceOpConversion>(typeConverter, benefit);
return;
}

View File

@@ -0,0 +1,37 @@
/*
* Copyright (c) 2023 NVIDIA Corporation & Affiliates. All rights reserved.
*
* Permission is hereby granted, free of charge, to any person obtaining
* a copy of this software and associated documentation files
* (the "Software"), to deal in the Software without restriction,
* including without limitation the rights to use, copy, modify, merge,
* publish, distribute, sublicense, and/or sell copies of the Software,
* and to permit persons to whom the Software is furnished to do so,
* subject to the following conditions:
*
* The above copyright notice and this permission notice shall be
* included in all copies or substantial portions of the Software.
*
* THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND,
* EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF
* MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT.
* IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY
* CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION OF CONTRACT,
* TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION WITH THE
* SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE.
*/
#ifndef TRITON_CONVERSION_TRITONGPU_TO_LLVM_TENSOR_PTR_OPS_H
#define TRITON_CONVERSION_TRITONGPU_TO_LLVM_TENSOR_PTR_OPS_H
#include "TritonGPUToLLVMBase.h"
using namespace mlir;
using namespace mlir::triton;
void populateTensorPtrOpsToLLVMPatterns(
TritonGPUToLLVMTypeConverter &typeConverter, RewritePatternSet &patterns,
int numWarps, ModuleAxisInfoAnalysis &axisInfoAnalysis,
ModuleAllocation &allocation, PatternBenefit benefit);
#endif

View File

@@ -389,18 +389,23 @@ struct GetProgramIdOpConversion
LogicalResult
matchAndRewrite(triton::GetProgramIdOp op, OpAdaptor adaptor,
ConversionPatternRewriter &rewriter) const override {
// It is not easy to get the compute capability here, so we use numCTAs to
// decide the semantic of GetProgramIdOp. If numCTAs = 1, then
// GetProgramIdOp is converted to "%ctaid", otherwise it is converted to
// "%clusterid".
auto moduleOp = op->getParentOfType<ModuleOp>();
assert(moduleOp && "Parent ModuleOp not found for GetProgramIdOp");
int numCTAs = triton::gpu::TritonGPUDialect::getNumCTAs(moduleOp);
Location loc = op->getLoc();
assert(op.getAxisAsInt() < 3);
std::string sreg = numCTAs == 1 ? "%ctaid." : "%clusterid.";
sreg.append(1, 'x' + op.getAxisAsInt()); // 0 -> 'x', 1 -> 'y', 2 -> 'z'
Value blockId =
rewriter.create<::mlir::gpu::BlockIdOp>(loc, dims[op.getAxisAsInt()]);
rewriter.replaceOpWithNewOp<arith::IndexCastOp>(op, i32_ty, blockId);
Value programId = getSRegValue(rewriter, loc, sreg);
rewriter.replaceOpWithNewOp<arith::IndexCastOp>(op, i32_ty, programId);
return success();
}
static constexpr mlir::gpu::Dimension dims[] = {mlir::gpu::Dimension::x,
mlir::gpu::Dimension::y,
mlir::gpu::Dimension::z};
};
struct GetNumProgramsOpConversion
@@ -411,19 +416,54 @@ struct GetNumProgramsOpConversion
LogicalResult
matchAndRewrite(triton::GetNumProgramsOp op, OpAdaptor adaptor,
ConversionPatternRewriter &rewriter) const override {
// It is not easy to get the compute capability here, so we use numCTAs to
// decide the semantic of GetNumProgramsOp. If numCTAs = 1, then
// GetNumProgramsOp is converted to "%nctaid", otherwise it is converted to
// "%nclusterid".
auto moduleOp = op->getParentOfType<ModuleOp>();
assert(moduleOp && "Parent ModuleOp not found for GetProgramIdOp");
int numCTAs = triton::gpu::TritonGPUDialect::getNumCTAs(moduleOp);
Location loc = op->getLoc();
assert(op.getAxis() < 3);
std::string sreg = numCTAs == 1 ? "%nctaid." : "%nclusterid.";
sreg.append(1, 'x' + op.getAxis()); // 0 -> 'x', 1 -> 'y', 2 -> 'z'
Value blockId =
rewriter.create<::mlir::gpu::GridDimOp>(loc, dims[op.getAxis()]);
rewriter.replaceOpWithNewOp<arith::IndexCastOp>(op, i32_ty, blockId);
Value numPrograms = getSRegValue(rewriter, loc, sreg);
rewriter.replaceOpWithNewOp<arith::IndexCastOp>(op, i32_ty, numPrograms);
return success();
}
};
static constexpr mlir::gpu::Dimension dims[] = {mlir::gpu::Dimension::x,
mlir::gpu::Dimension::y,
mlir::gpu::Dimension::z};
// TODO[goostavz]: GetThreadIdOp/GetClusterCTAIdOp is a temporary solution
// before async dialect is done. These concepts should appear in ttgpu
// level, and they are planned to be deprecated along with ttgpu.mbarrier_xxx
// ops.
struct GetThreadIdOpConversion : public ConvertTritonGPUOpToLLVMPattern<
triton::nvidia_gpu::GetThreadIdOp> {
using ConvertTritonGPUOpToLLVMPattern<
triton::nvidia_gpu::GetThreadIdOp>::ConvertTritonGPUOpToLLVMPattern;
LogicalResult
matchAndRewrite(triton::nvidia_gpu::GetThreadIdOp op, OpAdaptor adaptor,
ConversionPatternRewriter &rewriter) const override {
rewriter.replaceOp(op, getThreadId(rewriter, op->getLoc()));
return success();
}
};
struct GetClusterCTAIdOpConversion
: public ConvertTritonGPUOpToLLVMPattern<
triton::nvidia_gpu::GetClusterCTAIdOp> {
using ConvertTritonGPUOpToLLVMPattern<
triton::nvidia_gpu::GetClusterCTAIdOp>::ConvertTritonGPUOpToLLVMPattern;
LogicalResult
matchAndRewrite(triton::nvidia_gpu::GetClusterCTAIdOp op, OpAdaptor adaptor,
ConversionPatternRewriter &rewriter) const override {
rewriter.replaceOp(op, getClusterCTAId(rewriter, op->getLoc()));
return success();
}
};
struct AddPtrOpConversion
@@ -479,7 +519,8 @@ struct AllocTensorOpConversion
getTypeConverter()->convertType(resultTy.getElementType());
auto elemPtrTy = ptr_ty(llvmElemTy, 3);
smemBase = bitcast(smemBase, elemPtrTy);
auto order = resultTy.getEncoding().cast<SharedEncodingAttr>().getOrder();
auto sharedLayout = resultTy.getEncoding().cast<SharedEncodingAttr>();
auto order = sharedLayout.getOrder();
// Workaround for 3D tensors
// TODO: we need to modify the pipeline pass to give a proper shared
// encoding to 3D tensors
@@ -489,8 +530,9 @@ struct AllocTensorOpConversion
else
newOrder = SmallVector<unsigned>(order.begin(), order.end());
auto smemObj = SharedMemoryObject(smemBase, resultTy.getShape(), newOrder,
loc, rewriter);
auto shapePerCTA = getShapePerCTA(sharedLayout, resultTy.getShape());
auto smemObj =
SharedMemoryObject(smemBase, shapePerCTA, newOrder, loc, rewriter);
auto retVal = getStructFromSharedMemoryObject(loc, smemObj, rewriter);
rewriter.replaceOp(op, retVal);
return success();
@@ -593,6 +635,49 @@ struct AsyncCommitGroupOpConversion
}
};
struct AsyncBulkWaitOpConversion
: public ConvertTritonGPUOpToLLVMPattern<triton::gpu::AsyncBulkWaitOp> {
using ConvertTritonGPUOpToLLVMPattern<
triton::gpu::AsyncBulkWaitOp>::ConvertTritonGPUOpToLLVMPattern;
LogicalResult
matchAndRewrite(triton::gpu::AsyncBulkWaitOp op, OpAdaptor adaptor,
ConversionPatternRewriter &rewriter) const override {
PTXBuilder ptxBuilder;
auto &asyncBulkWaitOp = *ptxBuilder.create<>("cp.async.bulk.wait_group");
auto num = op->getAttrOfType<IntegerAttr>("num").getInt();
asyncBulkWaitOp(ptxBuilder.newConstantOperand(num));
auto ctx = op.getContext();
auto loc = op.getLoc();
auto voidTy = void_ty(ctx);
ptxBuilder.launch(rewriter, loc, voidTy);
// Safe to remove the op since it doesn't have any return value.
rewriter.eraseOp(op);
return success();
}
};
struct AsyncBulkCommitGroupOpConversion
: public ConvertTritonGPUOpToLLVMPattern<
triton::gpu::AsyncBulkCommitGroupOp> {
using ConvertTritonGPUOpToLLVMPattern<
triton::gpu::AsyncBulkCommitGroupOp>::ConvertTritonGPUOpToLLVMPattern;
LogicalResult
matchAndRewrite(triton::gpu::AsyncBulkCommitGroupOp op, OpAdaptor adaptor,
ConversionPatternRewriter &rewriter) const override {
PTXBuilder ptxBuilder;
ptxBuilder.create<>("cp.async.bulk.commit_group")->operator()();
ptxBuilder.launch(rewriter, op.getLoc(), void_ty(op.getContext()));
// Safe to remove the op since it doesn't have any return value.
rewriter.eraseOp(op);
return success();
}
};
namespace mlir {
namespace LLVM {
@@ -618,6 +703,7 @@ void vprintf_array(Value thread, ArrayRef<Value> arr, std::string info,
void populateTritonGPUToLLVMPatterns(
TritonGPUToLLVMTypeConverter &typeConverter, RewritePatternSet &patterns,
int numWarps, ModuleAxisInfoAnalysis &axisInfoAnalysis,
ModuleAllocation &moduleAllocation,
ConvertTritonGPUOpToLLVMPatternBase::IndexCacheInfo &indexCacheInfo,
PatternBenefit benefit) {
@@ -626,12 +712,15 @@ void populateTritonGPUToLLVMPatterns(
benefit);
patterns.add<AsyncCommitGroupOpConversion>(typeConverter, benefit);
patterns.add<AsyncWaitOpConversion>(typeConverter, benefit);
patterns.add<AsyncBulkCommitGroupOpConversion>(typeConverter, benefit);
patterns.add<AsyncBulkWaitOpConversion>(typeConverter, benefit);
patterns.add<BroadcastOpConversion>(typeConverter, benefit);
patterns.add<ExtractSliceOpConversion>(typeConverter, moduleAllocation,
benefit);
patterns.add<GetProgramIdOpConversion>(typeConverter, benefit);
patterns.add<GetNumProgramsOpConversion>(typeConverter, benefit);
patterns.add<GetThreadIdOpConversion>(typeConverter, benefit);
patterns.add<GetClusterCTAIdOpConversion>(typeConverter, benefit);
patterns.add<MakeRangeOpConversion>(typeConverter, indexCacheInfo, benefit);
patterns.add<ReturnOpConversion>(typeConverter, benefit);
patterns.add<PrintOpConversion>(typeConverter, benefit);

View File

@@ -8,6 +8,7 @@ using namespace mlir::triton;
void populateTritonGPUToLLVMPatterns(
TritonGPUToLLVMTypeConverter &typeConverter, RewritePatternSet &patterns,
int numWarps, ModuleAxisInfoAnalysis &axisInfoAnalysis,
ModuleAllocation &allocation,
ConvertTritonGPUOpToLLVMPatternBase::IndexCacheInfo &indexCacheInfo,
PatternBenefit benefit);

View File

@@ -11,16 +11,28 @@
#include "Utility.h"
#include "mlir/IR/TypeUtilities.h"
#include "triton/Analysis/AxisInfo.h"
#include "triton/Target/PTX/TmaMetadata.h"
#include <set>
#define DEBUG_TYPE "ttgpu_to_llvm"
constexpr ::llvm::StringLiteral kAttrNumTMALoadDescsName =
"triton_gpu.num-tma-load";
constexpr ::llvm::StringLiteral kAttrNumTMAStoreDescsName =
"triton_gpu.num-tma-store";
using namespace mlir;
using namespace mlir::triton;
using ::mlir::LLVM::delinearize;
using ::mlir::LLVM::SharedMemoryObject;
using ::mlir::triton::gpu::BlockedEncodingAttr;
using ::mlir::triton::gpu::CTALayoutAttr;
using ::mlir::triton::gpu::DotOperandEncodingAttr;
using ::mlir::triton::gpu::MmaEncodingAttr;
using ::mlir::triton::gpu::SliceEncodingAttr;
using ::mlir::triton::gpu::TMAMetadataTy;
typedef DenseMap<Operation *, triton::MakeTensorPtrOp> TensorPtrMapT;
namespace mlir {
namespace LLVM {
@@ -141,36 +153,39 @@ protected:
}
};
using IndexCacheKeyT = std::pair<Attribute, RankedTensorType>;
struct IndexCacheKeyT {
Attribute layout;
RankedTensorType type;
bool withCTAOffset;
};
struct CacheKeyDenseMapInfo {
static IndexCacheKeyT getEmptyKey() {
auto *pointer = llvm::DenseMapInfo<void *>::getEmptyKey();
return std::make_pair(
mlir::Attribute(static_cast<mlir::Attribute::ImplType *>(pointer)),
RankedTensorType{});
return {mlir::Attribute(static_cast<mlir::Attribute::ImplType *>(pointer)),
RankedTensorType{}, true};
}
static IndexCacheKeyT getTombstoneKey() {
auto *pointer = llvm::DenseMapInfo<void *>::getTombstoneKey();
auto tombstone = llvm::DenseMapInfo<RankedTensorType>::getTombstoneKey();
return std::make_pair(
mlir::Attribute(static_cast<mlir::Attribute::ImplType *>(pointer)),
tombstone);
return {mlir::Attribute(static_cast<mlir::Attribute::ImplType *>(pointer)),
tombstone, true};
}
static unsigned getHashValue(IndexCacheKeyT key) {
auto shape = key.second.getShape();
return llvm::hash_combine(mlir::hash_value(key.first),
mlir::hash_value(key.second));
return llvm::hash_combine(mlir::hash_value(key.layout),
mlir::hash_value(key.type),
llvm::hash_value(key.withCTAOffset));
}
static bool isEqual(IndexCacheKeyT LHS, IndexCacheKeyT RHS) {
return LHS == RHS;
return LHS.layout == RHS.layout && LHS.type == RHS.type &&
LHS.withCTAOffset == RHS.withCTAOffset;
}
};
class ConvertTritonGPUOpToLLVMPatternBase {
public:
// Two levels of value cache in emitting indices calculation:
// Key: pair<layout, shape>
// Key: {layout, shape, withCTAOffset}
struct IndexCacheInfo {
DenseMap<IndexCacheKeyT, SmallVector<Value>, CacheKeyDenseMapInfo>
*baseIndexCache;
@@ -198,6 +213,12 @@ public:
: converter(&typeConverter), allocation(&allocation),
indexCacheInfo(indexCacheInfo) {}
explicit ConvertTritonGPUOpToLLVMPatternBase(
TritonGPUToLLVMTypeConverter &typeConverter, ModuleAllocation &allocation,
TMAMetadataTy *tmaMetadata)
: converter(&typeConverter), allocation(&allocation),
tmaMetadata(tmaMetadata) {}
TritonGPUToLLVMTypeConverter *getTypeConverter() const { return converter; }
static Value
@@ -223,6 +244,26 @@ public:
return rewriter.create<arith::IndexCastOp>(loc, i32_ty, tid);
}
static Value getSRegValue(OpBuilder &b, Location loc,
const std::string &sRegStr) {
PTXBuilder builder;
auto &mov = builder.create("mov")->o("u32");
auto *destOpr = builder.newOperand("=r");
auto *sRegOpr = builder.newConstantOperand(sRegStr);
mov(destOpr, sRegOpr);
Value val = builder.launch(b, loc, b.getIntegerType(32), false);
auto cast = b.create<UnrealizedConversionCastOp>(
loc, TypeRange{b.getIntegerType(32)}, ValueRange{val});
return cast.getResult(0);
}
Value getClusterCTAId(ConversionPatternRewriter &rewriter,
Location loc) const {
return rewriter.create<triton::nvgpu::ClusterCTAIdOp>(
loc, rewriter.getI32Type());
}
// -----------------------------------------------------------------------
// Shared memory utilities
// -----------------------------------------------------------------------
@@ -259,7 +300,7 @@ public:
// for all indices (row, col) of `srcEncoding` such that idx % inVec = 0,
// the pointer: ptr[(row, col)] = base + (rowOff * strides[ord[1]] +
// colOff) where :
// compute phase = (row // perPhase) % maxPhase
// phase = (row // perPhase) % maxPhase
// rowOff = row
// colOff = colOffSwizzled + colOffOrdered
// colOffSwizzled = ((col // outVec) ^ phase) * outVec
@@ -280,60 +321,89 @@ public:
// then (x + y) XOR z = 0byyyyxxxx XOR 0b00000zzzz = (x XOR z) + y
// This means that we can use some immediate offsets for shared memory
// operations.
auto dstPtrTy = ptr_ty(resElemTy, 3);
auto dstPtrTy = ptr_ty(getTypeConverter()->convertType(resElemTy), 3);
auto dstOffset = dot(rewriter, loc, offsetVals, smemObj.strides);
Value dstPtrBase = gep(dstPtrTy, smemObj.base, dstOffset);
auto srcEncoding = srcTy.getEncoding();
auto srcShape = srcTy.getShape();
auto srcShapePerCTA = triton::gpu::getShapePerCTA(srcTy);
unsigned numElems = triton::gpu::getTotalElemsPerThread(srcTy);
// swizzling params as described in TritonGPUAttrDefs.td
unsigned outVec = resSharedLayout.getVec();
unsigned perPhase = resSharedLayout.getPerPhase();
unsigned maxPhase = resSharedLayout.getMaxPhase();
// order
// Order
auto inOrder = triton::gpu::getOrder(srcEncoding);
auto outOrder = triton::gpu::getOrder(resSharedLayout);
// tensor indices held by the current thread, as LLVM values
auto srcIndices = emitIndices(loc, rewriter, srcEncoding, srcTy);
// return values
// Tensor indices held by the current thread, as LLVM values
auto srcIndices = emitIndices(loc, rewriter, srcEncoding, srcTy, false);
// Swizzling with leading offsets (e.g. Hopper GMMA)
unsigned swizzlingByteWidth = 0;
if (resSharedLayout.getHasLeadingOffset()) {
if (perPhase == 4 && maxPhase == 2)
swizzlingByteWidth = 32;
else if (perPhase == 2 && maxPhase == 4)
swizzlingByteWidth = 64;
else if (perPhase == 1 && maxPhase == 8)
swizzlingByteWidth = 128;
else
llvm::report_fatal_error("Unsupported shared layout.");
}
unsigned numElemsPerSwizzlingRow =
swizzlingByteWidth * 8 / resElemTy.getIntOrFloatBitWidth();
Value numElemsPerSwizzlingRowVal = i32_val(numElemsPerSwizzlingRow);
unsigned leadingDimOffset =
numElemsPerSwizzlingRow * srcShapePerCTA[outOrder[1]];
Value leadingDimOffsetVal = i32_val(leadingDimOffset);
// Return values
DenseMap<unsigned, Value> ret;
// cache for non-immediate offsets
DenseMap<unsigned, Value> cacheCol, cacheRow;
unsigned minVec = std::min(outVec, inVec);
for (unsigned elemIdx = 0; elemIdx < numElems; elemIdx += minVec) {
// extract multi dimensional index for current element
Value offset = i32_val(0);
// Extract multi dimensional index for current element
auto idx = srcIndices[elemIdx];
Value idxCol = idx[outOrder[0]]; // contiguous dimension
Value idxRow = idx[outOrder[1]]; // discontiguous dimension
Value strideCol = srcStrides[outOrder[0]];
Value strideRow = srcStrides[outOrder[1]];
// extract dynamic/static offset for immediate offsetting
unsigned immedateOffCol = 0;
if (auto add = dyn_cast_or_null<LLVM::AddOp>(idxCol.getDefiningOp()))
if (auto _cst = dyn_cast_or_null<LLVM::ConstantOp>(
add.getRhs().getDefiningOp())) {
unsigned cst =
_cst.getValue().cast<IntegerAttr>().getValue().getSExtValue();
unsigned key = cst % (outVec * maxPhase);
cacheCol.insert({key, idxCol});
idxCol = cacheCol[key];
immedateOffCol = cst / (outVec * maxPhase) * (outVec * maxPhase);
}
// extract dynamic/static offset for immediate offsetting
unsigned immedateOffRow = 0;
if (auto add = dyn_cast_or_null<LLVM::AddOp>(idxRow.getDefiningOp()))
if (auto _cst = dyn_cast_or_null<LLVM::ConstantOp>(
add.getRhs().getDefiningOp())) {
unsigned cst =
_cst.getValue().cast<IntegerAttr>().getValue().getSExtValue();
unsigned key = cst % (perPhase * maxPhase);
cacheRow.insert({key, idxRow});
idxRow = cacheRow[key];
immedateOffRow = cst / (perPhase * maxPhase) * (perPhase * maxPhase);
}
// compute phase = (row // perPhase) % maxPhase
Value phase = urem(udiv(idxRow, i32_val(perPhase)), i32_val(maxPhase));
// extract dynamic/static offset for immediate offsetting
unsigned immedateOffCol = 0;
unsigned immedateOffRow = 0;
if (leadingDimOffset) {
// hopper
offset =
mul(udiv(idxCol, numElemsPerSwizzlingRowVal), leadingDimOffsetVal);
// Shrink by swizzling blocks
idxCol = urem(idxCol, numElemsPerSwizzlingRowVal);
strideRow = numElemsPerSwizzlingRowVal;
} else {
if (auto add = dyn_cast_or_null<LLVM::AddOp>(idxCol.getDefiningOp()))
if (auto _cst = dyn_cast_or_null<LLVM::ConstantOp>(
add.getRhs().getDefiningOp())) {
unsigned cst =
_cst.getValue().cast<IntegerAttr>().getValue().getSExtValue();
unsigned key = cst % (outVec * maxPhase);
cacheCol.insert({key, idxCol});
idxCol = cacheCol[key];
immedateOffCol = cst / (outVec * maxPhase) * (outVec * maxPhase);
}
if (auto add = dyn_cast_or_null<LLVM::AddOp>(idxRow.getDefiningOp()))
if (auto _cst = dyn_cast_or_null<LLVM::ConstantOp>(
add.getRhs().getDefiningOp())) {
unsigned cst =
_cst.getValue().cast<IntegerAttr>().getValue().getSExtValue();
unsigned key = cst % (perPhase * maxPhase);
cacheRow.insert({key, idxRow});
idxRow = cacheRow[key];
immedateOffRow =
cst / (perPhase * maxPhase) * (perPhase * maxPhase);
}
}
// row offset is simply row index
Value rowOff = mul(idxRow, strideRow);
// because swizzling happens at a granularity of outVec, we need to
@@ -347,7 +417,7 @@ public:
colOffOrdered = mul(colOffOrdered, i32_val(minVec));
Value colOff = add(colOffSwizzled, colOffOrdered);
// compute non-immediate offset
Value offset = add(rowOff, mul(colOff, strideCol));
offset = add(offset, add(rowOff, mul(colOff, strideCol)));
Value currPtr = gep(dstPtrTy, dstPtrBase, offset);
// compute immediate offset
Value immedateOff =
@@ -477,7 +547,7 @@ public:
auto threadsPerWarp = triton::gpu::getThreadsPerWarp(layout);
auto warpsPerCTA = triton::gpu::getWarpsPerCTA(layout);
auto order = triton::gpu::getOrder(layout);
auto shapePerCTA = triton::gpu::getShapePerCTA(layout, shape);
auto shapePerCTATile = triton::gpu::getShapePerCTATile(layout, shape);
Value warpSize = i32_val(32);
Value laneId = urem(tid, warpSize);
Value warpId = udiv(tid, warpSize);
@@ -487,7 +557,7 @@ public:
delinearize(rewriter, loc, laneId, threadsPerWarp, order);
for (unsigned dim = 0; dim < rank; ++dim) {
// if there is no data replication across threads on this dimension
if (shape[dim] >= shapePerCTA[dim])
if (shape[dim] >= shapePerCTATile[dim])
continue;
// Otherwise, we need to mask threads that will replicate data on this
// dimension. Calculate the thread index on this dimension for the CTA
@@ -535,13 +605,48 @@ public:
// Get offsets / indices for any layout
// -----------------------------------------------------------------------
SmallVector<Value> emitCTAOffsetForLayout(Location loc,
ConversionPatternRewriter &rewriter,
Attribute layout,
ArrayRef<int64_t> shape) const {
unsigned rank = shape.size();
SmallVector<unsigned> CTAsPerCGA = triton::gpu::getCTAsPerCGA(layout);
SmallVector<unsigned> CTASplitNum = triton::gpu::getCTASplitNum(layout);
SmallVector<unsigned> CTAOrder = triton::gpu::getCTAOrder(layout);
SmallVector<int64_t> shapePerCTA =
triton::gpu::getShapePerCTA(CTASplitNum, shape);
// Delinearize clusterCTAId
Value clusterCTAId = getClusterCTAId(rewriter, loc);
SmallVector<Value> multiDimClusterCTAId =
delinearize(rewriter, loc, clusterCTAId, CTAsPerCGA, CTAOrder);
// CTA Wrapping
for (unsigned i = 0; i < rank; ++i) {
// This wrapping rule must be consistent with getShapePerCTA
unsigned splitNum = std::min<unsigned>(shape[i], CTASplitNum[i]);
multiDimClusterCTAId[i] =
urem(multiDimClusterCTAId[i], i32_val(splitNum));
}
SmallVector<Value> CTAOffset(rank);
for (unsigned i = 0; i < rank; ++i)
CTAOffset[i] = mul(multiDimClusterCTAId[i], i32_val(shapePerCTA[i]));
return CTAOffset;
}
SmallVector<Value> emitBaseIndexForLayout(Location loc,
ConversionPatternRewriter &rewriter,
Attribute layout,
RankedTensorType type) const {
IndexCacheKeyT key = std::make_pair(layout, type);
RankedTensorType type,
bool withCTAOffset) const {
auto shape = type.getShape();
IndexCacheKeyT key{layout, type, withCTAOffset};
auto cache = indexCacheInfo.baseIndexCache;
auto insertPt = indexCacheInfo.indexInsertPoint;
SmallVector<Value> baseIndex;
if (cache && cache->count(key) > 0) {
return cache->lookup(key);
} else {
@@ -550,23 +655,34 @@ public:
restoreInsertionPointIfSet(insertPt, rewriter);
SmallVector<Value> result;
if (auto blockedLayout = layout.dyn_cast<BlockedEncodingAttr>()) {
result =
emitBaseIndexForBlockedLayout(loc, rewriter, blockedLayout, type);
result = emitBaseIndexWithinCTAForBlockedLayout(loc, rewriter,
blockedLayout, type);
} else if (auto mmaLayout = layout.dyn_cast<MmaEncodingAttr>()) {
if (mmaLayout.isVolta())
result = emitBaseIndexForMmaLayoutV1(loc, rewriter, mmaLayout, type);
if (mmaLayout.isAmpere())
result = emitBaseIndexForMmaLayoutV2(loc, rewriter, mmaLayout, type);
result = emitBaseIndexWithinCTAForMmaLayoutV1(loc, rewriter,
mmaLayout, type);
if (mmaLayout.isAmpere() || mmaLayout.isHopper())
result = emitBaseIndexWithinCTAForMmaLayoutV2V3(loc, rewriter,
mmaLayout, type);
} else if (auto sliceLayout = layout.dyn_cast<SliceEncodingAttr>()) {
auto parentLayout = sliceLayout.getParent();
auto parentShape = sliceLayout.paddedShape(type.getShape());
RankedTensorType parentTy = RankedTensorType::get(
parentShape, type.getElementType(), parentLayout);
result = emitBaseIndexForLayout(loc, rewriter, parentLayout, parentTy);
result = emitBaseIndexForLayout(loc, rewriter, parentLayout, parentTy,
withCTAOffset);
result.erase(result.begin() + sliceLayout.getDim());
// CTAOffset has been added in emitBaseIndexForLayout of parentLayout
return result;
} else {
llvm_unreachable("unsupported emitBaseIndexForLayout");
}
if (withCTAOffset) {
auto CTAOffset = emitCTAOffsetForLayout(loc, rewriter, layout, shape);
assert(CTAOffset.size() == result.size() && "Rank mismatch");
for (unsigned k = 0; k < result.size(); ++k)
result[k] = add(result[k], CTAOffset[k]);
}
if (cache) {
cache->insert(std::make_pair(key, result));
*insertPt = rewriter.saveInsertionPoint();
@@ -584,6 +700,8 @@ public:
return emitOffsetForMmaLayoutV1(mmaLayout, type);
if (mmaLayout.isAmpere())
return emitOffsetForMmaLayoutV2(mmaLayout, type);
if (mmaLayout.isHopper())
return emitOffsetForMmaLayoutV3(mmaLayout, type);
}
if (auto sliceLayout = layout.dyn_cast<SliceEncodingAttr>())
return emitOffsetForSliceLayout(sliceLayout, type);
@@ -593,11 +711,10 @@ public:
// -----------------------------------------------------------------------
// Emit indices
// -----------------------------------------------------------------------
SmallVector<SmallVector<Value>> emitIndices(Location loc,
ConversionPatternRewriter &b,
Attribute layout,
RankedTensorType type) const {
IndexCacheKeyT key(layout, type);
SmallVector<SmallVector<Value>>
emitIndices(Location loc, ConversionPatternRewriter &b, Attribute layout,
RankedTensorType type, bool withCTAOffset = true) const {
IndexCacheKeyT key{layout, type, withCTAOffset};
auto cache = indexCacheInfo.indexCache;
auto insertPt = indexCacheInfo.indexInsertPoint;
if (cache && cache->count(key) > 0) {
@@ -608,11 +725,14 @@ public:
restoreInsertionPointIfSet(insertPt, b);
SmallVector<SmallVector<Value>> result;
if (auto blocked = layout.dyn_cast<BlockedEncodingAttr>()) {
result = emitIndicesForDistributedLayout(loc, b, blocked, type);
result = emitIndicesForDistributedLayout(loc, b, blocked, type,
withCTAOffset);
} else if (auto mma = layout.dyn_cast<MmaEncodingAttr>()) {
result = emitIndicesForDistributedLayout(loc, b, mma, type);
result =
emitIndicesForDistributedLayout(loc, b, mma, type, withCTAOffset);
} else if (auto slice = layout.dyn_cast<SliceEncodingAttr>()) {
result = emitIndicesForDistributedLayout(loc, b, slice, type);
result =
emitIndicesForDistributedLayout(loc, b, slice, type, withCTAOffset);
} else {
llvm_unreachable(
"emitIndices for layouts other than blocked & slice not "
@@ -642,19 +762,20 @@ private:
// Blocked layout indices
// -----------------------------------------------------------------------
// Get an index-base for each dimension for a \param blocked_layout.
SmallVector<Value> emitBaseIndexForBlockedLayout(
// Get an index-base for each dimension for a \param blockedLayout.
SmallVector<Value> emitBaseIndexWithinCTAForBlockedLayout(
Location loc, ConversionPatternRewriter &rewriter,
const BlockedEncodingAttr &blocked_layout, RankedTensorType type) const {
const BlockedEncodingAttr &blockedLayout, RankedTensorType type) const {
auto shape = type.getShape();
Value threadId = getThreadId(rewriter, loc);
Value warpSize = i32_val(32);
Value laneId = urem(threadId, warpSize);
Value warpId = udiv(threadId, warpSize);
auto sizePerThread = blocked_layout.getSizePerThread();
auto threadsPerWarp = blocked_layout.getThreadsPerWarp();
auto warpsPerCTA = blocked_layout.getWarpsPerCTA();
auto order = blocked_layout.getOrder();
auto sizePerThread = blockedLayout.getSizePerThread();
auto threadsPerWarp = blockedLayout.getThreadsPerWarp();
auto warpsPerCTA = blockedLayout.getWarpsPerCTA();
auto order = blockedLayout.getOrder();
auto shapePerCTA = triton::gpu::getShapePerCTA(blockedLayout, shape);
unsigned rank = shape.size();
// delinearize threadId to get the base index
@@ -666,10 +787,10 @@ private:
SmallVector<Value> multiDimBase(rank);
for (unsigned k = 0; k < rank; ++k) {
// Wrap around multiDimWarpId/multiDimThreadId in case
// shape[k] > shapePerCTA[k]
// shapePerCTATile[k] > shapePerCTA[k]
auto maxWarps =
ceil<unsigned>(shape[k], sizePerThread[k] * threadsPerWarp[k]);
auto maxThreads = ceil<unsigned>(shape[k], sizePerThread[k]);
ceil<unsigned>(shapePerCTA[k], sizePerThread[k] * threadsPerWarp[k]);
auto maxThreads = ceil<unsigned>(shapePerCTA[k], sizePerThread[k]);
multiDimWarpId[k] = urem(multiDimWarpId[k], i32_val(maxWarps));
multiDimThreadId[k] = urem(multiDimThreadId[k], i32_val(maxThreads));
// multiDimBase[k] = (multiDimThreadId[k] +
@@ -692,16 +813,17 @@ private:
auto threadsPerWarp = blockedLayout.getThreadsPerWarp();
auto warpsPerCTA = blockedLayout.getWarpsPerCTA();
auto order = blockedLayout.getOrder();
auto shapePerCTATile = getShapePerCTATile(blockedLayout);
auto shapePerCTA = triton::gpu::getShapePerCTA(blockedLayout, shape);
unsigned rank = shape.size();
SmallVector<unsigned> shapePerCTA = getShapePerCTA(blockedLayout);
SmallVector<unsigned> tilesPerDim(rank);
for (unsigned k = 0; k < rank; ++k)
tilesPerDim[k] = ceil<unsigned>(shape[k], shapePerCTA[k]);
tilesPerDim[k] = ceil<unsigned>(shapePerCTA[k], shapePerCTATile[k]);
SmallVector<SmallVector<unsigned>> offset(rank);
for (unsigned k = 0; k < rank; ++k) {
// 1 block in minimum if shape[k] is less than shapePerCTA[k]
// 1 CTA tile in minimum if shapePerCTA[k] is less than shapePerCTATile[k]
for (unsigned blockOffset = 0; blockOffset < tilesPerDim[k];
++blockOffset)
for (unsigned warpOffset = 0; warpOffset < warpsPerCTA[k]; ++warpOffset)
@@ -741,12 +863,10 @@ private:
// Mma layout indices
// -----------------------------------------------------------------------
SmallVector<Value>
emitBaseIndexForMmaLayoutV1(Location loc, ConversionPatternRewriter &rewriter,
const MmaEncodingAttr &mmaLayout,
RankedTensorType type) const {
SmallVector<Value> emitBaseIndexWithinCTAForMmaLayoutV1(
Location loc, ConversionPatternRewriter &rewriter,
const MmaEncodingAttr &mmaLayout, RankedTensorType type) const {
auto shape = type.getShape();
auto wpt = mmaLayout.getWarpsPerCTA();
static constexpr std::array<int, 3> fpw{{2, 2, 1}};
auto [isARow, isBRow, isAVec4, isBVec4, _] =
@@ -879,9 +999,6 @@ private:
SmallVector<Value> multiDimWarpId =
delinearize(rewriter, loc, warpId, warpsPerCTA, order);
unsigned lastAxis = order[order.size() - 1];
multiDimWarpId[lastAxis] =
urem(multiDimWarpId[lastAxis], i32_val(warpsPerCTA[lastAxis]));
multiDimWarpId[0] = urem(multiDimWarpId[0], i32_val(shape[0] / 16));
multiDimWarpId[1] = urem(multiDimWarpId[1], i32_val(shape[1] / 8));
Value offWarp0 = mul(multiDimWarpId[0], i32_val(16));
@@ -897,10 +1014,13 @@ private:
emitOffsetForMmaLayoutV2(const MmaEncodingAttr &mmaLayout,
RankedTensorType type) const {
auto shape = type.getShape();
auto shapePerCTA = getShapePerCTA(mmaLayout, shape);
SmallVector<SmallVector<unsigned>> ret;
for (unsigned i = 0; i < shape[0]; i += getShapePerCTA(mmaLayout)[0]) {
for (unsigned j = 0; j < shape[1]; j += getShapePerCTA(mmaLayout)[1]) {
for (unsigned i = 0; i < shapePerCTA[0];
i += getShapePerCTATile(mmaLayout)[0]) {
for (unsigned j = 0; j < shapePerCTA[1];
j += getShapePerCTATile(mmaLayout)[1]) {
ret.push_back({i, j});
ret.push_back({i, j + 1});
ret.push_back({i + 8, j});
@@ -910,17 +1030,88 @@ private:
return ret;
}
SmallVector<Value> emitBaseIndexWithinCTAForMmaLayoutV2V3(
Location loc, ConversionPatternRewriter &rewriter,
const MmaEncodingAttr &mmaLayout, RankedTensorType type) const {
auto shape = type.getShape();
auto _warpsPerCTA = mmaLayout.getWarpsPerCTA();
assert(_warpsPerCTA.size() == 2);
auto order = triton::gpu::getOrder(mmaLayout);
ArrayRef<unsigned int> instrShape = mmaLayout.getInstrShape();
SmallVector<Value> warpsPerCTA = {i32_val(_warpsPerCTA[0]),
i32_val(_warpsPerCTA[1])};
auto shapePerCTA = getShapePerCTA(mmaLayout, shape);
Value threadId = getThreadId(rewriter, loc);
Value warpSize = i32_val(32);
Value laneId = urem(threadId, warpSize);
Value warpId = udiv(threadId, warpSize);
uint32_t repM = (_warpsPerCTA[0] * instrShape[0]) / shapePerCTA[0];
uint32_t repN = (_warpsPerCTA[1] * instrShape[1]) / shapePerCTA[1];
uint32_t warpsM;
if (repM > 1)
warpsM = _warpsPerCTA[0] / repM;
else
warpsM = shape[0] / instrShape[0];
uint32_t warpsN;
if (repN > 1)
warpsN = _warpsPerCTA[1] / repN;
else
warpsN = shape[1] / instrShape[1];
SmallVector<Value> multiDimWarpId =
delinearize(rewriter, loc, warpId, _warpsPerCTA, order);
Value warpId0 = urem(multiDimWarpId[0], i32_val(warpsM));
Value warpId1 = urem(multiDimWarpId[1], i32_val(warpsN));
Value offWarp0 = mul(warpId0, i32_val(instrShape[0]));
Value offWarp1 = mul(warpId1, i32_val(instrShape[1]));
SmallVector<Value> multiDimBase(2);
multiDimBase[0] = add(udiv(laneId, i32_val(4)), offWarp0);
multiDimBase[1] = add(mul(i32_val(2), urem(laneId, i32_val(4))), offWarp1);
return multiDimBase;
}
SmallVector<SmallVector<unsigned>>
emitOffsetForMmaLayoutV3(const MmaEncodingAttr &mmaLayout,
RankedTensorType type) const {
auto shape = type.getShape();
auto shapePerCTA = getShapePerCTA(mmaLayout, shape);
SmallVector<SmallVector<unsigned>> ret;
ArrayRef<unsigned int> instrShape = mmaLayout.getInstrShape();
for (unsigned i = 0; i < shapePerCTA[0];
i += getShapePerCTATile(mmaLayout)[0]) {
for (unsigned j = 0; j < shapePerCTA[1];
j += getShapePerCTATile(mmaLayout)[1]) {
for (unsigned k = 0; k < instrShape[1]; k += 8) {
ret.push_back({i, j + k});
ret.push_back({i, j + k + 1});
ret.push_back({i + 8, j + k});
ret.push_back({i + 8, j + k + 1});
}
}
}
return ret;
}
// Emit indices calculation within each ConversionPattern, and returns a
// [elemsPerThread X rank] index matrix.
SmallVector<SmallVector<Value>> emitIndicesForDistributedLayout(
Location loc, ConversionPatternRewriter &rewriter, Attribute layout,
RankedTensorType type) const {
RankedTensorType type, bool withCTAOffset) const {
// step 1, delinearize threadId to get the base index
auto multiDimBase = emitBaseIndexForLayout(loc, rewriter, layout, type);
auto multiDimBase =
emitBaseIndexForLayout(loc, rewriter, layout, type, withCTAOffset);
// step 2, get offset of each element
auto offset = emitOffsetForLayout(layout, type);
// step 3, add offset to base, and reorder the sequence of indices to
// guarantee that elems in the same sizePerThread are adjacent in order
// step 3, add offset to base, and reorder the sequence
// of indices to guarantee that elems in the same
// sizePerThread are adjacent in order
auto shape = type.getShape();
unsigned rank = shape.size();
unsigned elemsPerThread = offset.size();
@@ -961,6 +1152,7 @@ protected:
TritonGPUToLLVMTypeConverter *converter;
ModuleAllocation *allocation;
IndexCacheInfo indexCacheInfo;
mlir::triton::gpu::TMAMetadataTy *tmaMetadata;
};
template <typename SourceOp>
@@ -975,18 +1167,18 @@ public:
: ConvertOpToLLVMPattern<SourceOp>(typeConverter, benefit),
ConvertTritonGPUOpToLLVMPatternBase(typeConverter) {}
explicit ConvertTritonGPUOpToLLVMPattern(
TritonGPUToLLVMTypeConverter &typeConverter, ModuleAllocation &allocation,
PatternBenefit benefit = 1)
: ConvertOpToLLVMPattern<SourceOp>(typeConverter, benefit),
ConvertTritonGPUOpToLLVMPatternBase(typeConverter, allocation) {}
explicit ConvertTritonGPUOpToLLVMPattern(
TritonGPUToLLVMTypeConverter &typeConverter,
IndexCacheInfo indexCacheInfo, PatternBenefit benefit = 1)
: ConvertOpToLLVMPattern<SourceOp>(typeConverter, benefit),
ConvertTritonGPUOpToLLVMPatternBase(typeConverter, indexCacheInfo) {}
explicit ConvertTritonGPUOpToLLVMPattern(
TritonGPUToLLVMTypeConverter &typeConverter, ModuleAllocation &allocation,
PatternBenefit benefit = 1)
: ConvertOpToLLVMPattern<SourceOp>(typeConverter, benefit),
ConvertTritonGPUOpToLLVMPatternBase(typeConverter, allocation) {}
explicit ConvertTritonGPUOpToLLVMPattern(
TritonGPUToLLVMTypeConverter &typeConverter, ModuleAllocation &allocation,
IndexCacheInfo indexCacheInfo, PatternBenefit benefit = 1)
@@ -994,6 +1186,13 @@ public:
ConvertTritonGPUOpToLLVMPatternBase(typeConverter, allocation,
indexCacheInfo) {}
explicit ConvertTritonGPUOpToLLVMPattern(
TritonGPUToLLVMTypeConverter &typeConverter, ModuleAllocation &allocation,
mlir::triton::gpu::TMAMetadataTy *tmaMetadata, PatternBenefit benefit = 1)
: ConvertOpToLLVMPattern<SourceOp>(typeConverter, benefit),
ConvertTritonGPUOpToLLVMPatternBase(typeConverter, allocation,
tmaMetadata) {}
protected:
TritonGPUToLLVMTypeConverter *getTypeConverter() const {
LLVMTypeConverter *ret =

View File

@@ -14,20 +14,27 @@
#include "mlir/Dialect/LLVMIR/NVVMDialect.h"
#include "mlir/Dialect/LLVMIR/ROCDLDialect.h"
#include "mlir/Pass/Pass.h"
#include "mlir/Transforms/GreedyPatternRewriteDriver.h"
#include "triton/Analysis/Allocation.h"
#include "triton/Analysis/AxisInfo.h"
#include "triton/Analysis/Membar.h"
#include "triton/Dialect/Triton/IR/Dialect.h"
#include "triton/Dialect/TritonGPU/IR/Dialect.h"
#include "triton/Dialect/TritonNvidiaGPU/IR/Dialect.h"
#include "triton/Tools/Sys/GetPlatform.hpp"
#include "BarrierOpToLLVM.h"
#include "ClusterOpsToLLVM.h"
#include "ConvertLayoutOpToLLVM.h"
#include "DotOpToLLVM.h"
#include "ElementwiseOpToLLVM.h"
#include "LoadStoreOpToLLVM.h"
#include "ReduceOpToLLVM.h"
#include "RegReallocOpToLLVM.h"
#include "ScanOpToLLVM.h"
#include "TensorPtrOpsToLLVM.h"
#include "TritonGPUToLLVM.h"
#include "TritonGPUToLLVMBase.h"
#include "TypeConverter.h"
#include "ViewOpToLLVM.h"
@@ -35,6 +42,7 @@
using namespace mlir;
using namespace mlir::triton;
namespace ttng = mlir::triton::nvidia_gpu;
#define GEN_PASS_CLASSES
#include "triton/Conversion/TritonGPUToLLVM/Passes.h.inc"
@@ -56,6 +64,30 @@ public:
}
};
class FoldSplatMaskInInsertAsync : public mlir::RewritePattern {
public:
FoldSplatMaskInInsertAsync(mlir::MLIRContext *context)
: mlir::RewritePattern(
triton::nvidia_gpu::InsertSliceAsyncV2Op::getOperationName(), 1,
context) {}
LogicalResult
matchAndRewrite(mlir::Operation *op,
mlir::PatternRewriter &rewriter) const override {
auto insertOp = cast<triton::nvidia_gpu::InsertSliceAsyncV2Op>(op);
if (!insertOp.getMask())
return failure();
auto splatOp = insertOp.getMask().getDefiningOp<triton::SplatOp>();
if (!splatOp)
return failure();
rewriter.updateRootInPlace(insertOp, [&]() {
insertOp.getMaskMutable().assign(splatOp->getOperand(0));
});
return mlir::success();
}
};
struct ReturnOpConversion : public ConvertOpToLLVMPattern<triton::ReturnOp> {
using ConvertOpToLLVMPattern<triton::ReturnOp>::ConvertOpToLLVMPattern;
@@ -146,6 +178,18 @@ struct FuncOpConversion : public FuncOpConversionBase {
if (!allocation.isRoot(funcOp))
amendedFuncOp = amendFuncOp(funcOp, rewriter);
// Collect TMA informations.
unsigned numTMALoad = 0;
funcOp.walk(
[&numTMALoad](triton::nvidia_gpu::InsertSliceAsyncV2Op insertSliceOp) {
numTMALoad++;
});
unsigned numTMAStore = 0;
funcOp.walk([&numTMAStore](triton::nvidia_gpu::StoreAsyncOp storeAsyncOp) {
numTMAStore++;
});
unsigned numTMA = numTMALoad + numTMAStore;
auto newFuncOp = convertFuncOpToLLVMFuncOp(amendedFuncOp, rewriter);
if (!newFuncOp) {
return failure();
@@ -171,6 +215,30 @@ struct FuncOpConversion : public FuncOpConversionBase {
// The call graph is updated by mapping the old function to the new one.
allocation.mapFuncOp(funcOp, newFuncOp);
// Append arguments to receive TMADesc in global memory in the runtime
auto i8PtrTy = LLVM::LLVMPointerType::get(
this->getTypeConverter()->convertType(rewriter.getI8Type()), 1);
auto numArgs = newFuncOp.getBody().front().getNumArguments();
auto funcTy = newFuncOp.getFunctionType().cast<LLVM::LLVMFunctionType>();
SmallVector<Type> newInputsTy(funcTy.getParams().begin(),
funcTy.getParams().end());
for (unsigned i = 0; i < numTMA; ++i) {
newFuncOp.getBody().front().addArgument(i8PtrTy, funcOp.getLoc());
newInputsTy.push_back(i8PtrTy);
}
newFuncOp.setType(
LLVM::LLVMFunctionType::get(funcTy.getReturnType(), newInputsTy));
// required by AxisInfoAnalysis
for (unsigned i = 0; i < numTMA; ++i) {
newFuncOp.setArgAttr(numArgs + i, "tt.divisibility",
rewriter.getIntegerAttr(i32_ty, 1));
}
newFuncOp->setAttr(kAttrNumTMALoadDescsName,
rewriter.getIntegerAttr(i32_ty, numTMALoad));
newFuncOp->setAttr(kAttrNumTMAStoreDescsName,
rewriter.getIntegerAttr(i32_ty, numTMAStore));
rewriter.eraseOp(funcOp);
return success();
}
@@ -247,7 +315,6 @@ private:
this->getTypeConverter()->packFunctionResults(resultTypes)))
return nullptr;
}
auto newCallOp = rewriter.create<LLVM::CallOp>(
callOp.getLoc(), packedResult ? TypeRange(packedResult) : TypeRange(),
promotedOperands, callOp->getAttrs());
@@ -288,8 +355,10 @@ public:
} else {
addLegalDialect<NVVM::NVVMDialect>();
}
addLegalDialect<mlir::triton::nvgpu::NVGPUDialect>();
addIllegalDialect<triton::TritonDialect>();
addIllegalDialect<triton::gpu::TritonGPUDialect>();
addIllegalDialect<triton::nvidia_gpu::TritonNvidiaGPUDialect>();
addIllegalDialect<mlir::gpu::GPUDialect>();
addLegalOp<mlir::UnrealizedConversionCastOp>();
}
@@ -299,8 +368,11 @@ class ConvertTritonGPUToLLVM
: public ConvertTritonGPUToLLVMBase<ConvertTritonGPUToLLVM> {
public:
explicit ConvertTritonGPUToLLVM(int computeCapability, bool isROCM)
: computeCapability(computeCapability), isROCM(isROCM) {}
explicit ConvertTritonGPUToLLVM(int computeCapability,
mlir::triton::gpu::TMAMetadataTy *tmaMetadata,
bool isROCM)
: computeCapability(computeCapability), tmaMetadata(tmaMetadata),
isROCM(isROCM) {}
void runOnOperation() override {
MLIRContext *context = &getContext();
@@ -310,19 +382,54 @@ public:
TritonGPUToLLVMTypeConverter typeConverter(context, option);
TritonLLVMConversionTarget target(*context, isROCM);
int numWarps = triton::gpu::TritonGPUDialect::getNumWarps(mod);
int numCTAs = triton::gpu::TritonGPUDialect::getNumCTAs(mod);
int threadsPerWarp = triton::gpu::TritonGPUDialect::getThreadsPerWarp(mod);
// Preprocess
decomposeFp8e4b15Convert(mod);
decomposeMmaToDotOperand(mod, numWarps, threadsPerWarp);
decomposeMmaToDotOperand(mod, numWarps, threadsPerWarp, numCTAs);
decomposeBlockedToDotOperand(mod);
decomposeInsertSliceAsyncOp(mod);
decomposeMixedModeDotOp(mod);
// Allocate shared memory and set barrier
ModuleAllocation allocation(mod);
ModuleMembarAnalysis membarPass(&allocation);
membarPass.run();
/* Get tensorPtrMap before conversion */
TensorPtrMapT tensorPtrMap;
mod.walk([&tensorPtrMap](
mlir::triton::nvidia_gpu::InsertSliceAsyncV2Op insertOp) {
auto src = insertOp.getSrc();
auto ptrTy = src.getType().dyn_cast<triton::PointerType>();
if (ptrTy && ptrTy.getPointeeType().isa<RankedTensorType>()) {
auto makeTensorPtrOp = getMakeTensorPtrOp(insertOp.getSrc());
tensorPtrMap[insertOp.getOperation()] = makeTensorPtrOp;
}
});
mod.walk([&tensorPtrMap](mlir::triton::nvidia_gpu::StoreAsyncOp storeOp) {
auto dst = storeOp.getDst();
auto ptrTy = dst.getType().dyn_cast<triton::PointerType>();
if (ptrTy && ptrTy.getPointeeType().isa<RankedTensorType>()) {
auto makeTensorPtrOp = getMakeTensorPtrOp(storeOp.getDst());
tensorPtrMap[storeOp.getOperation()] = makeTensorPtrOp;
}
});
// Hack: cleanup
{
RewritePatternSet patterns(context);
patterns.add<FoldSplatMaskInInsertAsync>(context);
SmallVector<Operation *> insertSlices;
mod.walk([&insertSlices](triton::nvidia_gpu::InsertSliceAsyncV2Op op) {
insertSlices.push_back(op);
});
if (applyOpPatternsAndFold(insertSlices, std::move(patterns)).failed())
signalPassFailure();
}
// Lower functions
{
mlir::LowerToLLVMOptions option(context);
@@ -358,9 +465,14 @@ public:
}
ModuleAxisInfoAnalysis axisInfoAnalysis(mod);
// Rewrite ops
RewritePatternSet patterns(context);
// TritonGPU lowering patterns
// Emit logics to get threadId/blockIds/linearized clusterCTAId etc. and
// cache the values. The reason to do it here is that cluster_ctaid is
// currently implemented via inline asm, and thus cannot be CSEed.
// clusterCTAId will be emitted only when numCTAs is larger than 1, and
// other values will be DCEed if not used hereafter.
bool isWarpSpecialization =
ttng::TritonNvidiaGPUDialect::getWSSupportedAttr(mod);
OpBuilder::InsertPoint indexInsertPoint;
ConvertTritonGPUOpToLLVMPatternBase::IndexCacheInfo indexCacheInfo{
&baseIndexCache, &indexCache, &indexInsertPoint};
@@ -368,21 +480,56 @@ public:
if (axisInfoAnalysis.getNumFunctions() > 1) {
indexCacheInfo = {nullptr, nullptr, nullptr};
}
populateTritonGPUToLLVMPatterns(typeConverter, patterns, allocation,
indexCacheInfo, /*benefit=*/1);
populateConvertLayoutOpToLLVMPatterns(typeConverter, patterns, allocation,
indexCacheInfo, /*benefit=*/1);
populateDotOpToLLVMPatterns(typeConverter, patterns, allocation,
/*benefit=*/1);
populateElementwiseOpToLLVMPatterns(typeConverter, patterns, /*benefit=*/1);
populateLoadStoreOpToLLVMPatterns(typeConverter, patterns, axisInfoAnalysis,
allocation, indexCacheInfo,
/*benefit=*/1);
populateReduceOpToLLVMPatterns(typeConverter, patterns, allocation,
indexCacheInfo, /*benefit=*/1);
populateScanOpToLLVMPatterns(typeConverter, patterns, allocation,
indexCacheInfo, /*benefit=*/1);
populateViewOpToLLVMPatterns(typeConverter, patterns, /*benefit=*/1);
// tmaMetadata is absent in a triton-opt unit test, in this case, create a
// local one and dump it after this pass is done.
mlir::triton::gpu::TMAMetadataTy tmaMetaDataDebug;
if (tmaMetadata == nullptr)
tmaMetadata = &tmaMetaDataDebug;
RewritePatternSet patterns(context);
auto populatePatterns1 = [&](auto populateFunc) {
populateFunc(typeConverter, patterns, numWarps, axisInfoAnalysis,
allocation, indexCacheInfo,
/*benefit*/ 10);
};
auto populatePatterns2 = [&](auto populateFunc) {
populateFunc(typeConverter, patterns, numWarps, axisInfoAnalysis,
allocation, /*benefit*/ 10);
};
auto populatePatterns3 = [&](auto populateFunc) {
populateFunc(typeConverter, patterns, numWarps, axisInfoAnalysis,
allocation, indexCacheInfo, tmaMetadata, &tensorPtrMap,
/*benefit*/ 10);
};
auto populatePatterns4 = [&](auto populateFunc) {
populateFunc(typeConverter, patterns, numWarps, axisInfoAnalysis,
allocation, indexCacheInfo, computeCapability,
/*benefit*/ 10);
};
populatePatterns1(populateTritonGPUToLLVMPatterns);
populatePatterns1(populateConvertLayoutOpToLLVMPatterns);
populatePatterns2(populateDotOpToLLVMPatterns);
populatePatterns2(populateElementwiseOpToLLVMPatterns);
populatePatterns3(populateLoadStoreOpToLLVMPatterns);
populatePatterns4(populateReduceOpToLLVMPatterns);
populatePatterns1(populateScanOpToLLVMPatterns);
populatePatterns2(populateViewOpToLLVMPatterns);
populatePatterns2(populateBarrierOpToLLVMPatterns);
populatePatterns2(populateTensorPtrOpsToLLVMPatterns);
populatePatterns2(populateClusterOpsToLLVMPatterns);
populatePatterns2(populateRegReallocOpToLLVMPatterns);
// TODO(thomas): this should probably be done in a separate step to not
// interfere with our own lowering of arith ops. Add arith/math's patterns
// to help convert scalar expression to LLVM.
mlir::arith::populateArithToLLVMConversionPatterns(typeConverter, patterns);
mlir::populateMathToLLVMConversionPatterns(typeConverter, patterns);
// Native lowering patterns
if (isROCM) {
@@ -396,10 +543,18 @@ public:
patterns);
if (failed(applyPartialConversion(mod, target, std::move(patterns))))
return signalPassFailure();
// Fold CTAId when there is only 1 CTA.
if (numCTAs == 1) {
mod.walk([](triton::nvgpu::ClusterCTAIdOp id) {
OpBuilder b(id);
Value zero = LLVM::createConstantI32(id->getLoc(), b, 0);
id.replaceAllUsesWith(zero);
});
}
}
private:
using IndexCacheKeyT = std::pair<Attribute, RankedTensorType>;
DenseMap<IndexCacheKeyT, SmallVector<Value>, CacheKeyDenseMapInfo>
baseIndexCache;
DenseMap<IndexCacheKeyT, SmallVector<SmallVector<Value>>,
@@ -408,6 +563,7 @@ private:
int computeCapability{};
bool isROCM{};
mlir::triton::gpu::TMAMetadataTy *tmaMetadata;
void initSharedMemory(ModuleAllocation &allocation,
TritonGPUToLLVMTypeConverter &typeConverter) {
@@ -470,8 +626,8 @@ private:
});
}
void decomposeMmaToDotOperand(ModuleOp mod, int numWarps,
int threadsPerWarp) const {
void decomposeMmaToDotOperand(ModuleOp mod, int numWarps, int threadsPerWarp,
int numCTAs) const {
// Replace `mma -> dot_op` with `mma -> blocked -> dot_op`
// unless certain conditions are met
mod.walk([&](triton::gpu::ConvertLayoutOp cvtOp) -> void {
@@ -487,7 +643,7 @@ private:
dstType.getShape(), dstType.getElementType(),
triton::gpu::BlockedEncodingAttr::get(
mod.getContext(), srcType.getShape(), getSizePerThread(srcMma),
getOrder(srcMma), numWarps, threadsPerWarp));
getOrder(srcMma), numWarps, threadsPerWarp, numCTAs));
auto tmp = builder.create<triton::gpu::ConvertLayoutOp>(
cvtOp.getLoc(), tmpType, cvtOp.getOperand());
auto newConvert = builder.create<triton::gpu::ConvertLayoutOp>(
@@ -514,7 +670,8 @@ private:
dstType.getShape(), dstType.getElementType(),
triton::gpu::SharedEncodingAttr::get(
mod.getContext(), dstDotOp, srcType.getShape(),
getOrder(srcBlocked), srcType.getElementType()));
srcBlocked.getOrder(), srcBlocked.getCTALayout(),
srcType.getElementType()));
auto tmp = builder.create<triton::gpu::ConvertLayoutOp>(
cvtOp.getLoc(), tmpType, cvtOp.getOperand());
auto newConvert = builder.create<triton::gpu::ConvertLayoutOp>(
@@ -632,6 +789,52 @@ private:
}
});
}
static Value promoteOperand(OpBuilder &builder, Location loc, Value operand,
Type promotedType) {
Type tensorPromotedType =
operand.getType().cast<RankedTensorType>().cloneWith(std::nullopt,
promotedType);
return builder.create<triton::FpToFpOp>(loc, tensorPromotedType, operand);
}
// promote operands of dot op if the existing combination is not natively
// supported.
void decomposeMixedModeDotOp(ModuleOp mod) const {
mod.walk([](triton::DotOp dotOp) -> void {
Value D = dotOp.getResult();
OpBuilder builder(dotOp);
Type AElType =
dotOp.getA().getType().cast<RankedTensorType>().getElementType();
Type promoteType;
MmaEncodingAttr mmaLayout = D.getType()
.cast<RankedTensorType>()
.getEncoding()
.dyn_cast<MmaEncodingAttr>();
if (mmaLayout) {
bool isNativeHopperFP8 =
AElType.isFloat8E5M2() || AElType.isFloat8E4M3FN();
bool isFP8 = isNativeHopperFP8 || AElType.isFloat8E5M2FNUZ() ||
AElType.isFloat8E4M3FNUZ();
if (!isFP8 || (isNativeHopperFP8 && mmaLayout.isHopper()))
return;
promoteType = builder.getF16Type();
} else {
// FMA case.
Type AElType =
dotOp.getA().getType().cast<RankedTensorType>().getElementType();
Type DElType = D.getType().cast<RankedTensorType>().getElementType();
if (AElType == DElType)
return;
promoteType = DElType;
}
Location loc = dotOp.getLoc();
Value promotedA = promoteOperand(builder, loc, dotOp.getA(), promoteType);
Value promotedB = promoteOperand(builder, loc, dotOp.getB(), promoteType);
dotOp.setOperand(0, promotedA);
dotOp.setOperand(1, promotedB);
});
}
};
} // anonymous namespace
@@ -640,8 +843,11 @@ namespace mlir {
namespace triton {
std::unique_ptr<OperationPass<ModuleOp>>
createConvertTritonGPUToLLVMPass(int computeCapability, bool isROCM) {
return std::make_unique<::ConvertTritonGPUToLLVM>(computeCapability, isROCM);
createConvertTritonGPUToLLVMPass(int computeCapability,
mlir::triton::gpu::TMAMetadataTy *tmaMetadata,
bool isROCM) {
return std::make_unique<::ConvertTritonGPUToLLVM>(computeCapability,
tmaMetadata, isROCM);
}
} // namespace triton

View File

@@ -41,7 +41,27 @@ TritonGPUToLLVMTypeConverter::TritonGPUToLLVMTypeConverter(
Type TritonGPUToLLVMTypeConverter::convertTritonPointerType(
triton::PointerType type) {
// Recursively translate pointee type
auto ctx = type.getContext();
auto pointeeType = type.getPointeeType();
if (pointeeType.isa<RankedTensorType>()) {
auto rankedTensorType = pointeeType.cast<RankedTensorType>();
// struct { offset0, offset1, shape0, shape1, stride0,
// stride1, base_ptr};
auto eleType = rankedTensorType.getElementType();
auto shape = rankedTensorType.getShape();
SmallVector<Type, 4> types;
// offsets
for (size_t i = 0; i < shape.size(); ++i)
types.push_back(IntegerType::get(ctx, 32));
// shapes, strides
for (size_t i = 0; i < 2 * shape.size(); ++i)
types.push_back(IntegerType::get(ctx, 64));
types.push_back(
LLVM::LLVMPointerType::get(eleType, type.getAddressSpace()));
return LLVM::LLVMStructType::getLiteral(ctx, types);
}
return LLVM::LLVMPointerType::get(convertType(type.getPointeeType()),
type.getAddressSpace());
}

View File

@@ -6,19 +6,19 @@ namespace mlir {
namespace LLVM {
using namespace mlir::triton;
Value createConstantI32(Location loc, PatternRewriter &rewriter, int32_t v) {
Value createConstantI32(Location loc, OpBuilder &rewriter, int32_t v) {
auto i32ty = rewriter.getIntegerType(32);
return rewriter.create<LLVM::ConstantOp>(loc, i32ty,
IntegerAttr::get(i32ty, v));
}
Value createConstantF32(Location loc, PatternRewriter &rewriter, float v) {
Value createConstantF32(Location loc, OpBuilder &rewriter, float v) {
auto type = type::f32Ty(rewriter.getContext());
return rewriter.create<LLVM::ConstantOp>(loc, type,
rewriter.getF32FloatAttr(v));
}
Value createConstantF64(Location loc, PatternRewriter &rewriter, float v) {
Value createConstantF64(Location loc, OpBuilder &rewriter, float v) {
auto type = type::f64Ty(rewriter.getContext());
return rewriter.create<LLVM::ConstantOp>(loc, type,
rewriter.getF64FloatAttr(v));
@@ -40,6 +40,96 @@ Value createLLVMIntegerConstant(OpBuilder &builder, Location loc, short width,
builder.getIntegerAttr(ty, value));
}
// A wrapper of LoadDSmemOp when vec = 1
// (1) Get bitwidth from elemTy
// (2) Create LoadDSmemOp
// (3) Bitcast result from dataTy (u16/u32/u64) back to elemTy
Value createLoadDSmem(Location loc, PatternRewriter &rewriter, Value addr,
Value ctaId) {
assert(addr.getType().isa<LLVMPointerType>() &&
"addr must be a pointer type");
auto ptrTy = addr.getType().cast<LLVMPointerType>();
assert(ptrTy.getAddressSpace() == 3 && "Invalid addr space for load_dsmem");
auto elemTy = ptrTy.getElementType();
unsigned bitwidth = elemTy.getIntOrFloatBitWidth();
Value ret =
rewriter.create<triton::nvgpu::LoadDSmemOp>(loc, addr, ctaId, bitwidth);
return bitcast(ret, elemTy);
}
// A wrapper of LoadDSmemOp when vec > 1
// (1) Get bitwidth from elemTy
// (2) Create LoadDSmemOp and extract results from retStruct
// (3) Bitcast results from dataTy (u16/u32/u64) back to elemTy
SmallVector<Value> createLoadDSmem(Location loc, PatternRewriter &rewriter,
Value addr, Value ctaId, unsigned vec) {
assert(addr.getType().isa<LLVMPointerType>() &&
"addr must be a pointer type");
auto ptrTy = addr.getType().cast<LLVMPointerType>();
assert(ptrTy.getAddressSpace() == 3 && "Invalid addr space for load_dsmem");
auto elemTy = ptrTy.getElementType();
unsigned bitwidth = elemTy.getIntOrFloatBitWidth();
Value retStruct = rewriter.create<triton::nvgpu::LoadDSmemOp>(
loc, addr, ctaId, bitwidth, vec);
SmallVector<Value> retVals;
for (unsigned i = 0; i < vec; ++i) {
auto dataTy = rewriter.getIntegerType(bitwidth);
Value data = extract_val(dataTy, retStruct, i);
retVals.push_back(bitcast(data, elemTy));
}
return retVals;
}
// A wrapper of StoreDSmemOp when vec = 1
// (1) Get bitwidth from elemTy
// (2) Bitcast value from elemTy to dataTy (u16/u32/u64)
// (3) Create StoreDSmemOp
void createStoreDSmem(Location loc, PatternRewriter &rewriter, Value addr,
Value ctaId, Value value, Value pred) {
assert(addr.getType().isa<LLVMPointerType>() &&
"addr must be a pointer type");
auto ptrTy = addr.getType().cast<LLVMPointerType>();
assert(ptrTy.getAddressSpace() == 3 && "Invalid addr space for load_dsmem");
auto elemTy = ptrTy.getElementType();
unsigned bitwidth = elemTy.getIntOrFloatBitWidth();
auto dataTy = rewriter.getIntegerType(bitwidth);
Value data = bitcast(value, dataTy);
rewriter.create<triton::nvgpu::StoreDSmemOp>(loc, addr, ctaId, data, pred);
}
// A wrapper of StoreDSmemOp when vec = 1 and pred = 1
void createStoreDSmem(Location loc, PatternRewriter &rewriter, Value addr,
Value ctaId, Value value) {
Value pred = int_val(/*width=*/1, 1);
createStoreDSmem(loc, rewriter, addr, ctaId, value, pred);
}
// A wrapper of StoreDSmemOp when vec > 1
// (1) Get bitwidth from elemTy
// (2) Bitcast values from elemTy to dataTy (u16/u32/u64)
// (3) Create StoreDSmemOp
void createStoreDSmem(Location loc, PatternRewriter &rewriter, Value addr,
Value ctaId, ArrayRef<Value> values, Value pred) {
assert(addr.getType().isa<LLVMPointerType>() &&
"addr must be a pointer type");
auto ptrTy = addr.getType().cast<LLVMPointerType>();
assert(ptrTy.getAddressSpace() == 3 && "Invalid addr space for load_dsmem");
auto elemTy = ptrTy.getElementType();
unsigned bitwidth = elemTy.getIntOrFloatBitWidth();
auto dataTy = rewriter.getIntegerType(bitwidth);
SmallVector<Value> data;
for (unsigned i = 0; i < values.size(); ++i)
data.push_back(bitcast(values[i], dataTy));
rewriter.create<triton::nvgpu::StoreDSmemOp>(loc, addr, ctaId, data, pred);
}
// A wrapper of StoreDSmemOp when vec > 1 and pred = 1
void createStoreDSmem(Location loc, PatternRewriter &rewriter, Value addr,
Value ctaId, ArrayRef<Value> values) {
Value pred = int_val(/*width=*/1, 1);
createStoreDSmem(loc, rewriter, addr, ctaId, values, pred);
}
SharedMemoryObject
getSharedMemoryObjectFromStruct(Location loc, Value llvmStruct,
ConversionPatternRewriter &rewriter) {

View File

@@ -14,6 +14,7 @@
#define zext(...) rewriter.create<LLVM::ZExtOp>(loc, __VA_ARGS__)
#define sext(...) rewriter.create<LLVM::SExtOp>(loc, __VA_ARGS__)
#define fpext(...) rewriter.create<LLVM::FPExtOp>(loc, __VA_ARGS__)
#define trunc(...) rewriter.create<LLVM::TruncOp>(loc, __VA_ARGS__)
#define udiv(...) rewriter.create<LLVM::UDivOp>(loc, __VA_ARGS__)
#define urem(...) rewriter.create<LLVM::URemOp>(loc, __VA_ARGS__)
#define add(...) rewriter.create<LLVM::AddOp>(loc, __VA_ARGS__)
@@ -43,6 +44,8 @@
rewriter.create<LLVM::ExtractElementOp>(loc, __VA_ARGS__)
#define load(...) rewriter.create<LLVM::LoadOp>(loc, __VA_ARGS__)
#define store(val, ptr) rewriter.create<LLVM::StoreOp>(loc, val, ptr)
#define load_dsmem(...) LLVM::createLoadDSmem(loc, rewriter, __VA_ARGS__)
#define store_dsmem(...) LLVM::createStoreDSmem(loc, rewriter, __VA_ARGS__)
#define fcmp_ogt(lhs, rhs) \
rewriter.create<LLVM::FCmpOp>(loc, rewriter.getI1Type(), \
LLVM::FCmpPredicate::ogt, lhs, rhs)
@@ -75,6 +78,15 @@
#define select(...) rewriter.create<LLVM::SelectOp>(loc, __VA_ARGS__)
#define address_of(...) rewriter.create<LLVM::AddressOfOp>(loc, __VA_ARGS__)
#define barrier() rewriter.create<mlir::gpu::BarrierOp>(loc)
#define barSync(rewriter, op, bar, numThreads) \
do { \
::mlir::triton::PTXBuilder ptxBuilder; \
auto &barSyncOp = *ptxBuilder.create<>("bar.sync"); \
barSyncOp(ptxBuilder.newConstantOperand(bar), \
ptxBuilder.newConstantOperand(numThreads)); \
auto voidTy = void_ty(op->getContext()); \
ptxBuilder.launch(rewriter, op->getLoc(), voidTy); \
} while (0)
#define undef(...) rewriter.create<LLVM::UndefOp>(loc, __VA_ARGS__)
#define null(...) rewriter.create<LLVM::NullOp>(loc, __VA_ARGS__)
#define call(...) rewriter.create<LLVM::CallOp>(loc, __VA_ARGS__)
@@ -84,6 +96,8 @@
#define i64_ty rewriter.getIntegerType(64)
#define i32_ty rewriter.getIntegerType(32)
#define i16_ty rewriter.getIntegerType(16)
#define i32_ty rewriter.getIntegerType(32)
#define i64_ty rewriter.getIntegerType(64)
#define ui32_ty rewriter.getIntegerType(32, false)
#define f16_ty rewriter.getF16Type()
#define bf16_ty rewriter.getBF16Type()
@@ -174,13 +188,13 @@ T getLinearIndex(llvm::ArrayRef<T> multiDimIndex, llvm::ArrayRef<T> shape,
namespace LLVM {
using namespace mlir::triton;
Value createConstantI32(Location loc, PatternRewriter &rewriter, int32_t v);
Value createConstantI32(Location loc, OpBuilder &rewriter, int32_t v);
/// Create a 32-bit float constant.
Value createConstantF32(Location loc, PatternRewriter &rewriter, float v);
Value createConstantF32(Location loc, OpBuilder &rewriter, float v);
/// Create a 64-bit float constant.
Value createConstantF64(Location loc, PatternRewriter &rewriter, float v);
Value createConstantF64(Location loc, OpBuilder &rewriter, float v);
/// Create an index type constant.
Value createIndexConstant(OpBuilder &builder, Location loc,
@@ -190,6 +204,28 @@ Value createIndexConstant(OpBuilder &builder, Location loc,
Value createLLVMIntegerConstant(OpBuilder &builder, Location loc, short width,
int64_t value);
/// Usage of macro load_dsmem
/// (1) load_dsmem(addr, ctaId)
/// (2) load_dsmem(addr, ctaId, vec)
Value createLoadDSmem(Location loc, PatternRewriter &rewriter, Value addr,
Value ctaId);
SmallVector<Value> createLoadDSmem(Location loc, PatternRewriter &rewriter,
Value addr, Value ctaId, unsigned vec);
/// Usage of macro store_dsmem
/// (1) store_dsmem(addr, ctaId, value, pred)
/// (2) store_dsmem(addr, ctaId, value)
/// (3) store_dsmem(addr, ctaId, values, pred)
/// (4) store_dsmem(addr, ctaId, values)
void createStoreDSmem(Location loc, PatternRewriter &rewriter, Value addr,
Value ctaId, Value value, Value pred);
void createStoreDSmem(Location loc, PatternRewriter &rewriter, Value addr,
Value ctaId, Value value);
void createStoreDSmem(Location loc, PatternRewriter &rewriter, Value addr,
Value ctaId, ArrayRef<Value> values, Value pred);
void createStoreDSmem(Location loc, PatternRewriter &rewriter, Value addr,
Value ctaId, ArrayRef<Value> values);
/// Helper function to get strides from a given shape and its order
SmallVector<Value>
getStridesFromShapeAndOrder(ArrayRef<int64_t> shape, ArrayRef<unsigned> order,

View File

@@ -104,6 +104,7 @@ struct CatOpConversion : public ConvertTritonGPUOpToLLVMPattern<CatOp> {
using OpAdaptor = typename CatOp::Adaptor;
explicit CatOpConversion(TritonGPUToLLVMTypeConverter &typeConverter,
PatternBenefit benefit = 1)
: ConvertTritonGPUOpToLLVMPattern<CatOp>(typeConverter, benefit) {}
@@ -138,6 +139,7 @@ struct CatOpConversion : public ConvertTritonGPUOpToLLVMPattern<CatOp> {
struct ViewOpConversion : public ConvertTritonGPUOpToLLVMPattern<ViewOp> {
using OpAdaptor = typename ViewOp::Adaptor;
explicit ViewOpConversion(TritonGPUToLLVMTypeConverter &typeConverter,
PatternBenefit benefit = 1)
: ConvertTritonGPUOpToLLVMPattern<ViewOp>(typeConverter, benefit) {}
@@ -159,6 +161,7 @@ struct ExpandDimsOpConversion
: public ConvertTritonGPUOpToLLVMPattern<ExpandDimsOp> {
using OpAdaptor = typename ExpandDimsOp::Adaptor;
explicit ExpandDimsOpConversion(TritonGPUToLLVMTypeConverter &typeConverter,
PatternBenefit benefit = 1)
: ConvertTritonGPUOpToLLVMPattern<ExpandDimsOp>(typeConverter, benefit) {}
@@ -221,7 +224,9 @@ struct TransOpConversion
};
void populateViewOpToLLVMPatterns(TritonGPUToLLVMTypeConverter &typeConverter,
RewritePatternSet &patterns,
RewritePatternSet &patterns, int numWarps,
ModuleAxisInfoAnalysis &axisInfoAnalysis,
ModuleAllocation &allocation,
PatternBenefit benefit) {
patterns.add<ViewOpConversion>(typeConverter, benefit);
patterns.add<ExpandDimsOpConversion>(typeConverter, benefit);

View File

@@ -7,7 +7,9 @@ using namespace mlir;
using namespace mlir::triton;
void populateViewOpToLLVMPatterns(TritonGPUToLLVMTypeConverter &typeConverter,
RewritePatternSet &patterns,
RewritePatternSet &patterns, int numWarps,
ModuleAxisInfoAnalysis &axisInfoAnalysis,
ModuleAllocation &allocation,
PatternBenefit benefit);
#endif

View File

@@ -10,6 +10,7 @@
#include "triton/Dialect/Triton/IR/Dialect.h"
#include "triton/Dialect/TritonGPU/IR/Dialect.h"
#include "triton/Dialect/TritonGPU/Transforms/TritonGPUConversion.h"
#include "triton/Target/PTX/TmaMetadata.h"
#include "llvm/ADT/APSInt.h"
#include <numeric>
@@ -240,10 +241,19 @@ struct TritonExpandDimsPattern
retWarpsPerCTA.insert(retWarpsPerCTA.begin() + op.getAxis(), 1);
SmallVector<unsigned, 4> retOrder(retShape.size());
std::iota(retOrder.begin(), retOrder.end(), 0);
auto argCTALayout = argEncoding.getCTALayout();
auto retCTAsPerCGA = insertOne(argCTALayout.getCTAsPerCGA(), op.getAxis());
auto retCTASplitNum =
insertOne(argCTALayout.getCTASplitNum(), op.getAxis());
auto retCTAOrder = insertOrder(argCTALayout.getCTAOrder(), op.getAxis());
auto retCTALayout = triton::gpu::CTALayoutAttr::get(
getContext(), retCTAsPerCGA, retCTASplitNum, retCTAOrder);
triton::gpu::BlockedEncodingAttr retEncoding =
triton::gpu::BlockedEncodingAttr::get(getContext(), retSizePerThread,
retThreadsPerWarp, retWarpsPerCTA,
retOrder);
retOrder, retCTALayout);
// convert operand to slice of return type
Attribute newArgEncoding = triton::gpu::SliceEncodingAttr::get(
getContext(), op.getAxis(), retEncoding);
@@ -257,6 +267,26 @@ struct TritonExpandDimsPattern
adaptor.getAttributes());
return success();
}
private:
template <typename T>
SmallVector<T> insertOne(ArrayRef<T> vec, unsigned axis) const {
SmallVector<T> res(vec.begin(), vec.end());
res.insert(res.begin() + axis, 1);
return res;
}
// Example: order = [ 0, 2, 1, 3], dim = 2
// resOrder = [2, 0, 3, 1, 4]
SmallVector<unsigned> insertOrder(ArrayRef<unsigned> order,
unsigned axis) const {
SmallVector<unsigned> resOrder(order.begin(), order.end());
for (unsigned i = 0; i < resOrder.size(); ++i)
if (resOrder[i] >= axis)
++resOrder[i];
resOrder.insert(resOrder.begin(), axis);
return resOrder;
}
};
struct TritonDotPattern : public OpConversionPattern<triton::DotOp> {
@@ -270,6 +300,7 @@ struct TritonDotPattern : public OpConversionPattern<triton::DotOp> {
auto typeConverter = getTypeConverter<TritonGPUTypeConverter>();
int numWarps = typeConverter->getNumWarps();
int threadsPerWarp = typeConverter->getThreadsPerWarp();
int numCTAs = typeConverter->getNumCTAs();
SmallVector<unsigned> retSizePerThread = {1, 1};
if (origShape[0] * origShape[1] / (numWarps * threadsPerWarp) >= 4)
@@ -279,7 +310,7 @@ struct TritonDotPattern : public OpConversionPattern<triton::DotOp> {
SmallVector<unsigned> retOrder = {1, 0};
Attribute dEncoding = triton::gpu::BlockedEncodingAttr::get(
getContext(), origShape, retSizePerThread, retOrder, numWarps,
threadsPerWarp);
threadsPerWarp, numCTAs);
RankedTensorType retType =
RankedTensorType::get(origShape, origType.getElementType(), dEncoding);
// a & b must be of smem layout
@@ -354,9 +385,9 @@ struct TritonCatPattern : public OpConversionPattern<triton::CatOp> {
newRetSizePerThread[retOrder[0]] *=
newRetTotalElemsPerThread / retTotalElemsPerThread;
triton::gpu::BlockedEncodingAttr newRetEncoding =
triton::gpu::BlockedEncodingAttr::get(getContext(), newRetSizePerThread,
retThreadsPerWarp, retWarpsPerCTA,
retOrder);
triton::gpu::BlockedEncodingAttr::get(
getContext(), newRetSizePerThread, retThreadsPerWarp,
retWarpsPerCTA, retOrder, retEncoding.getCTALayout());
auto newRetType = RankedTensorType::get(retShape, retType.getElementType(),
newRetEncoding);
addNamedAttrs(rewriter.replaceOpWithNewOp<triton::CatOp>(
@@ -386,8 +417,12 @@ struct TritonTransPattern : public OpConversionPattern<triton::TransOp> {
if (auto srcBlockedEncoding =
srcEncoding.dyn_cast<triton::gpu::BlockedEncodingAttr>())
llvm::copy(srcBlockedEncoding.getOrder(), order.begin());
srcEncoding =
triton::gpu::SharedEncodingAttr::get(getContext(), 1, 1, 1, order);
// TODO(Qingyi): need to check whether the CTALayout of srcEncoding should
// be used here. For tests where numCTAs = 1, this is not a problem since
// all CTALayouts are the same.
auto CTALayout = triton::gpu::getCTALayout(srcEncoding);
srcEncoding = triton::gpu::SharedEncodingAttr::get(getContext(), 1, 1, 1,
order, CTALayout);
srcType = RankedTensorType::get(srcType.getShape(),
srcType.getElementType(), srcEncoding);
src = rewriter.create<triton::gpu::ConvertLayoutOp>(src.getLoc(), srcType,
@@ -658,10 +693,12 @@ public:
};
void populateTritonPatterns(TritonGPUTypeConverter &typeConverter,
RewritePatternSet &patterns) {
RewritePatternSet &patterns, unsigned numCTAs) {
MLIRContext *context = patterns.getContext();
patterns
.insert< // TODO: view should have custom pattern that views the layout
TritonGenericPattern<triton::AdvanceOp>,
TritonGenericPattern<triton::MakeTensorPtrOp>,
TritonGenericPattern<triton::ViewOp>,
TritonGenericPattern<triton::BitcastOp>,
TritonGenericPattern<triton::FpToFpOp>,
@@ -889,16 +926,20 @@ class ConvertTritonToTritonGPU
public:
ConvertTritonToTritonGPU() = default;
// constructor with some parameters set explicitly.
ConvertTritonToTritonGPU(int numWarps, int threadsPerWarp) {
ConvertTritonToTritonGPU(int numWarps, int threadsPerWarp, int numCTAs,
int computeCapability) {
this->numWarps = numWarps;
this->threadsPerWarp = threadsPerWarp;
this->numCTAs = numCTAs;
this->computeCapability = computeCapability;
}
void runOnOperation() override {
MLIRContext *context = &getContext();
ModuleOp mod = getOperation();
// type converter
TritonGPUTypeConverter typeConverter(context, numWarps, threadsPerWarp);
TritonGPUTypeConverter typeConverter(context, numWarps, threadsPerWarp,
numCTAs);
TritonGPUConversionTarget target(*context, typeConverter);
// rewrite patterns
RewritePatternSet patterns(context);
@@ -906,7 +947,7 @@ public:
populateStdPatternsAndLegality(typeConverter, patterns, target);
populateArithPatternsAndLegality(typeConverter, patterns, target);
populateMathPatternsAndLegality(typeConverter, patterns, target);
populateTritonPatterns(typeConverter, patterns);
populateTritonPatterns(typeConverter, patterns, numCTAs);
// TODO: can we use
// mlir::scf::populateSCFStructurealTypeConversionsAndLegality(...) here?
populateSCFPatterns(typeConverter, patterns);
@@ -925,6 +966,13 @@ public:
AttrNumThreadsPerWarp,
IntegerAttr::get(i32_ty, llvm::APInt(32, threadsPerWarp.getValue())));
mod->setAttr(AttrNumCTAsName,
IntegerAttr::get(i32_ty, llvm::APInt(32, numCTAs.getValue())));
mod->setAttr(AttrComputeCapabilityName,
IntegerAttr::get(
i32_ty, llvm::APInt(32, computeCapability.getValue())));
// update layouts
// broadcast src => multicast, dst => broadcasted
// if (failed(target.refineLayouts(mod, numWarps)))
@@ -936,8 +984,11 @@ public:
std::unique_ptr<OperationPass<ModuleOp>>
mlir::triton::createConvertTritonToTritonGPUPass(int numWarps,
int threadsPerWarp) {
return std::make_unique<::ConvertTritonToTritonGPU>(numWarps, threadsPerWarp);
int threadsPerWarp,
int numCTAs,
int computeCapability) {
return std::make_unique<::ConvertTritonToTritonGPU>(
numWarps, threadsPerWarp, numCTAs, computeCapability);
}
std::unique_ptr<OperationPass<ModuleOp>>

View File

@@ -1,2 +1,4 @@
add_subdirectory(Triton)
add_subdirectory(TritonGPU)
add_subdirectory(TritonNvidiaGPU)
add_subdirectory(NVGPU)

View File

@@ -0,0 +1,2 @@
add_subdirectory(IR)
add_subdirectory(ToLLVMIR)

View File

@@ -0,0 +1,9 @@
add_mlir_dialect_library(NVGPUIR
Dialect.cpp
DEPENDS
NVGPUTableGen
NVGPUAttrDefsIncGen
LINK_LIBS PUBLIC
)

Some files were not shown because too many files have changed in this diff Show More