mirror of
https://github.com/ROCm/ROCm.git
synced 2026-04-05 03:01:17 -04:00
Merge commit '36fc54b6f28168d3644808bfe299f1ba06a36272' into ifu230908-2
Conflicts: .gitignore bin/triton-translate.cpp include/triton/Conversion/TritonGPUToLLVM/TritonGPUToLLVMPass.h include/triton/Dialect/TritonGPU/IR/TritonGPUAttrDefs.td include/triton/Dialect/TritonGPU/IR/TritonGPUDialect.td lib/Analysis/Utility.cpp lib/Conversion/TritonGPUToLLVM/ConvertLayoutOpToLLVM/SharedToDotOperandMMAv2.cpp lib/Conversion/TritonGPUToLLVM/DotOpToLLVM.cpp lib/Conversion/TritonGPUToLLVM/ElementwiseOpToLLVM.cpp lib/Conversion/TritonGPUToLLVM/ReduceOpToLLVM.cpp lib/Conversion/TritonGPUToLLVM/TritonGPUToLLVM.cpp lib/Conversion/TritonGPUToLLVM/TritonGPUToLLVMBase.h lib/Conversion/TritonGPUToLLVM/TritonGPUToLLVMPass.cpp lib/Conversion/TritonGPUToLLVM/Utility.h lib/Dialect/Triton/Transforms/RewriteTensorPointer.cpp lib/Dialect/TritonGPU/IR/Dialect.cpp lib/Dialect/TritonGPU/Transforms/AccelerateMatmul.cpp lib/Dialect/TritonGPU/Transforms/RemoveLayoutConversions.cpp lib/Target/LLVMIR/LLVMIRTranslation.cpp python/src/triton.cc python/test/unit/runtime/test_subproc.py python/triton/compiler/compiler.py python/triton/compiler/make_launcher.py python/triton/language/semantic.py python/triton/runtime/jit.py python/tutorials/06-fused-attention.py test/Conversion/triton_to_tritongpu.mlir test/Conversion/tritongpu_to_llvm.mlir test/TritonGPU/coalesce.mlir unittest/Conversion/TritonGPUToLLVM/CMakeLists.txt
This commit is contained in:
@@ -9,6 +9,7 @@
|
||||
|
||||
#include "triton/Dialect/Triton/IR/Dialect.h"
|
||||
#include "triton/Dialect/TritonGPU/IR/Dialect.h"
|
||||
#include "triton/Dialect/TritonNvidiaGPU/IR/Dialect.h"
|
||||
#include <atomic>
|
||||
#include <limits>
|
||||
|
||||
@@ -158,17 +159,17 @@ private:
|
||||
BufferKind kind;
|
||||
BufferId id;
|
||||
size_t size;
|
||||
size_t alignment;
|
||||
size_t offset;
|
||||
|
||||
bool operator==(const BufferT &other) const { return id == other.id; }
|
||||
bool operator<(const BufferT &other) const { return id < other.id; }
|
||||
|
||||
BufferT() : BufferT(BufferKind::Explicit) {}
|
||||
BufferT(BufferKind kind)
|
||||
: kind(kind), id(InvalidBufferId), size(0), offset(0) {}
|
||||
BufferT(BufferKind kind, size_t size) : BufferT(kind, size, 0) {}
|
||||
BufferT(BufferKind kind, size_t size, size_t offset)
|
||||
: kind(kind), id(nextId++), size(size), offset(offset) {}
|
||||
BufferT() : BufferT(BufferKind::Explicit, 0) {}
|
||||
BufferT(BufferKind kind, size_t size, size_t alignment = 4,
|
||||
size_t offset = 0)
|
||||
: kind(kind), id(nextId++), size(size), alignment(alignment),
|
||||
offset(offset) {}
|
||||
};
|
||||
|
||||
/// Op -> Scratch Buffer
|
||||
|
||||
@@ -21,7 +21,7 @@ namespace mlir {
|
||||
/// This lattice value represents known information on the axes of a lattice.
|
||||
class AxisInfo {
|
||||
public:
|
||||
typedef SmallVector<int64_t, 4> DimVectorT;
|
||||
typedef SmallVector<int64_t> DimVectorT;
|
||||
|
||||
public:
|
||||
/// Default constructor
|
||||
|
||||
@@ -3,6 +3,7 @@
|
||||
|
||||
#include "mlir/Analysis/DataFlowFramework.h"
|
||||
#include "mlir/Analysis/SliceAnalysis.h"
|
||||
#include "triton/Dialect/Triton/IR/Dialect.h"
|
||||
#include "triton/Dialect/TritonGPU/IR/Dialect.h"
|
||||
#include <algorithm>
|
||||
#include <numeric>
|
||||
@@ -125,7 +126,11 @@ bool isSingleValue(Value value);
|
||||
|
||||
bool isMmaToDotShortcut(RankedTensorType &srcTy, RankedTensorType &dstTy);
|
||||
|
||||
Type getElementType(Value value);
|
||||
bool isMmaToMmaShortcut(RankedTensorType &srcTy, RankedTensorType &dstTy);
|
||||
|
||||
// TODO: Move utility functions that belong to ConvertLayoutOp to class
|
||||
// ConvertLayoutOpHelper in the future
|
||||
bool shouldUseDistSmem(Attribute srcLayout, Attribute dstLayout);
|
||||
|
||||
template <typename T_OUT, typename T_IN>
|
||||
inline SmallVector<T_OUT> convertType(ArrayRef<T_IN> in) {
|
||||
@@ -332,6 +337,10 @@ protected:
|
||||
FuncDataMapT funcMap;
|
||||
SmallVector<FunctionOpInterface> roots;
|
||||
};
|
||||
// Create a basic DataFlowSolver with constant and dead code analysis included.
|
||||
std::unique_ptr<DataFlowSolver> createDataFlowSolver();
|
||||
|
||||
triton::MakeTensorPtrOp getMakeTensorPtrOp(Value v);
|
||||
|
||||
} // namespace mlir
|
||||
|
||||
|
||||
@@ -1,2 +1,3 @@
|
||||
add_subdirectory(TritonToTritonGPU)
|
||||
add_subdirectory(TritonGPUToLLVM)
|
||||
add_subdirectory(NVGPUToLLVM)
|
||||
|
||||
3
include/triton/Conversion/NVGPUToLLVM/CMakeLists.txt
Normal file
3
include/triton/Conversion/NVGPUToLLVM/CMakeLists.txt
Normal file
@@ -0,0 +1,3 @@
|
||||
set(LLVM_TARGET_DEFINITIONS Passes.td)
|
||||
mlir_tablegen(Passes.h.inc -gen-pass-decls --name NVGPUToLLVM)
|
||||
add_public_tablegen_target(NVGPUConversionPassIncGen)
|
||||
19
include/triton/Conversion/NVGPUToLLVM/NVGPUToLLVMPass.h
Normal file
19
include/triton/Conversion/NVGPUToLLVM/NVGPUToLLVMPass.h
Normal file
@@ -0,0 +1,19 @@
|
||||
#ifndef TRITON_CONVERSION_NVGPU_TO_LLVM_PASS_H
|
||||
#define TRITON_CONVERSION_NVGPU_TO_LLVM_PASS_H
|
||||
|
||||
#include <memory>
|
||||
|
||||
namespace mlir {
|
||||
|
||||
class ModuleOp;
|
||||
template <typename T> class OperationPass;
|
||||
|
||||
namespace triton {
|
||||
|
||||
std::unique_ptr<OperationPass<ModuleOp>> createConvertNVGPUToLLVMPass();
|
||||
|
||||
} // namespace triton
|
||||
|
||||
} // namespace mlir
|
||||
|
||||
#endif
|
||||
16
include/triton/Conversion/NVGPUToLLVM/Passes.h
Normal file
16
include/triton/Conversion/NVGPUToLLVM/Passes.h
Normal file
@@ -0,0 +1,16 @@
|
||||
#ifndef NVGPU_CONVERSION_PASSES_H
|
||||
#define NVGPU_CONVERSION_PASSES_H
|
||||
|
||||
#include "mlir/Dialect/LLVMIR/LLVMDialect.h"
|
||||
#include "triton/Conversion/NVGPUToLLVM/NVGPUToLLVMPass.h"
|
||||
|
||||
namespace mlir {
|
||||
namespace triton {
|
||||
|
||||
#define GEN_PASS_REGISTRATION
|
||||
#include "triton/Conversion/NVGPUToLLVM/Passes.h.inc"
|
||||
|
||||
} // namespace triton
|
||||
} // namespace mlir
|
||||
|
||||
#endif
|
||||
20
include/triton/Conversion/NVGPUToLLVM/Passes.td
Normal file
20
include/triton/Conversion/NVGPUToLLVM/Passes.td
Normal file
@@ -0,0 +1,20 @@
|
||||
#ifndef NVGPU_CONVERSION_PASSES
|
||||
#define NVGPU_CONVERSION_PASSES
|
||||
|
||||
include "mlir/Pass/PassBase.td"
|
||||
|
||||
|
||||
def ConvertNVGPUToLLVM : Pass<"convert-nv-gpu-to-llvm", "mlir::ModuleOp"> {
|
||||
let summary = "Convert NVGPU to LLVM";
|
||||
let description = [{
|
||||
|
||||
}];
|
||||
let constructor = "mlir::triton::createConvertNVGPUToLLVMPass()";
|
||||
|
||||
let dependentDialects = ["mlir::arith::ArithDialect",
|
||||
"mlir::LLVM::LLVMDialect",
|
||||
"mlir::NVVM::NVVMDialect",
|
||||
"mlir::triton::nvgpu::NVGPUDialect"];
|
||||
}
|
||||
|
||||
#endif
|
||||
@@ -151,6 +151,12 @@ struct PTXBuilder {
|
||||
// aggressive optimizations that may lead to incorrect results.
|
||||
Operand *newOperand(StringRef constraint, bool init = false);
|
||||
|
||||
// Create a new operand that is tied to a previous operand. In this case the
|
||||
// asm would be permitted to write to an input register. Instead of providing
|
||||
// constraint code for this operand, the constraint code of the tied operand
|
||||
// is used.
|
||||
Operand *newOperand(unsigned operandIndex);
|
||||
|
||||
// Create a constant integer operand.
|
||||
Operand *newConstantOperand(int64_t v);
|
||||
// Create a constant operand with explicit code specified.
|
||||
|
||||
@@ -19,6 +19,7 @@ def ConvertTritonGPUToLLVM : Pass<"convert-triton-gpu-to-llvm", "mlir::ModuleOp"
|
||||
"mlir::tensor::TensorDialect",
|
||||
"mlir::triton::TritonDialect",
|
||||
"mlir::triton::gpu::TritonGPUDialect",
|
||||
"mlir::triton::nvidia_gpu::TritonNvidiaGPUDialect",
|
||||
"mlir::ROCDL::ROCDLDialect",
|
||||
"mlir::NVVM::NVVMDialect"];
|
||||
|
||||
@@ -26,9 +27,16 @@ def ConvertTritonGPUToLLVM : Pass<"convert-triton-gpu-to-llvm", "mlir::ModuleOp"
|
||||
Option<"computeCapability", "compute-capability",
|
||||
"int32_t", /*default*/"80",
|
||||
"device compute capability">,
|
||||
Option<"isROCM", "is-rocm",
|
||||
"bool", /*default*/"false",
|
||||
"compile for ROCM-compatible LLVM">,
|
||||
Option<"tmaMetadata", "tma-metadata",
|
||||
"mlir::triton::gpu::TMAMetadataTy*", /*default*/"nullptr",
|
||||
"tma metadata to the runtime">,
|
||||
Option<"target", "target", "enum Target", "mlir::triton::Target::Default",
|
||||
"compile for target compatible LLVM",
|
||||
"llvm::cl::values("
|
||||
"clEnumValN(mlir::triton::Target::NVVM, \"nvvm\", \"compile for "
|
||||
"NVVM-compatible LLVM\"), "
|
||||
"clEnumValN(mlir::triton::Target::ROCDL, \"rocdl\", \"compile for "
|
||||
"ROCDL-compatible LLVM\"))">,
|
||||
];
|
||||
}
|
||||
|
||||
|
||||
@@ -3,6 +3,8 @@
|
||||
|
||||
#include "mlir/Conversion/LLVMCommon/TypeConverter.h"
|
||||
#include "mlir/Transforms/DialectConversion.h"
|
||||
#include "triton/Target/PTX/TmaMetadata.h"
|
||||
|
||||
#include <memory>
|
||||
|
||||
namespace mlir {
|
||||
@@ -12,7 +14,14 @@ template <typename T> class OperationPass;
|
||||
|
||||
namespace triton {
|
||||
|
||||
enum Target { NVVM, ROCDL, Default = NVVM };
|
||||
|
||||
#define GEN_PASS_DECL
|
||||
#include "triton/Conversion/TritonGPUToLLVM/Passes.h.inc"
|
||||
|
||||
std::unique_ptr<OperationPass<ModuleOp>> createConvertTritonGPUToLLVMPass();
|
||||
std::unique_ptr<OperationPass<ModuleOp>>
|
||||
<<<<<<< HEAD
|
||||
#ifdef USE_ROCM
|
||||
createConvertTritonGPUToLLVMPass(int computeCapability = 80,
|
||||
bool isROCM = true);
|
||||
@@ -20,6 +29,10 @@ createConvertTritonGPUToLLVMPass(int computeCapability = 80,
|
||||
createConvertTritonGPUToLLVMPass(int computeCapability = 80,
|
||||
bool isROCM = false);
|
||||
#endif
|
||||
=======
|
||||
createConvertTritonGPUToLLVMPass(const ConvertTritonGPUToLLVMOptions &options);
|
||||
|
||||
>>>>>>> 36fc54b6f28168d3644808bfe299f1ba06a36272
|
||||
} // namespace triton
|
||||
|
||||
} // namespace mlir
|
||||
|
||||
@@ -2,6 +2,7 @@
|
||||
#define TRITON_CONVERSION_PASSES_H
|
||||
|
||||
#include "triton/Conversion/TritonToTritonGPU/TritonToTritonGPUPass.h"
|
||||
#include "triton/Target/PTX/TmaMetadata.h"
|
||||
|
||||
namespace mlir {
|
||||
namespace triton {
|
||||
|
||||
@@ -25,6 +25,12 @@ def ConvertTritonToTritonGPU: Pass<"convert-triton-to-tritongpu", "mlir::ModuleO
|
||||
Option<"threadsPerWarp", "threads-per-warp",
|
||||
"int32_t", /*default*/"TRITONGPU_DEFAULT_WARPSIZE",
|
||||
"number of threads per warp">,
|
||||
Option<"numCTAs", "num-ctas",
|
||||
"int32_t", /*default*/"1",
|
||||
"number of ctas in a cga">,
|
||||
Option<"computeCapability", "compute-capability",
|
||||
"int32_t", /*default*/"80",
|
||||
"compute capability">
|
||||
];
|
||||
}
|
||||
|
||||
|
||||
@@ -11,6 +11,9 @@ template <typename T> class OperationPass;
|
||||
namespace triton {
|
||||
|
||||
constexpr static char AttrNumWarpsName[] = "triton_gpu.num-warps";
|
||||
constexpr static char AttrNumCTAsName[] = "triton_gpu.num-ctas";
|
||||
constexpr static char AttrComputeCapabilityName[] =
|
||||
"triton_gpu.compute-capability";
|
||||
|
||||
constexpr static char AttrNumThreadsPerWarp[] = "triton_gpu.threads-per-warp";
|
||||
|
||||
@@ -19,7 +22,8 @@ std::unique_ptr<OperationPass<ModuleOp>> createConvertTritonToTritonGPUPass();
|
||||
|
||||
// Create the pass with numWarps set explicitly.
|
||||
std::unique_ptr<OperationPass<ModuleOp>>
|
||||
createConvertTritonToTritonGPUPass(int numWarps, int threadsPerWarp = 32);
|
||||
createConvertTritonToTritonGPUPass(int numWarps, int threadsPerWarp = 32,
|
||||
int numCTAs = 1, int computeCapability = 80);
|
||||
|
||||
} // namespace triton
|
||||
} // namespace mlir
|
||||
|
||||
@@ -1,2 +1,4 @@
|
||||
add_subdirectory(Triton)
|
||||
add_subdirectory(TritonGPU)
|
||||
add_subdirectory(TritonNvidiaGPU)
|
||||
add_subdirectory(NVGPU)
|
||||
|
||||
2
include/triton/Dialect/NVGPU/CMakeLists.txt
Normal file
2
include/triton/Dialect/NVGPU/CMakeLists.txt
Normal file
@@ -0,0 +1,2 @@
|
||||
add_subdirectory(IR)
|
||||
#add_subdirectory(Transforms)
|
||||
14
include/triton/Dialect/NVGPU/IR/CMakeLists.txt
Normal file
14
include/triton/Dialect/NVGPU/IR/CMakeLists.txt
Normal file
@@ -0,0 +1,14 @@
|
||||
set(LLVM_TARGET_DEFINITIONS NVGPUOps.td)
|
||||
mlir_tablegen(Dialect.h.inc -gen-dialect-decls -dialect=nvgpu)
|
||||
mlir_tablegen(Dialect.cpp.inc -gen-dialect-defs -dialect=nvgpu)
|
||||
mlir_tablegen(OpsConversions.inc -gen-llvmir-conversions)
|
||||
mlir_tablegen(Ops.h.inc -gen-op-decls)
|
||||
mlir_tablegen(Ops.cpp.inc -gen-op-defs)
|
||||
mlir_tablegen(OpsEnums.h.inc -gen-enum-decls)
|
||||
mlir_tablegen(OpsEnums.cpp.inc -gen-enum-defs)
|
||||
add_public_tablegen_target(NVGPUTableGen)
|
||||
|
||||
set(LLVM_TARGET_DEFINITIONS NVGPUAttrDefs.td)
|
||||
mlir_tablegen(NVGPUAttrDefs.h.inc -gen-attrdef-decls)
|
||||
mlir_tablegen(NVGPUAttrDefs.cpp.inc -gen-attrdef-defs)
|
||||
add_public_tablegen_target(NVGPUAttrDefsIncGen)
|
||||
47
include/triton/Dialect/NVGPU/IR/Dialect.h
Normal file
47
include/triton/Dialect/NVGPU/IR/Dialect.h
Normal file
@@ -0,0 +1,47 @@
|
||||
/*
|
||||
* Copyright (c) 2023 NVIDIA Corporation & Affiliates. All rights reserved.
|
||||
*
|
||||
* Permission is hereby granted, free of charge, to any person obtaining
|
||||
* a copy of this software and associated documentation files
|
||||
* (the "Software"), to deal in the Software without restriction,
|
||||
* including without limitation the rights to use, copy, modify, merge,
|
||||
* publish, distribute, sublicense, and/or sell copies of the Software,
|
||||
* and to permit persons to whom the Software is furnished to do so,
|
||||
* subject to the following conditions:
|
||||
*
|
||||
* The above copyright notice and this permission notice shall be
|
||||
* included in all copies or substantial portions of the Software.
|
||||
*
|
||||
* THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND,
|
||||
* EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF
|
||||
* MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT.
|
||||
* IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY
|
||||
* CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION OF CONTRACT,
|
||||
* TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION WITH THE
|
||||
* SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE.
|
||||
*/
|
||||
|
||||
#ifndef TRITON_DIALECT_NVGPU_IR_DIALECT_H_
|
||||
#define TRITON_DIALECT_NVGPU_IR_DIALECT_H_
|
||||
|
||||
#include "mlir/Dialect/GPU/IR/GPUDialect.h"
|
||||
#include "mlir/Dialect/LLVMIR/LLVMDialect.h"
|
||||
#include "mlir/Dialect/Tensor/IR/Tensor.h"
|
||||
#include "mlir/IR/BuiltinOps.h"
|
||||
#include "mlir/IR/Dialect.h"
|
||||
#include "triton/Dialect/NVGPU/IR/Dialect.h.inc"
|
||||
#include "triton/Dialect/NVGPU/IR/OpsEnums.h.inc"
|
||||
|
||||
#define GET_ATTRDEF_CLASSES
|
||||
#include "triton/Dialect/NVGPU/IR/NVGPUAttrDefs.h.inc"
|
||||
|
||||
#define GET_OP_CLASSES
|
||||
#include "triton/Dialect/NVGPU/IR/Ops.h.inc"
|
||||
|
||||
namespace mlir {
|
||||
namespace triton {
|
||||
namespace nvgpu {} // namespace nvgpu
|
||||
} // namespace triton
|
||||
} // namespace mlir
|
||||
|
||||
#endif // TRITON_DIALECT_TRITONGPU_IR_DIALECT_H_
|
||||
33
include/triton/Dialect/NVGPU/IR/NVGPUAttrDefs.td
Normal file
33
include/triton/Dialect/NVGPU/IR/NVGPUAttrDefs.td
Normal file
@@ -0,0 +1,33 @@
|
||||
// Copyright (c) 2023 NVIDIA Corporation & Affiliates. All rights reserved.
|
||||
//
|
||||
// Permission is hereby granted, free of charge, to any person obtaining
|
||||
// a copy of this software and associated documentation files
|
||||
// (the "Software"), to deal in the Software without restriction,
|
||||
// including without limitation the rights to use, copy, modify, merge,
|
||||
// publish, distribute, sublicense, and/or sell copies of the Software,
|
||||
// and to permit persons to whom the Software is furnished to do so,
|
||||
// subject to the following conditions:
|
||||
//
|
||||
// The above copyright notice and this permission notice shall be
|
||||
// included in all copies or substantial portions of the Software.
|
||||
//
|
||||
// THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND,
|
||||
// EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF
|
||||
// MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT.
|
||||
// IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY
|
||||
// CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION OF CONTRACT,
|
||||
// TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION WITH THE
|
||||
// SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE.
|
||||
|
||||
#ifndef NVGPU_ATTRDEFS
|
||||
#define NVGPU_ATTRDEFS
|
||||
|
||||
include "triton/Dialect/NVGPU/IR/NVGPUDialect.td"
|
||||
include "mlir/IR/AttrTypeBase.td"
|
||||
|
||||
class NVGPU_Attr<string name, list<Trait> traits = [],
|
||||
string baseCppClass = "::mlir::Attribute">
|
||||
: AttrDef<NVGPU_Dialect, name, traits, baseCppClass> {
|
||||
}
|
||||
|
||||
#endif
|
||||
40
include/triton/Dialect/NVGPU/IR/NVGPUDialect.td
Normal file
40
include/triton/Dialect/NVGPU/IR/NVGPUDialect.td
Normal file
@@ -0,0 +1,40 @@
|
||||
// Copyright (c) 2023 NVIDIA Corporation & Affiliates. All rights reserved.
|
||||
//
|
||||
// Permission is hereby granted, free of charge, to any person obtaining
|
||||
// a copy of this software and associated documentation files
|
||||
// (the "Software"), to deal in the Software without restriction,
|
||||
// including without limitation the rights to use, copy, modify, merge,
|
||||
// publish, distribute, sublicense, and/or sell copies of the Software,
|
||||
// and to permit persons to whom the Software is furnished to do so,
|
||||
// subject to the following conditions:
|
||||
//
|
||||
// The above copyright notice and this permission notice shall be
|
||||
// included in all copies or substantial portions of the Software.
|
||||
//
|
||||
// THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND,
|
||||
// EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF
|
||||
// MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT.
|
||||
// IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY
|
||||
// CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION OF CONTRACT,
|
||||
// TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION WITH THE
|
||||
// SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE.
|
||||
|
||||
#ifndef NVGPU_DIALECT
|
||||
#define NVGPU_DIALECT
|
||||
|
||||
include "mlir/IR/OpBase.td"
|
||||
|
||||
def NVGPU_Dialect : Dialect {
|
||||
let name = "nvgpu";
|
||||
let cppNamespace = "::mlir::triton::nvgpu";
|
||||
|
||||
let description = [{
|
||||
NVGPU Dialect.
|
||||
}];
|
||||
|
||||
let dependentDialects = [
|
||||
"mlir::LLVM::LLVMDialect"
|
||||
];
|
||||
}
|
||||
|
||||
#endif
|
||||
248
include/triton/Dialect/NVGPU/IR/NVGPUOps.td
Normal file
248
include/triton/Dialect/NVGPU/IR/NVGPUOps.td
Normal file
@@ -0,0 +1,248 @@
|
||||
// Copyright (c) 2023 NVIDIA Corporation & Affiliates. All rights reserved.
|
||||
//
|
||||
// Permission is hereby granted, free of charge, to any person obtaining
|
||||
// a copy of this software and associated documentation files
|
||||
// (the "Software"), to deal in the Software without restriction,
|
||||
// including without limitation the rights to use, copy, modify, merge,
|
||||
// publish, distribute, sublicense, and/or sell copies of the Software,
|
||||
// and to permit persons to whom the Software is furnished to do so,
|
||||
// subject to the following conditions:
|
||||
//
|
||||
// The above copyright notice and this permission notice shall be
|
||||
// included in all copies or substantial portions of the Software.
|
||||
//
|
||||
// THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND,
|
||||
// EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF
|
||||
// MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT.
|
||||
// IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY
|
||||
// CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION OF CONTRACT,
|
||||
// TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION WITH THE
|
||||
// SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE.
|
||||
|
||||
#ifndef NVGPU_OPS
|
||||
#define NVGPU_OPS
|
||||
|
||||
include "triton/Dialect/NVGPU/IR/NVGPUDialect.td"
|
||||
include "triton/Dialect/NVGPU/IR/NVGPUAttrDefs.td"
|
||||
include "mlir/IR/OpBase.td"
|
||||
include "mlir/IR/EnumAttr.td"
|
||||
include "mlir/Dialect/LLVMIR/LLVMOpBase.td"
|
||||
|
||||
def I8Ptr_global : LLVM_IntPtrBase<8, 1>;
|
||||
def I8Ptr_shared : LLVM_IntPtrBase<8, 3>;
|
||||
def I64Ptr_shared : LLVM_IntPtrBase<64, 3>;
|
||||
|
||||
class NVGPU_Op<string mnemonic, list<Trait> traits = []> :
|
||||
LLVM_OpBase<NVGPU_Dialect, mnemonic, traits>;
|
||||
|
||||
def NVGPU_WGMMAFenceOp : NVGPU_Op<"wgmma_fence", []> {
|
||||
let assemblyFormat = "attr-dict";
|
||||
}
|
||||
|
||||
|
||||
def NVGPU_WGMMACommitGroupOp : NVGPU_Op<"wgmma_commit_group", []> {
|
||||
let assemblyFormat = "attr-dict";
|
||||
}
|
||||
|
||||
def NVGPU_WGMMAWaitGroupOp : NVGPU_Op<"wgmma_wait_group", []> {
|
||||
let arguments = (ins I32Attr:$pendings);
|
||||
let assemblyFormat = "attr-dict";
|
||||
}
|
||||
|
||||
def NVGPU_MBarrierInitOp : NVGPU_Op<"mbarrier_init", [MemoryEffects<[MemWrite]>]> {
|
||||
let arguments = (ins I64Ptr_shared:$mbarrier, I1:$pred, I32Attr:$count);
|
||||
let assemblyFormat = "$mbarrier `,` $pred attr-dict `:` type($mbarrier)";
|
||||
}
|
||||
|
||||
def MBarrier_ArriveTypeAttr : I32EnumAttr<"MBarriveType",
|
||||
"mbarrier arrive type, either 'normal', 'expect_tx', 'cp_async'",
|
||||
[
|
||||
I32EnumAttrCase<"normal", 0>,
|
||||
I32EnumAttrCase<"cp_async", 1>,
|
||||
I32EnumAttrCase<"expect_tx", 2>,
|
||||
I32EnumAttrCase<"remote", 3>,
|
||||
]>{
|
||||
let cppNamespace = "::mlir::triton::nvgpu";
|
||||
}
|
||||
|
||||
def NVGPU_MBarrierArriveOp : NVGPU_Op<"mbarrier_arrive", []> {
|
||||
let arguments = (ins I64Ptr_shared:$mbarrier, I1:$pred, Optional<I32>:$ctaId, MBarrier_ArriveTypeAttr:$arriveType, DefaultValuedAttr<I32Attr, "0">:$txCount);
|
||||
let assemblyFormat = "$mbarrier `,` $pred (`,` $ctaId^)? attr-dict `:` type($mbarrier)";
|
||||
}
|
||||
|
||||
def NVGPU_MBarrierWaitOp : NVGPU_Op<"mbarrier_wait", []> {
|
||||
let arguments = (ins I64Ptr_shared:$mbarrier, I1:$phase);
|
||||
let assemblyFormat = "$mbarrier `,` $phase attr-dict `:` type(operands)";
|
||||
}
|
||||
|
||||
def NVGPU_NamedBarrierArriveOp : NVGPU_Op<"bar_arrive", []> {
|
||||
let arguments = (ins I32:$bar, I32:$numThreads);
|
||||
let assemblyFormat = "$bar `,` $numThreads attr-dict `:` type(operands)";
|
||||
}
|
||||
|
||||
def NVGPU_NamedBarrierWaitOp : NVGPU_Op<"bar_wait", []> {
|
||||
let arguments = (ins I32:$bar, I32:$numThreads);
|
||||
let assemblyFormat = "$bar `,` $numThreads attr-dict `:` type(operands)";
|
||||
}
|
||||
|
||||
def WGMMADesc_ModeAttr : I32EnumAttr<"WGMMADescMode",
|
||||
"wgmma desc mode, either 'none', 'swizzle128', 'swizzle64', or 'swizzle32'",
|
||||
[
|
||||
I32EnumAttrCase<"none", 0>,
|
||||
I32EnumAttrCase<"swizzle128", 1>,
|
||||
I32EnumAttrCase<"swizzle64", 2>,
|
||||
I32EnumAttrCase<"swizzle32", 3>
|
||||
]>{
|
||||
let cppNamespace = "::mlir::triton::nvgpu";
|
||||
}
|
||||
|
||||
def NVGPU_WGMMADescCreateOp : NVGPU_Op<"wgmma_desc_create", []> {
|
||||
let arguments = (ins LLVM_AnyPointer:$buffer, I32:$height, WGMMADesc_ModeAttr:$mode);
|
||||
let results = (outs I64:$res);
|
||||
let assemblyFormat = "$buffer `,` $height attr-dict `:` functional-type(operands, results)";
|
||||
}
|
||||
|
||||
def NVGPU_TMALoadTiledOp : NVGPU_Op<"tma_load_tiled", [AttrSizedOperandSegments]> {
|
||||
let arguments = (ins I8Ptr_shared:$dst, I64Ptr_shared:$mbarrier, I8Ptr_global:$tmaDesc, I64:$l2Desc,
|
||||
I1:$pred, Variadic<I32>:$coords, Optional<I16>:$mcastMask);
|
||||
let assemblyFormat = "operands attr-dict `:` type(operands)";
|
||||
}
|
||||
|
||||
def NVGPU_TMALoadIm2colOp : NVGPU_Op<"tma_load_im2col", []> {
|
||||
let arguments = (ins I8Ptr_shared:$dst, I64Ptr_shared:$mbarrier, I8Ptr_global:$tmaDesc, I64:$l2Desc, LLVM_AnyStruct:$im2colOffsets, I1:$pred, Variadic<I32>:$coords, I16Attr:$mcastMask);
|
||||
let assemblyFormat = "operands attr-dict `:` type(operands)";
|
||||
}
|
||||
|
||||
def WGMMA_LayoutAttr : I32EnumAttr<"WGMMALayout",
|
||||
"wgmma layout, either 'row' or 'col'",
|
||||
[
|
||||
I32EnumAttrCase<"row", 0>,
|
||||
I32EnumAttrCase<"col", 1>
|
||||
]>{
|
||||
let cppNamespace = "::mlir::triton::nvgpu";
|
||||
}
|
||||
|
||||
def WGMMA_EltTypeAttr : I32EnumAttr<"WGMMAEltType",
|
||||
"wgmma operand type, either 's8', 's32', 'e4m3', 'e5m2', 'f16', 'bf16', 'tf32', or 'f32'",
|
||||
[
|
||||
I32EnumAttrCase<"s8", 0>,
|
||||
I32EnumAttrCase<"s32", 1>,
|
||||
I32EnumAttrCase<"e4m3", 2>,
|
||||
I32EnumAttrCase<"e5m2", 3>,
|
||||
I32EnumAttrCase<"f16", 4>,
|
||||
I32EnumAttrCase<"bf16", 5>,
|
||||
I32EnumAttrCase<"tf32", 6>,
|
||||
I32EnumAttrCase<"f32", 7>
|
||||
]>{
|
||||
let cppNamespace = "::mlir::triton::nvgpu";
|
||||
}
|
||||
|
||||
def WGMMA_OperandType : AnyTypeOf<[LLVM_AnyStruct, I64], "wgmma operand A/B type">;
|
||||
|
||||
def NVGPU_WGMMAOp : NVGPU_Op<"wgmma", []> {
|
||||
let arguments = (ins WGMMA_OperandType:$opA, WGMMA_OperandType:$opB, LLVM_AnyStruct:$opC,
|
||||
I32Attr:$m, I32Attr:$n, I32Attr:$k,
|
||||
WGMMA_EltTypeAttr:$eltTypeC, WGMMA_EltTypeAttr:$eltTypeA, WGMMA_EltTypeAttr:$eltTypeB,
|
||||
WGMMA_LayoutAttr:$layoutA, WGMMA_LayoutAttr:$layoutB);
|
||||
let results = (outs LLVM_AnyStruct:$res);
|
||||
let assemblyFormat = "$opA `,` $opB `,` $opC attr-dict `:` functional-type(operands, $res)";
|
||||
}
|
||||
|
||||
def NVGPU_CGABarrierSyncOp : NVGPU_Op<"cga_barrier_sync", []> {
|
||||
let assemblyFormat = "attr-dict";
|
||||
}
|
||||
|
||||
def NVGPU_CGABarrierArriveOp : NVGPU_Op<"cga_barrier_arrive", []> {
|
||||
let assemblyFormat = "attr-dict";
|
||||
}
|
||||
|
||||
def NVGPU_CGABarrierWaitOp : NVGPU_Op<"cga_barrier_wait", []> {
|
||||
let assemblyFormat = "attr-dict";
|
||||
}
|
||||
|
||||
def NVGPU_LoadDSmemOp : NVGPU_Op<"load_dsmem", [MemoryEffects<[MemRead]>]> {
|
||||
let arguments = (ins LLVM_AnyPointer:$addr, I32:$ctaId, I32Attr:$bitwidth, I32Attr:$vec);
|
||||
let builders = [
|
||||
OpBuilder<(ins "Type":$resultTy, "Value":$addr, "Value":$ctaId)>,
|
||||
OpBuilder<(ins "Value":$addr, "Value":$ctaId, "unsigned":$bitwidth, "unsigned":$vec)>,
|
||||
OpBuilder<(ins "Value":$addr, "Value":$ctaId, "unsigned":$bitwidth)>
|
||||
];
|
||||
let results = (outs LLVM_LoadableType:$result);
|
||||
let assemblyFormat = "operands attr-dict `:` functional-type(operands, results)";
|
||||
}
|
||||
|
||||
def NVGPU_StoreDSmemOp : NVGPU_Op<"store_dsmem", [MemoryEffects<[MemWrite]>]> {
|
||||
let arguments = (ins LLVM_AnyPointer:$addr, I32:$ctaId,
|
||||
Variadic<LLVM_LoadableType>:$values, I1:$pred);
|
||||
let builders = [
|
||||
OpBuilder<(ins "Value":$addr, "Value":$ctaId, "Value":$value, "Value":$pred)>,
|
||||
];
|
||||
let assemblyFormat = "operands attr-dict `:` type(operands)";
|
||||
let extraClassDeclaration = [{
|
||||
unsigned getBitwidth();
|
||||
unsigned getVec();
|
||||
}];
|
||||
}
|
||||
|
||||
def NVGPU_FenceAsyncSharedOp : NVGPU_Op<"fence_async_shared", []> {
|
||||
let arguments = (ins BoolAttr:$bCluster);
|
||||
let assemblyFormat = "attr-dict";
|
||||
}
|
||||
|
||||
def NVGPU_FenceMBarrierInitOp : NVGPU_Op<"fence_mbarrier_init", []> {
|
||||
let assemblyFormat = "attr-dict";
|
||||
}
|
||||
|
||||
def NVGPU_ClusterArriveOp : NVGPU_Op<"cluster_arrive", []> {
|
||||
let arguments = (ins I1Attr:$relaxed);
|
||||
|
||||
let assemblyFormat = "attr-dict";
|
||||
}
|
||||
|
||||
def NVGPU_ClusterWaitOp : NVGPU_Op<"cluster_wait", []> {
|
||||
let assemblyFormat = "attr-dict";
|
||||
}
|
||||
|
||||
def NVGPU_TMAStoreTiledOp : NVGPU_Op<"tma_store_tiled", [MemoryEffects<[MemWrite]>]> {
|
||||
let arguments = (ins I8Ptr_global:$tmaDesc, I8Ptr_shared:$src, I1:$pred, Variadic<I32>:$coords);
|
||||
let assemblyFormat = "operands attr-dict `:` type(operands)";
|
||||
}
|
||||
|
||||
def NVGPU_StoreMatrixOp : NVGPU_Op<"stmatrix", [MemoryEffects<[MemWrite]>]> {
|
||||
let arguments = (ins I8Ptr_shared:$addr, Variadic<I32>:$datas);
|
||||
let assemblyFormat = "operands attr-dict `:` type(operands)";
|
||||
}
|
||||
|
||||
def NVGPU_OffsetOfStmatrixV4Op : NVGPU_Op<"offset_of_stmatrix_v4", []> {
|
||||
let arguments = (ins I32:$threadId, I32:$rowOfWarp, I32:$elemIdx, I32Attr:$leadingDimOffset, I32Attr:$rowStride, I1Attr:$swizzleEnabled);
|
||||
let results = (outs I32:$offset);
|
||||
let assemblyFormat = "operands attr-dict `:` type(operands) `->` type($offset)";
|
||||
}
|
||||
|
||||
def NVGPU_OffsetOfSts64Op : NVGPU_Op<"offset_of_sts64", []> {
|
||||
let arguments = (ins I32:$threadId, I32:$rowOfWarp, I32:$elemIdx, I32Attr:$leadingDimOffset, I32Attr:$rowStride, I1Attr:$swizzleEnabled);
|
||||
let results = (outs I32:$offset);
|
||||
let assemblyFormat = "operands attr-dict `:` type(operands) `->` type($offset)";
|
||||
}
|
||||
|
||||
def NVGPU_Sts64Op : NVGPU_Op<"sts64", [MemoryEffects<[MemWrite]>]> {
|
||||
let arguments = (ins I32:$offset, AnyTypeOf<[F32, I32]>:$d0, AnyTypeOf<[F32, I32]>:$d1);
|
||||
let assemblyFormat = "operands attr-dict `:` type(operands)";
|
||||
}
|
||||
|
||||
def NVGPU_ClusterCTAIdOp : NVGPU_Op<"cluster_id", [Pure]> {
|
||||
let results = (outs I32:$result);
|
||||
let assemblyFormat = "attr-dict";
|
||||
}
|
||||
|
||||
def NVGPU_RegAllocOp : NVGPU_Op<"reg_alloc", []> {
|
||||
let arguments = (ins I32Attr: $regCount);
|
||||
let assemblyFormat = "operands attr-dict `:` type(operands)";
|
||||
}
|
||||
|
||||
def NVGPU_RegDeallocOp : NVGPU_Op<"reg_dealloc", []> {
|
||||
let arguments = (ins I32Attr: $regCount);
|
||||
let assemblyFormat = "operands attr-dict `:` type(operands)";
|
||||
}
|
||||
|
||||
#endif
|
||||
@@ -9,6 +9,8 @@
|
||||
#include "mlir/Dialect/Tensor/IR/Tensor.h"
|
||||
#include "mlir/IR/BuiltinOps.h"
|
||||
#include "mlir/IR/Dialect.h"
|
||||
#include "mlir/IR/FunctionInterfaces.h"
|
||||
#include "mlir/Interfaces/CallInterfaces.h"
|
||||
#include "mlir/Interfaces/ControlFlowInterfaces.h"
|
||||
#include "triton/Dialect/Triton/IR/Dialect.h.inc"
|
||||
#include "triton/Dialect/Triton/IR/OpsEnums.h.inc"
|
||||
|
||||
@@ -9,6 +9,8 @@ include "mlir/IR/OpBase.td"
|
||||
include "mlir/IR/FunctionInterfaces.td" // FunctionOpInterface
|
||||
include "mlir/IR/SymbolInterfaces.td" // SymbolUserOpInterface
|
||||
include "mlir/IR/OpAsmInterface.td" // OpAsmOpInterface
|
||||
include "mlir/Interfaces/CallInterfaces.td" // CallOpInterface
|
||||
include "mlir/Interfaces/CastInterfaces.td" // CastOpInterface
|
||||
include "mlir/Interfaces/SideEffectInterfaces.td" // Pure
|
||||
include "mlir/Interfaces/ControlFlowInterfaces.td" // BranchOpInterface
|
||||
include "mlir/Interfaces/InferTypeOpInterface.td" // SameOperandsAndResultType
|
||||
@@ -135,7 +137,7 @@ def TT_LoadOp : TT_Op<"load",
|
||||
[SameLoadStoreOperandsAndResultShape,
|
||||
SameLoadStoreOperandsAndResultEncoding,
|
||||
AttrSizedOperandSegments,
|
||||
MemoryEffects<[MemRead]>,
|
||||
DeclareOpInterfaceMethods<MemoryEffectsOpInterface>,
|
||||
TypesMatchWith<"infer ptr type from result type",
|
||||
"result", "ptr", "$_self",
|
||||
"mlir::OpTrait::impl::verifyLoadStorePointerAndValueType">,
|
||||
@@ -461,33 +463,23 @@ def TT_ScanReturnOp: TT_Op<"scan.return",
|
||||
//
|
||||
// External Elementwise op
|
||||
//
|
||||
class TT_ExternElementwiseOpBase<string mnemonic, list<Trait> traits = []> :
|
||||
TT_Op<mnemonic,
|
||||
traits # [Elementwise,
|
||||
SameOperandsAndResultEncoding,
|
||||
SameVariadicOperandSize]> {
|
||||
def TT_ExternElementwiseOp : TT_Op<"extern_elementwise", [Elementwise,
|
||||
SameOperandsAndResultEncoding,
|
||||
SameVariadicOperandSize,
|
||||
DeclareOpInterfaceMethods<MemoryEffectsOpInterface>]> {
|
||||
|
||||
let description = [{
|
||||
call an external function $symbol implemented in $libpath/$libname with $args
|
||||
return $libpath/$libname:$symbol($args...)
|
||||
}];
|
||||
|
||||
let arguments = (ins Variadic<TT_Type>:$args, StrAttr:$libname, StrAttr:$libpath, StrAttr:$symbol);
|
||||
let arguments = (ins Variadic<TT_Type>:$args, StrAttr:$libname, StrAttr:$libpath, StrAttr:$symbol, BoolAttr:$pure);
|
||||
|
||||
let results = (outs TT_Type:$result);
|
||||
|
||||
let assemblyFormat = "operands attr-dict `:` functional-type(operands, $result)";
|
||||
}
|
||||
|
||||
def TT_PureExternElementwiseOp : TT_ExternElementwiseOpBase<"pure_extern_elementwise", [Pure, Elementwise]> {
|
||||
let summary = "FFI for pure element-wise extern LLVM bitcode functions";
|
||||
}
|
||||
|
||||
def TT_ImpureExternElementwiseOp : TT_ExternElementwiseOpBase<"impure_extern_elementwise", [MemoryEffects<[MemRead]>,
|
||||
MemoryEffects<[MemWrite]>]> {
|
||||
let summary = "FFI for impure element-wise extern LLVM bitcode functions";
|
||||
}
|
||||
|
||||
//
|
||||
// Make Range Op
|
||||
//
|
||||
@@ -506,6 +498,30 @@ def TT_MakeRangeOp : TT_Op<"make_range", [Pure]> {
|
||||
let results = (outs TT_IntTensor:$result);
|
||||
|
||||
let assemblyFormat = "attr-dict `:` type($result)";
|
||||
|
||||
let hasFolder = 1;
|
||||
}
|
||||
|
||||
//
|
||||
// ElementwiseInlineAsm Op
|
||||
//
|
||||
def TT_ElementwiseInlineAsmOp : TT_Op<"elementwise_inline_asm", [Elementwise,
|
||||
SameOperandsAndResultEncoding,
|
||||
DeclareOpInterfaceMethods<MemoryEffectsOpInterface>]> {
|
||||
let summary = "inline assembly applying elementwise operation to a group of packed element.";
|
||||
let description = [{
|
||||
This will apply the given in inline assembly to `packed_element` number of
|
||||
elements of the inputs. The elements packed together is unknown and will
|
||||
depend on the backend implementation.
|
||||
}];
|
||||
|
||||
let arguments = (ins StrAttr:$asm_string, StrAttr:$constraints, BoolAttr:$pure, I32Attr:$packed_element, Variadic<AnyTypeOf<[TT_Type]>>:$args);
|
||||
let results = (outs TT_Type:$result);
|
||||
|
||||
|
||||
let assemblyFormat = [{
|
||||
$asm_string attr-dict ($args^ `:` type($args))? `->` type($result)
|
||||
}];
|
||||
}
|
||||
|
||||
//
|
||||
@@ -563,6 +579,7 @@ def TT_MakeTensorPtrOp : TT_Op<"make_tensor_ptr",
|
||||
|
||||
let results = (outs TT_TensorPtr:$result);
|
||||
|
||||
// TODO(Keren): define a custom assembly format for this op because the result type cannot be printed correctly
|
||||
// Add additional `[]` to increase readability and split variadic lists
|
||||
let assemblyFormat = "$base `,` `[` $shape `]` `,` `[` $strides `]` `,` `[` $offsets `]` attr-dict `:` type($result)";
|
||||
|
||||
|
||||
@@ -14,7 +14,7 @@ class TritonTypeDef<string name, string _mnemonic>
|
||||
}
|
||||
|
||||
// Floating-point Type
|
||||
def TT_Float : AnyTypeOf<[F8E4M3FNUZ, F8E4M3B11FNUZ, F8E5M2, F16, BF16, F32, F64], "floating-point">;
|
||||
def TT_Float : AnyTypeOf<[F8E4M3FNUZ, F8E4M3FN, F8E4M3B11FNUZ, F8E5M2, F16, BF16, F32, F64], "floating-point">;
|
||||
def TT_FloatTensor : TensorOf<[TT_Float]>;
|
||||
def TT_FloatLike : AnyTypeOf<[TT_Float, TT_FloatTensor]>;
|
||||
|
||||
@@ -74,7 +74,7 @@ def TT_PtrType : TritonTypeDef<"Pointer", "ptr"> {
|
||||
// Scalar Pointer Type: `ptr<>`
|
||||
def TT_Ptr : TT_PtrOf<[AnyType]>;
|
||||
|
||||
// Tensor of Pointer Type
|
||||
// Tensor of Pointer Type: `tensor<ptr<>>`
|
||||
def TT_PtrTensor : TensorOf<[TT_Ptr]>;
|
||||
|
||||
// Tensor of Pointer Type or Pointer type: `tensor<ptr<>>` or `ptr<>`
|
||||
|
||||
@@ -14,6 +14,8 @@ namespace triton {
|
||||
|
||||
bool isTensorPointerType(Type type);
|
||||
|
||||
bool isTensorOrTensorPointerType(Type type);
|
||||
|
||||
unsigned getPointeeBitWidth(Type type);
|
||||
|
||||
Type getPointeeType(Type type);
|
||||
|
||||
@@ -9,7 +9,6 @@ namespace triton {
|
||||
std::unique_ptr<Pass> createCombineOpsPass();
|
||||
|
||||
std::unique_ptr<Pass> createReorderBroadcastPass();
|
||||
|
||||
std::unique_ptr<Pass>
|
||||
createRewriteTensorPointerPass(int computeCapability = 80,
|
||||
bool isROCM = false);
|
||||
|
||||
@@ -3,9 +3,13 @@ mlir_tablegen(Dialect.h.inc -gen-dialect-decls -dialect=triton_gpu)
|
||||
mlir_tablegen(Dialect.cpp.inc -gen-dialect-defs -dialect=triton_gpu)
|
||||
mlir_tablegen(Ops.h.inc -gen-op-decls)
|
||||
mlir_tablegen(Ops.cpp.inc -gen-op-defs)
|
||||
mlir_tablegen(Types.h.inc -gen-typedef-decls -typedefs-dialect=triton_gpu)
|
||||
mlir_tablegen(Types.cpp.inc -gen-typedef-defs -typedefs-dialect=triton_gpu)
|
||||
add_public_tablegen_target(TritonGPUTableGen)
|
||||
|
||||
set(LLVM_TARGET_DEFINITIONS TritonGPUAttrDefs.td)
|
||||
mlir_tablegen(TritonGPUAttrDefs.h.inc -gen-attrdef-decls)
|
||||
mlir_tablegen(TritonGPUAttrDefs.cpp.inc -gen-attrdef-defs)
|
||||
mlir_tablegen(OpsEnums.h.inc -gen-enum-decls)
|
||||
mlir_tablegen(OpsEnums.cpp.inc -gen-enum-defs)
|
||||
add_public_tablegen_target(TritonGPUAttrDefsIncGen)
|
||||
|
||||
@@ -7,6 +7,7 @@
|
||||
#include "mlir/IR/Dialect.h"
|
||||
|
||||
// TritonGPU depends on Triton
|
||||
#include "triton/Dialect/NVGPU/IR/Dialect.h"
|
||||
#include "triton/Dialect/Triton/IR/Dialect.h"
|
||||
#include "triton/Dialect/TritonGPU/IR/Attributes.h"
|
||||
#include "triton/Dialect/TritonGPU/IR/Dialect.h.inc"
|
||||
@@ -73,17 +74,41 @@ getWarpsPerCTAWithUniqueData(Attribute layout, ArrayRef<int64_t> tensorShape);
|
||||
|
||||
SmallVector<unsigned> getThreadsPerCTA(Attribute layout);
|
||||
|
||||
SmallVector<unsigned>
|
||||
getShapePerCTA(Attribute layout,
|
||||
ArrayRef<int64_t> tensorShape = ArrayRef<int64_t>());
|
||||
|
||||
SmallVector<unsigned> getOrder(Attribute layout);
|
||||
|
||||
CTALayoutAttr getCTALayout(Attribute layout);
|
||||
|
||||
SmallVector<unsigned> getCTAsPerCGA(Attribute layout);
|
||||
|
||||
SmallVector<unsigned> getCTASplitNum(Attribute layout);
|
||||
|
||||
SmallVector<unsigned> getCTAOrder(Attribute layout);
|
||||
|
||||
/* The difference between ShapePerCTATile and ShapePerCTA:
|
||||
* (1) ShapePerCTATile is defined by SizePerThread * ThreadsPerWarp *
|
||||
* WarpsPerCTA in each dimension and is independent from the tensor shape.
|
||||
* (2) ShapePerCTA is defined by shape / CTASplitNum in each dimension.
|
||||
* (3) In the implementation of emitIndices, ShapePerCTATile will
|
||||
* be replicated or wraped to fit ShapePerCTA.
|
||||
*/
|
||||
SmallVector<unsigned>
|
||||
getShapePerCTATile(Attribute layout,
|
||||
ArrayRef<int64_t> tensorShape = ArrayRef<int64_t>());
|
||||
|
||||
SmallVector<int64_t> getShapePerCTA(ArrayRef<unsigned> CTASplitNum,
|
||||
ArrayRef<int64_t> shape);
|
||||
SmallVector<int64_t> getShapePerCTA(Attribute layout, ArrayRef<int64_t> shape);
|
||||
SmallVector<int64_t> getShapePerCTA(Type type);
|
||||
|
||||
unsigned getNumWarpsPerCTA(Attribute layout);
|
||||
|
||||
unsigned getNumCTAs(Attribute layout);
|
||||
|
||||
bool isaDistributedLayout(Attribute layout);
|
||||
|
||||
bool isSharedEncoding(Value value);
|
||||
|
||||
bool isExpensiveCat(CatOp cat, Attribute &targetEncoding);
|
||||
bool isExpensiveCat(CatOp cat, Attribute targetEncoding);
|
||||
|
||||
} // namespace gpu
|
||||
} // namespace triton
|
||||
|
||||
@@ -41,6 +41,19 @@ Right now, Triton implements two classes of layouts: shared, and distributed.
|
||||
}];
|
||||
}
|
||||
|
||||
//===----------------------------------------------------------------------===//
|
||||
// CTA Layout
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
||||
def CTALayoutAttr : TritonGPU_Attr<"CTALayout"> {
|
||||
let parameters = (
|
||||
ins
|
||||
ArrayRefParameter<"unsigned">:$CTAsPerCGA,
|
||||
ArrayRefParameter<"unsigned">:$CTASplitNum,
|
||||
ArrayRefParameter<"unsigned">:$CTAOrder
|
||||
);
|
||||
}
|
||||
|
||||
//===----------------------------------------------------------------------===//
|
||||
// Shared Layout Encoding
|
||||
//===----------------------------------------------------------------------===//
|
||||
@@ -64,19 +77,41 @@ are stored contiguously
|
||||
_ _ _ _ /\_ _ _ _
|
||||
A_{2, 2} A_{2, 3} A_{2, 0} A_{2, 1} ... [phase 1] \ per phase = 2
|
||||
A_{3, 2} A_{3, 3} A_{3, 0} A_{3, 1} ... [phase 1] /
|
||||
|
||||
For MMAv3 eg Hopper GMMA, hasLeadingOffset should be true. In this case,
|
||||
when the matrix is stored in shared memory, there will be an offset not
|
||||
only in the stride dimension, but also in the leading dimension. For example,
|
||||
a matrix of size 16x128 and data type I8 is stored in the shared memory with
|
||||
64B-swizzle mode. The offset of the element with index (0, 64) will be 16*64,
|
||||
compared to 1*64 when the hasLeadingOffset is false.
|
||||
}];
|
||||
|
||||
// swizzle info: vec, perPhase, maxPhase
|
||||
// order: the fastest-changing axis first
|
||||
let parameters = (
|
||||
ins
|
||||
// swizzle info
|
||||
"unsigned":$vec, "unsigned":$perPhase, "unsigned":$maxPhase,
|
||||
ArrayRefParameter<"unsigned", "order of axes by the rate of changing">:$order
|
||||
"unsigned":$vec,
|
||||
"unsigned":$perPhase,
|
||||
"unsigned":$maxPhase,
|
||||
ArrayRefParameter<"unsigned">:$order,
|
||||
"CTALayoutAttr":$CTALayout,
|
||||
"bool":$hasLeadingOffset
|
||||
);
|
||||
|
||||
let builders = [
|
||||
AttrBuilder<(ins "unsigned":$vec,
|
||||
"unsigned":$perPhase,
|
||||
"unsigned":$maxPhase,
|
||||
"ArrayRef<unsigned>":$order,
|
||||
"CTALayoutAttr":$CTALayout), [{
|
||||
bool hasLeadingOffset = false; // default value
|
||||
return $_get(context, vec, perPhase, maxPhase, order, CTALayout, hasLeadingOffset);
|
||||
}]>,
|
||||
|
||||
AttrBuilder<(ins "DotOperandEncodingAttr":$dotOpEnc,
|
||||
"ArrayRef<int64_t>":$shape,
|
||||
"ArrayRef<unsigned>":$order,
|
||||
"CTALayoutAttr":$CTALayout,
|
||||
"unsigned":$typeWidthInBit), [{
|
||||
|
||||
#ifdef USE_ROCM
|
||||
@@ -123,9 +158,10 @@ A_{3, 2} A_{3, 3} A_{3, 0} A_{3, 1} ... [phase 1] /
|
||||
auto mmaEnc = dotOpEnc.getParent().dyn_cast<MmaEncodingAttr>();
|
||||
|
||||
if(!mmaEnc)
|
||||
return $_get(context, 1, 1, 1, order);
|
||||
return get(context, 1, 1, 1, order, CTALayout);
|
||||
|
||||
int opIdx = dotOpEnc.getOpIdx();
|
||||
auto shapePerCTA = getShapePerCTA(CTALayout.getCTASplitNum(), shape);
|
||||
|
||||
// number of rows per phase
|
||||
|
||||
@@ -134,34 +170,42 @@ A_{3, 2} A_{3, 3} A_{3, 0} A_{3, 1} ... [phase 1] /
|
||||
|
||||
// ---- begin Volta ----
|
||||
if (mmaEnc.isVolta()) {
|
||||
int perPhase = 128 / (shape[order[0]] * (typeWidthInBit / 8));
|
||||
int perPhase = 128 / (shapePerCTA[order[0]] * (typeWidthInBit / 8));
|
||||
perPhase = std::max<int>(perPhase, 1);
|
||||
bool is_row = order[0] != 0;
|
||||
bool is_vec4 = opIdx == 0 ? !is_row && (shape[order[0]] <= 16) :
|
||||
is_row && (shape[order[0]] <= 16);
|
||||
bool is_vec4 = opIdx == 0 ? !is_row && (shapePerCTA[order[0]] <= 16) :
|
||||
is_row && (shapePerCTA[order[0]] <= 16);
|
||||
int pack_size = opIdx == 0 ? ((is_row || is_vec4) ? 1 : 2) :
|
||||
((is_row && !is_vec4) ? 2 : 1);
|
||||
int rep = 2 * pack_size;
|
||||
int maxPhase = (order[inner] == 1 ? 8 : 4) / perPhase;
|
||||
int vec = 2 * rep;
|
||||
return $_get(context, vec, perPhase, maxPhase, order);
|
||||
return get(context, vec, perPhase, maxPhase, order, CTALayout);
|
||||
}
|
||||
|
||||
// ---- begin Ampere ----
|
||||
if (mmaEnc.isAmpere()) {
|
||||
<<<<<<< HEAD
|
||||
int perPhase = 128 / (shape[order[0]] * 4 / dotOpEnc.getKWidth());
|
||||
=======
|
||||
int perPhase = 128 / (shapePerCTA[order[0]] * 4 / dotOpEnc.getKWidth());
|
||||
>>>>>>> 36fc54b6f28168d3644808bfe299f1ba06a36272
|
||||
perPhase = std::max<int>(perPhase, 1);
|
||||
std::vector<size_t> matShape = {8, 8, 4 * dotOpEnc.getKWidth()};
|
||||
// for now, disable swizzle when using transposed int8 tensor cores
|
||||
if ((32 / typeWidthInBit != dotOpEnc.getKWidth()) && order[0] == inner)
|
||||
<<<<<<< HEAD
|
||||
return $_get(context, 1, 1, 1, order);
|
||||
=======
|
||||
return get(context, 1, 1, 1, order, CTALayout);
|
||||
>>>>>>> 36fc54b6f28168d3644808bfe299f1ba06a36272
|
||||
|
||||
// --- handle A operand ---
|
||||
if (opIdx == 0) { // compute swizzling for A operand
|
||||
int vec = (order[0] == 1) ? matShape[2] : matShape[0]; // k : m
|
||||
int mmaStride = (order[0] == 1) ? matShape[0] : matShape[2];
|
||||
int maxPhase = mmaStride / perPhase;
|
||||
return $_get(context, vec, perPhase, maxPhase, order);
|
||||
return get(context, vec, perPhase, maxPhase, order, CTALayout);
|
||||
}
|
||||
|
||||
// --- handle B operand ---
|
||||
@@ -169,12 +213,19 @@ A_{3, 2} A_{3, 3} A_{3, 0} A_{3, 1} ... [phase 1] /
|
||||
int vec = (order[0] == 1) ? matShape[1] : matShape[2]; // n : k
|
||||
int mmaStride = (order[0] == 1) ? matShape[2] : matShape[1];
|
||||
int maxPhase = mmaStride / perPhase;
|
||||
return $_get(context, vec, perPhase, maxPhase, order);
|
||||
return get(context, vec, perPhase, maxPhase, order, CTALayout);
|
||||
}
|
||||
|
||||
llvm_unreachable("invalid operand index");
|
||||
}
|
||||
|
||||
// ---- begin version 3 ----
|
||||
if (mmaEnc.isHopper()) {
|
||||
llvm_unreachable("SharedEncodingAttr builder when the MMAEncodingAttr"
|
||||
" is Hopper has not been implemented yet");
|
||||
return $_get(context, 1, 1, 1, order, CTALayout, true);
|
||||
}
|
||||
|
||||
// ---- not implemented ----
|
||||
llvm_unreachable("unsupported swizzling for provided MMA version");
|
||||
}]>,
|
||||
@@ -182,9 +233,38 @@ A_{3, 2} A_{3, 3} A_{3, 0} A_{3, 1} ... [phase 1] /
|
||||
AttrBuilder<(ins "DotOperandEncodingAttr":$dotOpEnc,
|
||||
"ArrayRef<int64_t>":$shape,
|
||||
"ArrayRef<unsigned>":$order,
|
||||
"CTALayoutAttr":$CTALayout,
|
||||
"Type":$eltTy), [{
|
||||
unsigned bitwidth = eltTy.getIntOrFloatBitWidth();
|
||||
return get(context, dotOpEnc, shape, order, bitwidth);
|
||||
return get(context, dotOpEnc, shape, order, CTALayout, bitwidth);
|
||||
}]>,
|
||||
|
||||
AttrBuilder<(ins "ArrayRef<int64_t>":$shape,
|
||||
"ArrayRef<unsigned>":$order,
|
||||
"CTALayoutAttr":$CTALayout,
|
||||
"Type":$eltTy), [{
|
||||
auto shapePerCTA = getShapePerCTA(CTALayout.getCTASplitNum(), shape);
|
||||
|
||||
int32_t eleBitWidth = eltTy.getIntOrFloatBitWidth();
|
||||
int32_t vec = 128 / eleBitWidth, perPhase = 1, maxPhase = 1;
|
||||
|
||||
// get proper shared memory swizzling mode from the contiguous dimension
|
||||
// size of the origin blocked layout.
|
||||
auto contigDimSizeInByte = shapePerCTA[order[0]] * eleBitWidth / 8;
|
||||
if (contigDimSizeInByte >= 128 && contigDimSizeInByte % 128 == 0) {
|
||||
perPhase = 1;
|
||||
maxPhase = 8;
|
||||
} else if (contigDimSizeInByte >= 64 && contigDimSizeInByte % 64 == 0) {
|
||||
perPhase = 2;
|
||||
maxPhase = 4;
|
||||
} else if (contigDimSizeInByte >= 32 && contigDimSizeInByte % 32 == 0) {
|
||||
perPhase = 4;
|
||||
maxPhase = 2;
|
||||
} else {
|
||||
llvm_unreachable("unsupported shared memory layout for MMAv3");
|
||||
}
|
||||
|
||||
return $_get(context, vec, perPhase, maxPhase, order, CTALayout, true);
|
||||
}]>
|
||||
];
|
||||
|
||||
@@ -236,7 +316,7 @@ used to promote memory coalescing in LoadInst and StoreInst.
|
||||
It is characterized by three tuples -- thread tile size, warp tile size, and block tile size -- which
|
||||
specify the amount of elements owned by each CUDA thread, warp and CTA respectively.
|
||||
|
||||
For example, a row-major coalesced layout may partition a 16x16 tensor over 2 warps (i.e. 64 threads) as follows.
|
||||
Example 1, a row-major coalesced layout may partition a 16x16 tensor over 2 warps (i.e. 64 threads) as follows:
|
||||
|
||||
[ 0 0 1 1 2 2 3 3 ; 32 32 33 33 34 34 35 35 ]
|
||||
[ 0 0 1 1 2 2 3 3 ; 32 32 33 33 34 34 35 35 ]
|
||||
@@ -252,82 +332,143 @@ for
|
||||
sizePerThread = {2, 2}
|
||||
threadsPerWarp = {8, 4}
|
||||
warpsPerCTA = {1, 2}
|
||||
CTAsPerCGA = {1, 1}
|
||||
}>
|
||||
|
||||
Example 2, a row-major coalesced layout may partition a 32x32 tensor over 2 warps (i.e. 64 threads) as follows:
|
||||
|
||||
[ 0 0 1 1 2 2 3 3 ; 32 32 33 33 34 34 35 35 0 0 1 1 2 2 3 3 ; 32 32 33 33 34 34 35 35 ]
|
||||
[ 0 0 1 1 2 2 3 3 ; 32 32 33 33 34 34 35 35 0 0 1 1 2 2 3 3 ; 32 32 33 33 34 34 35 35 ]
|
||||
[ 4 4 5 5 6 6 7 7 ; 36 36 37 37 38 38 39 39 4 4 5 5 6 6 7 7 ; 36 36 37 37 38 38 39 39 ]
|
||||
[ 4 4 5 5 6 6 7 7 ; 36 36 37 37 38 38 39 39 4 4 5 5 6 6 7 7 ; 36 36 37 37 38 38 39 39 ]
|
||||
... ...
|
||||
[ 28 28 29 29 30 30 31 31 ; 60 60 61 61 62 62 63 63 28 28 29 29 30 30 31 31 ; 60 60 61 61 62 62 63 63 ]
|
||||
[ 28 28 29 29 30 30 31 31 ; 60 60 61 61 62 62 63 63 28 28 29 29 30 30 31 31 ; 60 60 61 61 62 62 63 63 ]
|
||||
[ 0 0 1 1 2 2 3 3 ; 32 32 33 33 34 34 35 35 0 0 1 1 2 2 3 3 ; 32 32 33 33 34 34 35 35 ]
|
||||
[ 0 0 1 1 2 2 3 3 ; 32 32 33 33 34 34 35 35 0 0 1 1 2 2 3 3 ; 32 32 33 33 34 34 35 35 ]
|
||||
[ 4 4 5 5 6 6 7 7 ; 36 36 37 37 38 38 39 39 4 4 5 5 6 6 7 7 ; 36 36 37 37 38 38 39 39 ]
|
||||
[ 4 4 5 5 6 6 7 7 ; 36 36 37 37 38 38 39 39 4 4 5 5 6 6 7 7 ; 36 36 37 37 38 38 39 39 ]
|
||||
... ...
|
||||
[ 28 28 29 29 30 30 31 31 ; 60 60 61 61 62 62 63 63 28 28 29 29 30 30 31 31 ; 60 60 61 61 62 62 63 63 ]
|
||||
[ 28 28 29 29 30 30 31 31 ; 60 60 61 61 62 62 63 63 28 28 29 29 30 30 31 31 ; 60 60 61 61 62 62 63 63 ]
|
||||
for
|
||||
|
||||
#triton_gpu.blocked_layout<{
|
||||
sizePerThread = {2, 2}
|
||||
threadsPerWarp = {8, 4}
|
||||
warpsPerCTA = {1, 2}
|
||||
CTAsPerCGA = {1, 1}
|
||||
}>
|
||||
|
||||
Example 3, A row-major coalesced layout may partition a 32x32 tensor over 2 warps (i.e. 64 threads) and
|
||||
4 CTAs (taking 2x2 for example) as follows:
|
||||
|
||||
CTA [0,0] CTA [0,1]
|
||||
[ 0 0 1 1 2 2 3 3 ; 32 32 33 33 34 34 35 35 ] [ 0 0 1 1 2 2 3 3 ; 32 32 33 33 34 34 35 35 ]
|
||||
[ 0 0 1 1 2 2 3 3 ; 32 32 33 33 34 34 35 35 ] [ 0 0 1 1 2 2 3 3 ; 32 32 33 33 34 34 35 35 ]
|
||||
[ 4 4 5 5 6 6 7 7 ; 36 36 37 37 38 38 39 39 ] [ 4 4 5 5 6 6 7 7 ; 36 36 37 37 38 38 39 39 ]
|
||||
[ 4 4 5 5 6 6 7 7 ; 36 36 37 37 38 38 39 39 ] [ 4 4 5 5 6 6 7 7 ; 36 36 37 37 38 38 39 39 ]
|
||||
... ...
|
||||
[ 28 28 29 29 30 30 31 31 ; 60 60 61 61 62 62 63 63 ] [ 28 28 29 29 30 30 31 31 ; 60 60 61 61 62 62 63 63 ]
|
||||
[ 28 28 29 29 30 30 31 31 ; 60 60 61 61 62 62 63 63 ] [ 28 28 29 29 30 30 31 31 ; 60 60 61 61 62 62 63 63 ]
|
||||
|
||||
CTA [1,0] CTA [1,1]
|
||||
[ 0 0 1 1 2 2 3 3 ; 32 32 33 33 34 34 35 35 ] [ 0 0 1 1 2 2 3 3 ; 32 32 33 33 34 34 35 35 ]
|
||||
[ 0 0 1 1 2 2 3 3 ; 32 32 33 33 34 34 35 35 ] [ 0 0 1 1 2 2 3 3 ; 32 32 33 33 34 34 35 35 ]
|
||||
[ 4 4 5 5 6 6 7 7 ; 36 36 37 37 38 38 39 39 ] [ 4 4 5 5 6 6 7 7 ; 36 36 37 37 38 38 39 39 ]
|
||||
[ 4 4 5 5 6 6 7 7 ; 36 36 37 37 38 38 39 39 ] [ 4 4 5 5 6 6 7 7 ; 36 36 37 37 38 38 39 39 ]
|
||||
... ...
|
||||
[ 28 28 29 29 30 30 31 31 ; 60 60 61 61 62 62 63 63 ] [ 28 28 29 29 30 30 31 31 ; 60 60 61 61 62 62 63 63 ]
|
||||
[ 28 28 29 29 30 30 31 31 ; 60 60 61 61 62 62 63 63 ] [ 28 28 29 29 30 30 31 31 ; 60 60 61 61 62 62 63 63 ]
|
||||
for
|
||||
|
||||
#triton_gpu.blocked_layout<{
|
||||
sizePerThread = {2, 2}
|
||||
threadsPerWarp = {8, 4}
|
||||
warpsPerCTA = {1, 2}
|
||||
CTAsPerCGA = {2, 2}
|
||||
}>
|
||||
}];
|
||||
|
||||
|
||||
let builders = [
|
||||
// Custom builder initializes sizePerWarp and sizePerCTA automatically
|
||||
// TODO: compiles on MacOS but not linux?
|
||||
// AttrBuilder<(ins "ArrayRef<unsigned>":$sizePerThread,
|
||||
// "ArrayRef<unsigned>":$threadsPerWarp,
|
||||
// "ArrayRef<unsigned>":$warpsPerCTA,
|
||||
// "ArrayRef<unsigned>":$order), [{
|
||||
// int rank = threadsPerWarp.size();
|
||||
// SmallVector<unsigned, 4> sizePerWarp(rank);
|
||||
// SmallVector<unsigned, 4> sizePerCTA(rank);
|
||||
// for (unsigned i = 0; i < rank; i++) {
|
||||
// sizePerWarp.push_back(sizePerThread[i] * threadsPerWarp[i]);
|
||||
// sizePerCTA.push_back(sizePerWarp[i] * warpsPerCTA[i]);
|
||||
// }
|
||||
// return $_get(context, sizePerThread, threadsPerWarp, warpsPerCTA, order, sizePerWarp, sizePerCTA);
|
||||
// }]>,
|
||||
// Custom builder initializes sizePerWarp and sizePerCTA automatically
|
||||
// Default builder takes sizePerThread, order and numWarps, and tries to
|
||||
// pack numWarps*32 threads in the provided order for use in a type
|
||||
// of the given shape.
|
||||
AttrBuilder<(ins "ArrayRef<int64_t>":$shape,
|
||||
"ArrayRef<unsigned>":$sizePerThread,
|
||||
"ArrayRef<unsigned>":$order,
|
||||
"unsigned":$numWarps,
|
||||
"unsigned":$threadsPerWarp), [{
|
||||
int rank = sizePerThread.size();
|
||||
unsigned remainingLanes = threadsPerWarp;
|
||||
unsigned remainingThreads = numWarps*threadsPerWarp;
|
||||
unsigned remainingWarps = numWarps;
|
||||
unsigned prevLanes = 1;
|
||||
unsigned prevWarps = 1;
|
||||
SmallVector<unsigned, 4> rankedThreadsPerWarp(rank);
|
||||
SmallVector<unsigned, 4> warpsPerCTA(rank);
|
||||
for (int _dim = 0; _dim < rank - 1; ++_dim) {
|
||||
int i = order[_dim];
|
||||
unsigned threadsPerCTA = std::clamp<unsigned>(remainingThreads, 1, shape[i] / sizePerThread[i]);
|
||||
rankedThreadsPerWarp[i] = std::clamp<unsigned>(threadsPerCTA, 1, remainingLanes);
|
||||
warpsPerCTA[i] = std::clamp<unsigned>(threadsPerCTA / rankedThreadsPerWarp[i], 1, remainingWarps);
|
||||
remainingWarps /= warpsPerCTA[i];
|
||||
remainingLanes /= rankedThreadsPerWarp[i];
|
||||
remainingThreads /= threadsPerCTA;
|
||||
prevLanes *= rankedThreadsPerWarp[i];
|
||||
prevWarps *= warpsPerCTA[i];
|
||||
}
|
||||
// Expand the last dimension to fill the remaining lanes and warps
|
||||
rankedThreadsPerWarp[order[rank-1]] = threadsPerWarp / prevLanes;
|
||||
warpsPerCTA[order[rank-1]] = numWarps / prevWarps;
|
||||
|
||||
return $_get(context, sizePerThread, rankedThreadsPerWarp, warpsPerCTA, order);
|
||||
|
||||
}]>
|
||||
];
|
||||
|
||||
let extraClassDeclaration = extraBaseClassDeclaration # [{
|
||||
SliceEncodingAttr squeeze(int axis);
|
||||
}];
|
||||
|
||||
let parameters = (
|
||||
ins
|
||||
ArrayRefParameter<"unsigned">:$sizePerThread,
|
||||
ArrayRefParameter<"unsigned">:$threadsPerWarp,
|
||||
ArrayRefParameter<"unsigned">:$warpsPerCTA,
|
||||
// fastest-changing axis first
|
||||
ArrayRefParameter<
|
||||
"unsigned",
|
||||
"order of axes by the rate of changing"
|
||||
>:$order
|
||||
// These attributes can be inferred from the rest
|
||||
// ArrayRefParameter<"unsigned">:$sizePerWarp,
|
||||
// ArrayRefParameter<"unsigned">:$sizePerCTA
|
||||
ArrayRefParameter<"unsigned">:$order, // the fastest-changing axis first
|
||||
"CTALayoutAttr":$CTALayout
|
||||
);
|
||||
|
||||
let builders = [
|
||||
AttrBuilder<(ins "ArrayRef<int64_t>":$shape,
|
||||
"ArrayRef<unsigned>":$sizePerThread,
|
||||
"ArrayRef<unsigned>":$order,
|
||||
"unsigned":$numWarps,
|
||||
"unsigned":$numThreadsPerWarp,
|
||||
"CTALayoutAttr":$CTALayout), [{
|
||||
unsigned rank = sizePerThread.size();
|
||||
SmallVector<unsigned, 4> threadsPerWarp(rank);
|
||||
SmallVector<unsigned, 4> warpsPerCTA(rank);
|
||||
SmallVector<int64_t> shapePerCTA = getShapePerCTA(CTALayout.getCTASplitNum(), shape);
|
||||
|
||||
unsigned remainingLanes = numThreadsPerWarp;
|
||||
unsigned remainingThreads = numWarps * numThreadsPerWarp;
|
||||
unsigned remainingWarps = numWarps;
|
||||
unsigned prevLanes = 1;
|
||||
unsigned prevWarps = 1;
|
||||
|
||||
// starting from the contiguous dimension
|
||||
for (unsigned d = 0; d < rank - 1; ++d) {
|
||||
unsigned i = order[d];
|
||||
unsigned threadsPerCTA = std::clamp<unsigned>(remainingThreads, 1, shapePerCTA[i] / sizePerThread[i]);
|
||||
threadsPerWarp[i] = std::clamp<unsigned>(threadsPerCTA, 1, remainingLanes);
|
||||
warpsPerCTA[i] = std::clamp<unsigned>(threadsPerCTA / threadsPerWarp[i], 1, remainingWarps);
|
||||
remainingWarps /= warpsPerCTA[i];
|
||||
remainingLanes /= threadsPerWarp[i];
|
||||
remainingThreads /= threadsPerCTA;
|
||||
prevLanes *= threadsPerWarp[i];
|
||||
prevWarps *= warpsPerCTA[i];
|
||||
}
|
||||
|
||||
// Expand the last dimension to fill the remaining lanes and warps
|
||||
threadsPerWarp[order[rank - 1]] = numThreadsPerWarp / prevLanes;
|
||||
warpsPerCTA[order[rank - 1]] = numWarps / prevWarps;
|
||||
|
||||
return $_get(context, sizePerThread, threadsPerWarp, warpsPerCTA, order, CTALayout);
|
||||
}]>,
|
||||
|
||||
AttrBuilder<(ins "ArrayRef<int64_t>":$shape,
|
||||
"ArrayRef<unsigned>":$sizePerThread,
|
||||
"ArrayRef<unsigned>":$order,
|
||||
"unsigned":$numWarps,
|
||||
"unsigned":$numThreadsPerWarp,
|
||||
"unsigned":$numCTAs), [{
|
||||
unsigned rank = sizePerThread.size();
|
||||
SmallVector<unsigned, 4> CTAsPerCGA(rank);
|
||||
SmallVector<unsigned, 4> CTASplitNum(rank);
|
||||
ArrayRef<unsigned> CTAOrder = order;
|
||||
|
||||
unsigned remainingCTAs = numCTAs;
|
||||
|
||||
// starting from the most strided dimension
|
||||
for (int d = rank - 1; d >= 0; --d) {
|
||||
unsigned i = order[d];
|
||||
CTAsPerCGA[i] = std::clamp<unsigned>(remainingCTAs, 1, shape[i] / sizePerThread[i]);
|
||||
CTASplitNum[i] = CTAsPerCGA[i];
|
||||
remainingCTAs /= CTAsPerCGA[i];
|
||||
}
|
||||
|
||||
CTAsPerCGA[rank - 1] *= remainingCTAs; // wrap at CTA level
|
||||
|
||||
CTALayoutAttr CTALayout = CTALayoutAttr::get(context, CTAsPerCGA, CTASplitNum, CTAOrder);
|
||||
return get(context, shape, sizePerThread, order, numWarps, numThreadsPerWarp, CTALayout);
|
||||
}]>
|
||||
];
|
||||
|
||||
let extraClassDeclaration = extraBaseClassDeclaration # [{
|
||||
SliceEncodingAttr squeeze(int axis);
|
||||
}];
|
||||
|
||||
let hasCustomAssemblyFormat = 1;
|
||||
}
|
||||
|
||||
@@ -423,13 +564,17 @@ For example, the matrix L corresponding to blockTileSize=[32,16] is:
|
||||
ins
|
||||
"unsigned":$versionMajor,
|
||||
"unsigned":$versionMinor,
|
||||
ArrayRefParameter<"unsigned">:$warpsPerCTA
|
||||
ArrayRefParameter<"unsigned">:$warpsPerCTA,
|
||||
"CTALayoutAttr":$CTALayout,
|
||||
ArrayRefParameter<"unsigned">:$instrShape
|
||||
);
|
||||
|
||||
let builders = [
|
||||
// Specially for MMAV1(Volta)
|
||||
AttrBuilder<(ins "int":$versionMajor,
|
||||
"int":$numWarps,
|
||||
"CTALayoutAttr":$CTALayout,
|
||||
"ArrayRef<unsigned>":$instrShape,
|
||||
"ArrayRef<int64_t>":$shapeC,
|
||||
"bool":$isARow,
|
||||
"bool":$isBRow,
|
||||
@@ -443,7 +588,6 @@ For example, the matrix L corresponding to blockTileSize=[32,16] is:
|
||||
(isAVec4 * (1<<2)) |\
|
||||
(isBVec4 * (1<<3));
|
||||
|
||||
|
||||
// TODO: Share code with
|
||||
// DotOpMmaV1ConversionHelper::AParam/BParam, since same code to compute the
|
||||
// rep,spw and fpw.
|
||||
@@ -468,11 +612,13 @@ For example, the matrix L corresponding to blockTileSize=[32,16] is:
|
||||
wpt[1] = std::clamp<int>(wpt[1] * 2, 1, shapeC[1] / spw[1]);
|
||||
} while (wpt_nm1 != wpt);
|
||||
|
||||
return $_get(context, versionMajor, versionMinor, wpt);
|
||||
return $_get(context, versionMajor, versionMinor, wpt, CTALayout, instrShape);
|
||||
}]>,
|
||||
|
||||
AttrBuilder<(ins "int":$versionMajor,
|
||||
"int":$numWarps,
|
||||
"CTALayoutAttr":$CTALayout,
|
||||
"ArrayRef<unsigned>":$instrShape,
|
||||
"ArrayRef<int64_t>":$shapeA,
|
||||
"ArrayRef<int64_t>":$shapeB,
|
||||
"ArrayRef<int64_t>":$shapeC,
|
||||
@@ -482,15 +628,21 @@ For example, the matrix L corresponding to blockTileSize=[32,16] is:
|
||||
assert(versionMajor == 1 && "This builder is specially for versionMajor==1");
|
||||
bool isAVec4 = !isARow && (shapeA[isARow] <= 16);
|
||||
bool isBVec4 = isBRow && (shapeB[isBRow] <= 16);
|
||||
return get(context, versionMajor, numWarps, shapeC, isARow, isBRow, isAVec4, isBVec4, id);
|
||||
return get(context, versionMajor, numWarps, CTALayout, instrShape, shapeC, isARow, isBRow, isAVec4, isBVec4, id);
|
||||
}]>
|
||||
];
|
||||
|
||||
let extraClassDeclaration = extraBaseClassDeclaration # [{
|
||||
bool isVolta() const;
|
||||
bool isTuring() const;
|
||||
bool isAmpere() const;
|
||||
bool isHopper() const;
|
||||
|
||||
unsigned getElemsPerThreadOfOperand(int opIdx, ArrayRef<int64_t> shape) const;
|
||||
|
||||
// Get [isARow, isBRow, isAVec4, isBVec4, id] from versionMinor
|
||||
std::tuple<bool, bool, bool, bool, int> decodeVoltaLayoutStates() const;
|
||||
|
||||
// Number of bits in versionMinor to hold the ID of the MMA encoding instance.
|
||||
// Here 5 bits can hold 32 IDs in a single module.
|
||||
static constexpr int numBitsToHoldMmaV1ID{5};
|
||||
@@ -670,6 +822,4 @@ section 9.7.13.4.1 for more details.
|
||||
}];
|
||||
}
|
||||
|
||||
|
||||
|
||||
#endif
|
||||
|
||||
@@ -16,6 +16,7 @@ def TritonGPU_Dialect : Dialect {
|
||||
|
||||
let dependentDialects = [
|
||||
"triton::TritonDialect",
|
||||
"mlir::triton::nvgpu::NVGPUDialect",
|
||||
"mlir::gpu::GPUDialect",
|
||||
"tensor::TensorDialect",
|
||||
];
|
||||
@@ -23,14 +24,27 @@ def TritonGPU_Dialect : Dialect {
|
||||
let extraClassDeclaration = [{
|
||||
static std::string getNumWarpsAttrName() { return "triton_gpu.num-warps"; }
|
||||
static int getNumWarps(ModuleOp mod) {
|
||||
Attribute numWarps = mod->getDiscardableAttr("triton_gpu.num-warps");
|
||||
if(!numWarps)
|
||||
if(!mod->hasAttr("triton_gpu.num-warps"))
|
||||
llvm::report_fatal_error(
|
||||
"TritonGPU module should contain a triton_gpu.num-warps attribute");
|
||||
return numWarps.cast<IntegerAttr>().getInt();
|
||||
return mod->getAttr("triton_gpu.num-warps").cast<IntegerAttr>().getInt();
|
||||
}
|
||||
static int getNumCTAs(ModuleOp mod) {
|
||||
if(!mod->hasAttr("triton_gpu.num-ctas"))
|
||||
llvm::report_fatal_error(
|
||||
"TritonGPU module should contain a triton_gpu.num-ctas attribute");
|
||||
return mod->getAttr("triton_gpu.num-ctas").cast<IntegerAttr>().getInt();
|
||||
}
|
||||
static int getComputeCapability(ModuleOp mod) {
|
||||
if(!mod->hasAttr("triton_gpu.compute-capability"))
|
||||
llvm::report_fatal_error(
|
||||
"TritonGPU module should contain a triton_gpu.compute-capability attribute");
|
||||
return mod->getAttrOfType<IntegerAttr>("triton_gpu.compute-capability").getInt();
|
||||
}
|
||||
void registerTypes();
|
||||
|
||||
static std::string getThreadsPerWarpAttrName() { return "triton_gpu.threads-per-warp"; }
|
||||
|
||||
static int getThreadsPerWarp(ModuleOp mod) {
|
||||
Attribute threadsPerWarp = mod->getDiscardableAttr("triton_gpu.threads-per-warp");
|
||||
if(!threadsPerWarp) {
|
||||
@@ -38,6 +52,7 @@ def TritonGPU_Dialect : Dialect {
|
||||
}
|
||||
return threadsPerWarp.cast<IntegerAttr>().getInt();
|
||||
}
|
||||
<<<<<<< HEAD
|
||||
static int getSharedSize(ModuleOp mod) {
|
||||
Attribute sharedAttr = mod->getDiscardableAttr("triton_gpu.shared");
|
||||
if(!sharedAttr) {
|
||||
@@ -46,6 +61,8 @@ def TritonGPU_Dialect : Dialect {
|
||||
return sharedAttr.cast<IntegerAttr>().getInt();
|
||||
}
|
||||
|
||||
=======
|
||||
>>>>>>> 36fc54b6f28168d3644808bfe299f1ba06a36272
|
||||
}];
|
||||
|
||||
let useDefaultAttributePrinterParser = 1;
|
||||
|
||||
@@ -2,6 +2,7 @@
|
||||
#define TRITONGPU_OPS
|
||||
|
||||
include "triton/Dialect/TritonGPU/IR/TritonGPUDialect.td"
|
||||
include "triton/Dialect/TritonGPU/IR/TritonGPUTypes.td"
|
||||
include "triton/Dialect/TritonGPU/IR/TritonGPUAttrDefs.td"
|
||||
include "mlir/Dialect/Arith/IR/ArithBase.td"
|
||||
include "triton/Dialect/Triton/IR/TritonTypes.td"
|
||||
@@ -46,6 +47,20 @@ def TTG_AsyncWaitOp : TTG_Op<"async_wait"> {
|
||||
}];
|
||||
}
|
||||
|
||||
def TTG_AsyncBulkWaitOp : TTG_Op<"async_bulk_wait"> {
|
||||
let summary = "async bulk wait";
|
||||
|
||||
let arguments = (ins I32Attr:$num);
|
||||
|
||||
let assemblyFormat = "attr-dict";
|
||||
|
||||
let extraClassDeclaration = [{
|
||||
static bool isSupported(int computeCapability) {
|
||||
return computeCapability >= 90;
|
||||
}
|
||||
}];
|
||||
}
|
||||
|
||||
def TTG_AsyncCommitGroupOp : TTG_Op<"async_commit_group"> {
|
||||
let summary = "async commit group";
|
||||
|
||||
@@ -58,6 +73,18 @@ def TTG_AsyncCommitGroupOp : TTG_Op<"async_commit_group"> {
|
||||
}];
|
||||
}
|
||||
|
||||
def TTG_AsyncBulkCommitGroupOp : TTG_Op<"async_bulk_commit_group"> {
|
||||
let summary = "async bulk commit group";
|
||||
|
||||
let assemblyFormat = "attr-dict";
|
||||
|
||||
let extraClassDeclaration = [{
|
||||
static bool isSupported(int computeCapability) {
|
||||
return computeCapability >= 90;
|
||||
}
|
||||
}];
|
||||
}
|
||||
|
||||
|
||||
// Port Arith_CmpIOp & Arith_CmpFOp & Std_SelectOp to TritonGPU.
|
||||
// This is needed because these ops don't
|
||||
@@ -106,6 +133,98 @@ def TTG_SelectOp : TTG_Op<"select", [Pure, Elementwise,
|
||||
let results = (outs TT_Type:$result);
|
||||
}
|
||||
|
||||
// TODO[goostavz]: extract a base class for InsertSlice & InsertSliceAsync once the op definition is verified
|
||||
def TTG_InsertSliceOp : TTG_Op<"insert_slice",
|
||||
[AttrSizedOperandSegments,
|
||||
ResultsAreSharedEncoding,
|
||||
MemoryEffects<[MemRead, MemWrite]>,
|
||||
TypesMatchWith<"infer mask type from src type",
|
||||
"src", "mask", "getI1SameShape($_self)",
|
||||
"($_op.getOperands().size() <= 3) || std::equal_to<>()">,
|
||||
TypesMatchWith<"infer other type from src type",
|
||||
"src", "other", "getPointeeType($_self)",
|
||||
"($_op.getOperands().size() <= 4) || std::equal_to<>()">]> {
|
||||
let summary = "insert slice";
|
||||
|
||||
let description = [{
|
||||
This operation inserts a tensor `$src` into another tensor `$dst` as specified by the operation’s
|
||||
`$index` argument and `$axis` attribute.
|
||||
|
||||
It returns a copy of `$dst` with the proper slice updated with the value of `$src`.
|
||||
|
||||
When converting from `tt.load` to `triton_gpu.insert_slice`, the `$evict`, `$cache`, and `$isVolatile` fields
|
||||
might be ignored on certain hardware. For example, on NVIDIA GPUs, the cache policy is determined by the backend,
|
||||
and `$evict` and `$isVolatile` are ignored because they apply to L1 cache only.
|
||||
|
||||
The insert_slice operation supports the following arguments:
|
||||
|
||||
* src: the tensor that is inserted.
|
||||
* dst: the tensor into which the `$src` tensor is inserted.
|
||||
* index: the index of the `$src` tensor at the given `$axis` from which the `$dst` tensor is inserted into
|
||||
* mask: optional tensor-rank number of boolean masks which specify which
|
||||
elements of the `$src` tensor are inserted into the `$dst` tensor.
|
||||
* other: optional tensor-rank number of other tensors which specify what
|
||||
values are inserted into the `$dst` tensor if the corresponding
|
||||
element of the `$mask` tensor is false.
|
||||
|
||||
ttgpu.load_tile_async depracate
|
||||
triton_gpu.insert_slice might be further lowered into triton_gpu_async for different hardware implementations
|
||||
|
||||
like tt.load, ttgpu.insert_slice/insert_slice_async has two modes up to the type of src
|
||||
mode 1: ptr/src is a tensor of pointers
|
||||
mode 2: ptr/src is a tensor pointer
|
||||
|
||||
Some typical lowering paths are:
|
||||
in case the load is pipelined by the pipeline pass( load is inside kBlock loop, which means "pipeline pass):
|
||||
Load from global + store to shared : tt.load(mode 1) -(tt->ttgpu+Coalesce)-> tt.load(mode 1) -(Pipeline)-> ttgpu.insert_slice(mode 1)
|
||||
Non-bulk cp.async : tt.load(mode 1) -(tt->ttgpu+Coalesce)-> tt.load(mode 1) -(Pipeline)-> ttgpu.insert_slice(mode 1) -(MaterializeLoad)> ttgpu.insert_slice_async(mode 1) + ttgpu.await-> llvm
|
||||
TMA load : tt.load(mode 2) -(tt->ttgpu+Coalesce)-> tt.load(mode 2) -(Pipeline)-> ttgpu.insert_slice(mode 2) -(MaterializeLoad)> ttgpu.insert_slice_async_v2(mode 2) + ttgpu.await-> llvm
|
||||
|
||||
otherwise:
|
||||
Load from global + store to shared : tt.load(mode 1) -(tt->ttgpu+Coalesce)-> tt.load(mode 1)
|
||||
Non-bulk cp.async : tt.load(mode 1) -(tt->ttgpu+Coalesce)-> tt.load(mode 1) -> ... -(MaterializeLoad)-> ttgpu.insert_slice_async(mode 1) + ttgpu.await -> llvm
|
||||
TMA load : tt.load(mode 2) -(tt->ttgpu+Coalesce)-> tt.load(mode 2) -> ... -(MaterializeLoad)-> ttgpu.insert_slice_async(mode 2) + ttgpu.await -> llvm
|
||||
|
||||
Example:
|
||||
|
||||
```
|
||||
%1 = triton_gpu.alloc_tensor : tensor<2x32xf32>
|
||||
%2 = triton_gpu.insert_slice %0, %1, %index { axis = 0 } : tensor<32x!tt.ptr<f32>, #AL> -> tensor<2x32xf32, #A>
|
||||
```
|
||||
}];
|
||||
|
||||
let arguments = (ins TT_PtrLike:$src, TT_Tensor:$dst, I32:$index,
|
||||
Optional<I1Tensor>:$mask, Optional<TT_Type>:$other,
|
||||
TT_CacheModifierAttr:$cache, TT_EvictionPolicyAttr:$evict,
|
||||
BoolAttr:$isVolatile, I32Attr:$axis);
|
||||
|
||||
let builders = [
|
||||
OpBuilder<(ins "Value":$src, "Value":$dst, "Value":$index,
|
||||
"triton::CacheModifier":$cache,
|
||||
"triton::EvictionPolicy":$evict, "bool":$isVolatile, "int":$axis)>,
|
||||
OpBuilder<(ins "Value":$src, "Value":$dst, "Value":$index, "Value":$mask,
|
||||
"triton::CacheModifier":$cache,
|
||||
"triton::EvictionPolicy":$evict, "bool":$isVolatile, "int":$axis)>,
|
||||
OpBuilder<(ins "Value":$src, "Value":$dst, "Value":$index,
|
||||
"Value":$mask, "Value":$other,
|
||||
"triton::CacheModifier":$cache,
|
||||
"triton::EvictionPolicy":$evict, "bool":$isVolatile, "int":$axis)>,
|
||||
];
|
||||
|
||||
let results = (outs TT_Tensor:$result);
|
||||
|
||||
let extraClassDeclaration = [{
|
||||
static DenseSet<unsigned> getEligibleLoadByteWidth(int computeCapability) {
|
||||
DenseSet<unsigned> validLoadBytes;
|
||||
if (computeCapability >= 80) {
|
||||
validLoadBytes = {4, 8, 16};
|
||||
}
|
||||
return validLoadBytes;
|
||||
}
|
||||
}];
|
||||
|
||||
let hasCustomAssemblyFormat = 1;
|
||||
}
|
||||
|
||||
|
||||
def TTG_ExtractSliceOp : TTG_Op<"extract_slice",
|
||||
@@ -173,7 +292,8 @@ def TTG_ExtractSliceOp : TTG_Op<"extract_slice",
|
||||
def TTG_InsertSliceAsyncOp : TTG_Op<"insert_slice_async",
|
||||
[AttrSizedOperandSegments,
|
||||
ResultsAreSharedEncoding,
|
||||
MemoryEffects<[MemRead]>,
|
||||
// TODO: Check if MemWrite will degrade performance of non-warp-specialized kernel
|
||||
MemoryEffects<[MemRead, MemWrite]>,
|
||||
TypesMatchWith<"infer mask type from src type",
|
||||
"src", "mask", "getI1SameShape($_self)",
|
||||
"($_op.getOperands().size() <= 3) || std::equal_to<>()">,
|
||||
@@ -219,7 +339,7 @@ def TTG_InsertSliceAsyncOp : TTG_Op<"insert_slice_async",
|
||||
```
|
||||
}];
|
||||
|
||||
let arguments = (ins TT_PtrTensor:$src, TT_Tensor:$dst, I32:$index,
|
||||
let arguments = (ins TT_PtrLike:$src, TT_Tensor:$dst, I32:$index,
|
||||
Optional<I1Tensor>:$mask, Optional<TT_Type>:$other,
|
||||
TT_CacheModifierAttr:$cache, TT_EvictionPolicyAttr:$evict,
|
||||
BoolAttr:$isVolatile, I32Attr:$axis);
|
||||
|
||||
26
include/triton/Dialect/TritonGPU/IR/TritonGPUTypes.td
Normal file
26
include/triton/Dialect/TritonGPU/IR/TritonGPUTypes.td
Normal file
@@ -0,0 +1,26 @@
|
||||
#ifndef TRITONGPU_TYPES
|
||||
#define TRITONGPU_TYPES
|
||||
|
||||
include "triton/Dialect/TritonGPU/IR/TritonGPUDialect.td"
|
||||
include "mlir/IR/AttrTypeBase.td"
|
||||
|
||||
class TTG_TypeDef<string name, string _mnemonic, list<Trait> traits = []>
|
||||
: TypeDef<TritonGPU_Dialect, name, traits> {
|
||||
let mnemonic = _mnemonic;
|
||||
}
|
||||
|
||||
def TTG_TokenType : TTG_TypeDef<"Token", "token"> {
|
||||
let parameters = (ins "int32_t":$type);
|
||||
|
||||
let builders = [
|
||||
TypeBuilder<(ins "unsigned":$type), [{
|
||||
return $_get($_ctxt, type);
|
||||
}]>
|
||||
];
|
||||
|
||||
let hasCustomAssemblyFormat = 1;
|
||||
|
||||
let skipDefaultBuilders = 1;
|
||||
}
|
||||
|
||||
#endif
|
||||
10
include/triton/Dialect/TritonGPU/IR/Types.h
Normal file
10
include/triton/Dialect/TritonGPU/IR/Types.h
Normal file
@@ -0,0 +1,10 @@
|
||||
#ifndef TRITONGPU_IR_TYPES_H_
|
||||
#define TRITONGPU_IR_TYPES_H_
|
||||
|
||||
#include "mlir/IR/TypeSupport.h"
|
||||
#include "mlir/IR/Types.h"
|
||||
|
||||
#define GET_TYPEDEF_CLASSES
|
||||
#include "triton/Dialect/TritonGPU/IR/Types.h.inc"
|
||||
|
||||
#endif // TRITON_IR_TYPES_H_
|
||||
@@ -2,9 +2,14 @@
|
||||
#define TRITON_DIALECT_TRITONGPU_TRANSFORMS_PASSES_H_
|
||||
|
||||
#include "mlir/Pass/Pass.h"
|
||||
#include "triton/Dialect/TritonNvidiaGPU/IR/Dialect.h"
|
||||
|
||||
namespace mlir {
|
||||
std::unique_ptr<Pass> createTritonGPUPipelinePass(int numStages = 2);
|
||||
|
||||
std::unique_ptr<Pass> createTritonGPUPipelinePass(int numStages = 3,
|
||||
int numWarps = 4,
|
||||
int numCTAs = 1,
|
||||
int computeCapability = 80);
|
||||
|
||||
std::unique_ptr<Pass> createTritonGPUStreamPipelinePass();
|
||||
|
||||
@@ -27,6 +32,8 @@ std::unique_ptr<Pass> createTritonGPUVerifier();
|
||||
|
||||
std::unique_ptr<Pass> createTritonGPUOptimizeDotOperandsPass();
|
||||
|
||||
std::unique_ptr<Pass> createTritonGPUOptimizeEpiloguePass();
|
||||
|
||||
/// Generate the code for registering passes.
|
||||
#define GEN_PASS_REGISTRATION
|
||||
#include "triton/Dialect/TritonGPU/Transforms/Passes.h.inc"
|
||||
|
||||
@@ -14,13 +14,23 @@ def TritonGPUPipeline : Pass<"tritongpu-pipeline", "mlir::ModuleOp"> {
|
||||
let constructor = "mlir::createTritonGPUPipelinePass()";
|
||||
|
||||
let dependentDialects = ["mlir::triton::gpu::TritonGPUDialect",
|
||||
"mlir::triton::nvidia_gpu::TritonNvidiaGPUDialect",
|
||||
"mlir::scf::SCFDialect",
|
||||
"mlir::arith::ArithDialect"];
|
||||
|
||||
let options = [
|
||||
Option<"numStages", "num-stages",
|
||||
"int32_t", /*default*/"2",
|
||||
"number of pipeline stages">
|
||||
"int32_t", /*default*/"3",
|
||||
"number of pipeline stages">,
|
||||
Option<"numWarps", "num-warps",
|
||||
"int32_t", /*default*/"4",
|
||||
"number of warps per block">,
|
||||
Option<"numCTAs", "num-ctas",
|
||||
"int32_t", /*default*/"1",
|
||||
"number of CTAs per CGA">,
|
||||
Option<"computeCapability", "compute-capability",
|
||||
"int32_t", /*default*/"80",
|
||||
"device compute capability">
|
||||
];
|
||||
}
|
||||
|
||||
@@ -65,6 +75,7 @@ def TritonGPUAccelerateMatmul : Pass<"tritongpu-accelerate-matmul", "mlir::Modul
|
||||
let constructor = "mlir::createTritonGPUAccelerateMatmulPass()";
|
||||
|
||||
let dependentDialects = ["mlir::triton::gpu::TritonGPUDialect",
|
||||
"mlir::triton::nvidia_gpu::TritonNvidiaGPUDialect",
|
||||
"mlir::triton::TritonDialect"];
|
||||
|
||||
let options = [
|
||||
@@ -85,6 +96,7 @@ def TritonGPUOptimizeDotOperands : Pass<"tritongpu-optimize-dot-operands", "mlir
|
||||
let constructor = "mlir::createTritonGPUOptimizeDotOperandsPass()";
|
||||
|
||||
let dependentDialects = ["mlir::triton::gpu::TritonGPUDialect",
|
||||
"mlir::triton::nvidia_gpu::TritonNvidiaGPUDialect",
|
||||
"mlir::triton::TritonDialect"];
|
||||
}
|
||||
|
||||
@@ -111,6 +123,20 @@ def TritonGPURemoveLayoutConversions : Pass<"tritongpu-remove-layout-conversions
|
||||
|
||||
let dependentDialects = ["mlir::triton::gpu::TritonGPUDialect",
|
||||
"mlir::triton::TritonDialect"];
|
||||
|
||||
}
|
||||
|
||||
def TritonGPUOptimizeEpilogue : Pass<"tritongpu-optimize-epilogue", "mlir::ModuleOp"> {
|
||||
let summary = "Optimize epilogue: (1) Store accumulators directly without going thorough SMEM in epilogue.";
|
||||
|
||||
let description = [{
|
||||
}];
|
||||
|
||||
let constructor = "mlir::createTritonGPUOptimizeEpiloguePass()";
|
||||
|
||||
let dependentDialects = ["mlir::triton::gpu::TritonGPUDialect",
|
||||
"mlir::triton::TritonDialect"];
|
||||
|
||||
}
|
||||
|
||||
def TritonGPUReorderInstructions: Pass<"tritongpu-reorder-instructions", "mlir::ModuleOp"> {
|
||||
|
||||
@@ -13,15 +13,17 @@ namespace mlir {
|
||||
|
||||
class TritonGPUTypeConverter : public TypeConverter {
|
||||
public:
|
||||
TritonGPUTypeConverter(MLIRContext *context, int numWarps,
|
||||
int threadsPerWarp);
|
||||
TritonGPUTypeConverter(MLIRContext *context, int numWarps, int threadsPerWarp,
|
||||
int numCTAs);
|
||||
int getNumWarps() const { return numWarps; }
|
||||
int getThreadsPerWarp() const { return threadsPerWarp; }
|
||||
int getNumCTAs() const { return numCTAs; }
|
||||
|
||||
private:
|
||||
MLIRContext *context;
|
||||
int numWarps;
|
||||
int threadsPerWarp;
|
||||
int numCTAs;
|
||||
};
|
||||
|
||||
class TritonGPUConversionTarget : public ConversionTarget {
|
||||
|
||||
@@ -10,33 +10,143 @@
|
||||
|
||||
namespace mlir {
|
||||
|
||||
LogicalResult fixupLoops(ModuleOp mod);
|
||||
namespace triton {
|
||||
class LoadOp;
|
||||
class StoreOp;
|
||||
class FuncOp;
|
||||
namespace gpu {
|
||||
class SharedEncodingAttr;
|
||||
}
|
||||
} // namespace triton
|
||||
|
||||
// TODO: Interface
|
||||
LogicalResult invertEncoding(Attribute targetEncoding, Operation *op,
|
||||
Attribute &ret);
|
||||
SmallVector<unsigned, 3> mmaVersionToInstrShape(int version,
|
||||
const ArrayRef<int64_t> &shape,
|
||||
RankedTensorType type);
|
||||
|
||||
bool isExpensiveLoadOrStore(Operation *op, Attribute &targetEncoding);
|
||||
/// Returns true if the Load is for TMA
|
||||
bool isLoadFromTensorPtr(triton::LoadOp op);
|
||||
|
||||
bool isExpensiveToRemat(Operation *op, Attribute &targetEncoding);
|
||||
/// Returns true if the store is for TMA
|
||||
bool isStoreToTensorPtr(triton::StoreOp op);
|
||||
|
||||
// skipInit is True when we only consider the operands of the initOp but
|
||||
// not the initOp itself.
|
||||
int simulateBackwardRematerialization(
|
||||
Operation *initOp, SetVector<Operation *> &processed,
|
||||
SetVector<Attribute> &layout, llvm::MapVector<Value, Attribute> &toConvert,
|
||||
Attribute targetEncoding);
|
||||
/// Return the first consumer of v
|
||||
Operation *getFirstUser(Value v);
|
||||
|
||||
/// Return the proper SharedEncodingAttr according to shape/order
|
||||
triton::gpu::SharedEncodingAttr getSharedEncoding(RankedTensorType tensorTy);
|
||||
|
||||
/* Dump Triton IR in graphviz dot format.
|
||||
*
|
||||
* You can override `onValue` and `onOperation` in a subclass to mark
|
||||
* specific Values and Operations. The below subclass
|
||||
* GraphLayoutMarker is an example.
|
||||
*
|
||||
* Default NodeInfo for Value nodes:
|
||||
* {{"shape": "box"},
|
||||
* {"style", "filled"},
|
||||
* {"fillcolor", "white"},
|
||||
* {"label", shapeStr}}
|
||||
*
|
||||
* Default NodeInfo for Operation nodes:
|
||||
* {{"shape": "ellipse"},
|
||||
* {"style", "filled"},
|
||||
* {"fillcolor", "white"},
|
||||
* {"label", operationName}}
|
||||
*
|
||||
* If the key "label" is not set by `onValue` or `onOperation`, default labels
|
||||
* will be generated. For Value node, the default label is the shape string and
|
||||
* for Operation node, it is the operation name.
|
||||
*
|
||||
* Reference:
|
||||
* https://graphviz.org/doc/info/shapes.html
|
||||
* https://graphviz.org/doc/info/colors.html
|
||||
*
|
||||
* Usage:
|
||||
* C++: GraphDumper().dumpToFile(func, "func.dot");
|
||||
* Shell: dot -Tjpg func.dot -o func.jpg
|
||||
*/
|
||||
class GraphDumper {
|
||||
public:
|
||||
using NodeInfo = std::map<std::string, std::string>;
|
||||
|
||||
// Override this function to mark specific Values
|
||||
virtual NodeInfo onValue(Value value) const;
|
||||
// Override this function to mark specific Operations
|
||||
virtual NodeInfo onOperation(Operation *op) const;
|
||||
|
||||
std::string dump(triton::FuncOp func) const;
|
||||
void dumpToFile(triton::FuncOp func, const std::string &filename) const;
|
||||
|
||||
protected:
|
||||
std::string getShapeStr(const Type &type) const;
|
||||
|
||||
std::string getUniqueId(Value value) const;
|
||||
std::string getUniqueId(Operation *op) const;
|
||||
|
||||
std::string emitNode(const std::string &id, const NodeInfo style) const;
|
||||
std::string emitEdge(const std::string &srcId,
|
||||
const std::string &destId) const;
|
||||
|
||||
std::string emitValueNode(Value value) const;
|
||||
std::string emitOperationNode(Operation *op) const;
|
||||
};
|
||||
|
||||
/* A subclass of GraphDumper that marks different layout kinds in different
|
||||
* colors.*/
|
||||
class GraphLayoutMarker : public GraphDumper {
|
||||
public:
|
||||
NodeInfo onValue(Value value) const override;
|
||||
|
||||
protected:
|
||||
std::string getColor(const Type &type) const;
|
||||
};
|
||||
|
||||
// Infers the encoding of the result of op given the source encoding.
|
||||
std::optional<Attribute> inferDstEncoding(Operation *op, Attribute encoding);
|
||||
|
||||
// Infers the encoding of the source of op given the result encoding.
|
||||
std::optional<Attribute> inferSrcEncoding(Operation *op, Attribute encoding);
|
||||
|
||||
bool isExpensiveLoadOrStore(Operation *op);
|
||||
|
||||
bool canFoldIntoConversion(Operation *op, Attribute targetEncoding);
|
||||
|
||||
Operation *cloneWithInferType(mlir::OpBuilder &rewriter, Operation *op,
|
||||
IRMapping &mapping);
|
||||
|
||||
void rematerializeConversionChain(
|
||||
const llvm::MapVector<Value, Attribute> &toConvert,
|
||||
mlir::PatternRewriter &rewriter, SetVector<Operation *> &processed,
|
||||
IRMapping &mapping);
|
||||
// Get backward slice of tensor values starting from the root node along with
|
||||
// encoding propagation.
|
||||
LogicalResult getConvertBackwardSlice(
|
||||
Value root, SetVector<Value> &slice, Attribute rootEncoding,
|
||||
DenseMap<Value, Attribute> &layout,
|
||||
std::function<bool(Operation *)> stopPropagation = nullptr);
|
||||
|
||||
LogicalResult canMoveOutOfLoop(BlockArgument arg,
|
||||
SmallVector<Operation *> &cvts);
|
||||
// Populate pattern to remove dead cycles in ForOp.
|
||||
void populateForOpDeadArgumentElimination(RewritePatternSet &patterns);
|
||||
|
||||
// Convert an \param index to a multi-dim coordinate given \param shape and
|
||||
// \param order.
|
||||
SmallVector<Value> delinearize(OpBuilder &b, Location loc, Value linear,
|
||||
ArrayRef<unsigned> shape,
|
||||
ArrayRef<unsigned> order);
|
||||
|
||||
SmallVector<Value> delinearize(OpBuilder &b, Location loc, unsigned linear,
|
||||
ArrayRef<unsigned> shape);
|
||||
|
||||
SmallVector<Value> delinearize(OpBuilder &b, Location loc, Value linear,
|
||||
ArrayRef<unsigned> shape);
|
||||
Value linearize(OpBuilder &b, Location loc, ArrayRef<Value> multiDim,
|
||||
ArrayRef<unsigned> shape, ArrayRef<unsigned> order);
|
||||
|
||||
Value linearize(OpBuilder &b, Location loc, ArrayRef<Value> multiDim,
|
||||
ArrayRef<unsigned> shape);
|
||||
|
||||
// Returns null if the op is not inside a agent region (warp specialization
|
||||
// mode). Note that there should be at most one agent id attached to the
|
||||
// operation.
|
||||
std::optional<int> getWSAgentId(Operation *op);
|
||||
std::optional<int> getWSRoleId(Operation *op);
|
||||
void setRoleId(Operation *op, int roleId);
|
||||
|
||||
} // namespace mlir
|
||||
|
||||
|
||||
2
include/triton/Dialect/TritonNvidiaGPU/CMakeLists.txt
Normal file
2
include/triton/Dialect/TritonNvidiaGPU/CMakeLists.txt
Normal file
@@ -0,0 +1,2 @@
|
||||
add_subdirectory(IR)
|
||||
add_subdirectory(Transforms)
|
||||
15
include/triton/Dialect/TritonNvidiaGPU/IR/CMakeLists.txt
Normal file
15
include/triton/Dialect/TritonNvidiaGPU/IR/CMakeLists.txt
Normal file
@@ -0,0 +1,15 @@
|
||||
set(LLVM_TARGET_DEFINITIONS TritonNvidiaGPUOps.td)
|
||||
mlir_tablegen(Dialect.h.inc -gen-dialect-decls -dialect=triton_nvidia_gpu)
|
||||
mlir_tablegen(Dialect.cpp.inc -gen-dialect-defs -dialect=triton_nvidia_gpu)
|
||||
mlir_tablegen(Ops.h.inc -gen-op-decls)
|
||||
mlir_tablegen(Ops.cpp.inc -gen-op-defs)
|
||||
mlir_tablegen(Types.h.inc -gen-typedef-decls -typedefs-dialect=triton_nvidia_gpu)
|
||||
mlir_tablegen(Types.cpp.inc -gen-typedef-defs -typedefs-dialect=triton_nvidia_gpu)
|
||||
add_public_tablegen_target(TritonNvidiaGPUTableGen)
|
||||
|
||||
set(LLVM_TARGET_DEFINITIONS TritonNvidiaGPUAttrDefs.td)
|
||||
mlir_tablegen(TritonNvidiaGPUAttrDefs.h.inc -gen-attrdef-decls)
|
||||
mlir_tablegen(TritonNvidiaGPUAttrDefs.cpp.inc -gen-attrdef-defs)
|
||||
mlir_tablegen(OpsEnums.h.inc -gen-enum-decls)
|
||||
mlir_tablegen(OpsEnums.cpp.inc -gen-enum-defs)
|
||||
add_public_tablegen_target(TritonNvidiaGPUAttrDefsIncGen)
|
||||
46
include/triton/Dialect/TritonNvidiaGPU/IR/Dialect.h
Normal file
46
include/triton/Dialect/TritonNvidiaGPU/IR/Dialect.h
Normal file
@@ -0,0 +1,46 @@
|
||||
/*
|
||||
* Copyright (c) 2023 NVIDIA Corporation & Affiliates. All rights reserved.
|
||||
*
|
||||
* Permission is hereby granted, free of charge, to any person obtaining
|
||||
* a copy of this software and associated documentation files
|
||||
* (the "Software"), to deal in the Software without restriction,
|
||||
* including without limitation the rights to use, copy, modify, merge,
|
||||
* publish, distribute, sublicense, and/or sell copies of the Software,
|
||||
* and to permit persons to whom the Software is furnished to do so,
|
||||
* subject to the following conditions:
|
||||
*
|
||||
* The above copyright notice and this permission notice shall be
|
||||
* included in all copies or substantial portions of the Software.
|
||||
*
|
||||
* THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND,
|
||||
* EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF
|
||||
* MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT.
|
||||
* IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY
|
||||
* CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION OF CONTRACT,
|
||||
* TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION WITH THE
|
||||
* SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE.
|
||||
*/
|
||||
|
||||
#ifndef TRITON_DIALECT_TRITONNVIDIAGPU_IR_DIALECT_H_
|
||||
#define TRITON_DIALECT_TRITONNVIDIAGPU_IR_DIALECT_H_
|
||||
|
||||
#include "mlir/Dialect/GPU/IR/GPUDialect.h"
|
||||
#include "mlir/Dialect/Tensor/IR/Tensor.h"
|
||||
#include "mlir/IR/BuiltinOps.h"
|
||||
#include "mlir/IR/Dialect.h"
|
||||
|
||||
// TritonNvidiaGPU depends on Triton
|
||||
#include "triton/Dialect/NVGPU/IR/Dialect.h"
|
||||
#include "triton/Dialect/Triton/IR/Dialect.h"
|
||||
#include "triton/Dialect/TritonGPU/IR/Traits.h"
|
||||
#include "triton/Dialect/TritonNvidiaGPU/IR/Dialect.h.inc"
|
||||
#include "triton/Dialect/TritonNvidiaGPU/IR/Traits.h"
|
||||
#include "triton/Dialect/TritonNvidiaGPU/IR/Types.h"
|
||||
|
||||
#define GET_ATTRDEF_CLASSES
|
||||
#include "triton/Dialect/TritonNvidiaGPU/IR/TritonNvidiaGPUAttrDefs.h.inc"
|
||||
|
||||
#define GET_OP_CLASSES
|
||||
#include "triton/Dialect/TritonNvidiaGPU/IR/Ops.h.inc"
|
||||
|
||||
#endif // TRITON_DIALECT_TRITONNVIDIAGPU_IR_DIALECT_H_
|
||||
53
include/triton/Dialect/TritonNvidiaGPU/IR/Traits.h
Normal file
53
include/triton/Dialect/TritonNvidiaGPU/IR/Traits.h
Normal file
@@ -0,0 +1,53 @@
|
||||
/*
|
||||
* Copyright (c) 2023 NVIDIA Corporation & Affiliates. All rights reserved.
|
||||
*
|
||||
* Permission is hereby granted, free of charge, to any person obtaining
|
||||
* a copy of this software and associated documentation files
|
||||
* (the "Software"), to deal in the Software without restriction,
|
||||
* including without limitation the rights to use, copy, modify, merge,
|
||||
* publish, distribute, sublicense, and/or sell copies of the Software,
|
||||
* and to permit persons to whom the Software is furnished to do so,
|
||||
* subject to the following conditions:
|
||||
*
|
||||
* The above copyright notice and this permission notice shall be
|
||||
* included in all copies or substantial portions of the Software.
|
||||
*
|
||||
* THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND,
|
||||
* EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF
|
||||
* MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT.
|
||||
* IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY
|
||||
* CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION OF CONTRACT,
|
||||
* TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION WITH THE
|
||||
* SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE.
|
||||
*/
|
||||
|
||||
#ifndef TRITON_NVIDIA_GPU_IR_TRAITS_H_
|
||||
#define TRITON_NVIDIA_GPU_IR_TRAITS_H_
|
||||
|
||||
#include "mlir/IR/OpDefinition.h"
|
||||
|
||||
#include "mlir/IR/BuiltinTypes.h"
|
||||
#include "mlir/Support/LogicalResult.h"
|
||||
|
||||
namespace mlir {
|
||||
namespace OpTrait {
|
||||
|
||||
// These functions are out-of-line implementations of the methods in the
|
||||
// corresponding trait classes. This avoids them being template
|
||||
// instantiated/duplicated.
|
||||
namespace impl {
|
||||
LogicalResult verifySource1IsSharedEncoding(Operation *op);
|
||||
} // namespace impl
|
||||
|
||||
template <typename ConcreteType>
|
||||
class Source1IsSharedEncoding
|
||||
: public TraitBase<ConcreteType, Source1IsSharedEncoding> {
|
||||
public:
|
||||
static LogicalResult verifyTrait(Operation *op) {
|
||||
return impl::verifySource1IsSharedEncoding(op);
|
||||
}
|
||||
};
|
||||
} // namespace OpTrait
|
||||
} // namespace mlir
|
||||
|
||||
#endif
|
||||
@@ -0,0 +1,29 @@
|
||||
// Copyright (c) 2023 NVIDIA Corporation & Affiliates. All rights reserved.
|
||||
//
|
||||
// Permission is hereby granted, free of charge, to any person obtaining
|
||||
// a copy of this software and associated documentation files
|
||||
// (the "Software"), to deal in the Software without restriction,
|
||||
// including without limitation the rights to use, copy, modify, merge,
|
||||
// publish, distribute, sublicense, and/or sell copies of the Software,
|
||||
// and to permit persons to whom the Software is furnished to do so,
|
||||
// subject to the following conditions:
|
||||
//
|
||||
// The above copyright notice and this permission notice shall be
|
||||
// included in all copies or substantial portions of the Software.
|
||||
//
|
||||
// THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND,
|
||||
// EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF
|
||||
// MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT.
|
||||
// IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY
|
||||
// CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION OF CONTRACT,
|
||||
// TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION WITH THE
|
||||
// SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE.
|
||||
|
||||
#ifndef TRITONNVIDIAGPU_ATTRDEFS
|
||||
#define TRITONNVIDIAGPU_ATTRDEFS
|
||||
|
||||
include "mlir/IR/AttrTypeBase.td"
|
||||
include "triton/Dialect/TritonNvidiaGPU/IR/TritonNvidiaGPUDialect.td"
|
||||
include "triton/Dialect/Triton/IR/TritonInterfaces.td"
|
||||
|
||||
#endif
|
||||
@@ -0,0 +1,82 @@
|
||||
// Copyright (c) 2023 NVIDIA Corporation & Affiliates. All rights reserved.
|
||||
//
|
||||
// Permission is hereby granted, free of charge, to any person obtaining
|
||||
// a copy of this software and associated documentation files
|
||||
// (the "Software"), to deal in the Software without restriction,
|
||||
// including without limitation the rights to use, copy, modify, merge,
|
||||
// publish, distribute, sublicense, and/or sell copies of the Software,
|
||||
// and to permit persons to whom the Software is furnished to do so,
|
||||
// subject to the following conditions:
|
||||
//
|
||||
// The above copyright notice and this permission notice shall be
|
||||
// included in all copies or substantial portions of the Software.
|
||||
//
|
||||
// THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND,
|
||||
// EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF
|
||||
// MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT.
|
||||
// IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY
|
||||
// CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION OF CONTRACT,
|
||||
// TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION WITH THE
|
||||
// SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE.
|
||||
|
||||
#ifndef TRITONNVIDIAGPU_DIALECT
|
||||
#define TRITONNVIDIAGPU_DIALECT
|
||||
|
||||
include "mlir/IR/OpBase.td"
|
||||
|
||||
def TritonNvidiaGPU_Dialect : Dialect {
|
||||
let name = "triton_nvidia_gpu";
|
||||
|
||||
let cppNamespace = "::mlir::triton::nvidia_gpu";
|
||||
|
||||
let hasOperationAttrVerify = 1;
|
||||
|
||||
let description = [{
|
||||
Triton Nvidia GPU Dialect.
|
||||
}];
|
||||
|
||||
let dependentDialects = [
|
||||
"triton::TritonDialect",
|
||||
"triton::gpu::TritonGPUDialect",
|
||||
"mlir::triton::nvgpu::NVGPUDialect",
|
||||
"mlir::gpu::GPUDialect",
|
||||
"tensor::TensorDialect",
|
||||
];
|
||||
|
||||
let extraClassDeclaration = [{
|
||||
static std::string getNumWarpsAttrName() { return "triton_gpu.num-warps"; }
|
||||
static int getNumWarps(ModuleOp mod) {
|
||||
if(!mod->hasAttr("triton_gpu.num-warps"))
|
||||
llvm::report_fatal_error(
|
||||
"TritonGPU module should contain a triton_gpu.num-warps attribute");
|
||||
return mod->getAttr("triton_gpu.num-warps").cast<IntegerAttr>().getInt();
|
||||
}
|
||||
static int getNumCTAs(ModuleOp mod) {
|
||||
if(!mod->hasAttr("triton_gpu.num-ctas"))
|
||||
llvm::report_fatal_error(
|
||||
"TritonGPU module should contain a triton_gpu.num-ctas attribute");
|
||||
return mod->getAttr("triton_gpu.num-ctas").cast<IntegerAttr>().getInt();
|
||||
}
|
||||
static int getComputeCapability(ModuleOp mod) {
|
||||
if(!mod->hasAttr("triton_gpu.compute-capability"))
|
||||
llvm::report_fatal_error(
|
||||
"TritonGPU module should contain a triton_gpu.compute-capability attribute");
|
||||
return mod->getAttrOfType<IntegerAttr>("triton_gpu.compute-capability").getInt();
|
||||
}
|
||||
void registerTypes();
|
||||
|
||||
// Warp specialization related:
|
||||
static std::string getWSSupportedAttrName() { return "triton_gpu.enable-warp-specialization"; }
|
||||
static int getWSSupportedAttr(ModuleOp mod) {
|
||||
auto name = getWSSupportedAttrName();
|
||||
if (!mod->hasAttr(name)) return 0;
|
||||
return mod->getAttrOfType<IntegerAttr>(name).getInt();
|
||||
}
|
||||
}];
|
||||
|
||||
let useDefaultTypePrinterParser = 1;
|
||||
}
|
||||
|
||||
include "triton/Dialect/TritonNvidiaGPU/IR/TritonNvidiaGPUTypes.td"
|
||||
|
||||
#endif
|
||||
385
include/triton/Dialect/TritonNvidiaGPU/IR/TritonNvidiaGPUOps.td
Normal file
385
include/triton/Dialect/TritonNvidiaGPU/IR/TritonNvidiaGPUOps.td
Normal file
@@ -0,0 +1,385 @@
|
||||
// Copyright (c) 2023 NVIDIA Corporation & Affiliates. All rights reserved.
|
||||
//
|
||||
// Permission is hereby granted, free of charge, to any person obtaining
|
||||
// a copy of this software and associated documentation files
|
||||
// (the "Software"), to deal in the Software without restriction,
|
||||
// including without limitation the rights to use, copy, modify, merge,
|
||||
// publish, distribute, sublicense, and/or sell copies of the Software,
|
||||
// and to permit persons to whom the Software is furnished to do so,
|
||||
// subject to the following conditions:
|
||||
//
|
||||
// The above copyright notice and this permission notice shall be
|
||||
// included in all copies or substantial portions of the Software.
|
||||
//
|
||||
// THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND,
|
||||
// EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF
|
||||
// MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT.
|
||||
// IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY
|
||||
// CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION OF CONTRACT,
|
||||
// TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION WITH THE
|
||||
// SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE.
|
||||
|
||||
#ifndef TRITONNVIDIAGPU_OPS
|
||||
#define TRITONNVIDIAGPU_OPS
|
||||
|
||||
include "triton/Dialect/TritonNvidiaGPU/IR/TritonNvidiaGPUDialect.td"
|
||||
include "triton/Dialect/TritonNvidiaGPU/IR/TritonNvidiaGPUTypes.td"
|
||||
include "triton/Dialect/TritonNvidiaGPU/IR/TritonNvidiaGPUAttrDefs.td"
|
||||
include "mlir/Dialect/Arith/IR/ArithBase.td"
|
||||
include "triton/Dialect/Triton/IR/TritonTypes.td"
|
||||
include "triton/Dialect/Triton/IR/TritonAttrDefs.td"
|
||||
include "mlir/IR/OpBase.td"
|
||||
include "mlir/Interfaces/SideEffectInterfaces.td" // Pure
|
||||
include "mlir/Interfaces/InferTypeOpInterface.td" // SameOperandsAndResultType
|
||||
include "mlir/Interfaces/DestinationStyleOpInterface.td"
|
||||
include "mlir/Interfaces/ViewLikeInterface.td"
|
||||
|
||||
def Source1IsSharedEncoding: NativeOpTrait<"Source1IsSharedEncoding">;
|
||||
|
||||
def ResultsAreSharedEncoding: NativeOpTrait<"ResultsAreSharedEncoding">;
|
||||
|
||||
class TTNG_Op<string mnemonic, list<Trait> traits = []> :
|
||||
Op<TritonNvidiaGPU_Dialect, mnemonic, traits>;
|
||||
|
||||
// --------------------------------------------------------------------------------------------------
|
||||
// MBarrier related Ops:
|
||||
// 1, These mbarrier commands are currently not needed, and not taken into consideration:
|
||||
// (1), mbarrier.expect_tx
|
||||
// (2), mbarrier.arrive_drop
|
||||
// (3), mbarrier.complete_tx
|
||||
// (4), mbarrier.inval
|
||||
//
|
||||
// 2, The mbarriers is supported to be created in vector, and accessed in seperate via tensor.extract.
|
||||
// The mbarriers created in vector will have counters initialized in the same configuration. A
|
||||
// typical example to demonstrate this:
|
||||
//
|
||||
// %1 = triton_nvidia_gpu.alloc_mbarrier { count = 1 } : tensor<4x!tt.ptr<i64>>
|
||||
// scf.for %iv = %lb to %ub step %step iter_args() -> () {
|
||||
// %buffer_id = arith.remi %iv, %c4 : i32
|
||||
// %2 = triton_nvidia_gpu.extract_mbarrier %1[%buffer_id] : tensor<4xi64>, i32 -> !tt.ptr<i64>
|
||||
// triton_nvidia_gpu.mbarrier_arrive %2 {expectTx = 2048} : !tt.ptr<i64> -> ()
|
||||
// }
|
||||
// ...
|
||||
// scf.for %iv = %lb to %ub step %step iter_args() -> () {
|
||||
// %buffer_id = arith.remi %iv, %c4 : i32
|
||||
// %2 = triton_nvidia_gpu.extract_mbarrier %1[%buffer_id] : tensor<4xi64>, i32 -> !tt.ptr<i64>
|
||||
// triton_nvidia_gpu.mbarrier_wait %2, %c0 : !tt.ptr<i64>, i1 -> ()
|
||||
// }
|
||||
|
||||
def TTNG_AllocMBarrierOp : TTNG_Op<"alloc_mbarrier", [MemoryEffects<[MemAlloc]>]> {
|
||||
let summary = "allocate a vector of mbarriers";
|
||||
|
||||
let description = [{
|
||||
Allocate and initialize a vector of mbarriers. The size of the vector is implied in the returned type.
|
||||
Each mbarrier is initialized as:
|
||||
1, the current phase initialized to 0.
|
||||
2, the expected arrival count initialized to 'count'.
|
||||
3, the pending arrival count initialized to 'count'.
|
||||
4, the tx-count initialized to 0.
|
||||
|
||||
Example:
|
||||
|
||||
case a. when created in vector:
|
||||
%1 = triton_nvidia_gpu.alloc_mbarrier { count = 1 } : tensor<4xi64>
|
||||
|
||||
case b. when created in scalar:
|
||||
%1 = triton_nvidia_gpu.alloc_mbarrier { count = 1 } : !tt.ptr<i64>
|
||||
|
||||
}];
|
||||
|
||||
let assemblyFormat = [{attr-dict `:` type($result)}];
|
||||
|
||||
let arguments = (ins I32Attr:$count);
|
||||
|
||||
let results = (outs AnyTypeOf<[TT_Ptr, I64Tensor]>:$result);
|
||||
}
|
||||
|
||||
def TTNG_ExtractMBarrierOp : TTNG_Op<"extract_mbarrier", [Pure]> {
|
||||
let summary = "extract a mbarrier from a vector of mbarriers";
|
||||
|
||||
let description = [{
|
||||
Extract a mbarrier from a vector of mbarriers
|
||||
|
||||
Example:
|
||||
|
||||
%1 = triton_nvidia_gpu.extract_mbarrier %mbarriers[%idx] : tensor<4xi64>, index -> !tt.ptr<i64>
|
||||
|
||||
}];
|
||||
|
||||
let assemblyFormat = "$tensor `[` $index `]` attr-dict `:` type($tensor) `,` type($index) `->` type($result)";
|
||||
|
||||
let arguments = (ins I64Tensor:$tensor, I32:$index);
|
||||
|
||||
let results = (outs TT_Ptr:$result);
|
||||
}
|
||||
|
||||
def TTNG_MBarrierWaitOp : TTNG_Op<"mbarrier_wait", [MemoryEffects<[MemRead, MemWrite]>]> {
|
||||
let summary = "mbarrier wait";
|
||||
|
||||
let description = [{
|
||||
This operation defining the waiting action for a mbarrier.
|
||||
The subsequent operations should not execute until this operation completes waiting.
|
||||
|
||||
Example:
|
||||
|
||||
triton_nvidia_gpu.mbarrier_wait %0, %1 : !tt.ptr<i64>
|
||||
|
||||
}];
|
||||
|
||||
let arguments = (ins TT_Ptr:$mbarrier, I1: $phase);
|
||||
|
||||
let assemblyFormat = "$mbarrier `,` $phase attr-dict `:` type($mbarrier)";
|
||||
}
|
||||
|
||||
def TTNG_MBarrierArriveOp : TTNG_Op<"mbarrier_arrive", [AttrSizedOperandSegments,
|
||||
MemoryEffects<[MemWrite]>]> {
|
||||
let summary = "mbarrier arrive";
|
||||
|
||||
let description = [{
|
||||
This operation defining the arriving action for a mbarrier.
|
||||
txCount:
|
||||
An optional attribute that set tx-count. This Op will be lowered into
|
||||
mbarrier.arrive.expect_tx if the optional attribute exist.
|
||||
trackAsyncOp:
|
||||
If true, this op will be lowered into cp.async.mbarrier.arrive.noinc.
|
||||
pred:
|
||||
Only perform arrive action when pred is true.
|
||||
remoteCtaId:
|
||||
if set, perform an remote arrive action.
|
||||
|
||||
Example:
|
||||
|
||||
triton_nvidia_gpu.mbarrier_arrive %0 {trackAsyncOp = false} : !tt.ptr<i64>
|
||||
|
||||
}];
|
||||
|
||||
let arguments = (ins TT_Ptr:$mbarrier,
|
||||
Optional<I1>:$pred,
|
||||
Optional<I32>:$remoteCtaId,
|
||||
I1Attr: $trackAsyncOp,
|
||||
DefaultValuedAttr<I32Attr, "0">: $txCount
|
||||
);
|
||||
|
||||
let assemblyFormat = "operands attr-dict `:` type(operands)";
|
||||
}
|
||||
|
||||
def TTNG_FenceAsyncSharedOp : TTNG_Op<"fence_async_shared"> {
|
||||
let arguments = (ins BoolAttr:$bCluster);
|
||||
|
||||
let summary = "fence proxy async";
|
||||
|
||||
let assemblyFormat = "attr-dict";
|
||||
|
||||
let extraClassDeclaration = [{
|
||||
static bool isSupported(int computeCapability) {
|
||||
return computeCapability >= 90;
|
||||
}
|
||||
}];
|
||||
}
|
||||
|
||||
// TODO[goostavz]: ThreadId & ClusterCTAId should not be exposed to
|
||||
// ttgpu level. Remove them when async dialect is ready.
|
||||
def TTNG_GetThreadIdOp : TTNG_Op<"get_thread_id", [Pure]> {
|
||||
let description = [{
|
||||
Returns the one dimensional threadId.
|
||||
}];
|
||||
|
||||
let results = (outs I32:$result);
|
||||
let assemblyFormat = "attr-dict `:` type($result)";
|
||||
}
|
||||
|
||||
def TTNG_GetClusterCTAIdOp : TTNG_Op<"get_cluster_cta_id", [Pure]> {
|
||||
let description = [{
|
||||
Returns the one dimensional cluster_cta_id.
|
||||
}];
|
||||
|
||||
let results = (outs I32:$result);
|
||||
let assemblyFormat = "attr-dict `:` type($result)";
|
||||
}
|
||||
|
||||
def TTNG_NamedBarrierArriveOp : TTNG_Op<"bar_arrive", []> {
|
||||
let summary = "named barrier arrive";
|
||||
|
||||
let arguments = (ins I32:$bar, I32: $numThreads);
|
||||
|
||||
let assemblyFormat = "$bar `,` $numThreads attr-dict `:` type(operands)";
|
||||
}
|
||||
|
||||
def TTNG_NamedBarrierWaitOp : TTNG_Op<"bar_wait", []> {
|
||||
let summary = "named barrier wait";
|
||||
|
||||
let arguments = (ins I32:$bar, I32: $numThreads);
|
||||
|
||||
let assemblyFormat = "$bar `,` $numThreads attr-dict `:` type(operands)";
|
||||
}
|
||||
|
||||
def TTNG_InsertSliceAsyncV2Op : TTNG_Op<"insert_slice_async_v2",
|
||||
[AttrSizedOperandSegments,
|
||||
ResultsAreSharedEncoding,
|
||||
// TODO: Check if MemWrite will degrade performance of non-warp-specialized kernel
|
||||
MemoryEffects<[MemRead, MemWrite]>]> {
|
||||
|
||||
let arguments = (ins AnyTypeOf<[TT_Ptr, TT_PtrTensor]>:$src, TT_Tensor:$dst,
|
||||
I32:$index, TT_Ptr:$mbar,
|
||||
Optional<AnyTypeOf<[I1Tensor, I1]>>:$mask, Optional<TT_Type>:$other,
|
||||
TT_CacheModifierAttr:$cache, TT_EvictionPolicyAttr:$evict,
|
||||
BoolAttr:$isVolatile, I32Attr:$axis);
|
||||
|
||||
let results = (outs TT_Tensor:$result);
|
||||
|
||||
let assemblyFormat = "operands attr-dict `:` type(operands) `->` type($result)";
|
||||
}
|
||||
|
||||
// TODO: the abstraction of barriers in ttgpu level is pending, will revisit later
|
||||
// def TTNG_AwaitOp : TTNG_Op<"await", []> {
|
||||
// let arguments = (ins TTNG_TokenType:$token);
|
||||
// let assemblyFormat = "$token attr-dict `:` type($token)";
|
||||
// }
|
||||
|
||||
def TTNG_ClusterArriveOp : TTNG_Op<"cluster_arrive", []> {
|
||||
let arguments = (ins I1Attr:$relaxed);
|
||||
let assemblyFormat = "attr-dict";
|
||||
}
|
||||
|
||||
def TTNG_ClusterWaitOp : TTNG_Op<"cluster_wait", []> {
|
||||
let assemblyFormat = "attr-dict";
|
||||
}
|
||||
|
||||
//
|
||||
// DotAsync Op
|
||||
//
|
||||
def TTNG_DotAsyncOp : TTNG_Op<"dot_async", [Pure,
|
||||
DeclareOpInterfaceMethods<InferTypeOpInterface>,
|
||||
TypesMatchWith<"result's type matches accumulator's type",
|
||||
"d", "c", "$_self">]> {
|
||||
let summary = "dot async";
|
||||
|
||||
let description = [{
|
||||
$d = matrix_multiply($a, $b) + $c
|
||||
}];
|
||||
|
||||
let arguments = (ins TT_FpIntTensor:$a, TT_FpIntTensor:$b, TT_FpIntTensor:$c, BoolAttr:$allowTF32);
|
||||
|
||||
let results = (outs TT_FpIntTensor:$d);
|
||||
|
||||
let assemblyFormat = "$a`,` $b`,` $c attr-dict `:` type($a) `*` type($b) `->` type($d)";
|
||||
}
|
||||
|
||||
def TTNG_DotWaitOp : TTNG_Op<"dot_wait", []> {
|
||||
let summary = "dot wait";
|
||||
|
||||
let description = [{
|
||||
This operation defining the waiting action for a async dot, MMAv3 .e.g.
|
||||
The subsequent operations should not execute until this operation completes waiting.
|
||||
}];
|
||||
|
||||
let arguments = (ins I32Attr:$pendings);
|
||||
|
||||
let assemblyFormat = "attr-dict";
|
||||
}
|
||||
|
||||
def TTNG_StoreAsyncOp : TTNG_Op<"store_async",
|
||||
[MemoryEffects<[MemWrite]>]> {
|
||||
let summary = "store asynchronous by a tensor pointer";
|
||||
let arguments = (ins TT_TensorPtr:$dst, TT_Tensor:$src,
|
||||
DefaultValuedAttr<TT_CacheModifierAttr, "triton::CacheModifier::NONE">:$cache);
|
||||
let assemblyFormat = "operands attr-dict `:` type(operands)";
|
||||
}
|
||||
|
||||
def TTNG_GetAgentIdOp : TTNG_Op<"get_agent_id", [Pure]> {
|
||||
let results = (outs I32:$result);
|
||||
|
||||
let builders = [OpBuilder<(ins)>];
|
||||
|
||||
let assemblyFormat = "attr-dict `:` type($result)";
|
||||
}
|
||||
|
||||
//
|
||||
// Token
|
||||
//
|
||||
|
||||
def TTNG_CreateTokenOp : TTNG_Op<"create_token"> {
|
||||
let results = (outs TensorOf<[TTNG_TokenType]>:$result);
|
||||
|
||||
let arguments = (ins I32Attr:$num);
|
||||
|
||||
let builders = [OpBuilder<(ins "uint32_t":$num)>];
|
||||
|
||||
let assemblyFormat = "attr-dict `:` type($result)";
|
||||
}
|
||||
|
||||
def TTNG_ProducerAcquireOp : TTNG_Op<"producer_acquire"> {
|
||||
let arguments = (ins TensorOf<[TTNG_TokenType]>:$token, I32:$idx);
|
||||
|
||||
let assemblyFormat = "$token `,` $idx attr-dict `:` type(operands)";
|
||||
}
|
||||
|
||||
def TTNG_ProducerCommitOp : TTNG_Op<"producer_commit"> {
|
||||
let arguments = (ins TensorOf<[TTNG_TokenType]>:$token, I32:$idx);
|
||||
|
||||
let assemblyFormat = "$token `,` $idx attr-dict `:` type(operands)";
|
||||
}
|
||||
|
||||
def TTNG_ConsumerWaitOp : TTNG_Op<"consumer_wait"> {
|
||||
let arguments = (ins TensorOf<[TTNG_TokenType]>:$token, I32:$idx);
|
||||
|
||||
let assemblyFormat = "$token `,` $idx attr-dict `:` type(operands)";
|
||||
}
|
||||
|
||||
def TTNG_ConsumerReleaseOp : TTNG_Op<"consumer_release"> {
|
||||
let arguments = (ins TensorOf<[TTNG_TokenType]>:$token, I32:$idx);
|
||||
|
||||
let assemblyFormat = "$token `,` $idx attr-dict `:` type(operands)";
|
||||
}
|
||||
|
||||
//
|
||||
// Mutex
|
||||
//
|
||||
|
||||
def TTNG_GetMutexRoleIdOp : TTNG_Op<"get_mutex_role_id"> {
|
||||
let results = (outs I32:$result);
|
||||
|
||||
let arguments = (ins I32Attr:$num);
|
||||
|
||||
let builders = [OpBuilder<(ins "uint32_t":$num)>];
|
||||
|
||||
let assemblyFormat = "attr-dict `:` type($result)";
|
||||
}
|
||||
|
||||
def TTNG_CreateMutexOp : TTNG_Op<"create_mutex"> {
|
||||
let results = (outs TTNG_MutexType:$result);
|
||||
|
||||
let builders = [OpBuilder<(ins)>];
|
||||
|
||||
let assemblyFormat = "attr-dict `:` type($result)";
|
||||
}
|
||||
|
||||
def TTNG_LockOp : TTNG_Op<"lock"> {
|
||||
let arguments = (ins TTNG_MutexType:$mutex);
|
||||
|
||||
let assemblyFormat = "$mutex attr-dict `:` type(operands)";
|
||||
}
|
||||
|
||||
def TTNG_UnlockOp : TTNG_Op<"unlock"> {
|
||||
let arguments = (ins TTNG_MutexType:$mutex);
|
||||
|
||||
let assemblyFormat = "$mutex attr-dict `:` type(operands)";
|
||||
}
|
||||
|
||||
def TTNG_RegAllocOp : TTNG_Op<"reg_alloc", []> {
|
||||
let summary = "register allocation";
|
||||
|
||||
let arguments = (ins I32Attr: $regCount);
|
||||
|
||||
let assemblyFormat = "$regCount attr-dict";
|
||||
}
|
||||
|
||||
def TTNG_RegDeallocOp : TTNG_Op<"reg_dealloc", []> {
|
||||
let summary = "register deallocation";
|
||||
|
||||
let arguments = (ins I32Attr: $regCount);
|
||||
|
||||
let assemblyFormat = "$regCount attr-dict";
|
||||
}
|
||||
|
||||
#endif
|
||||
@@ -0,0 +1,37 @@
|
||||
// Copyright (c) 2023 NVIDIA Corporation & Affiliates. All rights reserved.
|
||||
//
|
||||
// Permission is hereby granted, free of charge, to any person obtaining
|
||||
// a copy of this software and associated documentation files
|
||||
// (the "Software"), to deal in the Software without restriction,
|
||||
// including without limitation the rights to use, copy, modify, merge,
|
||||
// publish, distribute, sublicense, and/or sell copies of the Software,
|
||||
// and to permit persons to whom the Software is furnished to do so,
|
||||
// subject to the following conditions:
|
||||
//
|
||||
// The above copyright notice and this permission notice shall be
|
||||
// included in all copies or substantial portions of the Software.
|
||||
//
|
||||
// THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND,
|
||||
// EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF
|
||||
// MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT.
|
||||
// IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY
|
||||
// CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION OF CONTRACT,
|
||||
// TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION WITH THE
|
||||
// SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE.
|
||||
|
||||
#ifndef TRITONNVIDIAGPU_TYPES
|
||||
#define TRITONNVIDIAGPU_TYPES
|
||||
|
||||
include "triton/Dialect/TritonNvidiaGPU/IR/TritonNvidiaGPUDialect.td"
|
||||
include "mlir/IR/AttrTypeBase.td"
|
||||
|
||||
class TTNG_TypeDef<string name, string _mnemonic>
|
||||
: TypeDef<TritonNvidiaGPU_Dialect, name> {
|
||||
let mnemonic = _mnemonic;
|
||||
}
|
||||
|
||||
def TTNG_TokenType : TTNG_TypeDef<"Token", "token">;
|
||||
|
||||
def TTNG_MutexType : TTNG_TypeDef<"Mutex", "mutex">;
|
||||
|
||||
#endif
|
||||
33
include/triton/Dialect/TritonNvidiaGPU/IR/Types.h
Normal file
33
include/triton/Dialect/TritonNvidiaGPU/IR/Types.h
Normal file
@@ -0,0 +1,33 @@
|
||||
/*
|
||||
* Copyright (c) 2023 NVIDIA Corporation & Affiliates. All rights reserved.
|
||||
*
|
||||
* Permission is hereby granted, free of charge, to any person obtaining
|
||||
* a copy of this software and associated documentation files
|
||||
* (the "Software"), to deal in the Software without restriction,
|
||||
* including without limitation the rights to use, copy, modify, merge,
|
||||
* publish, distribute, sublicense, and/or sell copies of the Software,
|
||||
* and to permit persons to whom the Software is furnished to do so,
|
||||
* subject to the following conditions:
|
||||
*
|
||||
* The above copyright notice and this permission notice shall be
|
||||
* included in all copies or substantial portions of the Software.
|
||||
*
|
||||
* THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND,
|
||||
* EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF
|
||||
* MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT.
|
||||
* IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY
|
||||
* CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION OF CONTRACT,
|
||||
* TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION WITH THE
|
||||
* SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE.
|
||||
*/
|
||||
|
||||
#ifndef TRITONNVIDIAGPU_IR_TYPES_H_
|
||||
#define TRITONNVIDIAGPU_IR_TYPES_H_
|
||||
|
||||
#include "mlir/IR/TypeSupport.h"
|
||||
#include "mlir/IR/Types.h"
|
||||
|
||||
#define GET_TYPEDEF_CLASSES
|
||||
#include "triton/Dialect/TritonNvidiaGPU/IR/Types.h.inc"
|
||||
|
||||
#endif // TRITON_IR_TYPES_H_
|
||||
@@ -0,0 +1,3 @@
|
||||
set(LLVM_TARGET_DEFINITIONS Passes.td)
|
||||
mlir_tablegen(Passes.h.inc -gen-pass-decls -name TritonNvidiaGPU)
|
||||
add_public_tablegen_target(TritonNvidiaGPUTransformsIncGen)
|
||||
83
include/triton/Dialect/TritonNvidiaGPU/Transforms/Passes.h
Normal file
83
include/triton/Dialect/TritonNvidiaGPU/Transforms/Passes.h
Normal file
@@ -0,0 +1,83 @@
|
||||
/*
|
||||
* Copyright (c) 2023 NVIDIA Corporation & Affiliates. All rights reserved.
|
||||
*
|
||||
* Permission is hereby granted, free of charge, to any person obtaining
|
||||
* a copy of this software and associated documentation files
|
||||
* (the "Software"), to deal in the Software without restriction,
|
||||
* including without limitation the rights to use, copy, modify, merge,
|
||||
* publish, distribute, sublicense, and/or sell copies of the Software,
|
||||
* and to permit persons to whom the Software is furnished to do so,
|
||||
* subject to the following conditions:
|
||||
*
|
||||
* The above copyright notice and this permission notice shall be
|
||||
* included in all copies or substantial portions of the Software.
|
||||
*
|
||||
* THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND,
|
||||
* EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF
|
||||
* MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT.
|
||||
* IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY
|
||||
* CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION OF CONTRACT,
|
||||
* TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION WITH THE
|
||||
* SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE.
|
||||
*/
|
||||
|
||||
#ifndef TRITON_DIALECT_TRITONNVIDIAGPU_TRANSFORMS_PASSES_H_
|
||||
#define TRITON_DIALECT_TRITONNVIDIAGPU_TRANSFORMS_PASSES_H_
|
||||
|
||||
#include "mlir/Pass/Pass.h"
|
||||
|
||||
namespace mlir {
|
||||
namespace triton {
|
||||
namespace nvidia_gpu {
|
||||
|
||||
// Used by Triton runtime
|
||||
struct ClusterInfo {
|
||||
ClusterInfo() : clusterDimX(1), clusterDimY(1), clusterDimZ(1) {}
|
||||
int clusterDimX;
|
||||
int clusterDimY;
|
||||
int clusterDimZ;
|
||||
};
|
||||
|
||||
} // namespace nvidia_gpu
|
||||
} // namespace triton
|
||||
} // namespace mlir
|
||||
|
||||
namespace mlir {
|
||||
|
||||
std::unique_ptr<Pass>
|
||||
createTritonNvidiaGPUMaterializeLoadStorePass(int numWarps = 4,
|
||||
int computeCapability = 80);
|
||||
|
||||
std::unique_ptr<Pass> createTritonNvidiaGPUPlanCTAPass(
|
||||
mlir::triton::nvidia_gpu::ClusterInfo *clusterInfo = nullptr);
|
||||
|
||||
std::unique_ptr<Pass>
|
||||
createTritonNvidiaGPUWSFeasibilityCheckingPass(int computeCapability = 90);
|
||||
|
||||
std::unique_ptr<Pass>
|
||||
createTritonNvidiaGPUWSDecomposingPass(int computeCapability = 90);
|
||||
|
||||
std::unique_ptr<Pass>
|
||||
createTritonNvidiaGPUWSPipelinePass(int numStages = 3, int numWarps = 4,
|
||||
int computeCapability = 90);
|
||||
|
||||
std::unique_ptr<Pass>
|
||||
createTritonNvidiaGPUWSMutexPass(int computeCapability = 90);
|
||||
|
||||
std::unique_ptr<Pass>
|
||||
createTritonNvidiaGPUWSMaterializationPass(int computeCapability = 90);
|
||||
|
||||
std::unique_ptr<Pass>
|
||||
createTritonNvidiaGPUFenceInsertionPass(int computeCapability = 90);
|
||||
|
||||
std::unique_ptr<Pass>
|
||||
createTritonGPURewriteTensorPointerPass(int computeCapability = 80);
|
||||
|
||||
std::unique_ptr<Pass> createTritonNvidiaGPUWSFixupMissingAttrs();
|
||||
|
||||
/// Generate the code for registering passes.
|
||||
#define GEN_PASS_REGISTRATION
|
||||
#include "triton/Dialect/TritonNvidiaGPU/Transforms/Passes.h.inc"
|
||||
|
||||
} // namespace mlir
|
||||
#endif // TRITON_DIALECT_TRITONNVIDIAGPU_TRANSFORMS_PASSES_H_
|
||||
246
include/triton/Dialect/TritonNvidiaGPU/Transforms/Passes.td
Normal file
246
include/triton/Dialect/TritonNvidiaGPU/Transforms/Passes.td
Normal file
@@ -0,0 +1,246 @@
|
||||
// Copyright (c) 2023 NVIDIA Corporation & Affiliates. All rights reserved.
|
||||
//
|
||||
// Permission is hereby granted, free of charge, to any person obtaining
|
||||
// a copy of this software and associated documentation files
|
||||
// (the "Software"), to deal in the Software without restriction,
|
||||
// including without limitation the rights to use, copy, modify, merge,
|
||||
// publish, distribute, sublicense, and/or sell copies of the Software,
|
||||
// and to permit persons to whom the Software is furnished to do so,
|
||||
// subject to the following conditions:
|
||||
//
|
||||
// The above copyright notice and this permission notice shall be
|
||||
// included in all copies or substantial portions of the Software.
|
||||
//
|
||||
// THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND,
|
||||
// EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF
|
||||
// MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT.
|
||||
// IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY
|
||||
// CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION OF CONTRACT,
|
||||
// TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION WITH THE
|
||||
// SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE.
|
||||
|
||||
#ifndef TRITONNVIDIAGPU_PASSES
|
||||
#define TRITONNVIDIAGPU_PASSES
|
||||
|
||||
include "mlir/Pass/PassBase.td"
|
||||
|
||||
def MaterializeLoadStore : Pass<"triton-nvidia-gpu-materialize-load-store", "mlir::ModuleOp"> {
|
||||
let summary = "materialize load & store";
|
||||
|
||||
let description = [{
|
||||
This pass works after pipeline pass, converting the remaining tt.LoadOp taking
|
||||
ptr<tensor> as input into ttg.InsertSliceAsyncOp and emit proper barriers
|
||||
}];
|
||||
|
||||
let constructor = "mlir::createTritonNvidiaGPUMaterializeLoadStorePass()";
|
||||
|
||||
let dependentDialects = ["mlir::triton::gpu::TritonGPUDialect",
|
||||
"mlir::triton::nvidia_gpu::TritonNvidiaGPUDialect",
|
||||
"mlir::scf::SCFDialect",
|
||||
"mlir::arith::ArithDialect"];
|
||||
|
||||
let options = [
|
||||
Option<"numWarps", "num-warps",
|
||||
"int32_t", /*default*/"4",
|
||||
"number of warps per block">,
|
||||
Option<"computeCapability", "compute-capability",
|
||||
"int32_t", /*default*/"80",
|
||||
"device compute capability">
|
||||
];
|
||||
}
|
||||
|
||||
def TritonGPUPlanCTAPass : Pass<"triton-nvidia-gpu-plan-cta", "mlir::ModuleOp"> {
|
||||
let summary = "plan CTA";
|
||||
|
||||
let description = [{
|
||||
Plan CTAs in CGA
|
||||
}];
|
||||
|
||||
let constructor = "mlir::createTritonNvidiaGPUPlanCTAPass()";
|
||||
|
||||
let dependentDialects = [
|
||||
"mlir::triton::gpu::TritonGPUDialect",
|
||||
"mlir::triton::nvidia_gpu::TritonNvidiaGPUDialect"
|
||||
];
|
||||
}
|
||||
|
||||
def TritonGPUWSFeasibilityChecking : Pass<"triton-nvidia-gpu-ws-feasibility-checking", "mlir::ModuleOp"> {
|
||||
let summary = "Attach attr named TritonNvidiaGPUDialect::getWSSupportedAttrName() if auto WS supported";
|
||||
|
||||
let description = [{
|
||||
Since not every legal triton kernels can be auto WS, this pass does some (conservative) check
|
||||
and attaches an attribute named TritonNvidiaGPUDialect::getWSSupportedAttrName() on
|
||||
the input module op if the kernel is supported.
|
||||
}];
|
||||
|
||||
let constructor = "mlir::createTritonNvidiaGPUWSFeasibilityCheckingPass()";
|
||||
|
||||
let dependentDialects = [
|
||||
"mlir::triton::gpu::TritonGPUDialect",
|
||||
"mlir::triton::nvidia_gpu::TritonNvidiaGPUDialect"
|
||||
];
|
||||
|
||||
let options = [
|
||||
Option<"computeCapability", "compute-capability",
|
||||
"int32_t", /*default*/"90",
|
||||
"device compute capability">
|
||||
];
|
||||
}
|
||||
|
||||
def TritonGPUWSDecomposing : Pass<"triton-nvidia-gpu-ws-decomposing", "mlir::ModuleOp"> {
|
||||
let summary = "Clustering on the ops according to their performance hotspots";
|
||||
|
||||
let description = [{
|
||||
Based on compute capability and heuristics,
|
||||
this pass will identify some operations to be executed in different agents,
|
||||
by marking them with async 'label'. E.g.,
|
||||
input:
|
||||
%1 = tt,load %0 ...
|
||||
%4 = tt.dot %1, %2, %3 ...
|
||||
output:
|
||||
%1 = tt,load %0 {async_agent = 0} ...
|
||||
%4 = tt.dot %1, %2, %3 {async_agent = 1} : ...
|
||||
}];
|
||||
|
||||
let constructor = "mlir::createTritonNvidiaGPUWSDecomposingPass()";
|
||||
|
||||
let dependentDialects = [
|
||||
"mlir::triton::gpu::TritonGPUDialect",
|
||||
"mlir::triton::nvidia_gpu::TritonNvidiaGPUDialect"
|
||||
];
|
||||
|
||||
let options = [
|
||||
Option<"computeCapability", "compute-capability",
|
||||
"int32_t", /*default*/"80",
|
||||
"device compute capability">
|
||||
];
|
||||
}
|
||||
|
||||
def TritonGPUWSPipeline : Pass<"triton-nvidia-gpu-ws-pipeline", "mlir::ModuleOp"> {
|
||||
let summary = "Warp specialization pipeline";
|
||||
|
||||
let description = [{
|
||||
}];
|
||||
|
||||
let constructor = "mlir::createTritonNvidiaGPUWSPipelinePass()";
|
||||
|
||||
let dependentDialects = ["mlir::triton::gpu::TritonGPUDialect",
|
||||
"mlir::triton::nvidia_gpu::TritonNvidiaGPUDialect",
|
||||
"mlir::scf::SCFDialect",
|
||||
"mlir::arith::ArithDialect"];
|
||||
|
||||
let options = [
|
||||
Option<"numStages", "num-stages",
|
||||
"int32_t", /*default*/"3",
|
||||
"number of pipeline stages">,
|
||||
Option<"numWarps", "num-warps",
|
||||
"int32_t", /*default*/"12",
|
||||
"number of warps per block">,
|
||||
Option<"computeCapability", "compute-capability",
|
||||
"int32_t", /*default*/"90",
|
||||
"device compute capability">
|
||||
];
|
||||
}
|
||||
|
||||
def TritonGPUWSMutex : Pass<"triton-nvidia-gpu-ws-mutex", "mlir::ModuleOp"> {
|
||||
let summary = "Warp specialization mutex syncronization";
|
||||
|
||||
let description = [{
|
||||
create mutex syncronization for persistent kernel. (as "2 Math WG" persistent kernel in cutlass)
|
||||
For example, the agent containing dot and store will be divided into two sub-agent,
|
||||
which execute dot and store alternately. i.e.:
|
||||
sub-agent-0: dot | store | dot | ... | store
|
||||
sub-agent-1: | dot | store | ... | dot | store
|
||||
}];
|
||||
|
||||
let constructor = "mlir::createTritonNvidiaGPUWSMutexPass()";
|
||||
|
||||
let dependentDialects = ["mlir::triton::gpu::TritonGPUDialect",
|
||||
"mlir::triton::nvidia_gpu::TritonNvidiaGPUDialect",
|
||||
"mlir::scf::SCFDialect",
|
||||
"mlir::arith::ArithDialect"];
|
||||
|
||||
let options = [
|
||||
Option<"computeCapability", "compute-capability",
|
||||
"int32_t", /*default*/"80",
|
||||
"device compute capability">
|
||||
];
|
||||
}
|
||||
|
||||
def TritonGPUWSMaterialization : Pass<"triton-nvidia-gpu-ws-materialization", "mlir::ModuleOp"> {
|
||||
let summary = "Warp specialization materialization";
|
||||
|
||||
let description = [{
|
||||
}];
|
||||
|
||||
let constructor = "mlir::createTritonNvidiaGPUWSMaterializationPass()";
|
||||
|
||||
let dependentDialects = ["mlir::triton::gpu::TritonGPUDialect",
|
||||
"mlir::triton::nvidia_gpu::TritonNvidiaGPUDialect"];
|
||||
|
||||
let options = [
|
||||
Option<"computeCapability", "compute-capability",
|
||||
"int32_t", /*default*/"90",
|
||||
"device compute capability">
|
||||
];
|
||||
}
|
||||
|
||||
def TritonGPUFenceInsertion : Pass<"triton-nvidia-gpu-fence-insertion", "mlir::ModuleOp"> {
|
||||
let summary = "Insert fences across generic and async proxy";
|
||||
|
||||
let description = [{
|
||||
}];
|
||||
|
||||
let constructor = "mlir::createTritonNvidiaGPUFenceInsertionPass()";
|
||||
|
||||
let dependentDialects = [
|
||||
"mlir::triton::gpu::TritonGPUDialect",
|
||||
"mlir::triton::nvidia_gpu::TritonNvidiaGPUDialect"
|
||||
];
|
||||
|
||||
let options = [
|
||||
Option<"computeCapability", "compute-capability",
|
||||
"int32_t", /*default*/"90",
|
||||
"device compute capability">
|
||||
];
|
||||
}
|
||||
|
||||
def TritonGPURewriteTensorPointer : Pass</*cli-arg*/"tritongpu-rewrite-tensor-pointer", /*Op*/"mlir::ModuleOp"> {
|
||||
let summary = "Rewrite load/stores with tensor pointers into legacy load/stores";
|
||||
let description = [{
|
||||
This pass rewrites all load/store semantics initiated by a `tt.make_tensor_ptr` and `tt.advance` into legacy
|
||||
semantics. After this pass, `tt.make_tensor_ptr` and `tt.advance` will disappear, and it generates logics to compute
|
||||
the pointer/mask/other for each load/store.
|
||||
}];
|
||||
|
||||
let constructor = "mlir::createTritonGPURewriteTensorPointerPass()";
|
||||
|
||||
let dependentDialects = ["mlir::triton::gpu::TritonGPUDialect",
|
||||
"mlir::triton::TritonDialect"];
|
||||
|
||||
let options = [
|
||||
Option<"computeCapability", "compute-capability",
|
||||
"int32_t", /*default*/"80",
|
||||
"device compute capability">
|
||||
];
|
||||
}
|
||||
|
||||
def TritonGPUWSFixupMissingAttrs : Pass<"triton-nvidia-gpu-ws-fixup-missing-attrs", "mlir::ModuleOp"> {
|
||||
let summary = "Fixup missing WS related attributes";
|
||||
|
||||
let description = [{
|
||||
WS related attributes are attached to some key operations and are used when lowering to llvm.
|
||||
However these attributes maybe be dropped in the following IR transform. This pass tries to
|
||||
fixup the missing attributes.
|
||||
}];
|
||||
|
||||
let constructor = "mlir::createTritonNvidiaGPUWSFixupMissingAttrs()";
|
||||
|
||||
let dependentDialects = ["mlir::triton::gpu::TritonGPUDialect",
|
||||
"mlir::triton::nvidia_gpu::TritonNvidiaGPUDialect",
|
||||
"mlir::scf::SCFDialect",
|
||||
"mlir::arith::ArithDialect"];
|
||||
}
|
||||
|
||||
|
||||
#endif
|
||||
95
include/triton/Dialect/TritonNvidiaGPU/Transforms/Utility.h
Normal file
95
include/triton/Dialect/TritonNvidiaGPU/Transforms/Utility.h
Normal file
@@ -0,0 +1,95 @@
|
||||
/*
|
||||
* Copyright (c) 2023 NVIDIA Corporation & Affiliates. All rights reserved.
|
||||
*
|
||||
* Permission is hereby granted, free of charge, to any person obtaining
|
||||
* a copy of this software and associated documentation files
|
||||
* (the "Software"), to deal in the Software without restriction,
|
||||
* including without limitation the rights to use, copy, modify, merge,
|
||||
* publish, distribute, sublicense, and/or sell copies of the Software,
|
||||
* and to permit persons to whom the Software is furnished to do so,
|
||||
* subject to the following conditions:
|
||||
*
|
||||
* The above copyright notice and this permission notice shall be
|
||||
* included in all copies or substantial portions of the Software.
|
||||
*
|
||||
* THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND,
|
||||
* EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF
|
||||
* MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT.
|
||||
* IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY
|
||||
* CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION OF CONTRACT,
|
||||
* TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION WITH THE
|
||||
* SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE.
|
||||
*/
|
||||
|
||||
#ifndef TRITON_DIALECT_TRITONNVIDIAGPU_TRANSFORMS_UTILITY_H_
|
||||
#define TRITON_DIALECT_TRITONNVIDIAGPU_TRANSFORMS_UTILITY_H_
|
||||
|
||||
#include "mlir/IR/Matchers.h"
|
||||
#include "mlir/IR/PatternMatch.h"
|
||||
#include "llvm/ADT/MapVector.h"
|
||||
|
||||
namespace mlir {
|
||||
|
||||
// 0 is reserved for default sync.
|
||||
// TODO: comprehensive mechanism to globally manage namedbarrier.
|
||||
static int const nameBarrierIdBegin = 1;
|
||||
static int nameBarrierIdEnd = 16;
|
||||
|
||||
/// Helper functions for async agent
|
||||
typedef int AgentId;
|
||||
SmallVector<AgentId> getAgentIds(Operation *op);
|
||||
bool hasAgentId(Operation *op, AgentId agentId);
|
||||
void setAgentIds(Operation *op, ArrayRef<AgentId> agentIds);
|
||||
SmallVector<AgentId> collectAgentIds(Operation *op);
|
||||
void addAgentIds(Operation *op, ArrayRef<int> agents);
|
||||
SmallVector<int> getMutexBarIds(Operation *op);
|
||||
SmallVector<int> getMutexNumThreads(Operation *op);
|
||||
|
||||
class OpBuilderWithAgentIds : public OpBuilder {
|
||||
public:
|
||||
OpBuilderWithAgentIds(MLIRContext *context) : OpBuilder(context) {}
|
||||
|
||||
void setAgentIdsFromArray(ArrayRef<AgentId> newAgentIds) {
|
||||
agentIds = SmallVector<AgentId>(newAgentIds.begin(), newAgentIds.end());
|
||||
}
|
||||
|
||||
void setAgentIdsFromOp(Operation *op) {
|
||||
setAgentIdsFromArray(getAgentIds(op));
|
||||
}
|
||||
|
||||
void setAgentIdsFromValueUsers(Value value) {
|
||||
SetVector<AgentId> agentIdSet;
|
||||
for (Operation *user : value.getUsers())
|
||||
for (AgentId agentId : getAgentIds(user))
|
||||
agentIdSet.insert(agentId);
|
||||
setAgentIdsFromArray(agentIdSet.getArrayRef());
|
||||
}
|
||||
|
||||
template <typename OpTy, typename... Args>
|
||||
OpTy createWithAgentIds(Args &&...args) {
|
||||
OpTy op = create<OpTy>(std::forward<Args>(args)...);
|
||||
if (!agentIds.empty())
|
||||
setAgentIds(op, agentIds);
|
||||
return op;
|
||||
}
|
||||
|
||||
private:
|
||||
SmallVector<AgentId> agentIds;
|
||||
};
|
||||
|
||||
/// Constant agent ids
|
||||
constexpr AgentId kLoadAgentId = 0;
|
||||
constexpr AgentId kDotAgentId = 1;
|
||||
|
||||
bool isWSCandidateLoad(Operation *op);
|
||||
bool isWSSupported(ModuleOp m, int computeCapability);
|
||||
|
||||
LogicalResult getDependentValues(Value val, DenseSet<Value> &depSet,
|
||||
const DenseSet<Value> &stopSet = {});
|
||||
LogicalResult getDependentValues(Operation *op, DenseSet<Value> &depSet,
|
||||
const DenseSet<Value> &stopSet = {});
|
||||
DenseSet<Operation *> getDependentOps(DenseSet<Value> &depSet);
|
||||
|
||||
} // namespace mlir
|
||||
|
||||
#endif // TRITON_DIALECT_TRITONNVIDIAGPU_TRANSFORMS_UTILITY_H_
|
||||
19
include/triton/Target/AMDGCN/AMDGCNTranslation.h
Normal file
19
include/triton/Target/AMDGCN/AMDGCNTranslation.h
Normal file
@@ -0,0 +1,19 @@
|
||||
#ifndef TRITON_TARGET_AMDGCNTRANSLATION_H
|
||||
#define TRITON_TARGET_AMDGCNTRANSLATION_H
|
||||
|
||||
#include <string>
|
||||
#include <tuple>
|
||||
|
||||
namespace llvm {
|
||||
class Module;
|
||||
} // namespace llvm
|
||||
|
||||
namespace triton {
|
||||
|
||||
// Translate LLVM IR to AMDGCN code.
|
||||
std::tuple<std::string, std::string>
|
||||
translateLLVMIRToAMDGCN(llvm::Module &module, std::string cc);
|
||||
|
||||
} // namespace triton
|
||||
|
||||
#endif
|
||||
@@ -1,5 +1,7 @@
|
||||
#ifndef TRITON_TARGET_LLVM_IR_LLVM_IR_TRANSLATION_H
|
||||
#define TRITON_TARGET_LLVM_IR_LLVM_IR_TRANSLATION_H
|
||||
#include "triton/Conversion/TritonGPUToLLVM/TritonGPUToLLVMPass.h"
|
||||
#include "triton/Target/PTX/TmaMetadata.h"
|
||||
#include "llvm/ADT/StringRef.h"
|
||||
#include <memory>
|
||||
#include <string>
|
||||
@@ -26,12 +28,16 @@ void addExternalLibs(mlir::ModuleOp &module,
|
||||
std::unique_ptr<llvm::Module>
|
||||
translateTritonGPUToLLVMIR(llvm::LLVMContext *llvmContext,
|
||||
mlir::ModuleOp module, int computeCapability,
|
||||
bool isROCM);
|
||||
mlir::triton::gpu::TMAMetadataTy &tmaInfos,
|
||||
Target target);
|
||||
|
||||
// Translate mlir LLVM dialect to LLVMIR, return null if failed.
|
||||
std::unique_ptr<llvm::Module>
|
||||
translateLLVMToLLVMIR(llvm::LLVMContext *llvmContext, mlir::ModuleOp module,
|
||||
bool isROCM);
|
||||
Target target);
|
||||
|
||||
bool linkExternLib(llvm::Module &module, llvm::StringRef name,
|
||||
llvm::StringRef path, Target target);
|
||||
|
||||
} // namespace triton
|
||||
} // namespace mlir
|
||||
|
||||
107
include/triton/Target/PTX/TmaMetadata.h
Normal file
107
include/triton/Target/PTX/TmaMetadata.h
Normal file
@@ -0,0 +1,107 @@
|
||||
/*
|
||||
* Copyright (c) 2023 NVIDIA Corporation & Affiliates. All rights reserved.
|
||||
*
|
||||
* Permission is hereby granted, free of charge, to any person obtaining
|
||||
* a copy of this software and associated documentation files
|
||||
* (the "Software"), to deal in the Software without restriction,
|
||||
* including without limitation the rights to use, copy, modify, merge,
|
||||
* publish, distribute, sublicense, and/or sell copies of the Software,
|
||||
* and to permit persons to whom the Software is furnished to do so,
|
||||
* subject to the following conditions:
|
||||
*
|
||||
* The above copyright notice and this permission notice shall be
|
||||
* included in all copies or substantial portions of the Software.
|
||||
*
|
||||
* THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND,
|
||||
* EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF
|
||||
* MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT.
|
||||
* IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY
|
||||
* CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION OF CONTRACT,
|
||||
* TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION WITH THE
|
||||
* SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE.
|
||||
*/
|
||||
|
||||
#ifndef TRITON_TARGET_PTX_TMAMETADATA_H
|
||||
#define TRITON_TARGET_PTX_TMAMETADATA_H
|
||||
|
||||
#include "python/triton/third_party/cuda/include/cuda.h"
|
||||
#include "llvm/ADT/StringRef.h"
|
||||
#include "llvm/Support/Debug.h"
|
||||
#include "llvm/Support/Format.h"
|
||||
#include "llvm/Support/FormatVariadic.h"
|
||||
#include <map>
|
||||
#include <utility>
|
||||
#include <vector>
|
||||
|
||||
namespace mlir {
|
||||
namespace triton {
|
||||
namespace gpu {
|
||||
|
||||
struct TMAInfo {
|
||||
// --------------------------------------------
|
||||
// informations to be filled into CUtensorMaps
|
||||
int tensorDataType;
|
||||
|
||||
uint32_t tensorRank;
|
||||
|
||||
// the argument indices for the runtime to get globalAddresses
|
||||
size_t globalAddressArgIdx;
|
||||
|
||||
// the argument indices for the runtime to get globalDims, -1 stands for this
|
||||
// dim is padded
|
||||
std::vector<int32_t> globalDimsArgIdx;
|
||||
|
||||
// the argument indices for the runtime to get globalStrides, -1 stands for
|
||||
// this dim is padded the runtime need to map the value to internal format
|
||||
std::vector<int32_t> globalStridesArgIdx;
|
||||
|
||||
std::vector<uint32_t> boxDims;
|
||||
|
||||
std::vector<uint32_t> elementStrides;
|
||||
|
||||
int interleave;
|
||||
|
||||
int swizzle;
|
||||
|
||||
int l2Promotion;
|
||||
|
||||
int oobFill;
|
||||
|
||||
// --------------------------------------------
|
||||
// the argument indices for the runtime to send the address of tma_desc to the
|
||||
// binary
|
||||
int TMADescArgIdx;
|
||||
|
||||
template <typename T>
|
||||
void dump_vec(const std::vector<T> &vec, llvm::StringRef info) const {
|
||||
llvm::errs() << info << ": ";
|
||||
for (const T &e : vec)
|
||||
llvm::errs() << e << ",";
|
||||
llvm::errs() << "\n";
|
||||
}
|
||||
|
||||
void dump() const {
|
||||
llvm::errs() << "TMA Info: ----------"
|
||||
<< "\n";
|
||||
llvm::errs() << "-- tensorDataType: " << tensorDataType
|
||||
<< ", tensorRank: " << tensorRank << "\n";
|
||||
llvm::errs() << "-- globalAddressArgIdx: " << globalAddressArgIdx << "\n";
|
||||
llvm::errs() << "-- TMADescArgIdx: " << TMADescArgIdx << "\n";
|
||||
dump_vec<int32_t>(globalDimsArgIdx, "-- globalDimsArgIdx");
|
||||
dump_vec<int32_t>(globalStridesArgIdx, "-- globalStridesArgIdx");
|
||||
dump_vec<uint32_t>(boxDims, "-- boxDims");
|
||||
dump_vec<uint32_t>(elementStrides, "-- elementStrides");
|
||||
llvm::errs() << "-- interleave: " << interleave << "\n";
|
||||
llvm::errs() << "-- swizzle: " << swizzle << "\n";
|
||||
llvm::errs() << "-- l2Promotion: " << l2Promotion << "\n";
|
||||
llvm::errs() << "-- oobFill: " << oobFill << "\n";
|
||||
};
|
||||
};
|
||||
|
||||
using TMAMetadataTy = std::vector<TMAInfo>;
|
||||
|
||||
} // namespace gpu
|
||||
} // namespace triton
|
||||
} // namespace mlir
|
||||
|
||||
#endif // TRITON_TARGET_PTX_TMAMETADATA_H
|
||||
@@ -24,10 +24,16 @@
|
||||
|
||||
#include <algorithm>
|
||||
#include <cstdlib>
|
||||
#include <set>
|
||||
#include <string>
|
||||
|
||||
namespace triton {
|
||||
|
||||
const std::set<std::string> ENV_VARS = {
|
||||
"ENABLE_MMA_V3", "TRITON_DISABLE_LINE_INFO", "DISABLE_FAST_REDUCTION",
|
||||
"ENABLE_TMA", "MLIR_ENABLE_DUMP", "LLVM_IR_ENABLE_DUMP",
|
||||
"AMDGCN_ENABLE_DUMP"};
|
||||
|
||||
namespace tools {
|
||||
|
||||
inline std::string getenv(const char *name) {
|
||||
@@ -39,6 +45,9 @@ inline std::string getenv(const char *name) {
|
||||
}
|
||||
|
||||
inline bool getBoolEnv(const std::string &env) {
|
||||
std::string msg = "Environment variable " + env + " is not recognized";
|
||||
assert(triton::ENV_VARS.find(env.c_str()) != triton::ENV_VARS.end() &&
|
||||
msg.c_str());
|
||||
const char *s = std::getenv(env.c_str());
|
||||
std::string str(s ? s : "");
|
||||
std::transform(str.begin(), str.end(), str.begin(),
|
||||
|
||||
Reference in New Issue
Block a user