mirror of
https://github.com/ROCm/ROCm.git
synced 2026-04-05 03:01:17 -04:00
Initial code merge of Hopper support (#2036)
The initial code merge of Nvidia Hopper features support. Please be aware that the code merge is not finished yet and the trouble-shooting is still ongoing. The new hardware features (GMMA, TMA, STMATRIX etc.) and automatic warp-specialization are experimental for now and turned off by default. It is recommended for a trial when version 3.0 is released. The work is contributed by: ben-zhang-609, bealwang, donproc, qliu93, jsh20, allatit23, LyricZhao, ivanyinwz, goostavz & yangjunpro from Nvidia, in cooperation with: ptillet, Jokeren, ThomasRaoux & zahimoud from OpenAI. Co-authored-by: Goostav Zhu <gzhu@nvidia.com>
This commit is contained in:
30
.github/workflows/integration-tests.yml
vendored
30
.github/workflows/integration-tests.yml
vendored
@@ -50,6 +50,8 @@ jobs:
|
||||
if: ${{(matrix.runner[0] == 'self-hosted') && (matrix.runner[1] == 'V100' || matrix.runner[1] == 'A100' || matrix.runner[1] == 'H100')}}
|
||||
run: |
|
||||
echo "BACKEND=CUDA" >> "${GITHUB_ENV}"
|
||||
echo "ENABLE_TMA=0" >> "${GITHUB_ENV}"
|
||||
echo "ENABLE_MMA_V3=0" >> "${GITHUB_ENV}"
|
||||
|
||||
- name: Clear cache
|
||||
run: |
|
||||
@@ -79,8 +81,32 @@ jobs:
|
||||
fi
|
||||
lit -v "${LIT_TEST_DIR}"
|
||||
|
||||
- name: Run python tests on CUDA
|
||||
if: ${{ env.BACKEND == 'CUDA'}}
|
||||
- name: Enable MMAV3 and TMA
|
||||
if: ${{(matrix.runner[0] == 'self-hosted') && (matrix.runner[1] == 'H100')}}
|
||||
run: |
|
||||
echo "ENABLE_TMA=1" >> "${GITHUB_ENV}"
|
||||
echo "ENABLE_MMA_V3=1" >> "${GITHUB_ENV}"
|
||||
|
||||
- name: Run python tests on CUDA with ENABLE_TMA=1 and ENABLE_MMA_V3=1
|
||||
if: ${{ env.BACKEND == 'CUDA' && env.ENABLE_TMA == '1' && env.ENABLE_MMA_V3 == '1'}}
|
||||
run: |
|
||||
cd python/test/unit
|
||||
python3 -m pytest -n 8 --ignore=runtime
|
||||
# run runtime tests serially to avoid race condition with cache handling.
|
||||
python3 -m pytest runtime/
|
||||
|
||||
- name: Disable MMAV3 and TMA
|
||||
if: ${{(matrix.runner[0] == 'self-hosted') && (matrix.runner[1] == 'H100')}}
|
||||
run: |
|
||||
echo "ENABLE_TMA=0" >> "${GITHUB_ENV}"
|
||||
echo "ENABLE_MMA_V3=0" >> "${GITHUB_ENV}"
|
||||
|
||||
- name: Clear cache
|
||||
run: |
|
||||
rm -rf ~/.triton
|
||||
|
||||
- name: Run python tests on CUDA with ENABLE_TMA=0 and ENABLE_MMA_V3=0
|
||||
if: ${{ env.BACKEND == 'CUDA' && env.ENABLE_TMA == '0' && env.ENABLE_MMA_V3 == '0'}}
|
||||
run: |
|
||||
cd python/test/unit
|
||||
python3 -m pytest -n 8 --ignore=runtime
|
||||
|
||||
@@ -209,6 +209,7 @@ if(TRITON_BUILD_PYTHON_MODULE)
|
||||
TritonAnalysis
|
||||
TritonTransforms
|
||||
TritonGPUTransforms
|
||||
TritonNvidiaGPUTransforms
|
||||
TritonLLVMIR
|
||||
TritonPTX
|
||||
TritonHSACO
|
||||
|
||||
@@ -9,6 +9,7 @@ target_link_libraries(triton-opt PRIVATE
|
||||
TritonAnalysis
|
||||
TritonTransforms
|
||||
TritonGPUTransforms
|
||||
TritonNvidiaGPUTransforms
|
||||
${dialect_libs}
|
||||
${conversion_libs}
|
||||
# tests
|
||||
@@ -29,6 +30,7 @@ target_link_libraries(triton-reduce PRIVATE
|
||||
TritonAnalysis
|
||||
TritonTransforms
|
||||
TritonGPUTransforms
|
||||
TritonNvidiaGPUTransforms
|
||||
${dialect_libs}
|
||||
${conversion_libs}
|
||||
# tests
|
||||
@@ -48,6 +50,7 @@ llvm_update_compile_flags(triton-translate)
|
||||
TritonAnalysis
|
||||
TritonTransforms
|
||||
TritonGPUTransforms
|
||||
TritonNvidiaGPUTransforms
|
||||
TritonLLVMIR
|
||||
TritonPTX
|
||||
TritonHSACO
|
||||
|
||||
@@ -1,9 +1,11 @@
|
||||
#pragma once
|
||||
#include "triton/Dialect/Triton/IR/Dialect.h"
|
||||
#include "triton/Dialect/TritonGPU/IR/Dialect.h"
|
||||
#include "triton/Dialect/TritonNvidiaGPU/IR/Dialect.h"
|
||||
|
||||
#include "triton/Dialect/Triton/Transforms/Passes.h"
|
||||
#include "triton/Dialect/TritonGPU/Transforms/Passes.h"
|
||||
#include "triton/Dialect/TritonNvidiaGPU/Transforms/Passes.h"
|
||||
|
||||
#include "triton/Conversion/TritonGPUToLLVM/Passes.h"
|
||||
#include "triton/Conversion/TritonToTritonGPU/Passes.h"
|
||||
@@ -23,6 +25,7 @@ inline void registerTritonDialects(mlir::DialectRegistry ®istry) {
|
||||
mlir::registerAllPasses();
|
||||
mlir::registerTritonPasses();
|
||||
mlir::registerTritonGPUPasses();
|
||||
mlir::registerTritonNvidiaGPUPasses();
|
||||
mlir::test::registerTestAliasPass();
|
||||
mlir::test::registerTestAlignmentPass();
|
||||
mlir::test::registerTestAllocationPass();
|
||||
@@ -32,6 +35,7 @@ inline void registerTritonDialects(mlir::DialectRegistry ®istry) {
|
||||
|
||||
// TODO: register Triton & TritonGPU passes
|
||||
registry.insert<mlir::triton::TritonDialect, mlir::cf::ControlFlowDialect,
|
||||
mlir::triton::nvidia_gpu::TritonNvidiaGPUDialect,
|
||||
mlir::triton::gpu::TritonGPUDialect, mlir::math::MathDialect,
|
||||
mlir::arith::ArithDialect, mlir::scf::SCFDialect,
|
||||
mlir::gpu::GPUDialect>();
|
||||
|
||||
@@ -14,6 +14,7 @@
|
||||
#include "triton/Conversion/TritonToTritonGPU/TritonToTritonGPUPass.h"
|
||||
#include "triton/Dialect/Triton/IR/Dialect.h"
|
||||
#include "triton/Dialect/TritonGPU/IR/Dialect.h"
|
||||
#include "triton/Dialect/TritonNvidiaGPU/IR/Dialect.h"
|
||||
#include "triton/Target/HSACO/HSACOTranslation.h"
|
||||
#include "triton/Target/LLVMIR/LLVMIRTranslation.h"
|
||||
#include "triton/Target/PTX/PTXTranslation.h"
|
||||
@@ -38,6 +39,7 @@ OwningOpRef<ModuleOp> loadMLIRModule(llvm::StringRef inputFilename,
|
||||
mlir::DialectRegistry registry;
|
||||
registry
|
||||
.insert<TritonDialect, triton::gpu::TritonGPUDialect,
|
||||
triton::nvidia_gpu::TritonNvidiaGPUDialect,
|
||||
mlir::math::MathDialect, arith::ArithDialect, scf::SCFDialect>();
|
||||
|
||||
context.appendDialectRegistry(registry);
|
||||
@@ -121,8 +123,10 @@ LogicalResult tritonTranslateMain(int argc, char **argv,
|
||||
}
|
||||
|
||||
llvm::LLVMContext llvmContext;
|
||||
auto llvmir = translateTritonGPUToLLVMIR(&llvmContext, *module,
|
||||
SMArch.getValue(), false /*isRocm*/);
|
||||
mlir::triton::gpu::TMAMetadataTy tmaInfos;
|
||||
auto llvmir = translateTritonGPUToLLVMIR(
|
||||
&llvmContext, *module, SMArch.getValue(), tmaInfos, false /*isRocm*/);
|
||||
|
||||
if (!llvmir) {
|
||||
llvm::errs() << "Translate to LLVM IR failed";
|
||||
}
|
||||
|
||||
@@ -192,3 +192,4 @@ Iterators
|
||||
:nosignatures:
|
||||
|
||||
static_range
|
||||
multiple_of
|
||||
|
||||
@@ -9,6 +9,7 @@
|
||||
|
||||
#include "triton/Dialect/Triton/IR/Dialect.h"
|
||||
#include "triton/Dialect/TritonGPU/IR/Dialect.h"
|
||||
#include "triton/Dialect/TritonNvidiaGPU/IR/Dialect.h"
|
||||
#include <atomic>
|
||||
#include <limits>
|
||||
|
||||
@@ -147,17 +148,17 @@ private:
|
||||
BufferKind kind;
|
||||
BufferId id;
|
||||
size_t size;
|
||||
size_t alignment;
|
||||
size_t offset;
|
||||
|
||||
bool operator==(const BufferT &other) const { return id == other.id; }
|
||||
bool operator<(const BufferT &other) const { return id < other.id; }
|
||||
|
||||
BufferT() : BufferT(BufferKind::Explicit) {}
|
||||
BufferT(BufferKind kind)
|
||||
: kind(kind), id(InvalidBufferId), size(0), offset(0) {}
|
||||
BufferT(BufferKind kind, size_t size) : BufferT(kind, size, 0) {}
|
||||
BufferT(BufferKind kind, size_t size, size_t offset)
|
||||
: kind(kind), id(nextId++), size(size), offset(offset) {}
|
||||
BufferT() : BufferT(BufferKind::Explicit, 0) {}
|
||||
BufferT(BufferKind kind, size_t size, size_t alignment = 4,
|
||||
size_t offset = 0)
|
||||
: kind(kind), id(nextId++), size(size), alignment(alignment),
|
||||
offset(offset) {}
|
||||
};
|
||||
|
||||
/// Op -> Scratch Buffer
|
||||
|
||||
@@ -21,7 +21,7 @@ namespace mlir {
|
||||
/// This lattice value represents known information on the axes of a lattice.
|
||||
class AxisInfo {
|
||||
public:
|
||||
typedef SmallVector<int64_t, 4> DimVectorT;
|
||||
typedef SmallVector<int64_t> DimVectorT;
|
||||
|
||||
public:
|
||||
/// Default constructor
|
||||
|
||||
@@ -3,6 +3,7 @@
|
||||
|
||||
#include "mlir/Analysis/DataFlowFramework.h"
|
||||
#include "mlir/Analysis/SliceAnalysis.h"
|
||||
#include "triton/Dialect/Triton/IR/Dialect.h"
|
||||
#include "triton/Dialect/TritonGPU/IR/Dialect.h"
|
||||
#include <algorithm>
|
||||
#include <numeric>
|
||||
@@ -121,7 +122,11 @@ bool isSingleValue(Value value);
|
||||
|
||||
bool isMmaToDotShortcut(RankedTensorType &srcTy, RankedTensorType &dstTy);
|
||||
|
||||
Type getElementType(Value value);
|
||||
bool isMmaToMmaShortcut(RankedTensorType &srcTy, RankedTensorType &dstTy);
|
||||
|
||||
// TODO: Move utility functions that belong to ConvertLayoutOp to class
|
||||
// ConvertLayoutOpHelper in the future
|
||||
bool shouldUseDistSmem(Attribute srcLayout, Attribute dstLayout);
|
||||
|
||||
template <typename T_OUT, typename T_IN>
|
||||
inline SmallVector<T_OUT> convertType(ArrayRef<T_IN> in) {
|
||||
@@ -324,6 +329,10 @@ protected:
|
||||
FuncDataMapT funcMap;
|
||||
SmallVector<FunctionOpInterface> roots;
|
||||
};
|
||||
// Create a basic DataFlowSolver with constant and dead code analysis included.
|
||||
std::unique_ptr<DataFlowSolver> createDataFlowSolver();
|
||||
|
||||
triton::MakeTensorPtrOp getMakeTensorPtrOp(Value v);
|
||||
|
||||
} // namespace mlir
|
||||
|
||||
|
||||
@@ -19,6 +19,7 @@ def ConvertTritonGPUToLLVM : Pass<"convert-triton-gpu-to-llvm", "mlir::ModuleOp"
|
||||
"mlir::tensor::TensorDialect",
|
||||
"mlir::triton::TritonDialect",
|
||||
"mlir::triton::gpu::TritonGPUDialect",
|
||||
"mlir::triton::nvidia_gpu::TritonNvidiaGPUDialect",
|
||||
"mlir::ROCDL::ROCDLDialect",
|
||||
"mlir::NVVM::NVVMDialect"];
|
||||
|
||||
@@ -26,6 +27,9 @@ def ConvertTritonGPUToLLVM : Pass<"convert-triton-gpu-to-llvm", "mlir::ModuleOp"
|
||||
Option<"computeCapability", "compute-capability",
|
||||
"int32_t", /*default*/"80",
|
||||
"device compute capability">,
|
||||
Option<"TmaMetadata", "tma-metadata",
|
||||
"mlir::triton::gpu::TMAMetadataTy*", /*default*/"nullptr",
|
||||
"tma metadata to the runtime">,
|
||||
Option<"isROCM", "is-rocm",
|
||||
"bool", /*default*/"false",
|
||||
"compile for ROCM-compatible LLVM">,
|
||||
|
||||
@@ -3,6 +3,8 @@
|
||||
|
||||
#include "mlir/Conversion/LLVMCommon/TypeConverter.h"
|
||||
#include "mlir/Transforms/DialectConversion.h"
|
||||
#include "triton/Target/PTX/TmaMetadata.h"
|
||||
|
||||
#include <memory>
|
||||
|
||||
namespace mlir {
|
||||
@@ -12,9 +14,10 @@ template <typename T> class OperationPass;
|
||||
|
||||
namespace triton {
|
||||
|
||||
std::unique_ptr<OperationPass<ModuleOp>>
|
||||
createConvertTritonGPUToLLVMPass(int computeCapability = 80,
|
||||
bool isROCM = false);
|
||||
std::unique_ptr<OperationPass<ModuleOp>> createConvertTritonGPUToLLVMPass(
|
||||
int computeCapability = 80,
|
||||
mlir::triton::gpu::TMAMetadataTy *tmaMetadata = nullptr,
|
||||
bool isROCM = false);
|
||||
|
||||
} // namespace triton
|
||||
|
||||
|
||||
@@ -2,6 +2,7 @@
|
||||
#define TRITON_CONVERSION_PASSES_H
|
||||
|
||||
#include "triton/Conversion/TritonToTritonGPU/TritonToTritonGPUPass.h"
|
||||
#include "triton/Target/PTX/TmaMetadata.h"
|
||||
|
||||
namespace mlir {
|
||||
namespace triton {
|
||||
|
||||
@@ -25,6 +25,12 @@ def ConvertTritonToTritonGPU: Pass<"convert-triton-to-tritongpu", "mlir::ModuleO
|
||||
Option<"threadsPerWarp", "threads-per-warp",
|
||||
"int32_t", /*default*/"32",
|
||||
"number of threads per warp">,
|
||||
Option<"numCTAs", "num-ctas",
|
||||
"int32_t", /*default*/"1",
|
||||
"number of ctas in a cga">,
|
||||
Option<"computeCapability", "compute-capability",
|
||||
"int32_t", /*default*/"80",
|
||||
"compute capability">
|
||||
];
|
||||
}
|
||||
|
||||
|
||||
@@ -11,6 +11,9 @@ template <typename T> class OperationPass;
|
||||
namespace triton {
|
||||
|
||||
constexpr static char AttrNumWarpsName[] = "triton_gpu.num-warps";
|
||||
constexpr static char AttrNumCTAsName[] = "triton_gpu.num-ctas";
|
||||
constexpr static char AttrComputeCapabilityName[] =
|
||||
"triton_gpu.compute-capability";
|
||||
|
||||
constexpr static char AttrNumThreadsPerWarp[] = "triton_gpu.threads-per-warp";
|
||||
|
||||
@@ -19,7 +22,8 @@ std::unique_ptr<OperationPass<ModuleOp>> createConvertTritonToTritonGPUPass();
|
||||
|
||||
// Create the pass with numWarps set explicitly.
|
||||
std::unique_ptr<OperationPass<ModuleOp>>
|
||||
createConvertTritonToTritonGPUPass(int numWarps, int threadsPerWarp = 32);
|
||||
createConvertTritonToTritonGPUPass(int numWarps, int threadsPerWarp = 32,
|
||||
int numCTAs = 1, int computeCapability = 80);
|
||||
|
||||
} // namespace triton
|
||||
} // namespace mlir
|
||||
|
||||
@@ -1,2 +1,4 @@
|
||||
add_subdirectory(Triton)
|
||||
add_subdirectory(TritonGPU)
|
||||
add_subdirectory(TritonNvidiaGPU)
|
||||
add_subdirectory(NVGPU)
|
||||
|
||||
2
include/triton/Dialect/NVGPU/CMakeLists.txt
Normal file
2
include/triton/Dialect/NVGPU/CMakeLists.txt
Normal file
@@ -0,0 +1,2 @@
|
||||
add_subdirectory(IR)
|
||||
#add_subdirectory(Transforms)
|
||||
14
include/triton/Dialect/NVGPU/IR/CMakeLists.txt
Normal file
14
include/triton/Dialect/NVGPU/IR/CMakeLists.txt
Normal file
@@ -0,0 +1,14 @@
|
||||
set(LLVM_TARGET_DEFINITIONS NVGPUOps.td)
|
||||
mlir_tablegen(Dialect.h.inc -gen-dialect-decls -dialect=nvgpu)
|
||||
mlir_tablegen(Dialect.cpp.inc -gen-dialect-defs -dialect=nvgpu)
|
||||
mlir_tablegen(OpsConversions.inc -gen-llvmir-conversions)
|
||||
mlir_tablegen(Ops.h.inc -gen-op-decls)
|
||||
mlir_tablegen(Ops.cpp.inc -gen-op-defs)
|
||||
mlir_tablegen(OpsEnums.h.inc -gen-enum-decls)
|
||||
mlir_tablegen(OpsEnums.cpp.inc -gen-enum-defs)
|
||||
add_public_tablegen_target(NVGPUTableGen)
|
||||
|
||||
set(LLVM_TARGET_DEFINITIONS NVGPUAttrDefs.td)
|
||||
mlir_tablegen(NVGPUAttrDefs.h.inc -gen-attrdef-decls)
|
||||
mlir_tablegen(NVGPUAttrDefs.cpp.inc -gen-attrdef-defs)
|
||||
add_public_tablegen_target(NVGPUAttrDefsIncGen)
|
||||
47
include/triton/Dialect/NVGPU/IR/Dialect.h
Normal file
47
include/triton/Dialect/NVGPU/IR/Dialect.h
Normal file
@@ -0,0 +1,47 @@
|
||||
/*
|
||||
* Copyright (c) 2023 NVIDIA Corporation & Affiliates. All rights reserved.
|
||||
*
|
||||
* Permission is hereby granted, free of charge, to any person obtaining
|
||||
* a copy of this software and associated documentation files
|
||||
* (the "Software"), to deal in the Software without restriction,
|
||||
* including without limitation the rights to use, copy, modify, merge,
|
||||
* publish, distribute, sublicense, and/or sell copies of the Software,
|
||||
* and to permit persons to whom the Software is furnished to do so,
|
||||
* subject to the following conditions:
|
||||
*
|
||||
* The above copyright notice and this permission notice shall be
|
||||
* included in all copies or substantial portions of the Software.
|
||||
*
|
||||
* THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND,
|
||||
* EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF
|
||||
* MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT.
|
||||
* IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY
|
||||
* CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION OF CONTRACT,
|
||||
* TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION WITH THE
|
||||
* SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE.
|
||||
*/
|
||||
|
||||
#ifndef TRITON_DIALECT_NVGPU_IR_DIALECT_H_
|
||||
#define TRITON_DIALECT_NVGPU_IR_DIALECT_H_
|
||||
|
||||
#include "mlir/Dialect/GPU/IR/GPUDialect.h"
|
||||
#include "mlir/Dialect/LLVMIR/LLVMDialect.h"
|
||||
#include "mlir/Dialect/Tensor/IR/Tensor.h"
|
||||
#include "mlir/IR/BuiltinOps.h"
|
||||
#include "mlir/IR/Dialect.h"
|
||||
#include "triton/Dialect/NVGPU/IR/Dialect.h.inc"
|
||||
#include "triton/Dialect/NVGPU/IR/OpsEnums.h.inc"
|
||||
|
||||
#define GET_ATTRDEF_CLASSES
|
||||
#include "triton/Dialect/NVGPU/IR/NVGPUAttrDefs.h.inc"
|
||||
|
||||
#define GET_OP_CLASSES
|
||||
#include "triton/Dialect/NVGPU/IR/Ops.h.inc"
|
||||
|
||||
namespace mlir {
|
||||
namespace triton {
|
||||
namespace nvgpu {} // namespace nvgpu
|
||||
} // namespace triton
|
||||
} // namespace mlir
|
||||
|
||||
#endif // TRITON_DIALECT_TRITONGPU_IR_DIALECT_H_
|
||||
33
include/triton/Dialect/NVGPU/IR/NVGPUAttrDefs.td
Normal file
33
include/triton/Dialect/NVGPU/IR/NVGPUAttrDefs.td
Normal file
@@ -0,0 +1,33 @@
|
||||
// Copyright (c) 2023 NVIDIA Corporation & Affiliates. All rights reserved.
|
||||
//
|
||||
// Permission is hereby granted, free of charge, to any person obtaining
|
||||
// a copy of this software and associated documentation files
|
||||
// (the "Software"), to deal in the Software without restriction,
|
||||
// including without limitation the rights to use, copy, modify, merge,
|
||||
// publish, distribute, sublicense, and/or sell copies of the Software,
|
||||
// and to permit persons to whom the Software is furnished to do so,
|
||||
// subject to the following conditions:
|
||||
//
|
||||
// The above copyright notice and this permission notice shall be
|
||||
// included in all copies or substantial portions of the Software.
|
||||
//
|
||||
// THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND,
|
||||
// EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF
|
||||
// MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT.
|
||||
// IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY
|
||||
// CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION OF CONTRACT,
|
||||
// TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION WITH THE
|
||||
// SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE.
|
||||
|
||||
#ifndef NVGPU_ATTRDEFS
|
||||
#define NVGPU_ATTRDEFS
|
||||
|
||||
include "triton/Dialect/NVGPU/IR/NVGPUDialect.td"
|
||||
include "mlir/IR/AttrTypeBase.td"
|
||||
|
||||
class NVGPU_Attr<string name, list<Trait> traits = [],
|
||||
string baseCppClass = "::mlir::Attribute">
|
||||
: AttrDef<NVGPU_Dialect, name, traits, baseCppClass> {
|
||||
}
|
||||
|
||||
#endif
|
||||
40
include/triton/Dialect/NVGPU/IR/NVGPUDialect.td
Normal file
40
include/triton/Dialect/NVGPU/IR/NVGPUDialect.td
Normal file
@@ -0,0 +1,40 @@
|
||||
// Copyright (c) 2023 NVIDIA Corporation & Affiliates. All rights reserved.
|
||||
//
|
||||
// Permission is hereby granted, free of charge, to any person obtaining
|
||||
// a copy of this software and associated documentation files
|
||||
// (the "Software"), to deal in the Software without restriction,
|
||||
// including without limitation the rights to use, copy, modify, merge,
|
||||
// publish, distribute, sublicense, and/or sell copies of the Software,
|
||||
// and to permit persons to whom the Software is furnished to do so,
|
||||
// subject to the following conditions:
|
||||
//
|
||||
// The above copyright notice and this permission notice shall be
|
||||
// included in all copies or substantial portions of the Software.
|
||||
//
|
||||
// THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND,
|
||||
// EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF
|
||||
// MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT.
|
||||
// IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY
|
||||
// CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION OF CONTRACT,
|
||||
// TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION WITH THE
|
||||
// SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE.
|
||||
|
||||
#ifndef NVGPU_DIALECT
|
||||
#define NVGPU_DIALECT
|
||||
|
||||
include "mlir/IR/OpBase.td"
|
||||
|
||||
def NVGPU_Dialect : Dialect {
|
||||
let name = "nvgpu";
|
||||
let cppNamespace = "::mlir::triton::nvgpu";
|
||||
|
||||
let description = [{
|
||||
NVGPU Dialect.
|
||||
}];
|
||||
|
||||
let dependentDialects = [
|
||||
"mlir::LLVM::LLVMDialect"
|
||||
];
|
||||
}
|
||||
|
||||
#endif
|
||||
380
include/triton/Dialect/NVGPU/IR/NVGPUOps.td
Normal file
380
include/triton/Dialect/NVGPU/IR/NVGPUOps.td
Normal file
@@ -0,0 +1,380 @@
|
||||
// Copyright (c) 2023 NVIDIA Corporation & Affiliates. All rights reserved.
|
||||
//
|
||||
// Permission is hereby granted, free of charge, to any person obtaining
|
||||
// a copy of this software and associated documentation files
|
||||
// (the "Software"), to deal in the Software without restriction,
|
||||
// including without limitation the rights to use, copy, modify, merge,
|
||||
// publish, distribute, sublicense, and/or sell copies of the Software,
|
||||
// and to permit persons to whom the Software is furnished to do so,
|
||||
// subject to the following conditions:
|
||||
//
|
||||
// The above copyright notice and this permission notice shall be
|
||||
// included in all copies or substantial portions of the Software.
|
||||
//
|
||||
// THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND,
|
||||
// EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF
|
||||
// MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT.
|
||||
// IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY
|
||||
// CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION OF CONTRACT,
|
||||
// TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION WITH THE
|
||||
// SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE.
|
||||
|
||||
#ifndef NVGPU_OPS
|
||||
#define NVGPU_OPS
|
||||
|
||||
include "triton/Dialect/NVGPU/IR/NVGPUDialect.td"
|
||||
include "triton/Dialect/NVGPU/IR/NVGPUAttrDefs.td"
|
||||
include "mlir/IR/OpBase.td"
|
||||
include "mlir/IR/EnumAttr.td"
|
||||
include "mlir/Dialect/LLVMIR/LLVMOpBase.td"
|
||||
|
||||
def I8Ptr_global : LLVM_IntPtrBase<8, 1>;
|
||||
def I8Ptr_shared : LLVM_IntPtrBase<8, 3>;
|
||||
def I64Ptr_shared : LLVM_IntPtrBase<64, 3>;
|
||||
|
||||
class NVGPU_Op<string mnemonic, list<Trait> traits = []> :
|
||||
LLVM_OpBase<NVGPU_Dialect, mnemonic, traits>;
|
||||
|
||||
def NVGPU_WGMMAFenceOp : NVGPU_Op<"wgmma_fence", []> {
|
||||
string llvmBuilder = [{
|
||||
createWGMMAFence(builder);
|
||||
}];
|
||||
let assemblyFormat = "attr-dict";
|
||||
}
|
||||
|
||||
|
||||
def NVGPU_WGMMACommitGroupOp : NVGPU_Op<"wgmma_commit_group", []> {
|
||||
let assemblyFormat = "attr-dict";
|
||||
string llvmBuilder = [{
|
||||
createWGMMACommitGroup(builder);
|
||||
}];
|
||||
}
|
||||
|
||||
def NVGPU_WGMMAWaitOp : NVGPU_Op<"wgmma_wait_group", []> {
|
||||
let arguments = (ins I32Attr:$pendings);
|
||||
let assemblyFormat = "attr-dict";
|
||||
string llvmBuilder = [{
|
||||
createWGMMAWaitGroup(builder, $pendings);
|
||||
}];
|
||||
}
|
||||
|
||||
def NVGPU_MBarrierInitOp : NVGPU_Op<"mbarrier_init", [MemoryEffects<[MemWrite]>]> {
|
||||
let arguments = (ins I64Ptr_shared:$mbarrier, I1:$pred, I32Attr:$count);
|
||||
let assemblyFormat = "$mbarrier `,` $pred attr-dict `:` type($mbarrier)";
|
||||
string llvmBuilder = [{
|
||||
auto *i32Ty = builder.getInt32Ty();
|
||||
auto *arriveCnt = builder.getInt32($count);
|
||||
createExternalCall(builder, "__nv_mbarrier_init", {builder.CreatePtrToInt($mbarrier, i32Ty),
|
||||
arriveCnt,
|
||||
builder.CreateIntCast($pred, i32Ty, false)});
|
||||
}];
|
||||
}
|
||||
|
||||
def MBarrier_ArriveTypeAttr : I32EnumAttr<"MBarriveType",
|
||||
"mbarrier arrive type, either 'normal', 'expect_tx', 'cp_async'",
|
||||
[
|
||||
I32EnumAttrCase<"normal", 0>,
|
||||
I32EnumAttrCase<"cp_async", 1>,
|
||||
I32EnumAttrCase<"expect_tx", 2>,
|
||||
I32EnumAttrCase<"remote", 3>,
|
||||
]>{
|
||||
let cppNamespace = "::mlir::triton::nvgpu";
|
||||
}
|
||||
|
||||
def NVGPU_MBarrierArriveOp : NVGPU_Op<"mbarrier_arrive", []> {
|
||||
let arguments = (ins I64Ptr_shared:$mbarrier, I1:$pred, Optional<I32>:$ctaId, MBarrier_ArriveTypeAttr:$arriveType, DefaultValuedAttr<I32Attr, "0">:$txCount);
|
||||
let assemblyFormat = "$mbarrier `,` $pred (`,` $ctaId^)? attr-dict `:` type($mbarrier)";
|
||||
string llvmBuilder = [{
|
||||
auto *i32Ty = builder.getInt32Ty();
|
||||
createMBarrierArrive(builder, $arriveType, builder.CreatePtrToInt($mbarrier, i32Ty),
|
||||
builder.CreateIntCast($pred, i32Ty, false), $ctaId,
|
||||
$txCount);
|
||||
}];
|
||||
}
|
||||
|
||||
def NVGPU_MBarrierWaitOp : NVGPU_Op<"mbarrier_wait", []> {
|
||||
let arguments = (ins I64Ptr_shared:$mbarrier, I1:$phase);
|
||||
let assemblyFormat = "$mbarrier `,` $phase attr-dict `:` type(operands)";
|
||||
|
||||
string llvmBuilder = [{
|
||||
auto *i32Ty = builder.getInt32Ty();
|
||||
createExternalCall(builder, "__nv_mbarrier_wait", {builder.CreatePtrToInt($mbarrier, i32Ty),
|
||||
builder.CreateIntCast($phase, i32Ty, false)});
|
||||
}];
|
||||
}
|
||||
|
||||
def NVGPU_NamedBarrierArriveOp : NVGPU_Op<"bar_arrive", []> {
|
||||
let arguments = (ins I32:$bar, I32:$numThreads);
|
||||
let assemblyFormat = "$bar `,` $numThreads attr-dict `:` type(operands)";
|
||||
string llvmBuilder = [{
|
||||
createExternalCall(builder, "__nv_bar_arrive", {$bar, $numThreads});
|
||||
}];
|
||||
}
|
||||
|
||||
def NVGPU_NamedBarrierWaitOp : NVGPU_Op<"bar_wait", []> {
|
||||
let arguments = (ins I32:$bar, I32:$numThreads);
|
||||
let assemblyFormat = "$bar `,` $numThreads attr-dict `:` type(operands)";
|
||||
string llvmBuilder = [{
|
||||
createExternalCall(builder, "__nv_bar_wait", {$bar, $numThreads});
|
||||
}];
|
||||
}
|
||||
|
||||
def WGMMADesc_ModeAttr : I32EnumAttr<"WGMMADescMode",
|
||||
"wgmma desc mode, either 'none', 'swizzle128', 'swizzle64', or 'swizzle32'",
|
||||
[
|
||||
I32EnumAttrCase<"none", 0>,
|
||||
I32EnumAttrCase<"swizzle128", 1>,
|
||||
I32EnumAttrCase<"swizzle64", 2>,
|
||||
I32EnumAttrCase<"swizzle32", 3>
|
||||
]>{
|
||||
let cppNamespace = "::mlir::triton::nvgpu";
|
||||
}
|
||||
|
||||
def NVGPU_WGMMADescCreateOp : NVGPU_Op<"wgmma_desc_create", []> {
|
||||
let arguments = (ins LLVM_AnyPointer:$buffer, I32:$height, WGMMADesc_ModeAttr:$mode);
|
||||
let results = (outs I64:$res);
|
||||
let assemblyFormat = "$buffer `,` $height attr-dict `:` functional-type(operands, results)";
|
||||
string llvmBuilder = [{
|
||||
$res = createWGMMADesc(builder, builder.CreatePtrToInt($buffer, builder.getInt32Ty()), $mode, $height);
|
||||
}];
|
||||
}
|
||||
|
||||
def NVGPU_TMALoadTiledOp : NVGPU_Op<"tma_load_tiled", [AttrSizedOperandSegments]> {
|
||||
let arguments = (ins I8Ptr_shared:$dst, I64Ptr_shared:$mbarrier, I8Ptr_global:$tmaDesc, I64:$l2Desc,
|
||||
I1:$pred, Variadic<I32>:$coords, Optional<I16>:$mcastMask);
|
||||
let assemblyFormat = "operands attr-dict `:` type(operands)";
|
||||
string llvmBuilder = [{
|
||||
auto *i32Ty = builder.getInt32Ty();
|
||||
auto *i64Ty = builder.getInt64Ty();
|
||||
createTMALoadTiled(builder,
|
||||
builder.CreatePtrToInt($dst, i32Ty),
|
||||
builder.CreatePtrToInt($mbarrier, i32Ty),
|
||||
builder.CreatePtrToInt($tmaDesc, i64Ty),
|
||||
$l2Desc, $mcastMask, builder.CreateIntCast($pred, builder.getInt32Ty(), false), $coords);
|
||||
}];
|
||||
}
|
||||
|
||||
def NVGPU_TMALoadIm2colOp : NVGPU_Op<"tma_load_im2col", []> {
|
||||
let arguments = (ins I8Ptr_shared:$dst, I64Ptr_shared:$mbarrier, I8Ptr_global:$tmaDesc, I64:$l2Desc, LLVM_AnyStruct:$im2colOffsets, I1:$pred, Variadic<I32>:$coords, I16Attr:$mcastMask);
|
||||
let assemblyFormat = "operands attr-dict `:` type(operands)";
|
||||
string llvmBuilder = [{
|
||||
auto *i32Ty = builder.getInt32Ty();
|
||||
auto *i64Ty = builder.getInt64Ty();
|
||||
createTMALoadIm2col(builder,
|
||||
builder.CreatePtrToInt($dst, i32Ty),
|
||||
builder.CreatePtrToInt($mbarrier, i32Ty),
|
||||
builder.CreatePtrToInt($tmaDesc, i64Ty),
|
||||
$l2Desc, $mcastMask, $im2colOffsets, builder.CreateIntCast($pred, builder.getInt32Ty(), false), $coords);
|
||||
}];
|
||||
}
|
||||
|
||||
def WGMMA_LayoutAttr : I32EnumAttr<"WGMMALayout",
|
||||
"wgmma layout, either 'row' or 'col'",
|
||||
[
|
||||
I32EnumAttrCase<"row", 0>,
|
||||
I32EnumAttrCase<"col", 1>
|
||||
]>{
|
||||
let cppNamespace = "::mlir::triton::nvgpu";
|
||||
}
|
||||
|
||||
def WGMMA_EltTypeAttr : I32EnumAttr<"WGMMAEltType",
|
||||
"wgmma operand type, either 's8', 's32', 'e4m3', 'e5m2', 'f16', 'bf16', 'tf32', or 'f32'",
|
||||
[
|
||||
I32EnumAttrCase<"s8", 0>,
|
||||
I32EnumAttrCase<"s32", 1>,
|
||||
I32EnumAttrCase<"e4m3", 2>,
|
||||
I32EnumAttrCase<"e5m2", 3>,
|
||||
I32EnumAttrCase<"f16", 4>,
|
||||
I32EnumAttrCase<"bf16", 5>,
|
||||
I32EnumAttrCase<"tf32", 6>,
|
||||
I32EnumAttrCase<"f32", 7>
|
||||
]>{
|
||||
let cppNamespace = "::mlir::triton::nvgpu";
|
||||
}
|
||||
|
||||
def WGMMA_OperandType : AnyTypeOf<[LLVM_AnyStruct, I64], "wgmma operand A/B type">;
|
||||
|
||||
def NVGPU_WGMMAOp : NVGPU_Op<"wgmma", []> {
|
||||
let arguments = (ins WGMMA_OperandType:$opA, WGMMA_OperandType:$opB, LLVM_AnyStruct:$opC,
|
||||
I32Attr:$m, I32Attr:$n, I32Attr:$k,
|
||||
WGMMA_EltTypeAttr:$eltTypeC, WGMMA_EltTypeAttr:$eltTypeA, WGMMA_EltTypeAttr:$eltTypeB,
|
||||
WGMMA_LayoutAttr:$layoutA, WGMMA_LayoutAttr:$layoutB);
|
||||
let results = (outs LLVM_AnyStruct:$res);
|
||||
let assemblyFormat = "$opA `,` $opB `,` $opC attr-dict `:` functional-type(operands, $res)";
|
||||
string llvmBuilder = [{
|
||||
$res = createWGMMA(builder, $m, $n, $k, $eltTypeC, $eltTypeA, $eltTypeB, $layoutA, $layoutB, $opA, $opB, $opC);
|
||||
}];
|
||||
}
|
||||
|
||||
def NVGPU_CGABarrierSyncOp : NVGPU_Op<"cga_barrier_sync", []> {
|
||||
let assemblyFormat = "attr-dict";
|
||||
string llvmBuilder = [{
|
||||
createExternalCall(builder, "__nv_cga_barrier_sync");
|
||||
}];
|
||||
}
|
||||
|
||||
def NVGPU_CGABarrierArriveOp : NVGPU_Op<"cga_barrier_arrive", []> {
|
||||
let assemblyFormat = "attr-dict";
|
||||
string llvmBuilder = [{
|
||||
createExternalCall(builder, "__nv_cga_barrier_arrive");
|
||||
}];
|
||||
}
|
||||
|
||||
def NVGPU_CGABarrierWaitOp : NVGPU_Op<"cga_barrier_wait", []> {
|
||||
let assemblyFormat = "attr-dict";
|
||||
string llvmBuilder = [{
|
||||
createExternalCall(builder, "__nv_cga_barrier_wait");
|
||||
}];
|
||||
}
|
||||
|
||||
def NVGPU_LoadDSmemOp : NVGPU_Op<"load_dsmem", [MemoryEffects<[MemRead]>]> {
|
||||
let arguments = (ins LLVM_AnyPointer:$addr, I32:$ctaId, I32Attr:$bitwidth, I32Attr:$vec);
|
||||
let builders = [
|
||||
OpBuilder<(ins "Type":$resultTy, "Value":$addr, "Value":$ctaId)>,
|
||||
OpBuilder<(ins "Value":$addr, "Value":$ctaId, "unsigned":$bitwidth, "unsigned":$vec)>,
|
||||
OpBuilder<(ins "Value":$addr, "Value":$ctaId, "unsigned":$bitwidth)>
|
||||
];
|
||||
let results = (outs LLVM_LoadableType:$result);
|
||||
let assemblyFormat = "operands attr-dict `:` functional-type(operands, results)";
|
||||
string llvmBuilder = [{
|
||||
$result = createLoadSharedCluster(builder, $addr, $ctaId, $bitwidth, $vec);
|
||||
}];
|
||||
}
|
||||
|
||||
def NVGPU_StoreDSmemOp : NVGPU_Op<"store_dsmem", [MemoryEffects<[MemWrite]>]> {
|
||||
let arguments = (ins LLVM_AnyPointer:$addr, I32:$ctaId,
|
||||
Variadic<LLVM_LoadableType>:$values, I1:$pred);
|
||||
let builders = [
|
||||
OpBuilder<(ins "Value":$addr, "Value":$ctaId, "Value":$value, "Value":$pred)>,
|
||||
];
|
||||
let assemblyFormat = "operands attr-dict `:` type(operands)";
|
||||
string llvmBuilder = [{
|
||||
createStoreSharedCluster(builder, $addr, $ctaId, $values, $pred, op.getBitwidth(), op.getVec());
|
||||
}];
|
||||
let extraClassDeclaration = [{
|
||||
unsigned getBitwidth();
|
||||
unsigned getVec();
|
||||
}];
|
||||
}
|
||||
|
||||
def NVGPU_FenceAsyncSharedOp : NVGPU_Op<"fence_async_shared", []> {
|
||||
let arguments = (ins BoolAttr:$bCluster);
|
||||
string llvmBuilder = [{
|
||||
if ($bCluster)
|
||||
createExternalCall(builder, "__nv_fence_async_shared_cluster", {});
|
||||
else
|
||||
createExternalCall(builder, "__nv_fence_async_shared_cta", {});
|
||||
}];
|
||||
let assemblyFormat = "attr-dict";
|
||||
}
|
||||
|
||||
def NVGPU_FenceMBarrierInitOp : NVGPU_Op<"fence_mbarrier_init", []> {
|
||||
string llvmBuilder = [{
|
||||
createExternalCall(builder, "__nv_fence_mbarrier_init", {});
|
||||
}];
|
||||
let assemblyFormat = "attr-dict";
|
||||
}
|
||||
|
||||
def NVGPU_ClusterArriveOp : NVGPU_Op<"cluster_arrive", []> {
|
||||
let arguments = (ins I1Attr:$relaxed);
|
||||
|
||||
string llvmBuilder = [{
|
||||
if ($relaxed)
|
||||
createExternalCall(builder, "__nv_cluster_arrive_relaxed", {});
|
||||
else
|
||||
createExternalCall(builder, "__nv_cluster_arrive", {});
|
||||
}];
|
||||
let assemblyFormat = "attr-dict";
|
||||
}
|
||||
|
||||
def NVGPU_ClusterWaitOp : NVGPU_Op<"cluster_wait", []> {
|
||||
string llvmBuilder = [{
|
||||
createExternalCall(builder, "__nv_cluster_wait", {});
|
||||
}];
|
||||
let assemblyFormat = "attr-dict";
|
||||
}
|
||||
|
||||
def NVGPU_TMAStoreTiledOp : NVGPU_Op<"tma_store_tiled", [MemoryEffects<[MemWrite]>]> {
|
||||
let arguments = (ins I8Ptr_global:$tmaDesc, I8Ptr_shared:$src, I1:$pred, Variadic<I32>:$coords);
|
||||
let assemblyFormat = "operands attr-dict `:` type(operands)";
|
||||
string llvmBuilder = [{
|
||||
auto *i32Ty = builder.getInt32Ty();
|
||||
auto *i64Ty = builder.getInt64Ty();
|
||||
createTMAStoreTiled(builder,
|
||||
builder.CreatePtrToInt($tmaDesc, i64Ty),
|
||||
builder.CreatePtrToInt($src, i32Ty),
|
||||
builder.CreateIntCast($pred, i32Ty, false), $coords);
|
||||
}];
|
||||
}
|
||||
|
||||
def NVGPU_StoreMatrixOp : NVGPU_Op<"stmatrix", [MemoryEffects<[MemWrite]>]> {
|
||||
let arguments = (ins I8Ptr_shared:$addr, Variadic<I32>:$datas);
|
||||
let assemblyFormat = "operands attr-dict `:` type(operands)";
|
||||
string llvmBuilder = [{
|
||||
auto *i32Ty = builder.getInt32Ty();
|
||||
createStoreMatrix(builder,
|
||||
builder.CreatePtrToInt($addr, i32Ty),
|
||||
$datas);
|
||||
}];
|
||||
}
|
||||
|
||||
def NVGPU_OffsetOfStmatrixV4Op : NVGPU_Op<"offset_of_stmatrix_v4", []> {
|
||||
let arguments = (ins I32:$threadId, I32:$rowOfWarp, I32:$elemIdx, I32Attr:$leadingDimOffset, I32Attr:$rowStride, I1Attr:$swizzleEnabled);
|
||||
let results = (outs I32:$offset);
|
||||
let assemblyFormat = "operands attr-dict `:` type(operands) `->` type($offset)";
|
||||
string llvmBuilder = [{
|
||||
$offset = createOffsetOfStmatrixV4(builder, $threadId, $rowOfWarp, $elemIdx, $leadingDimOffset, $rowStride, $swizzleEnabled);
|
||||
}];
|
||||
}
|
||||
|
||||
def NVGPU_OffsetOfSts64Op : NVGPU_Op<"offset_of_sts64", []> {
|
||||
let arguments = (ins I32:$threadId, I32:$rowOfWarp, I32:$elemIdx, I32Attr:$leadingDimOffset, I32Attr:$rowStride, I1Attr:$swizzleEnabled);
|
||||
let results = (outs I32:$offset);
|
||||
let assemblyFormat = "operands attr-dict `:` type(operands) `->` type($offset)";
|
||||
string llvmBuilder = [{
|
||||
$offset = createOffsetOfSts64(builder, $threadId, $rowOfWarp, $elemIdx, $leadingDimOffset, $rowStride, $swizzleEnabled);
|
||||
}];
|
||||
}
|
||||
|
||||
def NVGPU_Sts64Op : NVGPU_Op<"sts64", [MemoryEffects<[MemWrite]>]> {
|
||||
let arguments = (ins I32:$offset, AnyTypeOf<[F32, I32]>:$d0, AnyTypeOf<[F32, I32]>:$d1);
|
||||
let assemblyFormat = "operands attr-dict `:` type(operands)";
|
||||
string llvmBuilder = [{
|
||||
createSts64(builder, $offset, $d0, $d1);
|
||||
}];
|
||||
}
|
||||
|
||||
def NVGPU_CvtPackOp : NVGPU_Op<"cvt_pack", []> {
|
||||
let arguments = (ins AnyTypeOf<[F16, I16]>:$d0, AnyTypeOf<[F16, I16]>:$d1);
|
||||
let results = (outs I32:$result);
|
||||
let assemblyFormat = "operands attr-dict `:` type(operands) `->` type($result)";
|
||||
string llvmBuilder = [{
|
||||
$result = createCvtPack(builder, $d0, $d1);
|
||||
}];
|
||||
}
|
||||
|
||||
def NVGPU_ClusterCTAIdOp : NVGPU_Op<"cluster_id", [Pure]> {
|
||||
let results = (outs I32:$result);
|
||||
let assemblyFormat = "attr-dict";
|
||||
string llvmBuilder = [{
|
||||
$result = createClusterId(builder);
|
||||
}];
|
||||
}
|
||||
|
||||
def NVGPU_RegAllocOp : NVGPU_Op<"reg_alloc", []> {
|
||||
let arguments = (ins I32Attr: $regCount);
|
||||
let assemblyFormat = "operands attr-dict `:` type(operands)";
|
||||
string llvmBuilder = [{
|
||||
createRegAlloc(builder, $regCount);
|
||||
}];
|
||||
}
|
||||
|
||||
def NVGPU_RegDeallocOp : NVGPU_Op<"reg_dealloc", []> {
|
||||
let arguments = (ins I32Attr: $regCount);
|
||||
let assemblyFormat = "operands attr-dict `:` type(operands)";
|
||||
string llvmBuilder = [{
|
||||
createRegDealloc(builder, $regCount);
|
||||
}];
|
||||
}
|
||||
|
||||
#endif
|
||||
41
include/triton/Dialect/NVGPU/ToLLVMIR/NVGPUToLLVMIR.h
Normal file
41
include/triton/Dialect/NVGPU/ToLLVMIR/NVGPUToLLVMIR.h
Normal file
@@ -0,0 +1,41 @@
|
||||
/*
|
||||
* Copyright (c) 2023 NVIDIA Corporation & Affiliates. All rights reserved.
|
||||
*
|
||||
* Permission is hereby granted, free of charge, to any person obtaining
|
||||
* a copy of this software and associated documentation files
|
||||
* (the "Software"), to deal in the Software without restriction,
|
||||
* including without limitation the rights to use, copy, modify, merge,
|
||||
* publish, distribute, sublicense, and/or sell copies of the Software,
|
||||
* and to permit persons to whom the Software is furnished to do so,
|
||||
* subject to the following conditions:
|
||||
*
|
||||
* The above copyright notice and this permission notice shall be
|
||||
* included in all copies or substantial portions of the Software.
|
||||
*
|
||||
* THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND,
|
||||
* EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF
|
||||
* MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT.
|
||||
* IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY
|
||||
* CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION OF CONTRACT,
|
||||
* TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION WITH THE
|
||||
* SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE.
|
||||
*/
|
||||
|
||||
#ifndef TRITON_DIALECT_NVGPU_NVGPUTOLLVMIRTRANSLATION_H
|
||||
#define TRITON_DIALECT_NVGPU_NVGPUTOLLVMIRTRANSLATION_H
|
||||
|
||||
namespace mlir {
|
||||
|
||||
class DialectRegistry;
|
||||
class MLIRContext;
|
||||
|
||||
/// Register the nvgpu dialect and the translation from it to the LLVM IR in the
|
||||
/// given registry;
|
||||
void registerNVGPUDialectTranslation(DialectRegistry ®istry);
|
||||
|
||||
/// Register the nvgpu dialect and the translation from it in the registry
|
||||
/// associated with the given context.
|
||||
void registerNVGPUDialectTranslation(MLIRContext &context);
|
||||
} // namespace mlir
|
||||
|
||||
#endif // TRITON_DIALECT_NVGPU_NVGPUTOLLVMIRTRANSLATION_H
|
||||
@@ -9,6 +9,8 @@
|
||||
#include "mlir/Dialect/Tensor/IR/Tensor.h"
|
||||
#include "mlir/IR/BuiltinOps.h"
|
||||
#include "mlir/IR/Dialect.h"
|
||||
#include "mlir/IR/FunctionInterfaces.h"
|
||||
#include "mlir/Interfaces/CallInterfaces.h"
|
||||
#include "mlir/Interfaces/ControlFlowInterfaces.h"
|
||||
#include "triton/Dialect/Triton/IR/Dialect.h.inc"
|
||||
#include "triton/Dialect/Triton/IR/OpsEnums.h.inc"
|
||||
|
||||
@@ -9,6 +9,8 @@ include "mlir/IR/OpBase.td"
|
||||
include "mlir/IR/FunctionInterfaces.td" // FunctionOpInterface
|
||||
include "mlir/IR/SymbolInterfaces.td" // SymbolUserOpInterface
|
||||
include "mlir/IR/OpAsmInterface.td" // OpAsmOpInterface
|
||||
include "mlir/Interfaces/CallInterfaces.td" // CallOpInterface
|
||||
include "mlir/Interfaces/CastInterfaces.td" // CastOpInterface
|
||||
include "mlir/Interfaces/SideEffectInterfaces.td" // Pure
|
||||
include "mlir/Interfaces/ControlFlowInterfaces.td" // BranchOpInterface
|
||||
include "mlir/Interfaces/InferTypeOpInterface.td" // SameOperandsAndResultType
|
||||
@@ -563,6 +565,7 @@ def TT_MakeTensorPtrOp : TT_Op<"make_tensor_ptr",
|
||||
|
||||
let results = (outs TT_TensorPtr:$result);
|
||||
|
||||
// TODO(Keren): define a custom assembly format for this op because the result type cannot be printed correctly
|
||||
// Add additional `[]` to increase readability and split variadic lists
|
||||
let assemblyFormat = "$base `,` `[` $shape `]` `,` `[` $strides `]` `,` `[` $offsets `]` attr-dict `:` type($result)";
|
||||
|
||||
|
||||
@@ -74,7 +74,7 @@ def TT_PtrType : TritonTypeDef<"Pointer", "ptr"> {
|
||||
// Scalar Pointer Type: `ptr<>`
|
||||
def TT_Ptr : TT_PtrOf<[AnyType]>;
|
||||
|
||||
// Tensor of Pointer Type
|
||||
// Tensor of Pointer Type: `tensor<ptr<>>`
|
||||
def TT_PtrTensor : TensorOf<[TT_Ptr]>;
|
||||
|
||||
// Tensor of Pointer Type or Pointer type: `tensor<ptr<>>` or `ptr<>`
|
||||
|
||||
@@ -14,6 +14,8 @@ namespace triton {
|
||||
|
||||
bool isTensorPointerType(Type type);
|
||||
|
||||
bool isTensorOrTensorPointerType(Type type);
|
||||
|
||||
unsigned getPointeeBitWidth(Type type);
|
||||
|
||||
Type getPointeeType(Type type);
|
||||
|
||||
@@ -9,7 +9,6 @@ namespace triton {
|
||||
std::unique_ptr<Pass> createCombineOpsPass();
|
||||
|
||||
std::unique_ptr<Pass> createReorderBroadcastPass();
|
||||
|
||||
std::unique_ptr<Pass>
|
||||
createRewriteTensorPointerPass(int computeCapability = 80);
|
||||
|
||||
|
||||
@@ -3,9 +3,13 @@ mlir_tablegen(Dialect.h.inc -gen-dialect-decls -dialect=triton_gpu)
|
||||
mlir_tablegen(Dialect.cpp.inc -gen-dialect-defs -dialect=triton_gpu)
|
||||
mlir_tablegen(Ops.h.inc -gen-op-decls)
|
||||
mlir_tablegen(Ops.cpp.inc -gen-op-defs)
|
||||
mlir_tablegen(Types.h.inc -gen-typedef-decls -typedefs-dialect=triton_gpu)
|
||||
mlir_tablegen(Types.cpp.inc -gen-typedef-defs -typedefs-dialect=triton_gpu)
|
||||
add_public_tablegen_target(TritonGPUTableGen)
|
||||
|
||||
set(LLVM_TARGET_DEFINITIONS TritonGPUAttrDefs.td)
|
||||
mlir_tablegen(TritonGPUAttrDefs.h.inc -gen-attrdef-decls)
|
||||
mlir_tablegen(TritonGPUAttrDefs.cpp.inc -gen-attrdef-defs)
|
||||
mlir_tablegen(OpsEnums.h.inc -gen-enum-decls)
|
||||
mlir_tablegen(OpsEnums.cpp.inc -gen-enum-defs)
|
||||
add_public_tablegen_target(TritonGPUAttrDefsIncGen)
|
||||
|
||||
@@ -7,6 +7,7 @@
|
||||
#include "mlir/IR/Dialect.h"
|
||||
|
||||
// TritonGPU depends on Triton
|
||||
#include "triton/Dialect/NVGPU/IR/Dialect.h"
|
||||
#include "triton/Dialect/Triton/IR/Dialect.h"
|
||||
#include "triton/Dialect/TritonGPU/IR/Attributes.h"
|
||||
#include "triton/Dialect/TritonGPU/IR/Dialect.h.inc"
|
||||
@@ -71,17 +72,41 @@ getWarpsPerCTAWithUniqueData(Attribute layout, ArrayRef<int64_t> tensorShape);
|
||||
|
||||
SmallVector<unsigned> getThreadsPerCTA(Attribute layout);
|
||||
|
||||
SmallVector<unsigned>
|
||||
getShapePerCTA(Attribute layout,
|
||||
ArrayRef<int64_t> tensorShape = ArrayRef<int64_t>());
|
||||
|
||||
SmallVector<unsigned> getOrder(Attribute layout);
|
||||
|
||||
CTALayoutAttr getCTALayout(Attribute layout);
|
||||
|
||||
SmallVector<unsigned> getCTAsPerCGA(Attribute layout);
|
||||
|
||||
SmallVector<unsigned> getCTASplitNum(Attribute layout);
|
||||
|
||||
SmallVector<unsigned> getCTAOrder(Attribute layout);
|
||||
|
||||
/* The difference between ShapePerCTATile and ShapePerCTA:
|
||||
* (1) ShapePerCTATile is defined by SizePerThread * ThreadsPerWarp *
|
||||
* WarpsPerCTA in each dimension and is independent from the tensor shape.
|
||||
* (2) ShapePerCTA is defined by shape / CTASplitNum in each dimension.
|
||||
* (3) In the implementation of emitIndices, ShapePerCTATile will
|
||||
* be replicated or wraped to fit ShapePerCTA.
|
||||
*/
|
||||
SmallVector<unsigned>
|
||||
getShapePerCTATile(Attribute layout,
|
||||
ArrayRef<int64_t> tensorShape = ArrayRef<int64_t>());
|
||||
|
||||
SmallVector<int64_t> getShapePerCTA(ArrayRef<unsigned> CTASplitNum,
|
||||
ArrayRef<int64_t> shape);
|
||||
SmallVector<int64_t> getShapePerCTA(Attribute layout, ArrayRef<int64_t> shape);
|
||||
SmallVector<int64_t> getShapePerCTA(Type type);
|
||||
|
||||
unsigned getNumWarpsPerCTA(Attribute layout);
|
||||
|
||||
unsigned getNumCTAs(Attribute layout);
|
||||
|
||||
bool isaDistributedLayout(Attribute layout);
|
||||
|
||||
bool isSharedEncoding(Value value);
|
||||
|
||||
bool isExpensiveCat(CatOp cat, Attribute &targetEncoding);
|
||||
bool isExpensiveCat(CatOp cat, Attribute targetEncoding);
|
||||
|
||||
} // namespace gpu
|
||||
} // namespace triton
|
||||
|
||||
@@ -41,6 +41,19 @@ Right now, Triton implements two classes of layouts: shared, and distributed.
|
||||
}];
|
||||
}
|
||||
|
||||
//===----------------------------------------------------------------------===//
|
||||
// CTA Layout
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
||||
def CTALayoutAttr : TritonGPU_Attr<"CTALayout"> {
|
||||
let parameters = (
|
||||
ins
|
||||
ArrayRefParameter<"unsigned">:$CTAsPerCGA,
|
||||
ArrayRefParameter<"unsigned">:$CTASplitNum,
|
||||
ArrayRefParameter<"unsigned">:$CTAOrder
|
||||
);
|
||||
}
|
||||
|
||||
//===----------------------------------------------------------------------===//
|
||||
// Shared Layout Encoding
|
||||
//===----------------------------------------------------------------------===//
|
||||
@@ -64,26 +77,49 @@ are stored contiguously
|
||||
_ _ _ _ /\_ _ _ _
|
||||
A_{2, 2} A_{2, 3} A_{2, 0} A_{2, 1} ... [phase 1] \ per phase = 2
|
||||
A_{3, 2} A_{3, 3} A_{3, 0} A_{3, 1} ... [phase 1] /
|
||||
|
||||
For MMAv3 eg Hopper GMMA, hasLeadingOffset should be true. In this case,
|
||||
when the matrix is stored in shared memory, there will be an offset not
|
||||
only in the stride dimension, but also in the leading dimension. For example,
|
||||
a matrix of size 16x128 and data type I8 is stored in the shared memory with
|
||||
64B-swizzle mode. The offset of the element with index (0, 64) will be 16*64,
|
||||
compared to 1*64 when the hasLeadingOffset is false.
|
||||
}];
|
||||
|
||||
// swizzle info: vec, perPhase, maxPhase
|
||||
// order: the fastest-changing axis first
|
||||
let parameters = (
|
||||
ins
|
||||
// swizzle info
|
||||
"unsigned":$vec, "unsigned":$perPhase, "unsigned":$maxPhase,
|
||||
ArrayRefParameter<"unsigned", "order of axes by the rate of changing">:$order
|
||||
"unsigned":$vec,
|
||||
"unsigned":$perPhase,
|
||||
"unsigned":$maxPhase,
|
||||
ArrayRefParameter<"unsigned">:$order,
|
||||
"CTALayoutAttr":$CTALayout,
|
||||
"bool":$hasLeadingOffset
|
||||
);
|
||||
|
||||
let builders = [
|
||||
AttrBuilder<(ins "unsigned":$vec,
|
||||
"unsigned":$perPhase,
|
||||
"unsigned":$maxPhase,
|
||||
"ArrayRef<unsigned>":$order,
|
||||
"CTALayoutAttr":$CTALayout), [{
|
||||
bool hasLeadingOffset = false; // default value
|
||||
return $_get(context, vec, perPhase, maxPhase, order, CTALayout, hasLeadingOffset);
|
||||
}]>,
|
||||
|
||||
AttrBuilder<(ins "DotOperandEncodingAttr":$dotOpEnc,
|
||||
"ArrayRef<int64_t>":$shape,
|
||||
"ArrayRef<unsigned>":$order,
|
||||
"CTALayoutAttr":$CTALayout,
|
||||
"unsigned":$typeWidthInBit), [{
|
||||
auto mmaEnc = dotOpEnc.getParent().dyn_cast<MmaEncodingAttr>();
|
||||
|
||||
if(!mmaEnc)
|
||||
return $_get(context, 1, 1, 1, order);
|
||||
return get(context, 1, 1, 1, order, CTALayout);
|
||||
|
||||
int opIdx = dotOpEnc.getOpIdx();
|
||||
auto shapePerCTA = getShapePerCTA(CTALayout.getCTASplitNum(), shape);
|
||||
|
||||
// number of rows per phase
|
||||
|
||||
@@ -92,34 +128,34 @@ A_{3, 2} A_{3, 3} A_{3, 0} A_{3, 1} ... [phase 1] /
|
||||
|
||||
// ---- begin Volta ----
|
||||
if (mmaEnc.isVolta()) {
|
||||
int perPhase = 128 / (shape[order[0]] * (typeWidthInBit / 8));
|
||||
int perPhase = 128 / (shapePerCTA[order[0]] * (typeWidthInBit / 8));
|
||||
perPhase = std::max<int>(perPhase, 1);
|
||||
bool is_row = order[0] != 0;
|
||||
bool is_vec4 = opIdx == 0 ? !is_row && (shape[order[0]] <= 16) :
|
||||
is_row && (shape[order[0]] <= 16);
|
||||
bool is_vec4 = opIdx == 0 ? !is_row && (shapePerCTA[order[0]] <= 16) :
|
||||
is_row && (shapePerCTA[order[0]] <= 16);
|
||||
int pack_size = opIdx == 0 ? ((is_row || is_vec4) ? 1 : 2) :
|
||||
((is_row && !is_vec4) ? 2 : 1);
|
||||
int rep = 2 * pack_size;
|
||||
int maxPhase = (order[inner] == 1 ? 8 : 4) / perPhase;
|
||||
int vec = 2 * rep;
|
||||
return $_get(context, vec, perPhase, maxPhase, order);
|
||||
return get(context, vec, perPhase, maxPhase, order, CTALayout);
|
||||
}
|
||||
|
||||
// ---- begin Ampere ----
|
||||
if (mmaEnc.isAmpere()) {
|
||||
int perPhase = 128 / (shape[order[0]] * 4 / dotOpEnc.getMMAv2kWidth());
|
||||
int perPhase = 128 / (shapePerCTA[order[0]] * 4 / dotOpEnc.getMMAv2kWidth());
|
||||
perPhase = std::max<int>(perPhase, 1);
|
||||
std::vector<size_t> matShape = {8, 8, 4 * dotOpEnc.getMMAv2kWidth()};
|
||||
// for now, disable swizzle when using transposed int8 tensor cores
|
||||
if ((32 / typeWidthInBit != dotOpEnc.getMMAv2kWidth()) && order[0] == inner)
|
||||
return $_get(context, 1, 1, 1, order);
|
||||
return get(context, 1, 1, 1, order, CTALayout);
|
||||
|
||||
// --- handle A operand ---
|
||||
if (opIdx == 0) { // compute swizzling for A operand
|
||||
int vec = (order[0] == 1) ? matShape[2] : matShape[0]; // k : m
|
||||
int mmaStride = (order[0] == 1) ? matShape[0] : matShape[2];
|
||||
int maxPhase = mmaStride / perPhase;
|
||||
return $_get(context, vec, perPhase, maxPhase, order);
|
||||
return get(context, vec, perPhase, maxPhase, order, CTALayout);
|
||||
}
|
||||
|
||||
// --- handle B operand ---
|
||||
@@ -127,12 +163,19 @@ A_{3, 2} A_{3, 3} A_{3, 0} A_{3, 1} ... [phase 1] /
|
||||
int vec = (order[0] == 1) ? matShape[1] : matShape[2]; // n : k
|
||||
int mmaStride = (order[0] == 1) ? matShape[2] : matShape[1];
|
||||
int maxPhase = mmaStride / perPhase;
|
||||
return $_get(context, vec, perPhase, maxPhase, order);
|
||||
return get(context, vec, perPhase, maxPhase, order, CTALayout);
|
||||
}
|
||||
|
||||
llvm_unreachable("invalid operand index");
|
||||
}
|
||||
|
||||
// ---- begin version 3 ----
|
||||
if (mmaEnc.isHopper()) {
|
||||
llvm_unreachable("SharedEncodingAttr builder when the MMAEncodingAttr"
|
||||
" is Hopper has not been implemented yet");
|
||||
return $_get(context, 1, 1, 1, order, CTALayout, true);
|
||||
}
|
||||
|
||||
// ---- not implemented ----
|
||||
llvm_unreachable("unsupported swizzling for provided MMA version");
|
||||
}]>,
|
||||
@@ -140,9 +183,38 @@ A_{3, 2} A_{3, 3} A_{3, 0} A_{3, 1} ... [phase 1] /
|
||||
AttrBuilder<(ins "DotOperandEncodingAttr":$dotOpEnc,
|
||||
"ArrayRef<int64_t>":$shape,
|
||||
"ArrayRef<unsigned>":$order,
|
||||
"CTALayoutAttr":$CTALayout,
|
||||
"Type":$eltTy), [{
|
||||
unsigned bitwidth = eltTy.getIntOrFloatBitWidth();
|
||||
return get(context, dotOpEnc, shape, order, bitwidth);
|
||||
return get(context, dotOpEnc, shape, order, CTALayout, bitwidth);
|
||||
}]>,
|
||||
|
||||
AttrBuilder<(ins "ArrayRef<int64_t>":$shape,
|
||||
"ArrayRef<unsigned>":$order,
|
||||
"CTALayoutAttr":$CTALayout,
|
||||
"Type":$eltTy), [{
|
||||
auto shapePerCTA = getShapePerCTA(CTALayout.getCTASplitNum(), shape);
|
||||
|
||||
int32_t eleBitWidth = eltTy.getIntOrFloatBitWidth();
|
||||
int32_t vec = 128 / eleBitWidth, perPhase = 1, maxPhase = 1;
|
||||
|
||||
// get proper shared memory swizzling mode from the contiguous dimension
|
||||
// size of the origin blocked layout.
|
||||
auto contigDimSizeInByte = shapePerCTA[order[0]] * eleBitWidth / 8;
|
||||
if (contigDimSizeInByte >= 128 && contigDimSizeInByte % 128 == 0) {
|
||||
perPhase = 1;
|
||||
maxPhase = 8;
|
||||
} else if (contigDimSizeInByte >= 64 && contigDimSizeInByte % 64 == 0) {
|
||||
perPhase = 2;
|
||||
maxPhase = 4;
|
||||
} else if (contigDimSizeInByte >= 32 && contigDimSizeInByte % 32 == 0) {
|
||||
perPhase = 4;
|
||||
maxPhase = 2;
|
||||
} else {
|
||||
llvm_unreachable("unsupported shared memory layout for MMAv3");
|
||||
}
|
||||
|
||||
return $_get(context, vec, perPhase, maxPhase, order, CTALayout, true);
|
||||
}]>
|
||||
];
|
||||
|
||||
@@ -194,7 +266,7 @@ used to promote memory coalescing in LoadInst and StoreInst.
|
||||
It is characterized by three tuples -- thread tile size, warp tile size, and block tile size -- which
|
||||
specify the amount of elements owned by each CUDA thread, warp and CTA respectively.
|
||||
|
||||
For example, a row-major coalesced layout may partition a 16x16 tensor over 2 warps (i.e. 64 threads) as follows.
|
||||
Example 1, a row-major coalesced layout may partition a 16x16 tensor over 2 warps (i.e. 64 threads) as follows:
|
||||
|
||||
[ 0 0 1 1 2 2 3 3 ; 32 32 33 33 34 34 35 35 ]
|
||||
[ 0 0 1 1 2 2 3 3 ; 32 32 33 33 34 34 35 35 ]
|
||||
@@ -210,82 +282,143 @@ for
|
||||
sizePerThread = {2, 2}
|
||||
threadsPerWarp = {8, 4}
|
||||
warpsPerCTA = {1, 2}
|
||||
CTAsPerCGA = {1, 1}
|
||||
}>
|
||||
|
||||
Example 2, a row-major coalesced layout may partition a 32x32 tensor over 2 warps (i.e. 64 threads) as follows:
|
||||
|
||||
[ 0 0 1 1 2 2 3 3 ; 32 32 33 33 34 34 35 35 0 0 1 1 2 2 3 3 ; 32 32 33 33 34 34 35 35 ]
|
||||
[ 0 0 1 1 2 2 3 3 ; 32 32 33 33 34 34 35 35 0 0 1 1 2 2 3 3 ; 32 32 33 33 34 34 35 35 ]
|
||||
[ 4 4 5 5 6 6 7 7 ; 36 36 37 37 38 38 39 39 4 4 5 5 6 6 7 7 ; 36 36 37 37 38 38 39 39 ]
|
||||
[ 4 4 5 5 6 6 7 7 ; 36 36 37 37 38 38 39 39 4 4 5 5 6 6 7 7 ; 36 36 37 37 38 38 39 39 ]
|
||||
... ...
|
||||
[ 28 28 29 29 30 30 31 31 ; 60 60 61 61 62 62 63 63 28 28 29 29 30 30 31 31 ; 60 60 61 61 62 62 63 63 ]
|
||||
[ 28 28 29 29 30 30 31 31 ; 60 60 61 61 62 62 63 63 28 28 29 29 30 30 31 31 ; 60 60 61 61 62 62 63 63 ]
|
||||
[ 0 0 1 1 2 2 3 3 ; 32 32 33 33 34 34 35 35 0 0 1 1 2 2 3 3 ; 32 32 33 33 34 34 35 35 ]
|
||||
[ 0 0 1 1 2 2 3 3 ; 32 32 33 33 34 34 35 35 0 0 1 1 2 2 3 3 ; 32 32 33 33 34 34 35 35 ]
|
||||
[ 4 4 5 5 6 6 7 7 ; 36 36 37 37 38 38 39 39 4 4 5 5 6 6 7 7 ; 36 36 37 37 38 38 39 39 ]
|
||||
[ 4 4 5 5 6 6 7 7 ; 36 36 37 37 38 38 39 39 4 4 5 5 6 6 7 7 ; 36 36 37 37 38 38 39 39 ]
|
||||
... ...
|
||||
[ 28 28 29 29 30 30 31 31 ; 60 60 61 61 62 62 63 63 28 28 29 29 30 30 31 31 ; 60 60 61 61 62 62 63 63 ]
|
||||
[ 28 28 29 29 30 30 31 31 ; 60 60 61 61 62 62 63 63 28 28 29 29 30 30 31 31 ; 60 60 61 61 62 62 63 63 ]
|
||||
for
|
||||
|
||||
#triton_gpu.blocked_layout<{
|
||||
sizePerThread = {2, 2}
|
||||
threadsPerWarp = {8, 4}
|
||||
warpsPerCTA = {1, 2}
|
||||
CTAsPerCGA = {1, 1}
|
||||
}>
|
||||
|
||||
Example 3, A row-major coalesced layout may partition a 32x32 tensor over 2 warps (i.e. 64 threads) and
|
||||
4 CTAs (taking 2x2 for example) as follows:
|
||||
|
||||
CTA [0,0] CTA [0,1]
|
||||
[ 0 0 1 1 2 2 3 3 ; 32 32 33 33 34 34 35 35 ] [ 0 0 1 1 2 2 3 3 ; 32 32 33 33 34 34 35 35 ]
|
||||
[ 0 0 1 1 2 2 3 3 ; 32 32 33 33 34 34 35 35 ] [ 0 0 1 1 2 2 3 3 ; 32 32 33 33 34 34 35 35 ]
|
||||
[ 4 4 5 5 6 6 7 7 ; 36 36 37 37 38 38 39 39 ] [ 4 4 5 5 6 6 7 7 ; 36 36 37 37 38 38 39 39 ]
|
||||
[ 4 4 5 5 6 6 7 7 ; 36 36 37 37 38 38 39 39 ] [ 4 4 5 5 6 6 7 7 ; 36 36 37 37 38 38 39 39 ]
|
||||
... ...
|
||||
[ 28 28 29 29 30 30 31 31 ; 60 60 61 61 62 62 63 63 ] [ 28 28 29 29 30 30 31 31 ; 60 60 61 61 62 62 63 63 ]
|
||||
[ 28 28 29 29 30 30 31 31 ; 60 60 61 61 62 62 63 63 ] [ 28 28 29 29 30 30 31 31 ; 60 60 61 61 62 62 63 63 ]
|
||||
|
||||
CTA [1,0] CTA [1,1]
|
||||
[ 0 0 1 1 2 2 3 3 ; 32 32 33 33 34 34 35 35 ] [ 0 0 1 1 2 2 3 3 ; 32 32 33 33 34 34 35 35 ]
|
||||
[ 0 0 1 1 2 2 3 3 ; 32 32 33 33 34 34 35 35 ] [ 0 0 1 1 2 2 3 3 ; 32 32 33 33 34 34 35 35 ]
|
||||
[ 4 4 5 5 6 6 7 7 ; 36 36 37 37 38 38 39 39 ] [ 4 4 5 5 6 6 7 7 ; 36 36 37 37 38 38 39 39 ]
|
||||
[ 4 4 5 5 6 6 7 7 ; 36 36 37 37 38 38 39 39 ] [ 4 4 5 5 6 6 7 7 ; 36 36 37 37 38 38 39 39 ]
|
||||
... ...
|
||||
[ 28 28 29 29 30 30 31 31 ; 60 60 61 61 62 62 63 63 ] [ 28 28 29 29 30 30 31 31 ; 60 60 61 61 62 62 63 63 ]
|
||||
[ 28 28 29 29 30 30 31 31 ; 60 60 61 61 62 62 63 63 ] [ 28 28 29 29 30 30 31 31 ; 60 60 61 61 62 62 63 63 ]
|
||||
for
|
||||
|
||||
#triton_gpu.blocked_layout<{
|
||||
sizePerThread = {2, 2}
|
||||
threadsPerWarp = {8, 4}
|
||||
warpsPerCTA = {1, 2}
|
||||
CTAsPerCGA = {2, 2}
|
||||
}>
|
||||
}];
|
||||
|
||||
|
||||
let builders = [
|
||||
// Custom builder initializes sizePerWarp and sizePerCTA automatically
|
||||
// TODO: compiles on MacOS but not linux?
|
||||
// AttrBuilder<(ins "ArrayRef<unsigned>":$sizePerThread,
|
||||
// "ArrayRef<unsigned>":$threadsPerWarp,
|
||||
// "ArrayRef<unsigned>":$warpsPerCTA,
|
||||
// "ArrayRef<unsigned>":$order), [{
|
||||
// int rank = threadsPerWarp.size();
|
||||
// SmallVector<unsigned, 4> sizePerWarp(rank);
|
||||
// SmallVector<unsigned, 4> sizePerCTA(rank);
|
||||
// for (unsigned i = 0; i < rank; i++) {
|
||||
// sizePerWarp.push_back(sizePerThread[i] * threadsPerWarp[i]);
|
||||
// sizePerCTA.push_back(sizePerWarp[i] * warpsPerCTA[i]);
|
||||
// }
|
||||
// return $_get(context, sizePerThread, threadsPerWarp, warpsPerCTA, order, sizePerWarp, sizePerCTA);
|
||||
// }]>,
|
||||
// Custom builder initializes sizePerWarp and sizePerCTA automatically
|
||||
// Default builder takes sizePerThread, order and numWarps, and tries to
|
||||
// pack numWarps*32 threads in the provided order for use in a type
|
||||
// of the given shape.
|
||||
AttrBuilder<(ins "ArrayRef<int64_t>":$shape,
|
||||
"ArrayRef<unsigned>":$sizePerThread,
|
||||
"ArrayRef<unsigned>":$order,
|
||||
"unsigned":$numWarps,
|
||||
"unsigned":$threadsPerWarp), [{
|
||||
int rank = sizePerThread.size();
|
||||
unsigned remainingLanes = threadsPerWarp;
|
||||
unsigned remainingThreads = numWarps*threadsPerWarp;
|
||||
unsigned remainingWarps = numWarps;
|
||||
unsigned prevLanes = 1;
|
||||
unsigned prevWarps = 1;
|
||||
SmallVector<unsigned, 4> rankedThreadsPerWarp(rank);
|
||||
SmallVector<unsigned, 4> warpsPerCTA(rank);
|
||||
for (int _dim = 0; _dim < rank - 1; ++_dim) {
|
||||
int i = order[_dim];
|
||||
unsigned threadsPerCTA = std::clamp<unsigned>(remainingThreads, 1, shape[i] / sizePerThread[i]);
|
||||
rankedThreadsPerWarp[i] = std::clamp<unsigned>(threadsPerCTA, 1, remainingLanes);
|
||||
warpsPerCTA[i] = std::clamp<unsigned>(threadsPerCTA / rankedThreadsPerWarp[i], 1, remainingWarps);
|
||||
remainingWarps /= warpsPerCTA[i];
|
||||
remainingLanes /= rankedThreadsPerWarp[i];
|
||||
remainingThreads /= threadsPerCTA;
|
||||
prevLanes *= rankedThreadsPerWarp[i];
|
||||
prevWarps *= warpsPerCTA[i];
|
||||
}
|
||||
// Expand the last dimension to fill the remaining lanes and warps
|
||||
rankedThreadsPerWarp[order[rank-1]] = threadsPerWarp / prevLanes;
|
||||
warpsPerCTA[order[rank-1]] = numWarps / prevWarps;
|
||||
|
||||
return $_get(context, sizePerThread, rankedThreadsPerWarp, warpsPerCTA, order);
|
||||
|
||||
}]>
|
||||
];
|
||||
|
||||
let extraClassDeclaration = extraBaseClassDeclaration # [{
|
||||
SliceEncodingAttr squeeze(int axis);
|
||||
}];
|
||||
|
||||
let parameters = (
|
||||
ins
|
||||
ArrayRefParameter<"unsigned">:$sizePerThread,
|
||||
ArrayRefParameter<"unsigned">:$threadsPerWarp,
|
||||
ArrayRefParameter<"unsigned">:$warpsPerCTA,
|
||||
// fastest-changing axis first
|
||||
ArrayRefParameter<
|
||||
"unsigned",
|
||||
"order of axes by the rate of changing"
|
||||
>:$order
|
||||
// These attributes can be inferred from the rest
|
||||
// ArrayRefParameter<"unsigned">:$sizePerWarp,
|
||||
// ArrayRefParameter<"unsigned">:$sizePerCTA
|
||||
ArrayRefParameter<"unsigned">:$order, // the fastest-changing axis first
|
||||
"CTALayoutAttr":$CTALayout
|
||||
);
|
||||
|
||||
let builders = [
|
||||
AttrBuilder<(ins "ArrayRef<int64_t>":$shape,
|
||||
"ArrayRef<unsigned>":$sizePerThread,
|
||||
"ArrayRef<unsigned>":$order,
|
||||
"unsigned":$numWarps,
|
||||
"unsigned":$numThreadsPerWarp,
|
||||
"CTALayoutAttr":$CTALayout), [{
|
||||
unsigned rank = sizePerThread.size();
|
||||
SmallVector<unsigned, 4> threadsPerWarp(rank);
|
||||
SmallVector<unsigned, 4> warpsPerCTA(rank);
|
||||
SmallVector<int64_t> shapePerCTA = getShapePerCTA(CTALayout.getCTASplitNum(), shape);
|
||||
|
||||
unsigned remainingLanes = numThreadsPerWarp;
|
||||
unsigned remainingThreads = numWarps * numThreadsPerWarp;
|
||||
unsigned remainingWarps = numWarps;
|
||||
unsigned prevLanes = 1;
|
||||
unsigned prevWarps = 1;
|
||||
|
||||
// starting from the contiguous dimension
|
||||
for (unsigned d = 0; d < rank - 1; ++d) {
|
||||
unsigned i = order[d];
|
||||
unsigned threadsPerCTA = std::clamp<unsigned>(remainingThreads, 1, shapePerCTA[i] / sizePerThread[i]);
|
||||
threadsPerWarp[i] = std::clamp<unsigned>(threadsPerCTA, 1, remainingLanes);
|
||||
warpsPerCTA[i] = std::clamp<unsigned>(threadsPerCTA / threadsPerWarp[i], 1, remainingWarps);
|
||||
remainingWarps /= warpsPerCTA[i];
|
||||
remainingLanes /= threadsPerWarp[i];
|
||||
remainingThreads /= threadsPerCTA;
|
||||
prevLanes *= threadsPerWarp[i];
|
||||
prevWarps *= warpsPerCTA[i];
|
||||
}
|
||||
|
||||
// Expand the last dimension to fill the remaining lanes and warps
|
||||
threadsPerWarp[order[rank - 1]] = numThreadsPerWarp / prevLanes;
|
||||
warpsPerCTA[order[rank - 1]] = numWarps / prevWarps;
|
||||
|
||||
return $_get(context, sizePerThread, threadsPerWarp, warpsPerCTA, order, CTALayout);
|
||||
}]>,
|
||||
|
||||
AttrBuilder<(ins "ArrayRef<int64_t>":$shape,
|
||||
"ArrayRef<unsigned>":$sizePerThread,
|
||||
"ArrayRef<unsigned>":$order,
|
||||
"unsigned":$numWarps,
|
||||
"unsigned":$numThreadsPerWarp,
|
||||
"unsigned":$numCTAs), [{
|
||||
unsigned rank = sizePerThread.size();
|
||||
SmallVector<unsigned, 4> CTAsPerCGA(rank);
|
||||
SmallVector<unsigned, 4> CTASplitNum(rank);
|
||||
ArrayRef<unsigned> CTAOrder = order;
|
||||
|
||||
unsigned remainingCTAs = numCTAs;
|
||||
|
||||
// starting from the most strided dimension
|
||||
for (int d = rank - 1; d >= 0; --d) {
|
||||
unsigned i = order[d];
|
||||
CTAsPerCGA[i] = std::clamp<unsigned>(remainingCTAs, 1, shape[i] / sizePerThread[i]);
|
||||
CTASplitNum[i] = CTAsPerCGA[i];
|
||||
remainingCTAs /= CTAsPerCGA[i];
|
||||
}
|
||||
|
||||
CTAsPerCGA[rank - 1] *= remainingCTAs; // wrap at CTA level
|
||||
|
||||
CTALayoutAttr CTALayout = CTALayoutAttr::get(context, CTAsPerCGA, CTASplitNum, CTAOrder);
|
||||
return get(context, shape, sizePerThread, order, numWarps, numThreadsPerWarp, CTALayout);
|
||||
}]>
|
||||
];
|
||||
|
||||
let extraClassDeclaration = extraBaseClassDeclaration # [{
|
||||
SliceEncodingAttr squeeze(int axis);
|
||||
}];
|
||||
|
||||
let hasCustomAssemblyFormat = 1;
|
||||
}
|
||||
|
||||
@@ -381,13 +514,17 @@ For example, the matrix L corresponding to blockTileSize=[32,16] is:
|
||||
ins
|
||||
"unsigned":$versionMajor,
|
||||
"unsigned":$versionMinor,
|
||||
ArrayRefParameter<"unsigned">:$warpsPerCTA
|
||||
ArrayRefParameter<"unsigned">:$warpsPerCTA,
|
||||
"CTALayoutAttr":$CTALayout,
|
||||
ArrayRefParameter<"unsigned">:$instrShape
|
||||
);
|
||||
|
||||
let builders = [
|
||||
// Specially for MMAV1(Volta)
|
||||
AttrBuilder<(ins "int":$versionMajor,
|
||||
"int":$numWarps,
|
||||
"CTALayoutAttr":$CTALayout,
|
||||
"ArrayRef<unsigned>":$instrShape,
|
||||
"ArrayRef<int64_t>":$shapeC,
|
||||
"bool":$isARow,
|
||||
"bool":$isBRow,
|
||||
@@ -401,7 +538,6 @@ For example, the matrix L corresponding to blockTileSize=[32,16] is:
|
||||
(isAVec4 * (1<<2)) |\
|
||||
(isBVec4 * (1<<3));
|
||||
|
||||
|
||||
// TODO: Share code with
|
||||
// DotOpMmaV1ConversionHelper::AParam/BParam, since same code to compute the
|
||||
// rep,spw and fpw.
|
||||
@@ -426,12 +562,14 @@ For example, the matrix L corresponding to blockTileSize=[32,16] is:
|
||||
wpt[1] = std::clamp<int>(wpt[1] * 2, 1, shapeC[1] / spw[1]);
|
||||
} while (wpt_nm1 != wpt);
|
||||
|
||||
return $_get(context, versionMajor, versionMinor, wpt);
|
||||
return $_get(context, versionMajor, versionMinor, wpt, CTALayout, instrShape);
|
||||
}]>,
|
||||
|
||||
|
||||
AttrBuilder<(ins "int":$versionMajor,
|
||||
"int":$numWarps,
|
||||
"CTALayoutAttr":$CTALayout,
|
||||
"ArrayRef<unsigned>":$instrShape,
|
||||
"ArrayRef<int64_t>":$shapeA,
|
||||
"ArrayRef<int64_t>":$shapeB,
|
||||
"ArrayRef<int64_t>":$shapeC,
|
||||
@@ -441,15 +579,20 @@ For example, the matrix L corresponding to blockTileSize=[32,16] is:
|
||||
assert(versionMajor == 1 && "This builder is specially for versionMajor==1");
|
||||
bool isAVec4 = !isARow && (shapeA[isARow] <= 16);
|
||||
bool isBVec4 = isBRow && (shapeB[isBRow] <= 16);
|
||||
return get(context, versionMajor, numWarps, shapeC, isARow, isBRow, isAVec4, isBVec4, id);
|
||||
return get(context, versionMajor, numWarps, CTALayout, instrShape, shapeC, isARow, isBRow, isAVec4, isBVec4, id);
|
||||
}]>
|
||||
];
|
||||
|
||||
let extraClassDeclaration = extraBaseClassDeclaration # [{
|
||||
bool isVolta() const;
|
||||
bool isAmpere() const;
|
||||
bool isHopper() const;
|
||||
|
||||
unsigned getElemsPerThreadOfOperand(int opIdx, ArrayRef<int64_t> shape) const;
|
||||
|
||||
// Get [isARow, isBRow, isAVec4, isBVec4, id] from versionMinor
|
||||
std::tuple<bool, bool, bool, bool, int> decodeVoltaLayoutStates() const;
|
||||
|
||||
// Number of bits in versionMinor to hold the ID of the MMA encoding instance.
|
||||
// Here 5 bits can hold 32 IDs in a single module.
|
||||
static constexpr int numBitsToHoldMmaV1ID{5};
|
||||
@@ -544,6 +687,4 @@ section 9.7.13.4.1 for more details.
|
||||
}];
|
||||
}
|
||||
|
||||
|
||||
|
||||
#endif
|
||||
|
||||
@@ -16,6 +16,7 @@ def TritonGPU_Dialect : Dialect {
|
||||
|
||||
let dependentDialects = [
|
||||
"triton::TritonDialect",
|
||||
"mlir::triton::nvgpu::NVGPUDialect",
|
||||
"mlir::gpu::GPUDialect",
|
||||
"tensor::TensorDialect",
|
||||
];
|
||||
@@ -23,14 +24,27 @@ def TritonGPU_Dialect : Dialect {
|
||||
let extraClassDeclaration = [{
|
||||
static std::string getNumWarpsAttrName() { return "triton_gpu.num-warps"; }
|
||||
static int getNumWarps(ModuleOp mod) {
|
||||
Attribute numWarps = mod->getDiscardableAttr("triton_gpu.num-warps");
|
||||
if(!numWarps)
|
||||
if(!mod->hasAttr("triton_gpu.num-warps"))
|
||||
llvm::report_fatal_error(
|
||||
"TritonGPU module should contain a triton_gpu.num-warps attribute");
|
||||
return numWarps.cast<IntegerAttr>().getInt();
|
||||
return mod->getAttr("triton_gpu.num-warps").cast<IntegerAttr>().getInt();
|
||||
}
|
||||
static int getNumCTAs(ModuleOp mod) {
|
||||
if(!mod->hasAttr("triton_gpu.num-ctas"))
|
||||
llvm::report_fatal_error(
|
||||
"TritonGPU module should contain a triton_gpu.num-ctas attribute");
|
||||
return mod->getAttr("triton_gpu.num-ctas").cast<IntegerAttr>().getInt();
|
||||
}
|
||||
static int getComputeCapability(ModuleOp mod) {
|
||||
if(!mod->hasAttr("triton_gpu.compute-capability"))
|
||||
llvm::report_fatal_error(
|
||||
"TritonGPU module should contain a triton_gpu.compute-capability attribute");
|
||||
return mod->getAttrOfType<IntegerAttr>("triton_gpu.compute-capability").getInt();
|
||||
}
|
||||
void registerTypes();
|
||||
|
||||
static std::string getThreadsPerWarpAttrName() { return "triton_gpu.threads-per-warp"; }
|
||||
|
||||
static int getThreadsPerWarp(ModuleOp mod) {
|
||||
Attribute threadsPerWarp = mod->getDiscardableAttr("triton_gpu.threads-per-warp");
|
||||
if(!threadsPerWarp) {
|
||||
@@ -38,7 +52,6 @@ def TritonGPU_Dialect : Dialect {
|
||||
}
|
||||
return threadsPerWarp.cast<IntegerAttr>().getInt();
|
||||
}
|
||||
|
||||
}];
|
||||
|
||||
let useDefaultAttributePrinterParser = 1;
|
||||
|
||||
@@ -2,6 +2,7 @@
|
||||
#define TRITONGPU_OPS
|
||||
|
||||
include "triton/Dialect/TritonGPU/IR/TritonGPUDialect.td"
|
||||
include "triton/Dialect/TritonGPU/IR/TritonGPUTypes.td"
|
||||
include "triton/Dialect/TritonGPU/IR/TritonGPUAttrDefs.td"
|
||||
include "mlir/Dialect/Arith/IR/ArithBase.td"
|
||||
include "triton/Dialect/Triton/IR/TritonTypes.td"
|
||||
@@ -46,6 +47,20 @@ def TTG_AsyncWaitOp : TTG_Op<"async_wait"> {
|
||||
}];
|
||||
}
|
||||
|
||||
def TTG_AsyncBulkWaitOp : TTG_Op<"async_bulk_wait"> {
|
||||
let summary = "async bulk wait";
|
||||
|
||||
let arguments = (ins I32Attr:$num);
|
||||
|
||||
let assemblyFormat = "attr-dict";
|
||||
|
||||
let extraClassDeclaration = [{
|
||||
static bool isSupported(int computeCapability) {
|
||||
return computeCapability >= 90;
|
||||
}
|
||||
}];
|
||||
}
|
||||
|
||||
def TTG_AsyncCommitGroupOp : TTG_Op<"async_commit_group"> {
|
||||
let summary = "async commit group";
|
||||
|
||||
@@ -58,6 +73,18 @@ def TTG_AsyncCommitGroupOp : TTG_Op<"async_commit_group"> {
|
||||
}];
|
||||
}
|
||||
|
||||
def TTG_AsyncBulkCommitGroupOp : TTG_Op<"async_bulk_commit_group"> {
|
||||
let summary = "async bulk commit group";
|
||||
|
||||
let assemblyFormat = "attr-dict";
|
||||
|
||||
let extraClassDeclaration = [{
|
||||
static bool isSupported(int computeCapability) {
|
||||
return computeCapability >= 90;
|
||||
}
|
||||
}];
|
||||
}
|
||||
|
||||
|
||||
// Port Arith_CmpIOp & Arith_CmpFOp & Std_SelectOp to TritonGPU.
|
||||
// This is needed because these ops don't
|
||||
@@ -106,6 +133,98 @@ def TTG_SelectOp : TTG_Op<"select", [Pure, Elementwise,
|
||||
let results = (outs TT_Type:$result);
|
||||
}
|
||||
|
||||
// TODO[goostavz]: extract a base class for InsertSlice & InsertSliceAsync once the op definition is verified
|
||||
def TTG_InsertSliceOp : TTG_Op<"insert_slice",
|
||||
[AttrSizedOperandSegments,
|
||||
ResultsAreSharedEncoding,
|
||||
MemoryEffects<[MemRead, MemWrite]>,
|
||||
TypesMatchWith<"infer mask type from src type",
|
||||
"src", "mask", "getI1SameShape($_self)",
|
||||
"($_op.getOperands().size() <= 3) || std::equal_to<>()">,
|
||||
TypesMatchWith<"infer other type from src type",
|
||||
"src", "other", "getPointeeType($_self)",
|
||||
"($_op.getOperands().size() <= 4) || std::equal_to<>()">]> {
|
||||
let summary = "insert slice";
|
||||
|
||||
let description = [{
|
||||
This operation inserts a tensor `$src` into another tensor `$dst` as specified by the operation’s
|
||||
`$index` argument and `$axis` attribute.
|
||||
|
||||
It returns a copy of `$dst` with the proper slice updated with the value of `$src`.
|
||||
|
||||
When converting from `tt.load` to `triton_gpu.insert_slice`, the `$evict`, `$cache`, and `$isVolatile` fields
|
||||
might be ignored on certain hardware. For example, on NVIDIA GPUs, the cache policy is determined by the backend,
|
||||
and `$evict` and `$isVolatile` are ignored because they apply to L1 cache only.
|
||||
|
||||
The insert_slice operation supports the following arguments:
|
||||
|
||||
* src: the tensor that is inserted.
|
||||
* dst: the tensor into which the `$src` tensor is inserted.
|
||||
* index: the index of the `$src` tensor at the given `$axis` from which the `$dst` tensor is inserted into
|
||||
* mask: optional tensor-rank number of boolean masks which specify which
|
||||
elements of the `$src` tensor are inserted into the `$dst` tensor.
|
||||
* other: optional tensor-rank number of other tensors which specify what
|
||||
values are inserted into the `$dst` tensor if the corresponding
|
||||
element of the `$mask` tensor is false.
|
||||
|
||||
ttgpu.load_tile_async depracate
|
||||
triton_gpu.insert_slice might be further lowered into triton_gpu_async for different hardware implementations
|
||||
|
||||
like tt.load, ttgpu.insert_slice/insert_slice_async has two modes up to the type of src
|
||||
mode 1: ptr/src is a tensor of pointers
|
||||
mode 2: ptr/src is a tensor pointer
|
||||
|
||||
Some typical lowering paths are:
|
||||
in case the load is pipelined by the pipeline pass( load is inside kBlock loop, which means "pipeline pass):
|
||||
Load from global + store to shared : tt.load(mode 1) -(tt->ttgpu+Coalesce)-> tt.load(mode 1) -(Pipeline)-> ttgpu.insert_slice(mode 1)
|
||||
Non-bulk cp.async : tt.load(mode 1) -(tt->ttgpu+Coalesce)-> tt.load(mode 1) -(Pipeline)-> ttgpu.insert_slice(mode 1) -(MaterializeLoad)> ttgpu.insert_slice_async(mode 1) + ttgpu.await-> llvm
|
||||
TMA load : tt.load(mode 2) -(tt->ttgpu+Coalesce)-> tt.load(mode 2) -(Pipeline)-> ttgpu.insert_slice(mode 2) -(MaterializeLoad)> ttgpu.insert_slice_async_v2(mode 2) + ttgpu.await-> llvm
|
||||
|
||||
otherwise:
|
||||
Load from global + store to shared : tt.load(mode 1) -(tt->ttgpu+Coalesce)-> tt.load(mode 1)
|
||||
Non-bulk cp.async : tt.load(mode 1) -(tt->ttgpu+Coalesce)-> tt.load(mode 1) -> ... -(MaterializeLoad)-> ttgpu.insert_slice_async(mode 1) + ttgpu.await -> llvm
|
||||
TMA load : tt.load(mode 2) -(tt->ttgpu+Coalesce)-> tt.load(mode 2) -> ... -(MaterializeLoad)-> ttgpu.insert_slice_async(mode 2) + ttgpu.await -> llvm
|
||||
|
||||
Example:
|
||||
|
||||
```
|
||||
%1 = triton_gpu.alloc_tensor : tensor<2x32xf32>
|
||||
%2 = triton_gpu.insert_slice %0, %1, %index { axis = 0 } : tensor<32x!tt.ptr<f32>, #AL> -> tensor<2x32xf32, #A>
|
||||
```
|
||||
}];
|
||||
|
||||
let arguments = (ins TT_PtrLike:$src, TT_Tensor:$dst, I32:$index,
|
||||
Optional<I1Tensor>:$mask, Optional<TT_Type>:$other,
|
||||
TT_CacheModifierAttr:$cache, TT_EvictionPolicyAttr:$evict,
|
||||
BoolAttr:$isVolatile, I32Attr:$axis);
|
||||
|
||||
let builders = [
|
||||
OpBuilder<(ins "Value":$src, "Value":$dst, "Value":$index,
|
||||
"triton::CacheModifier":$cache,
|
||||
"triton::EvictionPolicy":$evict, "bool":$isVolatile, "int":$axis)>,
|
||||
OpBuilder<(ins "Value":$src, "Value":$dst, "Value":$index, "Value":$mask,
|
||||
"triton::CacheModifier":$cache,
|
||||
"triton::EvictionPolicy":$evict, "bool":$isVolatile, "int":$axis)>,
|
||||
OpBuilder<(ins "Value":$src, "Value":$dst, "Value":$index,
|
||||
"Value":$mask, "Value":$other,
|
||||
"triton::CacheModifier":$cache,
|
||||
"triton::EvictionPolicy":$evict, "bool":$isVolatile, "int":$axis)>,
|
||||
];
|
||||
|
||||
let results = (outs TT_Tensor:$result);
|
||||
|
||||
let extraClassDeclaration = [{
|
||||
static DenseSet<unsigned> getEligibleLoadByteWidth(int computeCapability) {
|
||||
DenseSet<unsigned> validLoadBytes;
|
||||
if (computeCapability >= 80) {
|
||||
validLoadBytes = {4, 8, 16};
|
||||
}
|
||||
return validLoadBytes;
|
||||
}
|
||||
}];
|
||||
|
||||
let hasCustomAssemblyFormat = 1;
|
||||
}
|
||||
|
||||
|
||||
def TTG_ExtractSliceOp : TTG_Op<"extract_slice",
|
||||
@@ -173,7 +292,8 @@ def TTG_ExtractSliceOp : TTG_Op<"extract_slice",
|
||||
def TTG_InsertSliceAsyncOp : TTG_Op<"insert_slice_async",
|
||||
[AttrSizedOperandSegments,
|
||||
ResultsAreSharedEncoding,
|
||||
MemoryEffects<[MemRead]>,
|
||||
// TODO: Check if MemWrite will degrade performance of non-warp-specialized kernel
|
||||
MemoryEffects<[MemRead, MemWrite]>,
|
||||
TypesMatchWith<"infer mask type from src type",
|
||||
"src", "mask", "getI1SameShape($_self)",
|
||||
"($_op.getOperands().size() <= 3) || std::equal_to<>()">,
|
||||
@@ -219,7 +339,7 @@ def TTG_InsertSliceAsyncOp : TTG_Op<"insert_slice_async",
|
||||
```
|
||||
}];
|
||||
|
||||
let arguments = (ins TT_PtrTensor:$src, TT_Tensor:$dst, I32:$index,
|
||||
let arguments = (ins TT_PtrLike:$src, TT_Tensor:$dst, I32:$index,
|
||||
Optional<I1Tensor>:$mask, Optional<TT_Type>:$other,
|
||||
TT_CacheModifierAttr:$cache, TT_EvictionPolicyAttr:$evict,
|
||||
BoolAttr:$isVolatile, I32Attr:$axis);
|
||||
|
||||
26
include/triton/Dialect/TritonGPU/IR/TritonGPUTypes.td
Normal file
26
include/triton/Dialect/TritonGPU/IR/TritonGPUTypes.td
Normal file
@@ -0,0 +1,26 @@
|
||||
#ifndef TRITONGPU_TYPES
|
||||
#define TRITONGPU_TYPES
|
||||
|
||||
include "triton/Dialect/TritonGPU/IR/TritonGPUDialect.td"
|
||||
include "mlir/IR/AttrTypeBase.td"
|
||||
|
||||
class TTG_TypeDef<string name, string _mnemonic, list<Trait> traits = []>
|
||||
: TypeDef<TritonGPU_Dialect, name, traits> {
|
||||
let mnemonic = _mnemonic;
|
||||
}
|
||||
|
||||
def TTG_TokenType : TTG_TypeDef<"Token", "token"> {
|
||||
let parameters = (ins "int32_t":$type);
|
||||
|
||||
let builders = [
|
||||
TypeBuilder<(ins "unsigned":$type), [{
|
||||
return $_get($_ctxt, type);
|
||||
}]>
|
||||
];
|
||||
|
||||
let hasCustomAssemblyFormat = 1;
|
||||
|
||||
let skipDefaultBuilders = 1;
|
||||
}
|
||||
|
||||
#endif
|
||||
10
include/triton/Dialect/TritonGPU/IR/Types.h
Normal file
10
include/triton/Dialect/TritonGPU/IR/Types.h
Normal file
@@ -0,0 +1,10 @@
|
||||
#ifndef TRITONGPU_IR_TYPES_H_
|
||||
#define TRITONGPU_IR_TYPES_H_
|
||||
|
||||
#include "mlir/IR/TypeSupport.h"
|
||||
#include "mlir/IR/Types.h"
|
||||
|
||||
#define GET_TYPEDEF_CLASSES
|
||||
#include "triton/Dialect/TritonGPU/IR/Types.h.inc"
|
||||
|
||||
#endif // TRITON_IR_TYPES_H_
|
||||
@@ -2,9 +2,14 @@
|
||||
#define TRITON_DIALECT_TRITONGPU_TRANSFORMS_PASSES_H_
|
||||
|
||||
#include "mlir/Pass/Pass.h"
|
||||
#include "triton/Dialect/TritonNvidiaGPU/IR/Dialect.h"
|
||||
|
||||
namespace mlir {
|
||||
std::unique_ptr<Pass> createTritonGPUPipelinePass(int numStages = 2);
|
||||
|
||||
std::unique_ptr<Pass> createTritonGPUPipelinePass(int numStages = 3,
|
||||
int numWarps = 4,
|
||||
int numCTAs = 1,
|
||||
int computeCapability = 80);
|
||||
|
||||
std::unique_ptr<Pass>
|
||||
createTritonGPUAccelerateMatmulPass(int computeCapability = 80);
|
||||
@@ -25,6 +30,8 @@ std::unique_ptr<Pass> createTritonGPUVerifier();
|
||||
|
||||
std::unique_ptr<Pass> createTritonGPUOptimizeDotOperandsPass();
|
||||
|
||||
std::unique_ptr<Pass> createTritonGPUOptimizeEpiloguePass();
|
||||
|
||||
/// Generate the code for registering passes.
|
||||
#define GEN_PASS_REGISTRATION
|
||||
#include "triton/Dialect/TritonGPU/Transforms/Passes.h.inc"
|
||||
|
||||
@@ -14,13 +14,23 @@ def TritonGPUPipeline : Pass<"tritongpu-pipeline", "mlir::ModuleOp"> {
|
||||
let constructor = "mlir::createTritonGPUPipelinePass()";
|
||||
|
||||
let dependentDialects = ["mlir::triton::gpu::TritonGPUDialect",
|
||||
"mlir::triton::nvidia_gpu::TritonNvidiaGPUDialect",
|
||||
"mlir::scf::SCFDialect",
|
||||
"mlir::arith::ArithDialect"];
|
||||
|
||||
let options = [
|
||||
Option<"numStages", "num-stages",
|
||||
"int32_t", /*default*/"2",
|
||||
"number of pipeline stages">
|
||||
"int32_t", /*default*/"3",
|
||||
"number of pipeline stages">,
|
||||
Option<"numWarps", "num-warps",
|
||||
"int32_t", /*default*/"4",
|
||||
"number of warps per block">,
|
||||
Option<"numCTAs", "num-ctas",
|
||||
"int32_t", /*default*/"1",
|
||||
"number of CTAs per CGA">,
|
||||
Option<"computeCapability", "compute-capability",
|
||||
"int32_t", /*default*/"80",
|
||||
"device compute capability">
|
||||
];
|
||||
}
|
||||
|
||||
@@ -50,6 +60,7 @@ def TritonGPUAccelerateMatmul : Pass<"tritongpu-accelerate-matmul", "mlir::Modul
|
||||
let constructor = "mlir::createTritonGPUAccelerateMatmulPass()";
|
||||
|
||||
let dependentDialects = ["mlir::triton::gpu::TritonGPUDialect",
|
||||
"mlir::triton::nvidia_gpu::TritonNvidiaGPUDialect",
|
||||
"mlir::triton::TritonDialect"];
|
||||
|
||||
let options = [
|
||||
@@ -70,6 +81,7 @@ def TritonGPUOptimizeDotOperands : Pass<"tritongpu-optimize-dot-operands", "mlir
|
||||
let constructor = "mlir::createTritonGPUOptimizeDotOperandsPass()";
|
||||
|
||||
let dependentDialects = ["mlir::triton::gpu::TritonGPUDialect",
|
||||
"mlir::triton::nvidia_gpu::TritonNvidiaGPUDialect",
|
||||
"mlir::triton::TritonDialect"];
|
||||
}
|
||||
|
||||
@@ -96,6 +108,20 @@ def TritonGPURemoveLayoutConversions : Pass<"tritongpu-remove-layout-conversions
|
||||
|
||||
let dependentDialects = ["mlir::triton::gpu::TritonGPUDialect",
|
||||
"mlir::triton::TritonDialect"];
|
||||
|
||||
}
|
||||
|
||||
def TritonGPUOptimizeEpilogue : Pass<"tritongpu-optimize-epilogue", "mlir::ModuleOp"> {
|
||||
let summary = "Optimize epilogue: (1) Store accumulators directly without going thorough SMEM in epilogue.";
|
||||
|
||||
let description = [{
|
||||
}];
|
||||
|
||||
let constructor = "mlir::createTritonGPUOptimizeEpiloguePass()";
|
||||
|
||||
let dependentDialects = ["mlir::triton::gpu::TritonGPUDialect",
|
||||
"mlir::triton::TritonDialect"];
|
||||
|
||||
}
|
||||
|
||||
def TritonGPUReorderInstructions: Pass<"tritongpu-reorder-instructions", "mlir::ModuleOp"> {
|
||||
|
||||
@@ -13,15 +13,17 @@ namespace mlir {
|
||||
|
||||
class TritonGPUTypeConverter : public TypeConverter {
|
||||
public:
|
||||
TritonGPUTypeConverter(MLIRContext *context, int numWarps,
|
||||
int threadsPerWarp);
|
||||
TritonGPUTypeConverter(MLIRContext *context, int numWarps, int threadsPerWarp,
|
||||
int numCTAs);
|
||||
int getNumWarps() const { return numWarps; }
|
||||
int getThreadsPerWarp() const { return threadsPerWarp; }
|
||||
int getNumCTAs() const { return numCTAs; }
|
||||
|
||||
private:
|
||||
MLIRContext *context;
|
||||
int numWarps;
|
||||
int threadsPerWarp;
|
||||
int numCTAs;
|
||||
};
|
||||
|
||||
class TritonGPUConversionTarget : public ConversionTarget {
|
||||
|
||||
@@ -10,8 +10,99 @@
|
||||
|
||||
namespace mlir {
|
||||
|
||||
namespace triton {
|
||||
class LoadOp;
|
||||
class StoreOp;
|
||||
class FuncOp;
|
||||
namespace gpu {
|
||||
class SharedEncodingAttr;
|
||||
}
|
||||
} // namespace triton
|
||||
|
||||
LogicalResult fixupLoops(ModuleOp mod);
|
||||
|
||||
SmallVector<unsigned, 3> mmaVersionToInstrShape(int version,
|
||||
const ArrayRef<int64_t> &shape,
|
||||
RankedTensorType type);
|
||||
|
||||
/// Returns true if the Load is for TMA
|
||||
bool isLoadFromTensorPtr(triton::LoadOp op);
|
||||
|
||||
/// Returns true if the store is for TMA
|
||||
bool isStoreToTensorPtr(triton::StoreOp op);
|
||||
|
||||
/// Return the first consumer of v
|
||||
Operation *getFirstUser(Value v);
|
||||
|
||||
/// Return the proper SharedEncodingAttr according to shape/order
|
||||
triton::gpu::SharedEncodingAttr getSharedEncoding(RankedTensorType tensorTy);
|
||||
|
||||
/* Dump Triton IR in graphviz dot format.
|
||||
*
|
||||
* You can override `onValue` and `onOperation` in a subclass to mark
|
||||
* specific Values and Operations. The below subclass
|
||||
* GraphLayoutMarker is an example.
|
||||
*
|
||||
* Default NodeInfo for Value nodes:
|
||||
* {{"shape": "box"},
|
||||
* {"style", "filled"},
|
||||
* {"fillcolor", "white"},
|
||||
* {"label", shapeStr}}
|
||||
*
|
||||
* Default NodeInfo for Operation nodes:
|
||||
* {{"shape": "ellipse"},
|
||||
* {"style", "filled"},
|
||||
* {"fillcolor", "white"},
|
||||
* {"label", operationName}}
|
||||
*
|
||||
* If the key "label" is not set by `onValue` or `onOperation`, default labels
|
||||
* will be generated. For Value node, the default label is the shape string and
|
||||
* for Operation node, it is the operation name.
|
||||
*
|
||||
* Reference:
|
||||
* https://graphviz.org/doc/info/shapes.html
|
||||
* https://graphviz.org/doc/info/colors.html
|
||||
*
|
||||
* Usage:
|
||||
* C++: GraphDumper().dumpToFile(func, "func.dot");
|
||||
* Shell: dot -Tjpg func.dot -o func.jpg
|
||||
*/
|
||||
class GraphDumper {
|
||||
public:
|
||||
using NodeInfo = std::map<std::string, std::string>;
|
||||
|
||||
// Override this function to mark specific Values
|
||||
virtual NodeInfo onValue(Value value) const;
|
||||
// Override this function to mark specific Operations
|
||||
virtual NodeInfo onOperation(Operation *op) const;
|
||||
|
||||
std::string dump(triton::FuncOp func) const;
|
||||
void dumpToFile(triton::FuncOp func, const std::string &filename) const;
|
||||
|
||||
protected:
|
||||
std::string getShapeStr(const Type &type) const;
|
||||
|
||||
std::string getUniqueId(Value value) const;
|
||||
std::string getUniqueId(Operation *op) const;
|
||||
|
||||
std::string emitNode(const std::string &id, const NodeInfo style) const;
|
||||
std::string emitEdge(const std::string &srcId,
|
||||
const std::string &destId) const;
|
||||
|
||||
std::string emitValueNode(Value value) const;
|
||||
std::string emitOperationNode(Operation *op) const;
|
||||
};
|
||||
|
||||
/* A subclass of GraphDumper that marks different layout kinds in different
|
||||
* colors.*/
|
||||
class GraphLayoutMarker : public GraphDumper {
|
||||
public:
|
||||
NodeInfo onValue(Value value) const override;
|
||||
|
||||
protected:
|
||||
std::string getColor(const Type &type) const;
|
||||
};
|
||||
|
||||
// TODO: Interface
|
||||
LogicalResult invertEncoding(Attribute targetEncoding, Operation *op,
|
||||
Attribute &ret);
|
||||
@@ -38,6 +129,23 @@ void rematerializeConversionChain(
|
||||
LogicalResult canMoveOutOfLoop(BlockArgument arg,
|
||||
SmallVector<Operation *> &cvts);
|
||||
|
||||
// Convert an \param index to a multi-dim coordinate given \param shape and
|
||||
// \param order.
|
||||
SmallVector<Value> delinearize(OpBuilder &b, Location loc, Value linear,
|
||||
ArrayRef<unsigned> shape,
|
||||
ArrayRef<unsigned> order);
|
||||
|
||||
SmallVector<Value> delinearize(OpBuilder &b, Location loc, unsigned linear,
|
||||
ArrayRef<unsigned> shape);
|
||||
|
||||
SmallVector<Value> delinearize(OpBuilder &b, Location loc, Value linear,
|
||||
ArrayRef<unsigned> shape);
|
||||
Value linearize(OpBuilder &b, Location loc, ArrayRef<Value> multiDim,
|
||||
ArrayRef<unsigned> shape, ArrayRef<unsigned> order);
|
||||
|
||||
Value linearize(OpBuilder &b, Location loc, ArrayRef<Value> multiDim,
|
||||
ArrayRef<unsigned> shape);
|
||||
|
||||
} // namespace mlir
|
||||
|
||||
#endif // TRITON_DIALECT_TRITONGPU_TRANSFORMS_UTILITY_H_
|
||||
|
||||
2
include/triton/Dialect/TritonNvidiaGPU/CMakeLists.txt
Normal file
2
include/triton/Dialect/TritonNvidiaGPU/CMakeLists.txt
Normal file
@@ -0,0 +1,2 @@
|
||||
add_subdirectory(IR)
|
||||
add_subdirectory(Transforms)
|
||||
15
include/triton/Dialect/TritonNvidiaGPU/IR/CMakeLists.txt
Normal file
15
include/triton/Dialect/TritonNvidiaGPU/IR/CMakeLists.txt
Normal file
@@ -0,0 +1,15 @@
|
||||
set(LLVM_TARGET_DEFINITIONS TritonNvidiaGPUOps.td)
|
||||
mlir_tablegen(Dialect.h.inc -gen-dialect-decls -dialect=triton_nvidia_gpu)
|
||||
mlir_tablegen(Dialect.cpp.inc -gen-dialect-defs -dialect=triton_nvidia_gpu)
|
||||
mlir_tablegen(Ops.h.inc -gen-op-decls)
|
||||
mlir_tablegen(Ops.cpp.inc -gen-op-defs)
|
||||
mlir_tablegen(Types.h.inc -gen-typedef-decls -typedefs-dialect=triton_nvidia_gpu)
|
||||
mlir_tablegen(Types.cpp.inc -gen-typedef-defs -typedefs-dialect=triton_nvidia_gpu)
|
||||
add_public_tablegen_target(TritonNvidiaGPUTableGen)
|
||||
|
||||
set(LLVM_TARGET_DEFINITIONS TritonNvidiaGPUAttrDefs.td)
|
||||
mlir_tablegen(TritonNvidiaGPUAttrDefs.h.inc -gen-attrdef-decls)
|
||||
mlir_tablegen(TritonNvidiaGPUAttrDefs.cpp.inc -gen-attrdef-defs)
|
||||
mlir_tablegen(OpsEnums.h.inc -gen-enum-decls)
|
||||
mlir_tablegen(OpsEnums.cpp.inc -gen-enum-defs)
|
||||
add_public_tablegen_target(TritonNvidiaGPUAttrDefsIncGen)
|
||||
46
include/triton/Dialect/TritonNvidiaGPU/IR/Dialect.h
Normal file
46
include/triton/Dialect/TritonNvidiaGPU/IR/Dialect.h
Normal file
@@ -0,0 +1,46 @@
|
||||
/*
|
||||
* Copyright (c) 2023 NVIDIA Corporation & Affiliates. All rights reserved.
|
||||
*
|
||||
* Permission is hereby granted, free of charge, to any person obtaining
|
||||
* a copy of this software and associated documentation files
|
||||
* (the "Software"), to deal in the Software without restriction,
|
||||
* including without limitation the rights to use, copy, modify, merge,
|
||||
* publish, distribute, sublicense, and/or sell copies of the Software,
|
||||
* and to permit persons to whom the Software is furnished to do so,
|
||||
* subject to the following conditions:
|
||||
*
|
||||
* The above copyright notice and this permission notice shall be
|
||||
* included in all copies or substantial portions of the Software.
|
||||
*
|
||||
* THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND,
|
||||
* EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF
|
||||
* MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT.
|
||||
* IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY
|
||||
* CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION OF CONTRACT,
|
||||
* TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION WITH THE
|
||||
* SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE.
|
||||
*/
|
||||
|
||||
#ifndef TRITON_DIALECT_TRITONNVIDIAGPU_IR_DIALECT_H_
|
||||
#define TRITON_DIALECT_TRITONNVIDIAGPU_IR_DIALECT_H_
|
||||
|
||||
#include "mlir/Dialect/GPU/IR/GPUDialect.h"
|
||||
#include "mlir/Dialect/Tensor/IR/Tensor.h"
|
||||
#include "mlir/IR/BuiltinOps.h"
|
||||
#include "mlir/IR/Dialect.h"
|
||||
|
||||
// TritonNvidiaGPU depends on Triton
|
||||
#include "triton/Dialect/NVGPU/IR/Dialect.h"
|
||||
#include "triton/Dialect/Triton/IR/Dialect.h"
|
||||
#include "triton/Dialect/TritonGPU/IR/Traits.h"
|
||||
#include "triton/Dialect/TritonNvidiaGPU/IR/Dialect.h.inc"
|
||||
#include "triton/Dialect/TritonNvidiaGPU/IR/Traits.h"
|
||||
#include "triton/Dialect/TritonNvidiaGPU/IR/Types.h"
|
||||
|
||||
#define GET_ATTRDEF_CLASSES
|
||||
#include "triton/Dialect/TritonNvidiaGPU/IR/TritonNvidiaGPUAttrDefs.h.inc"
|
||||
|
||||
#define GET_OP_CLASSES
|
||||
#include "triton/Dialect/TritonNvidiaGPU/IR/Ops.h.inc"
|
||||
|
||||
#endif // TRITON_DIALECT_TRITONNVIDIAGPU_IR_DIALECT_H_
|
||||
53
include/triton/Dialect/TritonNvidiaGPU/IR/Traits.h
Normal file
53
include/triton/Dialect/TritonNvidiaGPU/IR/Traits.h
Normal file
@@ -0,0 +1,53 @@
|
||||
/*
|
||||
* Copyright (c) 2023 NVIDIA Corporation & Affiliates. All rights reserved.
|
||||
*
|
||||
* Permission is hereby granted, free of charge, to any person obtaining
|
||||
* a copy of this software and associated documentation files
|
||||
* (the "Software"), to deal in the Software without restriction,
|
||||
* including without limitation the rights to use, copy, modify, merge,
|
||||
* publish, distribute, sublicense, and/or sell copies of the Software,
|
||||
* and to permit persons to whom the Software is furnished to do so,
|
||||
* subject to the following conditions:
|
||||
*
|
||||
* The above copyright notice and this permission notice shall be
|
||||
* included in all copies or substantial portions of the Software.
|
||||
*
|
||||
* THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND,
|
||||
* EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF
|
||||
* MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT.
|
||||
* IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY
|
||||
* CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION OF CONTRACT,
|
||||
* TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION WITH THE
|
||||
* SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE.
|
||||
*/
|
||||
|
||||
#ifndef TRITON_NVIDIA_GPU_IR_TRAITS_H_
|
||||
#define TRITON_NVIDIA_GPU_IR_TRAITS_H_
|
||||
|
||||
#include "mlir/IR/OpDefinition.h"
|
||||
|
||||
#include "mlir/IR/BuiltinTypes.h"
|
||||
#include "mlir/Support/LogicalResult.h"
|
||||
|
||||
namespace mlir {
|
||||
namespace OpTrait {
|
||||
|
||||
// These functions are out-of-line implementations of the methods in the
|
||||
// corresponding trait classes. This avoids them being template
|
||||
// instantiated/duplicated.
|
||||
namespace impl {
|
||||
LogicalResult verifySource1IsSharedEncoding(Operation *op);
|
||||
} // namespace impl
|
||||
|
||||
template <typename ConcreteType>
|
||||
class Source1IsSharedEncoding
|
||||
: public TraitBase<ConcreteType, Source1IsSharedEncoding> {
|
||||
public:
|
||||
static LogicalResult verifyTrait(Operation *op) {
|
||||
return impl::verifySource1IsSharedEncoding(op);
|
||||
}
|
||||
};
|
||||
} // namespace OpTrait
|
||||
} // namespace mlir
|
||||
|
||||
#endif
|
||||
@@ -0,0 +1,29 @@
|
||||
// Copyright (c) 2023 NVIDIA Corporation & Affiliates. All rights reserved.
|
||||
//
|
||||
// Permission is hereby granted, free of charge, to any person obtaining
|
||||
// a copy of this software and associated documentation files
|
||||
// (the "Software"), to deal in the Software without restriction,
|
||||
// including without limitation the rights to use, copy, modify, merge,
|
||||
// publish, distribute, sublicense, and/or sell copies of the Software,
|
||||
// and to permit persons to whom the Software is furnished to do so,
|
||||
// subject to the following conditions:
|
||||
//
|
||||
// The above copyright notice and this permission notice shall be
|
||||
// included in all copies or substantial portions of the Software.
|
||||
//
|
||||
// THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND,
|
||||
// EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF
|
||||
// MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT.
|
||||
// IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY
|
||||
// CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION OF CONTRACT,
|
||||
// TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION WITH THE
|
||||
// SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE.
|
||||
|
||||
#ifndef TRITONNVIDIAGPU_ATTRDEFS
|
||||
#define TRITONNVIDIAGPU_ATTRDEFS
|
||||
|
||||
include "mlir/IR/AttrTypeBase.td"
|
||||
include "triton/Dialect/TritonNvidiaGPU/IR/TritonNvidiaGPUDialect.td"
|
||||
include "triton/Dialect/Triton/IR/TritonInterfaces.td"
|
||||
|
||||
#endif
|
||||
@@ -0,0 +1,82 @@
|
||||
// Copyright (c) 2023 NVIDIA Corporation & Affiliates. All rights reserved.
|
||||
//
|
||||
// Permission is hereby granted, free of charge, to any person obtaining
|
||||
// a copy of this software and associated documentation files
|
||||
// (the "Software"), to deal in the Software without restriction,
|
||||
// including without limitation the rights to use, copy, modify, merge,
|
||||
// publish, distribute, sublicense, and/or sell copies of the Software,
|
||||
// and to permit persons to whom the Software is furnished to do so,
|
||||
// subject to the following conditions:
|
||||
//
|
||||
// The above copyright notice and this permission notice shall be
|
||||
// included in all copies or substantial portions of the Software.
|
||||
//
|
||||
// THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND,
|
||||
// EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF
|
||||
// MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT.
|
||||
// IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY
|
||||
// CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION OF CONTRACT,
|
||||
// TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION WITH THE
|
||||
// SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE.
|
||||
|
||||
#ifndef TRITONNVIDIAGPU_DIALECT
|
||||
#define TRITONNVIDIAGPU_DIALECT
|
||||
|
||||
include "mlir/IR/OpBase.td"
|
||||
|
||||
def TritonNvidiaGPU_Dialect : Dialect {
|
||||
let name = "triton_nvidia_gpu";
|
||||
|
||||
let cppNamespace = "::mlir::triton::nvidia_gpu";
|
||||
|
||||
let hasOperationAttrVerify = 1;
|
||||
|
||||
let description = [{
|
||||
Triton Nvidia GPU Dialect.
|
||||
}];
|
||||
|
||||
let dependentDialects = [
|
||||
"triton::TritonDialect",
|
||||
"triton::gpu::TritonGPUDialect",
|
||||
"mlir::triton::nvgpu::NVGPUDialect",
|
||||
"mlir::gpu::GPUDialect",
|
||||
"tensor::TensorDialect",
|
||||
];
|
||||
|
||||
let extraClassDeclaration = [{
|
||||
static std::string getNumWarpsAttrName() { return "triton_gpu.num-warps"; }
|
||||
static int getNumWarps(ModuleOp mod) {
|
||||
if(!mod->hasAttr("triton_gpu.num-warps"))
|
||||
llvm::report_fatal_error(
|
||||
"TritonGPU module should contain a triton_gpu.num-warps attribute");
|
||||
return mod->getAttr("triton_gpu.num-warps").cast<IntegerAttr>().getInt();
|
||||
}
|
||||
static int getNumCTAs(ModuleOp mod) {
|
||||
if(!mod->hasAttr("triton_gpu.num-ctas"))
|
||||
llvm::report_fatal_error(
|
||||
"TritonGPU module should contain a triton_gpu.num-ctas attribute");
|
||||
return mod->getAttr("triton_gpu.num-ctas").cast<IntegerAttr>().getInt();
|
||||
}
|
||||
static int getComputeCapability(ModuleOp mod) {
|
||||
if(!mod->hasAttr("triton_gpu.compute-capability"))
|
||||
llvm::report_fatal_error(
|
||||
"TritonGPU module should contain a triton_gpu.compute-capability attribute");
|
||||
return mod->getAttrOfType<IntegerAttr>("triton_gpu.compute-capability").getInt();
|
||||
}
|
||||
void registerTypes();
|
||||
|
||||
// Warp specialization related:
|
||||
static std::string getWSSupportedAttrName() { return "triton_gpu.enable-warp-specialization"; }
|
||||
static int getWSSupportedAttr(ModuleOp mod) {
|
||||
auto name = getWSSupportedAttrName();
|
||||
if (!mod->hasAttr(name)) return 0;
|
||||
return mod->getAttrOfType<IntegerAttr>(name).getInt();
|
||||
}
|
||||
}];
|
||||
|
||||
let useDefaultTypePrinterParser = 1;
|
||||
}
|
||||
|
||||
include "triton/Dialect/TritonNvidiaGPU/IR/TritonNvidiaGPUTypes.td"
|
||||
|
||||
#endif
|
||||
386
include/triton/Dialect/TritonNvidiaGPU/IR/TritonNvidiaGPUOps.td
Normal file
386
include/triton/Dialect/TritonNvidiaGPU/IR/TritonNvidiaGPUOps.td
Normal file
@@ -0,0 +1,386 @@
|
||||
// Copyright (c) 2023 NVIDIA Corporation & Affiliates. All rights reserved.
|
||||
//
|
||||
// Permission is hereby granted, free of charge, to any person obtaining
|
||||
// a copy of this software and associated documentation files
|
||||
// (the "Software"), to deal in the Software without restriction,
|
||||
// including without limitation the rights to use, copy, modify, merge,
|
||||
// publish, distribute, sublicense, and/or sell copies of the Software,
|
||||
// and to permit persons to whom the Software is furnished to do so,
|
||||
// subject to the following conditions:
|
||||
//
|
||||
// The above copyright notice and this permission notice shall be
|
||||
// included in all copies or substantial portions of the Software.
|
||||
//
|
||||
// THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND,
|
||||
// EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF
|
||||
// MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT.
|
||||
// IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY
|
||||
// CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION OF CONTRACT,
|
||||
// TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION WITH THE
|
||||
// SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE.
|
||||
|
||||
#ifndef TRITONNVIDIAGPU_OPS
|
||||
#define TRITONNVIDIAGPU_OPS
|
||||
|
||||
include "triton/Dialect/TritonNvidiaGPU/IR/TritonNvidiaGPUDialect.td"
|
||||
include "triton/Dialect/TritonNvidiaGPU/IR/TritonNvidiaGPUTypes.td"
|
||||
include "triton/Dialect/TritonNvidiaGPU/IR/TritonNvidiaGPUAttrDefs.td"
|
||||
include "mlir/Dialect/Arith/IR/ArithBase.td"
|
||||
include "triton/Dialect/Triton/IR/TritonTypes.td"
|
||||
include "triton/Dialect/Triton/IR/TritonAttrDefs.td"
|
||||
include "mlir/IR/OpBase.td"
|
||||
include "mlir/Interfaces/SideEffectInterfaces.td" // Pure
|
||||
include "mlir/Interfaces/InferTypeOpInterface.td" // SameOperandsAndResultType
|
||||
include "mlir/Interfaces/DestinationStyleOpInterface.td"
|
||||
include "mlir/Interfaces/ViewLikeInterface.td"
|
||||
|
||||
def Source1IsSharedEncoding: NativeOpTrait<"Source1IsSharedEncoding">;
|
||||
|
||||
def ResultsAreSharedEncoding: NativeOpTrait<"ResultsAreSharedEncoding">;
|
||||
|
||||
class TTNG_Op<string mnemonic, list<Trait> traits = []> :
|
||||
Op<TritonNvidiaGPU_Dialect, mnemonic, traits>;
|
||||
|
||||
// --------------------------------------------------------------------------------------------------
|
||||
// MBarrier related Ops:
|
||||
// 1, These mbarrier commands are currently not needed, and not taken into consideration:
|
||||
// (1), mbarrier.expect_tx
|
||||
// (2), mbarrier.arrive_drop
|
||||
// (3), mbarrier.complete_tx
|
||||
// (4), mbarrier.inval
|
||||
//
|
||||
// 2, The mbarriers is supported to be created in vector, and accessed in seperate via tensor.extract.
|
||||
// The mbarriers created in vector will have counters initialized in the same configuration. A
|
||||
// typical example to demonstrate this:
|
||||
//
|
||||
// %1 = triton_nvidia_gpu.alloc_mbarrier { count = 1 } : tensor<4x!tt.ptr<i64>>
|
||||
// scf.for %iv = %lb to %ub step %step iter_args() -> () {
|
||||
// %buffer_id = arith.remi %iv, %c4 : i32
|
||||
// %2 = triton_nvidia_gpu.extract_mbarrier %1[%buffer_id] : tensor<4xi64>, i32 -> !tt.ptr<i64>
|
||||
// triton_nvidia_gpu.mbarrier_arrive %2 {expectTx = 2048} : !tt.ptr<i64> -> ()
|
||||
// }
|
||||
// ...
|
||||
// scf.for %iv = %lb to %ub step %step iter_args() -> () {
|
||||
// %buffer_id = arith.remi %iv, %c4 : i32
|
||||
// %2 = triton_nvidia_gpu.extract_mbarrier %1[%buffer_id] : tensor<4xi64>, i32 -> !tt.ptr<i64>
|
||||
// triton_nvidia_gpu.mbarrier_wait %2, %c0 : !tt.ptr<i64>, i1 -> ()
|
||||
// }
|
||||
|
||||
def TTNG_AllocMBarrierOp : TTNG_Op<"alloc_mbarrier", [MemoryEffects<[MemAlloc]>]> {
|
||||
let summary = "allocate a vector of mbarriers";
|
||||
|
||||
let description = [{
|
||||
Allocate and initialize a vector of mbarriers. The size of the vector is implied in the returned type.
|
||||
Each mbarrier is initialized as:
|
||||
1, the current phase initialized to 0.
|
||||
2, the expected arrival count initialized to 'count'.
|
||||
3, the pending arrival count initialized to 'count'.
|
||||
4, the tx-count initialized to 0.
|
||||
|
||||
Example:
|
||||
|
||||
case a. when created in vector:
|
||||
%1 = triton_nvidia_gpu.alloc_mbarrier { count = 1 } : tensor<4xi64>
|
||||
|
||||
case b. when created in scalar:
|
||||
%1 = triton_nvidia_gpu.alloc_mbarrier { count = 1 } : !tt.ptr<i64>
|
||||
|
||||
}];
|
||||
|
||||
let assemblyFormat = [{attr-dict `:` type($result)}];
|
||||
|
||||
let arguments = (ins I32Attr:$count);
|
||||
|
||||
let results = (outs AnyTypeOf<[TT_Ptr, I64Tensor]>:$result);
|
||||
}
|
||||
|
||||
def TTNG_ExtractMBarrierOp : TTNG_Op<"extract_mbarrier", [Pure]> {
|
||||
let summary = "extract a mbarrier from a vector of mbarriers";
|
||||
|
||||
let description = [{
|
||||
Extract a mbarrier from a vector of mbarriers
|
||||
|
||||
Example:
|
||||
|
||||
%1 = triton_nvidia_gpu.extract_mbarrier %mbarriers[%idx] : tensor<4xi64>, index -> !tt.ptr<i64>
|
||||
|
||||
}];
|
||||
|
||||
let assemblyFormat = "$tensor `[` $index `]` attr-dict `:` type($tensor) `,` type($index) `->` type($result)";
|
||||
|
||||
let arguments = (ins I64Tensor:$tensor, I32:$index);
|
||||
|
||||
let results = (outs TT_Ptr:$result);
|
||||
}
|
||||
|
||||
def TTNG_MBarrierWaitOp : TTNG_Op<"mbarrier_wait", [MemoryEffects<[MemRead, MemWrite]>]> {
|
||||
let summary = "mbarrier wait";
|
||||
|
||||
let description = [{
|
||||
This operation defining the waiting action for a mbarrier.
|
||||
The subsequent operations should not execute until this operation completes waiting.
|
||||
|
||||
Example:
|
||||
|
||||
triton_nvidia_gpu.mbarrier_wait %0, %1 : !tt.ptr<i64>
|
||||
|
||||
}];
|
||||
|
||||
let arguments = (ins TT_Ptr:$mbarrier, I1: $phase);
|
||||
|
||||
let assemblyFormat = "$mbarrier `,` $phase attr-dict `:` type($mbarrier)";
|
||||
}
|
||||
|
||||
def TTNG_MBarrierArriveOp : TTNG_Op<"mbarrier_arrive", [AttrSizedOperandSegments,
|
||||
MemoryEffects<[MemWrite]>]> {
|
||||
let summary = "mbarrier arrive";
|
||||
|
||||
let description = [{
|
||||
This operation defining the arriving action for a mbarrier.
|
||||
txCount:
|
||||
An optional attribute that set tx-count. This Op will be lowered into
|
||||
mbarrier.arrive.expect_tx if the optional attribute exist.
|
||||
trackAsyncOp:
|
||||
If true, this op will be lowered into cp.async.mbarrier.arrive.noinc.
|
||||
pred:
|
||||
Only perform arrive action when pred is true.
|
||||
remoteCtaId:
|
||||
if set, perform an remote arrive action.
|
||||
|
||||
Example:
|
||||
|
||||
triton_nvidia_gpu.mbarrier_arrive %0 {trackAsyncOp = false} : !tt.ptr<i64>
|
||||
|
||||
}];
|
||||
|
||||
let arguments = (ins TT_Ptr:$mbarrier,
|
||||
Optional<I1>:$pred,
|
||||
Optional<I32>:$remoteCtaId,
|
||||
I1Attr: $trackAsyncOp,
|
||||
DefaultValuedAttr<I32Attr, "0">: $txCount
|
||||
);
|
||||
|
||||
let assemblyFormat = "operands attr-dict `:` type(operands)";
|
||||
}
|
||||
|
||||
def TTNG_FenceAsyncSharedOp : TTNG_Op<"fence_async_shared"> {
|
||||
let arguments = (ins BoolAttr:$bCluster);
|
||||
|
||||
let summary = "fence proxy async";
|
||||
|
||||
let assemblyFormat = "attr-dict";
|
||||
|
||||
let extraClassDeclaration = [{
|
||||
static bool isSupported(int computeCapability) {
|
||||
return computeCapability >= 90;
|
||||
}
|
||||
}];
|
||||
}
|
||||
|
||||
// TODO[goostavz]: ThreadId & ClusterCTAId should not be exposed to
|
||||
// ttgpu level. Remove them when async dialect is ready.
|
||||
def TTNG_GetThreadIdOp : TTNG_Op<"get_thread_id", [Pure]> {
|
||||
let description = [{
|
||||
Returns the one dimensional threadId.
|
||||
}];
|
||||
|
||||
let results = (outs I32:$result);
|
||||
let assemblyFormat = "attr-dict `:` type($result)";
|
||||
}
|
||||
|
||||
def TTNG_GetClusterCTAIdOp : TTNG_Op<"get_cluster_cta_id", [Pure]> {
|
||||
let description = [{
|
||||
Returns the one dimensional cluster_cta_id.
|
||||
}];
|
||||
|
||||
let results = (outs I32:$result);
|
||||
let assemblyFormat = "attr-dict `:` type($result)";
|
||||
}
|
||||
|
||||
def TTNG_NamedBarrierArriveOp : TTNG_Op<"bar_arrive", []> {
|
||||
let summary = "named barrier arrive";
|
||||
|
||||
let arguments = (ins I32:$bar, I32: $numThreads);
|
||||
|
||||
let assemblyFormat = "$bar `,` $numThreads attr-dict `:` type(operands)";
|
||||
}
|
||||
|
||||
def TTNG_NamedBarrierWaitOp : TTNG_Op<"bar_wait", []> {
|
||||
let summary = "named barrier wait";
|
||||
|
||||
let arguments = (ins I32:$bar, I32: $numThreads);
|
||||
|
||||
let assemblyFormat = "$bar `,` $numThreads attr-dict `:` type(operands)";
|
||||
}
|
||||
|
||||
def TTNG_InsertSliceAsyncV2Op : TTNG_Op<"insert_slice_async_v2",
|
||||
[AttrSizedOperandSegments,
|
||||
ResultsAreSharedEncoding,
|
||||
// TODO: Check if MemWrite will degrade performance of non-warp-specialized kernel
|
||||
MemoryEffects<[MemRead, MemWrite]>]> {
|
||||
|
||||
let arguments = (ins AnyTypeOf<[TT_Ptr, TT_PtrTensor]>:$src, TT_Tensor:$dst,
|
||||
I32:$index, TT_Ptr:$mbar,
|
||||
Optional<AnyTypeOf<[I1Tensor, I1]>>:$mask, Optional<TT_Type>:$other,
|
||||
TT_CacheModifierAttr:$cache, TT_EvictionPolicyAttr:$evict,
|
||||
BoolAttr:$isVolatile, I32Attr:$axis);
|
||||
|
||||
let results = (outs TT_Tensor:$result);
|
||||
|
||||
let assemblyFormat = "operands attr-dict `:` type(operands) `->` type($result)";
|
||||
}
|
||||
|
||||
// TODO: the abstraction of barriers in ttgpu level is pending, will revisit later
|
||||
// def TTNG_AwaitOp : TTNG_Op<"await", []> {
|
||||
// let arguments = (ins TTNG_TokenType:$token);
|
||||
// let assemblyFormat = "$token attr-dict `:` type($token)";
|
||||
// }
|
||||
|
||||
def TTNG_ClusterArriveOp : TTNG_Op<"cluster_arrive", []> {
|
||||
let arguments = (ins I1Attr:$relaxed);
|
||||
let assemblyFormat = "attr-dict";
|
||||
}
|
||||
|
||||
def TTNG_ClusterWaitOp : TTNG_Op<"cluster_wait", []> {
|
||||
let assemblyFormat = "attr-dict";
|
||||
}
|
||||
|
||||
//
|
||||
// DotAsync Op
|
||||
//
|
||||
def TTNG_DotAsyncOp : TTNG_Op<"dot_async", [Pure,
|
||||
DeclareOpInterfaceMethods<InferTypeOpInterface>,
|
||||
TypesMatchWith<"result's type matches accumulator's type",
|
||||
"d", "c", "$_self">]> {
|
||||
let summary = "dot async";
|
||||
|
||||
let description = [{
|
||||
$d = matrix_multiply($a, $b) + $c
|
||||
}];
|
||||
|
||||
let arguments = (ins TT_FpIntTensor:$a, TT_FpIntTensor:$b, TT_FpIntTensor:$c, BoolAttr:$allowTF32);
|
||||
|
||||
let results = (outs TT_FpIntTensor:$d);
|
||||
|
||||
let assemblyFormat = "$a`,` $b`,` $c attr-dict `:` type($a) `*` type($b) `->` type($d)";
|
||||
}
|
||||
|
||||
def TTNG_DotWaitOp : TTNG_Op<"dot_wait", []> {
|
||||
let summary = "dot wait";
|
||||
|
||||
let description = [{
|
||||
This operation defining the waiting action for a async dot, MMAv3 .e.g.
|
||||
The subsequent operations should not execute until this operation completes waiting.
|
||||
}];
|
||||
|
||||
let arguments = (ins I32Attr:$pendings);
|
||||
|
||||
let assemblyFormat = "attr-dict";
|
||||
}
|
||||
|
||||
def TTNG_StoreAsyncOp : TTNG_Op<"store_async",
|
||||
[Source1IsSharedEncoding,
|
||||
MemoryEffects<[MemWrite]>]> {
|
||||
let summary = "store asynchronous by a tensor pointer";
|
||||
let arguments = (ins TT_TensorPtr:$dst, TT_Tensor:$src,
|
||||
DefaultValuedAttr<TT_CacheModifierAttr, "triton::CacheModifier::NONE">:$cache);
|
||||
let assemblyFormat = "operands attr-dict `:` type(operands)";
|
||||
}
|
||||
|
||||
def TTNG_GetAgentIdOp : TTNG_Op<"get_agent_id", [Pure]> {
|
||||
let results = (outs I32:$result);
|
||||
|
||||
let builders = [OpBuilder<(ins)>];
|
||||
|
||||
let assemblyFormat = "attr-dict `:` type($result)";
|
||||
}
|
||||
|
||||
//
|
||||
// Token
|
||||
//
|
||||
|
||||
def TTNG_CreateTokenOp : TTNG_Op<"create_token"> {
|
||||
let results = (outs TensorOf<[TTNG_TokenType]>:$result);
|
||||
|
||||
let arguments = (ins I32Attr:$num);
|
||||
|
||||
let builders = [OpBuilder<(ins "uint32_t":$num)>];
|
||||
|
||||
let assemblyFormat = "attr-dict `:` type($result)";
|
||||
}
|
||||
|
||||
def TTNG_ProducerAcquireOp : TTNG_Op<"producer_acquire"> {
|
||||
let arguments = (ins TensorOf<[TTNG_TokenType]>:$token, I32:$idx);
|
||||
|
||||
let assemblyFormat = "$token `,` $idx attr-dict `:` type(operands)";
|
||||
}
|
||||
|
||||
def TTNG_ProducerCommitOp : TTNG_Op<"producer_commit"> {
|
||||
let arguments = (ins TensorOf<[TTNG_TokenType]>:$token, I32:$idx);
|
||||
|
||||
let assemblyFormat = "$token `,` $idx attr-dict `:` type(operands)";
|
||||
}
|
||||
|
||||
def TTNG_ConsumerWaitOp : TTNG_Op<"consumer_wait"> {
|
||||
let arguments = (ins TensorOf<[TTNG_TokenType]>:$token, I32:$idx);
|
||||
|
||||
let assemblyFormat = "$token `,` $idx attr-dict `:` type(operands)";
|
||||
}
|
||||
|
||||
def TTNG_ConsumerReleaseOp : TTNG_Op<"consumer_release"> {
|
||||
let arguments = (ins TensorOf<[TTNG_TokenType]>:$token, I32:$idx);
|
||||
|
||||
let assemblyFormat = "$token `,` $idx attr-dict `:` type(operands)";
|
||||
}
|
||||
|
||||
//
|
||||
// Mutex
|
||||
//
|
||||
|
||||
def TTNG_GetMutexRoleIdOp : TTNG_Op<"get_mutex_role_id"> {
|
||||
let results = (outs I32:$result);
|
||||
|
||||
let arguments = (ins I32Attr:$num);
|
||||
|
||||
let builders = [OpBuilder<(ins "uint32_t":$num)>];
|
||||
|
||||
let assemblyFormat = "attr-dict `:` type($result)";
|
||||
}
|
||||
|
||||
def TTNG_CreateMutexOp : TTNG_Op<"create_mutex"> {
|
||||
let results = (outs TTNG_MutexType:$result);
|
||||
|
||||
let builders = [OpBuilder<(ins)>];
|
||||
|
||||
let assemblyFormat = "attr-dict `:` type($result)";
|
||||
}
|
||||
|
||||
def TTNG_LockOp : TTNG_Op<"lock"> {
|
||||
let arguments = (ins TTNG_MutexType:$mutex);
|
||||
|
||||
let assemblyFormat = "$mutex attr-dict `:` type(operands)";
|
||||
}
|
||||
|
||||
def TTNG_UnlockOp : TTNG_Op<"unlock"> {
|
||||
let arguments = (ins TTNG_MutexType:$mutex);
|
||||
|
||||
let assemblyFormat = "$mutex attr-dict `:` type(operands)";
|
||||
}
|
||||
|
||||
def TTNG_RegAllocOp : TTNG_Op<"reg_alloc", []> {
|
||||
let summary = "register allocation";
|
||||
|
||||
let arguments = (ins I32Attr: $regCount);
|
||||
|
||||
let assemblyFormat = "$regCount attr-dict `:` type(operands)";
|
||||
}
|
||||
|
||||
def TTNG_RegDeallocOp : TTNG_Op<"reg_dealloc", []> {
|
||||
let summary = "register deallocation";
|
||||
|
||||
let arguments = (ins I32Attr: $regCount);
|
||||
|
||||
let assemblyFormat = "$regCount attr-dict `:` type(operands)";
|
||||
}
|
||||
|
||||
#endif
|
||||
@@ -0,0 +1,37 @@
|
||||
// Copyright (c) 2023 NVIDIA Corporation & Affiliates. All rights reserved.
|
||||
//
|
||||
// Permission is hereby granted, free of charge, to any person obtaining
|
||||
// a copy of this software and associated documentation files
|
||||
// (the "Software"), to deal in the Software without restriction,
|
||||
// including without limitation the rights to use, copy, modify, merge,
|
||||
// publish, distribute, sublicense, and/or sell copies of the Software,
|
||||
// and to permit persons to whom the Software is furnished to do so,
|
||||
// subject to the following conditions:
|
||||
//
|
||||
// The above copyright notice and this permission notice shall be
|
||||
// included in all copies or substantial portions of the Software.
|
||||
//
|
||||
// THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND,
|
||||
// EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF
|
||||
// MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT.
|
||||
// IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY
|
||||
// CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION OF CONTRACT,
|
||||
// TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION WITH THE
|
||||
// SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE.
|
||||
|
||||
#ifndef TRITONNVIDIAGPU_TYPES
|
||||
#define TRITONNVIDIAGPU_TYPES
|
||||
|
||||
include "triton/Dialect/TritonNvidiaGPU/IR/TritonNvidiaGPUDialect.td"
|
||||
include "mlir/IR/AttrTypeBase.td"
|
||||
|
||||
class TTNG_TypeDef<string name, string _mnemonic>
|
||||
: TypeDef<TritonNvidiaGPU_Dialect, name> {
|
||||
let mnemonic = _mnemonic;
|
||||
}
|
||||
|
||||
def TTNG_TokenType : TTNG_TypeDef<"Token", "token">;
|
||||
|
||||
def TTNG_MutexType : TTNG_TypeDef<"Mutex", "mutex">;
|
||||
|
||||
#endif
|
||||
33
include/triton/Dialect/TritonNvidiaGPU/IR/Types.h
Normal file
33
include/triton/Dialect/TritonNvidiaGPU/IR/Types.h
Normal file
@@ -0,0 +1,33 @@
|
||||
/*
|
||||
* Copyright (c) 2023 NVIDIA Corporation & Affiliates. All rights reserved.
|
||||
*
|
||||
* Permission is hereby granted, free of charge, to any person obtaining
|
||||
* a copy of this software and associated documentation files
|
||||
* (the "Software"), to deal in the Software without restriction,
|
||||
* including without limitation the rights to use, copy, modify, merge,
|
||||
* publish, distribute, sublicense, and/or sell copies of the Software,
|
||||
* and to permit persons to whom the Software is furnished to do so,
|
||||
* subject to the following conditions:
|
||||
*
|
||||
* The above copyright notice and this permission notice shall be
|
||||
* included in all copies or substantial portions of the Software.
|
||||
*
|
||||
* THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND,
|
||||
* EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF
|
||||
* MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT.
|
||||
* IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY
|
||||
* CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION OF CONTRACT,
|
||||
* TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION WITH THE
|
||||
* SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE.
|
||||
*/
|
||||
|
||||
#ifndef TRITONNVIDIAGPU_IR_TYPES_H_
|
||||
#define TRITONNVIDIAGPU_IR_TYPES_H_
|
||||
|
||||
#include "mlir/IR/TypeSupport.h"
|
||||
#include "mlir/IR/Types.h"
|
||||
|
||||
#define GET_TYPEDEF_CLASSES
|
||||
#include "triton/Dialect/TritonNvidiaGPU/IR/Types.h.inc"
|
||||
|
||||
#endif // TRITON_IR_TYPES_H_
|
||||
@@ -0,0 +1,3 @@
|
||||
set(LLVM_TARGET_DEFINITIONS Passes.td)
|
||||
mlir_tablegen(Passes.h.inc -gen-pass-decls -name TritonNvidiaGPU)
|
||||
add_public_tablegen_target(TritonNvidiaGPUTransformsIncGen)
|
||||
81
include/triton/Dialect/TritonNvidiaGPU/Transforms/Passes.h
Normal file
81
include/triton/Dialect/TritonNvidiaGPU/Transforms/Passes.h
Normal file
@@ -0,0 +1,81 @@
|
||||
/*
|
||||
* Copyright (c) 2023 NVIDIA Corporation & Affiliates. All rights reserved.
|
||||
*
|
||||
* Permission is hereby granted, free of charge, to any person obtaining
|
||||
* a copy of this software and associated documentation files
|
||||
* (the "Software"), to deal in the Software without restriction,
|
||||
* including without limitation the rights to use, copy, modify, merge,
|
||||
* publish, distribute, sublicense, and/or sell copies of the Software,
|
||||
* and to permit persons to whom the Software is furnished to do so,
|
||||
* subject to the following conditions:
|
||||
*
|
||||
* The above copyright notice and this permission notice shall be
|
||||
* included in all copies or substantial portions of the Software.
|
||||
*
|
||||
* THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND,
|
||||
* EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF
|
||||
* MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT.
|
||||
* IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY
|
||||
* CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION OF CONTRACT,
|
||||
* TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION WITH THE
|
||||
* SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE.
|
||||
*/
|
||||
|
||||
#ifndef TRITON_DIALECT_TRITONNVIDIAGPU_TRANSFORMS_PASSES_H_
|
||||
#define TRITON_DIALECT_TRITONNVIDIAGPU_TRANSFORMS_PASSES_H_
|
||||
|
||||
#include "mlir/Pass/Pass.h"
|
||||
|
||||
namespace mlir {
|
||||
namespace triton {
|
||||
namespace nvidia_gpu {
|
||||
|
||||
// Used by Triton runtime
|
||||
struct ClusterInfo {
|
||||
ClusterInfo() : clusterDimX(1), clusterDimY(1), clusterDimZ(1) {}
|
||||
int clusterDimX;
|
||||
int clusterDimY;
|
||||
int clusterDimZ;
|
||||
};
|
||||
|
||||
} // namespace nvidia_gpu
|
||||
} // namespace triton
|
||||
} // namespace mlir
|
||||
|
||||
namespace mlir {
|
||||
|
||||
std::unique_ptr<Pass>
|
||||
createTritonNvidiaGPUMaterializeLoadStorePass(int numWarps = 4,
|
||||
int computeCapability = 80);
|
||||
|
||||
std::unique_ptr<Pass> createTritonNvidiaGPUPlanCTAPass(
|
||||
mlir::triton::nvidia_gpu::ClusterInfo *clusterInfo = nullptr);
|
||||
|
||||
std::unique_ptr<Pass>
|
||||
createTritonNvidiaGPUWSFeasibilityCheckingPass(int computeCapability = 90);
|
||||
|
||||
std::unique_ptr<Pass>
|
||||
createTritonNvidiaGPUWSDecomposingPass(int computeCapability = 90);
|
||||
|
||||
std::unique_ptr<Pass>
|
||||
createTritonNvidiaGPUWSPipelinePass(int numStages = 3, int numWarps = 4,
|
||||
int computeCapability = 90);
|
||||
|
||||
std::unique_ptr<Pass>
|
||||
createTritonNvidiaGPUWSMutexPass(int computeCapability = 90);
|
||||
|
||||
std::unique_ptr<Pass>
|
||||
createTritonNvidiaGPUWSMaterializationPass(int computeCapability = 90);
|
||||
|
||||
std::unique_ptr<Pass>
|
||||
createTritonNvidiaGPUFenceInsertionPass(int computeCapability = 90);
|
||||
|
||||
std::unique_ptr<Pass>
|
||||
createTritonGPURewriteTensorPointerPass(int computeCapability = 80);
|
||||
|
||||
/// Generate the code for registering passes.
|
||||
#define GEN_PASS_REGISTRATION
|
||||
#include "triton/Dialect/TritonNvidiaGPU/Transforms/Passes.h.inc"
|
||||
|
||||
} // namespace mlir
|
||||
#endif // TRITON_DIALECT_TRITONNVIDIAGPU_TRANSFORMS_PASSES_H_
|
||||
228
include/triton/Dialect/TritonNvidiaGPU/Transforms/Passes.td
Normal file
228
include/triton/Dialect/TritonNvidiaGPU/Transforms/Passes.td
Normal file
@@ -0,0 +1,228 @@
|
||||
// Copyright (c) 2023 NVIDIA Corporation & Affiliates. All rights reserved.
|
||||
//
|
||||
// Permission is hereby granted, free of charge, to any person obtaining
|
||||
// a copy of this software and associated documentation files
|
||||
// (the "Software"), to deal in the Software without restriction,
|
||||
// including without limitation the rights to use, copy, modify, merge,
|
||||
// publish, distribute, sublicense, and/or sell copies of the Software,
|
||||
// and to permit persons to whom the Software is furnished to do so,
|
||||
// subject to the following conditions:
|
||||
//
|
||||
// The above copyright notice and this permission notice shall be
|
||||
// included in all copies or substantial portions of the Software.
|
||||
//
|
||||
// THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND,
|
||||
// EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF
|
||||
// MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT.
|
||||
// IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY
|
||||
// CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION OF CONTRACT,
|
||||
// TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION WITH THE
|
||||
// SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE.
|
||||
|
||||
#ifndef TRITONNVIDIAGPU_PASSES
|
||||
#define TRITONNVIDIAGPU_PASSES
|
||||
|
||||
include "mlir/Pass/PassBase.td"
|
||||
|
||||
def MaterializeLoadStore : Pass<"triton-nvidia-gpu-materialize-load-store", "mlir::ModuleOp"> {
|
||||
let summary = "materialize load & store";
|
||||
|
||||
let description = [{
|
||||
This pass works after pipeline pass, converting the remaining tt.LoadOp taking
|
||||
ptr<tensor> as input into ttg.InsertSliceAsyncOp and emit proper barriers
|
||||
}];
|
||||
|
||||
let constructor = "mlir::createTritonNvidiaGPUMaterializeLoadStorePass()";
|
||||
|
||||
let dependentDialects = ["mlir::triton::gpu::TritonGPUDialect",
|
||||
"mlir::triton::nvidia_gpu::TritonNvidiaGPUDialect",
|
||||
"mlir::scf::SCFDialect",
|
||||
"mlir::arith::ArithDialect"];
|
||||
|
||||
let options = [
|
||||
Option<"numWarps", "num-warps",
|
||||
"int32_t", /*default*/"4",
|
||||
"number of warps per block">,
|
||||
Option<"computeCapability", "compute-capability",
|
||||
"int32_t", /*default*/"80",
|
||||
"device compute capability">
|
||||
];
|
||||
}
|
||||
|
||||
def TritonGPUPlanCTAPass : Pass<"triton-nvidia-gpu-plan-cta", "mlir::ModuleOp"> {
|
||||
let summary = "plan CTA";
|
||||
|
||||
let description = [{
|
||||
Plan CTAs in CGA
|
||||
}];
|
||||
|
||||
let constructor = "mlir::createTritonNvidiaGPUPlanCTAPass()";
|
||||
|
||||
let dependentDialects = [
|
||||
"mlir::triton::gpu::TritonGPUDialect",
|
||||
"mlir::triton::nvidia_gpu::TritonNvidiaGPUDialect"
|
||||
];
|
||||
}
|
||||
|
||||
def TritonGPUWSFeasibilityChecking : Pass<"triton-nvidia-gpu-ws-feasibility-checking", "mlir::ModuleOp"> {
|
||||
let summary = "Attach attr named TritonNvidiaGPUDialect::getWSSupportedAttrName() if auto WS supported";
|
||||
|
||||
let description = [{
|
||||
Since not every legal triton kernels can be auto WS, this pass does some (conservative) check
|
||||
and attaches an attribute named TritonNvidiaGPUDialect::getWSSupportedAttrName() on
|
||||
the input module op if the kernel is supported.
|
||||
}];
|
||||
|
||||
let constructor = "mlir::createTritonNvidiaGPUWSFeasibilityCheckingPass()";
|
||||
|
||||
let dependentDialects = [
|
||||
"mlir::triton::gpu::TritonGPUDialect",
|
||||
"mlir::triton::nvidia_gpu::TritonNvidiaGPUDialect"
|
||||
];
|
||||
|
||||
let options = [
|
||||
Option<"computeCapability", "compute-capability",
|
||||
"int32_t", /*default*/"90",
|
||||
"device compute capability">
|
||||
];
|
||||
}
|
||||
|
||||
def TritonGPUWSDecomposing : Pass<"triton-nvidia-gpu-ws-decomposing", "mlir::ModuleOp"> {
|
||||
let summary = "Clustering on the ops according to their performance hotspots";
|
||||
|
||||
let description = [{
|
||||
Based on compute capability and heuristics,
|
||||
this pass will identify some operations to be executed in different agents,
|
||||
by marking them with async 'label'. E.g.,
|
||||
input:
|
||||
%1 = tt,load %0 ...
|
||||
%4 = tt.dot %1, %2, %3 ...
|
||||
output:
|
||||
%1 = tt,load %0 {async_agent = 0} ...
|
||||
%4 = tt.dot %1, %2, %3 {async_agent = 1} : ...
|
||||
}];
|
||||
|
||||
let constructor = "mlir::createTritonNvidiaGPUWSDecomposingPass()";
|
||||
|
||||
let dependentDialects = [
|
||||
"mlir::triton::gpu::TritonGPUDialect",
|
||||
"mlir::triton::nvidia_gpu::TritonNvidiaGPUDialect"
|
||||
];
|
||||
|
||||
let options = [
|
||||
Option<"computeCapability", "compute-capability",
|
||||
"int32_t", /*default*/"80",
|
||||
"device compute capability">
|
||||
];
|
||||
}
|
||||
|
||||
def TritonGPUWSPipeline : Pass<"triton-nvidia-gpu-ws-pipeline", "mlir::ModuleOp"> {
|
||||
let summary = "Warp specialization pipeline";
|
||||
|
||||
let description = [{
|
||||
}];
|
||||
|
||||
let constructor = "mlir::createTritonNvidiaGPUWSPipelinePass()";
|
||||
|
||||
let dependentDialects = ["mlir::triton::gpu::TritonGPUDialect",
|
||||
"mlir::triton::nvidia_gpu::TritonNvidiaGPUDialect",
|
||||
"mlir::scf::SCFDialect",
|
||||
"mlir::arith::ArithDialect"];
|
||||
|
||||
let options = [
|
||||
Option<"numStages", "num-stages",
|
||||
"int32_t", /*default*/"3",
|
||||
"number of pipeline stages">,
|
||||
Option<"numWarps", "num-warps",
|
||||
"int32_t", /*default*/"12",
|
||||
"number of warps per block">,
|
||||
Option<"computeCapability", "compute-capability",
|
||||
"int32_t", /*default*/"90",
|
||||
"device compute capability">
|
||||
];
|
||||
}
|
||||
|
||||
def TritonGPUWSMutex : Pass<"triton-nvidia-gpu-ws-mutex", "mlir::ModuleOp"> {
|
||||
let summary = "Warp specialization mutex syncronization";
|
||||
|
||||
let description = [{
|
||||
create mutex syncronization for persistent kernel. (as "2 Math WG" persistent kernel in cutlass)
|
||||
For example, the agent containing dot and store will be divided into two sub-agent,
|
||||
which execute dot and store alternately. i.e.:
|
||||
sub-agent-0: dot | store | dot | ... | store
|
||||
sub-agent-1: | dot | store | ... | dot | store
|
||||
}];
|
||||
|
||||
let constructor = "mlir::createTritonNvidiaGPUWSMutexPass()";
|
||||
|
||||
let dependentDialects = ["mlir::triton::gpu::TritonGPUDialect",
|
||||
"mlir::triton::nvidia_gpu::TritonNvidiaGPUDialect",
|
||||
"mlir::scf::SCFDialect",
|
||||
"mlir::arith::ArithDialect"];
|
||||
|
||||
let options = [
|
||||
Option<"computeCapability", "compute-capability",
|
||||
"int32_t", /*default*/"80",
|
||||
"device compute capability">
|
||||
];
|
||||
}
|
||||
|
||||
def TritonGPUWSMaterialization : Pass<"triton-nvidia-gpu-ws-materialization", "mlir::ModuleOp"> {
|
||||
let summary = "Warp specialization materialization";
|
||||
|
||||
let description = [{
|
||||
}];
|
||||
|
||||
let constructor = "mlir::createTritonNvidiaGPUWSMaterializationPass()";
|
||||
|
||||
let dependentDialects = ["mlir::triton::gpu::TritonGPUDialect",
|
||||
"mlir::triton::nvidia_gpu::TritonNvidiaGPUDialect"];
|
||||
|
||||
let options = [
|
||||
Option<"computeCapability", "compute-capability",
|
||||
"int32_t", /*default*/"90",
|
||||
"device compute capability">
|
||||
];
|
||||
}
|
||||
|
||||
def TritonGPUFenceInsertion : Pass<"triton-nvidia-gpu-fence-insertion", "mlir::ModuleOp"> {
|
||||
let summary = "Insert fences across generic and async proxy";
|
||||
|
||||
let description = [{
|
||||
}];
|
||||
|
||||
let constructor = "mlir::createTritonNvidiaGPUFenceInsertionPass()";
|
||||
|
||||
let dependentDialects = [
|
||||
"mlir::triton::gpu::TritonGPUDialect",
|
||||
"mlir::triton::nvidia_gpu::TritonNvidiaGPUDialect"
|
||||
];
|
||||
|
||||
let options = [
|
||||
Option<"computeCapability", "compute-capability",
|
||||
"int32_t", /*default*/"90",
|
||||
"device compute capability">
|
||||
];
|
||||
}
|
||||
|
||||
def TritonGPURewriteTensorPointer : Pass</*cli-arg*/"tritongpu-rewrite-tensor-pointer", /*Op*/"mlir::ModuleOp"> {
|
||||
let summary = "Rewrite load/stores with tensor pointers into legacy load/stores";
|
||||
let description = [{
|
||||
This pass rewrites all load/store semantics initiated by a `tt.make_tensor_ptr` and `tt.advance` into legacy
|
||||
semantics. After this pass, `tt.make_tensor_ptr` and `tt.advance` will disappear, and it generates logics to compute
|
||||
the pointer/mask/other for each load/store.
|
||||
}];
|
||||
|
||||
let constructor = "mlir::createTritonGPURewriteTensorPointerPass()";
|
||||
|
||||
let dependentDialects = ["mlir::triton::gpu::TritonGPUDialect",
|
||||
"mlir::triton::TritonDialect"];
|
||||
|
||||
let options = [
|
||||
Option<"computeCapability", "compute-capability",
|
||||
"int32_t", /*default*/"80",
|
||||
"device compute capability">
|
||||
];
|
||||
}
|
||||
|
||||
#endif
|
||||
95
include/triton/Dialect/TritonNvidiaGPU/Transforms/Utility.h
Normal file
95
include/triton/Dialect/TritonNvidiaGPU/Transforms/Utility.h
Normal file
@@ -0,0 +1,95 @@
|
||||
/*
|
||||
* Copyright (c) 2023 NVIDIA Corporation & Affiliates. All rights reserved.
|
||||
*
|
||||
* Permission is hereby granted, free of charge, to any person obtaining
|
||||
* a copy of this software and associated documentation files
|
||||
* (the "Software"), to deal in the Software without restriction,
|
||||
* including without limitation the rights to use, copy, modify, merge,
|
||||
* publish, distribute, sublicense, and/or sell copies of the Software,
|
||||
* and to permit persons to whom the Software is furnished to do so,
|
||||
* subject to the following conditions:
|
||||
*
|
||||
* The above copyright notice and this permission notice shall be
|
||||
* included in all copies or substantial portions of the Software.
|
||||
*
|
||||
* THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND,
|
||||
* EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF
|
||||
* MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT.
|
||||
* IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY
|
||||
* CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION OF CONTRACT,
|
||||
* TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION WITH THE
|
||||
* SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE.
|
||||
*/
|
||||
|
||||
#ifndef TRITON_DIALECT_TRITONNVIDIAGPU_TRANSFORMS_UTILITY_H_
|
||||
#define TRITON_DIALECT_TRITONNVIDIAGPU_TRANSFORMS_UTILITY_H_
|
||||
|
||||
#include "mlir/IR/Matchers.h"
|
||||
#include "mlir/IR/PatternMatch.h"
|
||||
#include "llvm/ADT/MapVector.h"
|
||||
|
||||
namespace mlir {
|
||||
|
||||
// 0 is reserved for default sync.
|
||||
// TODO: comprehensive mechanism to globally manage namedbarrier.
|
||||
static int const nameBarrierIdBegin = 1;
|
||||
static int nameBarrierIdEnd = 16;
|
||||
|
||||
/// Helper functions for async agent
|
||||
typedef int AgentId;
|
||||
SmallVector<AgentId> getAgentIds(Operation *op);
|
||||
bool hasAgentId(Operation *op, AgentId agentId);
|
||||
void setAgentIds(Operation *op, ArrayRef<AgentId> agentIds);
|
||||
SmallVector<AgentId> collectAgentIds(Operation *op);
|
||||
void addAgentIds(Operation *op, ArrayRef<int> agents);
|
||||
SmallVector<int> getMutexBarIds(Operation *op);
|
||||
SmallVector<int> getMutexNumThreads(Operation *op);
|
||||
|
||||
class OpBuilderWithAgentIds : public OpBuilder {
|
||||
public:
|
||||
OpBuilderWithAgentIds(MLIRContext *context) : OpBuilder(context) {}
|
||||
|
||||
void setAgentIdsFromArray(ArrayRef<AgentId> newAgentIds) {
|
||||
agentIds = SmallVector<AgentId>(newAgentIds.begin(), newAgentIds.end());
|
||||
}
|
||||
|
||||
void setAgentIdsFromOp(Operation *op) {
|
||||
setAgentIdsFromArray(getAgentIds(op));
|
||||
}
|
||||
|
||||
void setAgentIdsFromValueUsers(Value value) {
|
||||
SetVector<AgentId> agentIdSet;
|
||||
for (Operation *user : value.getUsers())
|
||||
for (AgentId agentId : getAgentIds(user))
|
||||
agentIdSet.insert(agentId);
|
||||
setAgentIdsFromArray(agentIdSet.getArrayRef());
|
||||
}
|
||||
|
||||
template <typename OpTy, typename... Args>
|
||||
OpTy createWithAgentIds(Args &&...args) {
|
||||
OpTy op = create<OpTy>(std::forward<Args>(args)...);
|
||||
if (!agentIds.empty())
|
||||
setAgentIds(op, agentIds);
|
||||
return op;
|
||||
}
|
||||
|
||||
private:
|
||||
SmallVector<AgentId> agentIds;
|
||||
};
|
||||
|
||||
/// Constant agent ids
|
||||
constexpr AgentId kLoadAgentId = 0;
|
||||
constexpr AgentId kDotAgentId = 1;
|
||||
|
||||
bool isWSCandidateLoad(Operation *op);
|
||||
bool isWSSupported(ModuleOp m, int computeCapability);
|
||||
|
||||
LogicalResult getDependentValues(Value val, DenseSet<Value> &depSet,
|
||||
const DenseSet<Value> &stopSet = {});
|
||||
LogicalResult getDependentValues(Operation *op, DenseSet<Value> &depSet,
|
||||
const DenseSet<Value> &stopSet = {});
|
||||
DenseSet<Operation *> getDependentOps(DenseSet<Value> &depSet);
|
||||
|
||||
} // namespace mlir
|
||||
|
||||
#endif // TRITON_DIALECT_TRITONNVIDIAGPU_TRANSFORMS_UTILITY_H_
|
||||
19
include/triton/Target/AMDGCN/AMDGCNTranslation.h
Normal file
19
include/triton/Target/AMDGCN/AMDGCNTranslation.h
Normal file
@@ -0,0 +1,19 @@
|
||||
#ifndef TRITON_TARGET_AMDGCNTRANSLATION_H
|
||||
#define TRITON_TARGET_AMDGCNTRANSLATION_H
|
||||
|
||||
#include <string>
|
||||
#include <tuple>
|
||||
|
||||
namespace llvm {
|
||||
class Module;
|
||||
} // namespace llvm
|
||||
|
||||
namespace triton {
|
||||
|
||||
// Translate LLVM IR to AMDGCN code.
|
||||
std::tuple<std::string, std::string>
|
||||
translateLLVMIRToAMDGCN(llvm::Module &module, std::string cc);
|
||||
|
||||
} // namespace triton
|
||||
|
||||
#endif
|
||||
@@ -1,5 +1,6 @@
|
||||
#ifndef TRITON_TARGET_LLVM_IR_LLVM_IR_TRANSLATION_H
|
||||
#define TRITON_TARGET_LLVM_IR_LLVM_IR_TRANSLATION_H
|
||||
#include "triton/Target/PTX/TmaMetadata.h"
|
||||
#include "llvm/ADT/StringRef.h"
|
||||
#include <memory>
|
||||
#include <string>
|
||||
@@ -26,6 +27,7 @@ void addExternalLibs(mlir::ModuleOp &module,
|
||||
std::unique_ptr<llvm::Module>
|
||||
translateTritonGPUToLLVMIR(llvm::LLVMContext *llvmContext,
|
||||
mlir::ModuleOp module, int computeCapability,
|
||||
mlir::triton::gpu::TMAMetadataTy &tmaInfos,
|
||||
bool isROCM);
|
||||
|
||||
// Translate mlir LLVM dialect to LLVMIR, return null if failed.
|
||||
@@ -33,6 +35,9 @@ std::unique_ptr<llvm::Module>
|
||||
translateLLVMToLLVMIR(llvm::LLVMContext *llvmContext, mlir::ModuleOp module,
|
||||
bool isROCM);
|
||||
|
||||
bool linkExternLib(llvm::Module &module, llvm::StringRef name,
|
||||
llvm::StringRef path, bool isROCM);
|
||||
|
||||
} // namespace triton
|
||||
} // namespace mlir
|
||||
|
||||
|
||||
107
include/triton/Target/PTX/TmaMetadata.h
Normal file
107
include/triton/Target/PTX/TmaMetadata.h
Normal file
@@ -0,0 +1,107 @@
|
||||
/*
|
||||
* Copyright (c) 2023 NVIDIA Corporation & Affiliates. All rights reserved.
|
||||
*
|
||||
* Permission is hereby granted, free of charge, to any person obtaining
|
||||
* a copy of this software and associated documentation files
|
||||
* (the "Software"), to deal in the Software without restriction,
|
||||
* including without limitation the rights to use, copy, modify, merge,
|
||||
* publish, distribute, sublicense, and/or sell copies of the Software,
|
||||
* and to permit persons to whom the Software is furnished to do so,
|
||||
* subject to the following conditions:
|
||||
*
|
||||
* The above copyright notice and this permission notice shall be
|
||||
* included in all copies or substantial portions of the Software.
|
||||
*
|
||||
* THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND,
|
||||
* EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF
|
||||
* MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT.
|
||||
* IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY
|
||||
* CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION OF CONTRACT,
|
||||
* TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION WITH THE
|
||||
* SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE.
|
||||
*/
|
||||
|
||||
#ifndef TRITON_TARGET_PTX_TMAMETADATA_H
|
||||
#define TRITON_TARGET_PTX_TMAMETADATA_H
|
||||
|
||||
#include "python/triton/third_party/cuda/include/cuda.h"
|
||||
#include "llvm/ADT/StringRef.h"
|
||||
#include "llvm/Support/Debug.h"
|
||||
#include "llvm/Support/Format.h"
|
||||
#include "llvm/Support/FormatVariadic.h"
|
||||
#include <map>
|
||||
#include <utility>
|
||||
#include <vector>
|
||||
|
||||
namespace mlir {
|
||||
namespace triton {
|
||||
namespace gpu {
|
||||
|
||||
struct TMAInfo {
|
||||
// --------------------------------------------
|
||||
// informations to be filled into CUtensorMaps
|
||||
int tensorDataType;
|
||||
|
||||
uint32_t tensorRank;
|
||||
|
||||
// the argument indices for the runtime to get globalAddresses
|
||||
size_t globalAddressArgIdx;
|
||||
|
||||
// the argument indices for the runtime to get globalDims, -1 stands for this
|
||||
// dim is padded
|
||||
std::vector<int32_t> globalDimsArgIdx;
|
||||
|
||||
// the argument indices for the runtime to get globalStrides, -1 stands for
|
||||
// this dim is padded the runtime need to map the value to internal format
|
||||
std::vector<int32_t> globalStridesArgIdx;
|
||||
|
||||
std::vector<uint32_t> boxDims;
|
||||
|
||||
std::vector<uint32_t> elementStrides;
|
||||
|
||||
int interleave;
|
||||
|
||||
int swizzle;
|
||||
|
||||
int l2Promotion;
|
||||
|
||||
int oobFill;
|
||||
|
||||
// --------------------------------------------
|
||||
// the argument indices for the runtime to send the address of tma_desc to the
|
||||
// binary
|
||||
int TMADescArgIdx;
|
||||
|
||||
template <typename T>
|
||||
void dump_vec(const std::vector<T> &vec, llvm::StringRef info) const {
|
||||
llvm::errs() << info << ": ";
|
||||
for (const T &e : vec)
|
||||
llvm::errs() << e << ",";
|
||||
llvm::errs() << "\n";
|
||||
}
|
||||
|
||||
void dump() const {
|
||||
llvm::errs() << "TMA Info: ----------"
|
||||
<< "\n";
|
||||
llvm::errs() << "-- tensorDataType: " << tensorDataType
|
||||
<< ", tensorRank: " << tensorRank << "\n";
|
||||
llvm::errs() << "-- globalAddressArgIdx: " << globalAddressArgIdx << "\n";
|
||||
llvm::errs() << "-- TMADescArgIdx: " << TMADescArgIdx << "\n";
|
||||
dump_vec<int32_t>(globalDimsArgIdx, "-- globalDimsArgIdx");
|
||||
dump_vec<int32_t>(globalStridesArgIdx, "-- globalStridesArgIdx");
|
||||
dump_vec<uint32_t>(boxDims, "-- boxDims");
|
||||
dump_vec<uint32_t>(elementStrides, "-- elementStrides");
|
||||
llvm::errs() << "-- interleave: " << interleave << "\n";
|
||||
llvm::errs() << "-- swizzle: " << swizzle << "\n";
|
||||
llvm::errs() << "-- l2Promotion: " << l2Promotion << "\n";
|
||||
llvm::errs() << "-- oobFill: " << oobFill << "\n";
|
||||
};
|
||||
};
|
||||
|
||||
using TMAMetadataTy = std::vector<TMAInfo>;
|
||||
|
||||
} // namespace gpu
|
||||
} // namespace triton
|
||||
} // namespace mlir
|
||||
|
||||
#endif // TRITON_TARGET_PTX_TMAMETADATA_H
|
||||
@@ -25,9 +25,13 @@
|
||||
#include <algorithm>
|
||||
#include <cstdlib>
|
||||
#include <string>
|
||||
|
||||
namespace triton {
|
||||
|
||||
const std::set<std::string> ENV_VARS = {
|
||||
"ENABLE_MMA_V3", "TRITON_DISABLE_LINE_INFO", "DISABLE_FAST_REDUCTION",
|
||||
"ENABLE_TMA", "MLIR_ENABLE_DUMP", "LLVM_IR_ENABLE_DUMP",
|
||||
"AMDGCN_ENABLE_DUMP"};
|
||||
|
||||
namespace tools {
|
||||
|
||||
inline std::string getenv(const char *name) {
|
||||
@@ -39,6 +43,9 @@ inline std::string getenv(const char *name) {
|
||||
}
|
||||
|
||||
inline bool getBoolEnv(const std::string &env) {
|
||||
std::string msg = "Environment variable " + env + " is not recognized";
|
||||
assert(triton::ENV_VARS.find(env.c_str()) != triton::ENV_VARS.end() &&
|
||||
msg.c_str());
|
||||
const char *s = std::getenv(env.c_str());
|
||||
std::string str(s ? s : "");
|
||||
std::transform(str.begin(), str.end(), str.begin(),
|
||||
|
||||
@@ -2,6 +2,7 @@
|
||||
#include "mlir/Dialect/Tensor/IR/Tensor.h"
|
||||
#include "triton/Analysis/Utility.h"
|
||||
#include "triton/Dialect/TritonGPU/IR/Dialect.h"
|
||||
#include "triton/Dialect/TritonNvidiaGPU/IR/Dialect.h"
|
||||
|
||||
namespace mlir {
|
||||
|
||||
@@ -27,17 +28,21 @@ void SharedMemoryAliasAnalysis::visitOperation(
|
||||
// These ops may allocate a new shared memory buffer.
|
||||
auto result = op->getResult(0);
|
||||
// XXX(Keren): the following ops are always aliasing for now
|
||||
if (isa<triton::gpu::ExtractSliceOp, triton::TransOp>(op)) {
|
||||
if (isa<triton::gpu::ExtractSliceOp, triton::TransOp,
|
||||
triton::nvidia_gpu::ExtractMBarrierOp>(op)) {
|
||||
// extract_slice %src
|
||||
// trans %src
|
||||
aliasInfo = AliasInfo(operands[0]->getValue());
|
||||
pessimistic = false;
|
||||
} else if (isa<tensor::InsertSliceOp, triton::gpu::InsertSliceAsyncOp>(
|
||||
op)) {
|
||||
} else if (isa<tensor::InsertSliceOp, triton::gpu::InsertSliceAsyncOp,
|
||||
triton::nvidia_gpu::InsertSliceAsyncV2Op>(op)) {
|
||||
// insert_slice_async %src, %dst, %index
|
||||
// insert_slice %src into %dst[%offsets]
|
||||
aliasInfo = AliasInfo(operands[1]->getValue());
|
||||
pessimistic = false;
|
||||
} else if (isa<triton::nvidia_gpu::StoreAsyncOp>(op)) {
|
||||
aliasInfo = AliasInfo(operands[0]->getValue());
|
||||
pessimistic = false;
|
||||
} else if (triton::gpu::isSharedEncoding(result)) {
|
||||
aliasInfo.insert(result);
|
||||
pessimistic = false;
|
||||
|
||||
@@ -16,6 +16,7 @@ using ::mlir::triton::gpu::DotOperandEncodingAttr;
|
||||
using ::mlir::triton::gpu::getContigPerThread;
|
||||
using ::mlir::triton::gpu::getOrder;
|
||||
using ::mlir::triton::gpu::getShapePerCTA;
|
||||
using ::mlir::triton::gpu::getShapePerCTATile;
|
||||
using ::mlir::triton::gpu::getSizePerThread;
|
||||
using ::mlir::triton::gpu::MmaEncodingAttr;
|
||||
using ::mlir::triton::gpu::SharedEncodingAttr;
|
||||
@@ -57,11 +58,23 @@ getScratchConfigForCvtLayout(triton::gpu::ConvertLayoutOp op, unsigned &inVec,
|
||||
Attribute srcLayout = srcTy.getEncoding();
|
||||
Attribute dstLayout = dstTy.getEncoding();
|
||||
|
||||
// MmaToDotShortcut doesn't use shared mem
|
||||
if (srcLayout.isa<MmaEncodingAttr>() &&
|
||||
dstLayout.isa<DotOperandEncodingAttr>())
|
||||
if (isMmaToDotShortcut(srcTy, dstTy))
|
||||
return {};
|
||||
if (shouldUseDistSmem(srcLayout, dstLayout)) {
|
||||
// TODO: padding to avoid bank conflicts
|
||||
return convertType<unsigned, int64_t>(getShapePerCTA(srcTy));
|
||||
}
|
||||
|
||||
// MmaToDotShortcut and MmaToMmaShortcut doesn't use shared mem
|
||||
if (auto srcMmaLayout = srcLayout.dyn_cast<MmaEncodingAttr>()) {
|
||||
if (dstLayout.isa<DotOperandEncodingAttr>()) {
|
||||
if (isMmaToDotShortcut(srcTy, dstTy)) {
|
||||
return {};
|
||||
}
|
||||
} else if (auto dstMmaLayout = dstLayout.dyn_cast<MmaEncodingAttr>()) {
|
||||
if (isMmaToMmaShortcut(srcTy, dstTy)) {
|
||||
return {};
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
assert(srcLayout && dstLayout &&
|
||||
"Unexpected layout in getScratchConfigForCvtLayout()");
|
||||
@@ -73,18 +86,18 @@ getScratchConfigForCvtLayout(triton::gpu::ConvertLayoutOp op, unsigned &inVec,
|
||||
inVec = outOrd[0] == 0 ? 1 : inOrd[0] == 0 ? 1 : srcContigPerThread;
|
||||
outVec = outOrd[0] == 0 ? 1 : dstContigPerThread;
|
||||
|
||||
auto srcShape = srcTy.getShape();
|
||||
auto dstShape = dstTy.getShape();
|
||||
auto srcShapePerCTA = getShapePerCTA(srcLayout, srcShape);
|
||||
auto dstShapePerCTA = getShapePerCTA(dstLayout, dstShape);
|
||||
auto srcShapePerCTA = getShapePerCTA(srcTy);
|
||||
auto dstShapePerCTA = getShapePerCTA(dstTy);
|
||||
auto srcShapePerCTATile = getShapePerCTATile(srcLayout, srcTy.getShape());
|
||||
auto dstShapePerCTATile = getShapePerCTATile(dstLayout, dstTy.getShape());
|
||||
|
||||
unsigned rank = dstTy.getRank();
|
||||
SmallVector<unsigned> paddedRepShape(rank);
|
||||
unsigned pad = std::max(inVec, outVec);
|
||||
for (unsigned d = 0; d < rank; ++d) {
|
||||
paddedRepShape[d] =
|
||||
std::max(std::min<unsigned>(srcTy.getShape()[d], srcShapePerCTA[d]),
|
||||
std::min<unsigned>(dstTy.getShape()[d], dstShapePerCTA[d]));
|
||||
std::max(std::min<unsigned>(srcShapePerCTA[d], srcShapePerCTATile[d]),
|
||||
std::min<unsigned>(dstShapePerCTA[d], dstShapePerCTATile[d]));
|
||||
}
|
||||
if (rank == 1)
|
||||
return paddedRepShape;
|
||||
@@ -146,20 +159,45 @@ private:
|
||||
// For example: %a = scf.if -> yield
|
||||
// %a must be allocated elsewhere by other operations.
|
||||
// FIXME(Keren): extract and insert are always alias for now
|
||||
if (!maybeSharedAllocationOp(op) || maybeAliasOp(op)) {
|
||||
if (!maybeSharedAllocationOp(op) || maybeAliasOp(op))
|
||||
return;
|
||||
}
|
||||
|
||||
// XXX(Keren): Why this hard-coded alignment?
|
||||
size_t kAlignment = 8;
|
||||
for (Value result : op->getResults()) {
|
||||
if (triton::gpu::isSharedEncoding(result)) {
|
||||
// Bytes could be a different value once we support padding or other
|
||||
// allocation policies.
|
||||
auto tensorType = result.getType().dyn_cast<RankedTensorType>();
|
||||
auto bytes = tensorType.getNumElements() *
|
||||
auto shapePerCTA = triton::gpu::getShapePerCTA(tensorType);
|
||||
auto bytes = product<int64_t>(shapePerCTA) *
|
||||
tensorType.getElementTypeBitWidth() / 8;
|
||||
allocation->addBuffer<BufferT::BufferKind::Explicit>(result, bytes);
|
||||
|
||||
// XXX(Keren): magic numbers 256 and 1024
|
||||
// benzh@maybe alignment should be passed in.
|
||||
// Software swizzling calculates phase based on offset, while hardware
|
||||
// swizzling do that based on physical address. Thus only by setting the
|
||||
// alignment to 1024 can ensure the correctness.
|
||||
if (bytes > 256)
|
||||
kAlignment = 1024;
|
||||
allocation->addBuffer<BufferT::BufferKind::Explicit>(result, bytes,
|
||||
kAlignment);
|
||||
}
|
||||
}
|
||||
if (isa<triton::nvidia_gpu::AllocMBarrierOp>(op)) {
|
||||
Value result = op->getResult(0);
|
||||
if (!result.getType().isa<RankedTensorType>())
|
||||
// In case AllocMBarrierOp is allocating scalar mbarriers
|
||||
allocation->addBuffer<BufferT::BufferKind::Explicit>(result, 8,
|
||||
kAlignment);
|
||||
}
|
||||
}
|
||||
|
||||
template <BufferT::BufferKind T>
|
||||
void maybeAddScratchBuffer(Operation *op, unsigned bytes,
|
||||
unsigned alignment) {
|
||||
if (bytes > 0)
|
||||
allocation->addBuffer<T>(op, bytes, alignment);
|
||||
}
|
||||
|
||||
template <BufferT::BufferKind T>
|
||||
@@ -170,14 +208,17 @@ private:
|
||||
|
||||
/// Initializes temporary shared memory for a given operation.
|
||||
void getScratchValueSize(Operation *op) {
|
||||
const size_t scratchAlignment = 128;
|
||||
if (auto reduceOp = dyn_cast<triton::ReduceOp>(op)) {
|
||||
ReduceOpHelper helper(reduceOp);
|
||||
unsigned bytes = helper.getScratchSizeInBytes();
|
||||
maybeAddScratchBuffer<BufferT::BufferKind::Scratch>(op, bytes);
|
||||
maybeAddScratchBuffer<BufferT::BufferKind::Scratch>(op, bytes,
|
||||
scratchAlignment);
|
||||
} else if (auto scanOp = dyn_cast<triton::ScanOp>(op)) {
|
||||
ScanLoweringHelper helper(scanOp);
|
||||
unsigned bytes = helper.getScratchSizeInBytes();
|
||||
maybeAddScratchBuffer<BufferT::BufferKind::Scratch>(op, bytes);
|
||||
maybeAddScratchBuffer<BufferT::BufferKind::Scratch>(op, bytes,
|
||||
scratchAlignment);
|
||||
} else if (auto cvtLayout = dyn_cast<triton::gpu::ConvertLayoutOp>(op)) {
|
||||
auto srcTy = cvtLayout.getSrc().getType().cast<RankedTensorType>();
|
||||
auto dstTy = cvtLayout.getResult().getType().cast<RankedTensorType>();
|
||||
@@ -201,7 +242,8 @@ private:
|
||||
srcTy.getElementType().isa<triton::PointerType>()
|
||||
? elems * kPtrBitWidth / 8
|
||||
: elems * std::max<int>(8, srcTy.getElementTypeBitWidth()) / 8;
|
||||
maybeAddScratchBuffer<BufferT::BufferKind::Scratch>(op, bytes);
|
||||
maybeAddScratchBuffer<BufferT::BufferKind::Scratch>(op, bytes,
|
||||
scratchAlignment);
|
||||
} else if (auto atomicRMWOp = dyn_cast<triton::AtomicRMWOp>(op)) {
|
||||
auto value = op->getOperand(0);
|
||||
// only scalar requires scratch memory
|
||||
@@ -218,7 +260,8 @@ private:
|
||||
elemTy.isa<triton::PointerType>()
|
||||
? elems * kPtrBitWidth / 8
|
||||
: elems * std::max<int>(8, elemTy.getIntOrFloatBitWidth()) / 8;
|
||||
maybeAddScratchBuffer<BufferT::BufferKind::Scratch>(op, bytes);
|
||||
maybeAddScratchBuffer<BufferT::BufferKind::Scratch>(op, bytes,
|
||||
scratchAlignment);
|
||||
}
|
||||
} else if (auto atomicCASOp = dyn_cast<triton::AtomicCASOp>(op)) {
|
||||
auto value = op->getOperand(0);
|
||||
@@ -230,13 +273,15 @@ private:
|
||||
auto bytes = elemTy.isa<triton::PointerType>()
|
||||
? elems * kPtrBitWidth / 8
|
||||
: elems * elemTy.getIntOrFloatBitWidth() / 8;
|
||||
maybeAddScratchBuffer<BufferT::BufferKind::Scratch>(op, bytes);
|
||||
maybeAddScratchBuffer<BufferT::BufferKind::Scratch>(op, bytes,
|
||||
scratchAlignment);
|
||||
} else if (auto callOp = dyn_cast<CallOpInterface>(op)) {
|
||||
auto callable = callOp.resolveCallable();
|
||||
auto funcOp = dyn_cast<FunctionOpInterface>(callable);
|
||||
auto *funcAlloc = &(*funcAllocMap)[funcOp];
|
||||
auto bytes = funcAlloc->getSharedMemorySize();
|
||||
maybeAddScratchBuffer<BufferT::BufferKind::Virtual>(op, bytes);
|
||||
maybeAddScratchBuffer<BufferT::BufferKind::Virtual>(op, bytes,
|
||||
scratchAlignment);
|
||||
}
|
||||
}
|
||||
|
||||
@@ -356,6 +401,12 @@ private:
|
||||
// Analyze liveness of explicit buffers
|
||||
Liveness liveness(operation);
|
||||
auto getValueLivenessRange = [&](Value value) {
|
||||
// Shared memory allocated by mbarrier cannot be reused
|
||||
if (value.getDefiningOp() &&
|
||||
isa<triton::nvidia_gpu::AllocMBarrierOp>(value.getDefiningOp()))
|
||||
return Interval(std::numeric_limits<size_t>::min(),
|
||||
std::numeric_limits<size_t>::max());
|
||||
|
||||
auto liveOperations = liveness.resolveLiveness(value);
|
||||
auto minId = std::numeric_limits<size_t>::max();
|
||||
auto maxId = std::numeric_limits<size_t>::min();
|
||||
@@ -437,17 +488,22 @@ private:
|
||||
auto xRange = bufferRange[buffer];
|
||||
bool res = xRange.intersects(range);
|
||||
for (auto val : tripleMap)
|
||||
res = res && !val.second.intersects(xRange);
|
||||
res = res &&
|
||||
!val.second.intersects(xRange); // only one buffer intersect
|
||||
return res;
|
||||
});
|
||||
if (bufferIt != xBuffers.end()) {
|
||||
auto buffer = *bufferIt;
|
||||
auto xSize = buffer->size;
|
||||
auto xRange = bufferRange.lookup(buffer);
|
||||
bufferStart[buffer] = size;
|
||||
tripleMap.insert(
|
||||
{size + xSize, Interval{std::max(range.start(), xRange.start()),
|
||||
std::min(range.end(), xRange.end())}});
|
||||
// TODO(Keren): A buffer's size shouldn't be determined here, have to
|
||||
// clean it up
|
||||
size_t alignment = buffer->alignment;
|
||||
size_t alignSize = ((size + alignment - 1) / alignment) * alignment;
|
||||
bufferStart[buffer] = alignSize;
|
||||
tripleMap.insert({alignSize + xSize,
|
||||
Interval{std::max(range.start(), xRange.start()),
|
||||
std::min(range.end(), xRange.end())}});
|
||||
// We could either insert (range.start, xRange.start) or (range.start,
|
||||
// xRange.end), both are correct and determine the potential buffer
|
||||
// offset, and the graph coloring algorithm will solve the interference,
|
||||
|
||||
@@ -14,4 +14,5 @@ add_mlir_library(TritonAnalysis
|
||||
MLIRLLVMDialect
|
||||
TritonIR
|
||||
TritonGPUIR
|
||||
TritonNvidiaGPUIR
|
||||
)
|
||||
|
||||
@@ -2,7 +2,11 @@
|
||||
#include "triton/Analysis/Alias.h"
|
||||
#include "triton/Dialect/TritonGPU/IR/Dialect.h"
|
||||
|
||||
#include "../lib/Conversion/TritonGPUToLLVM/Utility.h"
|
||||
#include "mlir/Dialect/Func/IR/FuncOps.h"
|
||||
#include "triton/Conversion/TritonGPUToLLVM/PTXAsmFormat.h"
|
||||
#include "triton/Dialect/TritonNvidiaGPU/Transforms/Utility.h"
|
||||
|
||||
#include "mlir/Dialect/GPU/IR/GPUDialect.h"
|
||||
#include "mlir/Dialect/Tensor/IR/Tensor.h"
|
||||
#include <deque>
|
||||
@@ -103,7 +107,11 @@ void MembarAnalysis::update(Operation *op, BlockInfo *blockInfo,
|
||||
return;
|
||||
}
|
||||
|
||||
if (isa<gpu::BarrierOp>(op)) {
|
||||
// TODO(Keren): Don't expose LLVM Dialect ops here
|
||||
if (isa<gpu::BarrierOp>(op) ||
|
||||
(isa<LLVM::InlineAsmOp>(op) &&
|
||||
(dyn_cast<LLVM::InlineAsmOp>(op).getAsmString().find("bar.sync") !=
|
||||
std::string::npos))) {
|
||||
// If the current op is a barrier, we sync previous reads and writes
|
||||
blockInfo->sync();
|
||||
return;
|
||||
@@ -169,12 +177,23 @@ void MembarAnalysis::update(Operation *op, BlockInfo *blockInfo,
|
||||
if (blockInfo->isIntersected(curBlockInfo)) {
|
||||
OpBuilder::InsertionGuard g(*builder);
|
||||
builder->setInsertionPoint(op);
|
||||
builder->create<gpu::BarrierOp>(op->getLoc());
|
||||
blockInfo->sync();
|
||||
// TODO(Keren): Don't expose LLVM Dialect ops here
|
||||
// TODO[shuhaoj]: Change hard code style of numThreads. Hide async_agent
|
||||
// attr. Better way to determine barId (number of agents are limited).
|
||||
if (op->hasAttr("async_agent")) {
|
||||
int agentId = getAgentIds(op).front(), roleId = 0;
|
||||
if (op->hasAttr("agent.mutex_role"))
|
||||
roleId = op->getAttrOfType<IntegerAttr>("agent.mutex_role").getInt();
|
||||
int barId = agentId + roleId + nameBarrierIdBegin;
|
||||
assert(barId < nameBarrierIdEnd);
|
||||
barSync(*builder, op, barId, 128);
|
||||
} else {
|
||||
builder->create<gpu::BarrierOp>(op->getLoc());
|
||||
blockInfo->sync();
|
||||
}
|
||||
}
|
||||
// Update the region info, even if barrier is inserted, we have to maintain
|
||||
// the current op's read/write buffers.
|
||||
blockInfo->join(curBlockInfo);
|
||||
}
|
||||
|
||||
} // namespace mlir
|
||||
|
||||
@@ -1,10 +1,14 @@
|
||||
#include "triton/Analysis/Utility.h"
|
||||
#include "mlir/Analysis/DataFlow/ConstantPropagationAnalysis.h"
|
||||
#include "mlir/Analysis/DataFlow/DeadCodeAnalysis.h"
|
||||
#include "mlir/Conversion/LLVMCommon/Pattern.h"
|
||||
#include "mlir/Dialect/ControlFlow/IR/ControlFlowOps.h"
|
||||
#include "mlir/Dialect/LLVMIR/LLVMDialect.h"
|
||||
#include "mlir/IR/Dialect.h"
|
||||
#include "mlir/IR/Matchers.h"
|
||||
#include "triton/Dialect/Triton/IR/Dialect.h"
|
||||
#include "triton/Dialect/TritonGPU/IR/Dialect.h"
|
||||
#include "triton/Dialect/TritonNvidiaGPU/IR/Dialect.h"
|
||||
#include "triton/Tools/Sys/GetEnv.hpp"
|
||||
#include <deque>
|
||||
|
||||
@@ -37,6 +41,51 @@ bool ReduceOpHelper::isFastReduction() {
|
||||
getParentOrder(getSrcLayout())[0];
|
||||
}
|
||||
|
||||
// Cases where distributed shared memory is not required in ConvertLayout:
|
||||
// (1) numCTAs == 1
|
||||
// (2) numCTAs > 1 but srcCTALayout == dstCTALayout
|
||||
// TODO: Case with SliceLayout as srcLayout and numCTAs > 1 is to be implemented
|
||||
// in the future
|
||||
bool shouldUseDistSmem(Attribute srcLayout, Attribute dstLayout) {
|
||||
unsigned numCTAs = triton::gpu::getNumCTAs(srcLayout);
|
||||
assert(numCTAs == triton::gpu::getNumCTAs(dstLayout) &&
|
||||
"Invalid layout conversion: the numbers of CTAs of src and dst "
|
||||
"layouts are different");
|
||||
|
||||
// Case (1): Never use dsmem when numCTAs == 1
|
||||
if (numCTAs == 1)
|
||||
return false;
|
||||
|
||||
// Case where CTAsPerCGA of srcLayout in the sliced dim is not 1 is not
|
||||
// implemented yet
|
||||
if (auto sliceLayout = srcLayout.dyn_cast<triton::gpu::SliceEncodingAttr>()) {
|
||||
auto dim = sliceLayout.getDim();
|
||||
auto CTAsPerCGA = triton::gpu::getCTAsPerCGA(sliceLayout.getParent());
|
||||
if (CTAsPerCGA[dim] != 1)
|
||||
assert(0 && "Layout conversion to be implemented");
|
||||
}
|
||||
|
||||
// Case where CTAsPerCGA of dstLayout in the sliced dim is not 1 is supported
|
||||
if (auto sliceLayout = dstLayout.dyn_cast<triton::gpu::SliceEncodingAttr>()) {
|
||||
auto dim = sliceLayout.getDim();
|
||||
auto CTAsPerCGA = triton::gpu::getCTAsPerCGA(sliceLayout.getParent());
|
||||
if (CTAsPerCGA[dim] != 1)
|
||||
return true;
|
||||
}
|
||||
|
||||
// The above two branches make sure that it is legal to call getCTALayout of
|
||||
// srcLayout and dstLayout
|
||||
|
||||
// Case (2): Do not use dsmem when srcCTALayout == dstCTALayout
|
||||
auto srcCTALayout = triton::gpu::getCTALayout(srcLayout);
|
||||
auto dstCTALayout = triton::gpu::getCTALayout(dstLayout);
|
||||
if (srcCTALayout == dstCTALayout)
|
||||
return false;
|
||||
|
||||
// Dsmem access is required when srcCTALayout != dstCTALayout
|
||||
return true;
|
||||
}
|
||||
|
||||
unsigned ReduceOpHelper::getInterWarpSize() {
|
||||
auto srcReduceDimSize = static_cast<unsigned>(srcShape[axis]);
|
||||
unsigned sizeIntraWarps = getIntraWarpSize();
|
||||
@@ -136,7 +185,7 @@ bool ReduceOpHelper::isSupportedLayout() {
|
||||
return true;
|
||||
}
|
||||
if (auto mmaLayout = srcLayout.dyn_cast<triton::gpu::MmaEncodingAttr>()) {
|
||||
if (mmaLayout.isAmpere()) {
|
||||
if (mmaLayout.isAmpere() || mmaLayout.isHopper()) {
|
||||
return true;
|
||||
}
|
||||
}
|
||||
@@ -282,6 +331,8 @@ bool maybeSharedAllocationOp(Operation *op) {
|
||||
return dialect &&
|
||||
(dialect->getTypeID() ==
|
||||
mlir::TypeID::get<triton::gpu::TritonGPUDialect>() ||
|
||||
dialect->getTypeID() ==
|
||||
mlir::TypeID::get<triton::nvidia_gpu::TritonNvidiaGPUDialect>() ||
|
||||
dialect->getTypeID() == mlir::TypeID::get<triton::TritonDialect>() ||
|
||||
dialect->getTypeID() == mlir::TypeID::get<arith::ArithDialect>() ||
|
||||
dialect->getTypeID() == mlir::TypeID::get<tensor::TensorDialect>());
|
||||
@@ -290,6 +341,8 @@ bool maybeSharedAllocationOp(Operation *op) {
|
||||
bool maybeAliasOp(Operation *op) {
|
||||
return isa<triton::gpu::ExtractSliceOp>(op) || isa<triton::TransOp>(op) ||
|
||||
isa<triton::gpu::InsertSliceAsyncOp>(op) ||
|
||||
isa<triton::nvidia_gpu::InsertSliceAsyncV2Op>(op) ||
|
||||
isa<triton::nvidia_gpu::StoreAsyncOp>(op) ||
|
||||
isa<tensor::InsertSliceOp>(op);
|
||||
}
|
||||
|
||||
@@ -299,6 +352,21 @@ bool supportMMA(triton::DotOp op, int version) {
|
||||
// https://docs.nvidia.com/cuda/parallel-thread-execution/index.html#warp-level-matrix-fragment-mma-884-f16
|
||||
auto aElemTy = op.getA().getType().cast<RankedTensorType>().getElementType();
|
||||
auto bElemTy = op.getB().getType().cast<RankedTensorType>().getElementType();
|
||||
if (version == 3) {
|
||||
if (!::triton::tools::getBoolEnv("ENABLE_MMA_V3"))
|
||||
return false;
|
||||
auto retType = op.getResult().getType().cast<RankedTensorType>();
|
||||
auto retShapePerCTA = triton::gpu::getShapePerCTA(retType);
|
||||
auto mod = op->getParentOfType<mlir::ModuleOp>();
|
||||
int numWarps = triton::gpu::TritonGPUDialect::getNumWarps(mod);
|
||||
if (!(numWarps % 4 == 0 && retShapePerCTA[0] % 64 == 0 &&
|
||||
retShapePerCTA[1] % 8 == 0 &&
|
||||
(aElemTy.isFloat8E5M2() || aElemTy.isFloat8E4M3FN() ||
|
||||
aElemTy.isInteger(8) || aElemTy.isF16() || aElemTy.isBF16() ||
|
||||
aElemTy.isF32()))) {
|
||||
return false;
|
||||
}
|
||||
}
|
||||
if (aElemTy.isF32() && bElemTy.isF32()) {
|
||||
return op.getAllowTF32() && version >= 2;
|
||||
}
|
||||
@@ -306,24 +374,21 @@ bool supportMMA(triton::DotOp op, int version) {
|
||||
}
|
||||
|
||||
bool supportMMA(Value value, int version) {
|
||||
// Tell whether a DotOp support HMMA by the operand type(either $a or $b).
|
||||
// Tell whether a DotOp support MMA by the operand type(either $a or $b).
|
||||
// We cannot get both the operand types(in TypeConverter), here we assume the
|
||||
// types of both the operands are identical here.
|
||||
assert((version == 1 || version == 2) &&
|
||||
assert((version == 1 || version == 2 || version == 3) &&
|
||||
"Unexpected MMA layout version found");
|
||||
auto elemTy = value.getType().cast<RankedTensorType>().getElementType();
|
||||
return elemTy.isF16() || elemTy.isBF16() ||
|
||||
// FP8 is not natively supported on all mma versions but it can always be
|
||||
// promoted to fp16 therefore we can always support it.
|
||||
bool isFP8 = elemTy.isFloat8E5M2() || elemTy.isFloat8E4M3FN() ||
|
||||
elemTy.isFloat8E5M2FNUZ() || elemTy.isFloat8E4M3FNUZ();
|
||||
return isFP8 || elemTy.isF16() || elemTy.isBF16() ||
|
||||
(elemTy.isF32() && version >= 2) ||
|
||||
(elemTy.isInteger(8) && version >= 2);
|
||||
}
|
||||
|
||||
Type getElementType(Value value) {
|
||||
auto type = value.getType();
|
||||
if (auto tensorType = type.dyn_cast<RankedTensorType>())
|
||||
return tensorType.getElementType();
|
||||
return type;
|
||||
}
|
||||
|
||||
bool isMmaToDotShortcut(RankedTensorType &srcTy, RankedTensorType &dstTy) {
|
||||
// dot_op<opIdx=0, parent=#mma> = #mma
|
||||
// when #mma = MmaEncoding<version=2, warpsPerCTA=[..., 1]>
|
||||
@@ -338,6 +403,17 @@ bool isMmaToDotShortcut(RankedTensorType &srcTy, RankedTensorType &dstTy) {
|
||||
!srcTy.getElementType().isF32();
|
||||
}
|
||||
|
||||
bool isMmaToMmaShortcut(RankedTensorType &srcTy, RankedTensorType &dstTy) {
|
||||
auto src = srcTy.getEncoding().cast<triton::gpu::MmaEncodingAttr>();
|
||||
auto dst = dstTy.getEncoding().cast<triton::gpu::MmaEncodingAttr>();
|
||||
auto srcElemsPerThread = triton::gpu::getTotalElemsPerThread(srcTy);
|
||||
auto dstElemsPerThread = triton::gpu::getTotalElemsPerThread(dstTy);
|
||||
// when #mma = MmaEncoding<version=3, warpsPerCTA=[..., 1]>
|
||||
return src.getVersionMajor() == 3 && src.getWarpsPerCTA()[1] == 1 &&
|
||||
dst.getVersionMajor() == 3 && dst.getWarpsPerCTA()[1] == 1 &&
|
||||
srcElemsPerThread == dstElemsPerThread;
|
||||
}
|
||||
|
||||
bool isSingleValue(Value value) {
|
||||
// Don't consider load as expensive if it is loading a scalar.
|
||||
if (auto tensorTy = value.getType().dyn_cast<RankedTensorType>())
|
||||
@@ -557,4 +633,81 @@ std::unique_ptr<DataFlowSolver> createDataFlowSolver() {
|
||||
return solver;
|
||||
}
|
||||
|
||||
static triton::MakeTensorPtrOp getMakeTensorPtrOpImpl(Operation *op, Value v) {
|
||||
|
||||
if (auto makeTensorPtrOp = dyn_cast<triton::MakeTensorPtrOp>(op)) {
|
||||
return makeTensorPtrOp;
|
||||
}
|
||||
|
||||
if (auto advanceOp = dyn_cast<triton::AdvanceOp>(op)) {
|
||||
return getMakeTensorPtrOp(advanceOp.getPtr());
|
||||
}
|
||||
|
||||
if (auto branch = dyn_cast<RegionBranchOpInterface>(op)) {
|
||||
auto idx = v.cast<OpResult>().getResultNumber();
|
||||
llvm::SmallVector<scf::YieldOp> yieldOps;
|
||||
op->walk([&](Operation *op) {
|
||||
if (auto yieldOp = dyn_cast<scf::YieldOp>(op))
|
||||
yieldOps.push_back(yieldOp);
|
||||
});
|
||||
|
||||
// benzh@ if multi yields, all yields operand should come from same arg.
|
||||
Value newValue = yieldOps[0].getOperands()[idx];
|
||||
return getMakeTensorPtrOp(newValue);
|
||||
}
|
||||
|
||||
llvm_unreachable("Unable to getMakeTensorPtr()");
|
||||
}
|
||||
|
||||
triton::MakeTensorPtrOp getMakeTensorPtrOp(Value v) {
|
||||
using BranchOps = llvm::SetVector<std::pair<Operation *, int>>;
|
||||
llvm::DenseMap<Block *, BranchOps> blockToCFOps;
|
||||
auto moduleOp =
|
||||
v.getParentBlock()->getParentOp()->getParentOfType<ModuleOp>();
|
||||
|
||||
moduleOp.walk([&](Operation *op) {
|
||||
if (auto br = dyn_cast<cf::BranchOp>(op)) {
|
||||
Block *block = br.getDest();
|
||||
blockToCFOps[block].insert({op, -1});
|
||||
}
|
||||
if (auto condBr = dyn_cast<cf::CondBranchOp>(op)) {
|
||||
Block *blockT = condBr.getTrueDest();
|
||||
Block *blockF = condBr.getFalseDest();
|
||||
blockToCFOps[blockT].insert({condBr, 1});
|
||||
blockToCFOps[blockF].insert({condBr, 0});
|
||||
}
|
||||
});
|
||||
|
||||
if (Operation *definingOp = v.getDefiningOp()) {
|
||||
return getMakeTensorPtrOpImpl(definingOp, v);
|
||||
} else if (BlockArgument arg = v.cast<BlockArgument>()) {
|
||||
unsigned argNum = arg.getArgNumber();
|
||||
Operation *argOwner = arg.getOwner()->getParentOp();
|
||||
|
||||
if (auto forOp = dyn_cast<scf::ForOp>(argOwner)) {
|
||||
return getMakeTensorPtrOp(
|
||||
forOp.getOperand(argNum + forOp.getNumControlOperands() - 1));
|
||||
} else if (auto funcOp = dyn_cast<mlir::triton::FuncOp>(argOwner)) {
|
||||
Block *block = arg.getOwner();
|
||||
Operation *op;
|
||||
int tOrF;
|
||||
std::tie(op, tOrF) = blockToCFOps[block][0];
|
||||
if (auto br = dyn_cast<cf::BranchOp>(op)) {
|
||||
return getMakeTensorPtrOp(br.getDestOperands()[argNum]);
|
||||
}
|
||||
if (auto condBr = dyn_cast<cf::CondBranchOp>(op)) {
|
||||
if (tOrF) {
|
||||
return getMakeTensorPtrOp(condBr.getTrueDestOperands()[argNum]);
|
||||
} else {
|
||||
return getMakeTensorPtrOp(condBr.getFalseDestOperands()[argNum]);
|
||||
}
|
||||
}
|
||||
} else {
|
||||
return getMakeTensorPtrOp(argOwner->getOperand(argNum));
|
||||
}
|
||||
}
|
||||
|
||||
llvm_unreachable("Unable to getMakeTensorPtr()");
|
||||
}
|
||||
|
||||
} // namespace mlir
|
||||
|
||||
@@ -3,3 +3,4 @@ add_subdirectory(Analysis)
|
||||
add_subdirectory(Conversion)
|
||||
add_subdirectory(Dialect)
|
||||
add_subdirectory(Target)
|
||||
add_subdirectory(Hopper)
|
||||
|
||||
217
lib/Conversion/TritonGPUToLLVM/BarrierOpToLLVM.cpp
Normal file
217
lib/Conversion/TritonGPUToLLVM/BarrierOpToLLVM.cpp
Normal file
@@ -0,0 +1,217 @@
|
||||
/*
|
||||
* Copyright (c) 2023 NVIDIA Corporation & Affiliates. All rights reserved.
|
||||
*
|
||||
* Permission is hereby granted, free of charge, to any person obtaining
|
||||
* a copy of this software and associated documentation files
|
||||
* (the "Software"), to deal in the Software without restriction,
|
||||
* including without limitation the rights to use, copy, modify, merge,
|
||||
* publish, distribute, sublicense, and/or sell copies of the Software,
|
||||
* and to permit persons to whom the Software is furnished to do so,
|
||||
* subject to the following conditions:
|
||||
*
|
||||
* The above copyright notice and this permission notice shall be
|
||||
* included in all copies or substantial portions of the Software.
|
||||
*
|
||||
* THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND,
|
||||
* EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF
|
||||
* MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT.
|
||||
* IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY
|
||||
* CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION OF CONTRACT,
|
||||
* TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION WITH THE
|
||||
* SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE.
|
||||
*/
|
||||
|
||||
#include "BarrierOpToLLVM.h"
|
||||
|
||||
using namespace mlir;
|
||||
using namespace mlir::triton;
|
||||
|
||||
// --------------------------------------------------------------------------
|
||||
// -- MBarrier related Ops lowering, to be moved to a seperate file ---------
|
||||
// --------------------------------------------------------------------------
|
||||
struct AllocMBarrierOpConversion : public ConvertTritonGPUOpToLLVMPattern<
|
||||
triton::nvidia_gpu::AllocMBarrierOp> {
|
||||
using ConvertTritonGPUOpToLLVMPattern<
|
||||
triton::nvidia_gpu::AllocMBarrierOp>::ConvertTritonGPUOpToLLVMPattern;
|
||||
|
||||
LogicalResult
|
||||
matchAndRewrite(triton::nvidia_gpu::AllocMBarrierOp op, OpAdaptor adaptor,
|
||||
ConversionPatternRewriter &rewriter) const override {
|
||||
Location loc = op->getLoc();
|
||||
Value smemBase = getSharedMemoryBase(loc, rewriter, op.getResult());
|
||||
auto resultTy = op.getType();
|
||||
auto resultTensorTy = resultTy.dyn_cast<RankedTensorType>();
|
||||
Type elemPtrTy;
|
||||
if (resultTensorTy) {
|
||||
auto llvmElemTy =
|
||||
getTypeConverter()->convertType(resultTensorTy.getElementType());
|
||||
elemPtrTy = ptr_ty(llvmElemTy, 3);
|
||||
} else {
|
||||
elemPtrTy = getTypeConverter()->convertType(resultTy);
|
||||
}
|
||||
smemBase = bitcast(smemBase, elemPtrTy);
|
||||
auto threadId = getThreadId(rewriter, loc);
|
||||
auto pred = icmp_eq(threadId, i32_val(0));
|
||||
int numMBarriers = 1;
|
||||
if (resultTensorTy) {
|
||||
assert(resultTensorTy.getRank() == 1 &&
|
||||
"unexpected rank for AllocMBarrierOp");
|
||||
numMBarriers = resultTensorTy.getShape()[0];
|
||||
}
|
||||
for (int i = 0; i < numMBarriers; ++i) {
|
||||
Value smem = smemBase;
|
||||
if (i > 0) {
|
||||
smem = gep(elemPtrTy, smem, i32_val(i));
|
||||
}
|
||||
rewriter.create<triton::nvgpu::MBarrierInitOp>(loc, smem, pred,
|
||||
op.getCount());
|
||||
}
|
||||
if (resultTensorTy) {
|
||||
auto smemObj = SharedMemoryObject(smemBase, resultTensorTy.getShape(),
|
||||
{0}, loc, rewriter);
|
||||
auto retVal = getStructFromSharedMemoryObject(loc, smemObj, rewriter);
|
||||
rewriter.replaceOp(op, retVal);
|
||||
} else {
|
||||
rewriter.replaceOp(op, smemBase);
|
||||
}
|
||||
return success();
|
||||
}
|
||||
};
|
||||
|
||||
struct MBarrierArriveOpConversion : public ConvertTritonGPUOpToLLVMPattern<
|
||||
triton::nvidia_gpu::MBarrierArriveOp> {
|
||||
using ConvertTritonGPUOpToLLVMPattern<
|
||||
triton::nvidia_gpu::MBarrierArriveOp>::ConvertTritonGPUOpToLLVMPattern;
|
||||
|
||||
LogicalResult
|
||||
matchAndRewrite(triton::nvidia_gpu::MBarrierArriveOp op, OpAdaptor adaptor,
|
||||
ConversionPatternRewriter &rewriter) const override {
|
||||
Location loc = op->getLoc();
|
||||
auto mbarrier = adaptor.getMbarrier();
|
||||
bool trackAsyncOp = op.getTrackAsyncOp();
|
||||
triton::nvgpu::MBarriveType type = triton::nvgpu::MBarriveType::normal;
|
||||
uint32_t txCount = op.getTxCount();
|
||||
auto remoteCtaId = adaptor.getRemoteCtaId();
|
||||
if (trackAsyncOp) {
|
||||
type = triton::nvgpu::MBarriveType::cp_async;
|
||||
} else if (remoteCtaId) {
|
||||
assert(txCount == 0 &&
|
||||
"remote arrive of transaction mbarrier is not implemented yet");
|
||||
type = triton::nvgpu::MBarriveType::remote;
|
||||
} else if (txCount > 0) {
|
||||
type = triton::nvgpu::MBarriveType::expect_tx;
|
||||
}
|
||||
Value pred = adaptor.getPred();
|
||||
if (pred == nullptr) {
|
||||
pred = int_val(/*width*/ 1, 1);
|
||||
}
|
||||
rewriter.replaceOpWithNewOp<triton::nvgpu::MBarrierArriveOp>(
|
||||
op, mbarrier, pred, remoteCtaId, type, txCount);
|
||||
return success();
|
||||
}
|
||||
};
|
||||
|
||||
struct MBarrierWaitOpConversion : public ConvertTritonGPUOpToLLVMPattern<
|
||||
triton::nvidia_gpu::MBarrierWaitOp> {
|
||||
using ConvertTritonGPUOpToLLVMPattern<
|
||||
triton::nvidia_gpu::MBarrierWaitOp>::ConvertTritonGPUOpToLLVMPattern;
|
||||
|
||||
LogicalResult
|
||||
matchAndRewrite(triton::nvidia_gpu::MBarrierWaitOp op, OpAdaptor adaptor,
|
||||
ConversionPatternRewriter &rewriter) const override {
|
||||
Location loc = op->getLoc();
|
||||
rewriter.replaceOpWithNewOp<triton::nvgpu::MBarrierWaitOp>(
|
||||
op, adaptor.getMbarrier(), adaptor.getPhase());
|
||||
return success();
|
||||
}
|
||||
};
|
||||
|
||||
struct ExtractMBarrierOpConversion
|
||||
: public ConvertTritonGPUOpToLLVMPattern<
|
||||
triton::nvidia_gpu::ExtractMBarrierOp> {
|
||||
using ConvertTritonGPUOpToLLVMPattern<
|
||||
triton::nvidia_gpu::ExtractMBarrierOp>::ConvertTritonGPUOpToLLVMPattern;
|
||||
|
||||
LogicalResult
|
||||
matchAndRewrite(triton::nvidia_gpu::ExtractMBarrierOp op, OpAdaptor adaptor,
|
||||
ConversionPatternRewriter &rewriter) const override {
|
||||
Location loc = op->getLoc();
|
||||
auto elemTy =
|
||||
op.getTensor().getType().cast<RankedTensorType>().getElementType();
|
||||
auto tensorStruct = adaptor.getTensor();
|
||||
auto index = adaptor.getIndex();
|
||||
auto ptrTy =
|
||||
LLVM::LLVMPointerType::get(getTypeConverter()->convertType(elemTy), 3);
|
||||
auto basePtr =
|
||||
extract_val(ptrTy, tensorStruct, rewriter.getDenseI64ArrayAttr(0));
|
||||
Value result = gep(ptrTy, basePtr, index);
|
||||
rewriter.replaceOp(op, result);
|
||||
return success();
|
||||
}
|
||||
};
|
||||
|
||||
struct NamedBarrierArriveOpConversion
|
||||
: public ConvertTritonGPUOpToLLVMPattern<
|
||||
triton::nvidia_gpu::NamedBarrierArriveOp> {
|
||||
using ConvertTritonGPUOpToLLVMPattern<
|
||||
triton::nvidia_gpu::NamedBarrierArriveOp>::
|
||||
ConvertTritonGPUOpToLLVMPattern;
|
||||
|
||||
LogicalResult
|
||||
matchAndRewrite(triton::nvidia_gpu::NamedBarrierArriveOp op,
|
||||
OpAdaptor adaptor,
|
||||
ConversionPatternRewriter &rewriter) const override {
|
||||
Location loc = op->getLoc();
|
||||
rewriter.replaceOpWithNewOp<triton::nvgpu::NamedBarrierArriveOp>(
|
||||
op, adaptor.getBar(), adaptor.getNumThreads());
|
||||
return success();
|
||||
}
|
||||
};
|
||||
|
||||
struct NamedBarrierWaitOpConversion
|
||||
: public ConvertTritonGPUOpToLLVMPattern<
|
||||
triton::nvidia_gpu::NamedBarrierWaitOp> {
|
||||
using ConvertTritonGPUOpToLLVMPattern<
|
||||
triton::nvidia_gpu::NamedBarrierWaitOp>::ConvertTritonGPUOpToLLVMPattern;
|
||||
|
||||
LogicalResult
|
||||
matchAndRewrite(triton::nvidia_gpu::NamedBarrierWaitOp op, OpAdaptor adaptor,
|
||||
ConversionPatternRewriter &rewriter) const override {
|
||||
Location loc = op->getLoc();
|
||||
rewriter.replaceOpWithNewOp<triton::nvgpu::NamedBarrierWaitOp>(
|
||||
op, adaptor.getBar(), adaptor.getNumThreads());
|
||||
return success();
|
||||
}
|
||||
};
|
||||
|
||||
struct FenceAsyncSharedOpConversion
|
||||
: public ConvertTritonGPUOpToLLVMPattern<
|
||||
triton::nvidia_gpu::FenceAsyncSharedOp> {
|
||||
using ConvertTritonGPUOpToLLVMPattern<
|
||||
triton::nvidia_gpu::FenceAsyncSharedOp>::ConvertTritonGPUOpToLLVMPattern;
|
||||
|
||||
LogicalResult
|
||||
matchAndRewrite(triton::nvidia_gpu::FenceAsyncSharedOp op, OpAdaptor adaptor,
|
||||
ConversionPatternRewriter &rewriter) const override {
|
||||
Location loc = op->getLoc();
|
||||
rewriter.replaceOpWithNewOp<triton::nvgpu::FenceAsyncSharedOp>(
|
||||
op, adaptor.getBCluster());
|
||||
return success();
|
||||
}
|
||||
};
|
||||
|
||||
void populateBarrierOpToLLVMPatterns(
|
||||
TritonGPUToLLVMTypeConverter &typeConverter, RewritePatternSet &patterns,
|
||||
int numWarps, ModuleAxisInfoAnalysis &axisInfoAnalysis,
|
||||
ModuleAllocation &allocation, PatternBenefit benefit) {
|
||||
patterns.add<AllocMBarrierOpConversion>(typeConverter, allocation, benefit);
|
||||
patterns.add<MBarrierArriveOpConversion>(typeConverter, allocation, benefit);
|
||||
patterns.add<MBarrierWaitOpConversion>(typeConverter, allocation, benefit);
|
||||
patterns.add<ExtractMBarrierOpConversion>(typeConverter, allocation, benefit);
|
||||
patterns.add<NamedBarrierArriveOpConversion>(typeConverter, allocation,
|
||||
benefit);
|
||||
patterns.add<NamedBarrierWaitOpConversion>(typeConverter, allocation,
|
||||
benefit);
|
||||
patterns.add<FenceAsyncSharedOpConversion>(typeConverter, allocation,
|
||||
benefit);
|
||||
}
|
||||
37
lib/Conversion/TritonGPUToLLVM/BarrierOpToLLVM.h
Normal file
37
lib/Conversion/TritonGPUToLLVM/BarrierOpToLLVM.h
Normal file
@@ -0,0 +1,37 @@
|
||||
/*
|
||||
* Copyright (c) 2023 NVIDIA Corporation & Affiliates. All rights reserved.
|
||||
*
|
||||
* Permission is hereby granted, free of charge, to any person obtaining
|
||||
* a copy of this software and associated documentation files
|
||||
* (the "Software"), to deal in the Software without restriction,
|
||||
* including without limitation the rights to use, copy, modify, merge,
|
||||
* publish, distribute, sublicense, and/or sell copies of the Software,
|
||||
* and to permit persons to whom the Software is furnished to do so,
|
||||
* subject to the following conditions:
|
||||
*
|
||||
* The above copyright notice and this permission notice shall be
|
||||
* included in all copies or substantial portions of the Software.
|
||||
*
|
||||
* THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND,
|
||||
* EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF
|
||||
* MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT.
|
||||
* IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY
|
||||
* CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION OF CONTRACT,
|
||||
* TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION WITH THE
|
||||
* SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE.
|
||||
*/
|
||||
|
||||
#ifndef TRITON_CONVERSION_TRITONGPU_TO_LLVM_BARRIER_OP_H
|
||||
#define TRITON_CONVERSION_TRITONGPU_TO_LLVM_BARRIER_OP_H
|
||||
|
||||
#include "TritonGPUToLLVMBase.h"
|
||||
|
||||
using namespace mlir;
|
||||
using namespace mlir::triton;
|
||||
|
||||
void populateBarrierOpToLLVMPatterns(
|
||||
TritonGPUToLLVMTypeConverter &typeConverter, RewritePatternSet &patterns,
|
||||
int numWarps, ModuleAxisInfoAnalysis &axisInfoAnalysis,
|
||||
ModuleAllocation &allocation, PatternBenefit benefit);
|
||||
|
||||
#endif
|
||||
@@ -1,4 +1,16 @@
|
||||
add_mlir_conversion_library(TritonGPUToLLVM
|
||||
ConvertLayoutOpToLLVM/SharedToDotOperandFMA.cpp
|
||||
ConvertLayoutOpToLLVM/SharedToDotOperandMMAv1.cpp
|
||||
ConvertLayoutOpToLLVM/SharedToDotOperandMMAv2.cpp
|
||||
ConvertLayoutOpToLLVM.cpp
|
||||
DotOpToLLVM/FMA.cpp
|
||||
DotOpToLLVM/MMAv1.cpp
|
||||
DotOpToLLVM/MMAv2.cpp
|
||||
DotOpToLLVM/WGMMA.cpp
|
||||
DotOpToLLVM.cpp
|
||||
ElementwiseOpToLLVM.cpp
|
||||
LoadStoreOpToLLVM.cpp
|
||||
BarrierOpToLLVM.cpp
|
||||
TritonGPUToLLVM.cpp
|
||||
GCNAsmFormat.cpp
|
||||
PTXAsmFormat.cpp
|
||||
@@ -15,12 +27,16 @@ add_mlir_conversion_library(TritonGPUToLLVM
|
||||
LoadStoreOpToLLVM.cpp
|
||||
TritonGPUToLLVM.cpp
|
||||
TritonGPUToLLVMPass.cpp
|
||||
GCNAsmFormat.cpp
|
||||
PTXAsmFormat.cpp
|
||||
ReduceOpToLLVM.cpp
|
||||
ScanOpToLLVM.cpp
|
||||
Utility.cpp
|
||||
TypeConverter.cpp
|
||||
Utility.cpp
|
||||
ViewOpToLLVM.cpp
|
||||
TensorPtrOpsToLLVM.cpp
|
||||
ClusterOpsToLLVM.cpp
|
||||
RegReallocOpToLLVM.cpp
|
||||
|
||||
ADDITIONAL_HEADER_DIRS
|
||||
${PROJECT_SOURCE_DIR}/include/triton/Conversion/TritonGPUToLLVM
|
||||
@@ -43,4 +59,6 @@ add_mlir_conversion_library(TritonGPUToLLVM
|
||||
TritonIR
|
||||
TritonGPUIR
|
||||
TritonGPUTransforms
|
||||
TritonNvidiaGPUTransforms
|
||||
NVGPUIR
|
||||
)
|
||||
|
||||
62
lib/Conversion/TritonGPUToLLVM/ClusterOpsToLLVM.cpp
Normal file
62
lib/Conversion/TritonGPUToLLVM/ClusterOpsToLLVM.cpp
Normal file
@@ -0,0 +1,62 @@
|
||||
/*
|
||||
* Copyright (c) 2023 NVIDIA Corporation & Affiliates. All rights reserved.
|
||||
*
|
||||
* Permission is hereby granted, free of charge, to any person obtaining
|
||||
* a copy of this software and associated documentation files
|
||||
* (the "Software"), to deal in the Software without restriction,
|
||||
* including without limitation the rights to use, copy, modify, merge,
|
||||
* publish, distribute, sublicense, and/or sell copies of the Software,
|
||||
* and to permit persons to whom the Software is furnished to do so,
|
||||
* subject to the following conditions:
|
||||
*
|
||||
* The above copyright notice and this permission notice shall be
|
||||
* included in all copies or substantial portions of the Software.
|
||||
*
|
||||
* THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND,
|
||||
* EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF
|
||||
* MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT.
|
||||
* IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY
|
||||
* CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION OF CONTRACT,
|
||||
* TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION WITH THE
|
||||
* SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE.
|
||||
*/
|
||||
|
||||
#include "ClusterOpsToLLVM.h"
|
||||
using namespace mlir;
|
||||
using namespace mlir::triton;
|
||||
|
||||
struct ClusterArriveOpConversion : public ConvertTritonGPUOpToLLVMPattern<
|
||||
triton::nvidia_gpu::ClusterArriveOp> {
|
||||
using ConvertTritonGPUOpToLLVMPattern<
|
||||
triton::nvidia_gpu::ClusterArriveOp>::ConvertTritonGPUOpToLLVMPattern;
|
||||
|
||||
LogicalResult
|
||||
matchAndRewrite(triton::nvidia_gpu::ClusterArriveOp op, OpAdaptor adaptor,
|
||||
ConversionPatternRewriter &rewriter) const override {
|
||||
rewriter.replaceOpWithNewOp<triton::nvgpu::ClusterArriveOp>(
|
||||
op, op.getRelaxed());
|
||||
return success();
|
||||
}
|
||||
};
|
||||
|
||||
struct ClusterWaitOpConversion : public ConvertTritonGPUOpToLLVMPattern<
|
||||
triton::nvidia_gpu::ClusterWaitOp> {
|
||||
using ConvertTritonGPUOpToLLVMPattern<
|
||||
triton::nvidia_gpu::ClusterWaitOp>::ConvertTritonGPUOpToLLVMPattern;
|
||||
|
||||
LogicalResult
|
||||
matchAndRewrite(triton::nvidia_gpu::ClusterWaitOp op, OpAdaptor adaptor,
|
||||
ConversionPatternRewriter &rewriter) const override {
|
||||
rewriter.replaceOpWithNewOp<triton::nvgpu::ClusterWaitOp>(op);
|
||||
return success();
|
||||
}
|
||||
};
|
||||
|
||||
void populateClusterOpsToLLVMPatterns(
|
||||
TritonGPUToLLVMTypeConverter &typeConverter, RewritePatternSet &patterns,
|
||||
int numWarps, ModuleAxisInfoAnalysis &axisInfoAnalysis,
|
||||
ModuleAllocation &allocation, PatternBenefit benefit) {
|
||||
patterns.add<ClusterArriveOpConversion>(typeConverter, benefit);
|
||||
patterns.add<ClusterWaitOpConversion>(typeConverter, benefit);
|
||||
return;
|
||||
}
|
||||
37
lib/Conversion/TritonGPUToLLVM/ClusterOpsToLLVM.h
Normal file
37
lib/Conversion/TritonGPUToLLVM/ClusterOpsToLLVM.h
Normal file
@@ -0,0 +1,37 @@
|
||||
/*
|
||||
* Copyright (c) 2023 NVIDIA Corporation & Affiliates. All rights reserved.
|
||||
*
|
||||
* Permission is hereby granted, free of charge, to any person obtaining
|
||||
* a copy of this software and associated documentation files
|
||||
* (the "Software"), to deal in the Software without restriction,
|
||||
* including without limitation the rights to use, copy, modify, merge,
|
||||
* publish, distribute, sublicense, and/or sell copies of the Software,
|
||||
* and to permit persons to whom the Software is furnished to do so,
|
||||
* subject to the following conditions:
|
||||
*
|
||||
* The above copyright notice and this permission notice shall be
|
||||
* included in all copies or substantial portions of the Software.
|
||||
*
|
||||
* THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND,
|
||||
* EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF
|
||||
* MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT.
|
||||
* IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY
|
||||
* CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION OF CONTRACT,
|
||||
* TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION WITH THE
|
||||
* SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE.
|
||||
*/
|
||||
|
||||
#ifndef TRITON_CONVERSION_TRITONGPU_TO_LLVM_CLUSTER_OPS_H
|
||||
#define TRITON_CONVERSION_TRITONGPU_TO_LLVM_CLUSTER_OPS_H
|
||||
|
||||
#include "TritonGPUToLLVMBase.h"
|
||||
|
||||
using namespace mlir;
|
||||
using namespace mlir::triton;
|
||||
|
||||
void populateClusterOpsToLLVMPatterns(
|
||||
TritonGPUToLLVMTypeConverter &typeConverter, RewritePatternSet &patterns,
|
||||
int numWarps, ModuleAxisInfoAnalysis &axisInfoAnalysis,
|
||||
ModuleAllocation &allocation, PatternBenefit benefit);
|
||||
|
||||
#endif
|
||||
@@ -1,14 +1,18 @@
|
||||
#include "ConvertLayoutOpToLLVM.h"
|
||||
#include "Utility.h"
|
||||
#include "triton/Dialect/TritonNvidiaGPU/Transforms/Utility.h"
|
||||
|
||||
using ::mlir::LLVM::delinearize;
|
||||
using ::mlir::LLVM::getSharedMemoryObjectFromStruct;
|
||||
using ::mlir::LLVM::getStridesFromShapeAndOrder;
|
||||
using ::mlir::LLVM::linearize;
|
||||
|
||||
using ::mlir::LLVM::getSharedMemoryObjectFromStruct;
|
||||
using ::mlir::LLVM::getStridesFromShapeAndOrder;
|
||||
using ::mlir::triton::gpu::DotOperandEncodingAttr;
|
||||
using ::mlir::triton::gpu::getContigPerThread;
|
||||
using ::mlir::triton::gpu::getOrder;
|
||||
using ::mlir::triton::gpu::getShapePerCTA;
|
||||
using ::mlir::triton::gpu::getShapePerCTATile;
|
||||
using ::mlir::triton::gpu::getSizePerThread;
|
||||
using ::mlir::triton::gpu::getTotalElemsPerThread;
|
||||
using ::mlir::triton::gpu::isaDistributedLayout;
|
||||
@@ -72,6 +76,13 @@ public:
|
||||
dstLayout.isa<DotOperandEncodingAttr>()) {
|
||||
return lowerSharedToDotOperand(op, adaptor, rewriter);
|
||||
}
|
||||
// forwarding on mma->mma shortcut, lower distributed->distributed otherwise
|
||||
if (srcLayout.isa<MmaEncodingAttr>() && dstLayout.isa<MmaEncodingAttr>()) {
|
||||
if (isMmaToMmaShortcut(srcTy, dstTy)) {
|
||||
rewriter.replaceOp(op, op.getSrc());
|
||||
return success();
|
||||
}
|
||||
}
|
||||
if (isaDistributedLayout(srcLayout) && isaDistributedLayout(dstLayout)) {
|
||||
return lowerDistributedToDistributed(op, adaptor, rewriter);
|
||||
}
|
||||
@@ -89,23 +100,25 @@ public:
|
||||
}
|
||||
|
||||
private:
|
||||
SmallVector<Value> getMultiDimOffset(Attribute layout, Location loc,
|
||||
ConversionPatternRewriter &rewriter,
|
||||
unsigned elemId, RankedTensorType type,
|
||||
ArrayRef<unsigned> multiDimCTAInRepId,
|
||||
ArrayRef<unsigned> shapePerCTA) const {
|
||||
SmallVector<Value>
|
||||
getMultiDimOffset(Attribute layout, Location loc,
|
||||
ConversionPatternRewriter &rewriter, unsigned elemId,
|
||||
RankedTensorType type,
|
||||
ArrayRef<unsigned> multiDimCTAInRepId,
|
||||
ArrayRef<unsigned> shapePerCTATile) const {
|
||||
auto shape = type.getShape();
|
||||
unsigned rank = shape.size();
|
||||
if (auto blockedLayout = layout.dyn_cast<BlockedEncodingAttr>()) {
|
||||
auto multiDimOffsetFirstElem =
|
||||
emitBaseIndexForLayout(loc, rewriter, blockedLayout, type);
|
||||
emitBaseIndexForLayout(loc, rewriter, blockedLayout, type, false);
|
||||
SmallVector<Value> multiDimOffset(rank);
|
||||
SmallVector<unsigned> multiDimElemId = getMultiDimIndex<unsigned>(
|
||||
elemId, getSizePerThread(layout), getOrder(layout));
|
||||
for (unsigned d = 0; d < rank; ++d) {
|
||||
multiDimOffset[d] = add(multiDimOffsetFirstElem[d],
|
||||
i32_val(multiDimCTAInRepId[d] * shapePerCTA[d] +
|
||||
multiDimElemId[d]));
|
||||
multiDimOffset[d] =
|
||||
add(multiDimOffsetFirstElem[d],
|
||||
i32_val(multiDimCTAInRepId[d] * shapePerCTATile[d] +
|
||||
multiDimElemId[d]));
|
||||
}
|
||||
return multiDimOffset;
|
||||
}
|
||||
@@ -127,7 +140,7 @@ private:
|
||||
auto multiDimOffsetParent = getMultiDimOffset(
|
||||
parentEncoding, loc, rewriter, idxs[elemId], parentTy,
|
||||
sliceLayout.paddedShape(multiDimCTAInRepId),
|
||||
sliceLayout.paddedShape(shapePerCTA));
|
||||
sliceLayout.paddedShape(shapePerCTATile));
|
||||
SmallVector<Value> multiDimOffset(rank);
|
||||
for (unsigned d = 0; d < rank + 1; ++d) {
|
||||
if (d == dim)
|
||||
@@ -138,6 +151,8 @@ private:
|
||||
return multiDimOffset;
|
||||
}
|
||||
if (auto mmaLayout = layout.dyn_cast<MmaEncodingAttr>()) {
|
||||
auto shapePerCTA = getShapePerCTA(mmaLayout, shape);
|
||||
auto instrShape = mmaLayout.getInstrShape();
|
||||
SmallVector<Value> mmaColIdx(4);
|
||||
SmallVector<Value> mmaRowIdx(2);
|
||||
Value threadId = getThreadId(rewriter, loc);
|
||||
@@ -145,27 +160,35 @@ private:
|
||||
Value laneId = urem(threadId, warpSize);
|
||||
Value warpId = udiv(threadId, warpSize);
|
||||
// TODO: fix the bug in MMAEncodingAttr document
|
||||
SmallVector<Value> multiDimWarpId(2);
|
||||
auto warpsPerCTA = mmaLayout.getWarpsPerCTA();
|
||||
auto order = triton::gpu::getOrder(mmaLayout);
|
||||
SmallVector<Value> multiDimWarpId =
|
||||
delinearize(rewriter, loc, warpId, warpsPerCTA, order);
|
||||
if (mmaLayout.isHopper()) {
|
||||
multiDimWarpId[0] = urem(warpId, i32_val(warpsPerCTA[0]));
|
||||
multiDimWarpId[1] = udiv(warpId, i32_val(warpsPerCTA[0]));
|
||||
} else {
|
||||
auto order = triton::gpu::getOrder(mmaLayout);
|
||||
multiDimWarpId = delinearize(rewriter, loc, warpId, warpsPerCTA, order);
|
||||
}
|
||||
Value _1 = i32_val(1);
|
||||
Value _2 = i32_val(2);
|
||||
Value _4 = i32_val(4);
|
||||
Value _8 = i32_val(8);
|
||||
Value _16 = i32_val(16);
|
||||
if (mmaLayout.isAmpere()) {
|
||||
multiDimWarpId[0] = urem(multiDimWarpId[0], i32_val(shape[0] / 16));
|
||||
multiDimWarpId[1] = urem(multiDimWarpId[1], i32_val(shape[1] / 8));
|
||||
if (mmaLayout.isAmpere() || mmaLayout.isHopper()) {
|
||||
multiDimWarpId[0] =
|
||||
urem(multiDimWarpId[0], i32_val(shapePerCTA[0] / instrShape[0]));
|
||||
multiDimWarpId[1] =
|
||||
urem(multiDimWarpId[1], i32_val(shapePerCTA[1] / instrShape[1]));
|
||||
|
||||
Value mmaGrpId = udiv(laneId, _4);
|
||||
Value mmaGrpIdP8 = add(mmaGrpId, _8);
|
||||
Value mmaThreadIdInGrp = urem(laneId, _4);
|
||||
Value mmaThreadIdInGrpM2 = mul(mmaThreadIdInGrp, _2);
|
||||
Value mmaThreadIdInGrpM2P1 = add(mmaThreadIdInGrpM2, _1);
|
||||
Value rowWarpOffset = mul(multiDimWarpId[0], _16);
|
||||
Value rowWarpOffset = mul(multiDimWarpId[0], i32_val(instrShape[0]));
|
||||
mmaRowIdx[0] = add(mmaGrpId, rowWarpOffset);
|
||||
mmaRowIdx[1] = add(mmaGrpIdP8, rowWarpOffset);
|
||||
Value colWarpOffset = mul(multiDimWarpId[1], _8);
|
||||
Value colWarpOffset = mul(multiDimWarpId[1], i32_val(instrShape[1]));
|
||||
mmaColIdx[0] = add(mmaThreadIdInGrpM2, colWarpOffset);
|
||||
mmaColIdx[1] = add(mmaThreadIdInGrpM2P1, colWarpOffset);
|
||||
} else if (mmaLayout.isVolta()) {
|
||||
@@ -176,13 +199,27 @@ private:
|
||||
|
||||
assert(rank == 2);
|
||||
SmallVector<Value> multiDimOffset(rank);
|
||||
if (mmaLayout.isAmpere()) {
|
||||
if (mmaLayout.isHopper()) {
|
||||
unsigned elemIdRem4 = elemId % 4;
|
||||
unsigned nGrpId = elemId / 4;
|
||||
multiDimOffset[0] = elemIdRem4 < 2 ? mmaRowIdx[0] : mmaRowIdx[1];
|
||||
multiDimOffset[1] = elemIdRem4 % 2 == 0 ? mmaColIdx[0] : mmaColIdx[1];
|
||||
multiDimOffset[1] = add(multiDimOffset[1], i32_val(8 * nGrpId));
|
||||
multiDimOffset[0] =
|
||||
add(multiDimOffset[0],
|
||||
i32_val(multiDimCTAInRepId[0] * shapePerCTATile[0]));
|
||||
multiDimOffset[1] =
|
||||
add(multiDimOffset[1],
|
||||
i32_val(multiDimCTAInRepId[1] * shapePerCTATile[1]));
|
||||
} else if (mmaLayout.isAmpere()) {
|
||||
multiDimOffset[0] = elemId < 2 ? mmaRowIdx[0] : mmaRowIdx[1];
|
||||
multiDimOffset[1] = elemId % 2 == 0 ? mmaColIdx[0] : mmaColIdx[1];
|
||||
multiDimOffset[0] = add(
|
||||
multiDimOffset[0], i32_val(multiDimCTAInRepId[0] * shapePerCTA[0]));
|
||||
multiDimOffset[1] = add(
|
||||
multiDimOffset[1], i32_val(multiDimCTAInRepId[1] * shapePerCTA[1]));
|
||||
multiDimOffset[0] =
|
||||
add(multiDimOffset[0],
|
||||
i32_val(multiDimCTAInRepId[0] * shapePerCTATile[0]));
|
||||
multiDimOffset[1] =
|
||||
add(multiDimOffset[1],
|
||||
i32_val(multiDimCTAInRepId[1] * shapePerCTATile[1]));
|
||||
} else if (mmaLayout.isVolta()) {
|
||||
auto [isARow, isBRow, isAVec4, isBVec4, _] =
|
||||
mmaLayout.decodeVoltaLayoutStates();
|
||||
@@ -211,11 +248,12 @@ private:
|
||||
auto rank = type.getRank();
|
||||
auto sizePerThread = getSizePerThread(layout);
|
||||
auto accumSizePerThread = product<unsigned>(sizePerThread);
|
||||
SmallVector<unsigned> numCTAs(rank);
|
||||
SmallVector<unsigned> numCTATiles(rank);
|
||||
auto shapePerCTATile = getShapePerCTATile(layout);
|
||||
auto shapePerCTA = getShapePerCTA(layout, type.getShape());
|
||||
auto order = getOrder(layout);
|
||||
for (unsigned d = 0; d < rank; ++d) {
|
||||
numCTAs[d] = ceil<unsigned>(type.getShape()[d], shapePerCTA[d]);
|
||||
numCTATiles[d] = ceil<unsigned>(shapePerCTA[d], shapePerCTATile[d]);
|
||||
}
|
||||
auto elemTy = type.getElementType();
|
||||
bool isInt1 = elemTy.isInteger(1);
|
||||
@@ -238,17 +276,16 @@ private:
|
||||
}
|
||||
|
||||
auto linearCTAId =
|
||||
getLinearIndex<unsigned>(multiDimCTAId, numCTAs, order);
|
||||
getLinearIndex<unsigned>(multiDimCTAId, numCTATiles, order);
|
||||
// TODO: This is actually redundant index calculation, we should
|
||||
// consider of caching the index calculation result in case
|
||||
// of performance issue observed.
|
||||
for (unsigned elemId = 0; elemId < accumSizePerThread; elemId += vec) {
|
||||
SmallVector<Value> multiDimOffset =
|
||||
getMultiDimOffset(layout, loc, rewriter, elemId, type,
|
||||
multiDimCTAInRepId, shapePerCTA);
|
||||
multiDimCTAInRepId, shapePerCTATile);
|
||||
Value offset =
|
||||
linearize(rewriter, loc, multiDimOffset, paddedRepShape, outOrd);
|
||||
|
||||
auto elemPtrTy = ptr_ty(llvmElemTy, 3);
|
||||
Value ptr = gep(elemPtrTy, smemBase, offset);
|
||||
auto vecTy = vec_ty(llvmElemTy, vec);
|
||||
@@ -305,7 +342,8 @@ private:
|
||||
|
||||
SmallVector<unsigned> numCTAs(rank, 1);
|
||||
SmallVector<unsigned> numCTAsEachRep(rank, 1);
|
||||
SmallVector<unsigned> shapePerCTA = getShapePerCTA(layout, shape);
|
||||
SmallVector<unsigned> shapePerCTATile = getShapePerCTATile(layout, shape);
|
||||
SmallVector<int64_t> shapePerCTA = getShapePerCTA(layout, shape);
|
||||
auto elemTy = type.getElementType();
|
||||
|
||||
int ctaId = 0;
|
||||
@@ -335,7 +373,7 @@ private:
|
||||
// duplicate in Volta.
|
||||
SmallVector<Value> multiDimOffset =
|
||||
getMultiDimOffset(layout, loc, rewriter, elemId, type,
|
||||
multiDimCTAInRepId, shapePerCTA);
|
||||
multiDimCTAInRepId, shapePerCTATile);
|
||||
coord2val[elemId] = std::make_pair(multiDimOffset, vals[elemId]);
|
||||
}
|
||||
|
||||
@@ -343,7 +381,7 @@ private:
|
||||
// do transpose
|
||||
auto aEncoding =
|
||||
DotOperandEncodingAttr::get(mma.getContext(), 0, mma, 0);
|
||||
int numM = aEncoding.getMMAv1NumOuter(shape);
|
||||
int numM = aEncoding.getMMAv1NumOuter(shapePerCTA);
|
||||
int numN = accumSizePerThread / numM;
|
||||
|
||||
for (int r = 0; r < numM; r++) {
|
||||
@@ -382,6 +420,91 @@ private:
|
||||
}
|
||||
}
|
||||
|
||||
LogicalResult
|
||||
lowerDistToDistWithDistSmem(triton::gpu::ConvertLayoutOp op,
|
||||
OpAdaptor adaptor,
|
||||
ConversionPatternRewriter &rewriter) const {
|
||||
auto loc = op.getLoc();
|
||||
|
||||
Value src = op.getSrc();
|
||||
Value dst = op.getResult();
|
||||
auto srcTy = src.getType().cast<RankedTensorType>();
|
||||
auto dstTy = dst.getType().cast<RankedTensorType>();
|
||||
auto srcLayout = srcTy.getEncoding();
|
||||
auto dstLayout = dstTy.getEncoding();
|
||||
auto srcShapePerCTA = getShapePerCTA(srcTy);
|
||||
auto srcCTAsPerCGA = triton::gpu::getCTAsPerCGA(srcLayout);
|
||||
auto srcCTAOrder = triton::gpu::getCTAOrder(srcLayout);
|
||||
unsigned rank = srcShapePerCTA.size();
|
||||
|
||||
auto llvmElemTy = getTypeConverter()->convertType(dstTy.getElementType());
|
||||
auto elemPtrTy = ptr_ty(llvmElemTy, 3);
|
||||
|
||||
Value smemBase = getSharedMemoryBase(loc, rewriter, op.getOperation());
|
||||
smemBase = bitcast(smemBase, elemPtrTy);
|
||||
auto smemShape = convertType<unsigned, int64_t>(srcShapePerCTA);
|
||||
|
||||
// Store to local shared memory
|
||||
{
|
||||
auto inVals = getTypeConverter()->unpackLLElements(loc, adaptor.getSrc(),
|
||||
rewriter, srcTy);
|
||||
auto inIndices =
|
||||
emitIndices(loc, rewriter, srcLayout, srcTy, /*withCTAOffset*/ false);
|
||||
|
||||
assert(inIndices.size() == inVals.size() &&
|
||||
"Unexpected number of indices emitted");
|
||||
|
||||
for (unsigned i = 0; i < inIndices.size(); ++i) {
|
||||
Value offset = linearize(rewriter, loc, inIndices[i], smemShape);
|
||||
Value ptr = gep(elemPtrTy, smemBase, offset);
|
||||
store(inVals[i], ptr);
|
||||
}
|
||||
}
|
||||
|
||||
// Cluster barrier
|
||||
rewriter.create<triton::nvidia_gpu::ClusterArriveOp>(loc, false);
|
||||
rewriter.create<triton::nvidia_gpu::ClusterWaitOp>(loc);
|
||||
|
||||
// Load from remote shared memory
|
||||
{
|
||||
SmallVector<Value> srcShapePerCTACache;
|
||||
for (unsigned i = 0; i < rank; ++i)
|
||||
srcShapePerCTACache.push_back(i32_val(srcShapePerCTA[i]));
|
||||
|
||||
SmallVector<Value> outVals;
|
||||
auto outIndices =
|
||||
emitIndices(loc, rewriter, dstLayout, dstTy, /*withCTAOffset*/ true);
|
||||
|
||||
for (unsigned i = 0; i < outIndices.size(); ++i) {
|
||||
auto coord = outIndices[i];
|
||||
assert(coord.size() == rank && "Unexpected rank of index emitted");
|
||||
|
||||
SmallVector<Value> multiDimCTAId, localCoord;
|
||||
for (unsigned d = 0; d < rank; ++d) {
|
||||
multiDimCTAId.push_back(udiv(coord[d], srcShapePerCTACache[d]));
|
||||
localCoord.push_back(urem(coord[d], srcShapePerCTACache[d]));
|
||||
}
|
||||
|
||||
Value remoteCTAId =
|
||||
linearize(rewriter, loc, multiDimCTAId, srcCTAsPerCGA, srcCTAOrder);
|
||||
Value localOffset = linearize(rewriter, loc, localCoord, smemShape);
|
||||
|
||||
Value ptr = gep(elemPtrTy, smemBase, localOffset);
|
||||
outVals.push_back(load_dsmem(ptr, remoteCTAId));
|
||||
}
|
||||
|
||||
Value result =
|
||||
getTypeConverter()->packLLElements(loc, outVals, rewriter, dstTy);
|
||||
rewriter.replaceOp(op, result);
|
||||
}
|
||||
|
||||
// Cluster barrier
|
||||
rewriter.create<triton::nvidia_gpu::ClusterArriveOp>(loc, false);
|
||||
rewriter.create<triton::nvidia_gpu::ClusterWaitOp>(loc);
|
||||
|
||||
return success();
|
||||
}
|
||||
|
||||
// blocked/mma -> blocked/mma.
|
||||
// Data padding in shared memory to avoid bank conflict.
|
||||
LogicalResult
|
||||
@@ -395,6 +518,10 @@ private:
|
||||
auto dstTy = dst.getType().cast<RankedTensorType>();
|
||||
Attribute srcLayout = srcTy.getEncoding();
|
||||
Attribute dstLayout = dstTy.getEncoding();
|
||||
|
||||
if (shouldUseDistSmem(srcLayout, dstLayout))
|
||||
return lowerDistToDistWithDistSmem(op, adaptor, rewriter);
|
||||
|
||||
auto llvmElemTy = getTypeConverter()->convertType(dstTy.getElementType());
|
||||
Value smemBase = getSharedMemoryBase(loc, rewriter, op.getOperation());
|
||||
auto elemPtrTy = ptr_ty(llvmElemTy, 3);
|
||||
@@ -406,8 +533,9 @@ private:
|
||||
SmallVector<unsigned> outNumCTAsEachRep(rank);
|
||||
SmallVector<unsigned> inNumCTAs(rank);
|
||||
SmallVector<unsigned> outNumCTAs(rank);
|
||||
auto srcShapePerCTA = getShapePerCTA(srcLayout, srcTy.getShape());
|
||||
auto dstShapePerCTA = getShapePerCTA(dstLayout, shape);
|
||||
auto srcShapePerCTATile = getShapePerCTATile(srcLayout, srcTy.getShape());
|
||||
auto dstShapePerCTATile = getShapePerCTATile(dstLayout, shape);
|
||||
auto shapePerCTA = getShapePerCTA(srcLayout, shape);
|
||||
|
||||
// For Volta, all the coords for a CTA are calculated.
|
||||
bool isSrcMmaV1{}, isDstMmaV1{};
|
||||
@@ -427,15 +555,17 @@ private:
|
||||
}
|
||||
|
||||
for (unsigned d = 0; d < rank; ++d) {
|
||||
unsigned inPerCTA = std::min<unsigned>(shape[d], srcShapePerCTA[d]);
|
||||
unsigned outPerCTA = std::min<unsigned>(shape[d], dstShapePerCTA[d]);
|
||||
unsigned inPerCTA =
|
||||
std::min<unsigned>(shapePerCTA[d], srcShapePerCTATile[d]);
|
||||
unsigned outPerCTA =
|
||||
std::min<unsigned>(shapePerCTA[d], dstShapePerCTATile[d]);
|
||||
unsigned maxPerCTA = std::max(inPerCTA, outPerCTA);
|
||||
numReplicates[d] = ceil<unsigned>(shape[d], maxPerCTA);
|
||||
numReplicates[d] = ceil<unsigned>(shapePerCTA[d], maxPerCTA);
|
||||
inNumCTAsEachRep[d] = maxPerCTA / inPerCTA;
|
||||
outNumCTAsEachRep[d] = maxPerCTA / outPerCTA;
|
||||
assert(maxPerCTA % inPerCTA == 0 && maxPerCTA % outPerCTA == 0);
|
||||
inNumCTAs[d] = ceil<unsigned>(shape[d], inPerCTA);
|
||||
outNumCTAs[d] = ceil<unsigned>(shape[d], outPerCTA);
|
||||
inNumCTAs[d] = ceil<unsigned>(shapePerCTA[d], inPerCTA);
|
||||
outNumCTAs[d] = ceil<unsigned>(shapePerCTA[d], outPerCTA);
|
||||
}
|
||||
// Potentially we need to store for multiple CTAs in this replication
|
||||
auto accumNumReplicates = product<unsigned>(numReplicates);
|
||||
@@ -456,8 +586,26 @@ private:
|
||||
for (unsigned repId = 0; repId < accumNumReplicates; ++repId) {
|
||||
auto multiDimRepId =
|
||||
getMultiDimIndex<unsigned>(repId, numReplicates, outOrd);
|
||||
if (repId != 0)
|
||||
barrier();
|
||||
if (repId != 0) {
|
||||
// TODO[shuhaoj]: change hard code style of numThreads. Hide async
|
||||
// attr. Better way to determine barId (number of agents are limited).
|
||||
if (op->hasAttr("async_agent")) {
|
||||
int agentId = getAgentIds(op).front(), roleId = 0;
|
||||
if (op->hasAttr("agent.mutex_role"))
|
||||
roleId =
|
||||
op->getAttrOfType<IntegerAttr>("agent.mutex_role").getInt();
|
||||
int barId = agentId + roleId + nameBarrierIdBegin;
|
||||
assert(barId < nameBarrierIdEnd);
|
||||
auto bar = rewriter.create<LLVM::ConstantOp>(
|
||||
loc, i32_ty, rewriter.getI32IntegerAttr(barId));
|
||||
auto kNumThreads = rewriter.create<LLVM::ConstantOp>(
|
||||
loc, i32_ty, rewriter.getI32IntegerAttr(128));
|
||||
rewriter.create<triton::nvgpu::NamedBarrierWaitOp>(loc, bar,
|
||||
kNumThreads);
|
||||
} else {
|
||||
barrier();
|
||||
}
|
||||
}
|
||||
if (srcLayout.isa<BlockedEncodingAttr>() ||
|
||||
srcLayout.isa<SliceEncodingAttr>() ||
|
||||
srcLayout.isa<MmaEncodingAttr>()) {
|
||||
@@ -474,7 +622,23 @@ private:
|
||||
return failure();
|
||||
}
|
||||
|
||||
barrier();
|
||||
// TODO[shuhaoj]: change hard code style of numThreads. Hide async_agent
|
||||
// attr. Better way to determine barId (number of agents are limited).
|
||||
if (op->hasAttr("async_agent")) {
|
||||
int agentId = getAgentIds(op).front(), roleId = 0;
|
||||
if (op->hasAttr("agent.mutex_role"))
|
||||
roleId = op->getAttrOfType<IntegerAttr>("agent.mutex_role").getInt();
|
||||
int barId = agentId + roleId + nameBarrierIdBegin;
|
||||
assert(barId < nameBarrierIdEnd);
|
||||
auto bar = rewriter.create<LLVM::ConstantOp>(
|
||||
loc, i32_ty, rewriter.getI32IntegerAttr(barId));
|
||||
auto kNumThreads = rewriter.create<LLVM::ConstantOp>(
|
||||
loc, i32_ty, rewriter.getI32IntegerAttr(128));
|
||||
rewriter.create<triton::nvgpu::NamedBarrierWaitOp>(loc, bar,
|
||||
kNumThreads);
|
||||
} else {
|
||||
barrier();
|
||||
}
|
||||
if (dstLayout.isa<BlockedEncodingAttr>() ||
|
||||
dstLayout.isa<SliceEncodingAttr>() ||
|
||||
dstLayout.isa<MmaEncodingAttr>()) {
|
||||
@@ -545,7 +709,7 @@ private:
|
||||
auto srcTy = src.getType().cast<RankedTensorType>();
|
||||
auto srcShape = srcTy.getShape();
|
||||
auto dstTy = dst.getType().cast<RankedTensorType>();
|
||||
auto dstShape = dstTy.getShape();
|
||||
auto dstShapePerCTA = triton::gpu::getShapePerCTA(dstTy);
|
||||
assert(srcShape.size() == 2 &&
|
||||
"Unexpected rank of ConvertLayout(blocked->shared)");
|
||||
auto srcLayout = srcTy.getEncoding();
|
||||
@@ -557,13 +721,93 @@ private:
|
||||
auto elemPtrTy = ptr_ty(getTypeConverter()->convertType(elemTy), 3);
|
||||
smemBase = bitcast(smemBase, elemPtrTy);
|
||||
|
||||
auto dstStrides =
|
||||
getStridesFromShapeAndOrder(dstShape, outOrd, loc, rewriter);
|
||||
auto srcIndices = emitIndices(loc, rewriter, srcLayout, srcTy);
|
||||
storeDistributedToShared(src, adaptor.getSrc(), dstStrides, srcIndices, dst,
|
||||
smemBase, elemTy, loc, rewriter);
|
||||
int32_t elemSize = elemTy.getIntOrFloatBitWidth();
|
||||
auto mmaLayout = srcLayout.dyn_cast<MmaEncodingAttr>();
|
||||
unsigned numElems = triton::gpu::getTotalElemsPerThread(srcTy);
|
||||
if (mmaLayout && mmaLayout.isHopper() && elemSize == 16 &&
|
||||
inOrd == outOrd && numElems >= 16) {
|
||||
auto inVals = getTypeConverter()->unpackLLElements(loc, adaptor.getSrc(),
|
||||
rewriter, srcTy);
|
||||
|
||||
auto srcShapePerCTA = getShapePerCTA(mmaLayout, srcShape);
|
||||
auto instrShape = mmaLayout.getInstrShape();
|
||||
auto warpsPerCTA = mmaLayout.getWarpsPerCTA();
|
||||
uint32_t repM =
|
||||
ceil<unsigned>(srcShapePerCTA[0], instrShape[0] * warpsPerCTA[0]);
|
||||
uint32_t numElemsPerRep = numElems / repM;
|
||||
// rowStride in bytes
|
||||
uint32_t rowStrideInBytes = dstShapePerCTA[outOrd[0]] * 2;
|
||||
uint32_t swizzlingByteWidth = rowStrideInBytes;
|
||||
if (swizzlingByteWidth > 128)
|
||||
swizzlingByteWidth = 128;
|
||||
|
||||
unsigned numElemsPerSwizzlingRow = swizzlingByteWidth * 8 / elemSize;
|
||||
unsigned leadingDimOffset =
|
||||
numElemsPerSwizzlingRow * srcShapePerCTA[outOrd[1]];
|
||||
|
||||
auto ptrI8SharedTy = LLVM::LLVMPointerType::get(
|
||||
typeConverter->convertType(rewriter.getI8Type()), 3);
|
||||
|
||||
uint32_t rowsPerRep = getShapePerCTATile(mmaLayout)[0];
|
||||
|
||||
Value threadId = getThreadId(rewriter, loc);
|
||||
Value warpId = udiv(threadId, i32_val(32));
|
||||
Value warpId0 = urem(urem(warpId, i32_val(warpsPerCTA[0])),
|
||||
i32_val(srcShape[0] / instrShape[0]));
|
||||
|
||||
for (int rep = 0; rep < repM; ++rep) {
|
||||
Value rowOfWarp = add(mul(warpId0, i32_val(instrShape[0])),
|
||||
i32_val(rep * rowsPerRep));
|
||||
uint32_t elemIdxOffset = rep * numElemsPerRep;
|
||||
|
||||
for (unsigned idx = 0; idx < numElemsPerRep; idx += 8) {
|
||||
uint32_t elemIdx = elemIdxOffset + idx;
|
||||
|
||||
Value offset = rewriter.create<triton::nvgpu::OffsetOfStmatrixV4Op>(
|
||||
loc, i32_ty, threadId, rowOfWarp, i32_val(idx), leadingDimOffset,
|
||||
numElemsPerSwizzlingRow, true);
|
||||
|
||||
Value addr = gep(elemPtrTy, smemBase, offset);
|
||||
Value data0 = rewriter.create<triton::nvgpu::CvtPackOp>(
|
||||
loc, i32_ty, inVals[elemIdx + 1], inVals[elemIdx + 0]);
|
||||
Value data1 = rewriter.create<triton::nvgpu::CvtPackOp>(
|
||||
loc, i32_ty, inVals[elemIdx + 3], inVals[elemIdx + 2]);
|
||||
Value data2 = rewriter.create<triton::nvgpu::CvtPackOp>(
|
||||
loc, i32_ty, inVals[elemIdx + 5], inVals[elemIdx + 4]);
|
||||
Value data3 = rewriter.create<triton::nvgpu::CvtPackOp>(
|
||||
loc, i32_ty, inVals[elemIdx + 7], inVals[elemIdx + 6]);
|
||||
|
||||
rewriter.create<triton::nvgpu::StoreMatrixOp>(
|
||||
loc, bitcast(addr, ptrI8SharedTy),
|
||||
ValueRange{data0, data1, data2, data3});
|
||||
}
|
||||
}
|
||||
// TODO[shuhaoj]: change hard code style of numThreads. Hide async_agent
|
||||
// attr. Better way to determine barId (number of agents are limited).
|
||||
if (op->hasAttr("async_agent")) {
|
||||
int agentId = getAgentIds(op).front(), roleId = 0;
|
||||
if (op->hasAttr("agent.mutex_role"))
|
||||
roleId = op->getAttrOfType<IntegerAttr>("agent.mutex_role").getInt();
|
||||
int barId = agentId + roleId + nameBarrierIdBegin;
|
||||
assert(barId < nameBarrierIdEnd);
|
||||
auto bar = rewriter.create<LLVM::ConstantOp>(
|
||||
loc, i32_ty, rewriter.getI32IntegerAttr(barId));
|
||||
auto kNumThreads = rewriter.create<LLVM::ConstantOp>(
|
||||
loc, i32_ty, rewriter.getI32IntegerAttr(128));
|
||||
rewriter.create<triton::nvgpu::NamedBarrierWaitOp>(loc, bar,
|
||||
kNumThreads);
|
||||
} else {
|
||||
barrier();
|
||||
}
|
||||
} else {
|
||||
auto dstStrides =
|
||||
getStridesFromShapeAndOrder(dstShapePerCTA, outOrd, loc, rewriter);
|
||||
auto srcIndices = emitIndices(loc, rewriter, srcLayout, srcTy, false);
|
||||
storeDistributedToShared(src, adaptor.getSrc(), dstStrides, srcIndices,
|
||||
dst, smemBase, elemTy, loc, rewriter);
|
||||
}
|
||||
auto smemObj =
|
||||
SharedMemoryObject(smemBase, dstShape, outOrd, loc, rewriter);
|
||||
SharedMemoryObject(smemBase, dstShapePerCTA, outOrd, loc, rewriter);
|
||||
auto retVal = getStructFromSharedMemoryObject(loc, smemObj, rewriter);
|
||||
rewriter.replaceOp(op, retVal);
|
||||
return success();
|
||||
@@ -688,19 +932,16 @@ private:
|
||||
auto loc = op.getLoc();
|
||||
Value src = op.getSrc();
|
||||
Value dst = op.getResult();
|
||||
bool isMMA = supportMMA(dst, mmaLayout.getVersionMajor());
|
||||
|
||||
auto smemObj =
|
||||
getSharedMemoryObjectFromStruct(loc, adaptor.getSrc(), rewriter);
|
||||
Value res;
|
||||
|
||||
if (!isOuter && mmaLayout.isAmpere()) { // tensor core v2
|
||||
|
||||
res = SharedToDotOperandMMAv2::convertLayout(
|
||||
dotOperandLayout.getOpIdx(), rewriter, loc, src, dotOperandLayout,
|
||||
smemObj, getTypeConverter(), tid_val());
|
||||
|
||||
} else if (!isOuter && mmaLayout.isVolta() &&
|
||||
supportMMA(dst, mmaLayout.getVersionMajor())) { // tensor core v1
|
||||
smemObj, getTypeConverter(), getThreadId(rewriter, loc));
|
||||
} else if (!isOuter && mmaLayout.isVolta() && isMMA) { // tensor core v1
|
||||
bool isMMAv1Row = dotOperandLayout.getMMAv1IsRow();
|
||||
auto srcSharedLayout = src.getType()
|
||||
.cast<RankedTensorType>()
|
||||
@@ -722,10 +963,11 @@ private:
|
||||
}
|
||||
return res;
|
||||
}
|
||||
}; // namespace triton::gpu::ConvertLayoutOp
|
||||
}; // namespace triton::gpu::ConvertLayoutOp>
|
||||
|
||||
void populateConvertLayoutOpToLLVMPatterns(
|
||||
TritonGPUToLLVMTypeConverter &typeConverter, RewritePatternSet &patterns,
|
||||
int numWarps, ModuleAxisInfoAnalysis &axisInfoAnalysis,
|
||||
ModuleAllocation &allocation,
|
||||
ConvertTritonGPUOpToLLVMPatternBase::IndexCacheInfo &indexCacheInfo,
|
||||
PatternBenefit benefit) {
|
||||
|
||||
@@ -10,6 +10,7 @@ using ::mlir::triton::gpu::DotOperandEncodingAttr;
|
||||
|
||||
void populateConvertLayoutOpToLLVMPatterns(
|
||||
TritonGPUToLLVMTypeConverter &typeConverter, RewritePatternSet &patterns,
|
||||
int numWarps, ModuleAxisInfoAnalysis &axisInfoAnalysis,
|
||||
ModuleAllocation &allocation,
|
||||
ConvertTritonGPUOpToLLVMPatternBase::IndexCacheInfo &indexCacheInfo,
|
||||
PatternBenefit benefit);
|
||||
|
||||
@@ -2,8 +2,10 @@
|
||||
#include "../Utility.h"
|
||||
|
||||
using ValueTable = std::map<std::pair<int, int>, Value>;
|
||||
using ::mlir::LLVM::delinearize;
|
||||
using ::mlir::LLVM::getSharedMemoryObjectFromStruct;
|
||||
using ::mlir::LLVM::getStridesFromShapeAndOrder;
|
||||
using ::mlir::LLVM::linearize;
|
||||
using ::mlir::triton::gpu::DotOperandEncodingAttr;
|
||||
using ::mlir::triton::gpu::getContigPerThread;
|
||||
using ::mlir::triton::gpu::getOrder;
|
||||
@@ -14,31 +16,32 @@ using ::mlir::triton::gpu::isaDistributedLayout;
|
||||
using ::mlir::triton::gpu::SharedEncodingAttr;
|
||||
|
||||
SmallVector<Value>
|
||||
getThreadIds(Value threadId, ArrayRef<unsigned int> shapePerCTA,
|
||||
getThreadIds(Value threadId, ArrayRef<unsigned int> shapePerCTATile,
|
||||
ArrayRef<unsigned int> sizePerThread, ArrayRef<unsigned int> order,
|
||||
ConversionPatternRewriter &rewriter, Location loc) {
|
||||
int dim = order.size();
|
||||
SmallVector<Value> threadIds(dim);
|
||||
for (unsigned k = 0; k < dim - 1; k++) {
|
||||
Value dimK = i32_val(shapePerCTA[order[k]] / sizePerThread[order[k]]);
|
||||
Value dimK = i32_val(shapePerCTATile[order[k]] / sizePerThread[order[k]]);
|
||||
Value rem = urem(threadId, dimK);
|
||||
threadId = udiv(threadId, dimK);
|
||||
threadIds[order[k]] = rem;
|
||||
}
|
||||
Value dimK = i32_val(shapePerCTA[order[dim - 1]]);
|
||||
Value dimK = i32_val(shapePerCTATile[order[dim - 1]]);
|
||||
threadIds[order[dim - 1]] = urem(threadId, dimK);
|
||||
return threadIds;
|
||||
}
|
||||
|
||||
int getShapePerCTAForMN(BlockedEncodingAttr layout, bool isM) {
|
||||
// Get shapePerCTATile for M or N axis.
|
||||
int getShapePerCTATileForMN(BlockedEncodingAttr layout, bool isM) {
|
||||
auto order = layout.getOrder();
|
||||
auto shapePerCTA = getShapePerCTA(layout);
|
||||
auto shapePerCTATile = getShapePerCTATile(layout);
|
||||
|
||||
int mShapePerCTA =
|
||||
order[0] == 1 ? shapePerCTA[order[1]] : shapePerCTA[order[0]];
|
||||
int nShapePerCTA =
|
||||
order[0] == 0 ? shapePerCTA[order[1]] : shapePerCTA[order[0]];
|
||||
return isM ? mShapePerCTA : nShapePerCTA;
|
||||
int mShapePerCTATile =
|
||||
order[0] == 1 ? shapePerCTATile[order[1]] : shapePerCTATile[order[0]];
|
||||
int nShapePerCTATile =
|
||||
order[0] == 0 ? shapePerCTATile[order[1]] : shapePerCTATile[order[0]];
|
||||
return isM ? mShapePerCTATile : nShapePerCTATile;
|
||||
}
|
||||
|
||||
// Get sizePerThread for M or N axis.
|
||||
@@ -91,7 +94,7 @@ Value loadAFMA(Value A, Value llA, BlockedEncodingAttr dLayout, Value thread,
|
||||
ConversionPatternRewriter &rewriter) {
|
||||
auto aTensorTy = A.getType().cast<RankedTensorType>();
|
||||
auto aLayout = aTensorTy.getEncoding().cast<SharedEncodingAttr>();
|
||||
auto aShape = aTensorTy.getShape();
|
||||
auto aShapePerCTA = getShapePerCTA(aTensorTy);
|
||||
|
||||
auto aOrder = aLayout.getOrder();
|
||||
auto order = dLayout.getOrder();
|
||||
@@ -104,10 +107,10 @@ Value loadAFMA(Value A, Value llA, BlockedEncodingAttr dLayout, Value thread,
|
||||
Value strideA0 = isARow ? strideAK : strideAM;
|
||||
Value strideA1 = isARow ? strideAM : strideAK;
|
||||
int aNumPtr = 8;
|
||||
int K = aShape[1];
|
||||
int M = aShape[0];
|
||||
int K = aShapePerCTA[1];
|
||||
int M = aShapePerCTA[0];
|
||||
|
||||
auto shapePerCTA = getShapePerCTA(dLayout);
|
||||
auto shapePerCTATile = getShapePerCTATile(dLayout);
|
||||
auto sizePerThread = getSizePerThread(dLayout);
|
||||
|
||||
Value _0 = i32_val(0);
|
||||
@@ -115,8 +118,8 @@ Value loadAFMA(Value A, Value llA, BlockedEncodingAttr dLayout, Value thread,
|
||||
Value mContig = i32_val(sizePerThread[order[1]]);
|
||||
|
||||
// threadId in blocked layout
|
||||
auto threadIds =
|
||||
getThreadIds(thread, shapePerCTA, sizePerThread, order, rewriter, loc);
|
||||
auto threadIds = getThreadIds(thread, shapePerCTATile, sizePerThread, order,
|
||||
rewriter, loc);
|
||||
Value threadIdM = threadIds[0];
|
||||
|
||||
Value offA0 = isARow ? _0 : mul(threadIdM, mContig);
|
||||
@@ -134,11 +137,11 @@ Value loadAFMA(Value A, Value llA, BlockedEncodingAttr dLayout, Value thread,
|
||||
|
||||
SmallVector<Value> vas;
|
||||
|
||||
int mShapePerCTA = getShapePerCTAForMN(dLayout, true /*isM*/);
|
||||
int mShapePerCTATile = getShapePerCTATileForMN(dLayout, true /*isM*/);
|
||||
int mSizePerThread = getSizePerThreadForMN(dLayout, true /*isM*/);
|
||||
|
||||
for (unsigned k = 0; k < K; ++k)
|
||||
for (unsigned m = 0; m < M; m += mShapePerCTA)
|
||||
for (unsigned m = 0; m < M; m += mShapePerCTATile)
|
||||
for (unsigned mm = 0; mm < mSizePerThread; ++mm) {
|
||||
Value offset =
|
||||
add(mul(i32_val(m + mm), strideAM), mul(i32_val(k), strideAK));
|
||||
@@ -155,7 +158,7 @@ Value loadBFMA(Value B, Value llB, BlockedEncodingAttr dLayout, Value thread,
|
||||
ConversionPatternRewriter &rewriter) {
|
||||
auto bTensorTy = B.getType().cast<RankedTensorType>();
|
||||
auto bLayout = bTensorTy.getEncoding().cast<SharedEncodingAttr>();
|
||||
auto bShape = bTensorTy.getShape();
|
||||
auto bShapePerCTA = getShapePerCTA(bTensorTy);
|
||||
|
||||
auto bOrder = bLayout.getOrder();
|
||||
auto order = dLayout.getOrder();
|
||||
@@ -168,10 +171,10 @@ Value loadBFMA(Value B, Value llB, BlockedEncodingAttr dLayout, Value thread,
|
||||
Value strideB0 = isBRow ? strideBN : strideBK;
|
||||
Value strideB1 = isBRow ? strideBK : strideBN;
|
||||
int bNumPtr = 8;
|
||||
int K = bShape[0];
|
||||
int N = bShape[1];
|
||||
int K = bShapePerCTA[0];
|
||||
int N = bShapePerCTA[1];
|
||||
|
||||
auto shapePerCTA = getShapePerCTA(dLayout);
|
||||
auto shapePerCTATile = getShapePerCTATile(dLayout);
|
||||
auto sizePerThread = getSizePerThread(dLayout);
|
||||
|
||||
Value _0 = i32_val(0);
|
||||
@@ -179,8 +182,8 @@ Value loadBFMA(Value B, Value llB, BlockedEncodingAttr dLayout, Value thread,
|
||||
Value nContig = i32_val(sizePerThread[order[0]]);
|
||||
|
||||
// threadId in blocked layout
|
||||
auto threadIds =
|
||||
getThreadIds(thread, shapePerCTA, sizePerThread, order, rewriter, loc);
|
||||
auto threadIds = getThreadIds(thread, shapePerCTATile, sizePerThread, order,
|
||||
rewriter, loc);
|
||||
Value threadIdN = threadIds[1];
|
||||
|
||||
Value offB0 = isBRow ? mul(threadIdN, nContig) : _0;
|
||||
@@ -198,11 +201,11 @@ Value loadBFMA(Value B, Value llB, BlockedEncodingAttr dLayout, Value thread,
|
||||
|
||||
SmallVector<Value> vbs;
|
||||
|
||||
int nShapePerCTA = getShapePerCTAForMN(dLayout, false /*isM*/);
|
||||
int nShapePerCTATile = getShapePerCTATileForMN(dLayout, false /*isM*/);
|
||||
int nSizePerThread = getSizePerThreadForMN(dLayout, false /*isM*/);
|
||||
|
||||
for (unsigned k = 0; k < K; ++k)
|
||||
for (unsigned n = 0; n < N; n += nShapePerCTA)
|
||||
for (unsigned n = 0; n < N; n += nShapePerCTATile)
|
||||
for (unsigned nn = 0; nn < nSizePerThread; ++nn) {
|
||||
Value offset =
|
||||
add(mul(i32_val(n + nn), strideBN), mul(i32_val(k), strideBK));
|
||||
|
||||
@@ -504,7 +504,7 @@ std::function<void(int, int)> getLoadMatrixFn(
|
||||
.cast<RankedTensorType>()
|
||||
.getElementType()
|
||||
.isa<mlir::Float8E4M3B11FNUZType>()) {
|
||||
bool noTrans = (isA ^ order[0] == 0);
|
||||
bool noTrans = (isA ^ (order[0] == 0));
|
||||
assert(noTrans && "float8e4b15 must have row-col layout");
|
||||
}
|
||||
|
||||
|
||||
@@ -4,7 +4,9 @@
|
||||
using namespace mlir;
|
||||
using namespace mlir::triton;
|
||||
|
||||
using ::mlir::LLVM::getSharedMemoryObjectFromStruct;
|
||||
using ::mlir::triton::gpu::DotOperandEncodingAttr;
|
||||
using ::mlir::triton::gpu::getShapePerCTA;
|
||||
using ::mlir::triton::gpu::MmaEncodingAttr;
|
||||
|
||||
LogicalResult convertFMADot(triton::DotOp op, triton::DotOp::Adaptor adaptor,
|
||||
@@ -19,6 +21,16 @@ LogicalResult convertMMA16816(triton::DotOp op, triton::DotOp::Adaptor adaptor,
|
||||
TritonGPUToLLVMTypeConverter *typeConverter,
|
||||
ConversionPatternRewriter &rewriter);
|
||||
|
||||
LogicalResult convertWGMMA(triton::DotOp op, triton::DotOp::Adaptor adaptor,
|
||||
TritonGPUToLLVMTypeConverter *typeConverter,
|
||||
ConversionPatternRewriter &rewriter, Value thread);
|
||||
|
||||
LogicalResult convertAsyncWGMMA(triton::nvidia_gpu::DotAsyncOp op,
|
||||
triton::nvidia_gpu::DotAsyncOp::Adaptor adaptor,
|
||||
TritonGPUToLLVMTypeConverter *typeConverter,
|
||||
ConversionPatternRewriter &rewriter,
|
||||
Value thread);
|
||||
|
||||
struct DotOpConversion : public ConvertTritonGPUOpToLLVMPattern<triton::DotOp> {
|
||||
using ConvertTritonGPUOpToLLVMPattern<
|
||||
triton::DotOp>::ConvertTritonGPUOpToLLVMPattern;
|
||||
@@ -26,14 +38,15 @@ struct DotOpConversion : public ConvertTritonGPUOpToLLVMPattern<triton::DotOp> {
|
||||
LogicalResult
|
||||
matchAndRewrite(triton::DotOp op, OpAdaptor adaptor,
|
||||
ConversionPatternRewriter &rewriter) const override {
|
||||
Location loc = op->getLoc();
|
||||
// D = A * B + C
|
||||
Value A = op.getA();
|
||||
Value D = op.getResult();
|
||||
|
||||
// Here we assume the DotOp's operands always comes from shared memory.
|
||||
auto AShape = A.getType().cast<RankedTensorType>().getShape();
|
||||
auto AShapePerCTA = getShapePerCTA(A.getType());
|
||||
size_t reduceAxis = 1;
|
||||
unsigned K = AShape[reduceAxis];
|
||||
unsigned K = AShapePerCTA[reduceAxis];
|
||||
bool isOuter = K == 1;
|
||||
|
||||
MmaEncodingAttr mmaLayout = D.getType()
|
||||
@@ -45,6 +58,9 @@ struct DotOpConversion : public ConvertTritonGPUOpToLLVMPattern<triton::DotOp> {
|
||||
return convertMMA884(op, adaptor, getTypeConverter(), rewriter);
|
||||
if (mmaLayout.isAmpere())
|
||||
return convertMMA16816(op, adaptor, getTypeConverter(), rewriter);
|
||||
if (mmaLayout.isHopper())
|
||||
return convertWGMMA(op, adaptor, getTypeConverter(), rewriter,
|
||||
getThreadId(rewriter, loc));
|
||||
|
||||
llvm::report_fatal_error(
|
||||
"Unsupported MMA kind found when converting DotOp to LLVM.");
|
||||
@@ -61,9 +77,68 @@ struct DotOpConversion : public ConvertTritonGPUOpToLLVMPattern<triton::DotOp> {
|
||||
}
|
||||
};
|
||||
|
||||
struct DotAsyncOpConversion
|
||||
: public ConvertTritonGPUOpToLLVMPattern<triton::nvidia_gpu::DotAsyncOp> {
|
||||
using ConvertTritonGPUOpToLLVMPattern<
|
||||
triton::nvidia_gpu::DotAsyncOp>::ConvertTritonGPUOpToLLVMPattern;
|
||||
|
||||
LogicalResult
|
||||
matchAndRewrite(triton::nvidia_gpu::DotAsyncOp op, OpAdaptor adaptor,
|
||||
ConversionPatternRewriter &rewriter) const override {
|
||||
auto loc = op.getLoc();
|
||||
// D = A * B + C
|
||||
Value A = op.getA();
|
||||
Value D = op.getResult();
|
||||
|
||||
// Here we assume the DotOp's operands always comes from shared memory.
|
||||
auto AShapePerCTA = getShapePerCTA(A.getType());
|
||||
size_t reduceAxis = 1;
|
||||
unsigned K = AShapePerCTA[reduceAxis];
|
||||
bool isOuter = K == 1;
|
||||
|
||||
MmaEncodingAttr mmaLayout = D.getType()
|
||||
.cast<RankedTensorType>()
|
||||
.getEncoding()
|
||||
.dyn_cast<MmaEncodingAttr>();
|
||||
if (!isOuter && mmaLayout &&
|
||||
supportMMA(op.getOperand(0), mmaLayout.getVersionMajor())) {
|
||||
if (mmaLayout.isHopper()) {
|
||||
return convertAsyncWGMMA(op, adaptor, getTypeConverter(), rewriter,
|
||||
getThreadId(rewriter, loc));
|
||||
}
|
||||
|
||||
llvm::report_fatal_error(
|
||||
"Unsupported MMA kind found when converting DotAsyncOp to LLVM.");
|
||||
}
|
||||
|
||||
llvm::report_fatal_error(
|
||||
"Unsupported DotAsyncOp found when converting TritonGPU to LLVM.");
|
||||
}
|
||||
};
|
||||
|
||||
struct DotWaitOpConversion
|
||||
: public ConvertTritonGPUOpToLLVMPattern<triton::nvidia_gpu::DotWaitOp> {
|
||||
using ConvertTritonGPUOpToLLVMPattern<
|
||||
triton::nvidia_gpu::DotWaitOp>::ConvertTritonGPUOpToLLVMPattern;
|
||||
|
||||
LogicalResult
|
||||
matchAndRewrite(triton::nvidia_gpu::DotWaitOp op, OpAdaptor adaptor,
|
||||
ConversionPatternRewriter &rewriter) const override {
|
||||
auto pendings = op.getPendings();
|
||||
rewriter.create<triton::nvgpu::WGMMAWaitOp>(op.getLoc(), pendings);
|
||||
|
||||
// Safe to remove the op since it doesn't have any return value.
|
||||
rewriter.eraseOp(op);
|
||||
return success();
|
||||
}
|
||||
};
|
||||
|
||||
void populateDotOpToLLVMPatterns(TritonGPUToLLVMTypeConverter &typeConverter,
|
||||
RewritePatternSet &patterns,
|
||||
RewritePatternSet &patterns, int numWarps,
|
||||
ModuleAxisInfoAnalysis &axisInfoAnalysis,
|
||||
ModuleAllocation &allocation,
|
||||
PatternBenefit benefit) {
|
||||
patterns.add<DotOpConversion>(typeConverter, allocation, benefit);
|
||||
patterns.add<DotAsyncOpConversion>(typeConverter, allocation, benefit);
|
||||
patterns.add<DotWaitOpConversion>(typeConverter, allocation, benefit);
|
||||
}
|
||||
|
||||
@@ -7,7 +7,8 @@ using namespace mlir;
|
||||
using namespace mlir::triton;
|
||||
|
||||
void populateDotOpToLLVMPatterns(TritonGPUToLLVMTypeConverter &typeConverter,
|
||||
RewritePatternSet &patterns,
|
||||
RewritePatternSet &patterns, int numWarps,
|
||||
ModuleAxisInfoAnalysis &axisInfoAnalysis,
|
||||
ModuleAllocation &allocation,
|
||||
PatternBenefit benefit);
|
||||
|
||||
|
||||
@@ -5,19 +5,20 @@ using namespace mlir;
|
||||
using namespace mlir::triton;
|
||||
|
||||
using ::mlir::triton::gpu::DotOperandEncodingAttr;
|
||||
using ::mlir::triton::gpu::getShapePerCTA;
|
||||
using ::mlir::triton::gpu::MmaEncodingAttr;
|
||||
|
||||
using ValueTableFMA = std::map<std::pair<int, int>, Value>;
|
||||
|
||||
static ValueTableFMA getValueTableFromStructFMA(
|
||||
Value val, int K, int n0, int shapePerCTA, int sizePerThread,
|
||||
Value val, int K, int n0, int shapePerCTATile, int sizePerThread,
|
||||
ConversionPatternRewriter &rewriter, Location loc,
|
||||
TritonGPUToLLVMTypeConverter *typeConverter, Type type) {
|
||||
ValueTableFMA res;
|
||||
auto elems = typeConverter->unpackLLElements(loc, val, rewriter, type);
|
||||
int index = 0;
|
||||
for (unsigned k = 0; k < K; ++k) {
|
||||
for (unsigned m = 0; m < n0; m += shapePerCTA)
|
||||
for (unsigned m = 0; m < n0; m += shapePerCTATile)
|
||||
for (unsigned mm = 0; mm < sizePerThread; ++mm) {
|
||||
res[{m + mm, k}] = elems[index++];
|
||||
}
|
||||
@@ -40,8 +41,8 @@ LogicalResult convertFMADot(triton::DotOp op, triton::DotOp::Adaptor adaptor,
|
||||
auto bTensorTy = B.getType().cast<RankedTensorType>();
|
||||
auto dTensorTy = D.getType().cast<RankedTensorType>();
|
||||
|
||||
auto aShape = aTensorTy.getShape();
|
||||
auto bShape = bTensorTy.getShape();
|
||||
auto aShapePerCTA = getShapePerCTA(aTensorTy);
|
||||
auto bShapePerCTA = getShapePerCTA(bTensorTy);
|
||||
|
||||
BlockedEncodingAttr dLayout =
|
||||
dTensorTy.getEncoding().cast<BlockedEncodingAttr>();
|
||||
@@ -53,41 +54,42 @@ LogicalResult convertFMADot(triton::DotOp op, triton::DotOp::Adaptor adaptor,
|
||||
Value llB = adaptor.getB();
|
||||
|
||||
auto sizePerThread = getSizePerThread(dLayout);
|
||||
auto shapePerCTA = getShapePerCTA(dLayout);
|
||||
auto shapePerCTATile = getShapePerCTATile(dLayout);
|
||||
|
||||
int K = aShape[1];
|
||||
int M = aShape[0];
|
||||
int N = bShape[1];
|
||||
int K = aShapePerCTA[1];
|
||||
int M = aShapePerCTA[0];
|
||||
int N = bShapePerCTA[1];
|
||||
|
||||
int mShapePerCTA =
|
||||
order[0] == 1 ? shapePerCTA[order[1]] : shapePerCTA[order[0]];
|
||||
int mShapePerCTATile =
|
||||
order[0] == 1 ? shapePerCTATile[order[1]] : shapePerCTATile[order[0]];
|
||||
int mSizePerThread =
|
||||
order[0] == 1 ? sizePerThread[order[1]] : sizePerThread[order[0]];
|
||||
int nShapePerCTA =
|
||||
order[0] == 0 ? shapePerCTA[order[1]] : shapePerCTA[order[0]];
|
||||
int nShapePerCTATile =
|
||||
order[0] == 0 ? shapePerCTATile[order[1]] : shapePerCTATile[order[0]];
|
||||
int nSizePerThread =
|
||||
order[0] == 0 ? sizePerThread[order[1]] : sizePerThread[order[0]];
|
||||
|
||||
auto has =
|
||||
getValueTableFromStructFMA(llA, K, M, mShapePerCTA, mSizePerThread,
|
||||
getValueTableFromStructFMA(llA, K, M, mShapePerCTATile, mSizePerThread,
|
||||
rewriter, loc, typeConverter, aTensorTy);
|
||||
auto hbs =
|
||||
getValueTableFromStructFMA(llB, K, N, nShapePerCTA, nSizePerThread,
|
||||
getValueTableFromStructFMA(llB, K, N, nShapePerCTATile, nSizePerThread,
|
||||
rewriter, loc, typeConverter, bTensorTy);
|
||||
|
||||
SmallVector<Value> ret = cc;
|
||||
bool isCRow = order[0] == 1;
|
||||
|
||||
for (unsigned k = 0; k < K; k++) {
|
||||
for (unsigned m = 0; m < M; m += mShapePerCTA)
|
||||
for (unsigned n = 0; n < N; n += nShapePerCTA)
|
||||
for (unsigned m = 0; m < M; m += mShapePerCTATile)
|
||||
for (unsigned n = 0; n < N; n += nShapePerCTATile)
|
||||
for (unsigned mm = 0; mm < mSizePerThread; ++mm)
|
||||
for (unsigned nn = 0; nn < nSizePerThread; ++nn) {
|
||||
int mIdx = m / mShapePerCTA * mSizePerThread + mm;
|
||||
int nIdx = n / nShapePerCTA * nSizePerThread + nn;
|
||||
int mIdx = m / mShapePerCTATile * mSizePerThread + mm;
|
||||
int nIdx = n / nShapePerCTATile * nSizePerThread + nn;
|
||||
|
||||
int z = isCRow ? mIdx * N / nShapePerCTA * mSizePerThread + nIdx
|
||||
: nIdx * M / mShapePerCTA * nSizePerThread + mIdx;
|
||||
int z = isCRow
|
||||
? mIdx * N / nShapePerCTATile * mSizePerThread + nIdx
|
||||
: nIdx * M / mShapePerCTATile * nSizePerThread + mIdx;
|
||||
ret[z] = rewriter.create<LLVM::FMulAddOp>(loc, has[{m + mm, k}],
|
||||
hbs[{n + nn, k}], ret[z]);
|
||||
}
|
||||
|
||||
@@ -170,16 +170,17 @@ LogicalResult convertDot(TritonGPUToLLVMTypeConverter *typeConverter,
|
||||
auto bTensorTy = b.getType().cast<RankedTensorType>();
|
||||
auto dTensorTy = d.getType().cast<RankedTensorType>();
|
||||
|
||||
SmallVector<int64_t> aShape(aTensorTy.getShape().begin(),
|
||||
aTensorTy.getShape().end());
|
||||
auto dShape = dTensorTy.getShape();
|
||||
auto aShapePerCTA = triton::gpu::getShapePerCTA(aTensorTy);
|
||||
auto bShapePerCTA = triton::gpu::getShapePerCTA(bTensorTy);
|
||||
auto dShapePerCTA = triton::gpu::getShapePerCTA(dTensorTy);
|
||||
|
||||
int bitwidth = aTensorTy.getElementType().getIntOrFloatBitWidth();
|
||||
auto repA =
|
||||
aTensorTy.getEncoding().cast<DotOperandEncodingAttr>().getMMAv2Rep(
|
||||
aTensorTy.getShape(), bitwidth);
|
||||
aShapePerCTA, bitwidth);
|
||||
auto repB =
|
||||
bTensorTy.getEncoding().cast<DotOperandEncodingAttr>().getMMAv2Rep(
|
||||
bTensorTy.getShape(), bitwidth);
|
||||
bShapePerCTA, bitwidth);
|
||||
|
||||
assert(repA[1] == repB[0]);
|
||||
int repM = repA[0], repN = repB[1], repK = repA[1];
|
||||
|
||||
391
lib/Conversion/TritonGPUToLLVM/DotOpToLLVM/WGMMA.cpp
Normal file
391
lib/Conversion/TritonGPUToLLVM/DotOpToLLVM/WGMMA.cpp
Normal file
@@ -0,0 +1,391 @@
|
||||
/*
|
||||
* Copyright (c) 2023 NVIDIA Corporation & Affiliates. All rights reserved.
|
||||
*
|
||||
* Permission is hereby granted, free of charge, to any person obtaining
|
||||
* a copy of this software and associated documentation files
|
||||
* (the "Software"), to deal in the Software without restriction,
|
||||
* including without limitation the rights to use, copy, modify, merge,
|
||||
* publish, distribute, sublicense, and/or sell copies of the Software,
|
||||
* and to permit persons to whom the Software is furnished to do so,
|
||||
* subject to the following conditions:
|
||||
*
|
||||
* The above copyright notice and this permission notice shall be
|
||||
* included in all copies or substantial portions of the Software.
|
||||
*
|
||||
* THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND,
|
||||
* EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF
|
||||
* MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT.
|
||||
* IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY
|
||||
* CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION OF CONTRACT,
|
||||
* TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION WITH THE
|
||||
* SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE.
|
||||
*/
|
||||
|
||||
#include "DotOpToLLVM.h"
|
||||
#include "Utility.h"
|
||||
|
||||
using namespace mlir;
|
||||
using namespace mlir::triton;
|
||||
|
||||
using ::mlir::LLVM::getSharedMemoryObjectFromStruct;
|
||||
using ::mlir::triton::gpu::getShapePerCTA;
|
||||
using ::mlir::triton::gpu::getShapePerCTATile;
|
||||
using ::mlir::triton::gpu::MmaEncodingAttr;
|
||||
using ::mlir::triton::gpu::SharedEncodingAttr;
|
||||
|
||||
triton::nvgpu::WGMMAEltType getMmaRetType(Value d) {
|
||||
auto dTy = d.getType().cast<RankedTensorType>().getElementType();
|
||||
if (dTy.isF32()) {
|
||||
return triton::nvgpu::WGMMAEltType::f32;
|
||||
} else if (dTy.isF16()) {
|
||||
return triton::nvgpu::WGMMAEltType::f16;
|
||||
} else if (dTy.isInteger(32)) {
|
||||
return triton::nvgpu::WGMMAEltType::s32;
|
||||
} else {
|
||||
llvm::report_fatal_error("Unsupported mma result type found");
|
||||
}
|
||||
}
|
||||
|
||||
triton::nvgpu::WGMMAEltType getMmaOperandType(Value a, bool allowTF32) {
|
||||
auto aTy = a.getType().cast<RankedTensorType>().getElementType();
|
||||
if (aTy.isF16()) {
|
||||
return triton::nvgpu::WGMMAEltType::f16;
|
||||
} else if (aTy.isBF16()) {
|
||||
return triton::nvgpu::WGMMAEltType::bf16;
|
||||
} else if (aTy.isF32() && allowTF32) {
|
||||
return triton::nvgpu::WGMMAEltType::tf32;
|
||||
} else if (aTy.isInteger(8)) {
|
||||
return triton::nvgpu::WGMMAEltType::s8;
|
||||
} else if (aTy.isFloat8E5M2()) {
|
||||
return triton::nvgpu::WGMMAEltType::e5m2;
|
||||
} else if (aTy.isFloat8E4M3FN()) {
|
||||
return triton::nvgpu::WGMMAEltType::e4m3;
|
||||
} else {
|
||||
llvm::report_fatal_error("Unsupported mma operand type found");
|
||||
}
|
||||
}
|
||||
|
||||
mlir::triton::nvgpu::WGMMADescMode
|
||||
getModeFromLayout(const SharedEncodingAttr &layout, uint32_t widthInByte) {
|
||||
int perPhase = layout.getPerPhase();
|
||||
int maxPhase = layout.getMaxPhase();
|
||||
uint32_t swizzlingByteWidth = 0;
|
||||
|
||||
mlir::triton::nvgpu::WGMMADescMode mode;
|
||||
if (perPhase == 4 && maxPhase == 2) {
|
||||
mode = mlir::triton::nvgpu::WGMMADescMode::swizzle32;
|
||||
swizzlingByteWidth = 32;
|
||||
} else if (perPhase == 2 && maxPhase == 4) {
|
||||
mode = mlir::triton::nvgpu::WGMMADescMode::swizzle64;
|
||||
swizzlingByteWidth = 64;
|
||||
} else if (perPhase == 1 && maxPhase == 8) {
|
||||
mode = mlir::triton::nvgpu::WGMMADescMode::swizzle128;
|
||||
swizzlingByteWidth = 128;
|
||||
} else {
|
||||
llvm::report_fatal_error("Unsupported shared layout.");
|
||||
}
|
||||
|
||||
// TODO[biaow]: remove it once we support swizzling size larger than matrix
|
||||
// width, which requires padding the matrix width to the swizzling size when
|
||||
// allocating shared memory.
|
||||
assert(swizzlingByteWidth <= widthInByte &&
|
||||
"swizzling size larger than matrix width is not supported.");
|
||||
return mode;
|
||||
}
|
||||
|
||||
class DotOpMmaV3SmemLoader {
|
||||
public:
|
||||
DotOpMmaV3SmemLoader(Value tensor, const SharedMemoryObject &smemObj,
|
||||
SmallVector<int64_t> shape, Value warpId,
|
||||
unsigned int dimWpt, bool trans,
|
||||
SmallVector<unsigned int> instrShape,
|
||||
ConversionPatternRewriter &rewriter, Location loc)
|
||||
: base(smemObj.base), shape(shape), warpId(warpId), dimWpt(dimWpt),
|
||||
trans(trans), instrShape(instrShape), rewriter(rewriter), loc(loc) {
|
||||
auto tensorTy = tensor.getType().cast<RankedTensorType>();
|
||||
auto sharedLayout = tensorTy.getEncoding().cast<SharedEncodingAttr>();
|
||||
ord = sharedLayout.getOrder();
|
||||
const int perPhase = sharedLayout.getPerPhase();
|
||||
const int maxPhase = sharedLayout.getMaxPhase();
|
||||
elemBytes = tensorTy.getElementTypeBitWidth() / 8;
|
||||
elemsPerSwizzlingRow = 128 / perPhase / elemBytes;
|
||||
elemsPerSwizzlingRowVal = i32_val(elemsPerSwizzlingRow);
|
||||
|
||||
uint32_t widthInByte = shape[ord[0]] * elemBytes;
|
||||
mode = getModeFromLayout(sharedLayout, widthInByte);
|
||||
|
||||
baseDesc = rewriter.create<triton::nvgpu::WGMMADescCreateOp>(
|
||||
loc, i64_ty, base, i32_val(shape[ord[1]]), mode);
|
||||
}
|
||||
|
||||
Value smemLoad(int a, int b) {
|
||||
Value k = i32_val(b * instrShape[1]);
|
||||
Value m = add(i32_val(a * dimWpt * instrShape[0]),
|
||||
mul(warpId, i32_val(instrShape[0])));
|
||||
if (trans) {
|
||||
std::swap(k, m);
|
||||
}
|
||||
Value leading_offset = mul(udiv(k, elemsPerSwizzlingRowVal),
|
||||
i32_val(shape[ord[1]] * elemsPerSwizzlingRow));
|
||||
Value stride_offset = mul(m, elemsPerSwizzlingRowVal);
|
||||
Value offset = add(add(leading_offset, stride_offset),
|
||||
urem(k, elemsPerSwizzlingRowVal));
|
||||
Value off1 = mul(i32_val(elemBytes), offset);
|
||||
Value off_ = zext(i64_ty, udiv(off1, i32_val(16)));
|
||||
|
||||
return add(baseDesc, off_);
|
||||
}
|
||||
|
||||
private:
|
||||
Value base;
|
||||
SmallVector<int64_t> shape;
|
||||
Value warpId;
|
||||
int dimWpt;
|
||||
bool trans;
|
||||
Value elemsPerSwizzlingRowVal;
|
||||
mlir::triton::nvgpu::WGMMADescMode mode;
|
||||
SmallVector<unsigned int> instrShape;
|
||||
ArrayRef<unsigned> ord;
|
||||
ConversionPatternRewriter &rewriter;
|
||||
Location loc;
|
||||
int elemsPerSwizzlingRow;
|
||||
int elemBytes;
|
||||
Value baseDesc;
|
||||
};
|
||||
|
||||
DotOpMmaV3SmemLoader loadA(TritonGPUToLLVMTypeConverter *typeConverter,
|
||||
ConversionPatternRewriter &rewriter, Location loc,
|
||||
const MmaEncodingAttr &mmaEncoding, Value tensor,
|
||||
const SharedMemoryObject &smemObj, Value thread) {
|
||||
auto aTensorTy = tensor.getType().cast<RankedTensorType>();
|
||||
auto aSharedLayout = aTensorTy.getEncoding().dyn_cast<SharedEncodingAttr>();
|
||||
assert(aSharedLayout && "only support load dot operand from shared.");
|
||||
auto instrShape = mmaEncoding.getInstrShape();
|
||||
auto wpt = mmaEncoding.getWarpsPerCTA();
|
||||
auto aOrd = aSharedLayout.getOrder();
|
||||
bool transA = aOrd[0] == 0;
|
||||
auto shapePerCTA = getShapePerCTA(aTensorTy);
|
||||
|
||||
int numRepM = ceil<unsigned>(shapePerCTA[0], instrShape[0] * wpt[0]);
|
||||
int numRepK = ceil<unsigned>(shapePerCTA[1], instrShape[2]);
|
||||
|
||||
Value warp = udiv(thread, i32_val(32));
|
||||
Value warpM = urem(warp, i32_val(wpt[0]));
|
||||
Value warpId = urem(warpM, i32_val(shapePerCTA[0] / instrShape[0]));
|
||||
|
||||
return {tensor,
|
||||
smemObj,
|
||||
shapePerCTA,
|
||||
warpId,
|
||||
wpt[0],
|
||||
transA,
|
||||
{instrShape[0], instrShape[2]},
|
||||
rewriter,
|
||||
loc};
|
||||
}
|
||||
|
||||
DotOpMmaV3SmemLoader loadB(TritonGPUToLLVMTypeConverter *typeConverter,
|
||||
ConversionPatternRewriter &rewriter, Location loc,
|
||||
MmaEncodingAttr &mmaEncoding, Value tensor,
|
||||
const SharedMemoryObject &smemObj, Value thread) {
|
||||
auto bTensorTy = tensor.getType().cast<RankedTensorType>();
|
||||
auto bSharedLayout = bTensorTy.getEncoding().cast<SharedEncodingAttr>();
|
||||
assert(bSharedLayout && "only support load B from shared.");
|
||||
auto instrShape = mmaEncoding.getInstrShape();
|
||||
auto wpt = mmaEncoding.getWarpsPerCTA();
|
||||
auto bOrd = bSharedLayout.getOrder();
|
||||
bool transB = bOrd[0] == 1;
|
||||
auto shapePerCTA = triton::gpu::getShapePerCTA(bTensorTy);
|
||||
|
||||
int numRepK = ceil<unsigned>(shapePerCTA[0], instrShape[2]);
|
||||
int numRepN = ceil<unsigned>(shapePerCTA[1], instrShape[1] * wpt[1]);
|
||||
|
||||
Value warp = udiv(thread, i32_val(32));
|
||||
Value warpMN = udiv(warp, i32_val(wpt[0]));
|
||||
Value warpN = urem(warpMN, i32_val(wpt[1]));
|
||||
Value warpId = urem(warpN, i32_val(shapePerCTA[1] / instrShape[1]));
|
||||
|
||||
return {tensor,
|
||||
smemObj,
|
||||
shapePerCTA,
|
||||
warpId,
|
||||
wpt[1],
|
||||
transB,
|
||||
{instrShape[1], instrShape[2]},
|
||||
rewriter,
|
||||
loc};
|
||||
}
|
||||
|
||||
LogicalResult convertDot(TritonGPUToLLVMTypeConverter *typeConverter,
|
||||
ConversionPatternRewriter &rewriter, Location loc,
|
||||
Operation *op, Value a, Value b, Value c, Value d,
|
||||
Value loadedA, Value loadedB, Value loadedC,
|
||||
bool allowTF32, const SharedMemoryObject &smemObjA,
|
||||
const SharedMemoryObject &smemObjB, bool sync,
|
||||
Value thread) {
|
||||
auto aTensorTy = a.getType().cast<RankedTensorType>();
|
||||
auto bTensorTy = b.getType().cast<RankedTensorType>();
|
||||
auto dTensorTy = d.getType().cast<RankedTensorType>();
|
||||
auto aSharedLayout = aTensorTy.getEncoding().cast<SharedEncodingAttr>();
|
||||
auto bSharedLayout = bTensorTy.getEncoding().cast<SharedEncodingAttr>();
|
||||
auto mmaEncoding = dTensorTy.getEncoding().cast<MmaEncodingAttr>();
|
||||
auto aOrd = aSharedLayout.getOrder();
|
||||
auto bOrd = bSharedLayout.getOrder();
|
||||
bool transA = aOrd[0] == 0;
|
||||
bool transB = bOrd[0] == 1;
|
||||
auto dShapePerCTA = getShapePerCTA(dTensorTy);
|
||||
auto instrShape = mmaEncoding.getInstrShape();
|
||||
auto accSize = 2 * (instrShape[1] / 4);
|
||||
Type resElemTy = dTensorTy.getElementType();
|
||||
llvm::SmallVector<Type> elemTypes(accSize, resElemTy);
|
||||
auto accTy =
|
||||
LLVM::LLVMStructType::getLiteral(rewriter.getContext(), elemTypes);
|
||||
|
||||
int M = 4 * instrShape[0];
|
||||
int N = instrShape[1];
|
||||
int K = instrShape[2];
|
||||
|
||||
auto shapePerCTATile = getShapePerCTATile(mmaEncoding);
|
||||
int numRepM = ceil<unsigned>(dShapePerCTA[0], shapePerCTATile[0]);
|
||||
int numRepN = ceil<unsigned>(dShapePerCTA[1], shapePerCTATile[1]);
|
||||
int numRepK = ceil<unsigned>(aTensorTy.getShape()[1], instrShape[2]);
|
||||
|
||||
DotOpMmaV3SmemLoader aLoader =
|
||||
loadA(typeConverter, rewriter, loc, mmaEncoding, a, smemObjA, thread);
|
||||
DotOpMmaV3SmemLoader bLoader =
|
||||
loadB(typeConverter, rewriter, loc, mmaEncoding, b, smemObjB, thread);
|
||||
|
||||
auto fc = typeConverter->unpackLLElements(loc, loadedC, rewriter, dTensorTy);
|
||||
|
||||
triton::nvgpu::WGMMAEltType eltTypeC = getMmaRetType(d);
|
||||
assert(eltTypeC != triton::nvgpu::WGMMAEltType::f16 &&
|
||||
"TODO support f16 return type. This requires packing C into "
|
||||
"vector<2xf16> type.");
|
||||
triton::nvgpu::WGMMAEltType eltTypeA = getMmaOperandType(a, allowTF32);
|
||||
triton::nvgpu::WGMMAEltType eltTypeB = eltTypeA;
|
||||
|
||||
triton::nvgpu::WGMMALayout layoutA = transA ? triton::nvgpu::WGMMALayout::col
|
||||
: triton::nvgpu::WGMMALayout::row;
|
||||
triton::nvgpu::WGMMALayout layoutB = transB ? triton::nvgpu::WGMMALayout::row
|
||||
: triton::nvgpu::WGMMALayout::col;
|
||||
|
||||
auto func = op->getParentOfType<LLVM::LLVMFuncOp>();
|
||||
int numTMADescs =
|
||||
func->getAttr(kAttrNumTMALoadDescsName).cast<IntegerAttr>().getInt();
|
||||
if (numTMADescs == 0)
|
||||
rewriter.create<triton::nvgpu::FenceAsyncSharedOp>(loc, 0);
|
||||
rewriter.create<triton::nvgpu::WGMMAFenceOp>(loc);
|
||||
|
||||
llvm::SmallVector<Value> mmaOut(accSize);
|
||||
for (int m = 0; m < numRepM; ++m) {
|
||||
for (int n = 0; n < numRepN; ++n) {
|
||||
// reuse the same mmaOut
|
||||
for (int i = 0; i < accSize; ++i) {
|
||||
mmaOut[i] = fc[(m * numRepN + n) * accSize + i];
|
||||
}
|
||||
Value d = typeConverter->packLLElements(loc, mmaOut, rewriter, accTy);
|
||||
for (int k = 0; k < numRepK; ++k) {
|
||||
auto a = aLoader.smemLoad(m, k);
|
||||
auto b = bLoader.smemLoad(n, k);
|
||||
ValueRange operands{a, b, d};
|
||||
d = rewriter.create<triton::nvgpu::WGMMAOp>(loc, accTy, a, b, d, M, N,
|
||||
K, eltTypeC, eltTypeA,
|
||||
eltTypeB, layoutA, layoutB);
|
||||
}
|
||||
auto acc = typeConverter->unpackLLElements(loc, d, rewriter, accTy);
|
||||
for (int i = 0; i < acc.size(); ++i) {
|
||||
fc[(m * numRepN + n) * accSize + i] = acc[i];
|
||||
}
|
||||
}
|
||||
}
|
||||
rewriter.create<triton::nvgpu::WGMMACommitGroupOp>(loc);
|
||||
|
||||
if (sync)
|
||||
rewriter.create<triton::nvgpu::WGMMAWaitOp>(loc, 0);
|
||||
|
||||
for (auto &elem : fc) {
|
||||
elem = bitcast(elem, resElemTy);
|
||||
}
|
||||
// replace with new packed result
|
||||
Type structTy = LLVM::LLVMStructType::getLiteral(
|
||||
mmaEncoding.getContext(), SmallVector<Type>(fc.size(), resElemTy));
|
||||
auto res = typeConverter->packLLElements(loc, fc, rewriter, structTy);
|
||||
rewriter.replaceOp(op, res);
|
||||
return success();
|
||||
}
|
||||
|
||||
// Loading $c to registers, returns a Value.
|
||||
Value loadC(Value tensor, Value llTensor) {
|
||||
auto tensorTy = tensor.getType().cast<RankedTensorType>();
|
||||
auto mmaEncoding = tensorTy.getEncoding().dyn_cast<MmaEncodingAttr>();
|
||||
assert(mmaEncoding && "Currently, we only support $c with a mma layout.");
|
||||
auto instrShape = mmaEncoding.getInstrShape();
|
||||
auto wpt = mmaEncoding.getWarpsPerCTA();
|
||||
auto shapePerCTA = getShapePerCTA(tensorTy);
|
||||
auto shapePerCTATile = getShapePerCTATile(mmaEncoding);
|
||||
|
||||
int numRepM = ceil<unsigned>(shapePerCTA[0], shapePerCTATile[0]);
|
||||
int numRepN = ceil<unsigned>(shapePerCTA[1], shapePerCTATile[1]);
|
||||
|
||||
size_t fcSize = 2 * (instrShape[1] / 4) * numRepM * numRepN;
|
||||
|
||||
auto structTy = llTensor.getType().cast<LLVM::LLVMStructType>();
|
||||
assert(structTy.getBody().size() == fcSize &&
|
||||
"DotOp's $c operand should pass the same number of values as $d in "
|
||||
"mma layout.");
|
||||
return llTensor;
|
||||
}
|
||||
|
||||
LogicalResult convertWGMMA(triton::DotOp op, triton::DotOp::Adaptor adaptor,
|
||||
TritonGPUToLLVMTypeConverter *typeConverter,
|
||||
ConversionPatternRewriter &rewriter, Value thread) {
|
||||
auto loc = op.getLoc();
|
||||
Value A = op.getA();
|
||||
Value B = op.getB();
|
||||
Value C = op.getC();
|
||||
auto ATensorTy = A.getType().cast<RankedTensorType>();
|
||||
auto BTensorTy = B.getType().cast<RankedTensorType>();
|
||||
|
||||
assert(ATensorTy.getEncoding().isa<SharedEncodingAttr>() &&
|
||||
BTensorTy.getEncoding().isa<SharedEncodingAttr>() &&
|
||||
"Both $a and %b should be Shared layout.");
|
||||
|
||||
Value llA, llB, llC;
|
||||
llA = adaptor.getA();
|
||||
llB = adaptor.getB();
|
||||
llC = loadC(C, adaptor.getC());
|
||||
|
||||
auto smemObjA = getSharedMemoryObjectFromStruct(loc, llA, rewriter);
|
||||
auto smemObjB = getSharedMemoryObjectFromStruct(loc, llB, rewriter);
|
||||
return convertDot(typeConverter, rewriter, loc, op.getOperation(), A, B, C,
|
||||
op.getD(), llA, llB, llC, op.getAllowTF32(), smemObjA,
|
||||
smemObjB, true, thread);
|
||||
}
|
||||
|
||||
LogicalResult convertAsyncWGMMA(triton::nvidia_gpu::DotAsyncOp op,
|
||||
triton::nvidia_gpu::DotAsyncOp::Adaptor adaptor,
|
||||
TritonGPUToLLVMTypeConverter *typeConverter,
|
||||
ConversionPatternRewriter &rewriter,
|
||||
Value thread) {
|
||||
auto loc = op.getLoc();
|
||||
Value A = op.getA();
|
||||
Value B = op.getB();
|
||||
Value C = op.getC();
|
||||
auto ATensorTy = A.getType().cast<RankedTensorType>();
|
||||
auto BTensorTy = B.getType().cast<RankedTensorType>();
|
||||
|
||||
assert(ATensorTy.getEncoding().isa<SharedEncodingAttr>() &&
|
||||
BTensorTy.getEncoding().isa<SharedEncodingAttr>() &&
|
||||
"Both $a and %b should be Shared layout.");
|
||||
|
||||
Value llA, llB, llC;
|
||||
llA = adaptor.getA();
|
||||
llB = adaptor.getB();
|
||||
llC = loadC(C, adaptor.getC());
|
||||
|
||||
auto smemObjA = getSharedMemoryObjectFromStruct(loc, llA, rewriter);
|
||||
auto smemObjB = getSharedMemoryObjectFromStruct(loc, llB, rewriter);
|
||||
return convertDot(typeConverter, rewriter, loc, op.getOperation(), A, B, C,
|
||||
op.getD(), llA, llB, llC, op.getAllowTF32(), smemObjA,
|
||||
smemObjB, false, thread);
|
||||
}
|
||||
@@ -353,6 +353,13 @@ static SmallVector<Value> reorderValues(const SmallVector<Value> &values,
|
||||
llvm_unreachable("unimplemented code path");
|
||||
}
|
||||
|
||||
inline Type getElementType(Value value) {
|
||||
auto type = value.getType();
|
||||
if (auto tensorType = type.dyn_cast<RankedTensorType>())
|
||||
return tensorType.getElementType();
|
||||
return type;
|
||||
}
|
||||
|
||||
inline SmallVector<Value> unpackI32(const SmallVector<Value> &inValues,
|
||||
Type srcTy,
|
||||
ConversionPatternRewriter &rewriter,
|
||||
@@ -1149,7 +1156,8 @@ struct IndexCastOpLowering
|
||||
|
||||
void populateElementwiseOpToLLVMPatterns(
|
||||
TritonGPUToLLVMTypeConverter &typeConverter, RewritePatternSet &patterns,
|
||||
PatternBenefit benefit) {
|
||||
int numWarps, ModuleAxisInfoAnalysis &axisInfoAnalysis,
|
||||
ModuleAllocation &allocation, PatternBenefit benefit) {
|
||||
#define POPULATE_TERNARY_OP(SRC_OP, DST_OP) \
|
||||
patterns.add<ElementwiseOpConversion<SRC_OP, DST_OP>>(typeConverter, benefit);
|
||||
POPULATE_TERNARY_OP(triton::gpu::SelectOp, LLVM::SelectOp)
|
||||
@@ -1222,3 +1230,161 @@ void populateElementwiseOpToLLVMPatterns(
|
||||
// __nv_expf for higher-precision calculation
|
||||
patterns.add<ExpOpConversionApprox>(typeConverter, benefit);
|
||||
}
|
||||
|
||||
struct FPExtOpConversion
|
||||
: ElementwiseOpConversionBase<LLVM::FPExtOp, FPExtOpConversion> {
|
||||
using Base = ElementwiseOpConversionBase<LLVM::FPExtOp, FPExtOpConversion>;
|
||||
using Base::Base;
|
||||
using Adaptor = typename Base::OpAdaptor;
|
||||
|
||||
static bool isLegalOp(LLVM::FPExtOp op) {
|
||||
auto retTy = op.getResult().getType();
|
||||
auto srcTy = op.getOperand().getType();
|
||||
if (retTy.isF32() && srcTy.isF16()) {
|
||||
return false;
|
||||
}
|
||||
return true;
|
||||
}
|
||||
|
||||
SmallVector<Value> createDestOps(LLVM::FPExtOp op, OpAdaptor adaptor,
|
||||
ConversionPatternRewriter &rewriter,
|
||||
Type elemTy, MultipleOperandsRange operands,
|
||||
Location loc) const {
|
||||
return {
|
||||
FpToFpOpConversion::convertFp16ToFp32(loc, rewriter, operands[0][0])};
|
||||
}
|
||||
};
|
||||
|
||||
struct FPTruncOpConversion
|
||||
: ElementwiseOpConversionBase<LLVM::FPTruncOp, FPTruncOpConversion> {
|
||||
using Base =
|
||||
ElementwiseOpConversionBase<LLVM::FPTruncOp, FPTruncOpConversion>;
|
||||
using Base::Base;
|
||||
using Adaptor = typename Base::OpAdaptor;
|
||||
|
||||
static bool isLegalOp(LLVM::FPTruncOp op) {
|
||||
auto retTy = op.getResult().getType();
|
||||
auto srcTy = op.getOperand().getType();
|
||||
if (retTy.isF16() && srcTy.isF32()) {
|
||||
return false;
|
||||
}
|
||||
return true;
|
||||
}
|
||||
|
||||
SmallVector<Value> createDestOps(LLVM::FPTruncOp op, OpAdaptor adaptor,
|
||||
ConversionPatternRewriter &rewriter,
|
||||
Type elemTy, MultipleOperandsRange operands,
|
||||
Location loc) const {
|
||||
return {
|
||||
FpToFpOpConversion::convertFp32ToFp16(loc, rewriter, operands[0][0])};
|
||||
}
|
||||
};
|
||||
|
||||
struct TruncOpConversion
|
||||
: ElementwiseOpConversionBase<LLVM::TruncOp, TruncOpConversion> {
|
||||
using Base = ElementwiseOpConversionBase<LLVM::TruncOp, TruncOpConversion>;
|
||||
using Base::Base;
|
||||
using Adaptor = typename Base::OpAdaptor;
|
||||
|
||||
static bool isLegalOp(LLVM::TruncOp op) {
|
||||
auto retTy = op.getResult().getType();
|
||||
auto srcTy = op.getOperand().getType();
|
||||
if (retTy.isInteger(16) && srcTy.isInteger(32)) {
|
||||
return false;
|
||||
}
|
||||
return true;
|
||||
}
|
||||
|
||||
SmallVector<Value> createDestOps(LLVM::TruncOp op, OpAdaptor adaptor,
|
||||
ConversionPatternRewriter &rewriter,
|
||||
Type elemTy, MultipleOperandsRange operands,
|
||||
Location loc) const {
|
||||
PTXBuilder builder;
|
||||
auto &cvt = *builder.create("cvt.u16.u32");
|
||||
auto res = builder.newOperand("=h");
|
||||
auto operand = builder.newOperand(operands[0][0], "r");
|
||||
cvt(res, operand);
|
||||
return {builder.launch(rewriter, loc, i16_ty, false)};
|
||||
}
|
||||
};
|
||||
|
||||
struct SExtOpConversion
|
||||
: ElementwiseOpConversionBase<LLVM::SExtOp, SExtOpConversion> {
|
||||
using Base = ElementwiseOpConversionBase<LLVM::SExtOp, SExtOpConversion>;
|
||||
using Base::Base;
|
||||
using Adaptor = typename Base::OpAdaptor;
|
||||
|
||||
static bool isLegalOp(LLVM::SExtOp op) {
|
||||
auto retTy = op.getResult().getType();
|
||||
auto srcTy = op.getOperand().getType();
|
||||
if (retTy.isInteger(32) && srcTy.isInteger(16)) {
|
||||
return false;
|
||||
}
|
||||
return true;
|
||||
}
|
||||
|
||||
SmallVector<Value> createDestOps(LLVM::SExtOp op, OpAdaptor adaptor,
|
||||
ConversionPatternRewriter &rewriter,
|
||||
Type elemTy, MultipleOperandsRange operands,
|
||||
Location loc) const {
|
||||
PTXBuilder builder;
|
||||
auto &cvt = *builder.create("cvt.s32.s16");
|
||||
auto res = builder.newOperand("=r");
|
||||
auto operand = builder.newOperand(operands[0][0], "h");
|
||||
cvt(res, operand);
|
||||
return {builder.launch(rewriter, loc, i32_ty, false)};
|
||||
}
|
||||
};
|
||||
|
||||
struct ZExtOpConversion
|
||||
: ElementwiseOpConversionBase<LLVM::ZExtOp, ZExtOpConversion> {
|
||||
using Base = ElementwiseOpConversionBase<LLVM::ZExtOp, ZExtOpConversion>;
|
||||
using Base::Base;
|
||||
using Adaptor = typename Base::OpAdaptor;
|
||||
|
||||
static bool isLegalOp(LLVM::ZExtOp op) {
|
||||
auto retTy = op.getResult().getType();
|
||||
auto srcTy = op.getOperand().getType();
|
||||
if (retTy.isInteger(32) && srcTy.isInteger(16)) {
|
||||
return false;
|
||||
}
|
||||
return true;
|
||||
}
|
||||
|
||||
SmallVector<Value> createDestOps(LLVM::ZExtOp op, OpAdaptor adaptor,
|
||||
ConversionPatternRewriter &rewriter,
|
||||
Type elemTy, MultipleOperandsRange operands,
|
||||
Location loc) const {
|
||||
PTXBuilder builder;
|
||||
auto &cvt = *builder.create("cvt.u32.u16");
|
||||
auto res = builder.newOperand("=r");
|
||||
auto operand = builder.newOperand(operands[0][0], "h");
|
||||
cvt(res, operand);
|
||||
return {builder.launch(rewriter, loc, i32_ty, false)};
|
||||
}
|
||||
};
|
||||
|
||||
bool isLegalElementwiseOp(Operation *op) {
|
||||
if (isa<LLVM::FPExtOp>(op)) {
|
||||
return FPExtOpConversion::isLegalOp(cast<LLVM::FPExtOp>(op));
|
||||
} else if (isa<LLVM::FPTruncOp>(op)) {
|
||||
return FPTruncOpConversion::isLegalOp(cast<LLVM::FPTruncOp>(op));
|
||||
} else if (isa<LLVM::TruncOp>(op)) {
|
||||
return TruncOpConversion::isLegalOp(cast<LLVM::TruncOp>(op));
|
||||
} else if (isa<LLVM::SExtOp>(op)) {
|
||||
return SExtOpConversion::isLegalOp(cast<LLVM::SExtOp>(op));
|
||||
} else if (isa<LLVM::ZExtOp>(op)) {
|
||||
return ZExtOpConversion::isLegalOp(cast<LLVM::ZExtOp>(op));
|
||||
}
|
||||
return true;
|
||||
}
|
||||
|
||||
void populateElementwiseOpToPTXPatterns(
|
||||
TritonGPUToLLVMTypeConverter &typeConverter, RewritePatternSet &patterns,
|
||||
PatternBenefit benefit) {
|
||||
patterns.add<FPExtOpConversion>(typeConverter, benefit);
|
||||
patterns.add<FPTruncOpConversion>(typeConverter, benefit);
|
||||
patterns.add<TruncOpConversion>(typeConverter, benefit);
|
||||
patterns.add<SExtOpConversion>(typeConverter, benefit);
|
||||
patterns.add<ZExtOpConversion>(typeConverter, benefit);
|
||||
}
|
||||
|
||||
@@ -8,8 +8,13 @@ using namespace mlir::triton;
|
||||
|
||||
void populateElementwiseOpToLLVMPatterns(
|
||||
TritonGPUToLLVMTypeConverter &typeConverter, RewritePatternSet &patterns,
|
||||
PatternBenefit benefit);
|
||||
int numWarps, ModuleAxisInfoAnalysis &axisInfoAnalysis,
|
||||
ModuleAllocation &allocation, PatternBenefit benefit);
|
||||
|
||||
bool isLegalElementwiseOp(Operation *op);
|
||||
|
||||
void populateElementwiseOpToPTXPatterns(
|
||||
TritonGPUToLLVMTypeConverter &typeConverter, RewritePatternSet &patterns,
|
||||
PatternBenefit benefit);
|
||||
|
||||
#endif
|
||||
|
||||
@@ -4,10 +4,15 @@
|
||||
#include "ConvertLayoutOpToLLVM.h"
|
||||
#include "LoadStoreOpToLLVM.h"
|
||||
|
||||
#include <numeric>
|
||||
|
||||
using namespace mlir;
|
||||
using namespace mlir::triton;
|
||||
|
||||
using ::mlir::LLVM::delinearize;
|
||||
using ::mlir::LLVM::getSharedMemoryObjectFromStruct;
|
||||
using ::mlir::LLVM::linearize;
|
||||
using ::mlir::triton::gpu::getShapePerCTA;
|
||||
using ::mlir::triton::gpu::getTotalElemsPerThread;
|
||||
using ::mlir::triton::gpu::SharedEncodingAttr;
|
||||
|
||||
@@ -64,6 +69,9 @@ struct LoadOpConversion
|
||||
Value other = op.getOther();
|
||||
|
||||
// adaptor values
|
||||
assert(!isTensorPointerType(ptr.getType()) &&
|
||||
"Cannot convert load with a tensor pointer into LLVM; "
|
||||
"this case should be transformed to normal load before lowering");
|
||||
Value llPtr = adaptor.getPtr();
|
||||
Value llMask = adaptor.getMask();
|
||||
Value llOther = adaptor.getOther();
|
||||
@@ -378,6 +386,251 @@ struct StoreOpConversion
|
||||
return success();
|
||||
}
|
||||
};
|
||||
// TODO: refactor to save common logic with insertsliceasyncv2
|
||||
struct StoreAsyncOpConversion
|
||||
: public ConvertTritonGPUOpToLLVMPattern<triton::nvidia_gpu::StoreAsyncOp> {
|
||||
using ConvertTritonGPUOpToLLVMPattern<
|
||||
triton::nvidia_gpu::StoreAsyncOp>::ConvertTritonGPUOpToLLVMPattern;
|
||||
|
||||
StoreAsyncOpConversion(TritonGPUToLLVMTypeConverter &converter,
|
||||
ModuleAllocation &allocation,
|
||||
mlir::triton::gpu::TMAMetadataTy *tmaMetadata,
|
||||
const TensorPtrMapT *tensorPtrMap,
|
||||
PatternBenefit benefit)
|
||||
: ConvertTritonGPUOpToLLVMPattern<triton::nvidia_gpu::StoreAsyncOp>(
|
||||
converter, allocation, tmaMetadata, benefit),
|
||||
tensorPtrMap(tensorPtrMap) {}
|
||||
|
||||
LogicalResult
|
||||
matchAndRewrite(triton::nvidia_gpu::StoreAsyncOp op, OpAdaptor adaptor,
|
||||
ConversionPatternRewriter &rewriter) const override {
|
||||
auto loc = op.getLoc();
|
||||
MLIRContext *ctx = rewriter.getContext();
|
||||
|
||||
auto dst = op.getDst();
|
||||
auto src = op.getSrc();
|
||||
auto srcTy = src.getType().cast<RankedTensorType>();
|
||||
auto elemTy = srcTy.getElementType();
|
||||
|
||||
auto rank = srcTy.getRank();
|
||||
assert(rank > 0 && rank <= 5);
|
||||
|
||||
auto moduleOp = op->getParentOfType<ModuleOp>();
|
||||
assert(moduleOp && "Parent ModuleOp not found for StoreAsyncOp");
|
||||
|
||||
auto llFuncOp = op->getParentOfType<LLVM::LLVMFuncOp>();
|
||||
assert(llFuncOp && "LLVMFuncOp not found for StoreAsyncOp");
|
||||
|
||||
int numTMADescs = getNumTMADescs(llFuncOp);
|
||||
assert(numTMADescs > 0);
|
||||
|
||||
auto sharedLayout = srcTy.getEncoding().dyn_cast<SharedEncodingAttr>();
|
||||
assert(sharedLayout && "expected shared encoding");
|
||||
|
||||
mlir::triton::gpu::TMAInfo tmaInfo;
|
||||
|
||||
tmaInfo.tensorDataType = getCUtensorMapDataType(elemTy);
|
||||
tmaInfo.tensorRank = rank;
|
||||
assert(tmaMetadata);
|
||||
|
||||
auto inOrder = sharedLayout.getOrder();
|
||||
unsigned TMADescIdx = tmaMetadata->size();
|
||||
unsigned numFuncArgs = llFuncOp.getBody().front().getNumArguments();
|
||||
auto makeTensorPtr = tensorPtrMap->lookup(op.getOperation());
|
||||
auto dstOrder = makeTensorPtr.getOrder();
|
||||
|
||||
unsigned globalAddressArgIdx = getArgIdx(makeTensorPtr.getBase());
|
||||
tmaInfo.globalAddressArgIdx = globalAddressArgIdx;
|
||||
tmaInfo.TMADescArgIdx = numFuncArgs - numTMADescs + TMADescIdx;
|
||||
|
||||
auto getDimOfOrder = [](ArrayRef<int32_t> order, int32_t i) {
|
||||
auto it = std::find(order.begin(), order.end(), i);
|
||||
assert(it != order.end());
|
||||
return std::distance(order.begin(), it);
|
||||
};
|
||||
|
||||
std::vector<int32_t> globalDimsArgIdx;
|
||||
std::vector<int32_t> globalStridesArgIdx;
|
||||
// constant values are mapped to (-1 - value)
|
||||
for (int i = 0; i < rank; ++i) {
|
||||
int32_t argIdx = -1;
|
||||
auto dim = getDimOfOrder(dstOrder, i);
|
||||
argIdx = getArgIdx(makeTensorPtr.getShape()[dim]);
|
||||
globalDimsArgIdx.emplace_back(argIdx);
|
||||
// handle constant stride
|
||||
argIdx = getArgIdx(makeTensorPtr.getStrides()[dim]);
|
||||
globalStridesArgIdx.emplace_back(argIdx);
|
||||
}
|
||||
|
||||
tmaInfo.globalDimsArgIdx = globalDimsArgIdx;
|
||||
tmaInfo.globalStridesArgIdx = globalStridesArgIdx;
|
||||
std::vector<uint32_t> boxDims;
|
||||
auto CTAsPerCGA = sharedLayout.getCTALayout().getCTAsPerCGA();
|
||||
auto CTAOrder = sharedLayout.getCTALayout().getCTAOrder();
|
||||
auto CTASplitNum = sharedLayout.getCTALayout().getCTASplitNum();
|
||||
auto tensorShape = makeTensorPtr.getResult()
|
||||
.getType()
|
||||
.cast<triton::PointerType>()
|
||||
.getPointeeType()
|
||||
.cast<RankedTensorType>()
|
||||
.getShape();
|
||||
auto shapePerCTA = getShapePerCTA(CTASplitNum, tensorShape);
|
||||
// magic 128 bytes
|
||||
uint32_t bytesPerElem = elemTy.getIntOrFloatBitWidth() / 8;
|
||||
uint32_t numBox{1};
|
||||
for (int i = 0; i < rank; ++i) {
|
||||
auto dim = getDimOfOrder(dstOrder, i);
|
||||
auto tNumElems = shapePerCTA[dim];
|
||||
if (i == 0 && tNumElems * bytesPerElem > 128) {
|
||||
tNumElems = 128 / bytesPerElem;
|
||||
numBox = (shapePerCTA[dim] + tNumElems - 1) / tNumElems;
|
||||
}
|
||||
boxDims.emplace_back(tNumElems);
|
||||
}
|
||||
std::vector<uint32_t> elementStrides(rank, 1);
|
||||
tmaInfo.boxDims = boxDims;
|
||||
tmaInfo.elementStrides = elementStrides;
|
||||
|
||||
CUtensorMapSwizzle swizzle = CUtensorMapSwizzle::CU_TENSOR_MAP_SWIZZLE_NONE;
|
||||
assert(
|
||||
((elemTy.getIntOrFloatBitWidth() == 16 && sharedLayout.getVec() == 8) or
|
||||
(elemTy.getIntOrFloatBitWidth() == 32 &&
|
||||
sharedLayout.getVec() == 4)) &&
|
||||
"Unexpected shared layout for StoreAsyncOp");
|
||||
if (sharedLayout.getPerPhase() == 4 && sharedLayout.getMaxPhase() == 2)
|
||||
swizzle = CUtensorMapSwizzle::CU_TENSOR_MAP_SWIZZLE_32B;
|
||||
else if (sharedLayout.getPerPhase() == 2 && sharedLayout.getMaxPhase() == 4)
|
||||
swizzle = CUtensorMapSwizzle::CU_TENSOR_MAP_SWIZZLE_64B;
|
||||
else if (sharedLayout.getPerPhase() == 1 && sharedLayout.getMaxPhase() == 8)
|
||||
swizzle = CUtensorMapSwizzle::CU_TENSOR_MAP_SWIZZLE_128B;
|
||||
else
|
||||
llvm::report_fatal_error("Unsupported shared layout for StoreAsyncOp");
|
||||
tmaInfo.swizzle = swizzle;
|
||||
tmaInfo.interleave = CUtensorMapInterleave::CU_TENSOR_MAP_INTERLEAVE_NONE;
|
||||
tmaInfo.l2Promotion =
|
||||
CUtensorMapL2promotion::CU_TENSOR_MAP_L2_PROMOTION_L2_128B;
|
||||
tmaInfo.oobFill =
|
||||
CUtensorMapFloatOOBfill::CU_TENSOR_MAP_FLOAT_OOB_FILL_NONE;
|
||||
|
||||
tmaMetadata->emplace_back(tmaInfo);
|
||||
|
||||
Value llDst = adaptor.getDst();
|
||||
Value llSrc = adaptor.getSrc();
|
||||
auto srcShape = srcTy.getShape();
|
||||
auto smemObj = getSharedMemoryObjectFromStruct(loc, llSrc, rewriter);
|
||||
|
||||
SmallVector<Value> offsetVals;
|
||||
for (auto i = 0; i < srcShape.size(); ++i) {
|
||||
offsetVals.emplace_back(i32_val(0));
|
||||
}
|
||||
|
||||
Value tmaDesc =
|
||||
llFuncOp.getBody().front().getArgument(tmaInfo.TMADescArgIdx);
|
||||
auto ptrI8SharedTy = LLVM::LLVMPointerType::get(
|
||||
typeConverter->convertType(rewriter.getI8Type()), 3);
|
||||
|
||||
auto threadId = getThreadId(rewriter, loc);
|
||||
Value pred = icmp_eq(threadId, i32_val(0));
|
||||
|
||||
auto llCoord = getTypeConverter()->unpackLLElements(loc, llDst, rewriter,
|
||||
dst.getType());
|
||||
uint32_t boxStride = std::accumulate(boxDims.begin(), boxDims.end(), 1,
|
||||
std::multiplies<uint32_t>());
|
||||
|
||||
Value clusterCTAId = getClusterCTAId(rewriter, loc);
|
||||
SmallVector<Value> multiDimClusterCTAId =
|
||||
delinearize(rewriter, loc, clusterCTAId, CTAsPerCGA, CTAOrder);
|
||||
|
||||
rewriter.create<triton::nvgpu::FenceAsyncSharedOp>(loc, 0);
|
||||
|
||||
for (uint32_t b = 0; b < numBox; ++b) {
|
||||
SmallVector<Value> coord;
|
||||
// raw coord
|
||||
for (int i = 0; i < rank; ++i) {
|
||||
auto dim = getDimOfOrder(dstOrder, i);
|
||||
coord.push_back(llCoord[dim]);
|
||||
}
|
||||
// coord with box and cta offset
|
||||
for (int i = 0; i < rank; ++i) {
|
||||
auto dim = getDimOfOrder(dstOrder, i);
|
||||
if (i == 0) {
|
||||
coord[i] = add(coord[i], i32_val(b * boxDims[i]));
|
||||
auto CTAOffset =
|
||||
mul(multiDimClusterCTAId[dim], i32_val(numBox * boxDims[i]));
|
||||
coord[i] = add(coord[i], CTAOffset);
|
||||
} else {
|
||||
coord[i] = add(coord[i],
|
||||
mul(multiDimClusterCTAId[dim], i32_val(boxDims[i])));
|
||||
}
|
||||
}
|
||||
Value srcOffset = i32_val(b * boxStride);
|
||||
auto srcPtrTy = ptr_ty(getTypeConverter()->convertType(elemTy), 3);
|
||||
Value srcPtrBase = gep(srcPtrTy, smemObj.base, srcOffset);
|
||||
auto addr = bitcast(srcPtrBase, ptrI8SharedTy);
|
||||
rewriter.create<triton::nvgpu::TMAStoreTiledOp>(loc, tmaDesc, addr, pred,
|
||||
coord);
|
||||
}
|
||||
rewriter.eraseOp(op);
|
||||
return success();
|
||||
}
|
||||
|
||||
private:
|
||||
CUtensorMapDataType getCUtensorMapDataType(Type ty) const {
|
||||
if (ty.isF16()) {
|
||||
return CUtensorMapDataType::CU_TENSOR_MAP_DATA_TYPE_FLOAT16;
|
||||
} else if (ty.isF32()) {
|
||||
return CUtensorMapDataType::CU_TENSOR_MAP_DATA_TYPE_FLOAT32;
|
||||
} else {
|
||||
llvm::report_fatal_error("Unsupported elemTy for StoreAsyncOp");
|
||||
return CUtensorMapDataType::CU_TENSOR_MAP_DATA_TYPE_FLOAT16;
|
||||
}
|
||||
}
|
||||
|
||||
unsigned getArgIdx(Value v) const {
|
||||
if (auto op = v.getDefiningOp<mlir::arith::ConstantOp>()) {
|
||||
return -1 -
|
||||
op.getValue().dyn_cast<IntegerAttr>().getValue().getZExtValue();
|
||||
}
|
||||
if (v.getDefiningOp() &&
|
||||
isa<mlir::UnrealizedConversionCastOp>(v.getDefiningOp())) {
|
||||
return getArgIdx(v.getDefiningOp()->getOperand(0));
|
||||
} else if (v.getParentBlock()->isEntryBlock() && v.isa<BlockArgument>()) {
|
||||
// in entryblock and is BlockArgument; Because argument of func are
|
||||
// arugments of entryblock bb0 in MLIR
|
||||
return v.cast<BlockArgument>().getArgNumber();
|
||||
} else if (v.getParentBlock()->isEntryBlock() &&
|
||||
(!v.isa<BlockArgument>())) {
|
||||
// in entryblock but not BlockArgument
|
||||
return getArgIdx(v.getDefiningOp()->getOperand(0));
|
||||
} else if (!v.getParentBlock()->isEntryBlock()) {
|
||||
// in non-entryblock
|
||||
return getArgIdx(v.getDefiningOp()->getOperand(0));
|
||||
} else {
|
||||
llvm::report_fatal_error(
|
||||
"Operand of `MakeTensorPtrOp` is not the function's argument");
|
||||
return 0;
|
||||
}
|
||||
}
|
||||
|
||||
int getNumTMADescs(LLVM::LLVMFuncOp func) const {
|
||||
if (!func->hasAttr(kAttrNumTMALoadDescsName)) {
|
||||
llvm::report_fatal_error("TritonGPU module should contain a "
|
||||
"triton_gpu.num-tma-load attribute");
|
||||
return -1;
|
||||
}
|
||||
if (!func->hasAttr(kAttrNumTMAStoreDescsName)) {
|
||||
llvm::report_fatal_error("TritonGPU module should contain a "
|
||||
"triton_gpu.num-tma-store attribute");
|
||||
return -1;
|
||||
}
|
||||
return func->getAttr(kAttrNumTMAStoreDescsName)
|
||||
.cast<IntegerAttr>()
|
||||
.getInt() +
|
||||
func->getAttr(kAttrNumTMALoadDescsName).cast<IntegerAttr>().getInt();
|
||||
}
|
||||
|
||||
const TensorPtrMapT *tensorPtrMap;
|
||||
};
|
||||
|
||||
struct AtomicCASOpConversion
|
||||
: public ConvertTritonGPUOpToLLVMPattern<triton::AtomicCASOp>,
|
||||
@@ -854,11 +1107,386 @@ struct InsertSliceAsyncOpConversion
|
||||
}
|
||||
};
|
||||
|
||||
struct InsertSliceAsyncV2OpConversion
|
||||
: public ConvertTritonGPUOpToLLVMPattern<
|
||||
triton::nvidia_gpu::InsertSliceAsyncV2Op> {
|
||||
using ConvertTritonGPUOpToLLVMPattern<
|
||||
triton::nvidia_gpu::InsertSliceAsyncV2Op>::
|
||||
ConvertTritonGPUOpToLLVMPattern;
|
||||
|
||||
InsertSliceAsyncV2OpConversion(TritonGPUToLLVMTypeConverter &converter,
|
||||
|
||||
ModuleAllocation &allocation,
|
||||
mlir::triton::gpu::TMAMetadataTy *tmaMetadata,
|
||||
const TensorPtrMapT *tensorPtrMap,
|
||||
PatternBenefit benefit)
|
||||
: ConvertTritonGPUOpToLLVMPattern<
|
||||
triton::nvidia_gpu::InsertSliceAsyncV2Op>(converter, allocation,
|
||||
tmaMetadata, benefit),
|
||||
tensorPtrMap(tensorPtrMap) {}
|
||||
|
||||
LogicalResult
|
||||
matchAndRewrite(triton::nvidia_gpu::InsertSliceAsyncV2Op op,
|
||||
OpAdaptor adaptor,
|
||||
ConversionPatternRewriter &rewriter) const override {
|
||||
|
||||
Location loc = op->getLoc();
|
||||
auto resultTy = op.getResult().getType().cast<RankedTensorType>();
|
||||
auto elemTy = resultTy.getElementType();
|
||||
auto rank = resultTy.getRank() - 1;
|
||||
|
||||
// TODO: support any valid rank in (3, 4, 5)
|
||||
assert(rank > 0 && rank <= 5);
|
||||
SmallVector<unsigned> shape;
|
||||
auto axis = op->getAttrOfType<IntegerAttr>("axis").getInt();
|
||||
auto moduleOp = op->getParentOfType<ModuleOp>();
|
||||
assert(moduleOp && "Parent ModuleOp not found for InsertSliceAsyncV2Op");
|
||||
auto llFuncOp = op->getParentOfType<LLVM::LLVMFuncOp>();
|
||||
assert(llFuncOp && "LLVMFuncOp not found for InsertSliceAsyncV2Op");
|
||||
int numTMADescs = getNumTMADescs(llFuncOp);
|
||||
assert(numTMADescs > 0);
|
||||
auto sharedLayout = resultTy.getEncoding().dyn_cast<SharedEncodingAttr>();
|
||||
assert(sharedLayout && "unexpected layout of InsertSliceAsyncV2Op");
|
||||
auto CTAsPerCGA = sharedLayout.getCTALayout().getCTAsPerCGA();
|
||||
auto CTAOrder = sharedLayout.getCTALayout().getCTAOrder();
|
||||
auto CTASplitNum = sharedLayout.getCTALayout().getCTASplitNum();
|
||||
|
||||
mlir::triton::gpu::TMAInfo tmaInfo;
|
||||
|
||||
tmaInfo.tensorDataType = getCUtensorMapDataType(elemTy);
|
||||
tmaInfo.tensorRank = rank;
|
||||
|
||||
assert(tmaMetadata);
|
||||
unsigned TMADescIdx = tmaMetadata->size();
|
||||
unsigned numFuncArgs = llFuncOp.getBody().front().getNumArguments();
|
||||
auto makeTensorPtr = tensorPtrMap->lookup(op.getOperation());
|
||||
auto inOrder = makeTensorPtr.getOrder();
|
||||
unsigned globalAddressArgIdx = getArgIdx(makeTensorPtr.getBase());
|
||||
tmaInfo.globalAddressArgIdx = globalAddressArgIdx;
|
||||
tmaInfo.TMADescArgIdx = numFuncArgs - numTMADescs + TMADescIdx;
|
||||
|
||||
auto getDimOfOrder = [](ArrayRef<int32_t> order, int32_t i) {
|
||||
auto it = std::find(order.begin(), order.end(), i);
|
||||
assert(it != order.end());
|
||||
return std::distance(order.begin(), it);
|
||||
};
|
||||
|
||||
std::vector<int32_t> globalDimsArgIdx;
|
||||
std::vector<int32_t> globalStridesArgIdx;
|
||||
// constant values are mapped to (-1 - value)
|
||||
for (int i = 0; i < rank; ++i) {
|
||||
int32_t argIdx = -1;
|
||||
auto dim = getDimOfOrder(inOrder, i);
|
||||
argIdx = getArgIdx(makeTensorPtr.getShape()[dim]);
|
||||
globalDimsArgIdx.emplace_back(argIdx);
|
||||
// handle constant stride
|
||||
argIdx = getArgIdx(makeTensorPtr.getStrides()[dim]);
|
||||
globalStridesArgIdx.emplace_back(argIdx);
|
||||
}
|
||||
|
||||
tmaInfo.globalDimsArgIdx = globalDimsArgIdx;
|
||||
tmaInfo.globalStridesArgIdx = globalStridesArgIdx;
|
||||
|
||||
std::vector<uint32_t> boxDims;
|
||||
auto tensorShape = makeTensorPtr.getResult()
|
||||
.getType()
|
||||
.cast<triton::PointerType>()
|
||||
.getPointeeType()
|
||||
.cast<RankedTensorType>()
|
||||
.getShape();
|
||||
|
||||
SmallVector<unsigned> numMcast(rank);
|
||||
unsigned accNumMcast = 1;
|
||||
for (unsigned i = 0; i < rank; ++i) {
|
||||
numMcast[i] = CTAsPerCGA[i] / CTASplitNum[i];
|
||||
accNumMcast *= numMcast[i];
|
||||
}
|
||||
auto shapePerCTA = getShapePerCTA(CTASplitNum, tensorShape);
|
||||
for (size_t i = 0; i < rank; ++i) {
|
||||
auto dim = getDimOfOrder(inOrder, i);
|
||||
// in case of TMA multicast, we should always slice along higher order
|
||||
// dimensions
|
||||
if (i == rank - 1) {
|
||||
assert(shapePerCTA[dim] >= accNumMcast &&
|
||||
"cases when the size of the highest order is smaller "
|
||||
"than numMcasts is not implemented");
|
||||
boxDims.emplace_back(shapePerCTA[dim] / accNumMcast);
|
||||
} else {
|
||||
boxDims.emplace_back(shapePerCTA[dim]);
|
||||
}
|
||||
}
|
||||
|
||||
std::vector<uint32_t> elementStrides(rank, 1);
|
||||
tmaInfo.elementStrides = elementStrides;
|
||||
|
||||
CUtensorMapSwizzle swizzle = CUtensorMapSwizzle::CU_TENSOR_MAP_SWIZZLE_NONE;
|
||||
if (sharedLayout.getPerPhase() == 4 && sharedLayout.getMaxPhase() == 2)
|
||||
swizzle = CUtensorMapSwizzle::CU_TENSOR_MAP_SWIZZLE_32B;
|
||||
else if (sharedLayout.getPerPhase() == 2 && sharedLayout.getMaxPhase() == 4)
|
||||
swizzle = CUtensorMapSwizzle::CU_TENSOR_MAP_SWIZZLE_64B;
|
||||
else if (sharedLayout.getPerPhase() == 1 && sharedLayout.getMaxPhase() == 8)
|
||||
swizzle = CUtensorMapSwizzle::CU_TENSOR_MAP_SWIZZLE_128B;
|
||||
else
|
||||
llvm::report_fatal_error(
|
||||
"Unsupported shared layout for InsertSliceAsyncV2Op");
|
||||
|
||||
tmaInfo.swizzle = swizzle;
|
||||
tmaInfo.interleave = CUtensorMapInterleave::CU_TENSOR_MAP_INTERLEAVE_NONE;
|
||||
tmaInfo.l2Promotion =
|
||||
CUtensorMapL2promotion::CU_TENSOR_MAP_L2_PROMOTION_L2_128B;
|
||||
tmaInfo.oobFill =
|
||||
CUtensorMapFloatOOBfill::CU_TENSOR_MAP_FLOAT_OOB_FILL_NONE;
|
||||
|
||||
uint32_t numBoxes = 1;
|
||||
uint32_t elemSizeOfBytes = elemTy.getIntOrFloatBitWidth() / 8;
|
||||
if (swizzle == CUtensorMapSwizzle::CU_TENSOR_MAP_SWIZZLE_128B) {
|
||||
while (elemSizeOfBytes * boxDims[0] > 128) {
|
||||
boxDims[0] = boxDims[0] / 2;
|
||||
numBoxes *= 2;
|
||||
}
|
||||
}
|
||||
tmaInfo.boxDims = boxDims;
|
||||
tmaMetadata->emplace_back(tmaInfo);
|
||||
|
||||
uint32_t elemsPerBox =
|
||||
std::accumulate(boxDims.begin(), boxDims.end(), 1, std::multiplies{});
|
||||
|
||||
Value clusterCTAId = getClusterCTAId(rewriter, loc);
|
||||
SmallVector<Value> multiDimClusterCTAId =
|
||||
delinearize(rewriter, loc, clusterCTAId, CTAsPerCGA, CTAOrder);
|
||||
|
||||
Value llDst = adaptor.getDst();
|
||||
Value llIndex = adaptor.getIndex();
|
||||
Value src = op.getSrc();
|
||||
Value dst = op.getDst();
|
||||
auto dstTy = dst.getType().cast<RankedTensorType>();
|
||||
auto dstShape = dstTy.getShape();
|
||||
auto smemObj = getSharedMemoryObjectFromStruct(loc, llDst, rewriter);
|
||||
|
||||
// the offset of coord considering multicast slicing
|
||||
SmallVector<Value> mcastOffsetVals;
|
||||
// The index of slice is this CTAId is responsible for
|
||||
SmallVector<Value> multiDimSliceIdx(rank);
|
||||
for (auto i = 0; i < rank; ++i)
|
||||
multiDimSliceIdx[i] =
|
||||
udiv(multiDimClusterCTAId[i], i32_val(CTASplitNum[i]));
|
||||
Value sliceIdx =
|
||||
linearize(rewriter, loc, multiDimSliceIdx, numMcast, CTAOrder);
|
||||
|
||||
Value sliceCoord;
|
||||
for (auto i = 0; i < rank; ++i) {
|
||||
if (inOrder[i] == rank - 1) {
|
||||
// TODO[goostavz]: Cases when the size of the highest order is smaller
|
||||
// than numMcasts is not implemented.
|
||||
sliceCoord = mul(sliceIdx, i32_val(shapePerCTA[i] / accNumMcast));
|
||||
mcastOffsetVals.emplace_back(
|
||||
mul(sliceIdx, i32_val(shapePerCTA[i] / accNumMcast)));
|
||||
} else {
|
||||
mcastOffsetVals.emplace_back(i32_val(0));
|
||||
}
|
||||
}
|
||||
|
||||
uint32_t elemsPerSlice = std::accumulate(
|
||||
shapePerCTA.begin(), shapePerCTA.end(), 1, std::multiplies{});
|
||||
Value dstOffsetCommon = mul(llIndex, i32_val(elemsPerSlice));
|
||||
// [benzh] sliceCoord should be higher dimension's multiplier accumulate.
|
||||
// currently only support rank == 2.
|
||||
dstOffsetCommon =
|
||||
add(dstOffsetCommon, mul(sliceCoord, i32_val(boxDims[0])));
|
||||
auto dstPtrTy = ptr_ty(getTypeConverter()->convertType(elemTy), 3);
|
||||
|
||||
Value tmaDesc =
|
||||
llFuncOp.getBody().front().getArgument(tmaInfo.TMADescArgIdx);
|
||||
// TODO: sink this logic into Triton::NVGPU dialect and support more
|
||||
// cache-policy modes
|
||||
Value l2Desc = int_val(64, 0x1000000000000000ll);
|
||||
|
||||
auto ptrI8SharedTy = LLVM::LLVMPointerType::get(
|
||||
typeConverter->convertType(rewriter.getI8Type()), 3);
|
||||
|
||||
SmallVector<Value> coordCommon;
|
||||
auto llCoord = getTypeConverter()->unpackLLElements(
|
||||
loc, adaptor.getSrc(), rewriter, src.getType());
|
||||
|
||||
for (int i = 0; i < rank; ++i) {
|
||||
auto dim = getDimOfOrder(inOrder, i);
|
||||
Value coordDim = bitcast(llCoord[dim], i32_ty);
|
||||
if (CTASplitNum[dim] != 1) {
|
||||
// Add offset for each CTA
|
||||
// boxDims[i] * (multiDimClusterCTAId[i] % CTASplitNum[i]);
|
||||
auto CTAOffset =
|
||||
mul(i32_val(shapePerCTA[dim]),
|
||||
urem(multiDimClusterCTAId[dim], i32_val(CTASplitNum[dim])));
|
||||
coordDim = add(coordDim, CTAOffset);
|
||||
}
|
||||
|
||||
if (i == rank - 1)
|
||||
// Add offset in case of multicast slicing
|
||||
coordCommon.push_back(add(coordDim, mcastOffsetVals[dim]));
|
||||
else
|
||||
coordCommon.push_back(coordDim);
|
||||
}
|
||||
|
||||
auto threadId = getThreadId(rewriter, loc);
|
||||
Value pred = icmp_eq(threadId, i32_val(0));
|
||||
|
||||
auto mask = adaptor.getMask();
|
||||
if (mask) {
|
||||
// TODO(thomas): What is the right implementation for this case?
|
||||
assert(mask.getType().isInteger(1) &&
|
||||
"need to implement cases with tensor mask");
|
||||
pred = rewriter.create<arith::AndIOp>(loc, pred, mask);
|
||||
}
|
||||
|
||||
Value mcastMask = getMCastMask(sharedLayout, rewriter, loc, clusterCTAId);
|
||||
|
||||
for (size_t i = 0; i < numBoxes; ++i) {
|
||||
Value dstOffset =
|
||||
add(dstOffsetCommon, i32_val(i * elemsPerBox * accNumMcast));
|
||||
Value dstPtrBase = gep(dstPtrTy, smemObj.base, dstOffset);
|
||||
SmallVector<Value> coord = coordCommon;
|
||||
coord[0] = add(coordCommon[0], i32_val(i * boxDims[0]));
|
||||
rewriter.create<triton::nvgpu::TMALoadTiledOp>(
|
||||
loc, bitcast(dstPtrBase, ptrI8SharedTy), adaptor.getMbar(), tmaDesc,
|
||||
l2Desc, pred, coord, mcastMask);
|
||||
}
|
||||
|
||||
rewriter.replaceOp(op, llDst);
|
||||
return success();
|
||||
}
|
||||
|
||||
private:
|
||||
Value getMCastMask(const SharedEncodingAttr &sharedLayout,
|
||||
ConversionPatternRewriter &rewriter, Location loc,
|
||||
Value clusterCTAId) const {
|
||||
auto CTAsPerCGA = sharedLayout.getCTALayout().getCTAsPerCGA();
|
||||
auto CTAOrder = sharedLayout.getCTALayout().getCTAOrder();
|
||||
auto CTASplitNum = sharedLayout.getCTALayout().getCTASplitNum();
|
||||
|
||||
// Short path when no multicast is needed
|
||||
if (CTAsPerCGA == CTASplitNum)
|
||||
return nullptr;
|
||||
|
||||
// Short path when bcastMask is a constant
|
||||
bool isConstMcastMask = true;
|
||||
for (unsigned s : CTASplitNum) {
|
||||
if (s > 1) {
|
||||
isConstMcastMask = false;
|
||||
break;
|
||||
}
|
||||
}
|
||||
if (isConstMcastMask) {
|
||||
unsigned numCTAs = std::accumulate(CTAsPerCGA.begin(), CTAsPerCGA.end(),
|
||||
1, std::multiplies{});
|
||||
return int_val(/*width*/ 16, (1u << numCTAs) - 1);
|
||||
}
|
||||
|
||||
SmallVector<Value> multiDimCTAId =
|
||||
delinearize(rewriter, loc, clusterCTAId, CTAsPerCGA, CTAOrder);
|
||||
auto rank = CTAOrder.size();
|
||||
SmallVector<SmallVector<Value>> multiDimMask(rank);
|
||||
unsigned accNumMcast = 1;
|
||||
SmallVector<unsigned> numMcast(rank);
|
||||
for (unsigned i = 0; i < rank; ++i) {
|
||||
// For the ith dimension, CTAsPerCGA[i]/CTASplitNum[i] vals is to be
|
||||
// broadcasted, which for this CTAId is:
|
||||
// multiDimCTAId[i] % CTASplitNum[i] + (0 ..
|
||||
// (CTAsPerCGA[i]/CTASplitNum[i] - 1)) * CTASplitNum[i]
|
||||
// TODO: will there be cases if CTAsPerCGA[i]/CTASplitNum[i] < 1?
|
||||
Value rem = urem(multiDimCTAId[i], i32_val(CTASplitNum[i]));
|
||||
numMcast[i] = CTAsPerCGA[i] / CTASplitNum[i];
|
||||
accNumMcast *= numMcast[i];
|
||||
for (unsigned j = 0; j < numMcast[i]; ++j) {
|
||||
if (j == 0) {
|
||||
multiDimMask[i].push_back(rem);
|
||||
} else {
|
||||
multiDimMask[i].push_back(add(rem, i32_val(j * CTASplitNum[i])));
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
Value bcastMask = int_val(/*width*/ 16, 0);
|
||||
Value _1_i16 = int_val(/*width*/ 16, 1);
|
||||
for (unsigned i = 0; i < accNumMcast; ++i) {
|
||||
SmallVector<unsigned> multiDimIdx =
|
||||
getMultiDimIndex<unsigned>(i, numMcast, CTAOrder);
|
||||
SmallVector<Value> multiDimMaskedCTAId(rank);
|
||||
for (unsigned dim = 0; dim < rank; ++dim) {
|
||||
multiDimMaskedCTAId[dim] = multiDimMask[dim][multiDimIdx[dim]];
|
||||
}
|
||||
Value bcastCTAId =
|
||||
linearize(rewriter, loc, multiDimMaskedCTAId, CTAsPerCGA, CTAOrder);
|
||||
// bcastMask |= 1u << bcastCTAId;
|
||||
bcastMask = or_(bcastMask, shl(_1_i16, trunc(i16_ty, bcastCTAId)));
|
||||
}
|
||||
|
||||
return bcastMask;
|
||||
}
|
||||
|
||||
CUtensorMapDataType getCUtensorMapDataType(Type ty) const {
|
||||
if (ty.isF16()) {
|
||||
return CUtensorMapDataType::CU_TENSOR_MAP_DATA_TYPE_FLOAT16;
|
||||
} else if (ty.isF32()) {
|
||||
return CUtensorMapDataType::CU_TENSOR_MAP_DATA_TYPE_FLOAT32;
|
||||
} else {
|
||||
llvm::report_fatal_error("Unsupported elemTy for InsertSliceAsyncV2Op");
|
||||
return CUtensorMapDataType::CU_TENSOR_MAP_DATA_TYPE_FLOAT16;
|
||||
}
|
||||
}
|
||||
|
||||
unsigned getArgIdx(Value v) const {
|
||||
if (auto op = v.getDefiningOp<mlir::arith::ConstantOp>()) {
|
||||
return -1 -
|
||||
op.getValue().dyn_cast<IntegerAttr>().getValue().getZExtValue();
|
||||
}
|
||||
if (v.getDefiningOp() &&
|
||||
isa<mlir::UnrealizedConversionCastOp>(v.getDefiningOp())) {
|
||||
return getArgIdx(v.getDefiningOp()->getOperand(0));
|
||||
} else if (v.getParentBlock()->isEntryBlock() && v.isa<BlockArgument>()) {
|
||||
// in entryblock and is BlockArgument; Because argument of func are
|
||||
// arugments of entryblock bb0 in MLIR
|
||||
return v.cast<BlockArgument>().getArgNumber();
|
||||
} else if (v.getParentBlock()->isEntryBlock() &&
|
||||
(!v.isa<BlockArgument>())) {
|
||||
// in entryblock but not BlockArgument
|
||||
return getArgIdx(v.getDefiningOp()->getOperand(0));
|
||||
} else if (!v.getParentBlock()->isEntryBlock()) {
|
||||
// in non-entryblock
|
||||
return getArgIdx(v.getDefiningOp()->getOperand(0));
|
||||
} else {
|
||||
llvm::report_fatal_error(
|
||||
"Operand of `MakeTensorPtrOp` is not the function's argument");
|
||||
return 0;
|
||||
}
|
||||
}
|
||||
|
||||
int getNumTMADescs(LLVM::LLVMFuncOp func) const {
|
||||
if (!func->hasAttr(kAttrNumTMALoadDescsName)) {
|
||||
llvm::report_fatal_error("TritonGPU module should contain a "
|
||||
"triton_gpu.num-tma-load attribute");
|
||||
return -1;
|
||||
}
|
||||
if (!func->hasAttr(kAttrNumTMAStoreDescsName)) {
|
||||
llvm::report_fatal_error("TritonGPU module should contain a "
|
||||
"triton_gpu.num-tma-store attribute");
|
||||
return -1;
|
||||
}
|
||||
return func->getAttr(kAttrNumTMAStoreDescsName)
|
||||
.cast<IntegerAttr>()
|
||||
.getInt() +
|
||||
func->getAttr(kAttrNumTMALoadDescsName).cast<IntegerAttr>().getInt();
|
||||
}
|
||||
|
||||
const TensorPtrMapT *tensorPtrMap;
|
||||
};
|
||||
|
||||
void populateLoadStoreOpToLLVMPatterns(
|
||||
TritonGPUToLLVMTypeConverter &typeConverter, RewritePatternSet &patterns,
|
||||
ModuleAxisInfoAnalysis &axisInfoAnalysis, ModuleAllocation &allocation,
|
||||
int numWarps, ModuleAxisInfoAnalysis &axisInfoAnalysis,
|
||||
ModuleAllocation &allocation,
|
||||
ConvertTritonGPUOpToLLVMPatternBase::IndexCacheInfo &indexCacheInfo,
|
||||
PatternBenefit benefit) {
|
||||
mlir::triton::gpu::TMAMetadataTy *tmaMetadata,
|
||||
const TensorPtrMapT *tensorPtrMap, PatternBenefit benefit) {
|
||||
patterns.add<LoadOpConversion>(typeConverter, axisInfoAnalysis, benefit);
|
||||
patterns.add<StoreOpConversion>(typeConverter, axisInfoAnalysis, benefit);
|
||||
patterns.add<AtomicCASOpConversion>(typeConverter, allocation,
|
||||
@@ -869,4 +1497,8 @@ void populateLoadStoreOpToLLVMPatterns(
|
||||
indexCacheInfo, benefit);
|
||||
patterns.add<InsertSliceAsyncOpConversion>(
|
||||
typeConverter, allocation, indexCacheInfo, axisInfoAnalysis, benefit);
|
||||
patterns.add<InsertSliceAsyncV2OpConversion>(
|
||||
typeConverter, allocation, tmaMetadata, tensorPtrMap, benefit);
|
||||
patterns.add<StoreAsyncOpConversion>(typeConverter, allocation, tmaMetadata,
|
||||
tensorPtrMap, benefit);
|
||||
}
|
||||
|
||||
@@ -8,8 +8,10 @@ using namespace mlir::triton;
|
||||
|
||||
void populateLoadStoreOpToLLVMPatterns(
|
||||
TritonGPUToLLVMTypeConverter &typeConverter, RewritePatternSet &patterns,
|
||||
ModuleAxisInfoAnalysis &axisInfoAnalysis, ModuleAllocation &allocation,
|
||||
int numWarps, ModuleAxisInfoAnalysis &axisInfoAnalysis,
|
||||
ModuleAllocation &allocation,
|
||||
ConvertTritonGPUOpToLLVMPatternBase::IndexCacheInfo &indexCacheInfo,
|
||||
PatternBenefit benefit);
|
||||
mlir::triton::gpu::TMAMetadataTy *tmaMetadata,
|
||||
const TensorPtrMapT *tensorPtrMap, PatternBenefit benefit);
|
||||
|
||||
#endif
|
||||
|
||||
@@ -1,5 +1,7 @@
|
||||
#include "ReduceOpToLLVM.h"
|
||||
#include "Utility.h"
|
||||
#include "mlir/Dialect/LLVMIR/NVVMDialect.h"
|
||||
#include "triton/Dialect/TritonNvidiaGPU/Transforms/Utility.h"
|
||||
|
||||
using namespace mlir;
|
||||
using namespace mlir::triton;
|
||||
@@ -14,8 +16,13 @@ using ::mlir::triton::gpu::getTotalElemsPerThread;
|
||||
struct ReduceOpConversion
|
||||
: public ConvertTritonGPUOpToLLVMPattern<triton::ReduceOp> {
|
||||
public:
|
||||
using ConvertTritonGPUOpToLLVMPattern<
|
||||
triton::ReduceOp>::ConvertTritonGPUOpToLLVMPattern;
|
||||
ReduceOpConversion(
|
||||
TritonGPUToLLVMTypeConverter &typeConverter, ModuleAllocation &allocation,
|
||||
ConvertTritonGPUOpToLLVMPatternBase::IndexCacheInfo &indexCacheInfo,
|
||||
int computeCapability, PatternBenefit benefit)
|
||||
: ConvertTritonGPUOpToLLVMPattern<triton::ReduceOp>(
|
||||
typeConverter, allocation, indexCacheInfo, benefit),
|
||||
computeCapability(computeCapability) {}
|
||||
|
||||
LogicalResult
|
||||
matchAndRewrite(triton::ReduceOp op, OpAdaptor adaptor,
|
||||
@@ -26,14 +33,12 @@ public:
|
||||
}
|
||||
|
||||
private:
|
||||
int computeCapability;
|
||||
|
||||
void accumulate(ConversionPatternRewriter &rewriter, Region &combineOp,
|
||||
llvm::SmallVectorImpl<Value> &acc, ValueRange cur,
|
||||
bool isFirst) const {
|
||||
SmallVector<Value> &acc, ValueRange cur, bool isFirst) const {
|
||||
if (isFirst) {
|
||||
acc.resize(cur.size());
|
||||
for (unsigned i = 0; i < cur.size(); ++i) {
|
||||
acc[i] = cur[i];
|
||||
}
|
||||
acc = SmallVector<Value>(cur.begin(), cur.end());
|
||||
return;
|
||||
}
|
||||
|
||||
@@ -114,7 +119,7 @@ private:
|
||||
// writeIdx[originalAxis] = index[originalAxis] / axisSizePerThread
|
||||
writeIdx[originalAxis] = udiv(index[originalAxis], axisSizePerThread);
|
||||
} else if (auto mmaLayout = layout.dyn_cast<MmaEncodingAttr>()) {
|
||||
if (!mmaLayout.isAmpere()) {
|
||||
if (!mmaLayout.isAmpere() && !mmaLayout.isHopper()) {
|
||||
llvm::report_fatal_error("Unsupported layout");
|
||||
}
|
||||
if (originalAxis == 0) {
|
||||
@@ -157,7 +162,6 @@ private:
|
||||
elemPtrTys[i] = LLVM::LLVMPointerType::get(llvmElemTy, 3);
|
||||
}
|
||||
auto llvmIndexTy = getTypeConverter()->getIndexType();
|
||||
auto indexPtrTy = LLVM::LLVMPointerType::get(llvmIndexTy, 3);
|
||||
|
||||
auto smemShape = helper.getScratchConfigBasic();
|
||||
unsigned elems = product<unsigned>(smemShape);
|
||||
@@ -171,33 +175,10 @@ private:
|
||||
elemPtrTys[i]);
|
||||
}
|
||||
|
||||
unsigned srcElems = getTotalElemsPerThread(srcTys[0]);
|
||||
// Emits indices of the original tensor that each thread
|
||||
// would own
|
||||
auto srcIndices = emitIndices(loc, rewriter, srcLayout, srcTys[0]);
|
||||
auto srcValues = unpackInputs(loc, op, adaptor, rewriter);
|
||||
|
||||
// Emits offsets (the offset from the base index)
|
||||
// of the original tensor that each thread would own
|
||||
// NOTE: Assumes offsets don't actually depend on type
|
||||
SmallVector<SmallVector<unsigned>> offset =
|
||||
emitOffsetForLayout(srcLayout, srcTys[0]);
|
||||
|
||||
// Keep track of accumulations and their indices
|
||||
std::map<SmallVector<unsigned>, SmallVector<Value>> accs;
|
||||
std::map<SmallVector<unsigned>, SmallVector<Value>> indices;
|
||||
|
||||
Region *combineOp = &op.getCombineOp();
|
||||
|
||||
// reduce within threads
|
||||
for (unsigned i = 0; i < srcElems; ++i) {
|
||||
SmallVector<unsigned> key = offset[i];
|
||||
key[axis] = 0;
|
||||
bool isFirst = accs.find(key) == accs.end();
|
||||
accumulate(rewriter, *combineOp, accs[key], srcValues[i], isFirst);
|
||||
if (isFirst)
|
||||
indices[key] = srcIndices[i];
|
||||
}
|
||||
reduceWithinThreads(helper, srcValues, accs, indices, rewriter);
|
||||
|
||||
// cached int32 constants
|
||||
std::map<int, Value> ints;
|
||||
@@ -249,15 +230,17 @@ private:
|
||||
readPtrs[i] = gep(elemPtrTys[i], writePtrs[i], readOffset);
|
||||
}
|
||||
|
||||
barrier();
|
||||
sync(rewriter, loc, op);
|
||||
|
||||
// Combine accumulator value from another thread
|
||||
SmallVector<Value> cur(op.getNumOperands());
|
||||
for (unsigned i = 0; i < op.getNumOperands(); ++i) {
|
||||
cur[i] = load(readPtrs[i]);
|
||||
}
|
||||
accumulate(rewriter, *combineOp, acc, cur, false);
|
||||
accumulate(rewriter, op.getCombineOp(), acc, cur, false);
|
||||
|
||||
sync(rewriter, loc, op);
|
||||
|
||||
barrier();
|
||||
// Publish our new accumulator value to shared memory
|
||||
for (unsigned i = 0; i < op.getNumOperands(); ++i) {
|
||||
store(acc[i], writePtrs[i]);
|
||||
@@ -265,7 +248,7 @@ private:
|
||||
}
|
||||
}
|
||||
|
||||
barrier();
|
||||
sync(rewriter, loc, op);
|
||||
|
||||
// set output values
|
||||
SmallVector<Value> results(op.getNumOperands());
|
||||
@@ -302,78 +285,186 @@ private:
|
||||
return success();
|
||||
}
|
||||
|
||||
// Use warp shuffle for reduction within warps and shared memory for data
|
||||
// exchange across warps
|
||||
LogicalResult matchAndRewriteFast(triton::ReduceOp op, OpAdaptor adaptor,
|
||||
ConversionPatternRewriter &rewriter) const {
|
||||
ReduceOpHelper helper(op);
|
||||
Location loc = op->getLoc();
|
||||
unsigned axis = adaptor.getAxis();
|
||||
|
||||
auto srcTys = op.getInputTypes();
|
||||
auto srcLayout = helper.getSrcLayout();
|
||||
if (!helper.isSupportedLayout()) {
|
||||
assert(false && "Unexpected srcLayout in ReduceOpConversion");
|
||||
void sync(ConversionPatternRewriter &rewriter, Location loc,
|
||||
triton::ReduceOp op) const {
|
||||
// TODO[shuhaoj]: change hard code style of numThreads. Hide async_agent
|
||||
// attr.
|
||||
if (op->hasAttr("async_agent")) {
|
||||
barSync(rewriter, op, getAgentIds(op).front(), 128);
|
||||
} else {
|
||||
barrier();
|
||||
}
|
||||
auto srcOrd = triton::gpu::getOrder(srcLayout);
|
||||
auto srcShape = helper.getSrcShape();
|
||||
}
|
||||
|
||||
SmallVector<Type> elemPtrTys(srcTys.size());
|
||||
for (unsigned i = 0; i < op.getNumOperands(); ++i) {
|
||||
auto ty = srcTys[i].getElementType();
|
||||
auto llvmElemTy = getTypeConverter()->convertType(ty);
|
||||
elemPtrTys[i] = LLVM::LLVMPointerType::get(llvmElemTy, 3);
|
||||
// Check if the reduction can use a redux op and return the kind.
|
||||
std::optional<NVVM::ReduxKind> matchReduxKind(triton::ReduceOp op) const {
|
||||
if (computeCapability < 80)
|
||||
return std::nullopt;
|
||||
if (op.getNumOperands() != 1 || op.getNumResults() != 1)
|
||||
return std::nullopt;
|
||||
Block *block = &(*op.getCombineOp().begin());
|
||||
Operation *yield = block->getTerminator();
|
||||
Operation *reduceOp = yield->getOperand(0).getDefiningOp();
|
||||
if (!reduceOp || reduceOp->getNumOperands() != 2 ||
|
||||
reduceOp->getNumResults() != 1 ||
|
||||
!reduceOp->getResultTypes()[0].isInteger(32))
|
||||
return std::nullopt;
|
||||
if (reduceOp->getOperand(0) != block->getArgument(0) ||
|
||||
reduceOp->getOperand(1) != block->getArgument(1))
|
||||
return std::nullopt;
|
||||
if (isa<arith::AddIOp>(reduceOp))
|
||||
return NVVM::ReduxKind::ADD;
|
||||
if (isa<arith::AndIOp>(reduceOp))
|
||||
return NVVM::ReduxKind::AND;
|
||||
if (isa<arith::OrIOp>(reduceOp))
|
||||
return NVVM::ReduxKind::OR;
|
||||
if (isa<arith::XOrIOp>(reduceOp))
|
||||
return NVVM::ReduxKind::XOR;
|
||||
if (auto externalCall =
|
||||
dyn_cast<triton::PureExternElementwiseOp>(reduceOp)) {
|
||||
if (externalCall.getSymbol() == "__nv_min")
|
||||
return NVVM::ReduxKind::MIN;
|
||||
if (externalCall.getSymbol() == "__nv_umin")
|
||||
return NVVM::ReduxKind::UMIN;
|
||||
if (externalCall.getSymbol() == "__nv_max")
|
||||
return NVVM::ReduxKind::MAX;
|
||||
if (externalCall.getSymbol() == "__nv_umax")
|
||||
return NVVM::ReduxKind::UMAX;
|
||||
}
|
||||
auto llvmIndexTy = getTypeConverter()->getIndexType();
|
||||
auto indexPtrTy = LLVM::LLVMPointerType::get(llvmIndexTy, 3);
|
||||
|
||||
auto smemShapes = helper.getScratchConfigsFast();
|
||||
unsigned elems = product<unsigned>(smemShapes[0]);
|
||||
unsigned maxElems = std::max(elems, product<unsigned>(smemShapes[1]));
|
||||
|
||||
unsigned sizeIntraWarps = helper.getIntraWarpSizeWithUniqueData();
|
||||
unsigned sizeInterWarps = helper.getInterWarpSizeWithUniqueData();
|
||||
|
||||
SmallVector<Value> smemBases(op.getNumOperands());
|
||||
bool isWarpSync = helper.isWarpSynchronous();
|
||||
|
||||
if (!isWarpSync) {
|
||||
smemBases[0] = bitcast(
|
||||
getSharedMemoryBase(loc, rewriter, op.getOperation()), elemPtrTys[0]);
|
||||
for (unsigned i = 1; i < op.getNumOperands(); ++i) {
|
||||
smemBases[i] =
|
||||
bitcast(gep(elemPtrTys[i - 1], smemBases[i - 1], i32_val(maxElems)),
|
||||
elemPtrTys[i]);
|
||||
}
|
||||
}
|
||||
|
||||
unsigned srcElems = getTotalElemsPerThread(srcTys[0]);
|
||||
auto srcIndices = emitIndices(loc, rewriter, srcLayout, srcTys[0]);
|
||||
auto srcValues = unpackInputs(loc, op, adaptor, rewriter);
|
||||
|
||||
std::map<SmallVector<unsigned>, SmallVector<Value>> accs;
|
||||
std::map<SmallVector<unsigned>, SmallVector<Value>> indices;
|
||||
return std::nullopt;
|
||||
}
|
||||
|
||||
// Reduce along op axis for elements that are in the same thread. The
|
||||
// accumulated value is stored in accs.
|
||||
void reduceWithinThreads(
|
||||
ReduceOpHelper &helper, SmallVector<SmallVector<Value>> &srcValues,
|
||||
std::map<SmallVector<unsigned>, SmallVector<Value>> &accs,
|
||||
std::map<SmallVector<unsigned>, SmallVector<Value>> &indices,
|
||||
ConversionPatternRewriter &rewriter) const {
|
||||
triton::ReduceOp op = helper.getOperation();
|
||||
RankedTensorType operandType = op.getInputTypes()[0];
|
||||
// Assumes offsets don't actually depend on type
|
||||
SmallVector<SmallVector<unsigned>> offset =
|
||||
emitOffsetForLayout(srcLayout, srcTys[0]);
|
||||
|
||||
emitOffsetForLayout(helper.getSrcLayout(), operandType);
|
||||
unsigned srcElems = getTotalElemsPerThread(operandType);
|
||||
auto *combineOp = &op.getCombineOp();
|
||||
|
||||
auto srcIndices =
|
||||
emitIndices(op.getLoc(), rewriter, helper.getSrcLayout(), operandType);
|
||||
// reduce within threads
|
||||
for (unsigned i = 0; i < srcElems; ++i) {
|
||||
SmallVector<unsigned> key = offset[i];
|
||||
key[axis] = 0;
|
||||
key[op.getAxis()] = 0;
|
||||
bool isFirst = accs.find(key) == accs.end();
|
||||
accumulate(rewriter, *combineOp, accs[key], srcValues[i], isFirst);
|
||||
if (isFirst)
|
||||
indices[key] = srcIndices[i];
|
||||
}
|
||||
}
|
||||
|
||||
// Apply warp reduction across the given number of contiguous lanes using op
|
||||
// region and the accumulator values as source.
|
||||
void warpReduce(ConversionPatternRewriter &rewriter, Location loc,
|
||||
SmallVector<Value> &acc, triton::ReduceOp op,
|
||||
unsigned numLaneToReduce) const {
|
||||
if (auto kind = matchReduxKind(op)) {
|
||||
// Based on benchmarking on A100 redux op gives a speed up only when doing
|
||||
// a single reduction (not partioned) and when the mask is static.
|
||||
// Therefore we currently only enable it to reduce across all the lanes.
|
||||
if (numLaneToReduce == 32) {
|
||||
assert(acc.size() == 1);
|
||||
Value mask = i32_val(0xFFFFFFFF);
|
||||
// Even though we currently don't use redux for partitioned reduction
|
||||
// the code below supports it in case we want to tweak the heuristic.
|
||||
if (numLaneToReduce < 32) {
|
||||
// For partitioned reduction we need to caluclate the mask so that
|
||||
// each group of numLaneToReduce threads has the correct mask.
|
||||
unsigned bitmask = (1 << numLaneToReduce) - 1;
|
||||
Value threadId = getThreadId(rewriter, loc);
|
||||
Value laneId = urem(threadId, i32_val(32));
|
||||
mask = shl(i32_val(bitmask),
|
||||
and_(laneId, i32_val(~(numLaneToReduce - 1))));
|
||||
}
|
||||
acc[0] = rewriter.create<NVVM::ReduxOp>(loc, acc[0].getType(), acc[0],
|
||||
*kind, mask);
|
||||
return;
|
||||
}
|
||||
}
|
||||
|
||||
for (unsigned N = numLaneToReduce / 2; N > 0; N >>= 1) {
|
||||
SmallVector<Value> shfl(acc.size());
|
||||
for (unsigned i = 0; i < acc.size(); ++i) {
|
||||
shfl[i] = shflSync(loc, rewriter, acc[i], N);
|
||||
}
|
||||
accumulate(rewriter, op.getCombineOp(), acc, shfl, false);
|
||||
}
|
||||
}
|
||||
|
||||
// Reduce across threads within each warp.
|
||||
void
|
||||
reduceWithinWarps(ReduceOpHelper &helper,
|
||||
std::map<SmallVector<unsigned>, SmallVector<Value>> &accs,
|
||||
ConversionPatternRewriter &rewriter) const {
|
||||
triton::ReduceOp op = helper.getOperation();
|
||||
unsigned sizeIntraWarps = helper.getIntraWarpSizeWithUniqueData();
|
||||
for (auto it : accs) {
|
||||
const SmallVector<unsigned> &key = it.first;
|
||||
SmallVector<Value> &acc = accs[key];
|
||||
warpReduce(rewriter, op.getLoc(), acc, op, sizeIntraWarps);
|
||||
}
|
||||
}
|
||||
|
||||
// Pack the accumualtor values and replace the reduce op with the result.
|
||||
void packResults(ReduceOpHelper &helper,
|
||||
std::map<SmallVector<unsigned>, SmallVector<Value>> &accs,
|
||||
ConversionPatternRewriter &rewriter) const {
|
||||
triton::ReduceOp op = helper.getOperation();
|
||||
Location loc = op.getLoc();
|
||||
unsigned axis = op.getAxis();
|
||||
SmallVector<Value> results(op.getNumOperands());
|
||||
for (unsigned i = 0; i < op.getNumOperands(); ++i) {
|
||||
if (auto resultTy =
|
||||
op.getResult()[i].getType().dyn_cast<RankedTensorType>()) {
|
||||
auto resultLayout = resultTy.getEncoding().cast<SliceEncodingAttr>();
|
||||
unsigned resultElems = getTotalElemsPerThread(resultTy);
|
||||
SmallVector<SmallVector<unsigned>> resultOffset =
|
||||
emitOffsetForLayout(resultLayout, resultTy);
|
||||
SmallVector<Value> resultVals;
|
||||
for (int j = 0; j < resultElems; j++) {
|
||||
auto key = resultOffset[j];
|
||||
key.insert(key.begin() + axis, 0);
|
||||
resultVals.push_back(accs[key][i]);
|
||||
}
|
||||
results[i] = getTypeConverter()->packLLElements(loc, resultVals,
|
||||
rewriter, resultTy);
|
||||
} else
|
||||
results[i] = accs.begin()->second[i];
|
||||
}
|
||||
rewriter.replaceOp(op, results);
|
||||
}
|
||||
|
||||
// Return the type of the shared memory pointer for operand i.
|
||||
Type getElementPtrType(triton::ReduceOp op, int i) const {
|
||||
auto ty = op.getInputTypes()[i].getElementType();
|
||||
auto llvmElemTy = getTypeConverter()->convertType(ty);
|
||||
return LLVM::LLVMPointerType::get(llvmElemTy, 3);
|
||||
}
|
||||
|
||||
void storeWarpReduceToSharedMemory(
|
||||
ReduceOpHelper &helper,
|
||||
std::map<SmallVector<unsigned>, SmallVector<Value>> &accs,
|
||||
std::map<SmallVector<unsigned>, SmallVector<Value>> &indices,
|
||||
SmallVector<Value> &smemBases,
|
||||
ConversionPatternRewriter &rewriter) const {
|
||||
triton::ReduceOp op = helper.getOperation();
|
||||
Location loc = op.getLoc();
|
||||
Value threadId = getThreadId(rewriter, loc);
|
||||
Value warpSize = i32_val(32);
|
||||
Value warpId = udiv(threadId, warpSize);
|
||||
Value laneId = urem(threadId, warpSize);
|
||||
auto srcLayout = helper.getSrcLayout();
|
||||
auto srcShape = helper.getSrcShape();
|
||||
unsigned axis = op.getAxis();
|
||||
auto smemShapes = helper.getScratchConfigsFast();
|
||||
|
||||
auto threadsPerWarp =
|
||||
triton::gpu::getThreadsPerWarpWithUniqueData(srcLayout, srcShape);
|
||||
@@ -391,67 +482,38 @@ private:
|
||||
Value zero = i32_val(0);
|
||||
Value laneZero = icmp_eq(laneIdAxis, zero);
|
||||
|
||||
std::map<SmallVector<unsigned>, SmallVector<Value>> finalAccs;
|
||||
for (auto it : accs) {
|
||||
const SmallVector<unsigned> &key = it.first;
|
||||
SmallVector<Value> acc = it.second;
|
||||
|
||||
// Reduce within warps
|
||||
for (unsigned N = sizeIntraWarps / 2; N > 0; N >>= 1) {
|
||||
SmallVector<Value> shfl(op.getNumOperands());
|
||||
for (unsigned i = 0; i < op.getNumOperands(); ++i) {
|
||||
shfl[i] = shflSync(loc, rewriter, acc[i], N);
|
||||
}
|
||||
accumulate(rewriter, *combineOp, acc, shfl, false);
|
||||
}
|
||||
|
||||
if (isWarpSync) {
|
||||
finalAccs[key] = acc;
|
||||
continue;
|
||||
}
|
||||
SmallVector<Value> &acc = it.second;
|
||||
|
||||
SmallVector<Value> writeIdx = indices[key];
|
||||
writeIdx[axis] = warpIdAxis;
|
||||
Value writeOffset =
|
||||
linearize(rewriter, loc, writeIdx, smemShapes[0], order);
|
||||
for (unsigned i = 0; i < op.getNumOperands(); ++i) {
|
||||
Value writePtr = gep(elemPtrTys[i], smemBases[i], writeOffset);
|
||||
auto elemPtrTy = getElementPtrType(op, i);
|
||||
Value writePtr = gep(elemPtrTy, smemBases[i], writeOffset);
|
||||
storeShared(rewriter, loc, writePtr, acc[i], laneZero);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
if (isWarpSync) {
|
||||
SmallVector<Value> results(op.getNumOperands());
|
||||
for (unsigned i = 0; i < op.getNumOperands(); ++i) {
|
||||
if (auto resultTy =
|
||||
op.getResult()[i].getType().dyn_cast<RankedTensorType>()) {
|
||||
auto resultLayout = resultTy.getEncoding().cast<SliceEncodingAttr>();
|
||||
unsigned resultElems = getTotalElemsPerThread(resultTy);
|
||||
SmallVector<SmallVector<unsigned>> resultOffset =
|
||||
emitOffsetForLayout(resultLayout, resultTy);
|
||||
SmallVector<Value> resultVals;
|
||||
for (int j = 0; j < resultElems; j++) {
|
||||
auto key = resultOffset[j];
|
||||
key.insert(key.begin() + axis, 0);
|
||||
resultVals.push_back(finalAccs[key][i]);
|
||||
}
|
||||
results[i] = getTypeConverter()->packLLElements(loc, resultVals,
|
||||
rewriter, resultTy);
|
||||
} else
|
||||
results[i] = finalAccs.begin()->second[i];
|
||||
}
|
||||
rewriter.replaceOp(op, results);
|
||||
return success();
|
||||
}
|
||||
// Load the reduction of each warp and accumulate them to a final value and
|
||||
// store back to shared memory.
|
||||
void accumulatePartialReductions(ReduceOpHelper &helper,
|
||||
SmallVector<Value> &smemBases,
|
||||
ConversionPatternRewriter &rewriter) const {
|
||||
triton::ReduceOp op = helper.getOperation();
|
||||
auto srcLayout = helper.getSrcLayout();
|
||||
auto smemShapes = helper.getScratchConfigsFast();
|
||||
unsigned elems = product<unsigned>(smemShapes[0]);
|
||||
unsigned sizeInterWarps = helper.getInterWarpSizeWithUniqueData();
|
||||
Location loc = op.getLoc();
|
||||
|
||||
barrier();
|
||||
|
||||
// The second round of shuffle reduction
|
||||
// now the problem size: sizeInterWarps, s1, s2, .. , sn
|
||||
// where sizeInterWarps is 2^m
|
||||
//
|
||||
// Each thread needs to process:
|
||||
// elemsPerThread = sizeInterWarps * s1 * s2 .. Sn / numThreads
|
||||
Value threadId = getThreadId(rewriter, loc);
|
||||
Value warpSize = i32_val(32);
|
||||
Value laneId = urem(threadId, warpSize);
|
||||
Value zero = i32_val(0);
|
||||
|
||||
auto mod = op.getOperation()->getParentOfType<ModuleOp>();
|
||||
unsigned numThreads =
|
||||
@@ -464,23 +526,18 @@ private:
|
||||
// i32_val(sizeInerWarps))
|
||||
SmallVector<Value> acc(op.getNumOperands());
|
||||
for (unsigned i = 0; i < op.getNumOperands(); ++i) {
|
||||
Value readPtr = gep(elemPtrTys[i], smemBases[i], readOffset);
|
||||
auto elemPtrTy = getElementPtrType(op, i);
|
||||
Value readPtr = gep(elemPtrTy, smemBases[i], readOffset);
|
||||
acc[i] = load(readPtr);
|
||||
}
|
||||
|
||||
for (unsigned N = sizeInterWarps / 2; N > 0; N >>= 1) {
|
||||
SmallVector<Value> shfl(op.getNumOperands());
|
||||
for (unsigned i = 0; i < op.getNumOperands(); ++i) {
|
||||
shfl[i] = shflSync(loc, rewriter, acc[i], N);
|
||||
}
|
||||
accumulate(rewriter, *combineOp, acc, shfl, false);
|
||||
}
|
||||
warpReduce(rewriter, loc, acc, op, sizeInterWarps);
|
||||
|
||||
// only the first thread in each sizeInterWarps is writing
|
||||
Value writeOffset = readOffset;
|
||||
SmallVector<Value> writePtrs(op.getNumOperands());
|
||||
for (unsigned i = 0; i < op.getNumOperands(); ++i) {
|
||||
writePtrs[i] = gep(elemPtrTys[i], smemBases[i], writeOffset);
|
||||
auto elemPtrTy = getElementPtrType(op, i);
|
||||
writePtrs[i] = gep(elemPtrTy, smemBases[i], writeOffset);
|
||||
}
|
||||
Value threadIsNeeded = icmp_slt(threadId, i32_val(elems));
|
||||
Value laneIdModSizeInterWarps = urem(laneId, i32_val(sizeInterWarps));
|
||||
@@ -496,10 +553,17 @@ private:
|
||||
readOffset = add(readOffset, i32_val(numThreads));
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
barrier();
|
||||
|
||||
// set output values
|
||||
// Load the final reduction from shared memory and replace the reduce result
|
||||
// with it.
|
||||
void loadReductionAndPackResult(ReduceOpHelper &helper,
|
||||
SmallVector<Value> &smemBases,
|
||||
ConversionPatternRewriter &rewriter) const {
|
||||
triton::ReduceOp op = helper.getOperation();
|
||||
Location loc = op.getLoc();
|
||||
auto smemShapes = helper.getScratchConfigsFast();
|
||||
auto order = getOrder(helper.getSrcLayout());
|
||||
SmallVector<Value> results(op.getNumOperands());
|
||||
for (unsigned i = 0; i < op.getNumOperands(); ++i) {
|
||||
if (auto resultTy =
|
||||
@@ -513,10 +577,11 @@ private:
|
||||
SmallVector<Value> resultVals(resultElems);
|
||||
for (size_t j = 0; j < resultElems; ++j) {
|
||||
SmallVector<Value> readIdx = resultIndices[j];
|
||||
readIdx.insert(readIdx.begin() + axis, i32_val(0));
|
||||
readIdx.insert(readIdx.begin() + op.getAxis(), i32_val(0));
|
||||
Value readOffset =
|
||||
linearize(rewriter, loc, readIdx, smemShapes[0], order);
|
||||
Value readPtr = gep(elemPtrTys[i], smemBases[i], readOffset);
|
||||
Value readPtr =
|
||||
gep(getElementPtrType(op, i), smemBases[i], readOffset);
|
||||
resultVals[j] = load(readPtr);
|
||||
}
|
||||
|
||||
@@ -528,6 +593,65 @@ private:
|
||||
}
|
||||
}
|
||||
rewriter.replaceOp(op, results);
|
||||
}
|
||||
|
||||
// Use warp shuffle for reduction within warps and shared memory for data
|
||||
// exchange across warps
|
||||
LogicalResult matchAndRewriteFast(triton::ReduceOp op, OpAdaptor adaptor,
|
||||
ConversionPatternRewriter &rewriter) const {
|
||||
ReduceOpHelper helper(op);
|
||||
assert(helper.isSupportedLayout() &&
|
||||
"Unexpected srcLayout in ReduceOpConversion");
|
||||
Location loc = op->getLoc();
|
||||
|
||||
auto srcValues = unpackInputs(loc, op, adaptor, rewriter);
|
||||
std::map<SmallVector<unsigned>, SmallVector<Value>> accs;
|
||||
std::map<SmallVector<unsigned>, SmallVector<Value>> indices;
|
||||
// First reduce all the values along axis within each thread.
|
||||
reduceWithinThreads(helper, srcValues, accs, indices, rewriter);
|
||||
|
||||
// Then reduce across threads within a warp.
|
||||
reduceWithinWarps(helper, accs, rewriter);
|
||||
|
||||
if (helper.isWarpSynchronous()) {
|
||||
// If all the values to be reduced are within the same warp there is
|
||||
// nothing left to do.
|
||||
packResults(helper, accs, rewriter);
|
||||
return success();
|
||||
}
|
||||
|
||||
// Compute a shared memory base per operand.
|
||||
auto smemShapes = helper.getScratchConfigsFast();
|
||||
unsigned elems = product<unsigned>(smemShapes[0]);
|
||||
unsigned maxElems = std::max(elems, product<unsigned>(smemShapes[1]));
|
||||
SmallVector<Value> smemBases(op.getNumOperands());
|
||||
smemBases[0] =
|
||||
bitcast(getSharedMemoryBase(loc, rewriter, op.getOperation()),
|
||||
getElementPtrType(op, 0));
|
||||
for (unsigned i = 1; i < op.getNumOperands(); ++i) {
|
||||
smemBases[i] = bitcast(gep(getElementPtrType(op, i - 1), smemBases[i - 1],
|
||||
i32_val(maxElems)),
|
||||
getElementPtrType(op, i));
|
||||
}
|
||||
storeWarpReduceToSharedMemory(helper, accs, indices, smemBases, rewriter);
|
||||
|
||||
sync(rewriter, loc, op);
|
||||
|
||||
// The second round of shuffle reduction
|
||||
// now the problem size: sizeInterWarps, s1, s2, .. , sn
|
||||
// where sizeInterWarps is 2^m
|
||||
//
|
||||
// Each thread needs to process:
|
||||
// elemsPerThread = sizeInterWarps * s1 * s2 .. Sn / numThreads
|
||||
accumulatePartialReductions(helper, smemBases, rewriter);
|
||||
|
||||
// We could avoid this barrier in some of the layouts, however this is not
|
||||
// the general case.
|
||||
// TODO: optimize the barrier in case the layouts are accepted.
|
||||
sync(rewriter, loc, op);
|
||||
|
||||
// set output values
|
||||
loadReductionAndPackResult(helper, smemBases, rewriter);
|
||||
|
||||
return success();
|
||||
}
|
||||
@@ -535,9 +659,10 @@ private:
|
||||
|
||||
void populateReduceOpToLLVMPatterns(
|
||||
TritonGPUToLLVMTypeConverter &typeConverter, RewritePatternSet &patterns,
|
||||
int numWarps, ModuleAxisInfoAnalysis &axisInfoAnalysis,
|
||||
ModuleAllocation &allocation,
|
||||
ConvertTritonGPUOpToLLVMPatternBase::IndexCacheInfo &indexCacheInfo,
|
||||
PatternBenefit benefit) {
|
||||
int computeCapability, PatternBenefit benefit) {
|
||||
patterns.add<ReduceOpConversion>(typeConverter, allocation, indexCacheInfo,
|
||||
benefit);
|
||||
computeCapability, benefit);
|
||||
}
|
||||
|
||||
@@ -8,8 +8,9 @@ using namespace mlir::triton;
|
||||
|
||||
void populateReduceOpToLLVMPatterns(
|
||||
TritonGPUToLLVMTypeConverter &typeConverter, RewritePatternSet &patterns,
|
||||
int numWarps, ModuleAxisInfoAnalysis &axisInfoAnalysis,
|
||||
ModuleAllocation &allocation,
|
||||
ConvertTritonGPUOpToLLVMPatternBase::IndexCacheInfo &indexCacheInfo,
|
||||
PatternBenefit benefit);
|
||||
int computeCapability, PatternBenefit benefit);
|
||||
|
||||
#endif
|
||||
|
||||
43
lib/Conversion/TritonGPUToLLVM/RegReallocOpToLLVM.cpp
Normal file
43
lib/Conversion/TritonGPUToLLVM/RegReallocOpToLLVM.cpp
Normal file
@@ -0,0 +1,43 @@
|
||||
#include "RegReallocOpToLLVM.h"
|
||||
|
||||
using namespace mlir;
|
||||
using namespace mlir::triton;
|
||||
|
||||
struct RegAllocOpConversion
|
||||
: public ConvertTritonGPUOpToLLVMPattern<triton::nvidia_gpu::RegAllocOp> {
|
||||
using ConvertTritonGPUOpToLLVMPattern<
|
||||
triton::nvidia_gpu::RegAllocOp>::ConvertTritonGPUOpToLLVMPattern;
|
||||
|
||||
LogicalResult
|
||||
matchAndRewrite(triton::nvidia_gpu::RegAllocOp op, OpAdaptor adaptor,
|
||||
ConversionPatternRewriter &rewriter) const override {
|
||||
Location loc = op->getLoc();
|
||||
rewriter.replaceOpWithNewOp<triton::nvgpu::RegAllocOp>(
|
||||
op, adaptor.getRegCount());
|
||||
return success();
|
||||
}
|
||||
};
|
||||
|
||||
struct RegDeallocOpConversion
|
||||
: public ConvertTritonGPUOpToLLVMPattern<triton::nvidia_gpu::RegDeallocOp> {
|
||||
using ConvertTritonGPUOpToLLVMPattern<
|
||||
triton::nvidia_gpu::RegDeallocOp>::ConvertTritonGPUOpToLLVMPattern;
|
||||
|
||||
LogicalResult
|
||||
matchAndRewrite(triton::nvidia_gpu::RegDeallocOp op, OpAdaptor adaptor,
|
||||
ConversionPatternRewriter &rewriter) const override {
|
||||
Location loc = op->getLoc();
|
||||
rewriter.replaceOpWithNewOp<triton::nvgpu::RegDeallocOp>(
|
||||
op, adaptor.getRegCount());
|
||||
return success();
|
||||
}
|
||||
};
|
||||
|
||||
void populateRegReallocOpToLLVMPatterns(
|
||||
TritonGPUToLLVMTypeConverter &typeConverter, RewritePatternSet &patterns,
|
||||
int numWarps, ModuleAxisInfoAnalysis &axisInfoAnalysis,
|
||||
const ModuleAllocation &allocation, PatternBenefit benefit) {
|
||||
patterns.add<RegAllocOpConversion>(typeConverter, benefit);
|
||||
patterns.add<RegDeallocOpConversion>(typeConverter, benefit);
|
||||
return;
|
||||
}
|
||||
14
lib/Conversion/TritonGPUToLLVM/RegReallocOpToLLVM.h
Normal file
14
lib/Conversion/TritonGPUToLLVM/RegReallocOpToLLVM.h
Normal file
@@ -0,0 +1,14 @@
|
||||
#ifndef TRITON_CONVERSION_TRITONGPU_TO_LLVM_REGREALLOC_OP_H
|
||||
#define TRITON_CONVERSION_TRITONGPU_TO_LLVM_REGREALLOC_OP_H
|
||||
|
||||
#include "TritonGPUToLLVMBase.h"
|
||||
|
||||
using namespace mlir;
|
||||
using namespace mlir::triton;
|
||||
|
||||
void populateRegReallocOpToLLVMPatterns(
|
||||
TritonGPUToLLVMTypeConverter &typeConverter, RewritePatternSet &patterns,
|
||||
int numWarps, ModuleAxisInfoAnalysis &axisInfoAnalysis,
|
||||
const ModuleAllocation &allocation, PatternBenefit benefit);
|
||||
|
||||
#endif
|
||||
@@ -316,6 +316,7 @@ ScanOpConversion::emitFastScan(triton::ScanOp op, triton::ScanOpAdaptor adaptor,
|
||||
|
||||
void populateScanOpToLLVMPatterns(
|
||||
TritonGPUToLLVMTypeConverter &typeConverter, RewritePatternSet &patterns,
|
||||
int numWarps, ModuleAxisInfoAnalysis &axisInfoAnalysis,
|
||||
ModuleAllocation &allocation,
|
||||
ConvertTritonGPUOpToLLVMPatternBase::IndexCacheInfo &indexCacheInfo,
|
||||
PatternBenefit benefit) {
|
||||
|
||||
@@ -8,6 +8,7 @@ using namespace mlir::triton;
|
||||
|
||||
void populateScanOpToLLVMPatterns(
|
||||
TritonGPUToLLVMTypeConverter &typeConverter, RewritePatternSet &patterns,
|
||||
int numWarps, ModuleAxisInfoAnalysis &axisInfoAnalysis,
|
||||
ModuleAllocation &allocation,
|
||||
ConvertTritonGPUOpToLLVMPatternBase::IndexCacheInfo &indexCacheInfo,
|
||||
PatternBenefit benefit);
|
||||
|
||||
104
lib/Conversion/TritonGPUToLLVM/TensorPtrOpsToLLVM.cpp
Normal file
104
lib/Conversion/TritonGPUToLLVM/TensorPtrOpsToLLVM.cpp
Normal file
@@ -0,0 +1,104 @@
|
||||
/*
|
||||
* Copyright (c) 2023 NVIDIA Corporation & Affiliates. All rights reserved.
|
||||
*
|
||||
* Permission is hereby granted, free of charge, to any person obtaining
|
||||
* a copy of this software and associated documentation files
|
||||
* (the "Software"), to deal in the Software without restriction,
|
||||
* including without limitation the rights to use, copy, modify, merge,
|
||||
* publish, distribute, sublicense, and/or sell copies of the Software,
|
||||
* and to permit persons to whom the Software is furnished to do so,
|
||||
* subject to the following conditions:
|
||||
*
|
||||
* The above copyright notice and this permission notice shall be
|
||||
* included in all copies or substantial portions of the Software.
|
||||
*
|
||||
* THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND,
|
||||
* EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF
|
||||
* MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT.
|
||||
* IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY
|
||||
* CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION OF CONTRACT,
|
||||
* TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION WITH THE
|
||||
* SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE.
|
||||
*/
|
||||
|
||||
#include "TensorPtrOpsToLLVM.h"
|
||||
using namespace mlir;
|
||||
using namespace mlir::triton;
|
||||
|
||||
struct MakeTensorPtrOpConversion
|
||||
: public ConvertTritonGPUOpToLLVMPattern<triton::MakeTensorPtrOp> {
|
||||
using ConvertTritonGPUOpToLLVMPattern<
|
||||
triton::MakeTensorPtrOp>::ConvertTritonGPUOpToLLVMPattern;
|
||||
|
||||
LogicalResult
|
||||
matchAndRewrite(triton::MakeTensorPtrOp op, OpAdaptor adaptor,
|
||||
ConversionPatternRewriter &rewriter) const override {
|
||||
|
||||
// struct { offset0, offset1, shape0, shape1, stride0,
|
||||
// stride1, base_ptr};
|
||||
auto offsets = adaptor.getOffsets();
|
||||
auto shapes = adaptor.getShape();
|
||||
auto strides = adaptor.getStrides();
|
||||
auto base = adaptor.getBase();
|
||||
auto result = op.getResult();
|
||||
|
||||
SmallVector<Value> elems;
|
||||
for (auto offset : offsets)
|
||||
elems.push_back(offset);
|
||||
for (auto shape : shapes)
|
||||
elems.push_back(shape);
|
||||
for (auto stride : strides)
|
||||
elems.push_back(stride);
|
||||
|
||||
elems.push_back(base);
|
||||
|
||||
auto newValue = getTypeConverter()->packLLElements(
|
||||
op.getLoc(), elems, rewriter, result.getType());
|
||||
rewriter.replaceOp(op, newValue);
|
||||
return success();
|
||||
}
|
||||
};
|
||||
|
||||
struct AdvanceOpConversion
|
||||
: public ConvertTritonGPUOpToLLVMPattern<triton::AdvanceOp> {
|
||||
using ConvertTritonGPUOpToLLVMPattern<
|
||||
triton::AdvanceOp>::ConvertTritonGPUOpToLLVMPattern;
|
||||
|
||||
LogicalResult
|
||||
matchAndRewrite(triton::AdvanceOp op, OpAdaptor adaptor,
|
||||
ConversionPatternRewriter &rewriter) const override {
|
||||
// struct { offset0, offset1, shape0, shape1, stride0,
|
||||
// stride1, base_ptr};
|
||||
auto loc = op.getLoc();
|
||||
auto ptrType = op.getPtr().getType();
|
||||
auto tensorPtr = adaptor.getPtr();
|
||||
|
||||
auto offsets = adaptor.getOffsets();
|
||||
auto elems =
|
||||
getTypeConverter()->unpackLLElements(loc, tensorPtr, rewriter, ptrType);
|
||||
|
||||
SmallVector<Value, 2> newOffsets;
|
||||
|
||||
for (auto [offset, oldOffset] : llvm::zip_first(offsets, elems)) {
|
||||
newOffsets.push_back((add(offset, oldOffset)));
|
||||
}
|
||||
|
||||
for (size_t i = 0; i < newOffsets.size(); ++i) {
|
||||
elems[i] = newOffsets[i];
|
||||
}
|
||||
|
||||
auto newValue = getTypeConverter()->packLLElements(op.getLoc(), elems,
|
||||
rewriter, ptrType);
|
||||
rewriter.replaceOp(op, newValue);
|
||||
return success();
|
||||
}
|
||||
};
|
||||
|
||||
void populateTensorPtrOpsToLLVMPatterns(
|
||||
TritonGPUToLLVMTypeConverter &typeConverter, RewritePatternSet &patterns,
|
||||
int numWarps, ModuleAxisInfoAnalysis &axisInfoAnalysis,
|
||||
ModuleAllocation &allocation, PatternBenefit benefit) {
|
||||
patterns.add<MakeTensorPtrOpConversion>(typeConverter, benefit);
|
||||
patterns.add<AdvanceOpConversion>(typeConverter, benefit);
|
||||
return;
|
||||
}
|
||||
37
lib/Conversion/TritonGPUToLLVM/TensorPtrOpsToLLVM.h
Normal file
37
lib/Conversion/TritonGPUToLLVM/TensorPtrOpsToLLVM.h
Normal file
@@ -0,0 +1,37 @@
|
||||
/*
|
||||
* Copyright (c) 2023 NVIDIA Corporation & Affiliates. All rights reserved.
|
||||
*
|
||||
* Permission is hereby granted, free of charge, to any person obtaining
|
||||
* a copy of this software and associated documentation files
|
||||
* (the "Software"), to deal in the Software without restriction,
|
||||
* including without limitation the rights to use, copy, modify, merge,
|
||||
* publish, distribute, sublicense, and/or sell copies of the Software,
|
||||
* and to permit persons to whom the Software is furnished to do so,
|
||||
* subject to the following conditions:
|
||||
*
|
||||
* The above copyright notice and this permission notice shall be
|
||||
* included in all copies or substantial portions of the Software.
|
||||
*
|
||||
* THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND,
|
||||
* EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF
|
||||
* MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT.
|
||||
* IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY
|
||||
* CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION OF CONTRACT,
|
||||
* TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION WITH THE
|
||||
* SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE.
|
||||
*/
|
||||
|
||||
#ifndef TRITON_CONVERSION_TRITONGPU_TO_LLVM_TENSOR_PTR_OPS_H
|
||||
#define TRITON_CONVERSION_TRITONGPU_TO_LLVM_TENSOR_PTR_OPS_H
|
||||
|
||||
#include "TritonGPUToLLVMBase.h"
|
||||
|
||||
using namespace mlir;
|
||||
using namespace mlir::triton;
|
||||
|
||||
void populateTensorPtrOpsToLLVMPatterns(
|
||||
TritonGPUToLLVMTypeConverter &typeConverter, RewritePatternSet &patterns,
|
||||
int numWarps, ModuleAxisInfoAnalysis &axisInfoAnalysis,
|
||||
ModuleAllocation &allocation, PatternBenefit benefit);
|
||||
|
||||
#endif
|
||||
@@ -389,18 +389,23 @@ struct GetProgramIdOpConversion
|
||||
LogicalResult
|
||||
matchAndRewrite(triton::GetProgramIdOp op, OpAdaptor adaptor,
|
||||
ConversionPatternRewriter &rewriter) const override {
|
||||
// It is not easy to get the compute capability here, so we use numCTAs to
|
||||
// decide the semantic of GetProgramIdOp. If numCTAs = 1, then
|
||||
// GetProgramIdOp is converted to "%ctaid", otherwise it is converted to
|
||||
// "%clusterid".
|
||||
auto moduleOp = op->getParentOfType<ModuleOp>();
|
||||
assert(moduleOp && "Parent ModuleOp not found for GetProgramIdOp");
|
||||
int numCTAs = triton::gpu::TritonGPUDialect::getNumCTAs(moduleOp);
|
||||
|
||||
Location loc = op->getLoc();
|
||||
assert(op.getAxisAsInt() < 3);
|
||||
std::string sreg = numCTAs == 1 ? "%ctaid." : "%clusterid.";
|
||||
sreg.append(1, 'x' + op.getAxisAsInt()); // 0 -> 'x', 1 -> 'y', 2 -> 'z'
|
||||
|
||||
Value blockId =
|
||||
rewriter.create<::mlir::gpu::BlockIdOp>(loc, dims[op.getAxisAsInt()]);
|
||||
rewriter.replaceOpWithNewOp<arith::IndexCastOp>(op, i32_ty, blockId);
|
||||
Value programId = getSRegValue(rewriter, loc, sreg);
|
||||
rewriter.replaceOpWithNewOp<arith::IndexCastOp>(op, i32_ty, programId);
|
||||
return success();
|
||||
}
|
||||
|
||||
static constexpr mlir::gpu::Dimension dims[] = {mlir::gpu::Dimension::x,
|
||||
mlir::gpu::Dimension::y,
|
||||
mlir::gpu::Dimension::z};
|
||||
};
|
||||
|
||||
struct GetNumProgramsOpConversion
|
||||
@@ -411,19 +416,54 @@ struct GetNumProgramsOpConversion
|
||||
LogicalResult
|
||||
matchAndRewrite(triton::GetNumProgramsOp op, OpAdaptor adaptor,
|
||||
ConversionPatternRewriter &rewriter) const override {
|
||||
// It is not easy to get the compute capability here, so we use numCTAs to
|
||||
// decide the semantic of GetNumProgramsOp. If numCTAs = 1, then
|
||||
// GetNumProgramsOp is converted to "%nctaid", otherwise it is converted to
|
||||
// "%nclusterid".
|
||||
auto moduleOp = op->getParentOfType<ModuleOp>();
|
||||
assert(moduleOp && "Parent ModuleOp not found for GetProgramIdOp");
|
||||
int numCTAs = triton::gpu::TritonGPUDialect::getNumCTAs(moduleOp);
|
||||
|
||||
Location loc = op->getLoc();
|
||||
assert(op.getAxis() < 3);
|
||||
std::string sreg = numCTAs == 1 ? "%nctaid." : "%nclusterid.";
|
||||
sreg.append(1, 'x' + op.getAxis()); // 0 -> 'x', 1 -> 'y', 2 -> 'z'
|
||||
|
||||
Value blockId =
|
||||
rewriter.create<::mlir::gpu::GridDimOp>(loc, dims[op.getAxis()]);
|
||||
rewriter.replaceOpWithNewOp<arith::IndexCastOp>(op, i32_ty, blockId);
|
||||
|
||||
Value numPrograms = getSRegValue(rewriter, loc, sreg);
|
||||
rewriter.replaceOpWithNewOp<arith::IndexCastOp>(op, i32_ty, numPrograms);
|
||||
return success();
|
||||
}
|
||||
};
|
||||
|
||||
static constexpr mlir::gpu::Dimension dims[] = {mlir::gpu::Dimension::x,
|
||||
mlir::gpu::Dimension::y,
|
||||
mlir::gpu::Dimension::z};
|
||||
// TODO[goostavz]: GetThreadIdOp/GetClusterCTAIdOp is a temporary solution
|
||||
// before async dialect is done. These concepts should appear in ttgpu
|
||||
// level, and they are planned to be deprecated along with ttgpu.mbarrier_xxx
|
||||
// ops.
|
||||
struct GetThreadIdOpConversion : public ConvertTritonGPUOpToLLVMPattern<
|
||||
triton::nvidia_gpu::GetThreadIdOp> {
|
||||
using ConvertTritonGPUOpToLLVMPattern<
|
||||
triton::nvidia_gpu::GetThreadIdOp>::ConvertTritonGPUOpToLLVMPattern;
|
||||
|
||||
LogicalResult
|
||||
matchAndRewrite(triton::nvidia_gpu::GetThreadIdOp op, OpAdaptor adaptor,
|
||||
ConversionPatternRewriter &rewriter) const override {
|
||||
rewriter.replaceOp(op, getThreadId(rewriter, op->getLoc()));
|
||||
return success();
|
||||
}
|
||||
};
|
||||
|
||||
struct GetClusterCTAIdOpConversion
|
||||
: public ConvertTritonGPUOpToLLVMPattern<
|
||||
triton::nvidia_gpu::GetClusterCTAIdOp> {
|
||||
using ConvertTritonGPUOpToLLVMPattern<
|
||||
triton::nvidia_gpu::GetClusterCTAIdOp>::ConvertTritonGPUOpToLLVMPattern;
|
||||
|
||||
LogicalResult
|
||||
matchAndRewrite(triton::nvidia_gpu::GetClusterCTAIdOp op, OpAdaptor adaptor,
|
||||
ConversionPatternRewriter &rewriter) const override {
|
||||
rewriter.replaceOp(op, getClusterCTAId(rewriter, op->getLoc()));
|
||||
return success();
|
||||
}
|
||||
};
|
||||
|
||||
struct AddPtrOpConversion
|
||||
@@ -479,7 +519,8 @@ struct AllocTensorOpConversion
|
||||
getTypeConverter()->convertType(resultTy.getElementType());
|
||||
auto elemPtrTy = ptr_ty(llvmElemTy, 3);
|
||||
smemBase = bitcast(smemBase, elemPtrTy);
|
||||
auto order = resultTy.getEncoding().cast<SharedEncodingAttr>().getOrder();
|
||||
auto sharedLayout = resultTy.getEncoding().cast<SharedEncodingAttr>();
|
||||
auto order = sharedLayout.getOrder();
|
||||
// Workaround for 3D tensors
|
||||
// TODO: we need to modify the pipeline pass to give a proper shared
|
||||
// encoding to 3D tensors
|
||||
@@ -489,8 +530,9 @@ struct AllocTensorOpConversion
|
||||
else
|
||||
newOrder = SmallVector<unsigned>(order.begin(), order.end());
|
||||
|
||||
auto smemObj = SharedMemoryObject(smemBase, resultTy.getShape(), newOrder,
|
||||
loc, rewriter);
|
||||
auto shapePerCTA = getShapePerCTA(sharedLayout, resultTy.getShape());
|
||||
auto smemObj =
|
||||
SharedMemoryObject(smemBase, shapePerCTA, newOrder, loc, rewriter);
|
||||
auto retVal = getStructFromSharedMemoryObject(loc, smemObj, rewriter);
|
||||
rewriter.replaceOp(op, retVal);
|
||||
return success();
|
||||
@@ -593,6 +635,49 @@ struct AsyncCommitGroupOpConversion
|
||||
}
|
||||
};
|
||||
|
||||
struct AsyncBulkWaitOpConversion
|
||||
: public ConvertTritonGPUOpToLLVMPattern<triton::gpu::AsyncBulkWaitOp> {
|
||||
using ConvertTritonGPUOpToLLVMPattern<
|
||||
triton::gpu::AsyncBulkWaitOp>::ConvertTritonGPUOpToLLVMPattern;
|
||||
|
||||
LogicalResult
|
||||
matchAndRewrite(triton::gpu::AsyncBulkWaitOp op, OpAdaptor adaptor,
|
||||
ConversionPatternRewriter &rewriter) const override {
|
||||
PTXBuilder ptxBuilder;
|
||||
auto &asyncBulkWaitOp = *ptxBuilder.create<>("cp.async.bulk.wait_group");
|
||||
auto num = op->getAttrOfType<IntegerAttr>("num").getInt();
|
||||
asyncBulkWaitOp(ptxBuilder.newConstantOperand(num));
|
||||
|
||||
auto ctx = op.getContext();
|
||||
auto loc = op.getLoc();
|
||||
auto voidTy = void_ty(ctx);
|
||||
ptxBuilder.launch(rewriter, loc, voidTy);
|
||||
|
||||
// Safe to remove the op since it doesn't have any return value.
|
||||
rewriter.eraseOp(op);
|
||||
return success();
|
||||
}
|
||||
};
|
||||
|
||||
struct AsyncBulkCommitGroupOpConversion
|
||||
: public ConvertTritonGPUOpToLLVMPattern<
|
||||
triton::gpu::AsyncBulkCommitGroupOp> {
|
||||
using ConvertTritonGPUOpToLLVMPattern<
|
||||
triton::gpu::AsyncBulkCommitGroupOp>::ConvertTritonGPUOpToLLVMPattern;
|
||||
|
||||
LogicalResult
|
||||
matchAndRewrite(triton::gpu::AsyncBulkCommitGroupOp op, OpAdaptor adaptor,
|
||||
ConversionPatternRewriter &rewriter) const override {
|
||||
|
||||
PTXBuilder ptxBuilder;
|
||||
ptxBuilder.create<>("cp.async.bulk.commit_group")->operator()();
|
||||
ptxBuilder.launch(rewriter, op.getLoc(), void_ty(op.getContext()));
|
||||
// Safe to remove the op since it doesn't have any return value.
|
||||
rewriter.eraseOp(op);
|
||||
return success();
|
||||
}
|
||||
};
|
||||
|
||||
namespace mlir {
|
||||
namespace LLVM {
|
||||
|
||||
@@ -618,6 +703,7 @@ void vprintf_array(Value thread, ArrayRef<Value> arr, std::string info,
|
||||
|
||||
void populateTritonGPUToLLVMPatterns(
|
||||
TritonGPUToLLVMTypeConverter &typeConverter, RewritePatternSet &patterns,
|
||||
int numWarps, ModuleAxisInfoAnalysis &axisInfoAnalysis,
|
||||
ModuleAllocation &moduleAllocation,
|
||||
ConvertTritonGPUOpToLLVMPatternBase::IndexCacheInfo &indexCacheInfo,
|
||||
PatternBenefit benefit) {
|
||||
@@ -626,12 +712,15 @@ void populateTritonGPUToLLVMPatterns(
|
||||
benefit);
|
||||
patterns.add<AsyncCommitGroupOpConversion>(typeConverter, benefit);
|
||||
patterns.add<AsyncWaitOpConversion>(typeConverter, benefit);
|
||||
patterns.add<AsyncBulkCommitGroupOpConversion>(typeConverter, benefit);
|
||||
patterns.add<AsyncBulkWaitOpConversion>(typeConverter, benefit);
|
||||
patterns.add<BroadcastOpConversion>(typeConverter, benefit);
|
||||
|
||||
patterns.add<ExtractSliceOpConversion>(typeConverter, moduleAllocation,
|
||||
benefit);
|
||||
patterns.add<GetProgramIdOpConversion>(typeConverter, benefit);
|
||||
patterns.add<GetNumProgramsOpConversion>(typeConverter, benefit);
|
||||
patterns.add<GetThreadIdOpConversion>(typeConverter, benefit);
|
||||
patterns.add<GetClusterCTAIdOpConversion>(typeConverter, benefit);
|
||||
patterns.add<MakeRangeOpConversion>(typeConverter, indexCacheInfo, benefit);
|
||||
patterns.add<ReturnOpConversion>(typeConverter, benefit);
|
||||
patterns.add<PrintOpConversion>(typeConverter, benefit);
|
||||
|
||||
@@ -8,6 +8,7 @@ using namespace mlir::triton;
|
||||
|
||||
void populateTritonGPUToLLVMPatterns(
|
||||
TritonGPUToLLVMTypeConverter &typeConverter, RewritePatternSet &patterns,
|
||||
int numWarps, ModuleAxisInfoAnalysis &axisInfoAnalysis,
|
||||
ModuleAllocation &allocation,
|
||||
ConvertTritonGPUOpToLLVMPatternBase::IndexCacheInfo &indexCacheInfo,
|
||||
PatternBenefit benefit);
|
||||
|
||||
@@ -11,16 +11,28 @@
|
||||
#include "Utility.h"
|
||||
#include "mlir/IR/TypeUtilities.h"
|
||||
#include "triton/Analysis/AxisInfo.h"
|
||||
#include "triton/Target/PTX/TmaMetadata.h"
|
||||
#include <set>
|
||||
|
||||
#define DEBUG_TYPE "ttgpu_to_llvm"
|
||||
|
||||
constexpr ::llvm::StringLiteral kAttrNumTMALoadDescsName =
|
||||
"triton_gpu.num-tma-load";
|
||||
constexpr ::llvm::StringLiteral kAttrNumTMAStoreDescsName =
|
||||
"triton_gpu.num-tma-store";
|
||||
using namespace mlir;
|
||||
using namespace mlir::triton;
|
||||
|
||||
using ::mlir::LLVM::delinearize;
|
||||
using ::mlir::LLVM::SharedMemoryObject;
|
||||
using ::mlir::triton::gpu::BlockedEncodingAttr;
|
||||
using ::mlir::triton::gpu::CTALayoutAttr;
|
||||
using ::mlir::triton::gpu::DotOperandEncodingAttr;
|
||||
using ::mlir::triton::gpu::MmaEncodingAttr;
|
||||
using ::mlir::triton::gpu::SliceEncodingAttr;
|
||||
using ::mlir::triton::gpu::TMAMetadataTy;
|
||||
|
||||
typedef DenseMap<Operation *, triton::MakeTensorPtrOp> TensorPtrMapT;
|
||||
|
||||
namespace mlir {
|
||||
namespace LLVM {
|
||||
@@ -141,36 +153,39 @@ protected:
|
||||
}
|
||||
};
|
||||
|
||||
using IndexCacheKeyT = std::pair<Attribute, RankedTensorType>;
|
||||
struct IndexCacheKeyT {
|
||||
Attribute layout;
|
||||
RankedTensorType type;
|
||||
bool withCTAOffset;
|
||||
};
|
||||
|
||||
struct CacheKeyDenseMapInfo {
|
||||
static IndexCacheKeyT getEmptyKey() {
|
||||
auto *pointer = llvm::DenseMapInfo<void *>::getEmptyKey();
|
||||
return std::make_pair(
|
||||
mlir::Attribute(static_cast<mlir::Attribute::ImplType *>(pointer)),
|
||||
RankedTensorType{});
|
||||
return {mlir::Attribute(static_cast<mlir::Attribute::ImplType *>(pointer)),
|
||||
RankedTensorType{}, true};
|
||||
}
|
||||
static IndexCacheKeyT getTombstoneKey() {
|
||||
auto *pointer = llvm::DenseMapInfo<void *>::getTombstoneKey();
|
||||
auto tombstone = llvm::DenseMapInfo<RankedTensorType>::getTombstoneKey();
|
||||
return std::make_pair(
|
||||
mlir::Attribute(static_cast<mlir::Attribute::ImplType *>(pointer)),
|
||||
tombstone);
|
||||
return {mlir::Attribute(static_cast<mlir::Attribute::ImplType *>(pointer)),
|
||||
tombstone, true};
|
||||
}
|
||||
static unsigned getHashValue(IndexCacheKeyT key) {
|
||||
auto shape = key.second.getShape();
|
||||
return llvm::hash_combine(mlir::hash_value(key.first),
|
||||
mlir::hash_value(key.second));
|
||||
return llvm::hash_combine(mlir::hash_value(key.layout),
|
||||
mlir::hash_value(key.type),
|
||||
llvm::hash_value(key.withCTAOffset));
|
||||
}
|
||||
static bool isEqual(IndexCacheKeyT LHS, IndexCacheKeyT RHS) {
|
||||
return LHS == RHS;
|
||||
return LHS.layout == RHS.layout && LHS.type == RHS.type &&
|
||||
LHS.withCTAOffset == RHS.withCTAOffset;
|
||||
}
|
||||
};
|
||||
|
||||
class ConvertTritonGPUOpToLLVMPatternBase {
|
||||
public:
|
||||
// Two levels of value cache in emitting indices calculation:
|
||||
// Key: pair<layout, shape>
|
||||
// Key: {layout, shape, withCTAOffset}
|
||||
struct IndexCacheInfo {
|
||||
DenseMap<IndexCacheKeyT, SmallVector<Value>, CacheKeyDenseMapInfo>
|
||||
*baseIndexCache;
|
||||
@@ -198,6 +213,12 @@ public:
|
||||
: converter(&typeConverter), allocation(&allocation),
|
||||
indexCacheInfo(indexCacheInfo) {}
|
||||
|
||||
explicit ConvertTritonGPUOpToLLVMPatternBase(
|
||||
TritonGPUToLLVMTypeConverter &typeConverter, ModuleAllocation &allocation,
|
||||
TMAMetadataTy *tmaMetadata)
|
||||
: converter(&typeConverter), allocation(&allocation),
|
||||
tmaMetadata(tmaMetadata) {}
|
||||
|
||||
TritonGPUToLLVMTypeConverter *getTypeConverter() const { return converter; }
|
||||
|
||||
static Value
|
||||
@@ -223,6 +244,26 @@ public:
|
||||
return rewriter.create<arith::IndexCastOp>(loc, i32_ty, tid);
|
||||
}
|
||||
|
||||
static Value getSRegValue(OpBuilder &b, Location loc,
|
||||
const std::string &sRegStr) {
|
||||
PTXBuilder builder;
|
||||
auto &mov = builder.create("mov")->o("u32");
|
||||
auto *destOpr = builder.newOperand("=r");
|
||||
auto *sRegOpr = builder.newConstantOperand(sRegStr);
|
||||
mov(destOpr, sRegOpr);
|
||||
Value val = builder.launch(b, loc, b.getIntegerType(32), false);
|
||||
|
||||
auto cast = b.create<UnrealizedConversionCastOp>(
|
||||
loc, TypeRange{b.getIntegerType(32)}, ValueRange{val});
|
||||
return cast.getResult(0);
|
||||
}
|
||||
|
||||
Value getClusterCTAId(ConversionPatternRewriter &rewriter,
|
||||
Location loc) const {
|
||||
return rewriter.create<triton::nvgpu::ClusterCTAIdOp>(
|
||||
loc, rewriter.getI32Type());
|
||||
}
|
||||
|
||||
// -----------------------------------------------------------------------
|
||||
// Shared memory utilities
|
||||
// -----------------------------------------------------------------------
|
||||
@@ -259,7 +300,7 @@ public:
|
||||
// for all indices (row, col) of `srcEncoding` such that idx % inVec = 0,
|
||||
// the pointer: ptr[(row, col)] = base + (rowOff * strides[ord[1]] +
|
||||
// colOff) where :
|
||||
// compute phase = (row // perPhase) % maxPhase
|
||||
// phase = (row // perPhase) % maxPhase
|
||||
// rowOff = row
|
||||
// colOff = colOffSwizzled + colOffOrdered
|
||||
// colOffSwizzled = ((col // outVec) ^ phase) * outVec
|
||||
@@ -280,60 +321,89 @@ public:
|
||||
// then (x + y) XOR z = 0byyyyxxxx XOR 0b00000zzzz = (x XOR z) + y
|
||||
// This means that we can use some immediate offsets for shared memory
|
||||
// operations.
|
||||
auto dstPtrTy = ptr_ty(resElemTy, 3);
|
||||
auto dstPtrTy = ptr_ty(getTypeConverter()->convertType(resElemTy), 3);
|
||||
auto dstOffset = dot(rewriter, loc, offsetVals, smemObj.strides);
|
||||
Value dstPtrBase = gep(dstPtrTy, smemObj.base, dstOffset);
|
||||
|
||||
auto srcEncoding = srcTy.getEncoding();
|
||||
auto srcShape = srcTy.getShape();
|
||||
auto srcShapePerCTA = triton::gpu::getShapePerCTA(srcTy);
|
||||
unsigned numElems = triton::gpu::getTotalElemsPerThread(srcTy);
|
||||
// swizzling params as described in TritonGPUAttrDefs.td
|
||||
unsigned outVec = resSharedLayout.getVec();
|
||||
unsigned perPhase = resSharedLayout.getPerPhase();
|
||||
unsigned maxPhase = resSharedLayout.getMaxPhase();
|
||||
// order
|
||||
// Order
|
||||
auto inOrder = triton::gpu::getOrder(srcEncoding);
|
||||
auto outOrder = triton::gpu::getOrder(resSharedLayout);
|
||||
// tensor indices held by the current thread, as LLVM values
|
||||
auto srcIndices = emitIndices(loc, rewriter, srcEncoding, srcTy);
|
||||
// return values
|
||||
// Tensor indices held by the current thread, as LLVM values
|
||||
auto srcIndices = emitIndices(loc, rewriter, srcEncoding, srcTy, false);
|
||||
// Swizzling with leading offsets (e.g. Hopper GMMA)
|
||||
unsigned swizzlingByteWidth = 0;
|
||||
if (resSharedLayout.getHasLeadingOffset()) {
|
||||
if (perPhase == 4 && maxPhase == 2)
|
||||
swizzlingByteWidth = 32;
|
||||
else if (perPhase == 2 && maxPhase == 4)
|
||||
swizzlingByteWidth = 64;
|
||||
else if (perPhase == 1 && maxPhase == 8)
|
||||
swizzlingByteWidth = 128;
|
||||
else
|
||||
llvm::report_fatal_error("Unsupported shared layout.");
|
||||
}
|
||||
unsigned numElemsPerSwizzlingRow =
|
||||
swizzlingByteWidth * 8 / resElemTy.getIntOrFloatBitWidth();
|
||||
Value numElemsPerSwizzlingRowVal = i32_val(numElemsPerSwizzlingRow);
|
||||
unsigned leadingDimOffset =
|
||||
numElemsPerSwizzlingRow * srcShapePerCTA[outOrder[1]];
|
||||
Value leadingDimOffsetVal = i32_val(leadingDimOffset);
|
||||
// Return values
|
||||
DenseMap<unsigned, Value> ret;
|
||||
// cache for non-immediate offsets
|
||||
DenseMap<unsigned, Value> cacheCol, cacheRow;
|
||||
unsigned minVec = std::min(outVec, inVec);
|
||||
for (unsigned elemIdx = 0; elemIdx < numElems; elemIdx += minVec) {
|
||||
// extract multi dimensional index for current element
|
||||
Value offset = i32_val(0);
|
||||
// Extract multi dimensional index for current element
|
||||
auto idx = srcIndices[elemIdx];
|
||||
Value idxCol = idx[outOrder[0]]; // contiguous dimension
|
||||
Value idxRow = idx[outOrder[1]]; // discontiguous dimension
|
||||
Value strideCol = srcStrides[outOrder[0]];
|
||||
Value strideRow = srcStrides[outOrder[1]];
|
||||
// extract dynamic/static offset for immediate offsetting
|
||||
unsigned immedateOffCol = 0;
|
||||
if (auto add = dyn_cast_or_null<LLVM::AddOp>(idxCol.getDefiningOp()))
|
||||
if (auto _cst = dyn_cast_or_null<LLVM::ConstantOp>(
|
||||
add.getRhs().getDefiningOp())) {
|
||||
unsigned cst =
|
||||
_cst.getValue().cast<IntegerAttr>().getValue().getSExtValue();
|
||||
unsigned key = cst % (outVec * maxPhase);
|
||||
cacheCol.insert({key, idxCol});
|
||||
idxCol = cacheCol[key];
|
||||
immedateOffCol = cst / (outVec * maxPhase) * (outVec * maxPhase);
|
||||
}
|
||||
// extract dynamic/static offset for immediate offsetting
|
||||
unsigned immedateOffRow = 0;
|
||||
if (auto add = dyn_cast_or_null<LLVM::AddOp>(idxRow.getDefiningOp()))
|
||||
if (auto _cst = dyn_cast_or_null<LLVM::ConstantOp>(
|
||||
add.getRhs().getDefiningOp())) {
|
||||
unsigned cst =
|
||||
_cst.getValue().cast<IntegerAttr>().getValue().getSExtValue();
|
||||
unsigned key = cst % (perPhase * maxPhase);
|
||||
cacheRow.insert({key, idxRow});
|
||||
idxRow = cacheRow[key];
|
||||
immedateOffRow = cst / (perPhase * maxPhase) * (perPhase * maxPhase);
|
||||
}
|
||||
// compute phase = (row // perPhase) % maxPhase
|
||||
Value phase = urem(udiv(idxRow, i32_val(perPhase)), i32_val(maxPhase));
|
||||
// extract dynamic/static offset for immediate offsetting
|
||||
unsigned immedateOffCol = 0;
|
||||
unsigned immedateOffRow = 0;
|
||||
if (leadingDimOffset) {
|
||||
// hopper
|
||||
offset =
|
||||
mul(udiv(idxCol, numElemsPerSwizzlingRowVal), leadingDimOffsetVal);
|
||||
// Shrink by swizzling blocks
|
||||
idxCol = urem(idxCol, numElemsPerSwizzlingRowVal);
|
||||
strideRow = numElemsPerSwizzlingRowVal;
|
||||
} else {
|
||||
if (auto add = dyn_cast_or_null<LLVM::AddOp>(idxCol.getDefiningOp()))
|
||||
if (auto _cst = dyn_cast_or_null<LLVM::ConstantOp>(
|
||||
add.getRhs().getDefiningOp())) {
|
||||
unsigned cst =
|
||||
_cst.getValue().cast<IntegerAttr>().getValue().getSExtValue();
|
||||
unsigned key = cst % (outVec * maxPhase);
|
||||
cacheCol.insert({key, idxCol});
|
||||
idxCol = cacheCol[key];
|
||||
immedateOffCol = cst / (outVec * maxPhase) * (outVec * maxPhase);
|
||||
}
|
||||
if (auto add = dyn_cast_or_null<LLVM::AddOp>(idxRow.getDefiningOp()))
|
||||
if (auto _cst = dyn_cast_or_null<LLVM::ConstantOp>(
|
||||
add.getRhs().getDefiningOp())) {
|
||||
unsigned cst =
|
||||
_cst.getValue().cast<IntegerAttr>().getValue().getSExtValue();
|
||||
unsigned key = cst % (perPhase * maxPhase);
|
||||
cacheRow.insert({key, idxRow});
|
||||
idxRow = cacheRow[key];
|
||||
immedateOffRow =
|
||||
cst / (perPhase * maxPhase) * (perPhase * maxPhase);
|
||||
}
|
||||
}
|
||||
// row offset is simply row index
|
||||
Value rowOff = mul(idxRow, strideRow);
|
||||
// because swizzling happens at a granularity of outVec, we need to
|
||||
@@ -347,7 +417,7 @@ public:
|
||||
colOffOrdered = mul(colOffOrdered, i32_val(minVec));
|
||||
Value colOff = add(colOffSwizzled, colOffOrdered);
|
||||
// compute non-immediate offset
|
||||
Value offset = add(rowOff, mul(colOff, strideCol));
|
||||
offset = add(offset, add(rowOff, mul(colOff, strideCol)));
|
||||
Value currPtr = gep(dstPtrTy, dstPtrBase, offset);
|
||||
// compute immediate offset
|
||||
Value immedateOff =
|
||||
@@ -477,7 +547,7 @@ public:
|
||||
auto threadsPerWarp = triton::gpu::getThreadsPerWarp(layout);
|
||||
auto warpsPerCTA = triton::gpu::getWarpsPerCTA(layout);
|
||||
auto order = triton::gpu::getOrder(layout);
|
||||
auto shapePerCTA = triton::gpu::getShapePerCTA(layout, shape);
|
||||
auto shapePerCTATile = triton::gpu::getShapePerCTATile(layout, shape);
|
||||
Value warpSize = i32_val(32);
|
||||
Value laneId = urem(tid, warpSize);
|
||||
Value warpId = udiv(tid, warpSize);
|
||||
@@ -487,7 +557,7 @@ public:
|
||||
delinearize(rewriter, loc, laneId, threadsPerWarp, order);
|
||||
for (unsigned dim = 0; dim < rank; ++dim) {
|
||||
// if there is no data replication across threads on this dimension
|
||||
if (shape[dim] >= shapePerCTA[dim])
|
||||
if (shape[dim] >= shapePerCTATile[dim])
|
||||
continue;
|
||||
// Otherwise, we need to mask threads that will replicate data on this
|
||||
// dimension. Calculate the thread index on this dimension for the CTA
|
||||
@@ -535,13 +605,48 @@ public:
|
||||
// Get offsets / indices for any layout
|
||||
// -----------------------------------------------------------------------
|
||||
|
||||
SmallVector<Value> emitCTAOffsetForLayout(Location loc,
|
||||
ConversionPatternRewriter &rewriter,
|
||||
Attribute layout,
|
||||
ArrayRef<int64_t> shape) const {
|
||||
unsigned rank = shape.size();
|
||||
SmallVector<unsigned> CTAsPerCGA = triton::gpu::getCTAsPerCGA(layout);
|
||||
SmallVector<unsigned> CTASplitNum = triton::gpu::getCTASplitNum(layout);
|
||||
SmallVector<unsigned> CTAOrder = triton::gpu::getCTAOrder(layout);
|
||||
SmallVector<int64_t> shapePerCTA =
|
||||
triton::gpu::getShapePerCTA(CTASplitNum, shape);
|
||||
|
||||
// Delinearize clusterCTAId
|
||||
Value clusterCTAId = getClusterCTAId(rewriter, loc);
|
||||
SmallVector<Value> multiDimClusterCTAId =
|
||||
delinearize(rewriter, loc, clusterCTAId, CTAsPerCGA, CTAOrder);
|
||||
|
||||
// CTA Wrapping
|
||||
for (unsigned i = 0; i < rank; ++i) {
|
||||
// This wrapping rule must be consistent with getShapePerCTA
|
||||
unsigned splitNum = std::min<unsigned>(shape[i], CTASplitNum[i]);
|
||||
multiDimClusterCTAId[i] =
|
||||
urem(multiDimClusterCTAId[i], i32_val(splitNum));
|
||||
}
|
||||
|
||||
SmallVector<Value> CTAOffset(rank);
|
||||
for (unsigned i = 0; i < rank; ++i)
|
||||
CTAOffset[i] = mul(multiDimClusterCTAId[i], i32_val(shapePerCTA[i]));
|
||||
|
||||
return CTAOffset;
|
||||
}
|
||||
|
||||
SmallVector<Value> emitBaseIndexForLayout(Location loc,
|
||||
ConversionPatternRewriter &rewriter,
|
||||
Attribute layout,
|
||||
RankedTensorType type) const {
|
||||
IndexCacheKeyT key = std::make_pair(layout, type);
|
||||
RankedTensorType type,
|
||||
bool withCTAOffset) const {
|
||||
auto shape = type.getShape();
|
||||
IndexCacheKeyT key{layout, type, withCTAOffset};
|
||||
auto cache = indexCacheInfo.baseIndexCache;
|
||||
auto insertPt = indexCacheInfo.indexInsertPoint;
|
||||
|
||||
SmallVector<Value> baseIndex;
|
||||
if (cache && cache->count(key) > 0) {
|
||||
return cache->lookup(key);
|
||||
} else {
|
||||
@@ -550,23 +655,34 @@ public:
|
||||
restoreInsertionPointIfSet(insertPt, rewriter);
|
||||
SmallVector<Value> result;
|
||||
if (auto blockedLayout = layout.dyn_cast<BlockedEncodingAttr>()) {
|
||||
result =
|
||||
emitBaseIndexForBlockedLayout(loc, rewriter, blockedLayout, type);
|
||||
result = emitBaseIndexWithinCTAForBlockedLayout(loc, rewriter,
|
||||
blockedLayout, type);
|
||||
} else if (auto mmaLayout = layout.dyn_cast<MmaEncodingAttr>()) {
|
||||
if (mmaLayout.isVolta())
|
||||
result = emitBaseIndexForMmaLayoutV1(loc, rewriter, mmaLayout, type);
|
||||
if (mmaLayout.isAmpere())
|
||||
result = emitBaseIndexForMmaLayoutV2(loc, rewriter, mmaLayout, type);
|
||||
result = emitBaseIndexWithinCTAForMmaLayoutV1(loc, rewriter,
|
||||
mmaLayout, type);
|
||||
if (mmaLayout.isAmpere() || mmaLayout.isHopper())
|
||||
result = emitBaseIndexWithinCTAForMmaLayoutV2V3(loc, rewriter,
|
||||
mmaLayout, type);
|
||||
} else if (auto sliceLayout = layout.dyn_cast<SliceEncodingAttr>()) {
|
||||
auto parentLayout = sliceLayout.getParent();
|
||||
auto parentShape = sliceLayout.paddedShape(type.getShape());
|
||||
RankedTensorType parentTy = RankedTensorType::get(
|
||||
parentShape, type.getElementType(), parentLayout);
|
||||
result = emitBaseIndexForLayout(loc, rewriter, parentLayout, parentTy);
|
||||
result = emitBaseIndexForLayout(loc, rewriter, parentLayout, parentTy,
|
||||
withCTAOffset);
|
||||
result.erase(result.begin() + sliceLayout.getDim());
|
||||
// CTAOffset has been added in emitBaseIndexForLayout of parentLayout
|
||||
return result;
|
||||
} else {
|
||||
llvm_unreachable("unsupported emitBaseIndexForLayout");
|
||||
}
|
||||
if (withCTAOffset) {
|
||||
auto CTAOffset = emitCTAOffsetForLayout(loc, rewriter, layout, shape);
|
||||
assert(CTAOffset.size() == result.size() && "Rank mismatch");
|
||||
for (unsigned k = 0; k < result.size(); ++k)
|
||||
result[k] = add(result[k], CTAOffset[k]);
|
||||
}
|
||||
if (cache) {
|
||||
cache->insert(std::make_pair(key, result));
|
||||
*insertPt = rewriter.saveInsertionPoint();
|
||||
@@ -584,6 +700,8 @@ public:
|
||||
return emitOffsetForMmaLayoutV1(mmaLayout, type);
|
||||
if (mmaLayout.isAmpere())
|
||||
return emitOffsetForMmaLayoutV2(mmaLayout, type);
|
||||
if (mmaLayout.isHopper())
|
||||
return emitOffsetForMmaLayoutV3(mmaLayout, type);
|
||||
}
|
||||
if (auto sliceLayout = layout.dyn_cast<SliceEncodingAttr>())
|
||||
return emitOffsetForSliceLayout(sliceLayout, type);
|
||||
@@ -593,11 +711,10 @@ public:
|
||||
// -----------------------------------------------------------------------
|
||||
// Emit indices
|
||||
// -----------------------------------------------------------------------
|
||||
SmallVector<SmallVector<Value>> emitIndices(Location loc,
|
||||
ConversionPatternRewriter &b,
|
||||
Attribute layout,
|
||||
RankedTensorType type) const {
|
||||
IndexCacheKeyT key(layout, type);
|
||||
SmallVector<SmallVector<Value>>
|
||||
emitIndices(Location loc, ConversionPatternRewriter &b, Attribute layout,
|
||||
RankedTensorType type, bool withCTAOffset = true) const {
|
||||
IndexCacheKeyT key{layout, type, withCTAOffset};
|
||||
auto cache = indexCacheInfo.indexCache;
|
||||
auto insertPt = indexCacheInfo.indexInsertPoint;
|
||||
if (cache && cache->count(key) > 0) {
|
||||
@@ -608,11 +725,14 @@ public:
|
||||
restoreInsertionPointIfSet(insertPt, b);
|
||||
SmallVector<SmallVector<Value>> result;
|
||||
if (auto blocked = layout.dyn_cast<BlockedEncodingAttr>()) {
|
||||
result = emitIndicesForDistributedLayout(loc, b, blocked, type);
|
||||
result = emitIndicesForDistributedLayout(loc, b, blocked, type,
|
||||
withCTAOffset);
|
||||
} else if (auto mma = layout.dyn_cast<MmaEncodingAttr>()) {
|
||||
result = emitIndicesForDistributedLayout(loc, b, mma, type);
|
||||
result =
|
||||
emitIndicesForDistributedLayout(loc, b, mma, type, withCTAOffset);
|
||||
} else if (auto slice = layout.dyn_cast<SliceEncodingAttr>()) {
|
||||
result = emitIndicesForDistributedLayout(loc, b, slice, type);
|
||||
result =
|
||||
emitIndicesForDistributedLayout(loc, b, slice, type, withCTAOffset);
|
||||
} else {
|
||||
llvm_unreachable(
|
||||
"emitIndices for layouts other than blocked & slice not "
|
||||
@@ -642,19 +762,20 @@ private:
|
||||
// Blocked layout indices
|
||||
// -----------------------------------------------------------------------
|
||||
|
||||
// Get an index-base for each dimension for a \param blocked_layout.
|
||||
SmallVector<Value> emitBaseIndexForBlockedLayout(
|
||||
// Get an index-base for each dimension for a \param blockedLayout.
|
||||
SmallVector<Value> emitBaseIndexWithinCTAForBlockedLayout(
|
||||
Location loc, ConversionPatternRewriter &rewriter,
|
||||
const BlockedEncodingAttr &blocked_layout, RankedTensorType type) const {
|
||||
const BlockedEncodingAttr &blockedLayout, RankedTensorType type) const {
|
||||
auto shape = type.getShape();
|
||||
Value threadId = getThreadId(rewriter, loc);
|
||||
Value warpSize = i32_val(32);
|
||||
Value laneId = urem(threadId, warpSize);
|
||||
Value warpId = udiv(threadId, warpSize);
|
||||
auto sizePerThread = blocked_layout.getSizePerThread();
|
||||
auto threadsPerWarp = blocked_layout.getThreadsPerWarp();
|
||||
auto warpsPerCTA = blocked_layout.getWarpsPerCTA();
|
||||
auto order = blocked_layout.getOrder();
|
||||
auto sizePerThread = blockedLayout.getSizePerThread();
|
||||
auto threadsPerWarp = blockedLayout.getThreadsPerWarp();
|
||||
auto warpsPerCTA = blockedLayout.getWarpsPerCTA();
|
||||
auto order = blockedLayout.getOrder();
|
||||
auto shapePerCTA = triton::gpu::getShapePerCTA(blockedLayout, shape);
|
||||
unsigned rank = shape.size();
|
||||
|
||||
// delinearize threadId to get the base index
|
||||
@@ -666,10 +787,10 @@ private:
|
||||
SmallVector<Value> multiDimBase(rank);
|
||||
for (unsigned k = 0; k < rank; ++k) {
|
||||
// Wrap around multiDimWarpId/multiDimThreadId in case
|
||||
// shape[k] > shapePerCTA[k]
|
||||
// shapePerCTATile[k] > shapePerCTA[k]
|
||||
auto maxWarps =
|
||||
ceil<unsigned>(shape[k], sizePerThread[k] * threadsPerWarp[k]);
|
||||
auto maxThreads = ceil<unsigned>(shape[k], sizePerThread[k]);
|
||||
ceil<unsigned>(shapePerCTA[k], sizePerThread[k] * threadsPerWarp[k]);
|
||||
auto maxThreads = ceil<unsigned>(shapePerCTA[k], sizePerThread[k]);
|
||||
multiDimWarpId[k] = urem(multiDimWarpId[k], i32_val(maxWarps));
|
||||
multiDimThreadId[k] = urem(multiDimThreadId[k], i32_val(maxThreads));
|
||||
// multiDimBase[k] = (multiDimThreadId[k] +
|
||||
@@ -692,16 +813,17 @@ private:
|
||||
auto threadsPerWarp = blockedLayout.getThreadsPerWarp();
|
||||
auto warpsPerCTA = blockedLayout.getWarpsPerCTA();
|
||||
auto order = blockedLayout.getOrder();
|
||||
auto shapePerCTATile = getShapePerCTATile(blockedLayout);
|
||||
auto shapePerCTA = triton::gpu::getShapePerCTA(blockedLayout, shape);
|
||||
|
||||
unsigned rank = shape.size();
|
||||
SmallVector<unsigned> shapePerCTA = getShapePerCTA(blockedLayout);
|
||||
SmallVector<unsigned> tilesPerDim(rank);
|
||||
for (unsigned k = 0; k < rank; ++k)
|
||||
tilesPerDim[k] = ceil<unsigned>(shape[k], shapePerCTA[k]);
|
||||
tilesPerDim[k] = ceil<unsigned>(shapePerCTA[k], shapePerCTATile[k]);
|
||||
|
||||
SmallVector<SmallVector<unsigned>> offset(rank);
|
||||
for (unsigned k = 0; k < rank; ++k) {
|
||||
// 1 block in minimum if shape[k] is less than shapePerCTA[k]
|
||||
// 1 CTA tile in minimum if shapePerCTA[k] is less than shapePerCTATile[k]
|
||||
for (unsigned blockOffset = 0; blockOffset < tilesPerDim[k];
|
||||
++blockOffset)
|
||||
for (unsigned warpOffset = 0; warpOffset < warpsPerCTA[k]; ++warpOffset)
|
||||
@@ -741,12 +863,10 @@ private:
|
||||
// Mma layout indices
|
||||
// -----------------------------------------------------------------------
|
||||
|
||||
SmallVector<Value>
|
||||
emitBaseIndexForMmaLayoutV1(Location loc, ConversionPatternRewriter &rewriter,
|
||||
const MmaEncodingAttr &mmaLayout,
|
||||
RankedTensorType type) const {
|
||||
SmallVector<Value> emitBaseIndexWithinCTAForMmaLayoutV1(
|
||||
Location loc, ConversionPatternRewriter &rewriter,
|
||||
const MmaEncodingAttr &mmaLayout, RankedTensorType type) const {
|
||||
auto shape = type.getShape();
|
||||
|
||||
auto wpt = mmaLayout.getWarpsPerCTA();
|
||||
static constexpr std::array<int, 3> fpw{{2, 2, 1}};
|
||||
auto [isARow, isBRow, isAVec4, isBVec4, _] =
|
||||
@@ -879,9 +999,6 @@ private:
|
||||
|
||||
SmallVector<Value> multiDimWarpId =
|
||||
delinearize(rewriter, loc, warpId, warpsPerCTA, order);
|
||||
unsigned lastAxis = order[order.size() - 1];
|
||||
multiDimWarpId[lastAxis] =
|
||||
urem(multiDimWarpId[lastAxis], i32_val(warpsPerCTA[lastAxis]));
|
||||
multiDimWarpId[0] = urem(multiDimWarpId[0], i32_val(shape[0] / 16));
|
||||
multiDimWarpId[1] = urem(multiDimWarpId[1], i32_val(shape[1] / 8));
|
||||
Value offWarp0 = mul(multiDimWarpId[0], i32_val(16));
|
||||
@@ -897,10 +1014,13 @@ private:
|
||||
emitOffsetForMmaLayoutV2(const MmaEncodingAttr &mmaLayout,
|
||||
RankedTensorType type) const {
|
||||
auto shape = type.getShape();
|
||||
auto shapePerCTA = getShapePerCTA(mmaLayout, shape);
|
||||
SmallVector<SmallVector<unsigned>> ret;
|
||||
|
||||
for (unsigned i = 0; i < shape[0]; i += getShapePerCTA(mmaLayout)[0]) {
|
||||
for (unsigned j = 0; j < shape[1]; j += getShapePerCTA(mmaLayout)[1]) {
|
||||
for (unsigned i = 0; i < shapePerCTA[0];
|
||||
i += getShapePerCTATile(mmaLayout)[0]) {
|
||||
for (unsigned j = 0; j < shapePerCTA[1];
|
||||
j += getShapePerCTATile(mmaLayout)[1]) {
|
||||
ret.push_back({i, j});
|
||||
ret.push_back({i, j + 1});
|
||||
ret.push_back({i + 8, j});
|
||||
@@ -910,17 +1030,88 @@ private:
|
||||
return ret;
|
||||
}
|
||||
|
||||
SmallVector<Value> emitBaseIndexWithinCTAForMmaLayoutV2V3(
|
||||
Location loc, ConversionPatternRewriter &rewriter,
|
||||
const MmaEncodingAttr &mmaLayout, RankedTensorType type) const {
|
||||
auto shape = type.getShape();
|
||||
auto _warpsPerCTA = mmaLayout.getWarpsPerCTA();
|
||||
assert(_warpsPerCTA.size() == 2);
|
||||
auto order = triton::gpu::getOrder(mmaLayout);
|
||||
ArrayRef<unsigned int> instrShape = mmaLayout.getInstrShape();
|
||||
SmallVector<Value> warpsPerCTA = {i32_val(_warpsPerCTA[0]),
|
||||
i32_val(_warpsPerCTA[1])};
|
||||
auto shapePerCTA = getShapePerCTA(mmaLayout, shape);
|
||||
|
||||
Value threadId = getThreadId(rewriter, loc);
|
||||
Value warpSize = i32_val(32);
|
||||
Value laneId = urem(threadId, warpSize);
|
||||
Value warpId = udiv(threadId, warpSize);
|
||||
|
||||
uint32_t repM = (_warpsPerCTA[0] * instrShape[0]) / shapePerCTA[0];
|
||||
uint32_t repN = (_warpsPerCTA[1] * instrShape[1]) / shapePerCTA[1];
|
||||
|
||||
uint32_t warpsM;
|
||||
if (repM > 1)
|
||||
warpsM = _warpsPerCTA[0] / repM;
|
||||
else
|
||||
warpsM = shape[0] / instrShape[0];
|
||||
|
||||
uint32_t warpsN;
|
||||
if (repN > 1)
|
||||
warpsN = _warpsPerCTA[1] / repN;
|
||||
else
|
||||
warpsN = shape[1] / instrShape[1];
|
||||
|
||||
SmallVector<Value> multiDimWarpId =
|
||||
delinearize(rewriter, loc, warpId, _warpsPerCTA, order);
|
||||
Value warpId0 = urem(multiDimWarpId[0], i32_val(warpsM));
|
||||
Value warpId1 = urem(multiDimWarpId[1], i32_val(warpsN));
|
||||
|
||||
Value offWarp0 = mul(warpId0, i32_val(instrShape[0]));
|
||||
Value offWarp1 = mul(warpId1, i32_val(instrShape[1]));
|
||||
|
||||
SmallVector<Value> multiDimBase(2);
|
||||
multiDimBase[0] = add(udiv(laneId, i32_val(4)), offWarp0);
|
||||
multiDimBase[1] = add(mul(i32_val(2), urem(laneId, i32_val(4))), offWarp1);
|
||||
return multiDimBase;
|
||||
}
|
||||
|
||||
SmallVector<SmallVector<unsigned>>
|
||||
emitOffsetForMmaLayoutV3(const MmaEncodingAttr &mmaLayout,
|
||||
RankedTensorType type) const {
|
||||
auto shape = type.getShape();
|
||||
auto shapePerCTA = getShapePerCTA(mmaLayout, shape);
|
||||
SmallVector<SmallVector<unsigned>> ret;
|
||||
ArrayRef<unsigned int> instrShape = mmaLayout.getInstrShape();
|
||||
|
||||
for (unsigned i = 0; i < shapePerCTA[0];
|
||||
i += getShapePerCTATile(mmaLayout)[0]) {
|
||||
for (unsigned j = 0; j < shapePerCTA[1];
|
||||
j += getShapePerCTATile(mmaLayout)[1]) {
|
||||
for (unsigned k = 0; k < instrShape[1]; k += 8) {
|
||||
ret.push_back({i, j + k});
|
||||
ret.push_back({i, j + k + 1});
|
||||
ret.push_back({i + 8, j + k});
|
||||
ret.push_back({i + 8, j + k + 1});
|
||||
}
|
||||
}
|
||||
}
|
||||
return ret;
|
||||
}
|
||||
|
||||
// Emit indices calculation within each ConversionPattern, and returns a
|
||||
// [elemsPerThread X rank] index matrix.
|
||||
SmallVector<SmallVector<Value>> emitIndicesForDistributedLayout(
|
||||
Location loc, ConversionPatternRewriter &rewriter, Attribute layout,
|
||||
RankedTensorType type) const {
|
||||
RankedTensorType type, bool withCTAOffset) const {
|
||||
// step 1, delinearize threadId to get the base index
|
||||
auto multiDimBase = emitBaseIndexForLayout(loc, rewriter, layout, type);
|
||||
auto multiDimBase =
|
||||
emitBaseIndexForLayout(loc, rewriter, layout, type, withCTAOffset);
|
||||
// step 2, get offset of each element
|
||||
auto offset = emitOffsetForLayout(layout, type);
|
||||
// step 3, add offset to base, and reorder the sequence of indices to
|
||||
// guarantee that elems in the same sizePerThread are adjacent in order
|
||||
// step 3, add offset to base, and reorder the sequence
|
||||
// of indices to guarantee that elems in the same
|
||||
// sizePerThread are adjacent in order
|
||||
auto shape = type.getShape();
|
||||
unsigned rank = shape.size();
|
||||
unsigned elemsPerThread = offset.size();
|
||||
@@ -961,6 +1152,7 @@ protected:
|
||||
TritonGPUToLLVMTypeConverter *converter;
|
||||
ModuleAllocation *allocation;
|
||||
IndexCacheInfo indexCacheInfo;
|
||||
mlir::triton::gpu::TMAMetadataTy *tmaMetadata;
|
||||
};
|
||||
|
||||
template <typename SourceOp>
|
||||
@@ -975,18 +1167,18 @@ public:
|
||||
: ConvertOpToLLVMPattern<SourceOp>(typeConverter, benefit),
|
||||
ConvertTritonGPUOpToLLVMPatternBase(typeConverter) {}
|
||||
|
||||
explicit ConvertTritonGPUOpToLLVMPattern(
|
||||
TritonGPUToLLVMTypeConverter &typeConverter, ModuleAllocation &allocation,
|
||||
PatternBenefit benefit = 1)
|
||||
: ConvertOpToLLVMPattern<SourceOp>(typeConverter, benefit),
|
||||
ConvertTritonGPUOpToLLVMPatternBase(typeConverter, allocation) {}
|
||||
|
||||
explicit ConvertTritonGPUOpToLLVMPattern(
|
||||
TritonGPUToLLVMTypeConverter &typeConverter,
|
||||
IndexCacheInfo indexCacheInfo, PatternBenefit benefit = 1)
|
||||
: ConvertOpToLLVMPattern<SourceOp>(typeConverter, benefit),
|
||||
ConvertTritonGPUOpToLLVMPatternBase(typeConverter, indexCacheInfo) {}
|
||||
|
||||
explicit ConvertTritonGPUOpToLLVMPattern(
|
||||
TritonGPUToLLVMTypeConverter &typeConverter, ModuleAllocation &allocation,
|
||||
PatternBenefit benefit = 1)
|
||||
: ConvertOpToLLVMPattern<SourceOp>(typeConverter, benefit),
|
||||
ConvertTritonGPUOpToLLVMPatternBase(typeConverter, allocation) {}
|
||||
|
||||
explicit ConvertTritonGPUOpToLLVMPattern(
|
||||
TritonGPUToLLVMTypeConverter &typeConverter, ModuleAllocation &allocation,
|
||||
IndexCacheInfo indexCacheInfo, PatternBenefit benefit = 1)
|
||||
@@ -994,6 +1186,13 @@ public:
|
||||
ConvertTritonGPUOpToLLVMPatternBase(typeConverter, allocation,
|
||||
indexCacheInfo) {}
|
||||
|
||||
explicit ConvertTritonGPUOpToLLVMPattern(
|
||||
TritonGPUToLLVMTypeConverter &typeConverter, ModuleAllocation &allocation,
|
||||
mlir::triton::gpu::TMAMetadataTy *tmaMetadata, PatternBenefit benefit = 1)
|
||||
: ConvertOpToLLVMPattern<SourceOp>(typeConverter, benefit),
|
||||
ConvertTritonGPUOpToLLVMPatternBase(typeConverter, allocation,
|
||||
tmaMetadata) {}
|
||||
|
||||
protected:
|
||||
TritonGPUToLLVMTypeConverter *getTypeConverter() const {
|
||||
LLVMTypeConverter *ret =
|
||||
|
||||
@@ -14,20 +14,27 @@
|
||||
#include "mlir/Dialect/LLVMIR/NVVMDialect.h"
|
||||
#include "mlir/Dialect/LLVMIR/ROCDLDialect.h"
|
||||
#include "mlir/Pass/Pass.h"
|
||||
#include "mlir/Transforms/GreedyPatternRewriteDriver.h"
|
||||
#include "triton/Analysis/Allocation.h"
|
||||
#include "triton/Analysis/AxisInfo.h"
|
||||
#include "triton/Analysis/Membar.h"
|
||||
#include "triton/Dialect/Triton/IR/Dialect.h"
|
||||
#include "triton/Dialect/TritonGPU/IR/Dialect.h"
|
||||
#include "triton/Dialect/TritonNvidiaGPU/IR/Dialect.h"
|
||||
#include "triton/Tools/Sys/GetPlatform.hpp"
|
||||
|
||||
#include "BarrierOpToLLVM.h"
|
||||
#include "ClusterOpsToLLVM.h"
|
||||
#include "ConvertLayoutOpToLLVM.h"
|
||||
#include "DotOpToLLVM.h"
|
||||
#include "ElementwiseOpToLLVM.h"
|
||||
#include "LoadStoreOpToLLVM.h"
|
||||
#include "ReduceOpToLLVM.h"
|
||||
#include "RegReallocOpToLLVM.h"
|
||||
#include "ScanOpToLLVM.h"
|
||||
#include "TensorPtrOpsToLLVM.h"
|
||||
#include "TritonGPUToLLVM.h"
|
||||
#include "TritonGPUToLLVMBase.h"
|
||||
#include "TypeConverter.h"
|
||||
#include "ViewOpToLLVM.h"
|
||||
|
||||
@@ -35,6 +42,7 @@
|
||||
|
||||
using namespace mlir;
|
||||
using namespace mlir::triton;
|
||||
namespace ttng = mlir::triton::nvidia_gpu;
|
||||
|
||||
#define GEN_PASS_CLASSES
|
||||
#include "triton/Conversion/TritonGPUToLLVM/Passes.h.inc"
|
||||
@@ -56,6 +64,30 @@ public:
|
||||
}
|
||||
};
|
||||
|
||||
class FoldSplatMaskInInsertAsync : public mlir::RewritePattern {
|
||||
|
||||
public:
|
||||
FoldSplatMaskInInsertAsync(mlir::MLIRContext *context)
|
||||
: mlir::RewritePattern(
|
||||
triton::nvidia_gpu::InsertSliceAsyncV2Op::getOperationName(), 1,
|
||||
context) {}
|
||||
|
||||
LogicalResult
|
||||
matchAndRewrite(mlir::Operation *op,
|
||||
mlir::PatternRewriter &rewriter) const override {
|
||||
auto insertOp = cast<triton::nvidia_gpu::InsertSliceAsyncV2Op>(op);
|
||||
if (!insertOp.getMask())
|
||||
return failure();
|
||||
auto splatOp = insertOp.getMask().getDefiningOp<triton::SplatOp>();
|
||||
if (!splatOp)
|
||||
return failure();
|
||||
rewriter.updateRootInPlace(insertOp, [&]() {
|
||||
insertOp.getMaskMutable().assign(splatOp->getOperand(0));
|
||||
});
|
||||
return mlir::success();
|
||||
}
|
||||
};
|
||||
|
||||
struct ReturnOpConversion : public ConvertOpToLLVMPattern<triton::ReturnOp> {
|
||||
using ConvertOpToLLVMPattern<triton::ReturnOp>::ConvertOpToLLVMPattern;
|
||||
|
||||
@@ -146,6 +178,18 @@ struct FuncOpConversion : public FuncOpConversionBase {
|
||||
if (!allocation.isRoot(funcOp))
|
||||
amendedFuncOp = amendFuncOp(funcOp, rewriter);
|
||||
|
||||
// Collect TMA informations.
|
||||
unsigned numTMALoad = 0;
|
||||
funcOp.walk(
|
||||
[&numTMALoad](triton::nvidia_gpu::InsertSliceAsyncV2Op insertSliceOp) {
|
||||
numTMALoad++;
|
||||
});
|
||||
unsigned numTMAStore = 0;
|
||||
funcOp.walk([&numTMAStore](triton::nvidia_gpu::StoreAsyncOp storeAsyncOp) {
|
||||
numTMAStore++;
|
||||
});
|
||||
unsigned numTMA = numTMALoad + numTMAStore;
|
||||
|
||||
auto newFuncOp = convertFuncOpToLLVMFuncOp(amendedFuncOp, rewriter);
|
||||
if (!newFuncOp) {
|
||||
return failure();
|
||||
@@ -171,6 +215,30 @@ struct FuncOpConversion : public FuncOpConversionBase {
|
||||
// The call graph is updated by mapping the old function to the new one.
|
||||
allocation.mapFuncOp(funcOp, newFuncOp);
|
||||
|
||||
// Append arguments to receive TMADesc in global memory in the runtime
|
||||
auto i8PtrTy = LLVM::LLVMPointerType::get(
|
||||
this->getTypeConverter()->convertType(rewriter.getI8Type()), 1);
|
||||
auto numArgs = newFuncOp.getBody().front().getNumArguments();
|
||||
auto funcTy = newFuncOp.getFunctionType().cast<LLVM::LLVMFunctionType>();
|
||||
SmallVector<Type> newInputsTy(funcTy.getParams().begin(),
|
||||
funcTy.getParams().end());
|
||||
for (unsigned i = 0; i < numTMA; ++i) {
|
||||
newFuncOp.getBody().front().addArgument(i8PtrTy, funcOp.getLoc());
|
||||
newInputsTy.push_back(i8PtrTy);
|
||||
}
|
||||
newFuncOp.setType(
|
||||
LLVM::LLVMFunctionType::get(funcTy.getReturnType(), newInputsTy));
|
||||
// required by AxisInfoAnalysis
|
||||
for (unsigned i = 0; i < numTMA; ++i) {
|
||||
newFuncOp.setArgAttr(numArgs + i, "tt.divisibility",
|
||||
rewriter.getIntegerAttr(i32_ty, 1));
|
||||
}
|
||||
|
||||
newFuncOp->setAttr(kAttrNumTMALoadDescsName,
|
||||
rewriter.getIntegerAttr(i32_ty, numTMALoad));
|
||||
newFuncOp->setAttr(kAttrNumTMAStoreDescsName,
|
||||
rewriter.getIntegerAttr(i32_ty, numTMAStore));
|
||||
|
||||
rewriter.eraseOp(funcOp);
|
||||
return success();
|
||||
}
|
||||
@@ -247,7 +315,6 @@ private:
|
||||
this->getTypeConverter()->packFunctionResults(resultTypes)))
|
||||
return nullptr;
|
||||
}
|
||||
|
||||
auto newCallOp = rewriter.create<LLVM::CallOp>(
|
||||
callOp.getLoc(), packedResult ? TypeRange(packedResult) : TypeRange(),
|
||||
promotedOperands, callOp->getAttrs());
|
||||
@@ -288,8 +355,10 @@ public:
|
||||
} else {
|
||||
addLegalDialect<NVVM::NVVMDialect>();
|
||||
}
|
||||
addLegalDialect<mlir::triton::nvgpu::NVGPUDialect>();
|
||||
addIllegalDialect<triton::TritonDialect>();
|
||||
addIllegalDialect<triton::gpu::TritonGPUDialect>();
|
||||
addIllegalDialect<triton::nvidia_gpu::TritonNvidiaGPUDialect>();
|
||||
addIllegalDialect<mlir::gpu::GPUDialect>();
|
||||
addLegalOp<mlir::UnrealizedConversionCastOp>();
|
||||
}
|
||||
@@ -299,8 +368,11 @@ class ConvertTritonGPUToLLVM
|
||||
: public ConvertTritonGPUToLLVMBase<ConvertTritonGPUToLLVM> {
|
||||
|
||||
public:
|
||||
explicit ConvertTritonGPUToLLVM(int computeCapability, bool isROCM)
|
||||
: computeCapability(computeCapability), isROCM(isROCM) {}
|
||||
explicit ConvertTritonGPUToLLVM(int computeCapability,
|
||||
mlir::triton::gpu::TMAMetadataTy *tmaMetadata,
|
||||
bool isROCM)
|
||||
: computeCapability(computeCapability), tmaMetadata(tmaMetadata),
|
||||
isROCM(isROCM) {}
|
||||
|
||||
void runOnOperation() override {
|
||||
MLIRContext *context = &getContext();
|
||||
@@ -310,19 +382,54 @@ public:
|
||||
TritonGPUToLLVMTypeConverter typeConverter(context, option);
|
||||
TritonLLVMConversionTarget target(*context, isROCM);
|
||||
int numWarps = triton::gpu::TritonGPUDialect::getNumWarps(mod);
|
||||
int numCTAs = triton::gpu::TritonGPUDialect::getNumCTAs(mod);
|
||||
int threadsPerWarp = triton::gpu::TritonGPUDialect::getThreadsPerWarp(mod);
|
||||
|
||||
// Preprocess
|
||||
decomposeFp8e4b15Convert(mod);
|
||||
decomposeMmaToDotOperand(mod, numWarps, threadsPerWarp);
|
||||
decomposeMmaToDotOperand(mod, numWarps, threadsPerWarp, numCTAs);
|
||||
decomposeBlockedToDotOperand(mod);
|
||||
decomposeInsertSliceAsyncOp(mod);
|
||||
decomposeMixedModeDotOp(mod);
|
||||
|
||||
// Allocate shared memory and set barrier
|
||||
ModuleAllocation allocation(mod);
|
||||
ModuleMembarAnalysis membarPass(&allocation);
|
||||
membarPass.run();
|
||||
|
||||
/* Get tensorPtrMap before conversion */
|
||||
TensorPtrMapT tensorPtrMap;
|
||||
mod.walk([&tensorPtrMap](
|
||||
mlir::triton::nvidia_gpu::InsertSliceAsyncV2Op insertOp) {
|
||||
auto src = insertOp.getSrc();
|
||||
auto ptrTy = src.getType().dyn_cast<triton::PointerType>();
|
||||
if (ptrTy && ptrTy.getPointeeType().isa<RankedTensorType>()) {
|
||||
auto makeTensorPtrOp = getMakeTensorPtrOp(insertOp.getSrc());
|
||||
tensorPtrMap[insertOp.getOperation()] = makeTensorPtrOp;
|
||||
}
|
||||
});
|
||||
|
||||
mod.walk([&tensorPtrMap](mlir::triton::nvidia_gpu::StoreAsyncOp storeOp) {
|
||||
auto dst = storeOp.getDst();
|
||||
auto ptrTy = dst.getType().dyn_cast<triton::PointerType>();
|
||||
if (ptrTy && ptrTy.getPointeeType().isa<RankedTensorType>()) {
|
||||
auto makeTensorPtrOp = getMakeTensorPtrOp(storeOp.getDst());
|
||||
tensorPtrMap[storeOp.getOperation()] = makeTensorPtrOp;
|
||||
}
|
||||
});
|
||||
|
||||
// Hack: cleanup
|
||||
{
|
||||
RewritePatternSet patterns(context);
|
||||
patterns.add<FoldSplatMaskInInsertAsync>(context);
|
||||
SmallVector<Operation *> insertSlices;
|
||||
mod.walk([&insertSlices](triton::nvidia_gpu::InsertSliceAsyncV2Op op) {
|
||||
insertSlices.push_back(op);
|
||||
});
|
||||
if (applyOpPatternsAndFold(insertSlices, std::move(patterns)).failed())
|
||||
signalPassFailure();
|
||||
}
|
||||
|
||||
// Lower functions
|
||||
{
|
||||
mlir::LowerToLLVMOptions option(context);
|
||||
@@ -358,9 +465,14 @@ public:
|
||||
}
|
||||
|
||||
ModuleAxisInfoAnalysis axisInfoAnalysis(mod);
|
||||
// Rewrite ops
|
||||
RewritePatternSet patterns(context);
|
||||
// TritonGPU lowering patterns
|
||||
|
||||
// Emit logics to get threadId/blockIds/linearized clusterCTAId etc. and
|
||||
// cache the values. The reason to do it here is that cluster_ctaid is
|
||||
// currently implemented via inline asm, and thus cannot be CSEed.
|
||||
// clusterCTAId will be emitted only when numCTAs is larger than 1, and
|
||||
// other values will be DCEed if not used hereafter.
|
||||
bool isWarpSpecialization =
|
||||
ttng::TritonNvidiaGPUDialect::getWSSupportedAttr(mod);
|
||||
OpBuilder::InsertPoint indexInsertPoint;
|
||||
ConvertTritonGPUOpToLLVMPatternBase::IndexCacheInfo indexCacheInfo{
|
||||
&baseIndexCache, &indexCache, &indexInsertPoint};
|
||||
@@ -368,21 +480,56 @@ public:
|
||||
if (axisInfoAnalysis.getNumFunctions() > 1) {
|
||||
indexCacheInfo = {nullptr, nullptr, nullptr};
|
||||
}
|
||||
populateTritonGPUToLLVMPatterns(typeConverter, patterns, allocation,
|
||||
indexCacheInfo, /*benefit=*/1);
|
||||
populateConvertLayoutOpToLLVMPatterns(typeConverter, patterns, allocation,
|
||||
indexCacheInfo, /*benefit=*/1);
|
||||
populateDotOpToLLVMPatterns(typeConverter, patterns, allocation,
|
||||
/*benefit=*/1);
|
||||
populateElementwiseOpToLLVMPatterns(typeConverter, patterns, /*benefit=*/1);
|
||||
populateLoadStoreOpToLLVMPatterns(typeConverter, patterns, axisInfoAnalysis,
|
||||
allocation, indexCacheInfo,
|
||||
/*benefit=*/1);
|
||||
populateReduceOpToLLVMPatterns(typeConverter, patterns, allocation,
|
||||
indexCacheInfo, /*benefit=*/1);
|
||||
populateScanOpToLLVMPatterns(typeConverter, patterns, allocation,
|
||||
indexCacheInfo, /*benefit=*/1);
|
||||
populateViewOpToLLVMPatterns(typeConverter, patterns, /*benefit=*/1);
|
||||
|
||||
// tmaMetadata is absent in a triton-opt unit test, in this case, create a
|
||||
// local one and dump it after this pass is done.
|
||||
mlir::triton::gpu::TMAMetadataTy tmaMetaDataDebug;
|
||||
if (tmaMetadata == nullptr)
|
||||
tmaMetadata = &tmaMetaDataDebug;
|
||||
|
||||
RewritePatternSet patterns(context);
|
||||
|
||||
auto populatePatterns1 = [&](auto populateFunc) {
|
||||
populateFunc(typeConverter, patterns, numWarps, axisInfoAnalysis,
|
||||
allocation, indexCacheInfo,
|
||||
/*benefit*/ 10);
|
||||
};
|
||||
|
||||
auto populatePatterns2 = [&](auto populateFunc) {
|
||||
populateFunc(typeConverter, patterns, numWarps, axisInfoAnalysis,
|
||||
allocation, /*benefit*/ 10);
|
||||
};
|
||||
|
||||
auto populatePatterns3 = [&](auto populateFunc) {
|
||||
populateFunc(typeConverter, patterns, numWarps, axisInfoAnalysis,
|
||||
allocation, indexCacheInfo, tmaMetadata, &tensorPtrMap,
|
||||
/*benefit*/ 10);
|
||||
};
|
||||
|
||||
auto populatePatterns4 = [&](auto populateFunc) {
|
||||
populateFunc(typeConverter, patterns, numWarps, axisInfoAnalysis,
|
||||
allocation, indexCacheInfo, computeCapability,
|
||||
/*benefit*/ 10);
|
||||
};
|
||||
|
||||
populatePatterns1(populateTritonGPUToLLVMPatterns);
|
||||
populatePatterns1(populateConvertLayoutOpToLLVMPatterns);
|
||||
populatePatterns2(populateDotOpToLLVMPatterns);
|
||||
populatePatterns2(populateElementwiseOpToLLVMPatterns);
|
||||
populatePatterns3(populateLoadStoreOpToLLVMPatterns);
|
||||
populatePatterns4(populateReduceOpToLLVMPatterns);
|
||||
populatePatterns1(populateScanOpToLLVMPatterns);
|
||||
populatePatterns2(populateViewOpToLLVMPatterns);
|
||||
populatePatterns2(populateBarrierOpToLLVMPatterns);
|
||||
populatePatterns2(populateTensorPtrOpsToLLVMPatterns);
|
||||
populatePatterns2(populateClusterOpsToLLVMPatterns);
|
||||
populatePatterns2(populateRegReallocOpToLLVMPatterns);
|
||||
|
||||
// TODO(thomas): this should probably be done in a separate step to not
|
||||
// interfere with our own lowering of arith ops. Add arith/math's patterns
|
||||
// to help convert scalar expression to LLVM.
|
||||
mlir::arith::populateArithToLLVMConversionPatterns(typeConverter, patterns);
|
||||
mlir::populateMathToLLVMConversionPatterns(typeConverter, patterns);
|
||||
|
||||
// Native lowering patterns
|
||||
if (isROCM) {
|
||||
@@ -396,10 +543,18 @@ public:
|
||||
patterns);
|
||||
if (failed(applyPartialConversion(mod, target, std::move(patterns))))
|
||||
return signalPassFailure();
|
||||
|
||||
// Fold CTAId when there is only 1 CTA.
|
||||
if (numCTAs == 1) {
|
||||
mod.walk([](triton::nvgpu::ClusterCTAIdOp id) {
|
||||
OpBuilder b(id);
|
||||
Value zero = LLVM::createConstantI32(id->getLoc(), b, 0);
|
||||
id.replaceAllUsesWith(zero);
|
||||
});
|
||||
}
|
||||
}
|
||||
|
||||
private:
|
||||
using IndexCacheKeyT = std::pair<Attribute, RankedTensorType>;
|
||||
DenseMap<IndexCacheKeyT, SmallVector<Value>, CacheKeyDenseMapInfo>
|
||||
baseIndexCache;
|
||||
DenseMap<IndexCacheKeyT, SmallVector<SmallVector<Value>>,
|
||||
@@ -408,6 +563,7 @@ private:
|
||||
|
||||
int computeCapability{};
|
||||
bool isROCM{};
|
||||
mlir::triton::gpu::TMAMetadataTy *tmaMetadata;
|
||||
|
||||
void initSharedMemory(ModuleAllocation &allocation,
|
||||
TritonGPUToLLVMTypeConverter &typeConverter) {
|
||||
@@ -470,8 +626,8 @@ private:
|
||||
});
|
||||
}
|
||||
|
||||
void decomposeMmaToDotOperand(ModuleOp mod, int numWarps,
|
||||
int threadsPerWarp) const {
|
||||
void decomposeMmaToDotOperand(ModuleOp mod, int numWarps, int threadsPerWarp,
|
||||
int numCTAs) const {
|
||||
// Replace `mma -> dot_op` with `mma -> blocked -> dot_op`
|
||||
// unless certain conditions are met
|
||||
mod.walk([&](triton::gpu::ConvertLayoutOp cvtOp) -> void {
|
||||
@@ -487,7 +643,7 @@ private:
|
||||
dstType.getShape(), dstType.getElementType(),
|
||||
triton::gpu::BlockedEncodingAttr::get(
|
||||
mod.getContext(), srcType.getShape(), getSizePerThread(srcMma),
|
||||
getOrder(srcMma), numWarps, threadsPerWarp));
|
||||
getOrder(srcMma), numWarps, threadsPerWarp, numCTAs));
|
||||
auto tmp = builder.create<triton::gpu::ConvertLayoutOp>(
|
||||
cvtOp.getLoc(), tmpType, cvtOp.getOperand());
|
||||
auto newConvert = builder.create<triton::gpu::ConvertLayoutOp>(
|
||||
@@ -514,7 +670,8 @@ private:
|
||||
dstType.getShape(), dstType.getElementType(),
|
||||
triton::gpu::SharedEncodingAttr::get(
|
||||
mod.getContext(), dstDotOp, srcType.getShape(),
|
||||
getOrder(srcBlocked), srcType.getElementType()));
|
||||
srcBlocked.getOrder(), srcBlocked.getCTALayout(),
|
||||
srcType.getElementType()));
|
||||
auto tmp = builder.create<triton::gpu::ConvertLayoutOp>(
|
||||
cvtOp.getLoc(), tmpType, cvtOp.getOperand());
|
||||
auto newConvert = builder.create<triton::gpu::ConvertLayoutOp>(
|
||||
@@ -632,6 +789,52 @@ private:
|
||||
}
|
||||
});
|
||||
}
|
||||
|
||||
static Value promoteOperand(OpBuilder &builder, Location loc, Value operand,
|
||||
Type promotedType) {
|
||||
Type tensorPromotedType =
|
||||
operand.getType().cast<RankedTensorType>().cloneWith(std::nullopt,
|
||||
promotedType);
|
||||
return builder.create<triton::FpToFpOp>(loc, tensorPromotedType, operand);
|
||||
}
|
||||
|
||||
// promote operands of dot op if the existing combination is not natively
|
||||
// supported.
|
||||
void decomposeMixedModeDotOp(ModuleOp mod) const {
|
||||
mod.walk([](triton::DotOp dotOp) -> void {
|
||||
Value D = dotOp.getResult();
|
||||
OpBuilder builder(dotOp);
|
||||
Type AElType =
|
||||
dotOp.getA().getType().cast<RankedTensorType>().getElementType();
|
||||
Type promoteType;
|
||||
MmaEncodingAttr mmaLayout = D.getType()
|
||||
.cast<RankedTensorType>()
|
||||
.getEncoding()
|
||||
.dyn_cast<MmaEncodingAttr>();
|
||||
if (mmaLayout) {
|
||||
bool isNativeHopperFP8 =
|
||||
AElType.isFloat8E5M2() || AElType.isFloat8E4M3FN();
|
||||
bool isFP8 = isNativeHopperFP8 || AElType.isFloat8E5M2FNUZ() ||
|
||||
AElType.isFloat8E4M3FNUZ();
|
||||
if (!isFP8 || (isNativeHopperFP8 && mmaLayout.isHopper()))
|
||||
return;
|
||||
promoteType = builder.getF16Type();
|
||||
} else {
|
||||
// FMA case.
|
||||
Type AElType =
|
||||
dotOp.getA().getType().cast<RankedTensorType>().getElementType();
|
||||
Type DElType = D.getType().cast<RankedTensorType>().getElementType();
|
||||
if (AElType == DElType)
|
||||
return;
|
||||
promoteType = DElType;
|
||||
}
|
||||
Location loc = dotOp.getLoc();
|
||||
Value promotedA = promoteOperand(builder, loc, dotOp.getA(), promoteType);
|
||||
Value promotedB = promoteOperand(builder, loc, dotOp.getB(), promoteType);
|
||||
dotOp.setOperand(0, promotedA);
|
||||
dotOp.setOperand(1, promotedB);
|
||||
});
|
||||
}
|
||||
};
|
||||
|
||||
} // anonymous namespace
|
||||
@@ -640,8 +843,11 @@ namespace mlir {
|
||||
namespace triton {
|
||||
|
||||
std::unique_ptr<OperationPass<ModuleOp>>
|
||||
createConvertTritonGPUToLLVMPass(int computeCapability, bool isROCM) {
|
||||
return std::make_unique<::ConvertTritonGPUToLLVM>(computeCapability, isROCM);
|
||||
createConvertTritonGPUToLLVMPass(int computeCapability,
|
||||
mlir::triton::gpu::TMAMetadataTy *tmaMetadata,
|
||||
bool isROCM) {
|
||||
return std::make_unique<::ConvertTritonGPUToLLVM>(computeCapability,
|
||||
tmaMetadata, isROCM);
|
||||
}
|
||||
|
||||
} // namespace triton
|
||||
|
||||
@@ -41,7 +41,27 @@ TritonGPUToLLVMTypeConverter::TritonGPUToLLVMTypeConverter(
|
||||
|
||||
Type TritonGPUToLLVMTypeConverter::convertTritonPointerType(
|
||||
triton::PointerType type) {
|
||||
// Recursively translate pointee type
|
||||
auto ctx = type.getContext();
|
||||
auto pointeeType = type.getPointeeType();
|
||||
if (pointeeType.isa<RankedTensorType>()) {
|
||||
auto rankedTensorType = pointeeType.cast<RankedTensorType>();
|
||||
// struct { offset0, offset1, shape0, shape1, stride0,
|
||||
// stride1, base_ptr};
|
||||
auto eleType = rankedTensorType.getElementType();
|
||||
auto shape = rankedTensorType.getShape();
|
||||
SmallVector<Type, 4> types;
|
||||
// offsets
|
||||
for (size_t i = 0; i < shape.size(); ++i)
|
||||
types.push_back(IntegerType::get(ctx, 32));
|
||||
// shapes, strides
|
||||
for (size_t i = 0; i < 2 * shape.size(); ++i)
|
||||
types.push_back(IntegerType::get(ctx, 64));
|
||||
|
||||
types.push_back(
|
||||
LLVM::LLVMPointerType::get(eleType, type.getAddressSpace()));
|
||||
|
||||
return LLVM::LLVMStructType::getLiteral(ctx, types);
|
||||
}
|
||||
return LLVM::LLVMPointerType::get(convertType(type.getPointeeType()),
|
||||
type.getAddressSpace());
|
||||
}
|
||||
|
||||
@@ -6,19 +6,19 @@ namespace mlir {
|
||||
namespace LLVM {
|
||||
using namespace mlir::triton;
|
||||
|
||||
Value createConstantI32(Location loc, PatternRewriter &rewriter, int32_t v) {
|
||||
Value createConstantI32(Location loc, OpBuilder &rewriter, int32_t v) {
|
||||
auto i32ty = rewriter.getIntegerType(32);
|
||||
return rewriter.create<LLVM::ConstantOp>(loc, i32ty,
|
||||
IntegerAttr::get(i32ty, v));
|
||||
}
|
||||
|
||||
Value createConstantF32(Location loc, PatternRewriter &rewriter, float v) {
|
||||
Value createConstantF32(Location loc, OpBuilder &rewriter, float v) {
|
||||
auto type = type::f32Ty(rewriter.getContext());
|
||||
return rewriter.create<LLVM::ConstantOp>(loc, type,
|
||||
rewriter.getF32FloatAttr(v));
|
||||
}
|
||||
|
||||
Value createConstantF64(Location loc, PatternRewriter &rewriter, float v) {
|
||||
Value createConstantF64(Location loc, OpBuilder &rewriter, float v) {
|
||||
auto type = type::f64Ty(rewriter.getContext());
|
||||
return rewriter.create<LLVM::ConstantOp>(loc, type,
|
||||
rewriter.getF64FloatAttr(v));
|
||||
@@ -40,6 +40,96 @@ Value createLLVMIntegerConstant(OpBuilder &builder, Location loc, short width,
|
||||
builder.getIntegerAttr(ty, value));
|
||||
}
|
||||
|
||||
// A wrapper of LoadDSmemOp when vec = 1
|
||||
// (1) Get bitwidth from elemTy
|
||||
// (2) Create LoadDSmemOp
|
||||
// (3) Bitcast result from dataTy (u16/u32/u64) back to elemTy
|
||||
Value createLoadDSmem(Location loc, PatternRewriter &rewriter, Value addr,
|
||||
Value ctaId) {
|
||||
assert(addr.getType().isa<LLVMPointerType>() &&
|
||||
"addr must be a pointer type");
|
||||
auto ptrTy = addr.getType().cast<LLVMPointerType>();
|
||||
assert(ptrTy.getAddressSpace() == 3 && "Invalid addr space for load_dsmem");
|
||||
auto elemTy = ptrTy.getElementType();
|
||||
unsigned bitwidth = elemTy.getIntOrFloatBitWidth();
|
||||
Value ret =
|
||||
rewriter.create<triton::nvgpu::LoadDSmemOp>(loc, addr, ctaId, bitwidth);
|
||||
return bitcast(ret, elemTy);
|
||||
}
|
||||
|
||||
// A wrapper of LoadDSmemOp when vec > 1
|
||||
// (1) Get bitwidth from elemTy
|
||||
// (2) Create LoadDSmemOp and extract results from retStruct
|
||||
// (3) Bitcast results from dataTy (u16/u32/u64) back to elemTy
|
||||
SmallVector<Value> createLoadDSmem(Location loc, PatternRewriter &rewriter,
|
||||
Value addr, Value ctaId, unsigned vec) {
|
||||
assert(addr.getType().isa<LLVMPointerType>() &&
|
||||
"addr must be a pointer type");
|
||||
auto ptrTy = addr.getType().cast<LLVMPointerType>();
|
||||
assert(ptrTy.getAddressSpace() == 3 && "Invalid addr space for load_dsmem");
|
||||
auto elemTy = ptrTy.getElementType();
|
||||
unsigned bitwidth = elemTy.getIntOrFloatBitWidth();
|
||||
Value retStruct = rewriter.create<triton::nvgpu::LoadDSmemOp>(
|
||||
loc, addr, ctaId, bitwidth, vec);
|
||||
SmallVector<Value> retVals;
|
||||
for (unsigned i = 0; i < vec; ++i) {
|
||||
auto dataTy = rewriter.getIntegerType(bitwidth);
|
||||
Value data = extract_val(dataTy, retStruct, i);
|
||||
retVals.push_back(bitcast(data, elemTy));
|
||||
}
|
||||
return retVals;
|
||||
}
|
||||
|
||||
// A wrapper of StoreDSmemOp when vec = 1
|
||||
// (1) Get bitwidth from elemTy
|
||||
// (2) Bitcast value from elemTy to dataTy (u16/u32/u64)
|
||||
// (3) Create StoreDSmemOp
|
||||
void createStoreDSmem(Location loc, PatternRewriter &rewriter, Value addr,
|
||||
Value ctaId, Value value, Value pred) {
|
||||
assert(addr.getType().isa<LLVMPointerType>() &&
|
||||
"addr must be a pointer type");
|
||||
auto ptrTy = addr.getType().cast<LLVMPointerType>();
|
||||
assert(ptrTy.getAddressSpace() == 3 && "Invalid addr space for load_dsmem");
|
||||
auto elemTy = ptrTy.getElementType();
|
||||
unsigned bitwidth = elemTy.getIntOrFloatBitWidth();
|
||||
auto dataTy = rewriter.getIntegerType(bitwidth);
|
||||
Value data = bitcast(value, dataTy);
|
||||
rewriter.create<triton::nvgpu::StoreDSmemOp>(loc, addr, ctaId, data, pred);
|
||||
}
|
||||
|
||||
// A wrapper of StoreDSmemOp when vec = 1 and pred = 1
|
||||
void createStoreDSmem(Location loc, PatternRewriter &rewriter, Value addr,
|
||||
Value ctaId, Value value) {
|
||||
Value pred = int_val(/*width=*/1, 1);
|
||||
createStoreDSmem(loc, rewriter, addr, ctaId, value, pred);
|
||||
}
|
||||
|
||||
// A wrapper of StoreDSmemOp when vec > 1
|
||||
// (1) Get bitwidth from elemTy
|
||||
// (2) Bitcast values from elemTy to dataTy (u16/u32/u64)
|
||||
// (3) Create StoreDSmemOp
|
||||
void createStoreDSmem(Location loc, PatternRewriter &rewriter, Value addr,
|
||||
Value ctaId, ArrayRef<Value> values, Value pred) {
|
||||
assert(addr.getType().isa<LLVMPointerType>() &&
|
||||
"addr must be a pointer type");
|
||||
auto ptrTy = addr.getType().cast<LLVMPointerType>();
|
||||
assert(ptrTy.getAddressSpace() == 3 && "Invalid addr space for load_dsmem");
|
||||
auto elemTy = ptrTy.getElementType();
|
||||
unsigned bitwidth = elemTy.getIntOrFloatBitWidth();
|
||||
auto dataTy = rewriter.getIntegerType(bitwidth);
|
||||
SmallVector<Value> data;
|
||||
for (unsigned i = 0; i < values.size(); ++i)
|
||||
data.push_back(bitcast(values[i], dataTy));
|
||||
rewriter.create<triton::nvgpu::StoreDSmemOp>(loc, addr, ctaId, data, pred);
|
||||
}
|
||||
|
||||
// A wrapper of StoreDSmemOp when vec > 1 and pred = 1
|
||||
void createStoreDSmem(Location loc, PatternRewriter &rewriter, Value addr,
|
||||
Value ctaId, ArrayRef<Value> values) {
|
||||
Value pred = int_val(/*width=*/1, 1);
|
||||
createStoreDSmem(loc, rewriter, addr, ctaId, values, pred);
|
||||
}
|
||||
|
||||
SharedMemoryObject
|
||||
getSharedMemoryObjectFromStruct(Location loc, Value llvmStruct,
|
||||
ConversionPatternRewriter &rewriter) {
|
||||
|
||||
@@ -14,6 +14,7 @@
|
||||
#define zext(...) rewriter.create<LLVM::ZExtOp>(loc, __VA_ARGS__)
|
||||
#define sext(...) rewriter.create<LLVM::SExtOp>(loc, __VA_ARGS__)
|
||||
#define fpext(...) rewriter.create<LLVM::FPExtOp>(loc, __VA_ARGS__)
|
||||
#define trunc(...) rewriter.create<LLVM::TruncOp>(loc, __VA_ARGS__)
|
||||
#define udiv(...) rewriter.create<LLVM::UDivOp>(loc, __VA_ARGS__)
|
||||
#define urem(...) rewriter.create<LLVM::URemOp>(loc, __VA_ARGS__)
|
||||
#define add(...) rewriter.create<LLVM::AddOp>(loc, __VA_ARGS__)
|
||||
@@ -43,6 +44,8 @@
|
||||
rewriter.create<LLVM::ExtractElementOp>(loc, __VA_ARGS__)
|
||||
#define load(...) rewriter.create<LLVM::LoadOp>(loc, __VA_ARGS__)
|
||||
#define store(val, ptr) rewriter.create<LLVM::StoreOp>(loc, val, ptr)
|
||||
#define load_dsmem(...) LLVM::createLoadDSmem(loc, rewriter, __VA_ARGS__)
|
||||
#define store_dsmem(...) LLVM::createStoreDSmem(loc, rewriter, __VA_ARGS__)
|
||||
#define fcmp_ogt(lhs, rhs) \
|
||||
rewriter.create<LLVM::FCmpOp>(loc, rewriter.getI1Type(), \
|
||||
LLVM::FCmpPredicate::ogt, lhs, rhs)
|
||||
@@ -75,6 +78,15 @@
|
||||
#define select(...) rewriter.create<LLVM::SelectOp>(loc, __VA_ARGS__)
|
||||
#define address_of(...) rewriter.create<LLVM::AddressOfOp>(loc, __VA_ARGS__)
|
||||
#define barrier() rewriter.create<mlir::gpu::BarrierOp>(loc)
|
||||
#define barSync(rewriter, op, bar, numThreads) \
|
||||
do { \
|
||||
::mlir::triton::PTXBuilder ptxBuilder; \
|
||||
auto &barSyncOp = *ptxBuilder.create<>("bar.sync"); \
|
||||
barSyncOp(ptxBuilder.newConstantOperand(bar), \
|
||||
ptxBuilder.newConstantOperand(numThreads)); \
|
||||
auto voidTy = void_ty(op->getContext()); \
|
||||
ptxBuilder.launch(rewriter, op->getLoc(), voidTy); \
|
||||
} while (0)
|
||||
#define undef(...) rewriter.create<LLVM::UndefOp>(loc, __VA_ARGS__)
|
||||
#define null(...) rewriter.create<LLVM::NullOp>(loc, __VA_ARGS__)
|
||||
#define call(...) rewriter.create<LLVM::CallOp>(loc, __VA_ARGS__)
|
||||
@@ -84,6 +96,8 @@
|
||||
#define i64_ty rewriter.getIntegerType(64)
|
||||
#define i32_ty rewriter.getIntegerType(32)
|
||||
#define i16_ty rewriter.getIntegerType(16)
|
||||
#define i32_ty rewriter.getIntegerType(32)
|
||||
#define i64_ty rewriter.getIntegerType(64)
|
||||
#define ui32_ty rewriter.getIntegerType(32, false)
|
||||
#define f16_ty rewriter.getF16Type()
|
||||
#define bf16_ty rewriter.getBF16Type()
|
||||
@@ -174,13 +188,13 @@ T getLinearIndex(llvm::ArrayRef<T> multiDimIndex, llvm::ArrayRef<T> shape,
|
||||
namespace LLVM {
|
||||
using namespace mlir::triton;
|
||||
|
||||
Value createConstantI32(Location loc, PatternRewriter &rewriter, int32_t v);
|
||||
Value createConstantI32(Location loc, OpBuilder &rewriter, int32_t v);
|
||||
|
||||
/// Create a 32-bit float constant.
|
||||
Value createConstantF32(Location loc, PatternRewriter &rewriter, float v);
|
||||
Value createConstantF32(Location loc, OpBuilder &rewriter, float v);
|
||||
|
||||
/// Create a 64-bit float constant.
|
||||
Value createConstantF64(Location loc, PatternRewriter &rewriter, float v);
|
||||
Value createConstantF64(Location loc, OpBuilder &rewriter, float v);
|
||||
|
||||
/// Create an index type constant.
|
||||
Value createIndexConstant(OpBuilder &builder, Location loc,
|
||||
@@ -190,6 +204,28 @@ Value createIndexConstant(OpBuilder &builder, Location loc,
|
||||
Value createLLVMIntegerConstant(OpBuilder &builder, Location loc, short width,
|
||||
int64_t value);
|
||||
|
||||
/// Usage of macro load_dsmem
|
||||
/// (1) load_dsmem(addr, ctaId)
|
||||
/// (2) load_dsmem(addr, ctaId, vec)
|
||||
Value createLoadDSmem(Location loc, PatternRewriter &rewriter, Value addr,
|
||||
Value ctaId);
|
||||
SmallVector<Value> createLoadDSmem(Location loc, PatternRewriter &rewriter,
|
||||
Value addr, Value ctaId, unsigned vec);
|
||||
|
||||
/// Usage of macro store_dsmem
|
||||
/// (1) store_dsmem(addr, ctaId, value, pred)
|
||||
/// (2) store_dsmem(addr, ctaId, value)
|
||||
/// (3) store_dsmem(addr, ctaId, values, pred)
|
||||
/// (4) store_dsmem(addr, ctaId, values)
|
||||
void createStoreDSmem(Location loc, PatternRewriter &rewriter, Value addr,
|
||||
Value ctaId, Value value, Value pred);
|
||||
void createStoreDSmem(Location loc, PatternRewriter &rewriter, Value addr,
|
||||
Value ctaId, Value value);
|
||||
void createStoreDSmem(Location loc, PatternRewriter &rewriter, Value addr,
|
||||
Value ctaId, ArrayRef<Value> values, Value pred);
|
||||
void createStoreDSmem(Location loc, PatternRewriter &rewriter, Value addr,
|
||||
Value ctaId, ArrayRef<Value> values);
|
||||
|
||||
/// Helper function to get strides from a given shape and its order
|
||||
SmallVector<Value>
|
||||
getStridesFromShapeAndOrder(ArrayRef<int64_t> shape, ArrayRef<unsigned> order,
|
||||
|
||||
@@ -104,6 +104,7 @@ struct CatOpConversion : public ConvertTritonGPUOpToLLVMPattern<CatOp> {
|
||||
using OpAdaptor = typename CatOp::Adaptor;
|
||||
|
||||
explicit CatOpConversion(TritonGPUToLLVMTypeConverter &typeConverter,
|
||||
|
||||
PatternBenefit benefit = 1)
|
||||
: ConvertTritonGPUOpToLLVMPattern<CatOp>(typeConverter, benefit) {}
|
||||
|
||||
@@ -138,6 +139,7 @@ struct CatOpConversion : public ConvertTritonGPUOpToLLVMPattern<CatOp> {
|
||||
struct ViewOpConversion : public ConvertTritonGPUOpToLLVMPattern<ViewOp> {
|
||||
using OpAdaptor = typename ViewOp::Adaptor;
|
||||
explicit ViewOpConversion(TritonGPUToLLVMTypeConverter &typeConverter,
|
||||
|
||||
PatternBenefit benefit = 1)
|
||||
: ConvertTritonGPUOpToLLVMPattern<ViewOp>(typeConverter, benefit) {}
|
||||
|
||||
@@ -159,6 +161,7 @@ struct ExpandDimsOpConversion
|
||||
: public ConvertTritonGPUOpToLLVMPattern<ExpandDimsOp> {
|
||||
using OpAdaptor = typename ExpandDimsOp::Adaptor;
|
||||
explicit ExpandDimsOpConversion(TritonGPUToLLVMTypeConverter &typeConverter,
|
||||
|
||||
PatternBenefit benefit = 1)
|
||||
: ConvertTritonGPUOpToLLVMPattern<ExpandDimsOp>(typeConverter, benefit) {}
|
||||
|
||||
@@ -221,7 +224,9 @@ struct TransOpConversion
|
||||
};
|
||||
|
||||
void populateViewOpToLLVMPatterns(TritonGPUToLLVMTypeConverter &typeConverter,
|
||||
RewritePatternSet &patterns,
|
||||
RewritePatternSet &patterns, int numWarps,
|
||||
ModuleAxisInfoAnalysis &axisInfoAnalysis,
|
||||
ModuleAllocation &allocation,
|
||||
PatternBenefit benefit) {
|
||||
patterns.add<ViewOpConversion>(typeConverter, benefit);
|
||||
patterns.add<ExpandDimsOpConversion>(typeConverter, benefit);
|
||||
|
||||
@@ -7,7 +7,9 @@ using namespace mlir;
|
||||
using namespace mlir::triton;
|
||||
|
||||
void populateViewOpToLLVMPatterns(TritonGPUToLLVMTypeConverter &typeConverter,
|
||||
RewritePatternSet &patterns,
|
||||
RewritePatternSet &patterns, int numWarps,
|
||||
ModuleAxisInfoAnalysis &axisInfoAnalysis,
|
||||
ModuleAllocation &allocation,
|
||||
PatternBenefit benefit);
|
||||
|
||||
#endif
|
||||
|
||||
@@ -10,6 +10,7 @@
|
||||
#include "triton/Dialect/Triton/IR/Dialect.h"
|
||||
#include "triton/Dialect/TritonGPU/IR/Dialect.h"
|
||||
#include "triton/Dialect/TritonGPU/Transforms/TritonGPUConversion.h"
|
||||
#include "triton/Target/PTX/TmaMetadata.h"
|
||||
#include "llvm/ADT/APSInt.h"
|
||||
#include <numeric>
|
||||
|
||||
@@ -240,10 +241,19 @@ struct TritonExpandDimsPattern
|
||||
retWarpsPerCTA.insert(retWarpsPerCTA.begin() + op.getAxis(), 1);
|
||||
SmallVector<unsigned, 4> retOrder(retShape.size());
|
||||
std::iota(retOrder.begin(), retOrder.end(), 0);
|
||||
|
||||
auto argCTALayout = argEncoding.getCTALayout();
|
||||
auto retCTAsPerCGA = insertOne(argCTALayout.getCTAsPerCGA(), op.getAxis());
|
||||
auto retCTASplitNum =
|
||||
insertOne(argCTALayout.getCTASplitNum(), op.getAxis());
|
||||
auto retCTAOrder = insertOrder(argCTALayout.getCTAOrder(), op.getAxis());
|
||||
auto retCTALayout = triton::gpu::CTALayoutAttr::get(
|
||||
getContext(), retCTAsPerCGA, retCTASplitNum, retCTAOrder);
|
||||
|
||||
triton::gpu::BlockedEncodingAttr retEncoding =
|
||||
triton::gpu::BlockedEncodingAttr::get(getContext(), retSizePerThread,
|
||||
retThreadsPerWarp, retWarpsPerCTA,
|
||||
retOrder);
|
||||
retOrder, retCTALayout);
|
||||
// convert operand to slice of return type
|
||||
Attribute newArgEncoding = triton::gpu::SliceEncodingAttr::get(
|
||||
getContext(), op.getAxis(), retEncoding);
|
||||
@@ -257,6 +267,26 @@ struct TritonExpandDimsPattern
|
||||
adaptor.getAttributes());
|
||||
return success();
|
||||
}
|
||||
|
||||
private:
|
||||
template <typename T>
|
||||
SmallVector<T> insertOne(ArrayRef<T> vec, unsigned axis) const {
|
||||
SmallVector<T> res(vec.begin(), vec.end());
|
||||
res.insert(res.begin() + axis, 1);
|
||||
return res;
|
||||
}
|
||||
|
||||
// Example: order = [ 0, 2, 1, 3], dim = 2
|
||||
// resOrder = [2, 0, 3, 1, 4]
|
||||
SmallVector<unsigned> insertOrder(ArrayRef<unsigned> order,
|
||||
unsigned axis) const {
|
||||
SmallVector<unsigned> resOrder(order.begin(), order.end());
|
||||
for (unsigned i = 0; i < resOrder.size(); ++i)
|
||||
if (resOrder[i] >= axis)
|
||||
++resOrder[i];
|
||||
resOrder.insert(resOrder.begin(), axis);
|
||||
return resOrder;
|
||||
}
|
||||
};
|
||||
|
||||
struct TritonDotPattern : public OpConversionPattern<triton::DotOp> {
|
||||
@@ -270,6 +300,7 @@ struct TritonDotPattern : public OpConversionPattern<triton::DotOp> {
|
||||
auto typeConverter = getTypeConverter<TritonGPUTypeConverter>();
|
||||
int numWarps = typeConverter->getNumWarps();
|
||||
int threadsPerWarp = typeConverter->getThreadsPerWarp();
|
||||
int numCTAs = typeConverter->getNumCTAs();
|
||||
|
||||
SmallVector<unsigned> retSizePerThread = {1, 1};
|
||||
if (origShape[0] * origShape[1] / (numWarps * threadsPerWarp) >= 4)
|
||||
@@ -279,7 +310,7 @@ struct TritonDotPattern : public OpConversionPattern<triton::DotOp> {
|
||||
SmallVector<unsigned> retOrder = {1, 0};
|
||||
Attribute dEncoding = triton::gpu::BlockedEncodingAttr::get(
|
||||
getContext(), origShape, retSizePerThread, retOrder, numWarps,
|
||||
threadsPerWarp);
|
||||
threadsPerWarp, numCTAs);
|
||||
RankedTensorType retType =
|
||||
RankedTensorType::get(origShape, origType.getElementType(), dEncoding);
|
||||
// a & b must be of smem layout
|
||||
@@ -354,9 +385,9 @@ struct TritonCatPattern : public OpConversionPattern<triton::CatOp> {
|
||||
newRetSizePerThread[retOrder[0]] *=
|
||||
newRetTotalElemsPerThread / retTotalElemsPerThread;
|
||||
triton::gpu::BlockedEncodingAttr newRetEncoding =
|
||||
triton::gpu::BlockedEncodingAttr::get(getContext(), newRetSizePerThread,
|
||||
retThreadsPerWarp, retWarpsPerCTA,
|
||||
retOrder);
|
||||
triton::gpu::BlockedEncodingAttr::get(
|
||||
getContext(), newRetSizePerThread, retThreadsPerWarp,
|
||||
retWarpsPerCTA, retOrder, retEncoding.getCTALayout());
|
||||
auto newRetType = RankedTensorType::get(retShape, retType.getElementType(),
|
||||
newRetEncoding);
|
||||
addNamedAttrs(rewriter.replaceOpWithNewOp<triton::CatOp>(
|
||||
@@ -386,8 +417,12 @@ struct TritonTransPattern : public OpConversionPattern<triton::TransOp> {
|
||||
if (auto srcBlockedEncoding =
|
||||
srcEncoding.dyn_cast<triton::gpu::BlockedEncodingAttr>())
|
||||
llvm::copy(srcBlockedEncoding.getOrder(), order.begin());
|
||||
srcEncoding =
|
||||
triton::gpu::SharedEncodingAttr::get(getContext(), 1, 1, 1, order);
|
||||
// TODO(Qingyi): need to check whether the CTALayout of srcEncoding should
|
||||
// be used here. For tests where numCTAs = 1, this is not a problem since
|
||||
// all CTALayouts are the same.
|
||||
auto CTALayout = triton::gpu::getCTALayout(srcEncoding);
|
||||
srcEncoding = triton::gpu::SharedEncodingAttr::get(getContext(), 1, 1, 1,
|
||||
order, CTALayout);
|
||||
srcType = RankedTensorType::get(srcType.getShape(),
|
||||
srcType.getElementType(), srcEncoding);
|
||||
src = rewriter.create<triton::gpu::ConvertLayoutOp>(src.getLoc(), srcType,
|
||||
@@ -658,10 +693,12 @@ public:
|
||||
};
|
||||
|
||||
void populateTritonPatterns(TritonGPUTypeConverter &typeConverter,
|
||||
RewritePatternSet &patterns) {
|
||||
RewritePatternSet &patterns, unsigned numCTAs) {
|
||||
MLIRContext *context = patterns.getContext();
|
||||
patterns
|
||||
.insert< // TODO: view should have custom pattern that views the layout
|
||||
TritonGenericPattern<triton::AdvanceOp>,
|
||||
TritonGenericPattern<triton::MakeTensorPtrOp>,
|
||||
TritonGenericPattern<triton::ViewOp>,
|
||||
TritonGenericPattern<triton::BitcastOp>,
|
||||
TritonGenericPattern<triton::FpToFpOp>,
|
||||
@@ -889,16 +926,20 @@ class ConvertTritonToTritonGPU
|
||||
public:
|
||||
ConvertTritonToTritonGPU() = default;
|
||||
// constructor with some parameters set explicitly.
|
||||
ConvertTritonToTritonGPU(int numWarps, int threadsPerWarp) {
|
||||
ConvertTritonToTritonGPU(int numWarps, int threadsPerWarp, int numCTAs,
|
||||
int computeCapability) {
|
||||
this->numWarps = numWarps;
|
||||
this->threadsPerWarp = threadsPerWarp;
|
||||
this->numCTAs = numCTAs;
|
||||
this->computeCapability = computeCapability;
|
||||
}
|
||||
|
||||
void runOnOperation() override {
|
||||
MLIRContext *context = &getContext();
|
||||
ModuleOp mod = getOperation();
|
||||
// type converter
|
||||
TritonGPUTypeConverter typeConverter(context, numWarps, threadsPerWarp);
|
||||
TritonGPUTypeConverter typeConverter(context, numWarps, threadsPerWarp,
|
||||
numCTAs);
|
||||
TritonGPUConversionTarget target(*context, typeConverter);
|
||||
// rewrite patterns
|
||||
RewritePatternSet patterns(context);
|
||||
@@ -906,7 +947,7 @@ public:
|
||||
populateStdPatternsAndLegality(typeConverter, patterns, target);
|
||||
populateArithPatternsAndLegality(typeConverter, patterns, target);
|
||||
populateMathPatternsAndLegality(typeConverter, patterns, target);
|
||||
populateTritonPatterns(typeConverter, patterns);
|
||||
populateTritonPatterns(typeConverter, patterns, numCTAs);
|
||||
// TODO: can we use
|
||||
// mlir::scf::populateSCFStructurealTypeConversionsAndLegality(...) here?
|
||||
populateSCFPatterns(typeConverter, patterns);
|
||||
@@ -925,6 +966,13 @@ public:
|
||||
AttrNumThreadsPerWarp,
|
||||
IntegerAttr::get(i32_ty, llvm::APInt(32, threadsPerWarp.getValue())));
|
||||
|
||||
mod->setAttr(AttrNumCTAsName,
|
||||
IntegerAttr::get(i32_ty, llvm::APInt(32, numCTAs.getValue())));
|
||||
|
||||
mod->setAttr(AttrComputeCapabilityName,
|
||||
IntegerAttr::get(
|
||||
i32_ty, llvm::APInt(32, computeCapability.getValue())));
|
||||
|
||||
// update layouts
|
||||
// broadcast src => multicast, dst => broadcasted
|
||||
// if (failed(target.refineLayouts(mod, numWarps)))
|
||||
@@ -936,8 +984,11 @@ public:
|
||||
|
||||
std::unique_ptr<OperationPass<ModuleOp>>
|
||||
mlir::triton::createConvertTritonToTritonGPUPass(int numWarps,
|
||||
int threadsPerWarp) {
|
||||
return std::make_unique<::ConvertTritonToTritonGPU>(numWarps, threadsPerWarp);
|
||||
int threadsPerWarp,
|
||||
int numCTAs,
|
||||
int computeCapability) {
|
||||
return std::make_unique<::ConvertTritonToTritonGPU>(
|
||||
numWarps, threadsPerWarp, numCTAs, computeCapability);
|
||||
}
|
||||
|
||||
std::unique_ptr<OperationPass<ModuleOp>>
|
||||
|
||||
@@ -1,2 +1,4 @@
|
||||
add_subdirectory(Triton)
|
||||
add_subdirectory(TritonGPU)
|
||||
add_subdirectory(TritonNvidiaGPU)
|
||||
add_subdirectory(NVGPU)
|
||||
|
||||
2
lib/Dialect/NVGPU/CMakeLists.txt
Normal file
2
lib/Dialect/NVGPU/CMakeLists.txt
Normal file
@@ -0,0 +1,2 @@
|
||||
add_subdirectory(IR)
|
||||
add_subdirectory(ToLLVMIR)
|
||||
9
lib/Dialect/NVGPU/IR/CMakeLists.txt
Normal file
9
lib/Dialect/NVGPU/IR/CMakeLists.txt
Normal file
@@ -0,0 +1,9 @@
|
||||
add_mlir_dialect_library(NVGPUIR
|
||||
Dialect.cpp
|
||||
|
||||
DEPENDS
|
||||
NVGPUTableGen
|
||||
NVGPUAttrDefsIncGen
|
||||
|
||||
LINK_LIBS PUBLIC
|
||||
)
|
||||
Some files were not shown because too many files have changed in this diff Show More
Reference in New Issue
Block a user