upgrade llvm to b1115f8c (NFC) (#2403)

Co-authored-by: Thomas Raoux <thomas.raoux@openai.com>
Co-authored-by: Keren Zhou <kerenzhou@openai.com>
Co-authored-by: Phil Tillet <phil@openai.com>
This commit is contained in:
Mehdi Amini
2023-10-16 16:38:49 -07:00
committed by GitHub
parent 87a223d76f
commit 721897fcc4
36 changed files with 128 additions and 109 deletions

View File

@@ -68,10 +68,10 @@ arbitrary LLVM version.
1. Find the version of LLVM that Triton builds against. Check `python/setup.py`
for a line like
version = "llvm-17.0.0-c5dede880d17"
version = "llvmorg-18-init-7000-g76ce4736721a"
This means that the version of Triton you have builds against
[LLVM](https://github.com/llvm/llvm-project) c5dede880d17.
[LLVM](https://github.com/llvm/llvm-project) 76ce4736721a.
2. `git checkout` LLVM at this revision. Optionally, make additional
modifications to LLVM.

View File

@@ -63,11 +63,12 @@ private:
// Shared Memory Alias Analysis
//===----------------------------------------------------------------------===//
class SharedMemoryAliasAnalysis
: public dataflow::SparseDataFlowAnalysis<dataflow::Lattice<AliasInfo>> {
: public dataflow::SparseForwardDataFlowAnalysis<
dataflow::Lattice<AliasInfo>> {
public:
using dataflow::SparseDataFlowAnalysis<
dataflow::Lattice<AliasInfo>>::SparseDataFlowAnalysis;
using dataflow::SparseDataFlowAnalysis<
using dataflow::SparseForwardDataFlowAnalysis<
dataflow::Lattice<AliasInfo>>::SparseForwardDataFlowAnalysis;
using dataflow::SparseForwardDataFlowAnalysis<
dataflow::Lattice<AliasInfo>>::getLatticeElement;
/// XXX(Keren): Compatible interface with MLIR AliasAnalysis for future use.

View File

@@ -271,8 +271,8 @@ private:
std::vector<std::unique_ptr<AxisInfoVisitor>> visitors;
};
class AxisInfoAnalysis
: public dataflow::SparseDataFlowAnalysis<dataflow::Lattice<AxisInfo>> {
class AxisInfoAnalysis : public dataflow::SparseForwardDataFlowAnalysis<
dataflow::Lattice<AxisInfo>> {
private:
AxisInfoVisitorList visitors;
@@ -284,7 +284,7 @@ private:
public:
AxisInfoAnalysis(DataFlowSolver &solver);
using dataflow::SparseDataFlowAnalysis<
using dataflow::SparseForwardDataFlowAnalysis<
dataflow::Lattice<AxisInfo>>::getLatticeElement;
using FuncAxisInfoMapT = DenseMap<FunctionOpInterface, AxisInfo>;

View File

@@ -9,9 +9,8 @@
#include "mlir/Dialect/Tensor/IR/Tensor.h"
#include "mlir/IR/BuiltinOps.h"
#include "mlir/IR/Dialect.h"
#include "mlir/IR/FunctionInterfaces.h"
#include "mlir/Interfaces/CallInterfaces.h"
#include "mlir/Interfaces/ControlFlowInterfaces.h"
#include "mlir/Interfaces/FunctionInterfaces.h"
#include "triton/Dialect/Triton/IR/Dialect.h.inc"
#include "triton/Dialect/Triton/IR/OpsEnums.h.inc"
#include "triton/Dialect/Triton/IR/Traits.h"

View File

@@ -6,11 +6,11 @@ include "triton/Dialect/Triton/IR/TritonTypes.td"
include "triton/Dialect/Triton/IR/TritonAttrDefs.td"
include "triton/Dialect/Triton/IR/TritonInterfaces.td"
include "mlir/IR/OpBase.td"
include "mlir/IR/FunctionInterfaces.td" // FunctionOpInterface
include "mlir/IR/SymbolInterfaces.td" // SymbolUserOpInterface
include "mlir/IR/OpAsmInterface.td" // OpAsmOpInterface
include "mlir/Interfaces/CallInterfaces.td" // CallOpInterface
include "mlir/Interfaces/CastInterfaces.td" // CastOpInterface
include "mlir/Interfaces/FunctionInterfaces.td" // FunctionOpInterface
include "mlir/Interfaces/SideEffectInterfaces.td" // Pure
include "mlir/Interfaces/ControlFlowInterfaces.td" // BranchOpInterface
include "mlir/Interfaces/InferTypeOpInterface.td" // SameOperandsAndResultType
@@ -668,6 +668,12 @@ def CallOp : TT_Op<"call", [CallOpInterface, /*MemRefsNormalizable, */DeclareOpI
void setCalleeFromCallable(CallInterfaceCallable callee) {
(*this)->setAttr("callee", callee.get<SymbolRefAttr>());
}
// Required by CallOpInterface.
MutableOperandRange getArgOperandsMutable() {
return getOperandsMutable();
}
}];
let assemblyFormat = [{

View File

@@ -885,7 +885,8 @@ public:
//===----------------------------------------------------------------------===//
AxisInfoAnalysis::AxisInfoAnalysis(DataFlowSolver &solver)
: dataflow::SparseDataFlowAnalysis<dataflow::Lattice<AxisInfo>>(solver) {
: dataflow::SparseForwardDataFlowAnalysis<dataflow::Lattice<AxisInfo>>(
solver) {
// UnrealizedConversionCast:
// This is needed by TritonGPUToLLVM, to get AxisInfo when the graph is
// in the process of a PartialConversion, where UnrealizedConversionCast

View File

@@ -640,7 +640,10 @@ SetVector<Operation *> multiRootGetSlice(Operation *op,
auto *currentOp = (slice)[currentIndex];
// Compute and insert the backwardSlice starting from currentOp.
backwardSlice.clear();
getBackwardSlice(currentOp, &backwardSlice, backwardFilter);
mlir::BackwardSliceOptions opt;
opt.omitBlockArguments = true;
opt.filter = backwardFilter;
getBackwardSlice(currentOp, &backwardSlice, opt);
slice.insert(backwardSlice.begin(), backwardSlice.end());
// Compute and insert the forwardSlice starting from currentOp.

View File

@@ -14,7 +14,7 @@ add_mlir_conversion_library(NVGPUToLLVM
LINK_LIBS PUBLIC
MLIRIR
MLIRPass
MLIRGPUOps
MLIRGPUDialect
MLIRGPUToNVVMTransforms
MLIRGPUToROCDLTransforms
MLIRGPUTransforms

View File

@@ -43,7 +43,7 @@ add_mlir_conversion_library(TritonGPUToLLVM
ASMBuilder
MLIRIR
MLIRPass
MLIRGPUOps
MLIRGPUDialect
MLIRGPUToNVVMTransforms
MLIRGPUToROCDLTransforms
MLIRGPUTransforms

View File

@@ -1314,18 +1314,18 @@ void populateElementwiseOpToLLVMPatterns(
POPULATE_BINARY_OP(arith::RemFOp, LLVM::FRemOp) // %
POPULATE_BINARY_OP(arith::RemSIOp, LLVM::SRemOp)
POPULATE_BINARY_OP(arith::RemUIOp, LLVM::URemOp)
POPULATE_BINARY_OP(arith::AndIOp, LLVM::AndOp) // &
POPULATE_BINARY_OP(arith::OrIOp, LLVM::OrOp) // |
POPULATE_BINARY_OP(arith::XOrIOp, LLVM::XOrOp) // ^
POPULATE_BINARY_OP(arith::ShLIOp, LLVM::ShlOp) // <<
POPULATE_BINARY_OP(arith::ShRSIOp, LLVM::AShrOp) // >>
POPULATE_BINARY_OP(arith::ShRUIOp, LLVM::LShrOp) // >>
POPULATE_BINARY_OP(arith::MinFOp, LLVM::MinNumOp) // fmin
POPULATE_BINARY_OP(arith::MaxFOp, LLVM::MaxNumOp) // fmax
POPULATE_BINARY_OP(arith::MinSIOp, LLVM::SMinOp) // smin
POPULATE_BINARY_OP(arith::MaxSIOp, LLVM::SMaxOp) // smax
POPULATE_BINARY_OP(arith::MinUIOp, LLVM::UMinOp) // umin
POPULATE_BINARY_OP(arith::MaxUIOp, LLVM::UMaxOp) // umax
POPULATE_BINARY_OP(arith::AndIOp, LLVM::AndOp) // &
POPULATE_BINARY_OP(arith::OrIOp, LLVM::OrOp) // |
POPULATE_BINARY_OP(arith::XOrIOp, LLVM::XOrOp) // ^
POPULATE_BINARY_OP(arith::ShLIOp, LLVM::ShlOp) // <<
POPULATE_BINARY_OP(arith::ShRSIOp, LLVM::AShrOp) // >>
POPULATE_BINARY_OP(arith::ShRUIOp, LLVM::LShrOp) // >>
POPULATE_BINARY_OP(arith::MinimumFOp, LLVM::MinNumOp) // fmin
POPULATE_BINARY_OP(arith::MaximumFOp, LLVM::MaxNumOp) // fmax
POPULATE_BINARY_OP(arith::MinSIOp, LLVM::SMinOp) // smin
POPULATE_BINARY_OP(arith::MaxSIOp, LLVM::SMaxOp) // smax
POPULATE_BINARY_OP(arith::MinUIOp, LLVM::UMinOp) // umin
POPULATE_BINARY_OP(arith::MaxUIOp, LLVM::UMaxOp) // umax
#undef POPULATE_BINARY_OP
#define POPULATE_UNARY_OP(SRC_OP, DST_OP) \

View File

@@ -145,7 +145,8 @@ protected:
}
auto newFuncOp = rewriter.create<LLVM::LLVMFuncOp>(
funcOp.getLoc(), funcOp.getName(), llvmType, linkage,
/*dsoLocal*/ false, LLVM::CConv::C, attributes);
/*dsoLocal*/ false, LLVM::CConv::C, /*comdat=*/SymbolRefAttr{},
attributes);
rewriter.inlineRegionBefore(funcOp.getBody(), newFuncOp.getBody(),
newFuncOp.end());
if (failed(rewriter.convertRegionTypes(&newFuncOp.getBody(), *typeConverter,

View File

@@ -89,7 +89,7 @@
ptxBuilder.launch(rewriter, op->getLoc(), voidTy); \
} while (0)
#define undef(...) rewriter.create<LLVM::UndefOp>(loc, __VA_ARGS__)
#define null(...) rewriter.create<LLVM::NullOp>(loc, __VA_ARGS__)
#define null(...) rewriter.create<LLVM::ZeroOp>(loc, __VA_ARGS__)
#define call(...) rewriter.create<LLVM::CallOp>(loc, __VA_ARGS__)
// Types

View File

@@ -115,8 +115,8 @@ void populateArithPatternsAndLegality(TritonGPUTypeConverter &typeConverter,
// Floating point
GenericOpPattern<arith::AddFOp>, GenericOpPattern<arith::SubFOp>,
// MaxMin
GenericOpPattern<arith::MaxFOp>, GenericOpPattern<arith::MaxSIOp>,
GenericOpPattern<arith::MaxUIOp>, GenericOpPattern<arith::MinFOp>,
GenericOpPattern<arith::MaximumFOp>, GenericOpPattern<arith::MaxSIOp>,
GenericOpPattern<arith::MaxUIOp>, GenericOpPattern<arith::MinimumFOp>,
GenericOpPattern<arith::MinSIOp>, GenericOpPattern<arith::MinUIOp>,
// Floating point
GenericOpPattern<arith::MulFOp>, GenericOpPattern<arith::DivFOp>,

View File

@@ -1,9 +1,9 @@
#include "mlir/IR/Builders.h"
#include "mlir/IR/BuiltinAttributes.h"
#include "mlir/IR/BuiltinTypes.h"
#include "mlir/IR/FunctionImplementation.h"
#include "mlir/IR/FunctionInterfaces.h"
#include "mlir/IR/OperationSupport.h"
#include "mlir/Interfaces/FunctionImplementation.h"
#include "mlir/Interfaces/FunctionInterfaces.h"
#include "triton/Dialect/Triton/IR/Dialect.h"
#include "triton/Dialect/Triton/IR/Types.h"

View File

@@ -8,6 +8,6 @@ add_mlir_dialect_library(TritonGPUIR
TritonGPUAttrDefsIncGen
LINK_LIBS PUBLIC
MLIRGPUOps
MLIRGPUDialect
TritonIR
)

View File

@@ -1360,7 +1360,7 @@ ParseResult parseInsertSliceOp(OpAsmParser &parser, OperationState &result) {
result.operands))
return failure();
// Deduce operand_segment_sizes from the number of the operands.
// Deduce operandSegmentSizes from the number of the operands.
auto operandSegmentSizesAttrName =
OpT::getOperandSegmentSizesAttrName(result.name);
result.addAttribute(
@@ -1373,7 +1373,7 @@ template <class OpT>
void printInsertSliceOp(OpAsmPrinter &printer, OpT insertSliceOp) {
printer << " ";
printer << insertSliceOp.getOperation()->getOperands();
// "operand_segment_sizes" can be deduced, so we don't print it.
// "operandSegmentSizes" can be deduced, so we don't print it.
printer.printOptionalAttrDict(
insertSliceOp->getAttrs(),
{insertSliceOp.getOperandSegmentSizesAttrName()});

View File

@@ -117,7 +117,10 @@ class BlockedToMMA : public mlir::RewritePattern {
int finalBitWidth = getElementTypeOrSelf(x).getIntOrFloatBitWidth();
int origBitWidth = finalBitWidth;
SetVector<Operation *> slice;
mlir::getBackwardSlice(x, &slice, bwdFilter);
mlir::BackwardSliceOptions opt;
opt.omitBlockArguments = true;
opt.filter = bwdFilter;
getBackwardSlice(x, &slice, opt);
Operation *firstOp = slice.empty() ? nullptr : *slice.begin();
if (firstOp)
if (Value arg = firstOp->getOperand(0))
@@ -213,8 +216,11 @@ public:
if (versionMajor == 1) {
SetVector<Operation *> aBwdSlices, bBwdSlices;
auto isCvt = [](Operation *op) { return isa<ConvertLayoutOp>(op); };
getBackwardSlice(a, &aBwdSlices, {isCvt});
getBackwardSlice(b, &bBwdSlices, {isCvt});
mlir::BackwardSliceOptions opt;
opt.omitBlockArguments = true;
opt.filter = isCvt;
getBackwardSlice(a, &aBwdSlices, opt);
getBackwardSlice(b, &bBwdSlices, opt);
// get the source of the first conversion found in slices
auto getCvtArgOrder = [](Operation *op) {
return cast<ConvertLayoutOp>(op)

View File

@@ -98,7 +98,9 @@ public:
// and all operations between the load and the conversion
// should be layout preserving
SetVector<Operation *> slice;
getBackwardSlice(op, &slice);
mlir::BackwardSliceOptions opt;
opt.omitBlockArguments = true;
getBackwardSlice(op, &slice, opt);
int loadIdx = -1;
bool checkOp = false;
for (int i = 0; i < slice.size(); i++) {

View File

@@ -91,7 +91,7 @@ private:
// suport ForOp only
if (auto forOp = dyn_cast<scf::ForOp>(argOwner)) {
// prologue
auto iterOperands = forOp.getIterOperands();
auto iterOperands = forOp.getInitArgs();
if (argNum == 0)
return false;
if (dependOnSharedEncOperand(iterOperands[argNum - 1]))

View File

@@ -628,12 +628,13 @@ bool CTAPlanner::isElementwiseOp(Operation *op) const {
arith::CeilDivUIOp, arith::DivFOp, arith::DivSIOp,
arith::DivUIOp, arith::ExtFOp, arith::ExtSIOp, arith::ExtUIOp,
arith::FloorDivSIOp, arith::FPToSIOp, arith::FPToUIOp,
arith::MaxFOp, arith::MaxSIOp, arith::MaxUIOp, arith::MinFOp,
arith::MinSIOp, arith::MinUIOp, arith::MulFOp, arith::MulIOp,
arith::NegFOp, arith::OrIOp, arith::RemFOp, arith::RemSIOp,
arith::RemUIOp, arith::ShLIOp, arith::ShRSIOp, arith::ShRUIOp,
arith::SIToFPOp, arith::SubFOp, arith::SubIOp, arith::TruncFOp,
arith::TruncIOp, arith::UIToFPOp, arith::XOrIOp>(op))
arith::MaximumFOp, arith::MaxSIOp, arith::MaxUIOp,
arith::MinimumFOp, arith::MinSIOp, arith::MinUIOp,
arith::MulFOp, arith::MulIOp, arith::NegFOp, arith::OrIOp,
arith::RemFOp, arith::RemSIOp, arith::RemUIOp, arith::ShLIOp,
arith::ShRSIOp, arith::ShRUIOp, arith::SIToFPOp, arith::SubFOp,
arith::SubIOp, arith::TruncFOp, arith::TruncIOp,
arith::UIToFPOp, arith::XOrIOp>(op))
return true;
if (llvm::isa<math::AbsFOp, math::AbsIOp, math::AtanOp, math::Atan2Op,
math::CeilOp, math::CopySignOp, math::CosOp, math::SinOp,

View File

@@ -220,7 +220,9 @@ public:
SetVector<Operation *> backwardSlice;
mod.walk([&](triton::MakeTensorPtrOp op) -> void {
assert(isa<triton::FuncOp>(op->getParentOp()));
getBackwardSlice(op.getOperation(), &backwardSlice);
mlir::BackwardSliceOptions opt;
opt.omitBlockArguments = true;
getBackwardSlice(op.getOperation(), &backwardSlice, opt);
op->removeAttr("async_agent");
});
for (auto op : backwardSlice) {

View File

@@ -259,7 +259,7 @@ scf::ForOp createNewMathLoop(scf::ForOp forOp, int numStages,
// 3. create newLoopArgs
SmallVector<Value> newLoopArgs;
for (auto operand : forOp.getIterOperands())
for (auto operand : forOp.getInitArgs())
newLoopArgs.push_back(operand);
builder.setInsertionPoint(forOp);

View File

@@ -93,7 +93,7 @@ std::string translateLLVMIRToPTX(llvm::Module &module, int cc, int version,
opt.TrapUnreachable = true;
std::unique_ptr<llvm::TargetMachine> machine{target->createTargetMachine(
module.getTargetTriple(), proc, features, opt, llvm::Reloc::PIC_,
std::nullopt, llvm::CodeGenOpt::Aggressive)};
std::nullopt, llvm::CodeGenOptLevel::Aggressive)};
// set data layout
if (layout.empty())
module.setDataLayout(machine->createDataLayout());
@@ -109,7 +109,7 @@ std::string translateLLVMIRToPTX(llvm::Module &module, int cc, int version,
llvm::legacy::PassManager pass;
// emit
machine->addPassesToEmitFile(pass, pstream, nullptr,
llvm::CodeGenFileType::CGFT_AssemblyFile);
llvm::CodeGenFileType::AssemblyFile);
pass.run(module);
}
// post-process

View File

@@ -73,23 +73,20 @@ def get_llvm_package_info():
if arch == 'aarch64':
arch = 'arm64'
if system == "Darwin":
system_suffix = "apple-darwin"
arch = platform.machine()
system_suffix = f"macos-{arch}"
elif system == "Linux":
# TODO: arm64
vglibc = tuple(map(int, platform.libc_ver()[1].split('.')))
vglibc = vglibc[0] * 100 + vglibc[1]
linux_suffix = 'ubuntu-18.04' if vglibc > 217 else 'centos-7'
system_suffix = f"linux-gnu-{linux_suffix}"
system_suffix = 'ubuntu-x64' if vglibc > 217 else 'centos-x64'
else:
return Package("llvm", "LLVM-C.lib", "", "LLVM_INCLUDE_DIRS", "LLVM_LIBRARY_DIR", "LLVM_SYSPATH")
use_assert_enabled_llvm = check_env_flag("TRITON_USE_ASSERT_ENABLED_LLVM", "False")
release_suffix = "assert" if use_assert_enabled_llvm else "release"
name = f'llvm+mlir-17.0.0-{arch}-{system_suffix}-{release_suffix}'
version = "llvm-17.0.0-c5dede880d17"
url = f"https://github.com/ptillet/triton-llvm-releases/releases/download/{version}/{name}.tar.xz"
# FIXME: remove the following once github.com/ptillet/triton-llvm-releases has arm64 llvm releases
if arch == 'arm64' and 'linux' in system_suffix:
url = f"https://github.com/acollins3/triton-llvm-releases/releases/download/{version}/{name}.tar.xz"
# use_assert_enabled_llvm = check_env_flag("TRITON_USE_ASSERT_ENABLED_LLVM", "False")
# release_suffix = "assert" if use_assert_enabled_llvm else "release"
rev = "b1115f8c"
name = f"llvm-{rev}-{system_suffix}"
url = f"https://tritonlang.blob.core.windows.net/llvm-builds/{name}.tar.gz"
return Package("llvm", name, url, "LLVM_INCLUDE_DIRS", "LLVM_LIBRARY_DIR", "LLVM_SYSPATH")

View File

@@ -1127,7 +1127,7 @@ void init_triton_ir(py::module &&m) {
.def("create_minf",
[](TritonOpBuilder &self, mlir::Value &lhs,
mlir::Value &rhs) -> mlir::Value {
return mlir::Value(self.create<mlir::arith::MinFOp>(lhs, rhs));
return mlir::Value(self.create<mlir::arith::MinimumFOp>(lhs, rhs));
})
.def("create_maxsi",
[](TritonOpBuilder &self, mlir::Value &lhs,
@@ -1142,7 +1142,7 @@ void init_triton_ir(py::module &&m) {
.def("create_maxf",
[](TritonOpBuilder &self, mlir::Value &lhs,
mlir::Value &rhs) -> mlir::Value {
return mlir::Value(self.create<mlir::arith::MaxFOp>(lhs, rhs));
return mlir::Value(self.create<mlir::arith::MaximumFOp>(lhs, rhs));
})
// AddPtr (similar to GEP)
.def("create_addptr",

View File

@@ -1914,7 +1914,7 @@ def test_reduce_layouts(M, N, src_layout, axis, reduce2d, dtype_str, reduce_op,
ty = {"int32": "i32", "float32": "f32", "float16": "f16"}[dtype_str]
arith_op = {
"max": {"int32": "arith.maxsi", "float32": "arith.maxf", "float16": "arith.maxf"},
"max": {"int32": "arith.maxsi", "float32": "arith.maximumf", "float16": "arith.maximumf"},
"sum": {"int32": "arith.addi", "float32": "arith.addf", "float16": "arith.addf"}
}[reduce_op][dtype_str]
numpy_op = {

View File

@@ -216,10 +216,10 @@ tt.func @for_for_if(%lb : index, %ub : index, %step : index, %A : !tt.ptr<f16>,
// CHECK-NEXT: %arg9 -> %cst_1
// CHECK-NEXT: %0#0 -> %cst
// CHECK-NEXT: %0#1 -> %cst_0
// CHECK-NEXT: %0#2 -> %cst_2,%cst_2
// CHECK-NEXT: %0#2 -> %cst_1,%cst_2,%cst_2
%a_shared, %b_shared, %c_shared = scf.for %iv = %lb to %ub step %step iter_args(%a_shared = %a_shared_init, %b_shared = %b_shared_init, %c_shared = %c_shared_init) -> (tensor<128x32xf16, #A_SHARED>, tensor<128x32xf16, #A_SHARED>, tensor<128x32xf16, #A_SHARED>) {
// CHECK-NEXT: %arg11 -> %cst_1,%cst_2,%cst_2
// CHECK-NEXT: %1 -> %cst_2,%cst_2
// CHECK-NEXT: %1 -> %cst_1,%cst_2,%cst_2
%c_shared_next = scf.for %jv = %lb to %ub step %step iter_args(%c_shared_next = %c_shared) -> (tensor<128x32xf16, #A_SHARED>) {
// CHECK-NEXT: %2 -> %cst_2,%cst_2
%c_shared_next_next = scf.if %i1 -> tensor<128x32xf16, #A_SHARED> {
@@ -257,9 +257,9 @@ tt.func @cf_for(%arg0: index, %arg1: index, %arg2: index, %arg3: !tt.ptr<f16>, %
%5 = arith.cmpi slt, %1, %arg1 : index
cf.cond_br %5, ^bb2, ^bb3
^bb2: // pred: ^bb1
%6 = tt.cat %cst, %cst_0 {axis = 0 : i64} : (tensor<128x32xf16, #A_SHARED>, tensor<128x32xf16, #A_SHARED>) -> tensor<256x32xf16, #blocked>
%6 = tt.cat %cst, %cst_0 {axis = 0 : i64} : (tensor<128x32xf16, #A_SHARED>, tensor<128x32xf16, #A_SHARED>) -> tensor<256x32xf16>
gpu.barrier
%7 = tt.cat %2, %3 {axis = 0 : i64} : (tensor<128x32xf16, #A_SHARED>, tensor<128x32xf16, #A_SHARED>) -> tensor<256x32xf16, #blocked>
%7 = tt.cat %2, %3 {axis = 0 : i64} : (tensor<128x32xf16, #A_SHARED>, tensor<128x32xf16, #A_SHARED>) -> tensor<256x32xf16>
%8 = arith.addi %1, %arg2 : index
cf.br ^bb1(%8, %4, %2, %3 : index, tensor<128x32xf16, #A_SHARED>, tensor<128x32xf16, #A_SHARED>, tensor<128x32xf16, #A_SHARED>)
^bb3: // pred: ^bb1

View File

@@ -12,8 +12,8 @@ module attributes {"triton_gpu.num-ctas" = 4 : i32, "triton_gpu.num-warps" = 4 :
%dst = triton_gpu.alloc_tensor : tensor<1x64x64xf16, #shared>
%c0 = arith.constant 0 : i32
%src = tt.make_tensor_ptr %basePtr, [%dim0, %dim1], [%stride0, %stride1], [%coord0, %coord1] {order = array<i32: 1, 0>} : !tt.ptr<tensor<64x64xf16, #blocked>, 1>
// CHECK: nvgpu.tma_load_tiled %{{.*}}, %{{.*}}, %{{.*}}, %{{.*}}, %{{.*}}, %{{.*}}, %{{.*}} {operand_segment_sizes = array<i32: 1, 1, 1, 1, 1, 2, 0>} : !llvm.ptr<i8, 3>, !llvm.ptr<i64, 3>, !llvm.ptr<i8, 1>, i64, i1, i32, i32
%res = triton_nvidia_gpu.insert_slice_async_v2 %src, %dst, %c0, %mbar {axis = 0 : i32, cache = 1 : i32, evict = 1 : i32, isVolatile = false, operand_segment_sizes = array<i32: 1, 1, 1, 1, 0, 0>} : !tt.ptr<tensor<64x64xf16, #blocked>, 1>, tensor<1x64x64xf16, #shared>, i32, !tt.ptr<i64, 3> -> tensor<1x64x64xf16, #shared>
// CHECK: nvgpu.tma_load_tiled %{{.*}}, %{{.*}}, %{{.*}}, %{{.*}}, %{{.*}}, %{{.*}}, %{{.*}} {operandSegmentSizes = array<i32: 1, 1, 1, 1, 1, 2, 0>} : !llvm.ptr<i8, 3>, !llvm.ptr<i64, 3>, !llvm.ptr<i8, 1>, i64, i1, i32, i32
%res = triton_nvidia_gpu.insert_slice_async_v2 %src, %dst, %c0, %mbar {axis = 0 : i32, cache = 1 : i32, evict = 1 : i32, isVolatile = false, operandSegmentSizes = array<i32: 1, 1, 1, 1, 0, 0>} : !tt.ptr<tensor<64x64xf16, #blocked>, 1>, tensor<1x64x64xf16, #shared>, i32, !tt.ptr<i64, 3> -> tensor<1x64x64xf16, #shared>
tt.return
}
}
@@ -34,7 +34,7 @@ module attributes {"triton_gpu.num-ctas" = 4 : i32, "triton_gpu.num-warps" = 4 :
%src = tt.make_tensor_ptr %basePtr, [%dim0, %dim1], [%stride0, %stride1], [%coord0, %coord1] {order = array<i32: 1, 0>} : !tt.ptr<tensor<64x64xf16, #blocked>, 1>
// CHECK: %[[C15:.*]] = llvm.mlir.constant(15 : i16) : i16
// CHECK: nvgpu.tma_load_tiled %{{.*}}, %{{.*}}, %{{.*}}, %{{.*}}, %{{.*}}, %{{.*}}, %{{.*}}, %[[C15]]
%res = triton_nvidia_gpu.insert_slice_async_v2 %src, %dst, %c0, %mbar {axis = 0 : i32, cache = 1 : i32, evict = 1 : i32, isVolatile = false, operand_segment_sizes = array<i32: 1, 1, 1, 1, 0, 0>} : !tt.ptr<tensor<64x64xf16, #blocked>, 1>, tensor<1x64x64xf16, #shared>, i32, !tt.ptr<i64, 3> -> tensor<1x64x64xf16, #shared>
%res = triton_nvidia_gpu.insert_slice_async_v2 %src, %dst, %c0, %mbar {axis = 0 : i32, cache = 1 : i32, evict = 1 : i32, isVolatile = false, operandSegmentSizes = array<i32: 1, 1, 1, 1, 0, 0>} : !tt.ptr<tensor<64x64xf16, #blocked>, 1>, tensor<1x64x64xf16, #shared>, i32, !tt.ptr<i64, 3> -> tensor<1x64x64xf16, #shared>
tt.return
}
}
@@ -55,7 +55,7 @@ module attributes {"triton_gpu.num-ctas" = 4 : i32, "triton_gpu.num-warps" = 4 :
%src = tt.make_tensor_ptr %basePtr, [%dim0, %dim1], [%stride0, %stride1], [%coord0, %coord1] {order = array<i32: 1, 0>} : !tt.ptr<tensor<64x64xf16, #blocked>, 1>
// CHECK: nvgpu.cluster_id
// CHECK: nvgpu.tma_load_tiled
%res = triton_nvidia_gpu.insert_slice_async_v2 %src, %dst, %c0, %mbar {axis = 0 : i32, cache = 1 : i32, evict = 1 : i32, isVolatile = false, operand_segment_sizes = array<i32: 1, 1, 1, 1, 0, 0>} : !tt.ptr<tensor<64x64xf16, #blocked>, 1>, tensor<1x64x64xf16, #shared>, i32, !tt.ptr<i64, 3> -> tensor<1x64x64xf16, #shared>
%res = triton_nvidia_gpu.insert_slice_async_v2 %src, %dst, %c0, %mbar {axis = 0 : i32, cache = 1 : i32, evict = 1 : i32, isVolatile = false, operandSegmentSizes = array<i32: 1, 1, 1, 1, 0, 0>} : !tt.ptr<tensor<64x64xf16, #blocked>, 1>, tensor<1x64x64xf16, #shared>, i32, !tt.ptr<i64, 3> -> tensor<1x64x64xf16, #shared>
tt.return
}
}

View File

@@ -14,7 +14,7 @@ module attributes {"triton_gpu.num-warps" = 4 : i32, "triton_gpu.num-ctas" = 2 :
nvgpu.cga_barrier_arrive
nvgpu.cga_barrier_wait
%ptr = llvm.mlir.null : !llvm.ptr<i32, 3>
%ptr = llvm.mlir.zero : !llvm.ptr<i32, 3>
// CHECK: llvm.inline_asm
%v = nvgpu.cluster_id

View File

@@ -2,7 +2,7 @@
#SHARED = #triton_gpu.shared<{vec = 2, perPhase = 2, maxPhase = 4, order = [1, 0]}>
module attributes {"triton_gpu.num-warps" = 4 : i32, "triton_gpu.num-ctas" = 2 : i32} {
tt.func @test_mbarrier() {
%mbarrier = llvm.mlir.null : !llvm.ptr<i64, 3>
%mbarrier = llvm.mlir.zero : !llvm.ptr<i64, 3>
%pred = arith.constant 1 : i1
// CHECK: llvm.inline_asm
nvgpu.mbarrier_init %mbarrier, %pred { count = 32 : i32 } : !llvm.ptr<i64, 3>

View File

@@ -2,9 +2,9 @@
#SHARED = #triton_gpu.shared<{vec = 2, perPhase = 2, maxPhase = 4, order = [1, 0]}>
module attributes {"triton_gpu.num-warps" = 4 : i32, "triton_gpu.num-ctas" = 2 : i32} {
tt.func @test_tma(%im2colOffsets0 : !llvm.struct<(i16, i16)>, %im2colOffsets1 : !llvm.struct<(i16, i16, i16)>) {
%mbarrier = llvm.mlir.null : !llvm.ptr<i64, 3>
%tmaDesc = llvm.mlir.null : !llvm.ptr<i8, 1>
%dst = llvm.mlir.null : !llvm.ptr<i8, 3>
%mbarrier = llvm.mlir.zero : !llvm.ptr<i64, 3>
%tmaDesc = llvm.mlir.zero : !llvm.ptr<i8, 1>
%dst = llvm.mlir.zero : !llvm.ptr<i8, 3>
%l2desc = arith.constant 0 : i64
%c0 = arith.constant 0 : i32
%c1 = arith.constant 1 : i32
@@ -16,13 +16,13 @@ module attributes {"triton_gpu.num-warps" = 4 : i32, "triton_gpu.num-ctas" = 2
// CHECK: llvm.inline_asm {{.*}} cp.async.bulk.tensor.2d.shared::cluster.global.mbarrier::complete_tx::bytes.L2::cache_hint
// CHECK: llvm.inline_asm {{.*}} cp.async.bulk.tensor.4d.shared::cluster.global.mbarrier::complete_tx::bytes.L2::cache_hint
nvgpu.tma_load_tiled %dst, %mbarrier, %tmaDesc, %l2desc, %pred, %c0, %c1 {operand_segment_sizes = array<i32: 1, 1, 1, 1, 1, 2, 0>}: !llvm.ptr<i8, 3>, !llvm.ptr<i64, 3>, !llvm.ptr<i8, 1>, i64, i1, i32, i32
nvgpu.tma_load_tiled %dst, %mbarrier, %tmaDesc, %l2desc, %pred, %c0, %c1, %c2, %c3 {operand_segment_sizes = array<i32: 1, 1, 1, 1, 1, 4, 0>}: !llvm.ptr<i8, 3>, !llvm.ptr<i64, 3>, !llvm.ptr<i8, 1>, i64, i1, i32, i32, i32, i32
nvgpu.tma_load_tiled %dst, %mbarrier, %tmaDesc, %l2desc, %pred, %c0, %c1 {operandSegmentSizes = array<i32: 1, 1, 1, 1, 1, 2, 0>}: !llvm.ptr<i8, 3>, !llvm.ptr<i64, 3>, !llvm.ptr<i8, 1>, i64, i1, i32, i32
nvgpu.tma_load_tiled %dst, %mbarrier, %tmaDesc, %l2desc, %pred, %c0, %c1, %c2, %c3 {operandSegmentSizes = array<i32: 1, 1, 1, 1, 1, 4, 0>}: !llvm.ptr<i8, 3>, !llvm.ptr<i64, 3>, !llvm.ptr<i8, 1>, i64, i1, i32, i32, i32, i32
// CHECK: llvm.inline_asm {{.*}} cp.async.bulk.tensor.2d.shared::cluster.global.mbarrier::complete_tx::bytes.multicast::cluster.L2::cache_hint
// CHECK: llvm.inline_asm {{.*}} cp.async.bulk.tensor.4d.shared::cluster.global.mbarrier::complete_tx::bytes.L2::cache_hint
nvgpu.tma_load_tiled %dst, %mbarrier, %tmaDesc, %l2desc, %pred, %c0, %c1, %mask {operand_segment_sizes = array<i32: 1, 1, 1, 1, 1, 2, 1>}: !llvm.ptr<i8, 3>, !llvm.ptr<i64, 3>, !llvm.ptr<i8, 1>, i64, i1, i32, i32, i16
nvgpu.tma_load_tiled %dst, %mbarrier, %tmaDesc, %l2desc, %pred, %c0, %c1, %c2, %c3 {operand_segment_sizes = array<i32: 1, 1, 1, 1, 1, 4, 0>}: !llvm.ptr<i8, 3>, !llvm.ptr<i64, 3>, !llvm.ptr<i8, 1>, i64, i1, i32, i32, i32, i32
nvgpu.tma_load_tiled %dst, %mbarrier, %tmaDesc, %l2desc, %pred, %c0, %c1, %mask {operandSegmentSizes = array<i32: 1, 1, 1, 1, 1, 2, 1>}: !llvm.ptr<i8, 3>, !llvm.ptr<i64, 3>, !llvm.ptr<i8, 1>, i64, i1, i32, i32, i16
nvgpu.tma_load_tiled %dst, %mbarrier, %tmaDesc, %l2desc, %pred, %c0, %c1, %c2, %c3 {operandSegmentSizes = array<i32: 1, 1, 1, 1, 1, 4, 0>}: !llvm.ptr<i8, 3>, !llvm.ptr<i64, 3>, !llvm.ptr<i8, 1>, i64, i1, i32, i32, i32, i32
tt.return
}

View File

@@ -2,7 +2,7 @@
#SHARED = #triton_gpu.shared<{vec = 2, perPhase = 2, maxPhase = 4, order = [1, 0]}>
module attributes {"triton_gpu.num-warps" = 4 : i32, "triton_gpu.num-ctas" = 2 : i32} {
tt.func @test_tma(%opC : !llvm.struct<(f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32)>) {
%buffer = llvm.mlir.null : !llvm.ptr<i64, 3>
%buffer = llvm.mlir.zero : !llvm.ptr<i64, 3>
%height = arith.constant 16 : i32
// CHECK: llvm.ptrtoint
// CHECK: llvm.inline_asm

View File

@@ -223,7 +223,7 @@ tt.func @if_else_both_convert(%arg0: i32, %arg1: !tt.ptr<i32> {tt.divisibility =
// CHECK: [[$col_layout:#.*]] = #triton_gpu.blocked<{sizePerThread = [4, 1], threadsPerWarp = [16, 2], warpsPerCTA = [4, 1], order = [0, 1], CTAsPerCGA = [1, 1], CTASplitNum = [1, 1], CTAOrder = [0, 1]}>
// CHECK: [[$col_layout_novec:#.*]] = #triton_gpu.blocked<{sizePerThread = [1, 1], threadsPerWarp = [32, 1], warpsPerCTA = [4, 1], order = [0, 1], CTAsPerCGA = [1, 1], CTASplitNum = [1, 1], CTAOrder = [0, 1]}>
// CHECK-LABEL: transpose
// CHECK-LABEL: @transpose
module attributes {"triton_gpu.num-warps" = 4 : i32} {
tt.func @transpose(%arg0: !tt.ptr<f32> {tt.divisibility = 16 : i32}, %arg1: i32 {tt.divisibility = 16 : i32}, %arg2: !tt.ptr<f32> {tt.divisibility = 16 : i32}, %arg3: i32 {tt.divisibility = 16 : i32}) {
// CHECK-NOT: triton_gpu.convert_layout
@@ -920,7 +920,7 @@ tt.func public @mnist(%arg0: !tt.ptr<f32> {tt.divisibility = 16 : i32}, %arg1: !
%29 = "triton_gpu.select"(%28, %26, %cst_2) : (tensor<16x16xi1, #blocked2>, tensor<16x16xf32, #blocked2>, tensor<16x16xf32, #blocked2>) -> tensor<16x16xf32, #blocked2>
%30 = "tt.reduce" (%29) ({
^bb0(%arg4: f32, %arg5: f32):
%max = arith.maxf %arg4, %arg5 : f32
%max = arith.maximumf %arg4, %arg5 : f32
tt.reduce.return %max : f32
}) {axis = 1 : i32} : (tensor<16x16xf32, #blocked2>) -> tensor<16xf32, #triton_gpu.slice<{dim = 1, parent = #blocked2}>>
%31 = triton_gpu.convert_layout %30 : (tensor<16xf32, #triton_gpu.slice<{dim = 1, parent = #blocked2}>>) -> tensor<16xf32, #blocked0>
@@ -1697,11 +1697,11 @@ module attributes {"triton_gpu.compute-capability" = 90 : i32, "triton_gpu.num-c
%122 = arith.extf %121 : tensor<128x64xf16, #blocked2> to tensor<128x64xf32, #blocked2>
%123 = "tt.reduce"(%122) <{axis = 1 : i32}> ({
^bb0(%arg28: f32, %arg29: f32):
%153 = arith.maxf %arg28, %arg29 : f32
%153 = arith.maximumf %arg28, %arg29 : f32
tt.reduce.return %153 : f32
}) : (tensor<128x64xf32, #blocked2>) -> tensor<128xf32, #triton_gpu.slice<{dim = 1, parent = #blocked2}>>
%124 = triton_gpu.convert_layout %123 : (tensor<128xf32, #triton_gpu.slice<{dim = 1, parent = #blocked2}>>) -> tensor<128xf32, #blocked1>
%125 = arith.maxf %arg25, %124 : tensor<128xf32, #blocked1>
%125 = arith.maximumf %arg25, %124 : tensor<128xf32, #blocked1>
%126 = arith.subf %arg25, %125 : tensor<128xf32, #blocked1>
%127 = tt.extern_elementwise %126 {pure = true, libname = "libdevice", libpath = "/root/.pyenv/versions/3.9.9/lib/python3.9/site-packages/triton/language/../third_party/cuda/lib/libdevice.10.bc", symbol = "__nv_exp2f"} : (tensor<128xf32, #blocked1>) -> tensor<128xf32, #blocked1>
%128 = triton_gpu.convert_layout %125 : (tensor<128xf32, #blocked1>) -> tensor<128xf32, #triton_gpu.slice<{dim = 1, parent = #blocked9}>>

View File

@@ -32,19 +32,19 @@ module attributes {"triton_gpu.compute-capability" = 90 : i32, "triton_gpu.num-c
%14 = triton_nvidia_gpu.get_thread_id : i32
%15 = arith.cmpi eq, %14, %c0_i32 : i32
%16 = arith.andi %15, %10 : i1
triton_nvidia_gpu.mbarrier_arrive %13, %16 {operand_segment_sizes = array<i32: 1, 1, 0>, trackAsyncOp = false, txCount = 65536 : i32} : !tt.ptr<i64, 3>, i1
%17 = triton_nvidia_gpu.insert_slice_async_v2 %3, %11, %c0_i32, %13, %12 {axis = 0 : i32, cache = 1 : i32, evict = 1 : i32, isVolatile = false, operand_segment_sizes = array<i32: 1, 1, 1, 1, 1, 0>} : !tt.ptr<tensor<128x128xf16, #blocked>, 1>, tensor<3x128x128xf16, #shared1>, i32, !tt.ptr<i64, 3>, tensor<128x128xi1, #blocked1> -> tensor<3x128x128xf16, #shared1>
triton_nvidia_gpu.mbarrier_arrive %13, %16 {operandSegmentSizes = array<i32: 1, 1, 0>, trackAsyncOp = false, txCount = 65536 : i32} : !tt.ptr<i64, 3>, i1
%17 = triton_nvidia_gpu.insert_slice_async_v2 %3, %11, %c0_i32, %13, %12 {axis = 0 : i32, cache = 1 : i32, evict = 1 : i32, isVolatile = false, operandSegmentSizes = array<i32: 1, 1, 1, 1, 1, 0>} : !tt.ptr<tensor<128x128xf16, #blocked>, 1>, tensor<3x128x128xf16, #shared1>, i32, !tt.ptr<i64, 3>, tensor<128x128xi1, #blocked1> -> tensor<3x128x128xf16, #shared1>
%18 = triton_gpu.alloc_tensor : tensor<3x128x128xf16, #shared1>
%19 = triton_nvidia_gpu.insert_slice_async_v2 %6, %18, %c0_i32, %13, %12 {axis = 0 : i32, cache = 1 : i32, evict = 1 : i32, isVolatile = false, operand_segment_sizes = array<i32: 1, 1, 1, 1, 1, 0>} : !tt.ptr<tensor<128x128xf16, #blocked>, 1>, tensor<3x128x128xf16, #shared1>, i32, !tt.ptr<i64, 3>, tensor<128x128xi1, #blocked1> -> tensor<3x128x128xf16, #shared1>
%19 = triton_nvidia_gpu.insert_slice_async_v2 %6, %18, %c0_i32, %13, %12 {axis = 0 : i32, cache = 1 : i32, evict = 1 : i32, isVolatile = false, operandSegmentSizes = array<i32: 1, 1, 1, 1, 1, 0>} : !tt.ptr<tensor<128x128xf16, #blocked>, 1>, tensor<3x128x128xf16, #shared1>, i32, !tt.ptr<i64, 3>, tensor<128x128xi1, #blocked1> -> tensor<3x128x128xf16, #shared1>
%20 = tt.advance %3, [%c0_i32, %c128_i32] : <tensor<128x128xf16, #blocked>, 1>
%21 = tt.advance %6, [%c128_i32, %c0_i32] : <tensor<128x128xf16, #blocked>, 1>
%22 = arith.cmpi sgt, %arg5, %c128_i32 : i32
%23 = tt.splat %22 : (i1) -> tensor<128x128xi1, #blocked1>
%24 = triton_nvidia_gpu.extract_mbarrier %9[%c1_i32] : tensor<3xi64, #shared>, i32 -> <i64, 3>
%25 = arith.andi %15, %22 : i1
triton_nvidia_gpu.mbarrier_arrive %24, %25 {operand_segment_sizes = array<i32: 1, 1, 0>, trackAsyncOp = false, txCount = 65536 : i32} : !tt.ptr<i64, 3>, i1
%26 = triton_nvidia_gpu.insert_slice_async_v2 %20, %17, %c1_i32, %24, %23 {axis = 0 : i32, cache = 1 : i32, evict = 1 : i32, isVolatile = false, operand_segment_sizes = array<i32: 1, 1, 1, 1, 1, 0>} : !tt.ptr<tensor<128x128xf16, #blocked>, 1>, tensor<3x128x128xf16, #shared1>, i32, !tt.ptr<i64, 3>, tensor<128x128xi1, #blocked1> -> tensor<3x128x128xf16, #shared1>
%27 = triton_nvidia_gpu.insert_slice_async_v2 %21, %19, %c1_i32, %24, %23 {axis = 0 : i32, cache = 1 : i32, evict = 1 : i32, isVolatile = false, operand_segment_sizes = array<i32: 1, 1, 1, 1, 1, 0>} : !tt.ptr<tensor<128x128xf16, #blocked>, 1>, tensor<3x128x128xf16, #shared1>, i32, !tt.ptr<i64, 3>, tensor<128x128xi1, #blocked1> -> tensor<3x128x128xf16, #shared1>
triton_nvidia_gpu.mbarrier_arrive %24, %25 {operandSegmentSizes = array<i32: 1, 1, 0>, trackAsyncOp = false, txCount = 65536 : i32} : !tt.ptr<i64, 3>, i1
%26 = triton_nvidia_gpu.insert_slice_async_v2 %20, %17, %c1_i32, %24, %23 {axis = 0 : i32, cache = 1 : i32, evict = 1 : i32, isVolatile = false, operandSegmentSizes = array<i32: 1, 1, 1, 1, 1, 0>} : !tt.ptr<tensor<128x128xf16, #blocked>, 1>, tensor<3x128x128xf16, #shared1>, i32, !tt.ptr<i64, 3>, tensor<128x128xi1, #blocked1> -> tensor<3x128x128xf16, #shared1>
%27 = triton_nvidia_gpu.insert_slice_async_v2 %21, %19, %c1_i32, %24, %23 {axis = 0 : i32, cache = 1 : i32, evict = 1 : i32, isVolatile = false, operandSegmentSizes = array<i32: 1, 1, 1, 1, 1, 0>} : !tt.ptr<tensor<128x128xf16, #blocked>, 1>, tensor<3x128x128xf16, #shared1>, i32, !tt.ptr<i64, 3>, tensor<128x128xi1, #blocked1> -> tensor<3x128x128xf16, #shared1>
%28 = triton_gpu.extract_slice %26[0, 0, 0] [1, 128, 128] [1, 1, 1] : tensor<3x128x128xf16, #shared1> to tensor<128x128xf16, #shared1>
%29 = triton_gpu.extract_slice %27[0, 0, 0] [1, 128, 128] [1, 1, 1] : tensor<3x128x128xf16, #shared1> to tensor<128x128xf16, #shared1>
%30:15 = scf.for %arg9 = %c0_i32 to %arg5 step %c128_i32 iter_args(%arg10 = %cst, %arg11 = %3, %arg12 = %6, %arg13 = %26, %arg14 = %27, %arg15 = %28, %arg16 = %29, %arg17 = %20, %arg18 = %21, %arg19 = %c128_i32, %arg20 = %c2_i32, %arg21 = %c0_i32, %arg22 = %c0_i32, %arg23 = %false, %arg24 = %true) -> (tensor<128x128xf32, #mma>, !tt.ptr<tensor<128x128xf16, #blocked>, 1>, !tt.ptr<tensor<128x128xf16, #blocked>, 1>, tensor<3x128x128xf16, #shared1>, tensor<3x128x128xf16, #shared1>, tensor<128x128xf16, #shared1>, tensor<128x128xf16, #shared1>, !tt.ptr<tensor<128x128xf16, #blocked>, 1>, !tt.ptr<tensor<128x128xf16, #blocked>, 1>, i32, i32, i32, i32, i1, i1) : i32 {
@@ -64,10 +64,10 @@ module attributes {"triton_gpu.compute-capability" = 90 : i32, "triton_gpu.num-c
%44 = tt.splat %38 : (i1) -> tensor<128x128xi1, #blocked1>
%45 = triton_nvidia_gpu.extract_mbarrier %9[%arg20] : tensor<3xi64, #shared>, i32 -> <i64, 3>
%46 = arith.andi %15, %38 : i1
triton_nvidia_gpu.mbarrier_arrive %45, %46 {operand_segment_sizes = array<i32: 1, 1, 0>, trackAsyncOp = false, txCount = 65536 : i32} : !tt.ptr<i64, 3>, i1
%47 = triton_nvidia_gpu.insert_slice_async_v2 %42, %arg13, %arg20, %45, %44 {axis = 0 : i32, cache = 1 : i32, evict = 1 : i32, isVolatile = false, operand_segment_sizes = array<i32: 1, 1, 1, 1, 1, 0>} : !tt.ptr<tensor<128x128xf16, #blocked>, 1>, tensor<3x128x128xf16, #shared1>, i32, !tt.ptr<i64, 3>, tensor<128x128xi1, #blocked1> -> tensor<3x128x128xf16, #shared1>
triton_nvidia_gpu.mbarrier_arrive %45, %46 {operandSegmentSizes = array<i32: 1, 1, 0>, trackAsyncOp = false, txCount = 65536 : i32} : !tt.ptr<i64, 3>, i1
%47 = triton_nvidia_gpu.insert_slice_async_v2 %42, %arg13, %arg20, %45, %44 {axis = 0 : i32, cache = 1 : i32, evict = 1 : i32, isVolatile = false, operandSegmentSizes = array<i32: 1, 1, 1, 1, 1, 0>} : !tt.ptr<tensor<128x128xf16, #blocked>, 1>, tensor<3x128x128xf16, #shared1>, i32, !tt.ptr<i64, 3>, tensor<128x128xi1, #blocked1> -> tensor<3x128x128xf16, #shared1>
%48 = triton_gpu.extract_slice %47[%41, 0, 0] [1, 128, 128] [1, 1, 1] : tensor<3x128x128xf16, #shared1> to tensor<128x128xf16, #shared1>
%49 = triton_nvidia_gpu.insert_slice_async_v2 %43, %arg14, %arg20, %45, %44 {axis = 0 : i32, cache = 1 : i32, evict = 1 : i32, isVolatile = false, operand_segment_sizes = array<i32: 1, 1, 1, 1, 1, 0>} : !tt.ptr<tensor<128x128xf16, #blocked>, 1>, tensor<3x128x128xf16, #shared1>, i32, !tt.ptr<i64, 3>, tensor<128x128xi1, #blocked1> -> tensor<3x128x128xf16, #shared1>
%49 = triton_nvidia_gpu.insert_slice_async_v2 %43, %arg14, %arg20, %45, %44 {axis = 0 : i32, cache = 1 : i32, evict = 1 : i32, isVolatile = false, operandSegmentSizes = array<i32: 1, 1, 1, 1, 1, 0>} : !tt.ptr<tensor<128x128xf16, #blocked>, 1>, tensor<3x128x128xf16, #shared1>, i32, !tt.ptr<i64, 3>, tensor<128x128xi1, #blocked1> -> tensor<3x128x128xf16, #shared1>
%50 = triton_gpu.extract_slice %49[%41, 0, 0] [1, 128, 128] [1, 1, 1] : tensor<3x128x128xf16, #shared1> to tensor<128x128xf16, #shared1>
%b_48 = triton_gpu.convert_layout %48 : (tensor<128x128xf16, #shared1>) -> tensor<128x128xf16, #blocked1>
%s_48 = triton_gpu.convert_layout %b_48 : (tensor<128x128xf16, #blocked1>) -> tensor<128x128xf16, #shared1>
@@ -133,19 +133,19 @@ module attributes {"triton_gpu.compute-capability" = 90 : i32, "triton_gpu.num-c
%14 = triton_nvidia_gpu.get_thread_id : i32
%15 = arith.cmpi eq, %14, %c0_i32 : i32
%16 = arith.andi %15, %10 : i1
triton_nvidia_gpu.mbarrier_arrive %13, %16 {operand_segment_sizes = array<i32: 1, 1, 0>, trackAsyncOp = false, txCount = 65536 : i32} : !tt.ptr<i64, 3>, i1
%17 = triton_nvidia_gpu.insert_slice_async_v2 %3, %11, %c0_i32, %13, %12 {axis = 0 : i32, cache = 1 : i32, evict = 1 : i32, isVolatile = false, operand_segment_sizes = array<i32: 1, 1, 1, 1, 1, 0>} : !tt.ptr<tensor<128x128xf16, #blocked>, 1>, tensor<3x128x128xf16, #shared1>, i32, !tt.ptr<i64, 3>, tensor<128x128xi1, #blocked1> -> tensor<3x128x128xf16, #shared1>
triton_nvidia_gpu.mbarrier_arrive %13, %16 {operandSegmentSizes = array<i32: 1, 1, 0>, trackAsyncOp = false, txCount = 65536 : i32} : !tt.ptr<i64, 3>, i1
%17 = triton_nvidia_gpu.insert_slice_async_v2 %3, %11, %c0_i32, %13, %12 {axis = 0 : i32, cache = 1 : i32, evict = 1 : i32, isVolatile = false, operandSegmentSizes = array<i32: 1, 1, 1, 1, 1, 0>} : !tt.ptr<tensor<128x128xf16, #blocked>, 1>, tensor<3x128x128xf16, #shared1>, i32, !tt.ptr<i64, 3>, tensor<128x128xi1, #blocked1> -> tensor<3x128x128xf16, #shared1>
%18 = triton_gpu.alloc_tensor : tensor<3x128x128xf16, #shared1>
%19 = triton_nvidia_gpu.insert_slice_async_v2 %6, %18, %c0_i32, %13, %12 {axis = 0 : i32, cache = 1 : i32, evict = 1 : i32, isVolatile = false, operand_segment_sizes = array<i32: 1, 1, 1, 1, 1, 0>} : !tt.ptr<tensor<128x128xf16, #blocked>, 1>, tensor<3x128x128xf16, #shared1>, i32, !tt.ptr<i64, 3>, tensor<128x128xi1, #blocked1> -> tensor<3x128x128xf16, #shared1>
%19 = triton_nvidia_gpu.insert_slice_async_v2 %6, %18, %c0_i32, %13, %12 {axis = 0 : i32, cache = 1 : i32, evict = 1 : i32, isVolatile = false, operandSegmentSizes = array<i32: 1, 1, 1, 1, 1, 0>} : !tt.ptr<tensor<128x128xf16, #blocked>, 1>, tensor<3x128x128xf16, #shared1>, i32, !tt.ptr<i64, 3>, tensor<128x128xi1, #blocked1> -> tensor<3x128x128xf16, #shared1>
%20 = tt.advance %3, [%c0_i32, %c128_i32] : <tensor<128x128xf16, #blocked>, 1>
%21 = tt.advance %6, [%c128_i32, %c0_i32] : <tensor<128x128xf16, #blocked>, 1>
%22 = arith.cmpi sgt, %arg5, %c128_i32 : i32
%23 = tt.splat %22 : (i1) -> tensor<128x128xi1, #blocked1>
%24 = triton_nvidia_gpu.extract_mbarrier %9[%c1_i32] : tensor<3xi64, #shared>, i32 -> <i64, 3>
%25 = arith.andi %15, %22 : i1
triton_nvidia_gpu.mbarrier_arrive %24, %25 {operand_segment_sizes = array<i32: 1, 1, 0>, trackAsyncOp = false, txCount = 65536 : i32} : !tt.ptr<i64, 3>, i1
%26 = triton_nvidia_gpu.insert_slice_async_v2 %20, %17, %c1_i32, %24, %23 {axis = 0 : i32, cache = 1 : i32, evict = 1 : i32, isVolatile = false, operand_segment_sizes = array<i32: 1, 1, 1, 1, 1, 0>} : !tt.ptr<tensor<128x128xf16, #blocked>, 1>, tensor<3x128x128xf16, #shared1>, i32, !tt.ptr<i64, 3>, tensor<128x128xi1, #blocked1> -> tensor<3x128x128xf16, #shared1>
%27 = triton_nvidia_gpu.insert_slice_async_v2 %21, %19, %c1_i32, %24, %23 {axis = 0 : i32, cache = 1 : i32, evict = 1 : i32, isVolatile = false, operand_segment_sizes = array<i32: 1, 1, 1, 1, 1, 0>} : !tt.ptr<tensor<128x128xf16, #blocked>, 1>, tensor<3x128x128xf16, #shared1>, i32, !tt.ptr<i64, 3>, tensor<128x128xi1, #blocked1> -> tensor<3x128x128xf16, #shared1>
triton_nvidia_gpu.mbarrier_arrive %24, %25 {operandSegmentSizes = array<i32: 1, 1, 0>, trackAsyncOp = false, txCount = 65536 : i32} : !tt.ptr<i64, 3>, i1
%26 = triton_nvidia_gpu.insert_slice_async_v2 %20, %17, %c1_i32, %24, %23 {axis = 0 : i32, cache = 1 : i32, evict = 1 : i32, isVolatile = false, operandSegmentSizes = array<i32: 1, 1, 1, 1, 1, 0>} : !tt.ptr<tensor<128x128xf16, #blocked>, 1>, tensor<3x128x128xf16, #shared1>, i32, !tt.ptr<i64, 3>, tensor<128x128xi1, #blocked1> -> tensor<3x128x128xf16, #shared1>
%27 = triton_nvidia_gpu.insert_slice_async_v2 %21, %19, %c1_i32, %24, %23 {axis = 0 : i32, cache = 1 : i32, evict = 1 : i32, isVolatile = false, operandSegmentSizes = array<i32: 1, 1, 1, 1, 1, 0>} : !tt.ptr<tensor<128x128xf16, #blocked>, 1>, tensor<3x128x128xf16, #shared1>, i32, !tt.ptr<i64, 3>, tensor<128x128xi1, #blocked1> -> tensor<3x128x128xf16, #shared1>
%28 = triton_gpu.extract_slice %26[0, 0, 0] [1, 128, 128] [1, 1, 1] : tensor<3x128x128xf16, #shared1> to tensor<128x128xf16, #shared1>
%29 = triton_gpu.extract_slice %27[0, 0, 0] [1, 128, 128] [1, 1, 1] : tensor<3x128x128xf16, #shared1> to tensor<128x128xf16, #shared1>
%b_29 = triton_gpu.convert_layout %29 : (tensor<128x128xf16, #shared1>) -> tensor<128x128xf16, #blocked1>
@@ -167,10 +167,10 @@ module attributes {"triton_gpu.compute-capability" = 90 : i32, "triton_gpu.num-c
%44 = tt.splat %38 : (i1) -> tensor<128x128xi1, #blocked1>
%45 = triton_nvidia_gpu.extract_mbarrier %9[%arg20] : tensor<3xi64, #shared>, i32 -> <i64, 3>
%46 = arith.andi %15, %38 : i1
triton_nvidia_gpu.mbarrier_arrive %45, %46 {operand_segment_sizes = array<i32: 1, 1, 0>, trackAsyncOp = false, txCount = 65536 : i32} : !tt.ptr<i64, 3>, i1
%47 = triton_nvidia_gpu.insert_slice_async_v2 %42, %arg13, %arg20, %45, %44 {axis = 0 : i32, cache = 1 : i32, evict = 1 : i32, isVolatile = false, operand_segment_sizes = array<i32: 1, 1, 1, 1, 1, 0>} : !tt.ptr<tensor<128x128xf16, #blocked>, 1>, tensor<3x128x128xf16, #shared1>, i32, !tt.ptr<i64, 3>, tensor<128x128xi1, #blocked1> -> tensor<3x128x128xf16, #shared1>
triton_nvidia_gpu.mbarrier_arrive %45, %46 {operandSegmentSizes = array<i32: 1, 1, 0>, trackAsyncOp = false, txCount = 65536 : i32} : !tt.ptr<i64, 3>, i1
%47 = triton_nvidia_gpu.insert_slice_async_v2 %42, %arg13, %arg20, %45, %44 {axis = 0 : i32, cache = 1 : i32, evict = 1 : i32, isVolatile = false, operandSegmentSizes = array<i32: 1, 1, 1, 1, 1, 0>} : !tt.ptr<tensor<128x128xf16, #blocked>, 1>, tensor<3x128x128xf16, #shared1>, i32, !tt.ptr<i64, 3>, tensor<128x128xi1, #blocked1> -> tensor<3x128x128xf16, #shared1>
%48 = triton_gpu.extract_slice %47[%41, 0, 0] [1, 128, 128] [1, 1, 1] : tensor<3x128x128xf16, #shared1> to tensor<128x128xf16, #shared1>
%49 = triton_nvidia_gpu.insert_slice_async_v2 %43, %arg14, %arg20, %45, %44 {axis = 0 : i32, cache = 1 : i32, evict = 1 : i32, isVolatile = false, operand_segment_sizes = array<i32: 1, 1, 1, 1, 1, 0>} : !tt.ptr<tensor<128x128xf16, #blocked>, 1>, tensor<3x128x128xf16, #shared1>, i32, !tt.ptr<i64, 3>, tensor<128x128xi1, #blocked1> -> tensor<3x128x128xf16, #shared1>
%49 = triton_nvidia_gpu.insert_slice_async_v2 %43, %arg14, %arg20, %45, %44 {axis = 0 : i32, cache = 1 : i32, evict = 1 : i32, isVolatile = false, operandSegmentSizes = array<i32: 1, 1, 1, 1, 1, 0>} : !tt.ptr<tensor<128x128xf16, #blocked>, 1>, tensor<3x128x128xf16, #shared1>, i32, !tt.ptr<i64, 3>, tensor<128x128xi1, #blocked1> -> tensor<3x128x128xf16, #shared1>
%50 = triton_gpu.extract_slice %49[%41, 0, 0] [1, 128, 128] [1, 1, 1] : tensor<3x128x128xf16, #shared1> to tensor<128x128xf16, #shared1>
%51 = arith.addi %arg20, %c1_i32 : i32
%52 = arith.cmpi uge, %51, %c3_i32 : i32

View File

@@ -12,7 +12,7 @@ module attributes {"triton_gpu.num-ctas" = 1 : i32, "triton_gpu.num-warps" = 4 :
%a_tileptr_init = tt.make_tensor_ptr %A, [%c64, %c16], [%c16, %c1], [%c0, %c0] { order = array<i32: 1, 0> } : !tt.ptr<tensor<64x16xf16>, 1>
// CHECK: %[[BUFFER:.*]] = triton_gpu.alloc_tensor : tensor<1x64x16xf16, #shared>
// CHECK: %[[MBAR:.*]] = triton_nvidia_gpu.alloc_mbarrier {count = 1 : i32} : !tt.ptr<i64, 3>
// CHECK: triton_nvidia_gpu.mbarrier_arrive %[[MBAR]], %{{.*}} {operand_segment_sizes = array<i32: 1, 1, 0>, trackAsyncOp = false, txCount = 2048 : i32} : !tt.ptr<i64, 3>, i1
// CHECK: triton_nvidia_gpu.mbarrier_arrive %[[MBAR]], %{{.*}} {operandSegmentSizes = array<i32: 1, 1, 0>, trackAsyncOp = false, txCount = 2048 : i32} : !tt.ptr<i64, 3>, i1
// CHECK: %[[INSERT:.*]] = triton_nvidia_gpu.insert_slice_async_v2 %[[TENSOR_PTR]], %[[BUFFER]], %{{.*}}, %[[MBAR]]
// CHECK: %[[EXT:.*]] = triton_gpu.extract_slice %[[INSERT]][0, 0, 0] [1, 64, 16] [1, 1, 1] : tensor<1x64x16xf16, #shared> to tensor<64x16xf16, #shared>
// CHECK: triton_nvidia_gpu.mbarrier_wait %[[MBAR]], %false : <i64, 3>

View File

@@ -20,7 +20,7 @@ struct TestAliasPass
return opName;
}
static void print(StringRef name, SmallVector<std::string, 4> &vals,
static void print(StringRef name, SmallVector<std::string> &vals,
raw_ostream &os) {
if (vals.empty())
return;
@@ -57,7 +57,7 @@ struct TestAliasPass
auto getAllocOpNames = [&](Value value) {
dataflow::Lattice<AliasInfo> *latticeElement =
analysis->getLatticeElement(value);
SmallVector<std::string, 4> opNames;
SmallVector<std::string> opNames;
if (latticeElement) {
auto &info = latticeElement->getValue();
for (auto &alias : info.getAllocs()) {