Merge commit '721897fcc4f942aa97d2e9ba3787a5e213758177' into ifu-231108

Conflicts:
	bin/triton-translate.cpp
	lib/Conversion/TritonGPUToLLVM/ElementwiseOpToLLVM.cpp
	lib/Dialect/TritonGPU/Transforms/RemoveLayoutConversions.cpp
	python/triton/compiler/compiler.py
	python/triton/runtime/jit.py
	python/tutorials/06-fused-attention.py
	test/Conversion/tritongpu_to_llvm.mlir
This commit is contained in:
Jason Furmanek
2023-11-08 18:51:23 +00:00
72 changed files with 1623 additions and 838 deletions

View File

@@ -239,6 +239,13 @@ set(TRITON_BINARY_DIR "${CMAKE_CURRENT_BINARY_DIR}")
get_property(dialect_libs GLOBAL PROPERTY MLIR_DIALECT_LIBS)
get_property(conversion_libs GLOBAL PROPERTY MLIR_CONVERSION_LIBS)
# TODO: Figure out which target is sufficient to fix errors; triton is
# apparently not enough. Currently set linking libstdc++fs for all targets
# to support some old version GCC compilers like 8.3.0.
if (NOT WIN32 AND NOT APPLE)
link_libraries(stdc++fs)
endif()
if(TRITON_BUILD_PYTHON_MODULE)
add_library(triton SHARED ${PYTHON_SRC})
set(TRITON_LIBRARIES
@@ -259,7 +266,6 @@ if(TRITON_BUILD_PYTHON_MODULE)
MLIRLLVMDialect
MLIRSupport
MLIRTargetLLVMIRExport
MLIRExecutionEngine
MLIRMathToLLVM
MLIRNVVMToLLVMIRTranslation
MLIRROCDLToLLVMIRTranslation
@@ -278,9 +284,6 @@ if(TRITON_BUILD_PYTHON_MODULE)
target_link_libraries(triton ${ROCM_LIBRARIES} ${LLVM_LIBRARIES} z
${TRITON_LIBRARIES}
)
# TODO: Figure out which target is sufficient to fix errors; triton is
# apparently not enough
link_libraries(stdc++fs)
endif()
target_link_options(triton PRIVATE ${LLVM_LDFLAGS})

View File

@@ -90,10 +90,10 @@ arbitrary LLVM version.
1. Find the version of LLVM that Triton builds against. Check `python/setup.py`
for a line like
version = "llvm-17.0.0-c5dede880d17"
version = "llvmorg-18-init-7000-g76ce4736721a"
This means that the version of Triton you have builds against
[LLVM](https://github.com/llvm/llvm-project) c5dede880d17.
[LLVM](https://github.com/llvm/llvm-project) 76ce4736721a.
2. `git checkout` LLVM at this revision. Optionally, make additional
modifications to LLVM.

View File

@@ -72,7 +72,6 @@ llvm_update_compile_flags(triton-translate)
MLIRPass
MLIRSupport
MLIRTransforms
MLIRExecutionEngine
MLIRMathToLLVM
MLIRTransformUtils
MLIRLLVMToLLVMIRTranslation

View File

@@ -104,10 +104,15 @@ LogicalResult tritonTranslateMain(int argc, char **argv,
"", llvm::cl::desc("AMDGCN features. e.g. '+sramecc,-xnack'"),
llvm::cl::value_desc("features"), llvm::cl::init("+sramecc,-xnack"));
static llvm::cl::opt<bool> enableFpFusion(
"enable-fp-fusion", llvm::cl::desc("Enables fusion of fadd/fmul"),
llvm::cl::init(true));
llvm::InitLLVM y(argc, argv);
registerAsmPrinterCLOptions();
registerMLIRContextCLOptions();
registerPassManagerCLOptions();
llvm::cl::ParseCommandLineOptions(argc, argv, toolName);
mlir::MLIRContext context;
@@ -142,12 +147,17 @@ LogicalResult tritonTranslateMain(int argc, char **argv,
llvm::outs() << *llvmir << '\n';
} else if (targetKind == "ptx") {
llvm::outs() << ::triton::translateLLVMIRToPTX(*llvmir, SMArch.getValue(),
<<<<<<< HEAD
ptxVersion.getValue());
} else if (targetKind == "hsaco") {
auto [module, hsaco] = mlir::triton::translateLLVMIRToHSACO(
*llvmir, GCNArch.getValue(), GCNTriple.getValue(),
GCNFeatures.getValue());
llvm::outs() << hsaco;
=======
ptxVersion.getValue(),
enableFpFusion.getValue());
>>>>>>> 721897fcc4f942aa97d2e9ba3787a5e213758177
} else {
llvm::errs() << "Error: Unknown target specified: " << targetKind << "\n";
return failure();

View File

@@ -63,8 +63,6 @@ Memory Ops
load
store
atomic_cas
atomic_xchg
Indexing Ops
@@ -129,10 +127,11 @@ Atomic Ops
:toctree: generated
:nosignatures:
atomic_cas
atomic_add
atomic_cas
atomic_max
atomic_min
atomic_xchg
Comparison ops

View File

@@ -63,11 +63,12 @@ private:
// Shared Memory Alias Analysis
//===----------------------------------------------------------------------===//
class SharedMemoryAliasAnalysis
: public dataflow::SparseDataFlowAnalysis<dataflow::Lattice<AliasInfo>> {
: public dataflow::SparseForwardDataFlowAnalysis<
dataflow::Lattice<AliasInfo>> {
public:
using dataflow::SparseDataFlowAnalysis<
dataflow::Lattice<AliasInfo>>::SparseDataFlowAnalysis;
using dataflow::SparseDataFlowAnalysis<
using dataflow::SparseForwardDataFlowAnalysis<
dataflow::Lattice<AliasInfo>>::SparseForwardDataFlowAnalysis;
using dataflow::SparseForwardDataFlowAnalysis<
dataflow::Lattice<AliasInfo>>::getLatticeElement;
/// XXX(Keren): Compatible interface with MLIR AliasAnalysis for future use.

View File

@@ -271,8 +271,8 @@ private:
std::vector<std::unique_ptr<AxisInfoVisitor>> visitors;
};
class AxisInfoAnalysis
: public dataflow::SparseDataFlowAnalysis<dataflow::Lattice<AxisInfo>> {
class AxisInfoAnalysis : public dataflow::SparseForwardDataFlowAnalysis<
dataflow::Lattice<AxisInfo>> {
private:
AxisInfoVisitorList visitors;
@@ -284,7 +284,7 @@ private:
public:
AxisInfoAnalysis(DataFlowSolver &solver);
using dataflow::SparseDataFlowAnalysis<
using dataflow::SparseForwardDataFlowAnalysis<
dataflow::Lattice<AxisInfo>>::getLatticeElement;
using FuncAxisInfoMapT = DenseMap<FunctionOpInterface, AxisInfo>;

View File

@@ -54,6 +54,8 @@ public:
SmallVector<unsigned> getScratchConfig();
SmallVector<unsigned> getOrderWithAxisAtBeginning();
unsigned getScratchSizeInBytes();
bool isSupportedLayout();
@@ -133,9 +135,9 @@ bool supportMMA(Value value, int version);
bool isSingleValue(Value value);
bool isMmaToDotShortcut(RankedTensorType &srcTy, RankedTensorType &dstTy);
bool isMmaToDotShortcut(RankedTensorType srcTy, RankedTensorType dstTy);
bool isMmaToMmaShortcut(RankedTensorType &srcTy, RankedTensorType &dstTy);
bool isMmaToMmaShortcut(RankedTensorType srcTy, RankedTensorType dstTy);
// Return true if the src and dst layout match.
bool matchMmaV3AndDotOperandLayout(RankedTensorType srcTy,

View File

@@ -27,6 +27,7 @@ include "triton/Dialect/NVGPU/IR/NVGPUAttrDefs.td"
include "mlir/IR/OpBase.td"
include "mlir/IR/EnumAttr.td"
include "mlir/Dialect/LLVMIR/LLVMOpBase.td"
include "mlir/Interfaces/InferTypeOpInterface.td" // SameOperandsAndResultType
def I8Ptr_global : LLVM_IntPtrBase<8, 1>;
def I8Ptr_shared : LLVM_IntPtrBase<8, 3>;
@@ -44,9 +45,13 @@ def NVGPU_WGMMACommitGroupOp : NVGPU_Op<"wgmma_commit_group", []> {
let assemblyFormat = "attr-dict";
}
def NVGPU_WGMMAWaitGroupOp : NVGPU_Op<"wgmma_wait_group", []> {
let arguments = (ins I32Attr:$pendings);
def NVGPU_WGMMAWaitGroupOp : NVGPU_Op<"wgmma_wait_group",
[DeclareOpInterfaceMethods<InferTypeOpInterface>,
AllTypesMatch<["input", "output"]>]> {
let arguments = (ins LLVM_AnyStruct:$input, I32Attr:$pendings);
let results = (outs LLVM_AnyStruct:$output);
let assemblyFormat = "attr-dict";
let assemblyFormat = "$input attr-dict `:` type($input)";
}
def NVGPU_MBarrierInitOp : NVGPU_Op<"mbarrier_init", [MemoryEffects<[MemWrite]>]> {

View File

@@ -9,9 +9,8 @@
#include "mlir/Dialect/Tensor/IR/Tensor.h"
#include "mlir/IR/BuiltinOps.h"
#include "mlir/IR/Dialect.h"
#include "mlir/IR/FunctionInterfaces.h"
#include "mlir/Interfaces/CallInterfaces.h"
#include "mlir/Interfaces/ControlFlowInterfaces.h"
#include "mlir/Interfaces/FunctionInterfaces.h"
#include "triton/Dialect/Triton/IR/Dialect.h.inc"
#include "triton/Dialect/Triton/IR/OpsEnums.h.inc"
#include "triton/Dialect/Triton/IR/Traits.h"

View File

@@ -6,11 +6,11 @@ include "triton/Dialect/Triton/IR/TritonTypes.td"
include "triton/Dialect/Triton/IR/TritonAttrDefs.td"
include "triton/Dialect/Triton/IR/TritonInterfaces.td"
include "mlir/IR/OpBase.td"
include "mlir/IR/FunctionInterfaces.td" // FunctionOpInterface
include "mlir/IR/SymbolInterfaces.td" // SymbolUserOpInterface
include "mlir/IR/OpAsmInterface.td" // OpAsmOpInterface
include "mlir/Interfaces/CallInterfaces.td" // CallOpInterface
include "mlir/Interfaces/CastInterfaces.td" // CastOpInterface
include "mlir/Interfaces/FunctionInterfaces.td" // FunctionOpInterface
include "mlir/Interfaces/SideEffectInterfaces.td" // Pure
include "mlir/Interfaces/ControlFlowInterfaces.td" // BranchOpInterface
include "mlir/Interfaces/InferTypeOpInterface.td" // SameOperandsAndResultType
@@ -668,6 +668,12 @@ def CallOp : TT_Op<"call", [CallOpInterface, /*MemRefsNormalizable, */DeclareOpI
void setCalleeFromCallable(CallInterfaceCallable callee) {
(*this)->setAttr("callee", callee.get<SymbolRefAttr>());
}
// Required by CallOpInterface.
MutableOperandRange getArgOperandsMutable() {
return getOperandsMutable();
}
}];
let assemblyFormat = [{

View File

@@ -269,17 +269,17 @@ def TTNG_DotAsyncOp : TTNG_Op<"dot_async", [Pure,
let assemblyFormat = "$a`,` $b`,` $c attr-dict `:` type($a) `*` type($b) `->` type($d)";
}
def TTNG_DotWaitOp : TTNG_Op<"dot_wait", []> {
def TTNG_DotWaitOp : TTNG_Op<"dot_wait", [DeclareOpInterfaceMethods<InferTypeOpInterface>,
AllTypesMatch<["input", "output"]>]> {
let summary = "dot wait";
let arguments = (ins TT_FpIntTensor:$input, I32Attr:$pendings);
let results = (outs TT_FpIntTensor:$output);
let description = [{
This operation defining the waiting action for a async dot, MMAv3 .e.g.
The subsequent operations should not execute until this operation completes waiting.
}];
let arguments = (ins I32Attr:$pendings);
let assemblyFormat = "attr-dict";
let assemblyFormat = "$input attr-dict `:` type($input)";
}
def TTNG_StoreAsyncOp : TTNG_Op<"store_async",

View File

@@ -10,7 +10,8 @@ class Module;
namespace triton {
// Translate TritonGPU IR to PTX code.
std::string translateLLVMIRToPTX(llvm::Module &module, int cc, int version);
std::string translateLLVMIRToPTX(llvm::Module &module, int cc, int version,
bool enable_fp_fusion);
} // namespace triton

View File

@@ -885,7 +885,8 @@ public:
//===----------------------------------------------------------------------===//
AxisInfoAnalysis::AxisInfoAnalysis(DataFlowSolver &solver)
: dataflow::SparseDataFlowAnalysis<dataflow::Lattice<AxisInfo>>(solver) {
: dataflow::SparseForwardDataFlowAnalysis<dataflow::Lattice<AxisInfo>>(
solver) {
// UnrealizedConversionCast:
// This is needed by TritonGPUToLLVM, to get AxisInfo when the graph is
// in the process of a PartialConversion, where UnrealizedConversionCast

View File

@@ -38,6 +38,17 @@ bool ReduceOpHelper::isReductionOnLayoutFastAxis() {
getParentOrder(getSrcLayout())[0];
}
SmallVector<unsigned> ReduceOpHelper::getOrderWithAxisAtBeginning() {
auto srcLayout = getSrcLayout();
auto order = triton::gpu::getOrder(srcLayout);
auto it = std::find(order.begin(), order.end(), axis);
// delete the axis from order
order.erase(it);
// insert axis at the beginning of order
order.insert(order.begin(), axis);
return order;
}
// Thread offset is the thread index offset of two adjacent threads on the
// reduction axis within the warp.
unsigned ReduceOpHelper::getThreadOffsetOnReductionAxis() {
@@ -56,11 +67,11 @@ unsigned ReduceOpHelper::getThreadOffsetOnReductionAxis() {
threadOffset = threadsPerWarp[sliceLayout.getDim()];
} else {
auto threadsPerWarp = triton::gpu::getThreadsPerWarp(srcLayout);
if (threadsPerWarp.size() == 1) {
threadOffset = 1;
} else {
assert(threadsPerWarp.size() == 2 && "Only supports 2D layouts");
threadOffset = axis == 0 ? threadsPerWarp[1] : threadsPerWarp[0];
auto order = triton::gpu::getOrder(srcLayout);
for (unsigned i = 0; i < order.size(); i++) {
if (order[i] == axis)
break;
threadOffset *= threadsPerWarp[order[i]];
}
}
return threadOffset;
@@ -150,8 +161,10 @@ unsigned ReduceOpHelper::getThreadsReductionAxis() {
}
bool ReduceOpHelper::isWarpSynchronous() {
auto argsLayout = getSrcLayout();
return triton::gpu::getWarpsPerCTA(argsLayout)[axis] == 1;
auto srcLayout = getSrcLayout();
auto srcShape = getSrcShape();
return triton::gpu::getWarpsPerCTAWithUniqueData(srcLayout, srcShape)[axis] ==
1;
}
SmallVector<unsigned> ReduceOpHelper::getScratchConfig() {
@@ -502,10 +515,10 @@ static bool isMmaToMmaShortcut(Attribute srcEncoding, Attribute dstEncoding) {
// when #mma = MmaEncoding<version=3, warpsPerCTA=[..., 1]>
return src && dst && src.getVersionMajor() == 3 &&
src.getWarpsPerCTA()[1] == 1 && dst.getVersionMajor() == 3 &&
dst.getWarpsPerCTA()[1] == 1 && srcInstrShape[2] == dstInstrShape[2];
dst.getWarpsPerCTA()[1] == 1;
}
bool isMmaToMmaShortcut(RankedTensorType &srcTy, RankedTensorType &dstTy) {
bool isMmaToMmaShortcut(RankedTensorType srcTy, RankedTensorType dstTy) {
return isMmaToMmaShortcut(srcTy.getEncoding(), dstTy.getEncoding());
}
@@ -521,7 +534,7 @@ bool matchMmaV3AndDotOperandLayout(RankedTensorType srcTy,
srcTy.getElementType().isF16();
}
bool isMmaToDotShortcut(RankedTensorType &srcTy, RankedTensorType &dstTy) {
bool isMmaToDotShortcut(RankedTensorType srcTy, RankedTensorType dstTy) {
if (matchMmaV3AndDotOperandLayout(srcTy, dstTy))
return true;
// dot_op<opIdx=0, parent=#mma> = #mma
@@ -713,7 +726,10 @@ SetVector<Operation *> multiRootGetSlice(Operation *op,
auto *currentOp = (slice)[currentIndex];
// Compute and insert the backwardSlice starting from currentOp.
backwardSlice.clear();
getBackwardSlice(currentOp, &backwardSlice, backwardFilter);
mlir::BackwardSliceOptions opt;
opt.omitBlockArguments = true;
opt.filter = backwardFilter;
getBackwardSlice(currentOp, &backwardSlice, opt);
slice.insert(backwardSlice.begin(), backwardSlice.end());
// Compute and insert the forwardSlice starting from currentOp.

View File

@@ -14,7 +14,7 @@ add_mlir_conversion_library(NVGPUToLLVM
LINK_LIBS PUBLIC
MLIRIR
MLIRPass
MLIRGPUOps
MLIRGPUDialect
MLIRGPUToNVVMTransforms
MLIRGPUToROCDLTransforms
MLIRGPUTransforms

View File

@@ -29,8 +29,6 @@ const std::string Reg_Alloc_Op = "setmaxnreg.inc.sync.aligned.u32 #regCount;";
const std::string Wgmma_Fence_Op = "wgmma.fence.sync.aligned;";
const std::string Cga_Barrier_Sync_op = "barrier.cluster.sync.aligned;";
const std::string Wgmma_Commit_Group_Op = "wgmma.commit_group.sync.aligned;";
const std::string Wgmma_Wait_Group_Op =
"wgmma.wait_group.sync.aligned #pendings;";
const std::string Cluster_Wait_Op = "barrier.cluster.wait.aligned;";
const std::string Fence_Mbarrier_Init_Op =
"fence.mbarrier_init.release.cluster;";
@@ -200,29 +198,6 @@ public:
return {};
}
Type getReturnType(std::vector<std::string> outputConstraints,
mlir::PatternRewriter &rewriter) const {
auto ctx = rewriter.getContext();
Type resTy;
if (outputConstraints.empty()) {
resTy = void_ty(ctx);
} else {
SmallVector<Type> retTys;
for (auto &outputConstraint : outputConstraints) {
assert(outputConstraint[0] == '=' &&
"Constraint must be for an output");
Type retTy = getTypeFromConstraint(outputConstraint[1], rewriter);
retTys.push_back(retTy);
}
if (retTys.size() == 1) {
resTy = retTys[0];
} else {
resTy = struct_ty(retTys);
}
}
return resTy;
}
std::string patchPtxAsm(mlir::Operation *op, std::string ptxAsm) const {
std::vector<std::pair<int, int>> patchLocations;
std::vector<std::string> patchValues;
@@ -285,7 +260,8 @@ public:
outputsAndOperands.append(ptxOperands.begin(), ptxOperands.end());
auto &ptxInstr = *ptxBuilder.create<PTXInstr>(ptxAsmPatched);
ptxInstr(outputsAndOperands, /*onlyAttachMLIRArgs=*/true);
auto retTy = getReturnType(outputConstraints, rewriter);
auto retTy =
op->getNumResults() == 0 ? void_ty(ctx) : op->getResult(0).getType();
auto res = ptxBuilder.launch(rewriter, loc, retTy,
/*hasSideEffects*/ hasSideEffects);
if (op->getNumResults() == 0) {
@@ -700,6 +676,45 @@ public:
}
};
class WGMMAWaitGroupOpPattern
: public NVGPUOpPatternBase<ttn::WGMMAWaitGroupOp,
WGMMAWaitGroupOpPattern> {
public:
using Base =
NVGPUOpPatternBase<ttn::WGMMAWaitGroupOp, WGMMAWaitGroupOpPattern>;
using Base::Base;
std::vector<std::string>
getOutputConstraints(ttn::WGMMAWaitGroupOp op) const {
auto outputStructType = op.getType().cast<LLVM::LLVMStructType>();
uint32_t numOutputRegs = outputStructType.getBody().size();
std::string output =
outputStructType.getBody().front().isF32() ? "=f" : "=r";
return std::vector<std::string>(numOutputRegs, output);
}
OperandsAndConstraints
getOperandsAndConstraints(ttn::WGMMAWaitGroupOp op) const {
OperandsAndConstraints operandsAndConstraints;
auto input = op.getInput();
operandsAndConstraints.push_back({input, "0"});
return operandsAndConstraints;
}
std::string getPtxAsm(ttn::WGMMAWaitGroupOp op) const {
auto outputStructType = op.getType().dyn_cast<LLVM::LLVMStructType>();
uint32_t numCRegs = outputStructType.getBody().size();
std::string args = "";
uint32_t asmOpIdx = 0;
for (uint32_t i = 0; i < numCRegs; ++i) {
args += "$" + std::to_string(asmOpIdx++) + (i == numCRegs - 1 ? "" : ",");
}
auto ptxAsm = "// wait for regs: " + args + "\n\t" +
"wgmma.wait_group.sync.aligned #pendings;";
return ptxAsm;
}
};
class WGMMAOpPattern : public NVGPUOpPatternBase<ttn::WGMMAOp, WGMMAOpPattern> {
public:
using Base = NVGPUOpPatternBase<ttn::WGMMAOp, WGMMAOpPattern>;
@@ -1072,7 +1087,6 @@ public:
POPULATE_NVGPU_OP(ttn::WGMMAFenceOp, Wgmma_Fence_Op)
POPULATE_NVGPU_OP(ttn::CGABarrierSyncOp, Cga_Barrier_Sync_op)
POPULATE_NVGPU_OP(ttn::WGMMACommitGroupOp, Wgmma_Commit_Group_Op)
POPULATE_NVGPU_OP(ttn::WGMMAWaitGroupOp, Wgmma_Wait_Group_Op)
POPULATE_NVGPU_OP(ttn::ClusterWaitOp, Cluster_Wait_Op)
POPULATE_NVGPU_OP(ttn::FenceMBarrierInitOp, Fence_Mbarrier_Init_Op)
POPULATE_NVGPU_OP(ttn::CGABarrierArriveOp, Cga_Barrier_Arrive_Op)
@@ -1100,7 +1114,8 @@ public:
OffsetOfStmatrixV4OpPattern, MBarrierArriveOpPattern,
ClusterArriveOpPattern, TMALoadTiledOpPattern,
TMAStoreTiledOpPattern, LoadDSmemOpPattern, WGMMAOpPattern,
StoreDSmemOpPattern, OffsetOfSts64OpPattern>(context);
WGMMAWaitGroupOpPattern, StoreDSmemOpPattern,
OffsetOfSts64OpPattern>(context);
if (applyPatternsAndFoldGreedily(mod, std::move(patterns)).failed())
signalPassFailure();

View File

@@ -49,7 +49,7 @@ add_mlir_conversion_library(TritonGPUToLLVM
ASMBuilder
MLIRIR
MLIRPass
MLIRGPUOps
MLIRGPUDialect
MLIRGPUToNVVMTransforms
MLIRGPUToROCDLTransforms
MLIRGPUTransforms

View File

@@ -146,10 +146,8 @@ struct DotWaitOpConversion
matchAndRewrite(triton::nvidia_gpu::DotWaitOp op, OpAdaptor adaptor,
ConversionPatternRewriter &rewriter) const override {
auto pendings = op.getPendings();
rewriter.create<triton::nvgpu::WGMMAWaitGroupOp>(op.getLoc(), pendings);
// Safe to remove the op since it doesn't have any return value.
rewriter.eraseOp(op);
rewriter.replaceOpWithNewOp<triton::nvgpu::WGMMAWaitGroupOp>(
op, adaptor.getInput(), pendings);
return success();
}
};

View File

@@ -168,7 +168,9 @@ DotOpMmaV3SmemLoader loadA(TritonGPUToLLVMTypeConverter *typeConverter,
int numRepM = ceil<unsigned>(shapePerCTA[0], instrShape[0] * wpt[0]);
int numRepK = ceil<unsigned>(shapePerCTA[1], instrShape[2]);
Value warp = udiv(thread, i32_val(32));
// The descriptor should be calculated based on the first warp of the
// warpgroup.
Value warp = and_(udiv(thread, i32_val(32)), i32_val(0xFFFFFFFC));
Value warpM = urem(warp, i32_val(wpt[0]));
Value warpId = urem(warpM, i32_val(shapePerCTA[0] / instrShape[0]));
@@ -199,7 +201,7 @@ DotOpMmaV3SmemLoader loadB(TritonGPUToLLVMTypeConverter *typeConverter,
int numRepK = ceil<unsigned>(shapePerCTA[0], instrShape[2]);
int numRepN = ceil<unsigned>(shapePerCTA[1], instrShape[1] * wpt[1]);
Value warp = udiv(thread, i32_val(32));
Value warp = and_(udiv(thread, i32_val(32)), i32_val(0xFFFFFFFC));
Value warpMN = udiv(warp, i32_val(wpt[0]));
Value warpN = urem(warpMN, i32_val(wpt[1]));
Value warpId = urem(warpN, i32_val(shapePerCTA[1] / instrShape[1]));
@@ -293,6 +295,26 @@ static bool isZero(Value v) {
return false;
}
static SmallVector<Value> emitWait(ConversionPatternRewriter &rewriter,
Location loc, SmallVector<Value> acc,
int pendings) {
SmallVector<Type> types(acc.size(), acc[0].getType());
auto structTy =
LLVM::LLVMStructType::getLiteral(rewriter.getContext(), types);
Value llvmStruct = rewriter.create<LLVM::UndefOp>(loc, structTy);
int i = 0;
for (Value v : acc) {
llvmStruct = insert_val(structTy, llvmStruct, v, i++);
}
Value res = rewriter.create<triton::nvgpu::WGMMAWaitGroupOp>(loc, llvmStruct,
pendings);
SmallVector<Value> results;
for (int i = 0; i < acc.size(); ++i) {
results.push_back(extract_val(types[0], res, i));
}
return results;
}
LogicalResult convertDot(TritonGPUToLLVMTypeConverter *typeConverter,
ConversionPatternRewriter &rewriter, Location loc,
Operation *op, Value a, Value b, Value c, Value d,
@@ -427,7 +449,7 @@ LogicalResult convertDot(TritonGPUToLLVMTypeConverter *typeConverter,
rewriter.create<triton::nvgpu::WGMMACommitGroupOp>(loc);
if (sync)
rewriter.create<triton::nvgpu::WGMMAWaitGroupOp>(loc, 0);
mmaResults = emitWait(rewriter, loc, mmaResults, 0);
SmallVector<Value> results =
unpackAccumulator(rewriter, loc, mmaResults, dTensorTy);

View File

@@ -6,7 +6,24 @@ using ::mlir::triton::gpu::getTotalElemsPerThread;
/* ----- FP8E5M2 ------ */
// This data-type is the standard FP8E5M2 format
static const std::string Fp16_to_Fp8E5M2(bool hasNativeFP) {
std::string ret;
if (!hasNativeFP) {
ret = "{ \n"
".reg .b32 a<2>; \n"
"and.b32 a0, $1, 0xfffefffe; \n" // a0 &= 0xfffefffe
"and.b32 a1, $2, 0xfffefffe; \n" // (strip lowest bit)
"add.u32 a0, a0, 0x00800080; \n" // a0 += 0x00800080
"add.u32 a1, a1, 0x00800080; \n" // (round to nearest)
"prmt.b32 $0, a0, a1, 0x7531; \n\t" // output = a1a0
"}";
} else {
ret = "cvt.rn.satfinite.e5m2x2.f16x2 $0, $1; \n\t";
}
return ret;
}
<<<<<<< HEAD
#ifdef USE_ROCM
static SmallVector<Value>
Fp16_to_Fp8E5M2(Location loc, ConversionPatternRewriter &rewriter,
@@ -356,11 +373,115 @@ const std::string Bf16_to_Fp8E5M2 =
"or.b32 $0, nosign, sign; \n" // restore sign
"}";
#endif
=======
static const std::string Fp8E5M2_to_Fp16(bool hasNativeFP) {
std::string ret;
if (!hasNativeFP) {
ret = "{ \n"
"prmt.b32 $0, 0, $2, 0x5140; \n\t"
"prmt.b32 $1, 0, $2, 0x7362; \n\t"
"}";
} else {
ret = "cvt.rn.f16x2.e5m2x2 $0, $1; \n\t";
}
return ret;
}
static const std::string Fp8E5M2_to_Bf16(bool hasNativeFP) {
std::string ret;
if (!hasNativeFP) {
ret =
"{ \n"
".reg .b32 a<2>, b<2>; \n" // if input = 0xf1f2f3f4
"prmt.b32 a0, 0, $2, 0x5140; \n" // a0 = 0xf300f400
"prmt.b32 a1, 0, $2, 0x7362; \n" // a1 = 0xf100f200
"lop3.b32 b0, a0, 0x7fff7fff, 0, 0xc0; \n" // b0 = a0 & 0x7fff7fff
"lop3.b32 b1, a1, 0x7fff7fff, 0, 0xc0; \n" // (strip sign)
"shr.b32 b0, b0, 3; \n" // b0 >>= 3
"shr.b32 b1, b1, 3; \n" // shift into bf16 position
"add.u32 b0, b0, 0x38003800; \n" // b0.exp += 2**7-2**4
// exponent compensate = 112
"add.u32 b1, b1, 0x38003800; \n" // b1 += 112<<7 | 112<<7<<16
"lop3.b32 $0, b0, 0x80008000, a0, 0xf8; \n" // out0 = b0|(0x80008000&a0)
"lop3.b32 $1, b1, 0x80008000, a1, 0xf8; \n" // (restore sign)
"}";
} else {
ret = "{ \n"
".reg .b32 a; \n"
".reg .f16 a<2>; \n"
".reg .b16 b<2>; \n"
"cvt.rn.f16x2.e5m2x2 a, $1; \n"
"mov.b32 {a0, a1}, a; \n"
"cvt.bf16.f16 b0, a0; \n"
"cvt.bf16.f16 b1, a1; \n"
"mov.b32 $0, {b0, b1}; \n"
"}";
}
return ret;
}
>>>>>>> 721897fcc4f942aa97d2e9ba3787a5e213758177
static const std::string Bf16_to_Fp8E5M2(bool hasNativeFP) {
std::string ret;
if (!hasNativeFP) {
ret =
"{ \n" // bf16=fp8>>3 + 112<<7
".reg .u32 sign, sign<2>, nosign, nosign<2>; \n" // fp8_min = 0b00000000
".reg .u32 fp8_min, fp8_max, rn_; \n" // fp8_max = 0b11111111
"mov.u32 fp8_min, 0x38003800; \n" // so bf16_min = 0x3800
"mov.u32 fp8_max, 0x57e057e0; \n" // so bf16_max = 0x57e0
"mov.u32 rn_, 0x00100010; \n" // round to nearest
"and.b32 sign0, $1, 0x80008000; \n" // sign0=in0&0x80008000
"and.b32 sign1, $2, 0x80008000; \n" // (store sign)
"prmt.b32 sign, sign0, sign1, 0x7531; \n"
"and.b32 nosign0, $1, 0x7fff7fff; \n" // nosign0=in0&0x7fff7fff
"and.b32 nosign1, $2, 0x7fff7fff; \n" // (strip sign)
// nosign = clamp(nosign, min, max)
".reg .u32 nosign_0_<2>, nosign_1_<2>; \n"
"and.b32 nosign_0_0, nosign0, 0xffff0000; \n"
"max.u32 nosign_0_0, nosign_0_0, 0x38000000; \n"
"min.u32 nosign_0_0, nosign_0_0, 0x57e00000; \n"
"and.b32 nosign_0_1, nosign0, 0x0000ffff; \n"
"max.u32 nosign_0_1, nosign_0_1, 0x3800; \n"
"min.u32 nosign_0_1, nosign_0_1, 0x57e0; \n"
"or.b32 nosign0, nosign_0_0, nosign_0_1; \n"
"and.b32 nosign_1_0, nosign1, 0xffff0000; \n"
"max.u32 nosign_1_0, nosign_1_0, 0x38000000; \n"
"min.u32 nosign_1_0, nosign_1_0, 0x57e00000; \n"
"and.b32 nosign_1_1, nosign1, 0x0000ffff; \n"
"max.u32 nosign_1_1, nosign_1_1, 0x3800; \n"
"min.u32 nosign_1_1, nosign_1_1, 0x57e0; \n"
"or.b32 nosign1, nosign_1_0, nosign_1_1; \n"
"add.u32 nosign0, nosign0, rn_; \n" // nosign0 += rn_
"add.u32 nosign1, nosign1, rn_; \n" // (round to nearest)
"sub.u32 nosign0, nosign0, 0x38003800; \n" // nosign0-=0x38003800
"sub.u32 nosign1, nosign1, 0x38003800; \n" // (compensate offset)
"shl.b32 nosign0, nosign0, 3; \n" // nosign0 <<= 3
"shl.b32 nosign1, nosign1, 3; \n" // shift into to fp8e4
"prmt.b32 nosign, nosign0, nosign1, 0x7531; \n" // nosign0 = 0xf100f200
// nosign1 = 0xf300f400
// nosign = 0xf3f4f1f2
"or.b32 $0, nosign, sign; \n" // restore sign
"}";
} else {
ret = "{ \n"
".reg .b16 a<2>; \n"
".reg .f32 b<2>; \n"
"mov.b32 {a0, a1}, $1; \n"
"cvt.f32.bf16 b0, a0; \n"
"cvt.f32.bf16 b1, a1; \n"
"cvt.rn.satfinite.e5m2x2.f32 $0, b0, b1; \n"
"}";
}
return ret;
}
/* ----- FP8E4M3B15 ------ */
// This data-type is a variant of the standard FP8E4M3 format.
// It was designed for fast software conversion to FP16 on
// nvidia GPUs that do not support it natively.
<<<<<<< HEAD
// Specifically, this data-type:
// - has infinities
// - has multiple nans (when all exponent bits are 1)
@@ -404,6 +525,11 @@ Fp8E4M3B15_to_Fp16(Location loc, ConversionPatternRewriter &rewriter,
};
}
#else
=======
// This is the same format as FP8E4M3Nv, but:
// - the exponent bias is 15 instead of 7
// - 0xff and 0x7f are mapped to +-1.750 instead of +-nan
>>>>>>> 721897fcc4f942aa97d2e9ba3787a5e213758177
const std::string Fp8E4M3B15_to_Fp16 =
"{ \n"
".reg .b32 a<2>, b<2>; \n"
@@ -416,6 +542,7 @@ const std::string Fp8E4M3B15_to_Fp16 =
"lop3.b32 $0, b0, 0x80008000, a0, 0xf8; \n"
"shl.b32 $1, b1, 7; \n"
"} \n";
<<<<<<< HEAD
#endif
#ifdef USE_ROCM
@@ -464,6 +591,10 @@ Fp16_to_Fp8E4M3B15(Location loc, ConversionPatternRewriter &rewriter,
}
#else
const std::string Fp16_to_Fp8E4M3B15(bool has_minx2) {
=======
static const std::string Fp16_to_Fp8E4M3B15(bool has_minx2) {
>>>>>>> 721897fcc4f942aa97d2e9ba3787a5e213758177
std::string ret;
ret += "{ \n"
".reg .pred p<4>; \n"
@@ -509,6 +640,7 @@ const std::string Fp16_to_Fp8E4M3B15(bool has_minx2) {
// $0 = (($2 << 1) & 0x80008000u) | (($2 << 7) & 0x3f803f80u);
// $1 = (($2 << 0) & 0x80008000u) | (($2 << 0) & 0x3f803f80u);
// WARN: subnormal (0bs0000xxx) are not handled
<<<<<<< HEAD
#ifdef USE_ROCM
static SmallVector<Value>
Fp8E4M3B15x4_to_Fp16(Location loc, ConversionPatternRewriter &rewriter,
@@ -540,6 +672,9 @@ Fp8E4M3B15x4_to_Fp16(Location loc, ConversionPatternRewriter &rewriter,
}
#else
const std::string Fp8E4M3B15x4_to_Fp16 =
=======
static const std::string Fp8E4M3B15x4_to_Fp16 =
>>>>>>> 721897fcc4f942aa97d2e9ba3787a5e213758177
"{ \n"
".reg .b32 a<2>; \n"
"add.u32 a0, $2, $2; \n"
@@ -557,6 +692,7 @@ const std::string Fp8E4M3B15x4_to_Fp16 =
// ((e4.y >> 0) & (0x80008000u >> 0)) |
// ((e4.y >> 0) & (0x3f803f80u >> 0)) ;
// WARN: subnormal (0bs0000xxx) are not handled
<<<<<<< HEAD
#ifdef USE_ROCM
static SmallVector<Value>
Fp16_to_Fp8E4M3B15x4(Location loc, ConversionPatternRewriter &rewriter,
@@ -591,6 +727,9 @@ Fp16_to_Fp8E4M3B15x4(Location loc, ConversionPatternRewriter &rewriter,
}
#else
const std::string Fp16_to_Fp8E4M3B15x4 =
=======
static const std::string Fp16_to_Fp8E4M3B15x4 =
>>>>>>> 721897fcc4f942aa97d2e9ba3787a5e213758177
"{ \n"
".reg .b32 a<2>; \n"
"shr.b32 a0, $1, 1; \n"
@@ -904,17 +1043,18 @@ const std::string Bf16_to_Fp8E4M3 =
#endif
// Fp8E4M3 (x2) -> Fp16 (x2) (packed)
const std::string Fp8E4M3Nv_to_Fp16 = "{ \n"
"cvt.rn.f16x2.e4m3x2 $0, $1; \n"
"}";
static const std::string Fp8E4M3Nv_to_Fp16 = "{ \n"
"cvt.rn.f16x2.e4m3x2 $0, $1; \n"
"}";
// Fp16 (x2) -> Fp8E4M3 (x2) (packed)
const std::string Fp16_to_Fp8E4M3Nv = "{ \n"
"cvt.rn.satfinite.e4m3x2.f16x2 $0, $1; \n"
"}";
static const std::string Fp16_to_Fp8E4M3Nv =
"{ \n"
"cvt.rn.satfinite.e4m3x2.f16x2 $0, $1; \n"
"}";
#ifndef USE_ROCM
// Fp8E4M3 (x2) -> Fp16 (x2) (packed)
const std::string Fp8E4M3Nv_to_Bf16 =
static const std::string Fp8E4M3Nv_to_Bf16 =
"{ \n"
".reg .b32 a; \n"
".reg .f16 a<2>; \n"
@@ -927,7 +1067,7 @@ const std::string Fp8E4M3Nv_to_Bf16 =
"}";
// Bf16 (x2) -> Fp8E4M3 (x2) (packed)
const std::string Bf16_to_Fp8E4M3Nv =
static const std::string Bf16_to_Fp8E4M3Nv =
"{ \n"
".reg .b16 a<2>; \n"
".reg .f32 b<2>; \n"
@@ -938,7 +1078,7 @@ const std::string Bf16_to_Fp8E4M3Nv =
"}";
/* ----- Packed integer to BF16 ------ */
const std::string S8_to_Bf16 =
static const std::string S8_to_Bf16 =
"{ \n"
".reg .s8 s<4>; \n"
".reg .f32 f<4>; \n"
@@ -952,6 +1092,12 @@ const std::string S8_to_Bf16 =
"}";
#endif
// Fp32 (x2) -> Fp8 (x2) (packed)
static const std::string Fp32_to_Fp8E4M3Nv =
"cvt.rn.satfinite.e4m3x2.f32 $0, $2, $1; \n";
static const std::string Fp32_to_Fp8E5M2 =
"cvt.rn.satfinite.e5m2x2.f32 $0, $2, $1; \n";
static SmallVector<Value> reorderValues(const SmallVector<Value> &values,
Type inType, Type ouType) {
auto inTensorTy = inType.dyn_cast<RankedTensorType>();
@@ -1383,9 +1529,14 @@ struct FpToFpOpConversion
// F8 -> F16
{{F8E4M3B15TyID, F16TyID}, Fp8E4M3B15_to_Fp16},
{{F8E4M3FNTyID, F16TyID}, Fp8E4M3B15x4_to_Fp16},
<<<<<<< HEAD
{{F8E4M3FNUZTyID, F16TyID}, Fp8E4M3FNUZ_to_Fp16},
{{F8E5M2TyID, F16TyID}, Fp8E5M2_to_Fp16},
{{F8E5M2FNUZTyID, F16TyID}, Fp8E5M2FNUZ_to_Fp16},
=======
{{F8E4M3TyID, F16TyID}, Fp8E4M3Nv_to_Fp16},
{{F8E5M2TyID, F16TyID}, Fp8E5M2_to_Fp16(computeCapability >= 90)},
>>>>>>> 721897fcc4f942aa97d2e9ba3787a5e213758177
// F16 -> F8
#ifdef USE_ROCM
{{F16TyID, F8E4M3B15TyID}, Fp16_to_Fp8E4M3B15},
@@ -1393,27 +1544,44 @@ struct FpToFpOpConversion
{{F16TyID, F8E4M3B15TyID}, Fp16_to_Fp8E4M3B15(computeCapability >= 80)},
#endif
{{F16TyID, F8E4M3FNTyID}, Fp16_to_Fp8E4M3B15x4},
<<<<<<< HEAD
{{F16TyID, F8E4M3FNUZTyID}, Fp16_to_Fp8E4M3FNUZ},
{{F16TyID, F8E5M2TyID}, Fp16_to_Fp8E5M2},
{{F16TyID, F8E5M2FNUZTyID}, Fp16_to_Fp8E5M2FNUZ},
// F8 -> BF16
{{F8E5M2TyID, BF16TyID}, Fp8E5M2_to_Bf16},
#ifndef USE_ROCM
=======
{{F16TyID, F8E4M3TyID}, Fp16_to_Fp8E4M3Nv},
{{F16TyID, F8E5M2TyID}, Fp16_to_Fp8E5M2(computeCapability >= 90)},
// F8 -> BF16
{{F8E5M2TyID, BF16TyID}, Fp8E5M2_to_Bf16(computeCapability >= 90)},
>>>>>>> 721897fcc4f942aa97d2e9ba3787a5e213758177
{{F8E4M3TyID, BF16TyID}, Fp8E4M3Nv_to_Bf16},
#endif
// BF16 -> F8
<<<<<<< HEAD
{{BF16TyID, F8E5M2TyID}, Bf16_to_Fp8E5M2},
#ifndef USE_ROCM
{{BF16TyID, F8E4M3TyID}, Bf16_to_Fp8E4M3Nv},
#endif
=======
{{BF16TyID, F8E5M2TyID}, Bf16_to_Fp8E5M2(computeCapability >= 90)},
{{BF16TyID, F8E4M3TyID}, Bf16_to_Fp8E4M3Nv},
// F32 -> F8
{{F32TyID, F8E4M3TyID}, Fp32_to_Fp8E4M3Nv},
{{F32TyID, F8E5M2TyID}, Fp32_to_Fp8E5M2},
>>>>>>> 721897fcc4f942aa97d2e9ba3787a5e213758177
};
int inVecWidthBits = 32;
int outVecWidthBits = 32;
if (srcTy.isFloat8E4M3FNUZ()) {
if (srcTy.isFloat8E4M3FNUZ() ||
(computeCapability >= 90 && srcTy.isFloat8E5M2())) {
inVecWidthBits = 16;
outVecWidthBits = 32;
}
if (dstTy.isFloat8E4M3FNUZ()) {
if (dstTy.isFloat8E4M3FNUZ() ||
(computeCapability >= 90 && dstTy.isFloat8E5M2())) {
inVecWidthBits = 32;
outVecWidthBits = 16;
}
@@ -1450,18 +1618,24 @@ struct FpToFpOpConversion
size_t numElements = 4;
if (srcElementType.isFloat8E4M3FNUZ() ||
dstElementType.isFloat8E4M3FNUZ()) {
dstElementType.isFloat8E4M3FNUZ() ||
(computeCapability >= 90 &&
(srcElementType.isFloat8E5M2() || dstElementType.isFloat8E5M2()))) {
numElements = 2;
}
bool isSrcFP32 = srcElementType.isF32();
bool useFP16IntermediateSrc =
srcElementType.isF32() &&
!(computeCapability >= 90 &&
(dstElementType.isFloat8E4M3FNUZ() || dstElementType.isFloat8E5M2()));
bool isDstFP32 = dstElementType.isF32();
auto cvtFunc = getConversionFunc(isSrcFP32 ? f16_ty : srcElementType,
isDstFP32 ? f16_ty : dstElementType);
auto cvtFunc =
getConversionFunc(useFP16IntermediateSrc ? f16_ty : srcElementType,
isDstFP32 ? f16_ty : dstElementType);
SmallVector<Value> inVals;
for (unsigned i = 0; i < std::min(numElements, operands.size()); i++) {
inVals.push_back(operands[i][0]);
}
if (isSrcFP32)
if (useFP16IntermediateSrc)
for (Value &v : inVals)
v = convertFp32ToFp16(loc, rewriter, v);
inVals.resize(numElements,
@@ -2115,18 +2289,18 @@ void populateElementwiseOpToLLVMPatterns(
POPULATE_BINARY_OP(arith::RemFOp, LLVM::FRemOp) // %
POPULATE_BINARY_OP(arith::RemSIOp, LLVM::SRemOp)
POPULATE_BINARY_OP(arith::RemUIOp, LLVM::URemOp)
POPULATE_BINARY_OP(arith::AndIOp, LLVM::AndOp) // &
POPULATE_BINARY_OP(arith::OrIOp, LLVM::OrOp) // |
POPULATE_BINARY_OP(arith::XOrIOp, LLVM::XOrOp) // ^
POPULATE_BINARY_OP(arith::ShLIOp, LLVM::ShlOp) // <<
POPULATE_BINARY_OP(arith::ShRSIOp, LLVM::AShrOp) // >>
POPULATE_BINARY_OP(arith::ShRUIOp, LLVM::LShrOp) // >>
POPULATE_BINARY_OP(arith::MinFOp, LLVM::MinNumOp) // fmin
POPULATE_BINARY_OP(arith::MaxFOp, LLVM::MaxNumOp) // fmax
POPULATE_BINARY_OP(arith::MinSIOp, LLVM::SMinOp) // smin
POPULATE_BINARY_OP(arith::MaxSIOp, LLVM::SMaxOp) // smax
POPULATE_BINARY_OP(arith::MinUIOp, LLVM::UMinOp) // umin
POPULATE_BINARY_OP(arith::MaxUIOp, LLVM::UMaxOp) // umax
POPULATE_BINARY_OP(arith::AndIOp, LLVM::AndOp) // &
POPULATE_BINARY_OP(arith::OrIOp, LLVM::OrOp) // |
POPULATE_BINARY_OP(arith::XOrIOp, LLVM::XOrOp) // ^
POPULATE_BINARY_OP(arith::ShLIOp, LLVM::ShlOp) // <<
POPULATE_BINARY_OP(arith::ShRSIOp, LLVM::AShrOp) // >>
POPULATE_BINARY_OP(arith::ShRUIOp, LLVM::LShrOp) // >>
POPULATE_BINARY_OP(arith::MinimumFOp, LLVM::MinNumOp) // fmin
POPULATE_BINARY_OP(arith::MaximumFOp, LLVM::MaxNumOp) // fmax
POPULATE_BINARY_OP(arith::MinSIOp, LLVM::SMinOp) // smin
POPULATE_BINARY_OP(arith::MaxSIOp, LLVM::SMaxOp) // smax
POPULATE_BINARY_OP(arith::MinUIOp, LLVM::UMinOp) // umin
POPULATE_BINARY_OP(arith::MaxUIOp, LLVM::UMaxOp) // umax
#undef POPULATE_BINARY_OP
#define POPULATE_UNARY_OP(SRC_OP, DST_OP) \

View File

@@ -1549,10 +1549,12 @@ struct InsertSliceAsyncOpConversion
auto srcTy = src.getType().cast<RankedTensorType>();
auto resTy = dst.getType().cast<RankedTensorType>();
auto resElemTy = getTypeConverter()->convertType(resTy.getElementType());
auto srcBlockedLayout = srcTy.getEncoding().cast<BlockedEncodingAttr>();
auto srcLayout = srcTy.getEncoding();
assert((srcLayout.isa<BlockedEncodingAttr, SliceEncodingAttr>() &&
"Unexpected srcLayout in InsertSliceAsyncOpConversion"));
auto resSharedLayout = resTy.getEncoding().cast<SharedEncodingAttr>();
auto srcShape = srcTy.getShape();
assert(srcShape.size() == 2 &&
assert((srcShape.size() == 1 || srcShape.size() == 2) &&
"insert_slice_async: Unexpected rank of %src");
Value llDst = adaptor.getDst();
@@ -1617,25 +1619,15 @@ struct InsertSliceAsyncOpConversion
unsigned numElems = getTotalElemsPerThread(srcTy);
unsigned perPhase = resSharedLayout.getPerPhase();
unsigned maxPhase = resSharedLayout.getMaxPhase();
auto sizePerThread = srcBlockedLayout.getSizePerThread();
auto threadsPerCTA = getThreadsPerCTA(srcBlockedLayout);
auto inOrder = srcBlockedLayout.getOrder();
DenseMap<unsigned, Value> sharedPtrs =
getSwizzledSharedPtrs(loc, inVec, srcTy, resSharedLayout, resElemTy,
smemObj, rewriter, offsetVals, srcStrides);
// If perPhase * maxPhase > threadsPerCTA, we will have elements
// that share the same tile indices. The index calculation will
// be cached.
auto numSwizzleRows = std::max<unsigned>(
(perPhase * maxPhase) / threadsPerCTA[inOrder[1]], 1);
// A sharedLayout encoding has a "vec" parameter.
// On the column dimension, if inVec > outVec, it means we have to divide
// single vector read into multiple ones
auto numVecCols = std::max<unsigned>(inVec / outVec, 1);
auto srcIndices = emitIndices(loc, rewriter, srcBlockedLayout, srcTy);
for (unsigned elemIdx = 0; elemIdx < numElems; elemIdx += minVec) {
// 16 * 8 = 128bits
auto maxBitWidth =

View File

@@ -419,16 +419,15 @@ private:
getMultiDimWarpId(helper, warpId, loc, rewriter);
Value warpIdAxis = multiDimWarpId[axis];
if (!helper.isReductionOnLayoutFastAxis()) {
std::reverse(order.begin(), order.end());
}
auto smemOrder = helper.getOrderWithAxisAtBeginning();
for (auto it : accs) {
const SmallVector<unsigned> &key = it.first;
SmallVector<Value> acc = it.second;
SmallVector<Value> writeIdx = indices[key];
writeIdx[axis] = warpIdAxis;
Value writeOffset = linearize(rewriter, loc, writeIdx, smemShape, order);
Value writeOffset =
linearize(rewriter, loc, writeIdx, smemShape, smemOrder);
for (unsigned i = 0; i < op.getNumOperands(); ++i) {
auto elemPtrTy = getElementPtrType(op, i);
Value writePtr = gep(elemPtrTy, smemBases[i], writeOffset);
@@ -513,10 +512,7 @@ private:
Location loc = op.getLoc();
auto srcLayout = helper.getSrcLayout();
auto axis = op.getAxis();
auto order = getOrder(srcLayout);
if (!helper.isReductionOnLayoutFastAxis()) {
std::reverse(order.begin(), order.end());
}
auto smemOrder = helper.getOrderWithAxisAtBeginning();
SmallVector<Value> results(op.getNumOperands());
for (unsigned i = 0; i < op.getNumOperands(); ++i) {
if (auto resultTy =
@@ -532,7 +528,7 @@ private:
SmallVector<Value> readIdx = resultIndices[j];
readIdx.insert(readIdx.begin() + op.getAxis(), i32_val(0));
Value readOffset =
linearize(rewriter, loc, readIdx, smemShape, order);
linearize(rewriter, loc, readIdx, smemShape, smemOrder);
Value readPtr =
gep(getElementPtrType(op, i), smemBases[i], readOffset);
resultVals[j] = load(readPtr);

View File

@@ -622,10 +622,13 @@ struct AllocTensorOpConversion
// TODO: we need to modify the pipeline pass to give a proper shared
// encoding to 3D tensors
SmallVector<unsigned> newOrder;
if (resultTy.getShape().size() == 3)
newOrder = {1 + order[0], 1 + order[1], 0};
else
if (resultTy.getShape().size() != order.size()) {
for (auto i = 0; i < order.size(); ++i)
newOrder.push_back(order[i] + 1);
newOrder.push_back(0);
} else {
newOrder = SmallVector<unsigned>(order.begin(), order.end());
}
auto shapePerCTA = getShapePerCTA(sharedLayout, resultTy.getShape());
auto smemObj =
@@ -659,10 +662,13 @@ struct ExtractSliceOpConversion
SmallVector<Value, 4> opOffsetVals;
SmallVector<Value, 4> offsetVals;
auto mixedOffsets = op.getMixedOffsets();
for (auto i = 0; i < mixedOffsets.size(); ++i) {
if (op.isDynamicOffset(i))
opOffsetVals.emplace_back(adaptor.getOffsets()[i]);
else
for (auto i = 0, j = 0; i < mixedOffsets.size(); ++i) {
if (op.isDynamicOffset(i)) {
// adaptor.getOffsets() returns list of variable offsets. the size of
// the list may not be the same as mixedOffsets
opOffsetVals.emplace_back(adaptor.getOffsets()[j]);
++j;
} else
opOffsetVals.emplace_back(i32_val(op.getStaticOffset(i)));
offsetVals.emplace_back(add(smemObj.offsets[i], opOffsetVals[i]));
}

View File

@@ -146,7 +146,8 @@ protected:
}
auto newFuncOp = rewriter.create<LLVM::LLVMFuncOp>(
funcOp.getLoc(), funcOp.getName(), llvmType, linkage,
/*dsoLocal*/ false, LLVM::CConv::C, attributes);
/*dsoLocal*/ false, LLVM::CConv::C, /*comdat=*/SymbolRefAttr{},
attributes);
rewriter.inlineRegionBefore(funcOp.getBody(), newFuncOp.getBody(),
newFuncOp.end());
if (failed(rewriter.convertRegionTypes(&newFuncOp.getBody(), *typeConverter,
@@ -361,8 +362,13 @@ public:
unsigned numElemsPerSwizzlingRow =
swizzlingByteWidth * 8 / resElemTy.getIntOrFloatBitWidth();
Value numElemsPerSwizzlingRowVal = i32_val(numElemsPerSwizzlingRow);
unsigned leadingDimOffset =
numElemsPerSwizzlingRow * srcShapePerCTA[outOrder[1]];
unsigned leadingDimOffset;
if (outOrder.size() == 2) {
leadingDimOffset = numElemsPerSwizzlingRow * srcShapePerCTA[outOrder[1]];
} else {
leadingDimOffset = numElemsPerSwizzlingRow;
}
Value leadingDimOffsetVal = i32_val(leadingDimOffset);
// Return values
DenseMap<unsigned, Value> ret;
@@ -374,9 +380,15 @@ public:
// Extract multi dimensional index for current element
auto idx = srcIndices[elemIdx];
Value idxCol = idx[outOrder[0]]; // contiguous dimension
Value idxRow = idx[outOrder[1]]; // discontiguous dimension
Value idxRow, strideRow;
if (outOrder.size() == 2) {
idxRow = idx[outOrder[1]]; // discontiguous dimension
strideRow = srcStrides[outOrder[1]];
} else {
idxRow = i32_val(0);
strideRow = i32_val(0);
}
Value strideCol = srcStrides[outOrder[0]];
Value strideRow = srcStrides[outOrder[1]];
// compute phase = (row // perPhase) % maxPhase
Value phase = urem(udiv(idxRow, i32_val(perPhase)), i32_val(maxPhase));
// extract dynamic/static offset for immediate offsetting
@@ -428,10 +440,16 @@ public:
offset = add(offset, add(rowOff, mul(colOff, strideCol)));
Value currPtr = gep(dstPtrTy, dstPtrBase, offset);
// compute immediate offset
Value immedateOff =
add(mul(i32_val(immedateOffRow), srcStrides[outOrder[1]]),
i32_val(immedateOffCol));
ret[elemIdx] = gep(dstPtrTy, currPtr, immedateOff);
Value immediateOff;
if (outOrder.size() == 2) {
immediateOff =
add(mul(i32_val(immedateOffRow), srcStrides[outOrder[1]]),
i32_val(immedateOffCol));
} else {
immediateOff = i32_val(immedateOffCol);
}
ret[elemIdx] = gep(dstPtrTy, currPtr, immediateOff);
}
return ret;
}

View File

@@ -371,13 +371,15 @@ static Value commonShflSync(Location loc, ConversionPatternRewriter &rewriter,
Type type = val.getType();
if (type != i32_ty) {
val = bitcast(val, int_ty(bits));
val = zext(i32_ty, val);
if (bits < 32)
val = zext(i32_ty, val);
}
Value mask = i32_val(0xFFFFFFFF);
Value result = rewriter.create<NVVM::ShflOp>(loc, i32_ty, mask, val, i, clamp,
mode, UnitAttr());
if (type != i32_ty) {
result = trunc(int_ty(bits), result);
if (bits < 32)
result = trunc(int_ty(bits), result);
result = bitcast(result, type);
}
return result;

View File

@@ -97,7 +97,7 @@
ptxBuilder.launch(rewriter, op->getLoc(), voidTy); \
} while (0)
#define undef(...) rewriter.create<LLVM::UndefOp>(loc, __VA_ARGS__)
#define null(...) rewriter.create<LLVM::NullOp>(loc, __VA_ARGS__)
#define null(...) rewriter.create<LLVM::ZeroOp>(loc, __VA_ARGS__)
#define call(...) rewriter.create<LLVM::CallOp>(loc, __VA_ARGS__)
// Types

View File

@@ -115,8 +115,8 @@ void populateArithPatternsAndLegality(TritonGPUTypeConverter &typeConverter,
// Floating point
GenericOpPattern<arith::AddFOp>, GenericOpPattern<arith::SubFOp>,
// MaxMin
GenericOpPattern<arith::MaxFOp>, GenericOpPattern<arith::MaxSIOp>,
GenericOpPattern<arith::MaxUIOp>, GenericOpPattern<arith::MinFOp>,
GenericOpPattern<arith::MaximumFOp>, GenericOpPattern<arith::MaxSIOp>,
GenericOpPattern<arith::MaxUIOp>, GenericOpPattern<arith::MinimumFOp>,
GenericOpPattern<arith::MinSIOp>, GenericOpPattern<arith::MinUIOp>,
// Floating point
GenericOpPattern<arith::MulFOp>, GenericOpPattern<arith::DivFOp>,

View File

@@ -1,9 +1,9 @@
#include "mlir/IR/Builders.h"
#include "mlir/IR/BuiltinAttributes.h"
#include "mlir/IR/BuiltinTypes.h"
#include "mlir/IR/FunctionImplementation.h"
#include "mlir/IR/FunctionInterfaces.h"
#include "mlir/IR/OperationSupport.h"
#include "mlir/Interfaces/FunctionImplementation.h"
#include "mlir/Interfaces/FunctionInterfaces.h"
#include "triton/Dialect/Triton/IR/Dialect.h"
#include "triton/Dialect/Triton/IR/Types.h"

View File

@@ -8,6 +8,6 @@ add_mlir_dialect_library(TritonGPUIR
TritonGPUAttrDefsIncGen
LINK_LIBS PUBLIC
MLIRGPUOps
MLIRGPUDialect
TritonIR
)

View File

@@ -1603,7 +1603,7 @@ ParseResult parseInsertSliceOp(OpAsmParser &parser, OperationState &result) {
result.operands))
return failure();
// Deduce operand_segment_sizes from the number of the operands.
// Deduce operandSegmentSizes from the number of the operands.
auto operandSegmentSizesAttrName =
OpT::getOperandSegmentSizesAttrName(result.name);
result.addAttribute(
@@ -1616,7 +1616,7 @@ template <class OpT>
void printInsertSliceOp(OpAsmPrinter &printer, OpT insertSliceOp) {
printer << " ";
printer << insertSliceOp.getOperation()->getOperands();
// "operand_segment_sizes" can be deduced, so we don't print it.
// "operandSegmentSizes" can be deduced, so we don't print it.
printer.printOptionalAttrDict(
insertSliceOp->getAttrs(),
{insertSliceOp.getOperandSegmentSizesAttrName()});

View File

@@ -139,7 +139,10 @@ class BlockedToMMA : public mlir::RewritePattern {
int finalBitWidth = getElementTypeOrSelf(x).getIntOrFloatBitWidth();
int origBitWidth = finalBitWidth;
SetVector<Operation *> slice;
mlir::getBackwardSlice(x, &slice, bwdFilter);
mlir::BackwardSliceOptions opt;
opt.omitBlockArguments = true;
opt.filter = bwdFilter;
getBackwardSlice(x, &slice, opt);
Operation *firstOp = slice.empty() ? nullptr : *slice.begin();
if (firstOp)
if (Value arg = firstOp->getOperand(0))
@@ -235,8 +238,11 @@ public:
if (versionMajor == 1) {
SetVector<Operation *> aBwdSlices, bBwdSlices;
auto isCvt = [](Operation *op) { return isa<ConvertLayoutOp>(op); };
getBackwardSlice(a, &aBwdSlices, {isCvt});
getBackwardSlice(b, &bBwdSlices, {isCvt});
mlir::BackwardSliceOptions opt;
opt.omitBlockArguments = true;
opt.filter = isCvt;
getBackwardSlice(a, &aBwdSlices, opt);
getBackwardSlice(b, &bBwdSlices, opt);
// get the source of the first conversion found in slices
auto getCvtArgOrder = [](Operation *op) {
return cast<ConvertLayoutOp>(op)

View File

@@ -98,7 +98,9 @@ public:
// and all operations between the load and the conversion
// should be layout preserving
SetVector<Operation *> slice;
getBackwardSlice(op, &slice);
mlir::BackwardSliceOptions opt;
opt.omitBlockArguments = true;
getBackwardSlice(op, &slice, opt);
int loadIdx = -1;
bool checkOp = false;
for (int i = 0; i < slice.size(); i++) {

View File

@@ -160,6 +160,8 @@ class LoopPipeliner {
void checkOpShareBarriers(SetVector<Operation *> &ops);
int numLoadsRequireAsyncWait = 0;
int numLoadsRequireMBarrier = 0;
// Number of buffers to allocate for each input.
int numSharedMemorySlices = 0;
/// Iterator values
Value nextIV;
@@ -280,9 +282,12 @@ class LoopPipeliner {
public:
LoopPipeliner(scf::ForOp forOp, int numStages, int numWarps, int numCTAs,
bool mode, ConsumerReleaseMap &consumerReleaseMap)
bool mode, int numSharedMemorySlices,
ConsumerReleaseMap &consumerReleaseMap)
: forOp(forOp), numStages(numStages), numWarps(numWarps),
numCTAs(numCTAs), mode(mode), consumerReleaseMap(consumerReleaseMap) {
numCTAs(numCTAs), mode(mode),
numSharedMemorySlices(numSharedMemorySlices),
consumerReleaseMap(consumerReleaseMap) {
// cache yieldOp
yieldOp = cast<scf::YieldOp>(forOp.getBody()->getTerminator());
}
@@ -644,7 +649,7 @@ void LoopPipeliner::createBufferTypes() {
auto ty = loadOp.getType().cast<RankedTensorType>();
SmallVector<int64_t> bufferShape(ty.getShape().begin(),
ty.getShape().end());
bufferShape.insert(bufferShape.begin(), numStages);
bufferShape.insert(bufferShape.begin(), numSharedMemorySlices);
auto CTALayout = ttg::getCTALayout(ty.getEncoding());
Attribute sharedEnc;
if (auto dotOpEnc = cvt.getType()
@@ -946,6 +951,11 @@ void LoopPipeliner::emitPrologue() {
pipelineIterIdx = builder.create<arith::AddIOp>(
iv.getLoc(), pipelineIterIdx,
builder.create<arith::ConstantIntOp>(iv.getLoc(), 1, 32));
Value numSlices = builder.create<arith::ConstantIntOp>(
iv.getLoc(), numSharedMemorySlices, 32);
Value _0 = builder.create<arith::ConstantIntOp>(iv.getLoc(), 0, 32);
pipelineIterIdx = getBoundedIterationValue(builder, pipelineIterIdx,
numSlices, pipelineIterIdx, _0);
// Some values have not been used by any ops in the loop body
for (BlockArgument arg : forOp.getRegionIterArgs())
setValueMappingYield(arg, valueMapping[arg][stage], stage + 1);
@@ -1220,11 +1230,13 @@ void LoopPipeliner::prefetchNextIteration(scf::ForOp newForOp,
Value _1 = builder.create<arith::ConstantIntOp>(idxLoc, 1, 32);
Value numStagesVal =
builder.create<arith::ConstantIntOp>(idxLoc, numStages, 32);
Value numSlices =
builder.create<arith::ConstantIntOp>(idxLoc, numSharedMemorySlices, 32);
// nextWaitIdx
Value waitIdxPlusOne = builder.create<arith::AddIOp>(idxLoc, curWaitIdx, _1);
Value nextWaitIdx = getBoundedIterationValue(
builder, waitIdxPlusOne, numStagesVal, waitIdxPlusOne, _0);
Value nextWaitIdx = getBoundedIterationValue(builder, waitIdxPlusOne,
numSlices, waitIdxPlusOne, _0);
// Indices of InsertSliceAsyncOp and ExtractSliceOp
Value insertSliceIndex = pipelineIterIdx;
@@ -1417,9 +1429,8 @@ void LoopPipeliner::prefetchNextIteration(scf::ForOp newForOp,
// Bump pipelineIterIdx
Value pipelineIterIdxPlusOne =
builder.create<arith::AddIOp>(idxLoc, pipelineIterIdx, _1);
pipelineIterIdx =
getBoundedIterationValue(builder, pipelineIterIdxPlusOne, numStagesVal,
pipelineIterIdxPlusOne, _0);
pipelineIterIdx = getBoundedIterationValue(
builder, pipelineIterIdxPlusOne, numSlices, pipelineIterIdxPlusOne, _0);
// Bump curWaitIdx
curWaitIdx = nextWaitIdx;
@@ -1516,10 +1527,23 @@ struct PipelinePass : public TritonGPUPipelineBase<PipelinePass> {
// applyPatternsAndFoldGreedily(getOperation(), std::move(patterns));
llvm::SmallVector<scf::ForOp> newForOps;
// Currently we schedule stage 0 after stage `numStages - 1` during
// pipelining therefore we only need `numStages - 1` slice of memory.
// On Hopper we have a separate post-processing that pipelines wgmma so we
// need an extra buffer for each input.
// Note that an alternative would be to keep allocating `numStages` buffers
// and remove the barrier between the loads from shared memory and the
// copies from global to shared. This would require improving existing
// membar analysis.
int numSharedMemorySlices =
computeCapability < 90 ? numStages - 1 : numStages;
// Do the pipelining
getOperation()->walk([&](scf::ForOp forOp) -> void {
LoopPipeliner pipeliner(forOp, this->numStages, this->numWarps,
this->numCTAs, mode, consumerReleaseMap);
this->numCTAs, mode, numSharedMemorySlices,
consumerReleaseMap);
if (pipeliner.initialize().failed())
return;
@@ -1593,7 +1617,8 @@ void PipelinePass::asyncLaunchDots(scf::ForOp forOp) {
/// XXX(Keren): Clean up the following duplicate code with checkDotOp
/// dots to be pipelined
SetVector<Value> dots;
SmallVector<tt::DotOp> dots;
SmallVector<unsigned> resultNeedSync;
for (Operation &op : *loop) {
if (auto dotOp = dyn_cast<tt::DotOp>(&op)) {
auto resTy = dotOp.getResult().getType().dyn_cast<RankedTensorType>();
@@ -1615,8 +1640,11 @@ void PipelinePass::asyncLaunchDots(scf::ForOp forOp) {
if (!CArg || !CArg.hasOneUse())
valid = false;
if (valid)
dots.insert(dotOp);
if (valid) {
dots.push_back(dotOp);
resultNeedSync.push_back(
dotOp->getUses().begin()->getOperandNumber());
}
}
}
}
@@ -1627,39 +1655,39 @@ void PipelinePass::asyncLaunchDots(scf::ForOp forOp) {
return;
OpBuilder builder(forOp);
// 0. insert dot_wait after the last dot in the loop
Value dot = dots.back();
auto loc = dot.getLoc();
builder.setInsertionPointAfter(dot.getDefiningOp());
auto dotWait = builder.create<tt::nvidia_gpu::DotWaitOp>(loc, dots.size());
// 0. insert dot_wait after the last dot in the loop as we implicitly pipeline
// wgmma ops by one stage.
// This is needed to prevent shared memory inputs to be overriden before the
// operation is completed.
// TODO: merge this with the rest of the pipelining transformation and look at
// a better representation for async dots.
tt::DotOp lastDot = dots.back();
builder.setInsertionPointAfter(lastDot);
auto dotWait = builder.create<tt::nvidia_gpu::DotWaitOp>(
lastDot.getLoc(), lastDot.getResult(), dots.size());
// 1. replace Dot with DotAsync
for (size_t idx = 0; idx < dots.size(); ++idx) {
Value dot = dots[idx];
auto dotOp = cast<tt::DotOp>(dot.getDefiningOp());
builder.setInsertionPoint(dot.getDefiningOp());
tt::DotOp dotOp = dots[idx];
builder.setInsertionPoint(dotOp);
auto dotAsync = builder.create<tt::nvidia_gpu::DotAsyncOp>(
loc, dotOp.getA(), dotOp.getB(), dotOp.getC(), dotOp.getAllowTF32(),
dotOp.getMaxNumImpreciseAcc());
dot.replaceAllUsesWith(dotAsync.getResult());
updateConsumerReleaseInfo(dot.getDefiningOp(), dotWait, /*stage=*/1);
dot.getDefiningOp()->erase();
dotOp.getLoc(), dotOp.getA(), dotOp.getB(), dotOp.getC(),
dotOp.getAllowTF32(), dotOp.getMaxNumImpreciseAcc());
dotOp.replaceAllUsesWith(dotAsync.getResult());
updateConsumerReleaseInfo(dotOp, dotWait, /*stage=*/1);
dotOp->erase();
}
// 2. If there's any outstanding DotAsyncOps, we need to wait for them.
builder.setInsertionPointAfter(forOp);
Value loopNotEmpty = builder.create<arith::CmpIOp>(
loc, arith::CmpIPredicate::slt, forOp.getLowerBound(),
forOp.getUpperBound());
// TODO[goostavz]: it's a workaround to put the DotWaitOp in an IfOp for
// a bug in ptxas which mistakenly analysis the control flow and turn the GMMA
// into synchronuous implementation for safety.
// Remove this If once the bug is fixed.
auto ifOp = builder.create<scf::IfOp>(loc, ArrayRef<Type>{}, loopNotEmpty,
/*hasElse*/ false);
builder.setInsertionPointToStart(ifOp.thenBlock());
builder.create<tt::nvidia_gpu::DotWaitOp>(forOp.getLoc(), 0);
for (unsigned resultIndex : resultNeedSync) {
Value result = forOp->getResult(resultIndex);
if (result.use_empty())
continue;
auto dotWait =
builder.create<tt::nvidia_gpu::DotWaitOp>(forOp.getLoc(), result, 0);
result.replaceAllUsesExcept(dotWait.getResult(), dotWait);
}
}
Value PipelinePass::getRemoteCTAId(OpBuilder &b, Location loc,

View File

@@ -31,6 +31,7 @@ using triton::gpu::SliceEncodingAttr;
//
// -----------------------------------------------------------------------------
<<<<<<< HEAD
// convert(blocked, dot_operand) ->
// convert(blocked, mma) + convert(mma, dot_operand)
// if this value is itself the result of a dot operation
@@ -102,6 +103,9 @@ public:
};
//
=======
// dot(a, b, load(ptr)) -> add(load(ptr), dot(a, b, 0))
>>>>>>> 721897fcc4f942aa97d2e9ba3787a5e213758177
class ConvertDotConvert : public mlir::RewritePattern {
public:
ConvertDotConvert(mlir::MLIRContext *context)
@@ -233,12 +237,17 @@ static bool hasConvertToMMATransisitiveUse(Operation *op, Attribute encoding) {
getForwardSlice(currentValue, &forwardSlice);
for (Operation *op : forwardSlice) {
if (auto convertOp = dyn_cast<triton::gpu::ConvertLayoutOp>(op)) {
if (convertOp.getResult()
.getType()
.cast<RankedTensorType>()
.getEncoding()
.isa<triton::gpu::MmaEncodingAttr>())
return true;
Attribute dstEncoding = convertOp.getResult()
.getType()
.cast<RankedTensorType>()
.getEncoding();
if (auto mmaLayout =
dstEncoding.dyn_cast<triton::gpu::MmaEncodingAttr>())
return (mmaLayout.getVersionMajor() > 1) ? true
: mmaLayout == encoding;
if (dstEncoding.isa<triton::gpu::DotOperandEncodingAttr>())
return encoding.cast<triton::gpu::MmaEncodingAttr>()
.getVersionMajor() > 1;
}
auto yield = dyn_cast<scf::YieldOp>(op);
if (!yield)
@@ -560,6 +569,15 @@ Value LayoutPropagation::getValueAs(Value value, Attribute encoding) {
return rewrittenValue;
OpBuilder rewriter(value.getContext());
rewriter.setInsertionPointAfterValue(rewrittenValue);
// Workaround: The pipeliner will insert async.wait after a pipelined loop
// to ensure that there is no pending copies and it is safe to re-use shared
// memory. We shouldn't insert ops that may use shared memory in between the
// loop and the async.wait. This is a hack until we fix the IR
// representation of async wait.
if (Operation *op = rewrittenValue.getDefiningOp()) {
if (isa<triton::gpu::AsyncWaitOp>(op->getNextNode()))
rewriter.setInsertionPointAfter(op->getNextNode());
}
auto tmpType = RankedTensorType::get(tensorType.getShape(),
tensorType.getElementType(), encoding);
Value converted = rewriter.create<triton::gpu::ConvertLayoutOp>(
@@ -1122,7 +1140,6 @@ public:
hoistConvert(m);
mlir::RewritePatternSet decomposePatterns(context);
decomposePatterns.add<DecomposeDotOperand>(context);
decomposePatterns.add<ConvertDotConvert>(context);
if (mlir::applyPatternsAndFoldGreedily(m, std::move(decomposePatterns))
.failed()) {

View File

@@ -91,7 +91,7 @@ private:
// suport ForOp only
if (auto forOp = dyn_cast<scf::ForOp>(argOwner)) {
// prologue
auto iterOperands = forOp.getIterOperands();
auto iterOperands = forOp.getInitArgs();
if (argNum == 0)
return false;
if (dependOnSharedEncOperand(iterOperands[argNum - 1]))

View File

@@ -628,12 +628,13 @@ bool CTAPlanner::isElementwiseOp(Operation *op) const {
arith::CeilDivUIOp, arith::DivFOp, arith::DivSIOp,
arith::DivUIOp, arith::ExtFOp, arith::ExtSIOp, arith::ExtUIOp,
arith::FloorDivSIOp, arith::FPToSIOp, arith::FPToUIOp,
arith::MaxFOp, arith::MaxSIOp, arith::MaxUIOp, arith::MinFOp,
arith::MinSIOp, arith::MinUIOp, arith::MulFOp, arith::MulIOp,
arith::NegFOp, arith::OrIOp, arith::RemFOp, arith::RemSIOp,
arith::RemUIOp, arith::ShLIOp, arith::ShRSIOp, arith::ShRUIOp,
arith::SIToFPOp, arith::SubFOp, arith::SubIOp, arith::TruncFOp,
arith::TruncIOp, arith::UIToFPOp, arith::XOrIOp>(op))
arith::MaximumFOp, arith::MaxSIOp, arith::MaxUIOp,
arith::MinimumFOp, arith::MinSIOp, arith::MinUIOp,
arith::MulFOp, arith::MulIOp, arith::NegFOp, arith::OrIOp,
arith::RemFOp, arith::RemSIOp, arith::RemUIOp, arith::ShLIOp,
arith::ShRSIOp, arith::ShRUIOp, arith::SIToFPOp, arith::SubFOp,
arith::SubIOp, arith::TruncFOp, arith::TruncIOp,
arith::UIToFPOp, arith::XOrIOp>(op))
return true;
if (llvm::isa<math::AbsFOp, math::AbsIOp, math::AtanOp, math::Atan2Op,
math::CeilOp, math::CopySignOp, math::CosOp, math::SinOp,

View File

@@ -220,7 +220,9 @@ public:
SetVector<Operation *> backwardSlice;
mod.walk([&](triton::MakeTensorPtrOp op) -> void {
assert(isa<triton::FuncOp>(op->getParentOp()));
getBackwardSlice(op.getOperation(), &backwardSlice);
mlir::BackwardSliceOptions opt;
opt.omitBlockArguments = true;
getBackwardSlice(op.getOperation(), &backwardSlice, opt);
op->removeAttr("async_agent");
});
for (auto op : backwardSlice) {

View File

@@ -79,6 +79,7 @@ void materializeGetAgentIdOp(Operation *parentOp) {
builder.setInsertionPoint(agentIdOp);
Value globalRoleId = builder.create<arith::ConstantIntOp>(loc, 0, 32);
int globalNumWarps = 0;
SmallVector<Operation *> deprecatedOps;
for (auto cmpOp : agentIdOp->getUsers()) {
assert(isa<arith::CmpIOp>(cmpOp));
for (auto u : cmpOp->getUsers()) {
@@ -111,11 +112,14 @@ void materializeGetAgentIdOp(Operation *parentOp) {
Value cond =
builder.create<arith::AndIOp>(loc, lowerBound, upperBound);
cmpOp->getResult(0).replaceAllUsesWith(cond);
cmpOp->erase();
deprecatedOps.push_back(cmpOp);
break;
}
}
}
for (Operation *cmpOp : deprecatedOps) {
cmpOp->erase();
}
});
}
@@ -145,39 +149,24 @@ LoadType scanLoadTypes(ttng::CreateTokenOp createTokenOp) {
}
Value getMBarrierPhaseBit(OpBuilder &builder, Operation *op,
bool skipFirstWait) {
bool emptyBarrier) {
// TODO: currently we only support one loop, no nested loop, while or
// condition.
auto loc = op->getLoc();
auto forOp = op->getParentOfType<scf::ForOp>();
if (!forOp) {
return builder.create<arith::ConstantIntOp>(loc, skipFirstWait, 1);
return builder.create<arith::ConstantIntOp>(loc, emptyBarrier, 1);
}
auto defOp = op->getOperand(0).getDefiningOp();
assert(isa<ttng::CreateTokenOp>(defOp) &&
"mbarrier's definingOp is not createTokenOp");
ttng::CreateTokenOp createTokenOp = dyn_cast<ttng::CreateTokenOp>(defOp);
Value numStage =
builder.create<arith::ConstantIntOp>(loc, createTokenOp.getNum(), 32);
Value curStep = forOp.getBody()->getArguments().back();
if (curStep.getType() == builder.getIndexType()) {
curStep =
builder.create<arith::IndexCastOp>(loc, numStage.getType(), curStep);
// for (..., phase, pipelineIdx)
unsigned numArgs = forOp.getBody()->getNumArguments();
assert(numArgs > 2 && "Unexpected number of arguments");
Value curPhase = forOp.getBody()->getArgument(numArgs - 2);
if (emptyBarrier) {
Value _1_1b = builder.create<arith::ConstantIntOp>(loc, 1, 1);
curPhase = builder.create<mlir::arith::XOrIOp>(loc, curPhase, _1_1b);
}
Value curPhase = builder.create<arith::DivUIOp>(loc, curStep, numStage);
if (skipFirstWait) {
// If skipFirstWait, it waits for phaseBit 1
Value _1 = builder.create<arith::ConstantIntOp>(loc, 1, 32);
curPhase = builder.create<arith::AddIOp>(loc, curPhase, _1);
}
Value _2 = builder.create<arith::ConstantIntOp>(loc, 2, 32);
// TODO: May use alternative methods of phaseBit calculation to avoid high
// overhead of RemOp
Value phaseBit = builder.create<arith::RemUIOp>(loc, curPhase, _2);
Value _0 = builder.create<arith::ConstantIntOp>(loc, 0, 32);
return builder.create<arith::CmpIOp>(loc, arith::CmpIPredicate::ne, phaseBit,
_0);
return curPhase;
}
int getTxBytes(ttng::InsertSliceAsyncV2Op load) {
@@ -260,7 +249,7 @@ void processProducerAcquireOp(OpBuilder &builder, ttng::ProducerAcquireOp op,
auto loc = op.getLoc();
// The first producer_aquire should be met immediately, so initailly producer
// skips the fisrt wait
Value phase = getMBarrierPhaseBit(builder, op, 1);
Value phase = getMBarrierPhaseBit(builder, op, true);
auto waitOp = builder.create<ttng::MBarrierWaitOp>(loc, bufferEmpty, phase);
assert(op.getOperation()->hasAttr("async_agent"));
setAgentIds(waitOp, getAgentIds(op.getOperation()));
@@ -296,7 +285,7 @@ void processProducerCommitOp(OpBuilder &builder, ttng::ProducerCommitOp op,
void processConsumerWaitOp(OpBuilder &builder, ttng::ConsumerWaitOp op,
Value bufferFull) {
auto loc = op.getLoc();
Value phase = getMBarrierPhaseBit(builder, op, 0);
Value phase = getMBarrierPhaseBit(builder, op, false);
auto waitOp = builder.create<ttng::MBarrierWaitOp>(loc, bufferFull, phase);
assert(op.getOperation()->hasAttr("async_agent"));
setAgentIds(waitOp, getAgentIds(op.getOperation()));
@@ -530,6 +519,7 @@ void mutexSyncPingPang(Operation *parentOp, int numAgents, int &nameBarrierId,
builder.create<arith::ConstantIntOp>(loc, nameBarrierId - 1, 32);
// Process mutex users
int numUsers = 0;
SmallVector<Operation *> deprecatedOps;
for (Operation *user : createMutexOp.getResult().getUsers()) {
numUsers++;
assert(numUsers <= 2);
@@ -543,14 +533,20 @@ void mutexSyncPingPang(Operation *parentOp, int numAgents, int &nameBarrierId,
Value barLeave = builder.create<arith::SelectOp>(
loc, isRole0, namedBarrierId1, namedBarrierId0);
builder.create<ttng::NamedBarrierArriveOp>(loc, barLeave, numThreads);
} else
} else {
llvm_unreachable("Unexpected user of mutex");
}
deprecatedOps.push_back(user);
}
for (Operation *user : deprecatedOps) {
user->erase();
}
nameBarrierId -= 2;
nameBarrierIdEnd -= 2;
createMutexOp.erase();
});
parentOp->walk(
[](ttng::CreateMutexOp createMutexOp) { createMutexOp.erase(); });
}
void processLockOp(OpBuilder &builder, ttng::LockOp op) {
@@ -587,6 +583,7 @@ void materializeMutexOperationsOthers(ModuleOp parentOp) {
OpBuilder builder(createMutexOp);
// Process mutex users
SmallVector<Operation *> deprecatedOps;
for (Operation *user : createMutexOp.getResult().getUsers()) {
auto loc = user->getLoc();
builder.setInsertionPoint(user);
@@ -596,6 +593,9 @@ void materializeMutexOperationsOthers(ModuleOp parentOp) {
processUnlockOp(builder, op);
else
llvm_unreachable("Unexpected user of mutex");
deprecatedOps.push_back(user);
}
for (Operation *user : deprecatedOps) {
user->erase();
}

View File

@@ -156,14 +156,20 @@ void mutexSync(ModuleOp &mod, scf::IfOp &ifOp, scf::ForOp &persistentForOp,
persistentForOp.getInitArgsMutable()
.slice(persistentForOp.getInitArgs().size() - 1, 1)
.assign(newIdx);
auto yield =
llvm::cast<scf::YieldOp>(persistentForOp.getBody()->getTerminator());
auto idxPlusOneOp =
yield->getOperand(yield->getNumOperands() - 1).getDefiningOp();
assert(isa<arith::AddIOp>(idxPlusOneOp));
assert(idxPlusOneOp->getOperand(0) ==
persistentForOp.getBody()->getArgument(
persistentForOp.getBody()->getNumArguments() - 1));
pipelineIdx = persistentForOp.getBody()->getArgument(
persistentForOp.getBody()->getNumArguments() - 1);
Operation *idxPlusOneOp = nullptr;
for (OpOperand &v : pipelineIdx.getUses()) {
if (isa<arith::AddIOp>(v.getOwner())) {
idxPlusOneOp = v.getOwner();
break;
}
}
assert(idxPlusOneOp && "idxPlusOneOp should be arith::AddIOp");
Operation *use = *idxPlusOneOp->getUsers().begin();
assert(isa<scf::YieldOp>(use) || isa<arith::SelectOp>(use) ||
isa<arith::CmpIOp>(use));
idxPlusOneOp->setOperand(1, numRolesValue);
// Add operations at the start of persistentForOp
@@ -213,45 +219,6 @@ void mutexSync(ModuleOp &mod, scf::IfOp &ifOp, scf::ForOp &persistentForOp,
unlockLocs[i] = op;
}
// Update unlockLocs
// ====================== IR after async launch dots ======================
// * %0:2 = scf.for %arg0 = %c0 to %1 step %c1 iter_args(%arg1 = %2, arg2 =
// %3) {
// * triton_nvidia_gpu.producer_wait arg2
// * %5 = triton_nvidia_gpu.dot_async %4, %5
// * triton_nvidia_gpu.dot_wait {pendings = 1}
// * %6 = arith.cmpi sgt, arg0, %c0
// * scf.if %6 {
// * %7 = arith.subi arg2, c1
// * triton_nvidia_gpu.consumer_release %7
// * }
// * %8 = arith.addi arg2, c1
// * scf.yield %5, %8
// * }
// * triton_nvidia_gpu.dot_wait {pendings = 0}
// * %9 = arith.subi %0#1, c1
// * triton_nvidia_gpu.consumer_release %9
// * =======================================================================
// after async launch dots, there will be outstanding consumerReleaseOp after
// ForOp. we should expend the unlockLocs from ForOp to the outstanding
// consumerReleaseOp.
for (int i = 0; i < numRoles; ++i) {
Operation *unlockOp = unlockLocs[i];
auto filter = [&](Operation *op) {
return op->getBlock() == unlockOp->getBlock();
};
if (isa<scf::ForOp>(unlockOp)) {
SetVector<Operation *> slices;
mlir::getForwardSlice(unlockOp->getResults().back(), &slices, {filter});
auto iter = llvm::find_if(slices, [](Operation *op) {
return isa<triton::nvidia_gpu::ConsumerReleaseOp>(op);
});
if (iter != slices.end()) {
unlockLocs[i] = *iter;
}
}
}
// Only cases where all lock/unlock locations are in same level make sense.
for (int i = 1; i < numRoles; ++i) {
if (lockLocs[i]->getParentOp() != lockLocs[i - 1]->getParentOp() ||
@@ -281,6 +248,54 @@ void mutexSync(ModuleOp &mod, scf::IfOp &ifOp, scf::ForOp &persistentForOp,
else
lockLocs[i] = unlockLocs[prevTypeIds[i]];
}
// Update lockLocs
// ====================== IR after async launch dots ======================
// * %0:2 = scf.for %arg0 = %c0 to %1 step %c1 iter_args(%arg1 = %2, arg2 =
// %3) {
// * triton_nvidia_gpu.producer_wait arg2
// * %5 = triton_nvidia_gpu.dot_async %4, %5
// * triton_nvidia_gpu.dot_wait {pendings = 1}
// * %6 = arith.cmpi sgt, arg0, %c0
// * scf.if %6 {
// * %7 = arith.subi arg2, c1
// * triton_nvidia_gpu.consumer_release %7
// * }
// * %8 = arith.addi arg2, c1
// * scf.yield %5, %8
// * }
// * triton_nvidia_gpu.dot_wait {pendings = 0}
// * ...
// * triton_nvidia_gpu.consumer_release ..
// * =======================================================================
// after async launch dots, there will be outstanding consumerReleaseOp after
// ForOp. we should set the epilogue lockLocs after the outstanding
// consumerReleaseOp.
for (int i = 0; i < numRoles; ++i) {
Operation *lockOp = lockLocs[i];
if (isa<scf::ForOp>(lockOp)) {
Operation *loc = nullptr;
unsigned numOutstandingConsumerRelease = 0;
for (auto v : lockOp->getResults()) {
SetVector<Operation *> slices;
mlir::getForwardSlice(v, &slices);
auto iter = llvm::find_if(slices, [](Operation *op) {
return isa<triton::nvidia_gpu::ConsumerReleaseOp>(op);
});
if (iter != slices.end()) {
numOutstandingConsumerRelease++;
loc = *iter;
}
}
assert(numOutstandingConsumerRelease <= 1 &&
"should have only one outstanding "
"consumerReleaseOp after "
"async launch dots");
if (loc)
lockLocs[i] = loc;
}
}
// lock
for (int i = 0; i < numRoles; ++i) {
builder.setInsertionPointAfter(lockLocs[i]);

View File

@@ -129,11 +129,12 @@ DenseMap<Channel *, Value> createBuffer(const SmallVector<Channel *> &channels,
}
//===----------------------------------------------------------------------===//
// appendPipelineIdxToLoopArgs
// createNewLoops
//===----------------------------------------------------------------------===//
scf::ForOp appendPipelineIdxToLoopArgs(scf::ForOp forOp, int numStages,
scf::ForOp &parentForOp) {
// for(...) -> for(..., pipelineIdx)
scf::ForOp createNewPersistentLoop(scf::ForOp forOp, int numStages,
scf::ForOp &parentForOp) {
auto loc = forOp.getLoc();
Block *body = forOp.getBody();
@@ -200,6 +201,117 @@ scf::ForOp appendPipelineIdxToLoopArgs(scf::ForOp forOp, int numStages,
return newForOp;
}
// for(...) -> for(..., phase, pipelineIdx)
scf::ForOp createNewMathLoop(scf::ForOp forOp, int numStages,
scf::ForOp &parentForOp) {
auto loc = forOp.getLoc();
Block *body = forOp.getBody();
// The agentId set of pipelineIdx is the union of agentId sets of all ops in
// the for loop
OpBuilderWithAgentIds builder(forOp.getContext());
builder.setAgentIdsFromArray(collectAgentIds(forOp));
builder.setInsertionPoint(forOp);
Value numStagesVal =
builder.createWithAgentIds<arith::ConstantIntOp>(loc, numStages, 32);
// 0. Append pipelineIdx to block arguments
Value phase =
body->insertArgument(body->getNumArguments(), builder.getI1Type(), loc);
Value pipelineIdx =
body->insertArgument(body->getNumArguments(), builder.getI32Type(), loc);
// 1. prepare index and phase for next iteration
// nextIdx = curIdx + 1
// nextPhase = ((nextIdx < numStages && curPhase) || (nextIdx >= numStages &&
// curPhase^1))
// nextIdx = nextIdx >= numStages ? 0 : nextIdx
auto yieldOp = llvm::cast<scf::YieldOp>(body->getTerminator());
builder.setInsertionPoint(yieldOp);
Value one = builder.createWithAgentIds<arith::ConstantIntOp>(loc, 1, 32);
Value zero = builder.createWithAgentIds<arith::ConstantIntOp>(loc, 0, 32);
Value _1_1b = builder.createWithAgentIds<arith::ConstantIntOp>(loc, 1, 1);
// generate index for next iter
Value nextPipelineIdx =
builder.createWithAgentIds<arith::AddIOp>(loc, pipelineIdx, one);
Value pipelineGECond = builder.createWithAgentIds<arith::CmpIOp>(
loc, arith::CmpIPredicate::uge, nextPipelineIdx, numStagesVal);
Value pipelineLTCond = builder.createWithAgentIds<arith::CmpIOp>(
loc, arith::CmpIPredicate::ult, nextPipelineIdx, numStagesVal);
Value cyclePipelineIdx = builder.createWithAgentIds<arith::SubIOp>(
loc, nextPipelineIdx, numStagesVal);
nextPipelineIdx = builder.createWithAgentIds<mlir::arith::SelectOp>(
loc, pipelineGECond, cyclePipelineIdx, nextPipelineIdx);
// generate phase for next iter
Value flipPhase =
builder.createWithAgentIds<mlir::arith::XOrIOp>(loc, phase, _1_1b);
Value cond0 = builder.createWithAgentIds<mlir::arith::AndIOp>(
loc, pipelineGECond, flipPhase);
Value cond1 = builder.createWithAgentIds<mlir::arith::AndIOp>(
loc, pipelineLTCond, phase);
Value nextPhase =
builder.createWithAgentIds<mlir::arith::OrIOp>(loc, cond0, cond1);
// 2. Append pipelineIdx to yield operands
yieldOp->insertOperands(yieldOp.getNumOperands(),
{nextPhase, nextPipelineIdx});
// 3. create newLoopArgs
SmallVector<Value> newLoopArgs;
for (auto operand : forOp.getInitArgs())
newLoopArgs.push_back(operand);
builder.setInsertionPoint(forOp);
Value initPipelineIdx, initEmptyIdx, initPhase;
zero = builder.createWithAgentIds<arith::ConstantIntOp>(loc, 0, 32);
if (parentForOp) {
// Make sure prior pipelineIdx is inserted in the end of parentForOp
initPipelineIdx = parentForOp.getBody()->getArguments().back();
Value numSteps = builder.createWithAgentIds<arith::SubIOp>(
loc, forOp.getUpperBound(), forOp.getLowerBound());
numSteps = builder.createWithAgentIds<arith::AddIOp>(loc, numSteps,
forOp.getStep());
Value one = builder.createWithAgentIds<arith::ConstantIntOp>(loc, 1, 32);
Value two = builder.createWithAgentIds<arith::ConstantIntOp>(loc, 2, 32);
numSteps = builder.createWithAgentIds<arith::SubIOp>(loc, numSteps, one);
numSteps = builder.createWithAgentIds<arith::DivUIOp>(loc, numSteps,
forOp.getStep());
// initPipelineIdx = (parentForOp.pipelineIdx * numSteps) % numStages
// initPhase = ((parentForOp.pipelineIdx * numSteps) / numStages) & 1
initPipelineIdx = builder.createWithAgentIds<arith::MulIOp>(
loc, initPipelineIdx, numSteps);
Value pipelineIdx = builder.createWithAgentIds<arith::DivUIOp>(
loc, initPipelineIdx, numStagesVal);
initPipelineIdx = builder.createWithAgentIds<arith::SubIOp>(
loc, initPipelineIdx,
builder.createWithAgentIds<arith::MulIOp>(loc, pipelineIdx,
numStagesVal));
pipelineIdx =
builder.createWithAgentIds<arith::AndIOp>(loc, pipelineIdx, one);
initPhase = builder.createWithAgentIds<arith::TruncIOp>(
loc, builder.getI1Type(), pipelineIdx);
} else {
// phase init to false and pipelineIdx init to 0
initPipelineIdx = zero;
initPhase = builder.createWithAgentIds<arith::ConstantIntOp>(loc, 0, 1);
}
newLoopArgs.append({initPhase, initPipelineIdx});
// 4. Create newForOp and take the region of forOp
auto newForOp = builder.createWithAgentIds<scf::ForOp>(
loc, forOp.getLowerBound(), forOp.getUpperBound(), forOp.getStep(),
newLoopArgs);
newForOp.getRegion().takeBody(forOp.getRegion());
// 5. Replace forOp with newForOp
for (unsigned i = 0; i < forOp.getNumResults(); ++i)
forOp.getResult(i).replaceAllUsesWith(newForOp.getResult(i));
forOp.erase();
return newForOp;
}
//===----------------------------------------------------------------------===//
// appendPipelineIdxArgs
//===----------------------------------------------------------------------===//
@@ -217,7 +329,22 @@ void appendPipelineIdxArgs(SmallVector<Operation *> &backbone, int numStages) {
for (auto &op : orderedForOps) {
scf::ForOp parentForOp = op->getParentOfType<scf::ForOp>();
auto newForOp = appendPipelineIdxToLoopArgs(op, numStages, parentForOp);
scf::ForOp newForOp;
bool hasDotOp = false;
for (Operation &subOp : *op.getBody()) {
if (isa<triton::DotOp>(&subOp)) {
hasDotOp = true;
break;
}
}
if (hasDotOp) {
// for(...) -> for(..., phase, pipelineIdx)
newForOp = createNewMathLoop(op, numStages, parentForOp);
} else {
// for(...) -> for(..., pipelineIdx)
newForOp = createNewPersistentLoop(op, numStages, parentForOp);
}
auto backboneForItr =
std::find(backbone.begin(), backbone.end(), op.getOperation());
if (backboneForItr != backbone.end()) {
@@ -688,8 +815,6 @@ void buildAsyncComm(const DenseMap<Operation *, SmallVector<Channel *>> &map,
}
builder.setAgentIdsFromArray(agentsPC);
Value pipelineIdx;
Value numStagesVal = builder.createWithAgentIds<arith::ConstantIntOp>(
headProducer->getLoc(), numStages, 32);
if (auto forOp = headProducer->getParentOfType<scf::ForOp>()) {
pipelineIdx = forOp.getBody()->getArguments().back();
} else {
@@ -700,10 +825,6 @@ void buildAsyncComm(const DenseMap<Operation *, SmallVector<Channel *>> &map,
// insert ProducerAcquireOp
builder.setInsertionPoint(headProducer);
if (headProducer->getParentOfType<scf::ForOp>()) {
pipelineIdx = builder.createWithAgentIds<arith::RemSIOp>(
headProducer->getLoc(), pipelineIdx, numStagesVal);
}
builder.setAgentIdsFromArray(agentP);
builder.createWithAgentIds<ttng::ProducerAcquireOp>(headProducer->getLoc(),
token, pipelineIdx);
@@ -738,7 +859,8 @@ void buildAsyncComm(const DenseMap<Operation *, SmallVector<Channel *>> &map,
loc, dotOp.getA(), dotOp.getB(), dotOp.getC(),
dotOp.getAllowTF32(), dotOp.getMaxNumImpreciseAcc());
dot.replaceAllUsesWith(dotAsync.getResult());
builder.createWithAgentIds<triton::nvidia_gpu::DotWaitOp>(loc, 1);
builder.createWithAgentIds<triton::nvidia_gpu::DotWaitOp>(
loc, dotAsync.getResult(), 1);
// 1. insert ConsumerReleaseOp for DotAsyncOps
Value cond = builder.createWithAgentIds<arith::CmpIOp>(
@@ -747,31 +869,43 @@ void buildAsyncComm(const DenseMap<Operation *, SmallVector<Channel *>> &map,
auto ifOp =
builder.createWithAgentIds<scf::IfOp>(loc, ArrayRef<Type>{}, cond,
/*hasElse*/ false);
setAgentIds(ifOp.thenYield().getOperation(), agentIds);
builder.setInsertionPointToStart(ifOp.thenBlock());
Value one = builder.createWithAgentIds<arith::ConstantIntOp>(
headConsumer->getLoc(), 1, 32);
auto oriIdx = forOp.getBody()->getArguments().back();
Value consumerReleaseIdx =
builder.createWithAgentIds<arith::SubIOp>(loc, oriIdx, one);
consumerReleaseIdx = builder.createWithAgentIds<arith::RemSIOp>(
loc, consumerReleaseIdx, numStagesVal);
Value consumerReleaseIdx = forOp.getBody()->getArguments().back();
Value zero = builder.createWithAgentIds<arith::ConstantIntOp>(loc, 0, 32);
Value one = builder.createWithAgentIds<arith::ConstantIntOp>(loc, 1, 32);
Value lastStage = builder.createWithAgentIds<arith::ConstantIntOp>(
loc, numStages - 1, 32);
Value consumerReleaseIdxMinusOne =
builder.createWithAgentIds<arith::SubIOp>(loc, consumerReleaseIdx,
one);
cond = builder.createWithAgentIds<arith::CmpIOp>(
loc, arith::CmpIPredicate::eq, consumerReleaseIdx, zero);
consumerReleaseIdx = builder.createWithAgentIds<arith::SelectOp>(
loc, cond, lastStage, consumerReleaseIdxMinusOne);
builder.createWithAgentIds<ttng::ConsumerReleaseOp>(loc, token,
consumerReleaseIdx);
setAgentIds(ifOp.thenYield().getOperation(), agentIds);
// 2. If there's any outstanding DotAsyncOps, we need to wait for them.
builder.setInsertionPointAfter(forOp);
builder.createWithAgentIds<triton::nvidia_gpu::DotWaitOp>(forOp.getLoc(),
0);
unsigned resultIndex = dotAsync->getUses().begin()->getOperandNumber();
Value result = forOp->getResult(resultIndex);
auto dotWait = builder.createWithAgentIds<triton::nvidia_gpu::DotWaitOp>(
forOp.getLoc(), result, 0);
result.replaceAllUsesExcept(dotWait.getResult(), dotWait);
// 3. insert ConsumerReleaseOp for outstanding DotAsyncOps
Value one_ = builder.createWithAgentIds<arith::ConstantIntOp>(
headConsumer->getLoc(), 1, 32);
zero = builder.createWithAgentIds<arith::ConstantIntOp>(loc, 0, 32);
one = builder.createWithAgentIds<arith::ConstantIntOp>(loc, 1, 32);
lastStage = builder.createWithAgentIds<arith::ConstantIntOp>(
loc, numStages - 1, 32);
consumerReleaseIdx = forOp.getResults().back();
consumerReleaseIdx = builder.createWithAgentIds<arith::SubIOp>(
loc, consumerReleaseIdx, one_);
consumerReleaseIdx = builder.createWithAgentIds<arith::RemSIOp>(
loc, consumerReleaseIdx, numStagesVal);
consumerReleaseIdxMinusOne = builder.createWithAgentIds<arith::SubIOp>(
loc, consumerReleaseIdx, one);
cond = builder.createWithAgentIds<arith::CmpIOp>(
loc, arith::CmpIPredicate::eq, consumerReleaseIdx, zero);
consumerReleaseIdx = builder.createWithAgentIds<arith::SelectOp>(
loc, cond, lastStage, consumerReleaseIdxMinusOne);
builder.createWithAgentIds<ttng::ConsumerReleaseOp>(loc, token,
consumerReleaseIdx);
dotOp->erase();

View File

@@ -14,7 +14,6 @@ add_mlir_translation_library(TritonLLVMIR
PUBLIC
MLIRArithToLLVM
MLIRBuiltinToLLVMIRTranslation
MLIRExecutionEngineUtils
MLIRIndexToLLVM
MLIRIR
MLIRLLVMDialect

View File

@@ -44,7 +44,8 @@ static bool findAndReplace(std::string &str, const std::string &begin,
return true;
}
std::string translateLLVMIRToPTX(llvm::Module &module, int cc, int version) {
std::string translateLLVMIRToPTX(llvm::Module &module, int cc, int version,
bool enable_fp_fusion) {
// LLVM version in use may not officially support target hardware.
// Supported versions for LLVM 14 are here:
// https://github.com/llvm/llvm-project/blob/f28c006a5895fc0e329fe15fead81e37457cb1d1/clang/include/clang/Basic/BuiltinsNVPTX.def
@@ -84,13 +85,15 @@ std::string translateLLVMIRToPTX(llvm::Module &module, int cc, int version) {
auto target =
llvm::TargetRegistry::lookupTarget(module.getTargetTriple(), error);
llvm::TargetOptions opt;
opt.AllowFPOpFusion = llvm::FPOpFusion::Fast;
if (enable_fp_fusion)
opt.AllowFPOpFusion = llvm::FPOpFusion::Fast;
opt.UnsafeFPMath = false;
opt.NoInfsFPMath = false;
opt.NoNaNsFPMath = true;
opt.TrapUnreachable = true;
std::unique_ptr<llvm::TargetMachine> machine{target->createTargetMachine(
module.getTargetTriple(), proc, features, opt, llvm::Reloc::PIC_,
std::nullopt, llvm::CodeGenOpt::Aggressive)};
std::nullopt, llvm::CodeGenOptLevel::Aggressive)};
// set data layout
if (layout.empty())
module.setDataLayout(machine->createDataLayout());
@@ -106,7 +109,7 @@ std::string translateLLVMIRToPTX(llvm::Module &module, int cc, int version) {
llvm::legacy::PassManager pass;
// emit
machine->addPassesToEmitFile(pass, pstream, nullptr,
llvm::CodeGenFileType::CGFT_AssemblyFile);
llvm::CodeGenFileType::AssemblyFile);
pass.run(module);
}
// post-process

View File

@@ -73,23 +73,20 @@ def get_llvm_package_info():
if arch == 'aarch64':
arch = 'arm64'
if system == "Darwin":
system_suffix = "apple-darwin"
arch = platform.machine()
system_suffix = f"macos-{arch}"
elif system == "Linux":
# TODO: arm64
vglibc = tuple(map(int, platform.libc_ver()[1].split('.')))
vglibc = vglibc[0] * 100 + vglibc[1]
linux_suffix = 'ubuntu-18.04' if vglibc > 217 else 'centos-7'
system_suffix = f"linux-gnu-{linux_suffix}"
system_suffix = 'ubuntu-x64' if vglibc > 217 else 'centos-x64'
else:
return Package("llvm", "LLVM-C.lib", "", "LLVM_INCLUDE_DIRS", "LLVM_LIBRARY_DIR", "LLVM_SYSPATH")
use_assert_enabled_llvm = check_env_flag("TRITON_USE_ASSERT_ENABLED_LLVM", "False")
release_suffix = "assert" if use_assert_enabled_llvm else "release"
name = f'llvm+mlir-17.0.0-{arch}-{system_suffix}-{release_suffix}'
version = "llvm-17.0.0-c5dede880d17"
url = f"https://github.com/ptillet/triton-llvm-releases/releases/download/{version}/{name}.tar.xz"
# FIXME: remove the following once github.com/ptillet/triton-llvm-releases has arm64 llvm releases
if arch == 'arm64' and 'linux' in system_suffix:
url = f"https://github.com/acollins3/triton-llvm-releases/releases/download/{version}/{name}.tar.xz"
# use_assert_enabled_llvm = check_env_flag("TRITON_USE_ASSERT_ENABLED_LLVM", "False")
# release_suffix = "assert" if use_assert_enabled_llvm else "release"
rev = "b1115f8c"
name = f"llvm-{rev}-{system_suffix}"
url = f"https://tritonlang.blob.core.windows.net/llvm-builds/{name}.tar.gz"
return Package("llvm", name, url, "LLVM_INCLUDE_DIRS", "LLVM_LIBRARY_DIR", "LLVM_SYSPATH")

View File

@@ -1236,7 +1236,7 @@ void init_triton_ir(py::module &&m) {
.def("create_minf",
[](TritonOpBuilder &self, mlir::Value &lhs,
mlir::Value &rhs) -> mlir::Value {
return mlir::Value(self.create<mlir::arith::MinFOp>(lhs, rhs));
return mlir::Value(self.create<mlir::arith::MinimumFOp>(lhs, rhs));
})
.def("create_maxsi",
[](TritonOpBuilder &self, mlir::Value &lhs,
@@ -1251,7 +1251,7 @@ void init_triton_ir(py::module &&m) {
.def("create_maxf",
[](TritonOpBuilder &self, mlir::Value &lhs,
mlir::Value &rhs) -> mlir::Value {
return mlir::Value(self.create<mlir::arith::MaxFOp>(lhs, rhs));
return mlir::Value(self.create<mlir::arith::MaximumFOp>(lhs, rhs));
})
// AddPtr (similar to GEP)
.def("create_addptr",
@@ -2006,7 +2006,8 @@ void init_triton_translation(py::module &m) {
m.def(
"translate_llvmir_to_ptx",
[](const std::string llvmIR, int capability, int version) -> std::string {
[](const std::string llvmIR, int capability, int version,
bool enable_fp_fusion) -> std::string {
py::gil_scoped_release allow_threads;
// create LLVM module from C++
llvm::LLVMContext context;
@@ -2021,75 +2022,77 @@ void init_triton_translation(py::module &m) {
"lineno: " + std::to_string(error.getLineNo()));
}
// translate module to PTX
auto ptxCode =
triton::translateLLVMIRToPTX(*module, capability, version);
auto ptxCode = triton::translateLLVMIRToPTX(*module, capability,
version, enable_fp_fusion);
return ptxCode;
},
ret::take_ownership);
m.def(
"compile_ptx_to_cubin",
[](const std::string &ptxCode, const std::string &ptxasPath,
int capability) -> py::object {
std::string cubin;
{
py::gil_scoped_release allow_threads;
m.def("compile_ptx_to_cubin",
[](const std::string &ptxCode, const std::string &ptxasPath,
int capability, bool enable_fp_fusion) -> py::object {
std::string cubin;
{
py::gil_scoped_release allow_threads;
// compile ptx with ptxas
llvm::SmallString<64> fsrc;
llvm::SmallString<64> flog;
llvm::sys::fs::createTemporaryFile("compile-ptx-src", "", fsrc);
llvm::sys::fs::createTemporaryFile("compile-ptx-log", "", flog);
std::string fbin = std::string(fsrc) + ".o";
llvm::FileRemover logRemover(flog);
llvm::FileRemover binRemover(fbin);
const char *_fsrc = fsrc.c_str();
const char *_flog = flog.c_str();
const char *_fbin = fbin.c_str();
std::ofstream ofs(_fsrc);
ofs << ptxCode << std::endl;
ofs.close();
// compile ptx with ptxas
llvm::SmallString<64> fsrc;
llvm::SmallString<64> flog;
llvm::sys::fs::createTemporaryFile("compile-ptx-src", "", fsrc);
llvm::sys::fs::createTemporaryFile("compile-ptx-log", "", flog);
std::string fbin = std::string(fsrc) + ".o";
llvm::FileRemover logRemover(flog);
llvm::FileRemover binRemover(fbin);
const char *_fsrc = fsrc.c_str();
const char *_flog = flog.c_str();
const char *_fbin = fbin.c_str();
std::ofstream ofs(_fsrc);
ofs << ptxCode << std::endl;
ofs.close();
auto lineInfoOption =
triton::tools::getBoolEnv("TRITON_DISABLE_LINE_INFO")
? ""
: " -lineinfo";
auto capabilitySuffix = (capability == 90) ? "a " : " ";
auto outputFileName = std::string(_fsrc) + ".o";
auto logRedirect = " 2> " + std::string(_flog);
std::string cmd = ptxasPath + lineInfoOption + " -v --gpu-name=sm_" +
std::to_string(capability) + capabilitySuffix +
_fsrc + " -o " + outputFileName + logRedirect;
auto lineInfoOption =
triton::tools::getBoolEnv("TRITON_DISABLE_LINE_INFO")
? ""
: " -lineinfo";
auto fmadOption = enable_fp_fusion ? "" : " --fmad=false";
auto capabilitySuffix = (capability == 90) ? "a " : " ";
auto outputFileName = std::string(_fsrc) + ".o";
auto logRedirect = " 2> " + std::string(_flog);
std::string cmd = ptxasPath + lineInfoOption + fmadOption +
" -v --gpu-name=sm_" +
std::to_string(capability) + capabilitySuffix +
_fsrc + " -o " + outputFileName + logRedirect;
int err = system(cmd.c_str());
if (err != 0) {
err >>= 8;
std::ifstream _log(_flog);
std::string log(std::istreambuf_iterator<char>(_log), {});
if (err == 255) {
throw std::runtime_error("Internal Triton PTX codegen error: \n" +
log);
} else if (err == 128 + SIGSEGV) {
throw std::runtime_error("Please run `ptxas " + fsrc.str().str() +
"` to confirm that this is a "
"bug in `ptxas`\n" +
log);
int err = system(cmd.c_str());
if (err != 0) {
err >>= 8;
std::ifstream _log(_flog);
std::string log(std::istreambuf_iterator<char>(_log), {});
if (err == 255) {
throw std::runtime_error(
"Internal Triton PTX codegen error: \n" + log);
} else if (err == 128 + SIGSEGV) {
throw std::runtime_error("Please run `ptxas " +
fsrc.str().str() +
"` to confirm that this is a "
"bug in `ptxas`\n" +
log);
} else {
throw std::runtime_error("`ptxas` failed with error code " +
std::to_string(err) + ": \n" + log);
}
return {};
} else {
throw std::runtime_error("`ptxas` failed with error code " +
std::to_string(err) + ": \n" + log);
llvm::FileRemover srcRemover(fsrc);
std::ifstream _cubin(_fbin, std::ios::binary);
cubin = std::string(std::istreambuf_iterator<char>(_cubin), {});
_cubin.close();
// Do not return here, exit the gil scope and return below
}
return {};
} else {
llvm::FileRemover srcRemover(fsrc);
std::ifstream _cubin(_fbin, std::ios::binary);
cubin = std::string(std::istreambuf_iterator<char>(_cubin), {});
_cubin.close();
// Do not return here, exit the gil scope and return below
}
}
py::bytes bytes(cubin);
return std::move(bytes);
});
py::bytes bytes(cubin);
return std::move(bytes);
});
m.def("add_external_libs",
[](mlir::ModuleOp &op, const std::vector<std::string> &names,

View File

@@ -860,9 +860,9 @@ def test_unary_op(dtype_x, expr, num_ctas, device):
# ----------------
@pytest.mark.parametrize("dtype_x, expr", [(dtype_x, expr) for dtype_x in ["float32", "float64"] for expr in ['exp', 'log', 'cos', 'sin']])
def test_math_op(dtype_x, expr, device):
_test_unary(dtype_x, f'tl.{expr}(x)', f'np.{expr}(x) ', device=device)
@pytest.mark.parametrize("dtype_x, expr, x", [(dtype_x, expr, x) for dtype_x in ["float32", "float64"] for expr in ['exp', 'log', 'cos', 'sin'] for x in ['x', '3.0']])
def test_math_op(dtype_x, expr, device, x):
_test_unary(dtype_x, f'tl.{expr}({x})', f'np.{expr}({x}) ', device=device)
# ----------------
# test abs
@@ -1662,10 +1662,18 @@ reduce_configs2 = [
for op in ['min', 'max', 'sum']
]
reduce3d_shapes = [(2, 32, 16), (32, 2, 16), (32, 16, 2)]
reduce_configs3 = [
(op, 'float32', shape, axis)
for op in ['min', 'max', 'sum', 'argmin', 'argmax']
for shape in reduce3d_shapes
for axis in [0, 1, 2]
]
@pytest.mark.parametrize("op, dtype_str, shape, axis", reduce_configs1 + reduce_configs2)
@pytest.mark.parametrize("op, dtype_str, shape, axis", reduce_configs1 + reduce_configs2 + reduce_configs3)
@pytest.mark.parametrize("num_ctas", num_ctas_list)
def test_reduce2d(op, dtype_str, shape, axis, num_ctas, device):
def test_reduce(op, dtype_str, shape, axis, num_ctas, device):
check_type_supported(dtype_str, device) # bfloat16 on cc < 80 will not be tested
if is_hip():
@@ -1673,17 +1681,31 @@ def test_reduce2d(op, dtype_str, shape, axis, num_ctas, device):
# triton kernel
@triton.jit
def kernel(X, Z, BLOCK_M: tl.constexpr, BLOCK_N: tl.constexpr, AXIS: tl.constexpr):
def kernel(X, Z, BLOCK_M: tl.constexpr, BLOCK_N: tl.constexpr, BLOCK_K: tl.constexpr, IS_3D: tl.constexpr, AXIS: tl.constexpr):
range_m = tl.arange(0, BLOCK_M)
range_n = tl.arange(0, BLOCK_N)
x = tl.load(X + range_m[:, None] * BLOCK_N + range_n[None, :])
z = GENERATE_TEST_HERE
if AXIS is None:
tl.store(Z, z)
elif AXIS == 1:
tl.store(Z + range_m, z)
range_k = tl.arange(0, BLOCK_K)
if IS_3D:
x = tl.load(X + range_m[:, None, None] * BLOCK_N * BLOCK_K + range_n[None, :, None] * BLOCK_K + range_k[None, None, :])
else:
tl.store(Z + range_n, z)
x = tl.load(X + range_m[:, None] * BLOCK_N + range_n[None, :])
z = GENERATE_TEST_HERE
if IS_3D:
if AXIS is None:
tl.store(Z, z)
elif AXIS == 0:
tl.store(Z + range_n[:, None] * BLOCK_K + range_k[None, :], z)
elif AXIS == 1:
tl.store(Z + range_m[:, None] * BLOCK_K + range_k[None, :], z)
else:
tl.store(Z + range_m[:, None] * BLOCK_N + range_n[None, :], z)
else:
if AXIS is None:
tl.store(Z, z)
elif AXIS == 0:
tl.store(Z + range_n, z)
else:
tl.store(Z + range_m, z)
kernel = patch_kernel(kernel, {'GENERATE_TEST_HERE': f'tl.{op}(x, axis=AXIS)'})
# input
@@ -1706,10 +1728,13 @@ def test_reduce2d(op, dtype_str, shape, axis, num_ctas, device):
z_ref = numpy_op(x, axis=axis).astype(getattr(np, z_dtype_str))
# triton result
ret_numel = 1 if axis is None else shape[1 - axis]
z_tri = to_triton(numpy_random((ret_numel,), dtype_str=z_dtype_str, rs=rs),
z_shape = (1,) if axis is None else tuple(shape_i for i, shape_i in enumerate(shape) if i != axis)
z_tri = to_triton(numpy_random(z_shape, dtype_str=z_dtype_str, rs=rs),
device=device, dst_type=z_tri_dtype_str)
BLOCK_K = 1 if len(shape) == 2 else shape[2]
IS_3D = bool(len(shape) == 3)
kernel[(1,)](x_tri, z_tri, BLOCK_M=shape[0],
BLOCK_N=shape[1], AXIS=axis, num_ctas=num_ctas)
BLOCK_N=shape[1], BLOCK_K=BLOCK_K, IS_3D=IS_3D, AXIS=axis, num_ctas=num_ctas)
z_tri = to_numpy(z_tri)
# compare
if op == 'sum':
@@ -1890,7 +1915,7 @@ def test_reduce_layouts(M, N, src_layout, axis, reduce2d, dtype_str, reduce_op,
ty = {"int32": "i32", "float32": "f32", "float16": "f16"}[dtype_str]
arith_op = {
"max": {"int32": "arith.maxsi", "float32": "arith.maxf", "float16": "arith.maxf"},
"max": {"int32": "arith.maxsi", "float32": "arith.maximumf", "float16": "arith.maximumf"},
"sum": {"int32": "arith.addi", "float32": "arith.addf", "float16": "arith.addf"}
}[reduce_op][dtype_str]
numpy_op = {
@@ -3049,6 +3074,20 @@ def test_constexpr_scalar_shape(device):
kernel[(1,)](x_tri, 32)
np.testing.assert_equal(to_numpy(x_tri), np.arange(0, 256) % 8)
@triton.jit
def static_assert_func():
tl.static_assert(tl.constexpr(False), f"Assert is firing because the constexpr progation did not work properly")
def test_constexpr_propagation():
@triton.jit
def _kernel(COND: tl.constexpr):
NEW_COND = COND
if NEW_COND:
static_assert_func()
_kernel[(1,)](False)
# -------------
# test call
# -------------
@@ -3893,3 +3932,22 @@ def test_fp8_dot_acc(in_type_str, low_precision_acc, device):
torch.testing.assert_close(ref_out, C, rtol=1e-3, atol=1e-3)
else:
torch.testing.assert_close(ref_out, C)
# -----------------------
# test enable_fp_fusion
# -----------------------
@pytest.mark.parametrize("enable_fp_fusion", [False, True])
def test_enable_fp_fusion(enable_fp_fusion):
# Sequential multiply add can be fused by backend
@triton.jit
def mul_add(data):
ptrs = data + tl.arange(0, 128)
tl.store(ptrs, tl.load(ptrs) * 1.5 + 1.0)
data = torch.randn((128,), device='cuda', dtype=torch.float32)
h = mul_add[(1,)](data, enable_fp_fusion=enable_fp_fusion)
found_fma = re.search(r'(mad|fma)\.r[nzmp]\.(ftz\.)?f32', h.asm["ptx"]) is not None
assert found_fma == enable_fp_fusion

View File

@@ -228,6 +228,7 @@ class CodeGenerator(ast.NodeVisitor):
self.local_defs: Dict[str, tensor] = {}
self.global_uses: Dict[str, tensor] = {}
self.dereference_name: Callable[[str], Any] = self._define_name_lookup()
self.fn = None
builtin_namespace: Dict[str, Any] = {_.__name__: _ for _ in (range, float, int, isinstance, getattr)}
builtin_namespace.update((
@@ -322,6 +323,8 @@ class CodeGenerator(ast.NodeVisitor):
def visit_FunctionDef(self, node):
arg_names, kwarg_names = self.visit(node.args)
if self.fn:
raise UnsupportedLanguageConstruct(None, node, "nested function definition is not supported.")
# initialize defaults
for i, default_value in enumerate(node.args.defaults):
arg_node = node.args.args[-i - 1]
@@ -335,9 +338,9 @@ class CodeGenerator(ast.NodeVisitor):
self.visit(init_node)
# initialize function
visibility = "public" if self.is_kernel else "private"
fn = self.builder.get_or_insert_function(self.module, self.function_name, self.prototype.to_ir(self.builder), visibility, self.noinline)
self.module.push_back(fn)
entry = fn.add_entry_block()
self.fn = self.builder.get_or_insert_function(self.module, self.function_name, self.prototype.to_ir(self.builder), visibility, self.noinline)
self.module.push_back(self.fn)
entry = self.fn.add_entry_block()
arg_values = []
idx = 0
for i, arg_name in enumerate(arg_names):
@@ -350,8 +353,8 @@ class CodeGenerator(ast.NodeVisitor):
else:
if i in self.attributes:
for name, value in self.attributes[i]:
fn.set_arg_attr(idx, name, value)
arg_values.append(tensor(fn.args(idx), self.prototype.param_types[idx]))
self.fn.set_arg_attr(idx, name, value)
arg_values.append(tensor(self.fn.args(idx), self.prototype.param_types[idx]))
idx += 1
insert_pt = self.builder.get_insertion_block()
@@ -367,14 +370,14 @@ class CodeGenerator(ast.NodeVisitor):
# update return type
if isinstance(self.last_ret_type, tuple):
self.prototype.ret_types = list(self.last_ret_type)
fn.reset_type(self.prototype.to_ir(self.builder))
self.fn.reset_type(self.prototype.to_ir(self.builder))
else:
self.prototype.ret_types = [self.last_ret_type]
fn.reset_type(self.prototype.to_ir(self.builder))
self.fn.reset_type(self.prototype.to_ir(self.builder))
if insert_pt:
self.builder.set_insertion_point_to_end(insert_pt)
# Remove dead code
fn.finalize()
self.fn.finalize()
def visit_arguments(self, node):
arg_names = []
@@ -412,6 +415,9 @@ class CodeGenerator(ast.NodeVisitor):
raise UnsupportedLanguageConstruct(None, node, "simultaneous multiple assignment is not supported.")
names = _names[0]
values = self.visit(node.value)
if not isinstance(node.value, ast.Constant) and _is_constexpr(values):
self.set_value(names, values)
return
if not _is_list_like(names):
names = [names]
if not _is_list_like(values):
@@ -686,9 +692,13 @@ class CodeGenerator(ast.NodeVisitor):
for name in loop_defs:
if name in liveins:
# We should not def new constexpr
assert _is_triton_tensor(loop_defs[name])
assert _is_triton_tensor(liveins[name])
assert loop_defs[name].type == liveins[name].type
assert _is_triton_tensor(loop_defs[name]), f'cannoe reassign constxpr {name} in the loop'
assert _is_triton_tensor(liveins[name]), f'cannot reasign constexpr {name} in the loop'
assert loop_defs[name].type == liveins[name].type, \
f'Loop-carried variable {name} has initial type {liveins[name].type} '\
f'but is re-assigned to {loop_defs[name].type} in loop! '\
f'Please make sure that the type stays consistent.'
# these are loop-carried values
names.append(name)
ret_types.append(loop_defs[name].type)

View File

@@ -37,6 +37,7 @@ CUDA_DEFAULT_WARP_SIZE = 32
class CudaTargetDescriptor:
capability: int
num_warps: int
enable_fp_fusion: bool
def _is_cuda(target):
@@ -147,6 +148,7 @@ def optimize_ttgir(mod, num_stages, num_warps, num_ctas, target,
pm.add_tritongpu_wspipeline_pass(num_stages, num_warps, capability)
pm.add_tritongpu_wsmutex_pass(capability)
pm.add_tritongpu_wsmaterialization_pass(capability)
pm.add_licm_pass()
pm.add_cse_pass()
else:
if is_hip():
@@ -220,7 +222,7 @@ def llir_to_ptx(mod: Any, target: CudaTargetDescriptor, ptx_version: int = None)
if ptx_version is None:
_, cuda_version = path_to_ptxas()
ptx_version = ptx_get_version(cuda_version)
return translate_llvmir_to_ptx(mod, target.capability, ptx_version)
return translate_llvmir_to_ptx(mod, target.capability, ptx_version, target.enable_fp_fusion)
def ptx_to_cubin(ptx: str, target: CudaTargetDescriptor):
@@ -231,7 +233,7 @@ def ptx_to_cubin(ptx: str, target: CudaTargetDescriptor):
:return: str
'''
ptxas, _ = path_to_ptxas()
return compile_ptx_to_cubin(ptx, ptxas, target.capability)
return compile_ptx_to_cubin(ptx, ptxas, target.capability, target.enable_fp_fusion)
# ------------------------------------------------------------------------------
@@ -407,8 +409,12 @@ def compile(fn, **kwargs):
assert num_warps > 0 and (num_warps & (num_warps - 1)) == 0, "num_warps must be a power of 2"
num_ctas = kwargs.get("num_ctas", 1)
num_stages = kwargs.get("num_stages", get_arch_default_num_stages(device_type, capability=capability))
<<<<<<< HEAD
waves_per_eu = kwargs.get("waves_per_eu", 0)
matrix_instr_nonkdim = kwargs.get("matrix_instr_nonkdim", 0)
=======
enable_fp_fusion = kwargs.get("enable_fp_fusion", True)
>>>>>>> 721897fcc4f942aa97d2e9ba3787a5e213758177
# TODO[shuhaoj]: Default should be to enable warp specialization once possible
enable_warp_specialization = kwargs.get("enable_warp_specialization", False)
# TODO[shuhaoj]: persistent can be decoupled with warp specialization
@@ -431,7 +437,7 @@ def compile(fn, **kwargs):
# build architecture descriptor
if device_type == "cuda":
_device_backend = get_backend(device_type)
target = CudaTargetDescriptor(capability=get_cuda_capability(capability), num_warps=num_warps)
target = CudaTargetDescriptor(capability=get_cuda_capability(capability), num_warps=num_warps, enable_fp_fusion=enable_fp_fusion)
else:
_device_backend = get_backend(device_type)
assert _device_backend

View File

@@ -1057,14 +1057,19 @@ def load(pointer, mask=None, other=None, boundary_check=tuple(), padding_option=
"""
Return a tensor of data whose values are loaded from memory at location defined by `pointer`:
(1) `pointer` could be a single element pointer, then a scalar will be loaded
- `mask` and `other` must be scalar too
- `other` is implicitly typecast to `pointer.dtype.element_ty`
- `boundary_check` and `padding_option` must be empty
(2) `pointer` could be element-wise tensor of pointers, in which case:
- `mask` and `other` are implicitly broadcast to `pointer.shape`
- `other` is implicitly typecast to `pointer.dtype.element_ty`
- `boundary_check` and `padding_option` must be empty
(3) `pointer` could be a block pointer defined by `make_block_ptr`, in which case:
- `mask` and `other` must be None
- `boundary_check` and `padding_option` can be specified to control the behavior of out-of-bound access
@@ -1103,14 +1108,20 @@ def store(pointer, value, mask=None, boundary_check=(), cache_modifier="", evict
"""
Store a tensor of data into memory locations defined by `pointer`:
(1) `pointer` could be a single element pointer, then a scalar will be stored
- `mask` must be scalar too
- `boundary_check` and `padding_option` must be empty
(2) `pointer` could be element-wise tensor of pointers, in which case:
- `mask` is implicitly broadcast to `pointer.shape`
- `boundary_check` must be empty
(3) or `pointer` could be a block pointer defined by `make_block_ptr`, in which case:
- `mask` must be None
- `boundary_check` can be specified to control the behavior of out-of-bound access
`value` is implicitly broadcast to `pointer.shape` and typecast to `pointer.dtype.element_ty`.
:param pointer: The memory location where the elements of `value` are stored
@@ -1165,29 +1176,35 @@ def advance(base: tensor, offsets, _builder=None):
# -----------------------
def _add_atomic_docstr(name: str) -> Callable[[T], T]:
def _add_atomic_docstr(name: str, has_cmp: bool = False) -> Callable[[T], T]:
def _decorator(func: T) -> T:
docstr = """
docstr = f"""
Performs an atomic {name} at the memory location specified by :code:`pointer`.
Return the data stored at :code:`pointer` before the atomic operation.
:param pointer: The memory locations to compare-and-swap.
:type pointer: Block of dtype=triton.PointerDType
:param pointer: The memory locations to operate on
:type pointer: Block of dtype=triton.PointerDType"""
if has_cmp:
docstr += """
:param cmp: The values expected to be found in the atomic object
:type cmp: Block of dtype=`pointer.dtype.element_ty`
:param val: The values to copy in case the expected value matches the contained value.
:type val: Block of dtype=`pointer.dtype.element_ty`
:type cmp: Block of dtype=pointer.dtype.element_ty"""
docstr += """
:param val: The values with which to perform the atomic operation
:type val: Block of dtype=pointer.dtype.element_ty
:param sem: Memory semantics to use ("ACQUIRE_RELEASE" (default),
"ACQUIRE", "RELEASE", or "RELAXED")
:type sem: str
"""
func.__doc__ = docstr.format(name=name)
func.__doc__ = docstr
return func
return _decorator
@builtin
@_add_atomic_docstr("compare-and-swap")
@_add_atomic_docstr("compare-and-swap", has_cmp=True)
def atomic_cas(pointer, cmp, val, sem=None, _builder=None):
cmp = _to_tensor(cmp, _builder)
val = _to_tensor(val, _builder)
@@ -1309,6 +1326,8 @@ def fdiv(x, y, ieee_rounding=False, _builder=None):
:type ieee_rounding: bool
"""
ieee_rounding = _constexpr_to_value(ieee_rounding)
x = _to_tensor(x, _builder)
y = _to_tensor(y, _builder)
return semantic.fdiv(x, y, ieee_rounding, _builder)
@@ -1330,36 +1349,42 @@ def _add_math_1arg_docstr(name: str) -> Callable[[T], T]:
@builtin
@_add_math_1arg_docstr("exponential")
def exp(x, _builder=None):
x = _to_tensor(x, _builder)
return semantic.exp(x, _builder)
@builtin
@_add_math_1arg_docstr("natural logarithm")
def log(x, _builder=None):
x = _to_tensor(x, _builder)
return semantic.log(x, _builder)
@builtin
@_add_math_1arg_docstr("cosine")
def cos(x, _builder=None):
x = _to_tensor(x, _builder)
return semantic.cos(x, _builder)
@builtin
@_add_math_1arg_docstr("sine")
def sin(x, _builder=None):
x = _to_tensor(x, _builder)
return semantic.sin(x, _builder)
@builtin
@_add_math_1arg_docstr("square root")
def sqrt(x, _builder=None):
x = _to_tensor(x, _builder)
return semantic.sqrt(x, _builder)
@builtin
@_add_math_1arg_docstr("absolute value")
def abs(x, _builder=None):
x = _to_tensor(x, _builder)
return semantic.abs(x, _builder)

View File

@@ -453,7 +453,7 @@ def _unwrap(tensor):
builder = Builder()
RESERVED_KWS = ['num_warps', 'num_stages', 'num_ctas', 'enable_warp_specialization']
RESERVED_KWS = ['num_warps', 'num_stages', 'num_ctas', 'enable_warp_specialization', 'enable_fp_fusion']
class GridExecutor:

View File

@@ -281,13 +281,21 @@ class JITFunction(KernelInterface[T]):
constants = dict(zip(self.constexprs, constexpr_key))
return constants
<<<<<<< HEAD
def _call_hook(self, key, signature, device, constants, num_warps, num_ctas, num_stages, waves_per_eu, matrix_instr_nonkdim, enable_warp_specialization, extern_libs, configs):
=======
def _call_hook(self, key, signature, device, constants, num_warps, num_ctas, num_stages, enable_warp_specialization, enable_fp_fusion, extern_libs, configs):
>>>>>>> 721897fcc4f942aa97d2e9ba3787a5e213758177
if JITFunction.cache_hook is None:
return False
name = self.fn.__name__
module = self.fn.__module__
arg_reprs = ', '.join([f'{name}: {ty}' for name, ty in zip(self.arg_names, key[1])])
<<<<<<< HEAD
repr = f"{name}[num_warps={num_warps}, num_ctas={num_ctas}, num_stages={num_stages}, waves_per_eu={waves_per_eu}, matrix_instr_nonkdim={matrix_instr_nonkdim}, enable_warp_specialization={enable_warp_specialization}]({arg_reprs})"
=======
repr = f"{name}[num_warps={num_warps}, num_ctas={num_ctas}, num_stages={num_stages}, enable_warp_specialization={enable_warp_specialization}, enable_fp_fusion={enable_fp_fusion}]({arg_reprs})"
>>>>>>> 721897fcc4f942aa97d2e9ba3787a5e213758177
key = str(key)
class LegacyCompiler:
@@ -297,7 +305,11 @@ class JITFunction(KernelInterface[T]):
pass
kwargs = dict(signature=signature, device=device, constants=constants,
<<<<<<< HEAD
num_warps=num_warps, num_ctas=num_ctas, num_stages=num_stages, waves_per_eu=waves_per_eu, enable_warp_specialization=enable_warp_specialization, extern_libs=extern_libs,
=======
num_warps=num_warps, num_ctas=num_ctas, num_stages=num_stages, enable_warp_specialization=enable_warp_specialization, enable_fp_fusion=enable_fp_fusion, extern_libs=extern_libs,
>>>>>>> 721897fcc4f942aa97d2e9ba3787a5e213758177
configs=configs)
return JITFunction.cache_hook(key=key, repr=repr, fn=LegacyCompiler(module, name), compile={
@@ -351,7 +363,11 @@ class JITFunction(KernelInterface[T]):
def regular_args_v(args_proxy):
return [args_proxy[arg_name] for arg_name in regular_args]
<<<<<<< HEAD
def launcher_body(args_proxy, grid, num_warps, num_ctas, num_stages, waves_per_eu, matrix_instr_nonkdim, enable_warp_specialization, extern_libs, stream, warmup, device, device_type):
=======
def launcher_body(args_proxy, grid, num_warps, num_ctas, num_stages, enable_warp_specialization, enable_fp_fusion, extern_libs, stream, warmup, device, device_type):
>>>>>>> 721897fcc4f942aa97d2e9ba3787a5e213758177
from ..compiler import (CompiledKernel, compile,
get_arch_default_num_stages,
get_arch_default_num_warps)
@@ -402,7 +418,11 @@ class JITFunction(KernelInterface[T]):
if num_stages is None:
num_stages = get_arch_default_num_stages(device_type)
<<<<<<< HEAD
key = (version_key, sig_key, constexpr_key, spec_key, num_warps, num_ctas, num_stages, waves_per_eu, matrix_instr_nonkdim, enable_warp_specialization, self.debug)
=======
key = (version_key(), sig_key, constexpr_key, spec_key, num_warps, num_ctas, num_stages, enable_warp_specialization, enable_fp_fusion, self.debug)
>>>>>>> 721897fcc4f942aa97d2e9ba3787a5e213758177
if extern_libs is not None:
key = (key, tuple(extern_libs.items()))
@@ -430,8 +450,13 @@ class JITFunction(KernelInterface[T]):
for i, arg in constants.items():
if callable(arg):
raise TypeError(f"Callable constexpr at index {i} is not supported")
<<<<<<< HEAD
if not self._call_hook(key, signature, device, constants, num_warps, num_ctas, num_stages, waves_per_eu, matrix_instr_nonkdim, enable_warp_specialization, extern_libs, configs):
bin = compile(self, signature=signature, device=device, constants=constants, num_warps=num_warps, num_ctas=num_ctas, num_stages=num_stages, waves_per_eu=waves_per_eu, matrix_instr_nonkdim=matrix_instr_nonkdim, enable_warp_specialization=enable_warp_specialization, extern_libs=extern_libs, configs=configs, debug=self.debug, device_type=device_type)
=======
if not self._call_hook(key, signature, device, constants, num_warps, num_ctas, num_stages, enable_warp_specialization, enable_fp_fusion, extern_libs, configs):
bin = compile(self, signature=signature, device=device, constants=constants, num_warps=num_warps, num_ctas=num_ctas, num_stages=num_stages, enable_warp_specialization=enable_warp_specialization, enable_fp_fusion=enable_fp_fusion, extern_libs=extern_libs, configs=configs, debug=self.debug, device_type=device_type)
>>>>>>> 721897fcc4f942aa97d2e9ba3787a5e213758177
# Create tensormaps and append to args
args = bin.assemble_tensormap_to_arg(args)
if not warmup:
@@ -446,8 +471,13 @@ class JITFunction(KernelInterface[T]):
args_signature = args_signature + ', ' if len(args_signature) > 0 else ''
src = f"""
import triton
<<<<<<< HEAD
def {self.fn.__name__}({args_signature}grid=None, num_warps=None, num_ctas=1, num_stages=None, waves_per_eu=0, matrix_instr_nonkdim=0, enable_warp_specialization=False, extern_libs=None, stream=None, warmup=False, device=None, device_type=None):
return launcher_body({{{args_map}}}, grid, num_warps, num_ctas, num_stages, waves_per_eu, matrix_instr_nonkdim, enable_warp_specialization, extern_libs, stream, warmup, device, device_type)
=======
def {self.fn.__name__}({args_signature}grid=None, num_warps=None, num_ctas=1, num_stages=None, enable_warp_specialization=False, enable_fp_fusion=True, extern_libs=None, stream=None, warmup=False, device=None, device_type=None):
return launcher_body({{{args_map}}}, grid, num_warps, num_ctas, num_stages, enable_warp_specialization, enable_fp_fusion, extern_libs, stream, warmup, device, device_type)
>>>>>>> 721897fcc4f942aa97d2e9ba3787a5e213758177
"""
scope = {"launcher_body": launcher_body}
exec(src, scope)

View File

@@ -32,13 +32,17 @@ def _attn_fwd_inner(
STAGE: tl.constexpr,
offs_m: tl.constexpr,
offs_n: tl.constexpr,
N_CTX: tl.constexpr,
):
# range of values handled by this stage
if STAGE == 1:
lo, hi = 0, start_m * BLOCK_M
else:
elif STAGE == 2:
lo, hi = start_m * BLOCK_M, (start_m + 1) * BLOCK_M
lo = tl.multiple_of(lo, BLOCK_M)
# causal = False
else:
lo, hi = 0, N_CTX
K_block_ptr = tl.advance(K_block_ptr, (0, lo))
V_block_ptr = tl.advance(V_block_ptr, (lo, 0))
# loop over k, v and update accumulator
@@ -72,6 +76,7 @@ def _attn_fwd_inner(
K_block_ptr = tl.advance(K_block_ptr, (0, BLOCK_N))
return acc, l_i, m_i
<<<<<<< HEAD
@triton.jit
def _attn_fwd_inner(
acc, l_i, m_i, q,
@@ -138,6 +143,31 @@ def _attn_fwd_inner(
],
key=['Z', 'H', 'N_CTX', 'STAGE', 'BLOCK_DMODEL'],
)
=======
# We don't run auto-tuning everytime to keep the tutorial fast. Uncommenting
# the code below and commenting out the equivalent parameters is convenient for
# re-tuning.
# @triton.autotune(
# configs=[
# triton.Config({'BLOCK_M': 128, 'BLOCK_N': 64}, num_stages=4, num_warps=8),
# triton.Config({'BLOCK_M': 256, 'BLOCK_N': 64}, num_stages=3, num_warps=8),
# triton.Config({'BLOCK_M': 256, 'BLOCK_N': 32}, num_stages=3, num_warps=8),
# triton.Config({'BLOCK_M': 256, 'BLOCK_N': 32}, num_stages=3, num_warps=4),
# triton.Config({'BLOCK_M': 128, 'BLOCK_N': 32}, num_stages=3, num_warps=4),
# triton.Config({'BLOCK_M': 128, 'BLOCK_N': 32}, num_stages=4, num_warps=4),
# triton.Config({'BLOCK_M': 128, 'BLOCK_N': 64}, num_stages=3, num_warps=4),
# triton.Config({'BLOCK_M': 128, 'BLOCK_N': 64}, num_stages=4, num_warps=4),
# triton.Config({'BLOCK_M': 128, 'BLOCK_N': 64}, num_stages=3, num_warps=8),
# triton.Config({'BLOCK_M': 128, 'BLOCK_N': 64}, num_stages=7, num_warps=8),
# triton.Config({'BLOCK_M': 128, 'BLOCK_N': 32}, num_stages=7, num_warps=8),
# triton.Config({'BLOCK_M': 128, 'BLOCK_N': 32}, num_stages=6, num_warps=8),
# triton.Config({'BLOCK_M': 128, 'BLOCK_N': 32}, num_stages=5, num_warps=8),
# triton.Config({'BLOCK_M': 128, 'BLOCK_N': 32}, num_stages=4, num_warps=8),
# triton.Config({'BLOCK_M': 128, 'BLOCK_N': 64}, num_stages=6, num_warps=4),
# ],
# key=['N_CTX'],
# )
>>>>>>> 721897fcc4f942aa97d2e9ba3787a5e213758177
@triton.jit
@@ -227,8 +257,12 @@ def _attn_fwd(
acc, l_i, m_i, q, K_block_ptr, V_block_ptr,
start_m,
BLOCK_M, BLOCK_DMODEL, BLOCK_N,
<<<<<<< HEAD
4 - STAGE, offs_m, offs_n,
N_CTX, pre_load_v,
=======
4 - STAGE, offs_m, offs_n, N_CTX,
>>>>>>> 721897fcc4f942aa97d2e9ba3787a5e213758177
)
# stage 2: on-band
if STAGE & 2:
@@ -239,8 +273,12 @@ def _attn_fwd(
acc, l_i, m_i, q, K_block_ptr, V_block_ptr,
start_m,
BLOCK_M, BLOCK_DMODEL, BLOCK_N,
<<<<<<< HEAD
2, offs_m, offs_n,
N_CTX, pre_load_v,
=======
2, offs_m, offs_n, N_CTX,
>>>>>>> 721897fcc4f942aa97d2e9ba3787a5e213758177
)
# epilogue
# write back m
@@ -701,6 +739,7 @@ class _attention(torch.autograd.Function):
Lq, Lk, Lv = q.shape[-1], k.shape[-1], v.shape[-1]
assert Lq == Lk and Lk == Lv
assert Lk in {16, 32, 64, 128}
<<<<<<< HEAD
o = torch.empty_like(q, dtype=v.dtype)
if torch.version.hip is None:
BLOCK_M = 128
@@ -716,6 +755,20 @@ class _attention(torch.autograd.Function):
)
M = torch.empty((q.shape[0] * q.shape[1], q.shape[2]), device=q.device, dtype=torch.float32)
=======
o = torch.empty_like(q)
BLOCK_M = 128
BLOCK_N = 64 if Lk <= 64 else 32
num_stages = 4 if Lk <= 64 else 3
num_warps = 4
stage = 3 if causal else 1
# Tuning for H100
if torch.cuda.get_device_capability()[0] == 9:
num_warps = 8
num_stages = 7 if Lk >= 64 else 3
grid = (triton.cdiv(q.shape[2], BLOCK_M), q.shape[0] * q.shape[1], 1)
M = torch.empty((q.shape[0], q.shape[1], q.shape[2]), device=q.device, dtype=torch.float32)
>>>>>>> 721897fcc4f942aa97d2e9ba3787a5e213758177
_attn_fwd[grid](
q, k, v, sm_scale, M, o,
q.stride(0), q.stride(1), q.stride(2), q.stride(3),
@@ -726,6 +779,11 @@ class _attention(torch.autograd.Function):
N_CTX=q.shape[2],
BLOCK_DMODEL=Lk,
STAGE=stage,
<<<<<<< HEAD
=======
num_warps=num_warps,
num_stages=num_stages,
>>>>>>> 721897fcc4f942aa97d2e9ba3787a5e213758177
)
## restore the grid for bwd kernel
@@ -905,6 +963,7 @@ try:
except BaseException:
HAS_FLASH = False
<<<<<<< HEAD
# vary seq length for fixed head and batch=4
configs = []
for mode in ['fwd', 'bwd']:
@@ -939,6 +998,36 @@ for mode in ['fwd', 'bwd']:
'mode': mode,
'causal': causal})
)
=======
TORCH_HAS_FP8 = hasattr(torch, 'float8_e5m2')
BATCH, N_HEADS, N_CTX, D_HEAD = 4, 48, 4096, 64
# vary seq length for fixed head and batch=4
configs = []
for mode in ["fwd", "bwd"]:
for causal in [True, False]:
if mode == "bwd" and not causal:
continue
configs.append(
triton.testing.Benchmark(
x_names=["N_CTX"],
x_vals=[2**i for i in range(10, 15)],
line_arg="provider",
line_vals=["triton"] + (["flash"] if HAS_FLASH else []),
line_names=["Triton"] + (["Flash-2"] if HAS_FLASH else []),
styles=[("red", "-"), ("blue", "-")],
ylabel="ms",
plot_name=f"fused-attention-batch{BATCH}-head{N_HEADS}-d{D_HEAD}-{mode}-causal={causal}",
args={
"H": N_HEADS,
"BATCH": BATCH,
"D_HEAD": D_HEAD,
"dtype": torch.float16,
"mode": mode,
"causal": causal,
},
)
)
>>>>>>> 721897fcc4f942aa97d2e9ba3787a5e213758177
@triton.testing.perf_report(configs)
@@ -956,6 +1045,9 @@ def bench_flash_attention(
if provider == "triton":
q = torch.randn((BATCH, H, N_CTX, D_HEAD), dtype=dtype, device="cuda", requires_grad=True)
k = torch.randn((BATCH, H, N_CTX, D_HEAD), dtype=dtype, device="cuda", requires_grad=True)
if mode == "fwd" and TORCH_HAS_FP8:
q = q.to(torch.float8_e5m2)
k = k.to(torch.float8_e5m2)
v = torch.randn((BATCH, H, N_CTX, D_HEAD), dtype=dtype, device="cuda", requires_grad=True)
if mode == "fwd":
q = q.to(torch_dtype)

View File

@@ -216,10 +216,10 @@ tt.func @for_for_if(%lb : index, %ub : index, %step : index, %A : !tt.ptr<f16>,
// CHECK-NEXT: %arg9 -> %cst_1
// CHECK-NEXT: %0#0 -> %cst
// CHECK-NEXT: %0#1 -> %cst_0
// CHECK-NEXT: %0#2 -> %cst_2,%cst_2
// CHECK-NEXT: %0#2 -> %cst_1,%cst_2,%cst_2
%a_shared, %b_shared, %c_shared = scf.for %iv = %lb to %ub step %step iter_args(%a_shared = %a_shared_init, %b_shared = %b_shared_init, %c_shared = %c_shared_init) -> (tensor<128x32xf16, #A_SHARED>, tensor<128x32xf16, #A_SHARED>, tensor<128x32xf16, #A_SHARED>) {
// CHECK-NEXT: %arg11 -> %cst_1,%cst_2,%cst_2
// CHECK-NEXT: %1 -> %cst_2,%cst_2
// CHECK-NEXT: %1 -> %cst_1,%cst_2,%cst_2
%c_shared_next = scf.for %jv = %lb to %ub step %step iter_args(%c_shared_next = %c_shared) -> (tensor<128x32xf16, #A_SHARED>) {
// CHECK-NEXT: %2 -> %cst_2,%cst_2
%c_shared_next_next = scf.if %i1 -> tensor<128x32xf16, #A_SHARED> {
@@ -257,9 +257,9 @@ tt.func @cf_for(%arg0: index, %arg1: index, %arg2: index, %arg3: !tt.ptr<f16>, %
%5 = arith.cmpi slt, %1, %arg1 : index
cf.cond_br %5, ^bb2, ^bb3
^bb2: // pred: ^bb1
%6 = tt.cat %cst, %cst_0 {axis = 0 : i64} : (tensor<128x32xf16, #A_SHARED>, tensor<128x32xf16, #A_SHARED>) -> tensor<256x32xf16, #blocked>
%6 = tt.cat %cst, %cst_0 {axis = 0 : i64} : (tensor<128x32xf16, #A_SHARED>, tensor<128x32xf16, #A_SHARED>) -> tensor<256x32xf16>
gpu.barrier
%7 = tt.cat %2, %3 {axis = 0 : i64} : (tensor<128x32xf16, #A_SHARED>, tensor<128x32xf16, #A_SHARED>) -> tensor<256x32xf16, #blocked>
%7 = tt.cat %2, %3 {axis = 0 : i64} : (tensor<128x32xf16, #A_SHARED>, tensor<128x32xf16, #A_SHARED>) -> tensor<256x32xf16>
%8 = arith.addi %1, %arg2 : index
cf.br ^bb1(%8, %4, %2, %3 : index, tensor<128x32xf16, #A_SHARED>, tensor<128x32xf16, #A_SHARED>, tensor<128x32xf16, #A_SHARED>)
^bb3: // pred: ^bb1

View File

@@ -864,6 +864,37 @@ module attributes {"triton_gpu.num-ctas" = 1 : i32, "triton_gpu.num-warps" = 4 :
// -----
#blocked1 = #triton_gpu.blocked<{sizePerThread = [1, 1], threadsPerWarp = [32, 1], warpsPerCTA = [1, 8], order = [0, 1], CTAsPerCGA = [1, 1], CTASplitNum = [1, 1], CTAOrder = [0, 1]}>
#slice1d0 = #triton_gpu.slice<{dim = 0, parent = #blocked1}>
#shared = #triton_gpu.shared<{vec = 2, perPhase = 1, maxPhase = 8, order = [0], CTAsPerCGA = [1], CTASplitNum = [1], CTAOrder = [0], hasLeadingOffset = true}>
module attributes {"triton_gpu.num-ctas" = 1 : i32, "triton_gpu.num-warps" = 8 : i32} {
// CHECK-LABEL: basic_insert_slice_async_1d
tt.func @basic_insert_slice_async_1d(%arg0: !tt.ptr<i64, 1> {tt.divisibility = 16 : i32}) {
%c0_i32 = arith.constant 0 : i32
%cst_2 = arith.constant dense<64> : tensor<64xi32, #slice1d0>
%58 = tt.splat %arg0 : (!tt.ptr<i64, 1>) -> tensor<64x!tt.ptr<i64, 1>, #slice1d0>
%24 = tt.make_range {end = 64 : i32, start = 0 : i32} : tensor<64xi32, #slice1d0>
%59 = tt.addptr %58, %24 : tensor<64x!tt.ptr<i64, 1>, #slice1d0>, tensor<64xi32, #slice1d0>
%66 = tt.addptr %59, %cst_2 : tensor<64x!tt.ptr<i64, 1>, #slice1d0>, tensor<64xi32, #slice1d0>
%71 = triton_gpu.alloc_tensor : tensor<2x64xi64, #shared>
// CHECK: llvm.inline_asm has_side_effects asm_dialect = att
// CHECK-SAME: cp.async.ca.shared.global [ ${{.*}} + 0 ], [ ${{.*}} + 0 ], 0x8, 0x8
// CHECK-NEXT: cp.async.ca.shared.global [ ${{.*}} + 0 ], [ ${{.*}} + 0 ], 0x8, 0x8
// CHECK-NEXT: cp.async.ca.shared.global [ ${{.*}} + 0 ], [ ${{.*}} + 0 ], 0x8, 0x8
// CHECK-NEXT: cp.async.ca.shared.global [ ${{.*}} + 0 ], [ ${{.*}} + 0 ], 0x8, 0x8
// CHECK-NEXT: cp.async.ca.shared.global [ ${{.*}} + 0 ], [ ${{.*}} + 0 ], 0x8, 0x8
// CHECK-NEXT: cp.async.ca.shared.global [ ${{.*}} + 0 ], [ ${{.*}} + 0 ], 0x8, 0x8
// CHECK-NEXT: cp.async.ca.shared.global [ ${{.*}} + 0 ], [ ${{.*}} + 0 ], 0x8, 0x8
// CHECK-NEXT: cp.async.ca.shared.global [ ${{.*}} + 0 ], [ ${{.*}} + 0 ], 0x8, 0x8
// CHECK-NEXT: cp.async.commit_group
%73 = triton_gpu.insert_slice_async %66, %71, %c0_i32 {axis = 0 : i32, cache = 1 : i32, evict = 1 : i32, isVolatile = false} : tensor<64x!tt.ptr<i64, 1>, #slice1d0> -> tensor<2x64xi64, #shared>
triton_gpu.async_commit_group
tt.return
}
}
// -----
#block0 = #triton_gpu.blocked<{sizePerThread = [1], threadsPerWarp = [4], warpsPerCTA = [4], order = [0], CTAsPerCGA = [1], CTASplitNum = [1], CTAOrder = [0]}>
#block1 = #triton_gpu.blocked<{sizePerThread = [1], threadsPerWarp = [8], warpsPerCTA = [4], order = [0], CTAsPerCGA = [1], CTASplitNum = [1], CTAOrder = [0]}>
#block2 = #triton_gpu.blocked<{sizePerThread = [1, 1], threadsPerWarp = [4, 1], warpsPerCTA = [4, 1], order = [1, 0], CTAsPerCGA = [1, 1], CTASplitNum = [1, 1], CTAOrder = [1, 0]}>
@@ -2012,6 +2043,7 @@ module attributes {"triton_gpu.compute-capability" = 80 : i32, "triton_gpu.num-c
// -----
<<<<<<< HEAD
// CHECK-LABEL: copyitem
// GCN: llvm.store
// GCN: llvm.load
@@ -2023,11 +2055,21 @@ module attributes {"triton_gpu.compute-capability" = 80 : i32, "triton_gpu.num-c
module attributes {"triton_gpu.compute-capability" = 80 : i32, "triton_gpu.num-ctas" = 1 : i32, "triton_gpu.num-warps" = 4 : i32, "triton_gpu.threads-per-warp" = 32 : i32} {
tt.func public @copyitem() attributes {noinline = false} {
%cst = arith.constant dense<true> : tensor<4x1xi1, #blocked>
=======
// CHECK-LABEL: reduce_slice
// CHECK-NOT: st.shared
// CHECK-NOT: ld.shared
#blocked = #triton_gpu.blocked<{sizePerThread = [1, 1, 1], threadsPerWarp = [4, 4, 2], warpsPerCTA = [2, 4, 2], order = [2, 0, 1], CTAsPerCGA = [1, 1, 1], CTASplitNum = [1, 1, 1], CTAOrder = [0, 1, 2]}>
#sliced2 = #triton_gpu.slice<{dim = 2, parent = #blocked}>
module attributes {"triton_gpu.compute-capability" = 80 : i32, "triton_gpu.num-ctas" = 1 : i32, "triton_gpu.num-warps" = 16 : i32, "triton_gpu.threads-per-warp" = 32 : i32} {
tt.func public @reduce_slice() attributes {noinline = false} {
%cst = arith.constant dense<true> : tensor<4x1xi1, #sliced2>
>>>>>>> 721897fcc4f942aa97d2e9ba3787a5e213758177
%0 = "tt.reduce"(%cst) <{axis = 1 : i32}> ({
^bb0(%arg0: i1, %arg1: i1):
%1 = arith.ori %arg0, %arg1 : i1
tt.reduce.return %1 : i1
}) : (tensor<4x1xi1, #blocked>) -> tensor<4xi1, #triton_gpu.slice<{dim = 1, parent = #blocked}>>
}) : (tensor<4x1xi1, #sliced2>) -> tensor<4xi1, #triton_gpu.slice<{dim = 1, parent = #sliced2}>>
tt.return
}
}

View File

@@ -1,4 +1,4 @@
// RUN: triton-opt %s -split-input-file --convert-triton-gpu-to-llvm 2>&1 | FileCheck %s
// RUN: triton-opt %s -split-input-file --convert-triton-gpu-to-llvm=compute-capability=90 2>&1 | FileCheck %s
#blocked = #triton_gpu.blocked<{sizePerThread = [1, 4], threadsPerWarp = [2, 16], warpsPerCTA = [4, 1], order = [1, 0], CTAsPerCGA = [1, 4], CTASplitNum = [1, 4], CTAOrder = [0, 1]}>
#shared = #triton_gpu.shared<{vec = 8, perPhase = 1, maxPhase = 8, order = [1, 0], CTAsPerCGA = [1, 4], CTASplitNum = [1, 4], CTAOrder = [0, 1], hasLeadingOffset = true}>
@@ -12,8 +12,8 @@ module attributes {"triton_gpu.num-ctas" = 4 : i32, "triton_gpu.num-warps" = 4 :
%dst = triton_gpu.alloc_tensor : tensor<1x64x64xf16, #shared>
%c0 = arith.constant 0 : i32
%src = tt.make_tensor_ptr %basePtr, [%dim0, %dim1], [%stride0, %stride1], [%coord0, %coord1] {order = array<i32: 1, 0>} : !tt.ptr<tensor<64x64xf16, #blocked>, 1>
// CHECK: nvgpu.tma_load_tiled %{{.*}}, %{{.*}}, %{{.*}}, %{{.*}}, %{{.*}}, %{{.*}}, %{{.*}} {operand_segment_sizes = array<i32: 1, 1, 1, 1, 1, 2, 0>} : !llvm.ptr<i8, 3>, !llvm.ptr<i64, 3>, !llvm.ptr<i8, 1>, i64, i1, i32, i32
%res = triton_nvidia_gpu.insert_slice_async_v2 %src, %dst, %c0, %mbar {axis = 0 : i32, cache = 1 : i32, evict = 1 : i32, isVolatile = false, operand_segment_sizes = array<i32: 1, 1, 1, 1, 0, 0>} : !tt.ptr<tensor<64x64xf16, #blocked>, 1>, tensor<1x64x64xf16, #shared>, i32, !tt.ptr<i64, 3> -> tensor<1x64x64xf16, #shared>
// CHECK: nvgpu.tma_load_tiled %{{.*}}, %{{.*}}, %{{.*}}, %{{.*}}, %{{.*}}, %{{.*}}, %{{.*}} {operandSegmentSizes = array<i32: 1, 1, 1, 1, 1, 2, 0>} : !llvm.ptr<i8, 3>, !llvm.ptr<i64, 3>, !llvm.ptr<i8, 1>, i64, i1, i32, i32
%res = triton_nvidia_gpu.insert_slice_async_v2 %src, %dst, %c0, %mbar {axis = 0 : i32, cache = 1 : i32, evict = 1 : i32, isVolatile = false, operandSegmentSizes = array<i32: 1, 1, 1, 1, 0, 0>} : !tt.ptr<tensor<64x64xf16, #blocked>, 1>, tensor<1x64x64xf16, #shared>, i32, !tt.ptr<i64, 3> -> tensor<1x64x64xf16, #shared>
tt.return
}
}
@@ -34,7 +34,7 @@ module attributes {"triton_gpu.num-ctas" = 4 : i32, "triton_gpu.num-warps" = 4 :
%src = tt.make_tensor_ptr %basePtr, [%dim0, %dim1], [%stride0, %stride1], [%coord0, %coord1] {order = array<i32: 1, 0>} : !tt.ptr<tensor<64x64xf16, #blocked>, 1>
// CHECK: %[[C15:.*]] = llvm.mlir.constant(15 : i16) : i16
// CHECK: nvgpu.tma_load_tiled %{{.*}}, %{{.*}}, %{{.*}}, %{{.*}}, %{{.*}}, %{{.*}}, %{{.*}}, %[[C15]]
%res = triton_nvidia_gpu.insert_slice_async_v2 %src, %dst, %c0, %mbar {axis = 0 : i32, cache = 1 : i32, evict = 1 : i32, isVolatile = false, operand_segment_sizes = array<i32: 1, 1, 1, 1, 0, 0>} : !tt.ptr<tensor<64x64xf16, #blocked>, 1>, tensor<1x64x64xf16, #shared>, i32, !tt.ptr<i64, 3> -> tensor<1x64x64xf16, #shared>
%res = triton_nvidia_gpu.insert_slice_async_v2 %src, %dst, %c0, %mbar {axis = 0 : i32, cache = 1 : i32, evict = 1 : i32, isVolatile = false, operandSegmentSizes = array<i32: 1, 1, 1, 1, 0, 0>} : !tt.ptr<tensor<64x64xf16, #blocked>, 1>, tensor<1x64x64xf16, #shared>, i32, !tt.ptr<i64, 3> -> tensor<1x64x64xf16, #shared>
tt.return
}
}
@@ -55,7 +55,7 @@ module attributes {"triton_gpu.num-ctas" = 4 : i32, "triton_gpu.num-warps" = 4 :
%src = tt.make_tensor_ptr %basePtr, [%dim0, %dim1], [%stride0, %stride1], [%coord0, %coord1] {order = array<i32: 1, 0>} : !tt.ptr<tensor<64x64xf16, #blocked>, 1>
// CHECK: nvgpu.cluster_id
// CHECK: nvgpu.tma_load_tiled
%res = triton_nvidia_gpu.insert_slice_async_v2 %src, %dst, %c0, %mbar {axis = 0 : i32, cache = 1 : i32, evict = 1 : i32, isVolatile = false, operand_segment_sizes = array<i32: 1, 1, 1, 1, 0, 0>} : !tt.ptr<tensor<64x64xf16, #blocked>, 1>, tensor<1x64x64xf16, #shared>, i32, !tt.ptr<i64, 3> -> tensor<1x64x64xf16, #shared>
%res = triton_nvidia_gpu.insert_slice_async_v2 %src, %dst, %c0, %mbar {axis = 0 : i32, cache = 1 : i32, evict = 1 : i32, isVolatile = false, operandSegmentSizes = array<i32: 1, 1, 1, 1, 0, 0>} : !tt.ptr<tensor<64x64xf16, #blocked>, 1>, tensor<1x64x64xf16, #shared>, i32, !tt.ptr<i64, 3> -> tensor<1x64x64xf16, #shared>
tt.return
}
}
@@ -175,6 +175,7 @@ module attributes {"triton_gpu.num-ctas" = 1 : i32, "triton_gpu.num-warps" = 4 :
// CHECK-LABEL: @dot_reg_operand_A
// Generate a wgmma where the first operand is a struct.
// CHECK: nvgpu.wgmma {{.*}} : (!llvm.struct<(i32, i32, i32, i32)>, i64, !llvm.struct<(f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32)>) -> !llvm.struct<(f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32)>
// CHECK: nvgpu.wgmma_wait_group %{{.*}} {pendings = 0 : i32} : !llvm.struct<(f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32)>
tt.func @dot_reg_operand_A(%a: tensor<128x64xf16, #mma>, %b: tensor<64x64xf16, #shared>) {
%cst = arith.constant dense<0.000000e+00> : tensor<128x64xf32, #mma>
%opA = triton_gpu.convert_layout %a : (tensor<128x64xf16, #mma>) -> tensor<128x64xf16, #triton_gpu.dot_op<{opIdx = 0, parent = #mma}>>
@@ -183,3 +184,29 @@ module attributes {"triton_gpu.num-ctas" = 1 : i32, "triton_gpu.num-warps" = 4 :
tt.return
}
}
// -----
#blocked = #triton_gpu.blocked<{sizePerThread = [1], threadsPerWarp = [32], warpsPerCTA = [1], order = [0], CTAsPerCGA = [1], CTASplitNum = [1], CTAOrder = [0]}>
module attributes {"triton_gpu.compute-capability" = 90 : i32, "triton_gpu.num-ctas" = 1 : i32, "triton_gpu.num-warps" = 1 : i32} {
// CHECK-LABEL: test_fp8_to_f16_conversion
tt.func @test_fp8_to_f16_conversion(
%in0: tensor<128xf8E5M2, #blocked>, %in1: tensor<128xf8E4M3FNUZ, #blocked>,
%in2: tensor<128xf16, #blocked>, %in3: tensor<128xf32, #blocked>) {
// CHECK-COUNT-2: cvt.rn.f16x2.e5m2x2 {{.*}} "=r,h" %{{.*}} : (i16) -> vector<2xf16>
%out0 = tt.fp_to_fp %in0 : tensor<128xf8E5M2, #blocked> -> tensor<128xf16, #blocked>
// CHECK-COUNT-2: cvt.rn.f16x2.e4m3x2 {{.*}} "=r,h" %{{.*}} : (i16) -> vector<2xf16>
%out1 = tt.fp_to_fp %in1 : tensor<128xf8E4M3FNUZ, #blocked> -> tensor<128xf16, #blocked>
// CHECK-COUNT-2: cvt.rn.satfinite.e5m2x2.f16x2 {{.*}} "=h,r" %{{.*}} : (i32) -> vector<2xi8>
%out2 = tt.fp_to_fp %in2 : tensor<128xf16, #blocked> -> tensor<128xf8E5M2, #blocked>
// CHECK-COUNT-2: cvt.rn.satfinite.e4m3x2.f16x2 {{.*}} "=h,r" %{{.*}} : (i32) -> vector<2xi8>
%out3 = tt.fp_to_fp %in2 : tensor<128xf16, #blocked> -> tensor<128xf8E4M3FNUZ, #blocked>
// CHECK-COUNT-2: cvt.rn.satfinite.e5m2x2.f32 {{.*}} "=h,r,r" %{{.*}}, %{{.*}} : (i32, i32) -> vector<2xi8>
%out4 = tt.fp_to_fp %in3 : tensor<128xf32, #blocked> -> tensor<128xf8E5M2, #blocked>
// CHECK-COUNT-2: cvt.rn.satfinite.e4m3x2.f32 {{.*}} "=h,r,r" %{{.*}}, %{{.*}} : (i32, i32) -> vector<2xi8>
%out5 = tt.fp_to_fp %in3 : tensor<128xf32, #blocked> -> tensor<128xf8E4M3FNUZ, #blocked>
tt.return
}
}

View File

@@ -14,7 +14,7 @@ module attributes {"triton_gpu.num-warps" = 4 : i32, "triton_gpu.num-ctas" = 2 :
nvgpu.cga_barrier_arrive
nvgpu.cga_barrier_wait
%ptr = llvm.mlir.null : !llvm.ptr<i32, 3>
%ptr = llvm.mlir.zero : !llvm.ptr<i32, 3>
// CHECK: llvm.inline_asm
%v = nvgpu.cluster_id

View File

@@ -2,7 +2,7 @@
#SHARED = #triton_gpu.shared<{vec = 2, perPhase = 2, maxPhase = 4, order = [1, 0]}>
module attributes {"triton_gpu.num-warps" = 4 : i32, "triton_gpu.num-ctas" = 2 : i32} {
tt.func @test_mbarrier() {
%mbarrier = llvm.mlir.null : !llvm.ptr<i64, 3>
%mbarrier = llvm.mlir.zero : !llvm.ptr<i64, 3>
%pred = arith.constant 1 : i1
// CHECK: llvm.inline_asm
nvgpu.mbarrier_init %mbarrier, %pred { count = 32 : i32 } : !llvm.ptr<i64, 3>

View File

@@ -2,9 +2,9 @@
#SHARED = #triton_gpu.shared<{vec = 2, perPhase = 2, maxPhase = 4, order = [1, 0]}>
module attributes {"triton_gpu.num-warps" = 4 : i32, "triton_gpu.num-ctas" = 2 : i32} {
tt.func @test_tma(%im2colOffsets0 : !llvm.struct<(i16, i16)>, %im2colOffsets1 : !llvm.struct<(i16, i16, i16)>) {
%mbarrier = llvm.mlir.null : !llvm.ptr<i64, 3>
%tmaDesc = llvm.mlir.null : !llvm.ptr<i8, 1>
%dst = llvm.mlir.null : !llvm.ptr<i8, 3>
%mbarrier = llvm.mlir.zero : !llvm.ptr<i64, 3>
%tmaDesc = llvm.mlir.zero : !llvm.ptr<i8, 1>
%dst = llvm.mlir.zero : !llvm.ptr<i8, 3>
%l2desc = arith.constant 0 : i64
%c0 = arith.constant 0 : i32
%c1 = arith.constant 1 : i32
@@ -16,13 +16,13 @@ module attributes {"triton_gpu.num-warps" = 4 : i32, "triton_gpu.num-ctas" = 2
// CHECK: llvm.inline_asm {{.*}} cp.async.bulk.tensor.2d.shared::cluster.global.mbarrier::complete_tx::bytes.L2::cache_hint
// CHECK: llvm.inline_asm {{.*}} cp.async.bulk.tensor.4d.shared::cluster.global.mbarrier::complete_tx::bytes.L2::cache_hint
nvgpu.tma_load_tiled %dst, %mbarrier, %tmaDesc, %l2desc, %pred, %c0, %c1 {operand_segment_sizes = array<i32: 1, 1, 1, 1, 1, 2, 0>}: !llvm.ptr<i8, 3>, !llvm.ptr<i64, 3>, !llvm.ptr<i8, 1>, i64, i1, i32, i32
nvgpu.tma_load_tiled %dst, %mbarrier, %tmaDesc, %l2desc, %pred, %c0, %c1, %c2, %c3 {operand_segment_sizes = array<i32: 1, 1, 1, 1, 1, 4, 0>}: !llvm.ptr<i8, 3>, !llvm.ptr<i64, 3>, !llvm.ptr<i8, 1>, i64, i1, i32, i32, i32, i32
nvgpu.tma_load_tiled %dst, %mbarrier, %tmaDesc, %l2desc, %pred, %c0, %c1 {operandSegmentSizes = array<i32: 1, 1, 1, 1, 1, 2, 0>}: !llvm.ptr<i8, 3>, !llvm.ptr<i64, 3>, !llvm.ptr<i8, 1>, i64, i1, i32, i32
nvgpu.tma_load_tiled %dst, %mbarrier, %tmaDesc, %l2desc, %pred, %c0, %c1, %c2, %c3 {operandSegmentSizes = array<i32: 1, 1, 1, 1, 1, 4, 0>}: !llvm.ptr<i8, 3>, !llvm.ptr<i64, 3>, !llvm.ptr<i8, 1>, i64, i1, i32, i32, i32, i32
// CHECK: llvm.inline_asm {{.*}} cp.async.bulk.tensor.2d.shared::cluster.global.mbarrier::complete_tx::bytes.multicast::cluster.L2::cache_hint
// CHECK: llvm.inline_asm {{.*}} cp.async.bulk.tensor.4d.shared::cluster.global.mbarrier::complete_tx::bytes.L2::cache_hint
nvgpu.tma_load_tiled %dst, %mbarrier, %tmaDesc, %l2desc, %pred, %c0, %c1, %mask {operand_segment_sizes = array<i32: 1, 1, 1, 1, 1, 2, 1>}: !llvm.ptr<i8, 3>, !llvm.ptr<i64, 3>, !llvm.ptr<i8, 1>, i64, i1, i32, i32, i16
nvgpu.tma_load_tiled %dst, %mbarrier, %tmaDesc, %l2desc, %pred, %c0, %c1, %c2, %c3 {operand_segment_sizes = array<i32: 1, 1, 1, 1, 1, 4, 0>}: !llvm.ptr<i8, 3>, !llvm.ptr<i64, 3>, !llvm.ptr<i8, 1>, i64, i1, i32, i32, i32, i32
nvgpu.tma_load_tiled %dst, %mbarrier, %tmaDesc, %l2desc, %pred, %c0, %c1, %mask {operandSegmentSizes = array<i32: 1, 1, 1, 1, 1, 2, 1>}: !llvm.ptr<i8, 3>, !llvm.ptr<i64, 3>, !llvm.ptr<i8, 1>, i64, i1, i32, i32, i16
nvgpu.tma_load_tiled %dst, %mbarrier, %tmaDesc, %l2desc, %pred, %c0, %c1, %c2, %c3 {operandSegmentSizes = array<i32: 1, 1, 1, 1, 1, 4, 0>}: !llvm.ptr<i8, 3>, !llvm.ptr<i64, 3>, !llvm.ptr<i8, 1>, i64, i1, i32, i32, i32, i32
tt.return
}

View File

@@ -2,7 +2,7 @@
#SHARED = #triton_gpu.shared<{vec = 2, perPhase = 2, maxPhase = 4, order = [1, 0]}>
module attributes {"triton_gpu.num-warps" = 4 : i32, "triton_gpu.num-ctas" = 2 : i32} {
tt.func @test_tma(%opC : !llvm.struct<(f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32)>) {
%buffer = llvm.mlir.null : !llvm.ptr<i64, 3>
%buffer = llvm.mlir.zero : !llvm.ptr<i64, 3>
%height = arith.constant 16 : i32
// CHECK: llvm.ptrtoint
// CHECK: llvm.inline_asm
@@ -30,3 +30,15 @@ module attributes {"triton_gpu.num-warps" = 4 : i32, "triton_gpu.num-ctas" = 2
tt.return
}
}
// -----
module attributes {"triton_gpu.num-warps" = 4 : i32, "triton_gpu.num-ctas" = 2 : i32} {
tt.func @wgmma_wait(%in: !llvm.struct<(f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32)>) {
// CHECK: // wait for regs: $0,$1,$2,$3,$4,$5,$6,$7,$8,$9,$10,$11,$12,$13,$14,$15,$16,$17,$18,$19,$20,$21,$22,$23,$24,$25,$26,$27,$28,$29,$30,$31,$32,$33,$34,$35,$36,$37,$38,$39,$40,$41,$42,$43,$44,$45,$46,$47,$48,$49,$50,$51,$52,$53,$54,$55,$56,$57,$58,$59,$60,$61,$62,$63
// CHECK: wgmma.wait_group.sync.aligned 0;
%out = nvgpu.wgmma_wait_group %in {pendings = 0 : i32} :
!llvm.struct<(f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32)>
tt.return
}
}

31
test/Triton/print.mlir Normal file
View File

@@ -0,0 +1,31 @@
// RUN: triton-translate %s --mlir-print-ir-after-all -o %t 2>&1 | FileCheck %s
// CHECK: IR Dump After SCFToControlFlow (convert-scf-to-cf)
// CHECK: tt.func public @add_kernel_0d1d2d3de
// CHECK: IR Dump After ConvertIndexToLLVMPass (convert-index-to-llvm)
// CHECK: tt.func public @add_kernel_0d1d2d3de
#blocked = #triton_gpu.blocked<{sizePerThread = [4], threadsPerWarp = [32], warpsPerCTA = [4], order = [0], CTAsPerCGA = [1], CTASplitNum = [1], CTAOrder = [0]}>
module attributes {"triton_gpu.compute-capability" = 80 : i32, "triton_gpu.num-ctas" = 1 : i32, "triton_gpu.num-warps" = 4 : i32, "triton_gpu.threads-per-warp" = 32 : i32} {
tt.func public @add_kernel_0d1d2d3de(%arg0: !tt.ptr<f32, 1> {tt.divisibility = 16 : i32}, %arg1: !tt.ptr<f32, 1> {tt.divisibility = 16 : i32}, %arg2: !tt.ptr<f32, 1> {tt.divisibility = 16 : i32}, %arg3: i32 {tt.divisibility = 16 : i32, tt.max_divisibility = 8 : i32}) attributes {noinline = false} {
%c1024_i32 = arith.constant 1024 : i32
%0 = tt.get_program_id x : i32
%1 = arith.muli %0, %c1024_i32 : i32
%2 = tt.make_range {end = 1024 : i32, start = 0 : i32} : tensor<1024xi32, #blocked>
%3 = tt.splat %1 : (i32) -> tensor<1024xi32, #blocked>
%4 = arith.addi %3, %2 : tensor<1024xi32, #blocked>
%5 = tt.splat %arg3 : (i32) -> tensor<1024xi32, #blocked>
%6 = "triton_gpu.cmpi"(%4, %5) <{predicate = 2 : i64}> : (tensor<1024xi32, #blocked>, tensor<1024xi32, #blocked>) -> tensor<1024xi1, #blocked>
%7 = tt.splat %arg0 : (!tt.ptr<f32, 1>) -> tensor<1024x!tt.ptr<f32, 1>, #blocked>
%8 = tt.addptr %7, %4 : tensor<1024x!tt.ptr<f32, 1>, #blocked>, tensor<1024xi32, #blocked>
%9 = tt.load %8, %6 {cache = 1 : i32, evict = 1 : i32, isVolatile = false} : tensor<1024xf32, #blocked>
%10 = tt.splat %arg1 : (!tt.ptr<f32, 1>) -> tensor<1024x!tt.ptr<f32, 1>, #blocked>
%11 = tt.addptr %10, %4 : tensor<1024x!tt.ptr<f32, 1>, #blocked>, tensor<1024xi32, #blocked>
%12 = tt.load %11, %6 {cache = 1 : i32, evict = 1 : i32, isVolatile = false} : tensor<1024xf32, #blocked>
%13 = arith.addf %9, %12 : tensor<1024xf32, #blocked>
%14 = tt.splat %arg2 : (!tt.ptr<f32, 1>) -> tensor<1024x!tt.ptr<f32, 1>, #blocked>
%15 = tt.addptr %14, %4 : tensor<1024x!tt.ptr<f32, 1>, #blocked>, tensor<1024xi32, #blocked>
tt.store %15, %13, %6 {cache = 1 : i32, evict = 1 : i32} : tensor<1024xf32, #blocked>
tt.return
}
}

View File

@@ -223,7 +223,7 @@ tt.func @if_else_both_convert(%arg0: i32, %arg1: !tt.ptr<i32> {tt.divisibility =
// CHECK: [[$col_layout:#.*]] = #triton_gpu.blocked<{sizePerThread = [4, 1], threadsPerWarp = [16, 2], warpsPerCTA = [4, 1], order = [0, 1], CTAsPerCGA = [1, 1], CTASplitNum = [1, 1], CTAOrder = [0, 1]}>
// CHECK: [[$col_layout_novec:#.*]] = #triton_gpu.blocked<{sizePerThread = [1, 1], threadsPerWarp = [32, 1], warpsPerCTA = [4, 1], order = [0, 1], CTAsPerCGA = [1, 1], CTASplitNum = [1, 1], CTAOrder = [0, 1]}>
// CHECK-LABEL: transpose
// CHECK-LABEL: @transpose
module attributes {"triton_gpu.num-warps" = 4 : i32} {
tt.func @transpose(%arg0: !tt.ptr<f32> {tt.divisibility = 16 : i32}, %arg1: i32 {tt.divisibility = 16 : i32}, %arg2: !tt.ptr<f32> {tt.divisibility = 16 : i32}, %arg3: i32 {tt.divisibility = 16 : i32}) {
// CHECK-NOT: triton_gpu.convert_layout
@@ -920,7 +920,7 @@ tt.func public @mnist(%arg0: !tt.ptr<f32> {tt.divisibility = 16 : i32}, %arg1: !
%29 = "triton_gpu.select"(%28, %26, %cst_2) : (tensor<16x16xi1, #blocked2>, tensor<16x16xf32, #blocked2>, tensor<16x16xf32, #blocked2>) -> tensor<16x16xf32, #blocked2>
%30 = "tt.reduce" (%29) ({
^bb0(%arg4: f32, %arg5: f32):
%max = arith.maxf %arg4, %arg5 : f32
%max = arith.maximumf %arg4, %arg5 : f32
tt.reduce.return %max : f32
}) {axis = 1 : i32} : (tensor<16x16xf32, #blocked2>) -> tensor<16xf32, #triton_gpu.slice<{dim = 1, parent = #blocked2}>>
%31 = triton_gpu.convert_layout %30 : (tensor<16xf32, #triton_gpu.slice<{dim = 1, parent = #blocked2}>>) -> tensor<16xf32, #blocked0>
@@ -1697,11 +1697,11 @@ module attributes {"triton_gpu.compute-capability" = 90 : i32, "triton_gpu.num-c
%122 = arith.extf %121 : tensor<128x64xf16, #blocked2> to tensor<128x64xf32, #blocked2>
%123 = "tt.reduce"(%122) <{axis = 1 : i32}> ({
^bb0(%arg28: f32, %arg29: f32):
%153 = arith.maxf %arg28, %arg29 : f32
%153 = arith.maximumf %arg28, %arg29 : f32
tt.reduce.return %153 : f32
}) : (tensor<128x64xf32, #blocked2>) -> tensor<128xf32, #triton_gpu.slice<{dim = 1, parent = #blocked2}>>
%124 = triton_gpu.convert_layout %123 : (tensor<128xf32, #triton_gpu.slice<{dim = 1, parent = #blocked2}>>) -> tensor<128xf32, #blocked1>
%125 = arith.maxf %arg25, %124 : tensor<128xf32, #blocked1>
%125 = arith.maximumf %arg25, %124 : tensor<128xf32, #blocked1>
%126 = arith.subf %arg25, %125 : tensor<128xf32, #blocked1>
%127 = tt.extern_elementwise %126 {pure = true, libname = "libdevice", libpath = "/root/.pyenv/versions/3.9.9/lib/python3.9/site-packages/triton/language/../third_party/cuda/lib/libdevice.10.bc", symbol = "__nv_exp2f"} : (tensor<128xf32, #blocked1>) -> tensor<128xf32, #blocked1>
%128 = triton_gpu.convert_layout %125 : (tensor<128xf32, #blocked1>) -> tensor<128xf32, #triton_gpu.slice<{dim = 1, parent = #blocked9}>>

View File

@@ -32,19 +32,19 @@ module attributes {"triton_gpu.compute-capability" = 90 : i32, "triton_gpu.num-c
%14 = triton_nvidia_gpu.get_thread_id : i32
%15 = arith.cmpi eq, %14, %c0_i32 : i32
%16 = arith.andi %15, %10 : i1
triton_nvidia_gpu.mbarrier_arrive %13, %16 {operand_segment_sizes = array<i32: 1, 1, 0>, trackAsyncOp = false, txCount = 65536 : i32} : !tt.ptr<i64, 3>, i1
%17 = triton_nvidia_gpu.insert_slice_async_v2 %3, %11, %c0_i32, %13, %12 {axis = 0 : i32, cache = 1 : i32, evict = 1 : i32, isVolatile = false, operand_segment_sizes = array<i32: 1, 1, 1, 1, 1, 0>} : !tt.ptr<tensor<128x128xf16, #blocked>, 1>, tensor<3x128x128xf16, #shared1>, i32, !tt.ptr<i64, 3>, tensor<128x128xi1, #blocked1> -> tensor<3x128x128xf16, #shared1>
triton_nvidia_gpu.mbarrier_arrive %13, %16 {operandSegmentSizes = array<i32: 1, 1, 0>, trackAsyncOp = false, txCount = 65536 : i32} : !tt.ptr<i64, 3>, i1
%17 = triton_nvidia_gpu.insert_slice_async_v2 %3, %11, %c0_i32, %13, %12 {axis = 0 : i32, cache = 1 : i32, evict = 1 : i32, isVolatile = false, operandSegmentSizes = array<i32: 1, 1, 1, 1, 1, 0>} : !tt.ptr<tensor<128x128xf16, #blocked>, 1>, tensor<3x128x128xf16, #shared1>, i32, !tt.ptr<i64, 3>, tensor<128x128xi1, #blocked1> -> tensor<3x128x128xf16, #shared1>
%18 = triton_gpu.alloc_tensor : tensor<3x128x128xf16, #shared1>
%19 = triton_nvidia_gpu.insert_slice_async_v2 %6, %18, %c0_i32, %13, %12 {axis = 0 : i32, cache = 1 : i32, evict = 1 : i32, isVolatile = false, operand_segment_sizes = array<i32: 1, 1, 1, 1, 1, 0>} : !tt.ptr<tensor<128x128xf16, #blocked>, 1>, tensor<3x128x128xf16, #shared1>, i32, !tt.ptr<i64, 3>, tensor<128x128xi1, #blocked1> -> tensor<3x128x128xf16, #shared1>
%19 = triton_nvidia_gpu.insert_slice_async_v2 %6, %18, %c0_i32, %13, %12 {axis = 0 : i32, cache = 1 : i32, evict = 1 : i32, isVolatile = false, operandSegmentSizes = array<i32: 1, 1, 1, 1, 1, 0>} : !tt.ptr<tensor<128x128xf16, #blocked>, 1>, tensor<3x128x128xf16, #shared1>, i32, !tt.ptr<i64, 3>, tensor<128x128xi1, #blocked1> -> tensor<3x128x128xf16, #shared1>
%20 = tt.advance %3, [%c0_i32, %c128_i32] : <tensor<128x128xf16, #blocked>, 1>
%21 = tt.advance %6, [%c128_i32, %c0_i32] : <tensor<128x128xf16, #blocked>, 1>
%22 = arith.cmpi sgt, %arg5, %c128_i32 : i32
%23 = tt.splat %22 : (i1) -> tensor<128x128xi1, #blocked1>
%24 = triton_nvidia_gpu.extract_mbarrier %9[%c1_i32] : tensor<3xi64, #shared>, i32 -> <i64, 3>
%25 = arith.andi %15, %22 : i1
triton_nvidia_gpu.mbarrier_arrive %24, %25 {operand_segment_sizes = array<i32: 1, 1, 0>, trackAsyncOp = false, txCount = 65536 : i32} : !tt.ptr<i64, 3>, i1
%26 = triton_nvidia_gpu.insert_slice_async_v2 %20, %17, %c1_i32, %24, %23 {axis = 0 : i32, cache = 1 : i32, evict = 1 : i32, isVolatile = false, operand_segment_sizes = array<i32: 1, 1, 1, 1, 1, 0>} : !tt.ptr<tensor<128x128xf16, #blocked>, 1>, tensor<3x128x128xf16, #shared1>, i32, !tt.ptr<i64, 3>, tensor<128x128xi1, #blocked1> -> tensor<3x128x128xf16, #shared1>
%27 = triton_nvidia_gpu.insert_slice_async_v2 %21, %19, %c1_i32, %24, %23 {axis = 0 : i32, cache = 1 : i32, evict = 1 : i32, isVolatile = false, operand_segment_sizes = array<i32: 1, 1, 1, 1, 1, 0>} : !tt.ptr<tensor<128x128xf16, #blocked>, 1>, tensor<3x128x128xf16, #shared1>, i32, !tt.ptr<i64, 3>, tensor<128x128xi1, #blocked1> -> tensor<3x128x128xf16, #shared1>
triton_nvidia_gpu.mbarrier_arrive %24, %25 {operandSegmentSizes = array<i32: 1, 1, 0>, trackAsyncOp = false, txCount = 65536 : i32} : !tt.ptr<i64, 3>, i1
%26 = triton_nvidia_gpu.insert_slice_async_v2 %20, %17, %c1_i32, %24, %23 {axis = 0 : i32, cache = 1 : i32, evict = 1 : i32, isVolatile = false, operandSegmentSizes = array<i32: 1, 1, 1, 1, 1, 0>} : !tt.ptr<tensor<128x128xf16, #blocked>, 1>, tensor<3x128x128xf16, #shared1>, i32, !tt.ptr<i64, 3>, tensor<128x128xi1, #blocked1> -> tensor<3x128x128xf16, #shared1>
%27 = triton_nvidia_gpu.insert_slice_async_v2 %21, %19, %c1_i32, %24, %23 {axis = 0 : i32, cache = 1 : i32, evict = 1 : i32, isVolatile = false, operandSegmentSizes = array<i32: 1, 1, 1, 1, 1, 0>} : !tt.ptr<tensor<128x128xf16, #blocked>, 1>, tensor<3x128x128xf16, #shared1>, i32, !tt.ptr<i64, 3>, tensor<128x128xi1, #blocked1> -> tensor<3x128x128xf16, #shared1>
%28 = triton_gpu.extract_slice %26[0, 0, 0] [1, 128, 128] [1, 1, 1] : tensor<3x128x128xf16, #shared1> to tensor<128x128xf16, #shared1>
%29 = triton_gpu.extract_slice %27[0, 0, 0] [1, 128, 128] [1, 1, 1] : tensor<3x128x128xf16, #shared1> to tensor<128x128xf16, #shared1>
%30:15 = scf.for %arg9 = %c0_i32 to %arg5 step %c128_i32 iter_args(%arg10 = %cst, %arg11 = %3, %arg12 = %6, %arg13 = %26, %arg14 = %27, %arg15 = %28, %arg16 = %29, %arg17 = %20, %arg18 = %21, %arg19 = %c128_i32, %arg20 = %c2_i32, %arg21 = %c0_i32, %arg22 = %c0_i32, %arg23 = %false, %arg24 = %true) -> (tensor<128x128xf32, #mma>, !tt.ptr<tensor<128x128xf16, #blocked>, 1>, !tt.ptr<tensor<128x128xf16, #blocked>, 1>, tensor<3x128x128xf16, #shared1>, tensor<3x128x128xf16, #shared1>, tensor<128x128xf16, #shared1>, tensor<128x128xf16, #shared1>, !tt.ptr<tensor<128x128xf16, #blocked>, 1>, !tt.ptr<tensor<128x128xf16, #blocked>, 1>, i32, i32, i32, i32, i1, i1) : i32 {
@@ -52,7 +52,6 @@ module attributes {"triton_gpu.compute-capability" = 90 : i32, "triton_gpu.num-c
triton_nvidia_gpu.mbarrier_wait %33, %arg23 : <i64, 3>
// CHECK: triton_nvidia_gpu.fence_async_shared
%34 = triton_nvidia_gpu.dot_async %arg15, %arg16, %arg10 {allowTF32 = true, maxNumImpreciseAcc = 0 : i32} : tensor<128x128xf16, #shared1> * tensor<128x128xf16, #shared1> -> tensor<128x128xf32, #mma>
triton_nvidia_gpu.dot_wait {pendings = 1 : i32}
%35 = tt.advance %arg11, [%c0_i32, %c128_i32] : <tensor<128x128xf16, #blocked>, 1>
%36 = tt.advance %arg12, [%c128_i32, %c0_i32] : <tensor<128x128xf16, #blocked>, 1>
%37 = arith.addi %arg19, %c128_i32 : i32
@@ -65,10 +64,10 @@ module attributes {"triton_gpu.compute-capability" = 90 : i32, "triton_gpu.num-c
%44 = tt.splat %38 : (i1) -> tensor<128x128xi1, #blocked1>
%45 = triton_nvidia_gpu.extract_mbarrier %9[%arg20] : tensor<3xi64, #shared>, i32 -> <i64, 3>
%46 = arith.andi %15, %38 : i1
triton_nvidia_gpu.mbarrier_arrive %45, %46 {operand_segment_sizes = array<i32: 1, 1, 0>, trackAsyncOp = false, txCount = 65536 : i32} : !tt.ptr<i64, 3>, i1
%47 = triton_nvidia_gpu.insert_slice_async_v2 %42, %arg13, %arg20, %45, %44 {axis = 0 : i32, cache = 1 : i32, evict = 1 : i32, isVolatile = false, operand_segment_sizes = array<i32: 1, 1, 1, 1, 1, 0>} : !tt.ptr<tensor<128x128xf16, #blocked>, 1>, tensor<3x128x128xf16, #shared1>, i32, !tt.ptr<i64, 3>, tensor<128x128xi1, #blocked1> -> tensor<3x128x128xf16, #shared1>
triton_nvidia_gpu.mbarrier_arrive %45, %46 {operandSegmentSizes = array<i32: 1, 1, 0>, trackAsyncOp = false, txCount = 65536 : i32} : !tt.ptr<i64, 3>, i1
%47 = triton_nvidia_gpu.insert_slice_async_v2 %42, %arg13, %arg20, %45, %44 {axis = 0 : i32, cache = 1 : i32, evict = 1 : i32, isVolatile = false, operandSegmentSizes = array<i32: 1, 1, 1, 1, 1, 0>} : !tt.ptr<tensor<128x128xf16, #blocked>, 1>, tensor<3x128x128xf16, #shared1>, i32, !tt.ptr<i64, 3>, tensor<128x128xi1, #blocked1> -> tensor<3x128x128xf16, #shared1>
%48 = triton_gpu.extract_slice %47[%41, 0, 0] [1, 128, 128] [1, 1, 1] : tensor<3x128x128xf16, #shared1> to tensor<128x128xf16, #shared1>
%49 = triton_nvidia_gpu.insert_slice_async_v2 %43, %arg14, %arg20, %45, %44 {axis = 0 : i32, cache = 1 : i32, evict = 1 : i32, isVolatile = false, operand_segment_sizes = array<i32: 1, 1, 1, 1, 1, 0>} : !tt.ptr<tensor<128x128xf16, #blocked>, 1>, tensor<3x128x128xf16, #shared1>, i32, !tt.ptr<i64, 3>, tensor<128x128xi1, #blocked1> -> tensor<3x128x128xf16, #shared1>
%49 = triton_nvidia_gpu.insert_slice_async_v2 %43, %arg14, %arg20, %45, %44 {axis = 0 : i32, cache = 1 : i32, evict = 1 : i32, isVolatile = false, operandSegmentSizes = array<i32: 1, 1, 1, 1, 1, 0>} : !tt.ptr<tensor<128x128xf16, #blocked>, 1>, tensor<3x128x128xf16, #shared1>, i32, !tt.ptr<i64, 3>, tensor<128x128xi1, #blocked1> -> tensor<3x128x128xf16, #shared1>
%50 = triton_gpu.extract_slice %49[%41, 0, 0] [1, 128, 128] [1, 1, 1] : tensor<3x128x128xf16, #shared1> to tensor<128x128xf16, #shared1>
%b_48 = triton_gpu.convert_layout %48 : (tensor<128x128xf16, #shared1>) -> tensor<128x128xf16, #blocked1>
%s_48 = triton_gpu.convert_layout %b_48 : (tensor<128x128xf16, #blocked1>) -> tensor<128x128xf16, #shared1>
@@ -88,10 +87,8 @@ module attributes {"triton_gpu.compute-capability" = 90 : i32, "triton_gpu.num-c
%64 = arith.ori %62, %63 : i1
scf.yield %34, %35, %36, %47, %49, %s_48, %50, %42, %43, %37, %53, %41, %54, %59, %64 : tensor<128x128xf32, #mma>, !tt.ptr<tensor<128x128xf16, #blocked>, 1>, !tt.ptr<tensor<128x128xf16, #blocked>, 1>, tensor<3x128x128xf16, #shared1>, tensor<3x128x128xf16, #shared1>, tensor<128x128xf16, #shared1>, tensor<128x128xf16, #shared1>, !tt.ptr<tensor<128x128xf16, #blocked>, 1>, !tt.ptr<tensor<128x128xf16, #blocked>, 1>, i32, i32, i32, i32, i1, i1
}
scf.if %10 {
triton_nvidia_gpu.dot_wait {pendings = 0 : i32}
}
%31 = arith.truncf %30#0 : tensor<128x128xf32, #mma> to tensor<128x128xf16, #mma>
%w = triton_nvidia_gpu.dot_wait %30#0 {pendings = 0 : i32} : tensor<128x128xf32, #mma>
%31 = arith.truncf %w : tensor<128x128xf32, #mma> to tensor<128x128xf16, #mma>
%32 = triton_gpu.convert_layout %31 : (tensor<128x128xf16, #mma>) -> tensor<128x128xf16, #shared1>
triton_nvidia_gpu.store_async %8, %32 : !tt.ptr<tensor<128x128xf16, #blocked>, 1>, tensor<128x128xf16, #shared1>
triton_gpu.async_bulk_commit_group
@@ -136,19 +133,19 @@ module attributes {"triton_gpu.compute-capability" = 90 : i32, "triton_gpu.num-c
%14 = triton_nvidia_gpu.get_thread_id : i32
%15 = arith.cmpi eq, %14, %c0_i32 : i32
%16 = arith.andi %15, %10 : i1
triton_nvidia_gpu.mbarrier_arrive %13, %16 {operand_segment_sizes = array<i32: 1, 1, 0>, trackAsyncOp = false, txCount = 65536 : i32} : !tt.ptr<i64, 3>, i1
%17 = triton_nvidia_gpu.insert_slice_async_v2 %3, %11, %c0_i32, %13, %12 {axis = 0 : i32, cache = 1 : i32, evict = 1 : i32, isVolatile = false, operand_segment_sizes = array<i32: 1, 1, 1, 1, 1, 0>} : !tt.ptr<tensor<128x128xf16, #blocked>, 1>, tensor<3x128x128xf16, #shared1>, i32, !tt.ptr<i64, 3>, tensor<128x128xi1, #blocked1> -> tensor<3x128x128xf16, #shared1>
triton_nvidia_gpu.mbarrier_arrive %13, %16 {operandSegmentSizes = array<i32: 1, 1, 0>, trackAsyncOp = false, txCount = 65536 : i32} : !tt.ptr<i64, 3>, i1
%17 = triton_nvidia_gpu.insert_slice_async_v2 %3, %11, %c0_i32, %13, %12 {axis = 0 : i32, cache = 1 : i32, evict = 1 : i32, isVolatile = false, operandSegmentSizes = array<i32: 1, 1, 1, 1, 1, 0>} : !tt.ptr<tensor<128x128xf16, #blocked>, 1>, tensor<3x128x128xf16, #shared1>, i32, !tt.ptr<i64, 3>, tensor<128x128xi1, #blocked1> -> tensor<3x128x128xf16, #shared1>
%18 = triton_gpu.alloc_tensor : tensor<3x128x128xf16, #shared1>
%19 = triton_nvidia_gpu.insert_slice_async_v2 %6, %18, %c0_i32, %13, %12 {axis = 0 : i32, cache = 1 : i32, evict = 1 : i32, isVolatile = false, operand_segment_sizes = array<i32: 1, 1, 1, 1, 1, 0>} : !tt.ptr<tensor<128x128xf16, #blocked>, 1>, tensor<3x128x128xf16, #shared1>, i32, !tt.ptr<i64, 3>, tensor<128x128xi1, #blocked1> -> tensor<3x128x128xf16, #shared1>
%19 = triton_nvidia_gpu.insert_slice_async_v2 %6, %18, %c0_i32, %13, %12 {axis = 0 : i32, cache = 1 : i32, evict = 1 : i32, isVolatile = false, operandSegmentSizes = array<i32: 1, 1, 1, 1, 1, 0>} : !tt.ptr<tensor<128x128xf16, #blocked>, 1>, tensor<3x128x128xf16, #shared1>, i32, !tt.ptr<i64, 3>, tensor<128x128xi1, #blocked1> -> tensor<3x128x128xf16, #shared1>
%20 = tt.advance %3, [%c0_i32, %c128_i32] : <tensor<128x128xf16, #blocked>, 1>
%21 = tt.advance %6, [%c128_i32, %c0_i32] : <tensor<128x128xf16, #blocked>, 1>
%22 = arith.cmpi sgt, %arg5, %c128_i32 : i32
%23 = tt.splat %22 : (i1) -> tensor<128x128xi1, #blocked1>
%24 = triton_nvidia_gpu.extract_mbarrier %9[%c1_i32] : tensor<3xi64, #shared>, i32 -> <i64, 3>
%25 = arith.andi %15, %22 : i1
triton_nvidia_gpu.mbarrier_arrive %24, %25 {operand_segment_sizes = array<i32: 1, 1, 0>, trackAsyncOp = false, txCount = 65536 : i32} : !tt.ptr<i64, 3>, i1
%26 = triton_nvidia_gpu.insert_slice_async_v2 %20, %17, %c1_i32, %24, %23 {axis = 0 : i32, cache = 1 : i32, evict = 1 : i32, isVolatile = false, operand_segment_sizes = array<i32: 1, 1, 1, 1, 1, 0>} : !tt.ptr<tensor<128x128xf16, #blocked>, 1>, tensor<3x128x128xf16, #shared1>, i32, !tt.ptr<i64, 3>, tensor<128x128xi1, #blocked1> -> tensor<3x128x128xf16, #shared1>
%27 = triton_nvidia_gpu.insert_slice_async_v2 %21, %19, %c1_i32, %24, %23 {axis = 0 : i32, cache = 1 : i32, evict = 1 : i32, isVolatile = false, operand_segment_sizes = array<i32: 1, 1, 1, 1, 1, 0>} : !tt.ptr<tensor<128x128xf16, #blocked>, 1>, tensor<3x128x128xf16, #shared1>, i32, !tt.ptr<i64, 3>, tensor<128x128xi1, #blocked1> -> tensor<3x128x128xf16, #shared1>
triton_nvidia_gpu.mbarrier_arrive %24, %25 {operandSegmentSizes = array<i32: 1, 1, 0>, trackAsyncOp = false, txCount = 65536 : i32} : !tt.ptr<i64, 3>, i1
%26 = triton_nvidia_gpu.insert_slice_async_v2 %20, %17, %c1_i32, %24, %23 {axis = 0 : i32, cache = 1 : i32, evict = 1 : i32, isVolatile = false, operandSegmentSizes = array<i32: 1, 1, 1, 1, 1, 0>} : !tt.ptr<tensor<128x128xf16, #blocked>, 1>, tensor<3x128x128xf16, #shared1>, i32, !tt.ptr<i64, 3>, tensor<128x128xi1, #blocked1> -> tensor<3x128x128xf16, #shared1>
%27 = triton_nvidia_gpu.insert_slice_async_v2 %21, %19, %c1_i32, %24, %23 {axis = 0 : i32, cache = 1 : i32, evict = 1 : i32, isVolatile = false, operandSegmentSizes = array<i32: 1, 1, 1, 1, 1, 0>} : !tt.ptr<tensor<128x128xf16, #blocked>, 1>, tensor<3x128x128xf16, #shared1>, i32, !tt.ptr<i64, 3>, tensor<128x128xi1, #blocked1> -> tensor<3x128x128xf16, #shared1>
%28 = triton_gpu.extract_slice %26[0, 0, 0] [1, 128, 128] [1, 1, 1] : tensor<3x128x128xf16, #shared1> to tensor<128x128xf16, #shared1>
%29 = triton_gpu.extract_slice %27[0, 0, 0] [1, 128, 128] [1, 1, 1] : tensor<3x128x128xf16, #shared1> to tensor<128x128xf16, #shared1>
%b_29 = triton_gpu.convert_layout %29 : (tensor<128x128xf16, #shared1>) -> tensor<128x128xf16, #blocked1>
@@ -158,7 +155,6 @@ module attributes {"triton_gpu.compute-capability" = 90 : i32, "triton_gpu.num-c
triton_nvidia_gpu.mbarrier_wait %33, %arg23 : <i64, 3>
// CHECK: triton_nvidia_gpu.fence_async_shared
%34 = triton_nvidia_gpu.dot_async %arg15, %arg16, %arg10 {allowTF32 = true, maxNumImpreciseAcc = 0 : i32} : tensor<128x128xf16, #shared1> * tensor<128x128xf16, #shared1> -> tensor<128x128xf32, #mma>
triton_nvidia_gpu.dot_wait {pendings = 1 : i32}
%35 = tt.advance %arg11, [%c0_i32, %c128_i32] : <tensor<128x128xf16, #blocked>, 1>
%36 = tt.advance %arg12, [%c128_i32, %c0_i32] : <tensor<128x128xf16, #blocked>, 1>
%37 = arith.addi %arg19, %c128_i32 : i32
@@ -171,10 +167,10 @@ module attributes {"triton_gpu.compute-capability" = 90 : i32, "triton_gpu.num-c
%44 = tt.splat %38 : (i1) -> tensor<128x128xi1, #blocked1>
%45 = triton_nvidia_gpu.extract_mbarrier %9[%arg20] : tensor<3xi64, #shared>, i32 -> <i64, 3>
%46 = arith.andi %15, %38 : i1
triton_nvidia_gpu.mbarrier_arrive %45, %46 {operand_segment_sizes = array<i32: 1, 1, 0>, trackAsyncOp = false, txCount = 65536 : i32} : !tt.ptr<i64, 3>, i1
%47 = triton_nvidia_gpu.insert_slice_async_v2 %42, %arg13, %arg20, %45, %44 {axis = 0 : i32, cache = 1 : i32, evict = 1 : i32, isVolatile = false, operand_segment_sizes = array<i32: 1, 1, 1, 1, 1, 0>} : !tt.ptr<tensor<128x128xf16, #blocked>, 1>, tensor<3x128x128xf16, #shared1>, i32, !tt.ptr<i64, 3>, tensor<128x128xi1, #blocked1> -> tensor<3x128x128xf16, #shared1>
triton_nvidia_gpu.mbarrier_arrive %45, %46 {operandSegmentSizes = array<i32: 1, 1, 0>, trackAsyncOp = false, txCount = 65536 : i32} : !tt.ptr<i64, 3>, i1
%47 = triton_nvidia_gpu.insert_slice_async_v2 %42, %arg13, %arg20, %45, %44 {axis = 0 : i32, cache = 1 : i32, evict = 1 : i32, isVolatile = false, operandSegmentSizes = array<i32: 1, 1, 1, 1, 1, 0>} : !tt.ptr<tensor<128x128xf16, #blocked>, 1>, tensor<3x128x128xf16, #shared1>, i32, !tt.ptr<i64, 3>, tensor<128x128xi1, #blocked1> -> tensor<3x128x128xf16, #shared1>
%48 = triton_gpu.extract_slice %47[%41, 0, 0] [1, 128, 128] [1, 1, 1] : tensor<3x128x128xf16, #shared1> to tensor<128x128xf16, #shared1>
%49 = triton_nvidia_gpu.insert_slice_async_v2 %43, %arg14, %arg20, %45, %44 {axis = 0 : i32, cache = 1 : i32, evict = 1 : i32, isVolatile = false, operand_segment_sizes = array<i32: 1, 1, 1, 1, 1, 0>} : !tt.ptr<tensor<128x128xf16, #blocked>, 1>, tensor<3x128x128xf16, #shared1>, i32, !tt.ptr<i64, 3>, tensor<128x128xi1, #blocked1> -> tensor<3x128x128xf16, #shared1>
%49 = triton_nvidia_gpu.insert_slice_async_v2 %43, %arg14, %arg20, %45, %44 {axis = 0 : i32, cache = 1 : i32, evict = 1 : i32, isVolatile = false, operandSegmentSizes = array<i32: 1, 1, 1, 1, 1, 0>} : !tt.ptr<tensor<128x128xf16, #blocked>, 1>, tensor<3x128x128xf16, #shared1>, i32, !tt.ptr<i64, 3>, tensor<128x128xi1, #blocked1> -> tensor<3x128x128xf16, #shared1>
%50 = triton_gpu.extract_slice %49[%41, 0, 0] [1, 128, 128] [1, 1, 1] : tensor<3x128x128xf16, #shared1> to tensor<128x128xf16, #shared1>
%51 = arith.addi %arg20, %c1_i32 : i32
%52 = arith.cmpi uge, %51, %c3_i32 : i32
@@ -192,10 +188,8 @@ module attributes {"triton_gpu.compute-capability" = 90 : i32, "triton_gpu.num-c
%64 = arith.ori %62, %63 : i1
scf.yield %34, %35, %36, %47, %49, %48, %50, %42, %43, %37, %53, %41, %54, %59, %64 : tensor<128x128xf32, #mma>, !tt.ptr<tensor<128x128xf16, #blocked>, 1>, !tt.ptr<tensor<128x128xf16, #blocked>, 1>, tensor<3x128x128xf16, #shared1>, tensor<3x128x128xf16, #shared1>, tensor<128x128xf16, #shared1>, tensor<128x128xf16, #shared1>, !tt.ptr<tensor<128x128xf16, #blocked>, 1>, !tt.ptr<tensor<128x128xf16, #blocked>, 1>, i32, i32, i32, i32, i1, i1
}
scf.if %10 {
triton_nvidia_gpu.dot_wait {pendings = 0 : i32}
}
%31 = arith.truncf %30#0 : tensor<128x128xf32, #mma> to tensor<128x128xf16, #mma>
%w = triton_nvidia_gpu.dot_wait %30#0 {pendings = 0 : i32} : tensor<128x128xf32, #mma>
%31 = arith.truncf %w : tensor<128x128xf32, #mma> to tensor<128x128xf16, #mma>
%32 = triton_gpu.convert_layout %31 : (tensor<128x128xf16, #mma>) -> tensor<128x128xf16, #shared1>
triton_nvidia_gpu.store_async %8, %32 : !tt.ptr<tensor<128x128xf16, #blocked>, 1>, tensor<128x128xf16, #shared1>
triton_gpu.async_bulk_commit_group

View File

@@ -15,7 +15,6 @@
// CHECK-DAG: %[[CONSTANT_0:.*]] = arith.constant 0 : i32
// CHECK-DAG: %[[CONSTANT_1:.*]] = arith.constant 1 : i32
// CHECK-DAG: %[[CONSTANT_2:.*]] = arith.constant 2 : i32
// CHECK-DAG: %[[CONSTANT_3:.*]] = arith.constant 3 : i32
// CHECK-DAG: %[[LOOP_COND_0:.*]] = arith.cmpi slt, %[[LB:.*]], %[[UB:.*]]
// CHECK: %[[ABUFFER:.*]] = triton_gpu.alloc_tensor
// CHECK-DAG: %[[LOOP_COND_0_SPLAT_A:.*]] = tt.splat %[[LOOP_COND_0]]
@@ -32,13 +31,13 @@
// CHECK: triton_gpu.async_wait {num = 2 : i32}
// CHECK: %[[A0:.*]] = triton_gpu.extract_slice %[[A1BUFFER]][0, 0, 0]
// CHECK: %[[B0:.*]] = triton_gpu.extract_slice %[[B1BUFFER]][0, 0, 0]
// CHECK: scf.for {{.*}} iter_args({{.*}}, {{.*}}, {{.*}}, {{.*}}, {{.*}}, %[[arg_a0:.*]] = %[[A0]], %[[arg_b0:.*]] = %[[B0]], {{.*}}, {{.*}}, {{.*}}, %[[PIPELINE_IDX:.*]] = %[[CONSTANT_2]], %[[LOOP_IDX:.*]] = %[[CONSTANT_0]]
// CHECK: scf.for {{.*}} iter_args({{.*}}, {{.*}}, {{.*}}, {{.*}}, {{.*}}, %[[arg_a0:.*]] = %[[A0]], %[[arg_b0:.*]] = %[[B0]], {{.*}}, {{.*}}, {{.*}}, %[[PIPELINE_IDX:.*]] = %[[CONSTANT_0]], %[[LOOP_IDX:.*]] = %[[CONSTANT_0]]
// CHECK: %[[arg_a0_dot_op:.*]] = triton_gpu.convert_layout %[[arg_a0]]
// CHECK: %[[arg_b0_dot_op_0:.*]] = triton_gpu.convert_layout %[[arg_b0]]
// CHECK: %[[arg_b0_dot_op_1:.*]] = arith.mulf %[[arg_b0_dot_op_0]]
// CHECK: tt.dot %[[arg_a0_dot_op]], %[[arg_b0_dot_op_1]], {{.*}}
// CHECK-DAG: %[[NEXT_LOOP_IDX:.*]] = arith.addi %[[LOOP_IDX]], %[[CONSTANT_1]] : i32
// CHECK-DAG: %[[CMP_LOOP:.*]] = arith.cmpi uge, %[[NEXT_LOOP_IDX]], %[[CONSTANT_3]]
// CHECK-DAG: %[[CMP_LOOP:.*]] = arith.cmpi uge, %[[NEXT_LOOP_IDX]], %[[CONSTANT_2]]
// CHECK-DAG: %[[EXTRACT_IDX:.*]] = arith.select %[[CMP_LOOP]], %[[CONSTANT_0]], %[[NEXT_LOOP_IDX]]
// CHECK: %[[NEXT_A_BUFFER:.*]] = triton_gpu.insert_slice_async {{.*}}, {{.*}}, %[[PIPELINE_IDX]]
// CHECK: %[[NEXT_B_BUFFER:.*]] = triton_gpu.insert_slice_async {{.*}}, {{.*}}, %[[PIPELINE_IDX]]
@@ -46,7 +45,7 @@
// CHECK: %[[NEXT_A:.*]] = triton_gpu.extract_slice %[[NEXT_A_BUFFER]][%[[EXTRACT_IDX]], 0, 0]
// CHECK: %[[NEXT_B:.*]] = triton_gpu.extract_slice %[[NEXT_B_BUFFER]][%[[EXTRACT_IDX]], 0, 0]
// CHECK-DAG: %[[PIPELINE_IDX_PLUS_ONE:.*]] = arith.addi %[[PIPELINE_IDX]], %[[CONSTANT_1]]
// CHECK-DAG: %[[CMP_PIPELINE:.*]] = arith.cmpi uge, %[[PIPELINE_IDX_PLUS_ONE]], %[[CONSTANT_3]]
// CHECK-DAG: %[[CMP_PIPELINE:.*]] = arith.cmpi uge, %[[PIPELINE_IDX_PLUS_ONE]], %[[CONSTANT_2]]
// CHECK-DAG: %[[NEXT_PIPELINE_IDX:.*]] = arith.select %[[CMP_PIPELINE]], %[[CONSTANT_0]], %[[PIPELINE_IDX_PLUS_ONE]]
// CHECK: scf.yield {{.*}}, {{.*}}, {{.*}}, %[[NEXT_A_BUFFER]], %[[NEXT_B_BUFFER]], %[[NEXT_A]], %[[NEXT_B]], {{.*}}, {{.*}}, {{.*}}, %[[NEXT_PIPELINE_IDX]], %[[EXTRACT_IDX]]
tt.func @matmul_loop(%lb : index, %ub : index, %step : index,
@@ -97,7 +96,6 @@ tt.func @matmul_loop(%lb : index, %ub : index, %step : index,
// CHECK-DAG: %[[CONSTANT_0:.*]] = arith.constant 0 : i32
// CHECK-DAG: %[[CONSTANT_1:.*]] = arith.constant 1 : i32
// CHECK-DAG: %[[CONSTANT_2:.*]] = arith.constant 2 : i32
// CHECK-DAG: %[[CONSTANT_3:.*]] = arith.constant 3 : i32
// CHECK: scf.for
// CHECK: %[[ABUFFER:.*]] = triton_gpu.alloc_tensor
// CHECK: %[[A0BUFFER:.*]] = triton_gpu.insert_slice_async {{.*}}, {{.*}}, %[[CONSTANT_0]]
@@ -108,12 +106,12 @@ tt.func @matmul_loop(%lb : index, %ub : index, %step : index,
// CHECK: triton_gpu.async_wait {num = 2 : i32}
// CHECK: %[[A0:.*]] = triton_gpu.extract_slice %[[A1BUFFER]][0, 0, 0]
// CHECK: %[[B0:.*]] = triton_gpu.extract_slice %[[B1BUFFER]][0, 0, 0]
// CHECK: scf.for {{.*}} iter_args({{.*}}, {{.*}}, {{.*}}, {{.*}}, {{.*}}, %[[arg_a0:.*]] = %[[A0]], %[[arg_b0:.*]] = %[[B0]], {{.*}}, {{.*}}, {{.*}}, %[[PIPELINE_IDX:.*]] = %[[CONSTANT_2]], %[[LOOP_IDX:.*]] = %[[CONSTANT_0]]
// CHECK: scf.for {{.*}} iter_args({{.*}}, {{.*}}, {{.*}}, {{.*}}, {{.*}}, %[[arg_a0:.*]] = %[[A0]], %[[arg_b0:.*]] = %[[B0]], {{.*}}, {{.*}}, {{.*}}, %[[PIPELINE_IDX:.*]] = %[[CONSTANT_0]], %[[LOOP_IDX:.*]] = %[[CONSTANT_0]]
// CHECK: %[[arg_a0_dot_op:.*]] = triton_gpu.convert_layout %[[arg_a0]]
// CHECK: %[[arg_b0_dot_op:.*]] = triton_gpu.convert_layout %[[arg_b0]]
// CHECK: tt.dot %[[arg_a0_dot_op]], %[[arg_b0_dot_op]], {{.*}}
// CHECK-DAG: %[[NEXT_LOOP_IDX:.*]] = arith.addi %[[LOOP_IDX]], %[[CONSTANT_1]] : i32
// CHECK-DAG: %[[CMP_LOOP:.*]] = arith.cmpi uge, %[[NEXT_LOOP_IDX]], %[[CONSTANT_3]]
// CHECK-DAG: %[[CMP_LOOP:.*]] = arith.cmpi uge, %[[NEXT_LOOP_IDX]], %[[CONSTANT_2]]
// CHECK-DAG: %[[EXTRACT_IDX:.*]] = arith.select %[[CMP_LOOP]], %[[CONSTANT_0]], %[[NEXT_LOOP_IDX]]
// CHECK: %[[NEXT_A_BUFFER:.*]] = triton_gpu.insert_slice_async {{.*}}, {{.*}}, %[[PIPELINE_IDX]]
// CHECK: %[[NEXT_B_BUFFER:.*]] = triton_gpu.insert_slice_async {{.*}}, {{.*}}, %[[PIPELINE_IDX]]
@@ -121,7 +119,7 @@ tt.func @matmul_loop(%lb : index, %ub : index, %step : index,
// CHECK: %[[NEXT_A:.*]] = triton_gpu.extract_slice %[[NEXT_A_BUFFER]][%[[EXTRACT_IDX]], 0, 0]
// CHECK: %[[NEXT_B:.*]] = triton_gpu.extract_slice %[[NEXT_B_BUFFER]][%[[EXTRACT_IDX]], 0, 0]
// CHECK-DAG: %[[PIPELINE_IDX_PLUS_ONE:.*]] = arith.addi %[[PIPELINE_IDX]], %[[CONSTANT_1]]
// CHECK-DAG: %[[CMP_PIPELINE:.*]] = arith.cmpi uge, %[[PIPELINE_IDX_PLUS_ONE]], %[[CONSTANT_3]]
// CHECK-DAG: %[[CMP_PIPELINE:.*]] = arith.cmpi uge, %[[PIPELINE_IDX_PLUS_ONE]], %[[CONSTANT_2]]
// CHECK-DAG: %[[NEXT_PIPELINE_IDX:.*]] = arith.select %[[CMP_PIPELINE]], %[[CONSTANT_0]], %[[PIPELINE_IDX_PLUS_ONE]]
// CHECK: scf.yield {{.*}}, {{.*}}, {{.*}}, %[[NEXT_A_BUFFER]], %[[NEXT_B_BUFFER]], %[[NEXT_A]], %[[NEXT_B]], {{.*}}, {{.*}}, {{.*}}, %[[NEXT_PIPELINE_IDX]], %[[EXTRACT_IDX]]
tt.func @matmul_loop_nested(%lb : index, %ub : index, %step : index,
@@ -174,23 +172,22 @@ tt.func @matmul_loop_nested(%lb : index, %ub : index, %step : index,
// CHECK-DAG: %[[CONSTANT_0:.*]] = arith.constant 0 : i32
// CHECK-DAG: %[[CONSTANT_1:.*]] = arith.constant 1 : i32
// CHECK-DAG: %[[CONSTANT_2:.*]] = arith.constant 2 : i32
// CHECK-DAG: %[[CONSTANT_3:.*]] = arith.constant 3 : i32
// CHECK: %[[BBUFFER:.*]] = triton_gpu.alloc_tensor
// CHECK: %[[B0BUFFER:.*]] = triton_gpu.insert_slice_async {{.*}}, {{.*}}, %[[CONSTANT_0]]
// CHECK: %[[B1BUFFER:.*]] = triton_gpu.insert_slice_async {{.*}}, {{.*}}, %[[CONSTANT_1]]
// CHECK: triton_gpu.async_wait {num = 1 : i32}
// CHECK: %[[B0:.*]] = triton_gpu.extract_slice %[[B1BUFFER]][0, 0, 0]
// CHECK: scf.for {{.*}} iter_args({{.*}}, {{.*}}, {{.*}}, %[[arg_b0:.*]] = %[[B0]], {{.*}}, {{.*}}, %[[PIPELINE_IDX:.*]] = %[[CONSTANT_2]], %[[LOOP_IDX:.*]] = %[[CONSTANT_0]]
// CHECK: scf.for {{.*}} iter_args({{.*}}, {{.*}}, {{.*}}, %[[arg_b0:.*]] = %[[B0]], {{.*}}, {{.*}}, %[[PIPELINE_IDX:.*]] = %[[CONSTANT_0]], %[[LOOP_IDX:.*]] = %[[CONSTANT_0]]
// CHECK: %[[arg_b0_dot_op:.*]] = triton_gpu.convert_layout %[[arg_b0]]
// CHECK: tt.dot {{.*}}, %[[arg_b0_dot_op]], {{.*}}
// CHECK-DAG: %[[NEXT_LOOP_IDX:.*]] = arith.addi %[[LOOP_IDX]], %[[CONSTANT_1]] : i32
// CHECK-DAG: %[[CMP_LOOP:.*]] = arith.cmpi uge, %[[NEXT_LOOP_IDX]], %[[CONSTANT_3]]
// CHECK-DAG: %[[CMP_LOOP:.*]] = arith.cmpi uge, %[[NEXT_LOOP_IDX]], %[[CONSTANT_2]]
// CHECK-DAG: %[[EXTRACT_IDX:.*]] = arith.select %[[CMP_LOOP]], %[[CONSTANT_0]], %[[NEXT_LOOP_IDX]]
// CHECK: %[[NEXT_B_BUFFER:.*]] = triton_gpu.insert_slice_async {{.*}}, {{.*}}, %[[PIPELINE_IDX]]
// CHECK: triton_gpu.async_wait {num = 1 : i32}
// CHECK: %[[NEXT_B:.*]] = triton_gpu.extract_slice %[[NEXT_B_BUFFER]][%[[EXTRACT_IDX]], 0, 0]
// CHECK-DAG: %[[PIPELINE_IDX_PLUS_ONE:.*]] = arith.addi %[[PIPELINE_IDX]], %[[CONSTANT_1]]
// CHECK-DAG: %[[CMP_PIPELINE:.*]] = arith.cmpi uge, %[[PIPELINE_IDX_PLUS_ONE]], %[[CONSTANT_3]]
// CHECK-DAG: %[[CMP_PIPELINE:.*]] = arith.cmpi uge, %[[PIPELINE_IDX_PLUS_ONE]], %[[CONSTANT_2]]
// CHECK-DAG: %[[NEXT_PIPELINE_IDX:.*]] = arith.select %[[CMP_PIPELINE]], %[[CONSTANT_0]], %[[PIPELINE_IDX_PLUS_ONE]]
// CHECK: scf.yield {{.*}}, {{.*}}, %[[NEXT_B_BUFFER]], %[[NEXT_B]], {{.*}}, {{.*}}, %[[NEXT_PIPELINE_IDX]], %[[EXTRACT_IDX]]
tt.func @matmul_loop_single_pipeline(%lb : index, %ub : index, %step : index,

View File

@@ -12,7 +12,7 @@ module attributes {"triton_gpu.num-ctas" = 1 : i32, "triton_gpu.num-warps" = 4 :
%a_tileptr_init = tt.make_tensor_ptr %A, [%c64, %c16], [%c16, %c1], [%c0, %c0] { order = array<i32: 1, 0> } : !tt.ptr<tensor<64x16xf16>, 1>
// CHECK: %[[BUFFER:.*]] = triton_gpu.alloc_tensor : tensor<1x64x16xf16, #shared>
// CHECK: %[[MBAR:.*]] = triton_nvidia_gpu.alloc_mbarrier {count = 1 : i32} : !tt.ptr<i64, 3>
// CHECK: triton_nvidia_gpu.mbarrier_arrive %[[MBAR]], %{{.*}} {operand_segment_sizes = array<i32: 1, 1, 0>, trackAsyncOp = false, txCount = 2048 : i32} : !tt.ptr<i64, 3>, i1
// CHECK: triton_nvidia_gpu.mbarrier_arrive %[[MBAR]], %{{.*}} {operandSegmentSizes = array<i32: 1, 1, 0>, trackAsyncOp = false, txCount = 2048 : i32} : !tt.ptr<i64, 3>, i1
// CHECK: %[[INSERT:.*]] = triton_nvidia_gpu.insert_slice_async_v2 %[[TENSOR_PTR]], %[[BUFFER]], %{{.*}}, %[[MBAR]]
// CHECK: %[[EXT:.*]] = triton_gpu.extract_slice %[[INSERT]][0, 0, 0] [1, 64, 16] [1, 1, 1] : tensor<1x64x16xf16, #shared> to tensor<64x16xf16, #shared>
// CHECK: triton_nvidia_gpu.mbarrier_wait %[[MBAR]], %false : <i64, 3>

View File

@@ -1,8 +1,8 @@
// RUN: triton-opt %s -split-input-file -convert-triton-to-tritongpu -tritongpu-remove-layout-conversions -tritongpu-pipeline=num-stages=3 -test-print-allocation 2>&1 | FileCheck %s
// CHECK: offset = 0, size = 49152
// CHECK: offset = 49152, size = 49152
// CHECK: size = 98304
// CHECK: offset = 0, size = 32768
// CHECK: offset = 32768, size = 32768
// CHECK: size = 65536
module {
tt.func @matmul_kernel__Pfp32_Pfp32_Pfp32_i32_i32_i32_i32_i32_i32_i32_i32_i32__12c64_13c64_14c64_15c8(%arg0: !tt.ptr<f32> {tt.divisibility = 16 : i32}, %arg1: !tt.ptr<f32> {tt.divisibility = 16 : i32}, %arg2: !tt.ptr<f32> {tt.divisibility = 16 : i32}, %arg3: i32, %arg4: i32, %arg5: i32, %arg6: i32 {tt.divisibility = 16 : i32}, %arg7: i32, %arg8: i32 {tt.divisibility = 16 : i32}, %arg9: i32, %arg10: i32 {tt.divisibility = 16 : i32}, %arg11: i32) {
%cst = arith.constant dense<true> : tensor<64x64xi1>

View File

@@ -1,234 +1,172 @@
// RUN: triton-opt -split-input-file -triton-nvidia-gpu-ws-materialization='compute-capability=90' %s | FileCheck %s
#blocked = #triton_gpu.blocked<{sizePerThread = [8, 1], threadsPerWarp = [4, 8], warpsPerCTA = [1, 4], order = [0, 1], CTAsPerCGA = [1, 1], CTASplitNum = [1, 1], CTAOrder = [0, 1]}>
#blocked1 = #triton_gpu.blocked<{sizePerThread = [1, 8], threadsPerWarp = [8, 4], warpsPerCTA = [4, 1], order = [1, 0], CTAsPerCGA = [1, 1], CTASplitNum = [1, 1], CTAOrder = [0, 1]}>
#blocked2 = #triton_gpu.blocked<{sizePerThread = [1, 8], threadsPerWarp = [2, 16], warpsPerCTA = [4, 1], order = [1, 0], CTAsPerCGA = [1, 1], CTASplitNum = [1, 1], CTAOrder = [0, 1]}>
#blocked = #triton_gpu.blocked<{sizePerThread = [1, 1], threadsPerWarp = [32, 1], warpsPerCTA = [4, 1], order = [0, 1], CTAsPerCGA = [1, 1], CTASplitNum = [1, 1], CTAOrder = [0, 1]}>
#blocked1 = #triton_gpu.blocked<{sizePerThread = [1, 1], threadsPerWarp = [32, 1], warpsPerCTA = [2, 2], order = [0, 1], CTAsPerCGA = [1, 1], CTASplitNum = [1, 1], CTAOrder = [0, 1]}>
#blocked2 = #triton_gpu.blocked<{sizePerThread = [1, 4], threadsPerWarp = [1, 32], warpsPerCTA = [4, 1], order = [1, 0], CTAsPerCGA = [1, 1], CTASplitNum = [1, 1], CTAOrder = [0, 1]}>
#mma = #triton_gpu.mma<{versionMajor = 3, versionMinor = 0, warpsPerCTA = [4, 1], CTAsPerCGA = [1, 1], CTASplitNum = [1, 1], CTAOrder = [1, 0], instrShape = [16, 128, 16]}>
#shared = #triton_gpu.shared<{vec = 8, perPhase = 2, maxPhase = 4, order = [0, 1], CTAsPerCGA = [1, 1], CTASplitNum = [1, 1], CTAOrder = [0, 1], hasLeadingOffset = true}>
#shared1 = #triton_gpu.shared<{vec = 8, perPhase = 2, maxPhase = 4, order = [1, 0], CTAsPerCGA = [1, 1], CTASplitNum = [1, 1], CTAOrder = [0, 1], hasLeadingOffset = true}>
module attributes {"triton_gpu.num-ctas" = 1 : i32, "triton_gpu.num-warps" = 4 : i32, "triton_gpu.enable-warp-specialization" = 1 : i32} {
#shared = #triton_gpu.shared<{vec = 8, perPhase = 1, maxPhase = 8, order = [1, 0], CTAsPerCGA = [1, 1], CTASplitNum = [1, 1], CTAOrder = [0, 1], hasLeadingOffset = true}>
#shared1 = #triton_gpu.shared<{vec = 8, perPhase = 1, maxPhase = 8, order = [0, 1], CTAsPerCGA = [1, 1], CTASplitNum = [1, 1], CTAOrder = [0, 1], hasLeadingOffset = true}>
module attributes {"async.num-agents" = 2 : i32, "triton_gpu.compute-capability" = 90 : i32, "triton_gpu.enable-warp-specialization" = 1 : i32, "triton_gpu.num-ctas" = 1 : i32, "triton_gpu.num-warps" = 4 : i32, "triton_gpu.threads-per-warp" = 32 : i32} {
// CHECK-LABEL: @simple_gemm
// CHECK: triton_nvidia_gpu.alloc_mbarrier
// CHECK: scf.if
// CHECK: scf.for
// CHECK: triton_nvidia_gpu.extract_mbarrier
// CHECK: triton_nvidia_gpu.mbarrier_wait
// CHECK: triton_gpu.insert_slice
// CHECK: triton_gpu.insert_slice
// CHECK: triton_nvidia_gpu.insert_slice_async_v2
// CHECK: triton_nvidia_gpu.insert_slice_async_v2
// CHECK: triton_nvidia_gpu.extract_mbarrier
// CHECK: triton_nvidia_gpu.mbarrier_arrive
// CHECK: scf.yield
// CHECK: scf.if
// CHECK: triton_nvidia_gpu.bar_wait
// CHECK: scf.for
// CHECK: triton_nvidia_gpu.extract_mbarrier
// CHECK: triton_nvidia_gpu.mbarrier_wait
// CHECK: triton_gpu.extract_slice
// CHECK: triton_gpu.extract_slice
// CHECK: tt.dot
// CHECK: triton_nvidia_gpu.dot_async
// CHECK: triton_nvidia_gpu.extract_mbarrier
// CHECK: triton_nvidia_gpu.mbarrier_arrive
// CHECK: scf.yield
// CHECK: triton_nvidia_gpu.bar_arrive
// CHECK: triton_nvidia_gpu.bar_wait
// CHECK: triton_nvidia_gpu.dot_wait
// CHECK: triton_nvidia_gpu.extract_mbarrier
// CHECK: triton_nvidia_gpu.mbarrier_arrive
// CHECK: tt.store
// CHECK: triton_nvidia_gpu.bar_arrive
tt.func public @simple_gemm(%arg0: !tt.ptr<f16, 1> {tt.divisibility = 16 : i32}, %arg1: !tt.ptr<f16, 1> {tt.divisibility = 16 : i32}, %arg2: !tt.ptr<f16, 1> {tt.divisibility = 16 : i32}, %arg3: i32 {tt.divisibility = 16 : i32}, %arg4: i32 {tt.divisibility = 16 : i32}, %arg5: i32 {tt.divisibility = 16 : i32}, %arg6: i32 {tt.divisibility = 16 : i32}, %arg7: i32 {tt.divisibility = 16 : i32}, %arg8: i32 {tt.divisibility = 16 : i32}) {
%0 = triton_gpu.alloc_tensor : tensor<3x32x128xf16, #shared>
%1 = triton_gpu.alloc_tensor : tensor<3x128x32xf16, #shared1>
tt.func public @simple_gemm(%arg0: !tt.ptr<f16, 1> {tt.divisibility = 16 : i32}, %arg1: !tt.ptr<f16, 1> {tt.divisibility = 16 : i32}, %arg2: !tt.ptr<f16, 1> {tt.divisibility = 16 : i32}, %arg3: !tt.ptr<f32, 1> {tt.divisibility = 16 : i32}, %arg4: !tt.ptr<f32, 1> {tt.divisibility = 16 : i32}, %arg5: i32 {tt.divisibility = 16 : i32, tt.max_divisibility = 8 : i32}, %arg6: i32 {tt.max_divisibility = 8 : i32}, %arg7: i32 {tt.divisibility = 16 : i32, tt.max_divisibility = 8 : i32}, %arg8: i32 {tt.divisibility = 16 : i32, tt.max_divisibility = 8 : i32}, %arg9: i32 {tt.divisibility = 16 : i32, tt.max_divisibility = 8 : i32}, %arg10: i32 {tt.max_divisibility = 8 : i32}, %arg11: i32 {tt.max_divisibility = 8 : i32}) attributes {noinline = false} {
%0 = triton_gpu.alloc_tensor : tensor<3x128x64xf16, #shared>
%1 = triton_gpu.alloc_tensor : tensor<3x64x128xf16, #shared1>
%2 = triton_nvidia_gpu.create_token {num = 3 : i32} : tensor<3x!triton_nvidia_gpu.token>
%3 = triton_nvidia_gpu.create_mutex : !triton_nvidia_gpu.mutex
%4 = triton_nvidia_gpu.create_mutex : !triton_nvidia_gpu.mutex
%5 = triton_nvidia_gpu.get_agent_id : i32
%c0_i32 = arith.constant 0 : i32
%6 = arith.cmpi eq, %5, %c0_i32 : i32
scf.if %6 {
%cst = arith.constant {async_agent = dense<0> : vector<1xi32>} dense<32> : tensor<32x128xi32, #blocked>
%cst_1 = arith.constant {async_agent = dense<0> : vector<1xi32>} dense<32> : tensor<128x32xi32, #blocked1>
%c31_i32 = arith.constant {async_agent = dense<[0, 1]> : vector<2xi32>} 31 : i32
%c127_i32 = arith.constant {async_agent = dense<[0, 1]> : vector<2xi32>} 127 : i32
%c1 = arith.constant {async_agent = dense<[0, 1]> : vector<2xi32>} 1 : index
%c0 = arith.constant {async_agent = dense<[0, 1]> : vector<2xi32>} 0 : index
%c32_i32 = arith.constant {async_agent = dense<[0, 1]> : vector<2xi32>} 32 : i32
%c128_i32 = arith.constant {async_agent = dense<[0, 1]> : vector<2xi32>} 128 : i32
%c8_i32 = arith.constant {async_agent = dense<[0, 1]> : vector<2xi32>} 8 : i32
%8 = tt.get_program_id x {async_agent = dense<[0, 1]> : vector<2xi32>} : i32
%9 = tt.get_program_id y {async_agent = dense<0> : vector<1xi32>} : i32
%10 = arith.addi %arg3, %c127_i32 {async_agent = dense<[0, 1]> : vector<2xi32>} : i32
%11 = arith.divsi %10, %c128_i32 {async_agent = dense<[0, 1]> : vector<2xi32>} : i32
%12 = arith.addi %arg4, %c127_i32 {async_agent = dense<[0, 1]> : vector<2xi32>} : i32
%13 = arith.divsi %12, %c128_i32 {async_agent = dense<[0, 1]> : vector<2xi32>} : i32
%14 = arith.muli %13, %c8_i32 {async_agent = dense<[0, 1]> : vector<2xi32>} : i32
%15 = arith.divsi %8, %14 {async_agent = dense<[0, 1]> : vector<2xi32>} : i32
%16 = arith.muli %15, %c8_i32 {async_agent = dense<[0, 1]> : vector<2xi32>} : i32
%17 = arith.subi %11, %16 {async_agent = dense<[0, 1]> : vector<2xi32>} : i32
%18 = arith.cmpi slt, %17, %c8_i32 {async_agent = dense<[0, 1]> : vector<2xi32>} : i32
%19 = arith.select %18, %17, %c8_i32 {async_agent = dense<[0, 1]> : vector<2xi32>} : i32
%20 = arith.remsi %8, %19 {async_agent = dense<[0, 1]> : vector<2xi32>} : i32
%21 = arith.addi %16, %20 {async_agent = dense<[0, 1]> : vector<2xi32>} : i32
%22 = arith.remsi %8, %14 {async_agent = dense<[0, 1]> : vector<2xi32>} : i32
%23 = arith.divsi %22, %19 {async_agent = dense<[0, 1]> : vector<2xi32>} : i32
%24 = arith.muli %21, %c128_i32 {async_agent = dense<[0, 1]> : vector<2xi32>} : i32
%25 = tt.make_range {async_agent = dense<0> : vector<1xi32>, end = 128 : i32, start = 0 : i32} : tensor<128xi32, #triton_gpu.slice<{dim = 1, parent = #blocked1}>>
%26 = tt.make_range {async_agent = dense<0> : vector<1xi32>, end = 128 : i32, start = 0 : i32} : tensor<128xi32, #triton_gpu.slice<{dim = 0, parent = #blocked}>>
%27 = tt.splat %24 {async_agent = dense<0> : vector<1xi32>} : (i32) -> tensor<128xi32, #triton_gpu.slice<{dim = 1, parent = #blocked1}>>
%28 = arith.addi %27, %25 {async_agent = dense<0> : vector<1xi32>} : tensor<128xi32, #triton_gpu.slice<{dim = 1, parent = #blocked1}>>
%29 = arith.muli %23, %c128_i32 {async_agent = dense<[0, 1]> : vector<2xi32>} : i32
%30 = tt.splat %29 {async_agent = dense<0> : vector<1xi32>} : (i32) -> tensor<128xi32, #triton_gpu.slice<{dim = 0, parent = #blocked}>>
%31 = arith.addi %30, %26 {async_agent = dense<0> : vector<1xi32>} : tensor<128xi32, #triton_gpu.slice<{dim = 0, parent = #blocked}>>
%32 = tt.splat %arg3 {async_agent = dense<0> : vector<1xi32>} : (i32) -> tensor<128xi32, #triton_gpu.slice<{dim = 1, parent = #blocked1}>>
%33 = arith.remsi %28, %32 {async_agent = dense<0> : vector<1xi32>, tt.contiguity = dense<128> : tensor<1xi32>, tt.divisibility = dense<128> : tensor<1xi32>} : tensor<128xi32, #triton_gpu.slice<{dim = 1, parent = #blocked1}>>
%34 = tt.splat %arg4 {async_agent = dense<0> : vector<1xi32>} : (i32) -> tensor<128xi32, #triton_gpu.slice<{dim = 0, parent = #blocked}>>
%35 = arith.remsi %31, %34 {async_agent = dense<0> : vector<1xi32>, tt.contiguity = dense<128> : tensor<1xi32>, tt.divisibility = dense<128> : tensor<1xi32>} : tensor<128xi32, #triton_gpu.slice<{dim = 0, parent = #blocked}>>
%36 = arith.muli %9, %c32_i32 {async_agent = dense<0> : vector<1xi32>} : i32
%37 = tt.make_range {async_agent = dense<0> : vector<1xi32>, end = 32 : i32, start = 0 : i32} : tensor<32xi32, #triton_gpu.slice<{dim = 0, parent = #blocked1}>>
%38 = tt.make_range {async_agent = dense<0> : vector<1xi32>, end = 32 : i32, start = 0 : i32} : tensor<32xi32, #triton_gpu.slice<{dim = 1, parent = #blocked}>>
%39 = tt.splat %36 {async_agent = dense<0> : vector<1xi32>} : (i32) -> tensor<32xi32, #triton_gpu.slice<{dim = 0, parent = #blocked1}>>
%40 = tt.splat %36 {async_agent = dense<0> : vector<1xi32>} : (i32) -> tensor<32xi32, #triton_gpu.slice<{dim = 1, parent = #blocked}>>
%41 = arith.addi %39, %37 {async_agent = dense<0> : vector<1xi32>} : tensor<32xi32, #triton_gpu.slice<{dim = 0, parent = #blocked1}>>
%42 = arith.addi %40, %38 {async_agent = dense<0> : vector<1xi32>} : tensor<32xi32, #triton_gpu.slice<{dim = 1, parent = #blocked}>>
%43 = tt.expand_dims %33 {async_agent = dense<0> : vector<1xi32>, axis = 1 : i32} : (tensor<128xi32, #triton_gpu.slice<{dim = 1, parent = #blocked1}>>) -> tensor<128x1xi32, #blocked1>
%44 = tt.splat %arg6 {async_agent = dense<0> : vector<1xi32>} : (i32) -> tensor<128x1xi32, #blocked1>
%45 = arith.muli %43, %44 {async_agent = dense<0> : vector<1xi32>} : tensor<128x1xi32, #blocked1>
%46 = tt.expand_dims %41 {async_agent = dense<0> : vector<1xi32>, axis = 0 : i32} : (tensor<32xi32, #triton_gpu.slice<{dim = 0, parent = #blocked1}>>) -> tensor<1x32xi32, #blocked1>
%47 = tt.broadcast %45 {async_agent = dense<0> : vector<1xi32>} : (tensor<128x1xi32, #blocked1>) -> tensor<128x32xi32, #blocked1>
%48 = tt.broadcast %46 {async_agent = dense<0> : vector<1xi32>} : (tensor<1x32xi32, #blocked1>) -> tensor<128x32xi32, #blocked1>
%49 = arith.addi %47, %48 {async_agent = dense<0> : vector<1xi32>} : tensor<128x32xi32, #blocked1>
%50 = tt.splat %arg0 {async_agent = dense<0> : vector<1xi32>} : (!tt.ptr<f16, 1>) -> tensor<128x32x!tt.ptr<f16, 1>, #blocked1>
%51 = tt.addptr %50, %49 {async_agent = dense<0> : vector<1xi32>} : tensor<128x32x!tt.ptr<f16, 1>, #blocked1>, tensor<128x32xi32, #blocked1>
%52 = tt.expand_dims %42 {async_agent = dense<0> : vector<1xi32>, axis = 1 : i32} : (tensor<32xi32, #triton_gpu.slice<{dim = 1, parent = #blocked}>>) -> tensor<32x1xi32, #blocked>
%53 = tt.expand_dims %35 {async_agent = dense<0> : vector<1xi32>, axis = 0 : i32} : (tensor<128xi32, #triton_gpu.slice<{dim = 0, parent = #blocked}>>) -> tensor<1x128xi32, #blocked>
%54 = tt.splat %arg7 {async_agent = dense<0> : vector<1xi32>} : (i32) -> tensor<1x128xi32, #blocked>
%55 = arith.muli %53, %54 {async_agent = dense<0> : vector<1xi32>} : tensor<1x128xi32, #blocked>
%56 = tt.broadcast %52 {async_agent = dense<0> : vector<1xi32>} : (tensor<32x1xi32, #blocked>) -> tensor<32x128xi32, #blocked>
%57 = tt.broadcast %55 {async_agent = dense<0> : vector<1xi32>} : (tensor<1x128xi32, #blocked>) -> tensor<32x128xi32, #blocked>
%58 = arith.addi %56, %57 {async_agent = dense<0> : vector<1xi32>} : tensor<32x128xi32, #blocked>
%59 = tt.splat %arg1 {async_agent = dense<0> : vector<1xi32>} : (!tt.ptr<f16, 1>) -> tensor<32x128x!tt.ptr<f16, 1>, #blocked>
%60 = tt.addptr %59, %58 {async_agent = dense<0> : vector<1xi32>} : tensor<32x128x!tt.ptr<f16, 1>, #blocked>, tensor<32x128xi32, #blocked>
%61 = arith.addi %arg5, %c31_i32 {async_agent = dense<[0, 1]> : vector<2xi32>} : i32
%62 = arith.divsi %61, %c32_i32 {async_agent = dense<[0, 1]> : vector<2xi32>} : i32
%63 = arith.index_cast %62 {async_agent = dense<[0, 1]> : vector<2xi32>} : i32 to index
%c0_i32_2 = arith.constant {async_agent = dense<[0, 1]> : vector<2xi32>} 0 : i32
%64:3 = scf.for %arg9 = %c0 to %63 step %c1 iter_args(%arg10 = %51, %arg11 = %60, %arg12 = %c0_i32_2) -> (tensor<128x32x!tt.ptr<f16, 1>, #blocked1>, tensor<32x128x!tt.ptr<f16, 1>, #blocked>, i32) {
triton_nvidia_gpu.producer_acquire %2, %arg12 {async_agent = dense<1> : vector<1xi32>} : tensor<3x!triton_nvidia_gpu.token>, i32
%65 = triton_gpu.insert_slice %arg10, %1, %arg12 {async_agent = dense<0> : vector<1xi32>, axis = 0 : i32, cache = 1 : i32, evict = 1 : i32, isVolatile = false} : tensor<128x32x!tt.ptr<f16, 1>, #blocked1> -> tensor<3x128x32xf16, #shared1>
%66 = triton_gpu.insert_slice %arg11, %0, %arg12 {async_agent = dense<0> : vector<1xi32>, axis = 0 : i32, cache = 1 : i32, evict = 1 : i32, isVolatile = false} : tensor<32x128x!tt.ptr<f16, 1>, #blocked> -> tensor<3x32x128xf16, #shared>
%67 = tt.addptr %arg10, %cst_1 {async_agent = dense<0> : vector<1xi32>} : tensor<128x32x!tt.ptr<f16, 1>, #blocked1>, tensor<128x32xi32, #blocked1>
%68 = tt.addptr %arg11, %cst {async_agent = dense<0> : vector<1xi32>} : tensor<32x128x!tt.ptr<f16, 1>, #blocked>, tensor<32x128xi32, #blocked>
%c1_i32_3 = arith.constant {async_agent = dense<0> : vector<1xi32>} 1 : i32
%c3_i32 = arith.constant {async_agent = dense<0> : vector<1xi32>} 3 : i32
%69 = arith.addi %arg12, %c1_i32_3 {async_agent = dense<0> : vector<1xi32>} : i32
%70 = arith.remsi %69, %c3_i32 {async_agent = dense<0> : vector<1xi32>} : i32
triton_nvidia_gpu.producer_commit %2, %arg12 {async_agent = dense<1> : vector<1xi32>} : tensor<3x!triton_nvidia_gpu.token>, i32
scf.yield %67, %68, %70 : tensor<128x32x!tt.ptr<f16, 1>, #blocked1>, tensor<32x128x!tt.ptr<f16, 1>, #blocked>, i32
} {async_agent = dense<0> : vector<1xi32>}
}
%c1_i32 = arith.constant 1 : i32
%c1_i32_0 = arith.constant 1 : i32
%7 = arith.cmpi sge, %5, %c1_i32_0 : i32
scf.if %7 {
%cst = arith.constant {async_agent = dense<1> : vector<1xi32>} dense<0.000000e+00> : tensor<128x128xf32, #mma>
%c31_i32 = arith.constant {async_agent = dense<[0, 1]> : vector<2xi32>} 31 : i32
%c127_i32 = arith.constant {async_agent = dense<[0, 1]> : vector<2xi32>} 127 : i32
%c1 = arith.constant {async_agent = dense<[0, 1]> : vector<2xi32>} 1 : index
%c0 = arith.constant {async_agent = dense<[0, 1]> : vector<2xi32>} 0 : index
%c32_i32 = arith.constant {async_agent = dense<[0, 1]> : vector<2xi32>} 32 : i32
%c128_i32 = arith.constant {async_agent = dense<[0, 1]> : vector<2xi32>} 128 : i32
%c8_i32 = arith.constant {async_agent = dense<[0, 1]> : vector<2xi32>} 8 : i32
%8 = tt.get_program_id x {async_agent = dense<[0, 1]> : vector<2xi32>} : i32
%9 = arith.addi %arg3, %c127_i32 {async_agent = dense<[0, 1]> : vector<2xi32>} : i32
%10 = arith.divsi %9, %c128_i32 {async_agent = dense<[0, 1]> : vector<2xi32>} : i32
%11 = arith.addi %arg4, %c127_i32 {async_agent = dense<[0, 1]> : vector<2xi32>} : i32
%12 = arith.divsi %11, %c128_i32 {async_agent = dense<[0, 1]> : vector<2xi32>} : i32
%13 = arith.muli %12, %c8_i32 {async_agent = dense<[0, 1]> : vector<2xi32>} : i32
%14 = arith.divsi %8, %13 {async_agent = dense<[0, 1]> : vector<2xi32>} : i32
%15 = arith.muli %14, %c8_i32 {async_agent = dense<[0, 1]> : vector<2xi32>} : i32
%16 = arith.subi %10, %15 {async_agent = dense<[0, 1]> : vector<2xi32>} : i32
%17 = arith.cmpi slt, %16, %c8_i32 {async_agent = dense<[0, 1]> : vector<2xi32>} : i32
%18 = arith.select %17, %16, %c8_i32 {async_agent = dense<[0, 1]> : vector<2xi32>} : i32
%19 = arith.remsi %8, %18 {async_agent = dense<[0, 1]> : vector<2xi32>} : i32
%20 = arith.addi %15, %19 {async_agent = dense<[0, 1]> : vector<2xi32>} : i32
%21 = arith.remsi %8, %13 {async_agent = dense<[0, 1]> : vector<2xi32>} : i32
%22 = arith.divsi %21, %18 {async_agent = dense<[0, 1]> : vector<2xi32>} : i32
%23 = arith.muli %20, %c128_i32 {async_agent = dense<[0, 1]> : vector<2xi32>} : i32
%24 = tt.make_range {async_agent = dense<1> : vector<1xi32>, end = 128 : i32, start = 0 : i32} : tensor<128xi32, #triton_gpu.slice<{dim = 1, parent = #blocked2}>>
%25 = tt.make_range {async_agent = dense<1> : vector<1xi32>, end = 128 : i32, start = 0 : i32} : tensor<128xi32, #triton_gpu.slice<{dim = 0, parent = #blocked2}>>
%26 = tt.make_range {async_agent = dense<1> : vector<1xi32>, end = 128 : i32, start = 0 : i32} : tensor<128xi32, #triton_gpu.slice<{dim = 1, parent = #blocked2}>>
%27 = tt.make_range {async_agent = dense<1> : vector<1xi32>, end = 128 : i32, start = 0 : i32} : tensor<128xi32, #triton_gpu.slice<{dim = 0, parent = #blocked2}>>
%28 = tt.splat %23 {async_agent = dense<1> : vector<1xi32>} : (i32) -> tensor<128xi32, #triton_gpu.slice<{dim = 1, parent = #blocked2}>>
%29 = tt.splat %23 {async_agent = dense<1> : vector<1xi32>} : (i32) -> tensor<128xi32, #triton_gpu.slice<{dim = 1, parent = #blocked2}>>
%30 = arith.addi %28, %24 {async_agent = dense<1> : vector<1xi32>} : tensor<128xi32, #triton_gpu.slice<{dim = 1, parent = #blocked2}>>
%31 = arith.addi %29, %26 {async_agent = dense<1> : vector<1xi32>} : tensor<128xi32, #triton_gpu.slice<{dim = 1, parent = #blocked2}>>
%32 = arith.muli %22, %c128_i32 {async_agent = dense<[0, 1]> : vector<2xi32>} : i32
%33 = tt.splat %32 {async_agent = dense<1> : vector<1xi32>} : (i32) -> tensor<128xi32, #triton_gpu.slice<{dim = 0, parent = #blocked2}>>
%34 = tt.splat %32 {async_agent = dense<1> : vector<1xi32>} : (i32) -> tensor<128xi32, #triton_gpu.slice<{dim = 0, parent = #blocked2}>>
%35 = arith.addi %33, %25 {async_agent = dense<1> : vector<1xi32>} : tensor<128xi32, #triton_gpu.slice<{dim = 0, parent = #blocked2}>>
%36 = arith.addi %34, %27 {async_agent = dense<1> : vector<1xi32>} : tensor<128xi32, #triton_gpu.slice<{dim = 0, parent = #blocked2}>>
%37 = tt.splat %arg3 {async_agent = dense<1> : vector<1xi32>} : (i32) -> tensor<128xi32, #triton_gpu.slice<{dim = 1, parent = #blocked2}>>
%38 = tt.splat %arg4 {async_agent = dense<1> : vector<1xi32>} : (i32) -> tensor<128xi32, #triton_gpu.slice<{dim = 0, parent = #blocked2}>>
%39 = arith.addi %arg5, %c31_i32 {async_agent = dense<[0, 1]> : vector<2xi32>} : i32
%40 = arith.divsi %39, %c32_i32 {async_agent = dense<[0, 1]> : vector<2xi32>} : i32
%41 = arith.index_cast %40 {async_agent = dense<[0, 1]> : vector<2xi32>} : i32 to index
%c127_i32 = arith.constant 127 : i32
%c1_i64 = arith.constant 1 : i64
%c128_i32 = arith.constant 128 : i32
%c8_i32 = arith.constant 8 : i32
%3 = tt.get_program_id x : i32
%4 = arith.addi %arg6, %c127_i32 : i32
%5 = arith.divsi %4, %c128_i32 : i32
%6 = arith.addi %arg5, %c127_i32 : i32
%7 = arith.divsi %6, %c128_i32 : i32
%8 = arith.muli %5, %c8_i32 : i32
%9 = arith.divsi %3, %8 : i32
%10 = arith.muli %9, %c8_i32 : i32
%11 = arith.subi %7, %10 : i32
%12 = arith.minsi %11, %c8_i32 : i32
%13 = arith.remsi %3, %12 : i32
%14 = arith.addi %10, %13 : i32
%15 = arith.remsi %3, %8 : i32
%16 = arith.divsi %15, %12 : i32
%17 = arith.muli %14, %c128_i32 : i32
%18 = arith.muli %16, %c128_i32 : i32
%19 = arith.extsi %arg5 : i32 to i64
%20 = arith.extsi %arg7 : i32 to i64
%21 = arith.extsi %arg8 : i32 to i64
%22 = tt.make_tensor_ptr %arg0, [%19, %20], [%21, %c1_i64], [%17, %c0_i32] {order = array<i32: 1, 0>} : <tensor<128x64xf16, #blocked>, 1>
%23 = arith.extsi %arg6 : i32 to i64
%24 = arith.extsi %arg9 : i32 to i64
%25 = tt.make_tensor_ptr %arg1, [%20, %23], [%c1_i64, %24], [%c0_i32, %18] {order = array<i32: 0, 1>} : <tensor<64x128xf16, #blocked1>, 1>
%26 = arith.extsi %arg11 : i32 to i64
%27 = tt.make_tensor_ptr %arg4, [%19, %23], [%26, %c1_i64], [%17, %18] {order = array<i32: 1, 0>} : <tensor<128x128xf32, #blocked>, 1>
%28 = triton_nvidia_gpu.get_agent_id : i32
%c0_i32_0 = arith.constant 0 : i32
%29 = arith.cmpi eq, %28, %c0_i32_0 : i32
scf.if %29 {
%c64_i32 = arith.constant {async_agent = dense<[0, 1]> : vector<2xi32>} 64 : i32
%c3_i32 = arith.constant {async_agent = dense<[0, 1]> : vector<2xi32>} 3 : i32
%c0_i32_1 = arith.constant {async_agent = dense<[0, 1]> : vector<2xi32>} 0 : i32
triton_nvidia_gpu.lock %3 {mutex.barId = dense<1> : vector<1xi32>, mutex.numThreads = dense<256> : vector<1xi32>} : !triton_nvidia_gpu.mutex
%42:2 = scf.for %arg9 = %c0 to %41 step %c1 iter_args(%arg10 = %cst, %arg11 = %c0_i32_1) -> (tensor<128x128xf32, #mma>, i32) {
triton_nvidia_gpu.consumer_wait %2, %arg11 {async_agent = dense<1> : vector<1xi32>} : tensor<3x!triton_nvidia_gpu.token>, i32
%62 = triton_gpu.extract_slice %1[%arg11, 0, 0] [1, 128, 32] [1, 1, 1] {async_agent = dense<1> : vector<1xi32>} : tensor<3x128x32xf16, #shared1> to tensor<128x32xf16, #shared1>
%63 = triton_gpu.extract_slice %0[%arg11, 0, 0] [1, 32, 128] [1, 1, 1] {async_agent = dense<1> : vector<1xi32>} : tensor<3x32x128xf16, #shared> to tensor<32x128xf16, #shared>
%64 = triton_gpu.convert_layout %62 {async_agent = dense<1> : vector<1xi32>} : (tensor<128x32xf16, #shared1>) -> tensor<128x32xf16, #shared1>
%65 = triton_gpu.convert_layout %63 {async_agent = dense<1> : vector<1xi32>} : (tensor<32x128xf16, #shared>) -> tensor<32x128xf16, #shared>
%66 = tt.dot %64, %65, %arg10 {allowTF32 = true, maxNumImpreciseAcc = 0 : i32, async_agent = dense<1> : vector<1xi32>} : tensor<128x32xf16, #shared1> * tensor<32x128xf16, #shared> -> tensor<128x128xf32, #mma>
%c1_i32_2 = arith.constant {async_agent = dense<1> : vector<1xi32>} 1 : i32
%c3_i32 = arith.constant {async_agent = dense<1> : vector<1xi32>} 3 : i32
%67 = arith.addi %arg11, %c1_i32_2 {async_agent = dense<1> : vector<1xi32>} : i32
%68 = arith.remsi %67, %c3_i32 {async_agent = dense<1> : vector<1xi32>} : i32
triton_nvidia_gpu.consumer_release %2, %arg11 {async_agent = dense<1> : vector<1xi32>} : tensor<3x!triton_nvidia_gpu.token>, i32
scf.yield %66, %68 : tensor<128x128xf32, #mma>, i32
%false = arith.constant {async_agent = dense<[0, 1]> : vector<2xi32>} false
%31:4 = scf.for %arg12 = %c0_i32 to %arg7 step %c64_i32 iter_args(%arg13 = %22, %arg14 = %25, %arg15 = %false, %arg16 = %c0_i32_1) -> (!tt.ptr<tensor<128x64xf16, #blocked>, 1>, !tt.ptr<tensor<64x128xf16, #blocked1>, 1>, i1, i32) : i32 {
triton_nvidia_gpu.producer_acquire %2, %arg16 {async_agent = dense<0> : vector<1xi32>} : tensor<3x!triton_nvidia_gpu.token>, i32
%32 = triton_gpu.insert_slice %arg13, %0, %arg16 {async_agent = dense<0> : vector<1xi32>, axis = 0 : i32, cache = 1 : i32, evict = 1 : i32, isVolatile = false} : !tt.ptr<tensor<128x64xf16, #blocked>, 1> -> tensor<3x128x64xf16, #shared>
%33 = triton_gpu.insert_slice %arg14, %1, %arg16 {async_agent = dense<0> : vector<1xi32>, axis = 0 : i32, cache = 1 : i32, evict = 1 : i32, isVolatile = false} : !tt.ptr<tensor<64x128xf16, #blocked1>, 1> -> tensor<3x64x128xf16, #shared1>
triton_nvidia_gpu.producer_commit %2, %arg16 {async_agent = dense<0> : vector<1xi32>} : tensor<3x!triton_nvidia_gpu.token>, i32
%34 = tt.advance %arg13, [%c0_i32, %c64_i32] {async_agent = dense<0> : vector<1xi32>} : <tensor<128x64xf16, #blocked>, 1>
%35 = tt.advance %arg14, [%c64_i32, %c0_i32] {async_agent = dense<0> : vector<1xi32>} : <tensor<64x128xf16, #blocked1>, 1>
%c1_i32_2 = arith.constant {async_agent = dense<0> : vector<1xi32>} 1 : i32
%c0_i32_3 = arith.constant {async_agent = dense<0> : vector<1xi32>} 0 : i32
%true = arith.constant {async_agent = dense<0> : vector<1xi32>} true
%36 = arith.addi %arg16, %c1_i32_2 {async_agent = dense<0> : vector<1xi32>} : i32
%37 = arith.cmpi uge, %36, %c3_i32 {async_agent = dense<0> : vector<1xi32>} : i32
%38 = arith.cmpi ult, %36, %c3_i32 {async_agent = dense<0> : vector<1xi32>} : i32
%39 = arith.subi %36, %c3_i32 {async_agent = dense<0> : vector<1xi32>} : i32
%40 = arith.select %37, %39, %36 {async_agent = dense<0> : vector<1xi32>} : i32
%41 = arith.xori %arg15, %true {async_agent = dense<0> : vector<1xi32>} : i1
%42 = arith.andi %37, %41 {async_agent = dense<0> : vector<1xi32>} : i1
%43 = arith.andi %38, %arg15 {async_agent = dense<0> : vector<1xi32>} : i1
%44 = arith.ori %42, %43 {async_agent = dense<0> : vector<1xi32>} : i1
scf.yield {async_agent = dense<0> : vector<1xi32>} %34, %35, %44, %40 : !tt.ptr<tensor<128x64xf16, #blocked>, 1>, !tt.ptr<tensor<64x128xf16, #blocked1>, 1>, i1, i32
} {async_agent = dense<0> : vector<1xi32>}
} {async_agent = dense<0> : vector<1xi32>}
%c1_i32 = arith.constant 1 : i32
%30 = arith.cmpi eq, %28, %c1_i32 : i32
scf.if %30 {
%cst = arith.constant {async_agent = dense<1> : vector<1xi32>} dense<0.000000e+00> : tensor<128x128xf32, #mma>
%c64_i32 = arith.constant {async_agent = dense<[0, 1]> : vector<2xi32>} 64 : i32
%c3_i32 = arith.constant {async_agent = dense<[0, 1]> : vector<2xi32>} 3 : i32
%c0_i32_1 = arith.constant {async_agent = dense<[0, 1]> : vector<2xi32>} 0 : i32
%false = arith.constant {async_agent = dense<[0, 1]> : vector<2xi32>} false
%31:3 = scf.for %arg12 = %c0_i32 to %arg7 step %c64_i32 iter_args(%arg13 = %cst, %arg14 = %false, %arg15 = %c0_i32_1) -> (tensor<128x128xf32, #mma>, i1, i32) : i32 {
triton_nvidia_gpu.consumer_wait %2, %arg15 {async_agent = dense<1> : vector<1xi32>} : tensor<3x!triton_nvidia_gpu.token>, i32
%37 = triton_gpu.extract_slice %0[%arg15, 0, 0] [1, 128, 64] [1, 1, 1] {async_agent = dense<1> : vector<1xi32>} : tensor<3x128x64xf16, #shared> to tensor<128x64xf16, #shared>
%38 = triton_gpu.convert_layout %37 {async_agent = dense<1> : vector<1xi32>} : (tensor<128x64xf16, #shared>) -> tensor<128x64xf16, #shared>
%39 = triton_gpu.extract_slice %1[%arg15, 0, 0] [1, 64, 128] [1, 1, 1] {async_agent = dense<1> : vector<1xi32>} : tensor<3x64x128xf16, #shared1> to tensor<64x128xf16, #shared1>
%40 = triton_gpu.convert_layout %39 {async_agent = dense<1> : vector<1xi32>} : (tensor<64x128xf16, #shared1>) -> tensor<64x128xf16, #shared1>
%41 = triton_nvidia_gpu.dot_async %38, %40, %arg13 {allowTF32 = true, async_agent = dense<1> : vector<1xi32>, maxNumImpreciseAcc = 0 : i32} : tensor<128x64xf16, #shared> * tensor<64x128xf16, #shared1> -> tensor<128x128xf32, #mma>
%42 = arith.cmpi sgt, %arg12, %c0_i32 {async_agent = dense<1> : vector<1xi32>} : i32
scf.if %42 {
%c0_i32_6 = arith.constant {async_agent = dense<1> : vector<1xi32>} 0 : i32
%c1_i32_7 = arith.constant {async_agent = dense<1> : vector<1xi32>} 1 : i32
%c2_i32_8 = arith.constant {async_agent = dense<1> : vector<1xi32>} 2 : i32
%52 = arith.subi %arg15, %c1_i32_7 {async_agent = dense<1> : vector<1xi32>} : i32
%53 = arith.cmpi eq, %arg15, %c0_i32_6 {async_agent = dense<1> : vector<1xi32>} : i32
%54 = arith.select %53, %c2_i32_8, %52 {async_agent = dense<1> : vector<1xi32>} : i32
triton_nvidia_gpu.consumer_release %2, %54 {async_agent = dense<1> : vector<1xi32>} : tensor<3x!triton_nvidia_gpu.token>, i32
} {async_agent = dense<1> : vector<1xi32>}
%c1_i32_4 = arith.constant {async_agent = dense<1> : vector<1xi32>} 1 : i32
%c0_i32_5 = arith.constant {async_agent = dense<1> : vector<1xi32>} 0 : i32
%true = arith.constant {async_agent = dense<1> : vector<1xi32>} true
%43 = arith.addi %arg15, %c1_i32_4 {async_agent = dense<1> : vector<1xi32>} : i32
%44 = arith.cmpi uge, %43, %c3_i32 {async_agent = dense<1> : vector<1xi32>} : i32
%45 = arith.cmpi ult, %43, %c3_i32 {async_agent = dense<1> : vector<1xi32>} : i32
%46 = arith.subi %43, %c3_i32 {async_agent = dense<1> : vector<1xi32>} : i32
%47 = arith.select %44, %46, %43 {async_agent = dense<1> : vector<1xi32>} : i32
%48 = arith.xori %arg14, %true {async_agent = dense<1> : vector<1xi32>} : i1
%49 = arith.andi %44, %48 {async_agent = dense<1> : vector<1xi32>} : i1
%50 = arith.andi %45, %arg14 {async_agent = dense<1> : vector<1xi32>} : i1
%51 = arith.ori %49, %50 {async_agent = dense<1> : vector<1xi32>} : i1
scf.yield {async_agent = dense<1> : vector<1xi32>} %41, %51, %47 : tensor<128x128xf32, #mma>, i1, i32
} {async_agent = dense<1> : vector<1xi32>}
triton_nvidia_gpu.unlock %3 {mutex.barId = dense<2> : vector<1xi32>, mutex.numThreads = dense<256> : vector<1xi32>} : !triton_nvidia_gpu.mutex
triton_nvidia_gpu.lock %4 {mutex.barId = dense<3> : vector<1xi32>, mutex.numThreads = dense<256> : vector<1xi32>} : !triton_nvidia_gpu.mutex
%43 = arith.truncf %42#0 {async_agent = dense<1> : vector<1xi32>} : tensor<128x128xf32, #mma> to tensor<128x128xf16, #mma>
%44 = tt.expand_dims %30 {async_agent = dense<1> : vector<1xi32>, axis = 1 : i32} : (tensor<128xi32, #triton_gpu.slice<{dim = 1, parent = #blocked2}>>) -> tensor<128x1xi32, #blocked2>
%45 = tt.splat %arg8 {async_agent = dense<1> : vector<1xi32>} : (i32) -> tensor<128x1xi32, #blocked2>
%46 = arith.muli %44, %45 {async_agent = dense<1> : vector<1xi32>} : tensor<128x1xi32, #blocked2>
%47 = tt.expand_dims %35 {async_agent = dense<1> : vector<1xi32>, axis = 0 : i32} : (tensor<128xi32, #triton_gpu.slice<{dim = 0, parent = #blocked2}>>) -> tensor<1x128xi32, #blocked2>
%48 = tt.broadcast %46 {async_agent = dense<1> : vector<1xi32>} : (tensor<128x1xi32, #blocked2>) -> tensor<128x128xi32, #blocked2>
%49 = tt.broadcast %47 {async_agent = dense<1> : vector<1xi32>} : (tensor<1x128xi32, #blocked2>) -> tensor<128x128xi32, #blocked2>
%50 = arith.addi %48, %49 {async_agent = dense<1> : vector<1xi32>} : tensor<128x128xi32, #blocked2>
%51 = tt.splat %arg2 {async_agent = dense<1> : vector<1xi32>} : (!tt.ptr<f16, 1>) -> tensor<128x128x!tt.ptr<f16, 1>, #blocked2>
%52 = tt.addptr %51, %50 {async_agent = dense<1> : vector<1xi32>} : tensor<128x128x!tt.ptr<f16, 1>, #blocked2>, tensor<128x128xi32, #blocked2>
%53 = "triton_gpu.cmpi"(%31, %37) {async_agent = dense<1> : vector<1xi32>, predicate = 2 : i64} : (tensor<128xi32, #triton_gpu.slice<{dim = 1, parent = #blocked2}>>, tensor<128xi32, #triton_gpu.slice<{dim = 1, parent = #blocked2}>>) -> tensor<128xi1, #triton_gpu.slice<{dim = 1, parent = #blocked2}>>
%54 = tt.expand_dims %53 {async_agent = dense<1> : vector<1xi32>, axis = 1 : i32} : (tensor<128xi1, #triton_gpu.slice<{dim = 1, parent = #blocked2}>>) -> tensor<128x1xi1, #blocked2>
%55 = "triton_gpu.cmpi"(%36, %38) {async_agent = dense<1> : vector<1xi32>, predicate = 2 : i64} : (tensor<128xi32, #triton_gpu.slice<{dim = 0, parent = #blocked2}>>, tensor<128xi32, #triton_gpu.slice<{dim = 0, parent = #blocked2}>>) -> tensor<128xi1, #triton_gpu.slice<{dim = 0, parent = #blocked2}>>
%56 = tt.expand_dims %55 {async_agent = dense<1> : vector<1xi32>, axis = 0 : i32} : (tensor<128xi1, #triton_gpu.slice<{dim = 0, parent = #blocked2}>>) -> tensor<1x128xi1, #blocked2>
%57 = tt.broadcast %54 {async_agent = dense<1> : vector<1xi32>} : (tensor<128x1xi1, #blocked2>) -> tensor<128x128xi1, #blocked2>
%58 = tt.broadcast %56 {async_agent = dense<1> : vector<1xi32>} : (tensor<1x128xi1, #blocked2>) -> tensor<128x128xi1, #blocked2>
%59 = arith.andi %57, %58 {async_agent = dense<1> : vector<1xi32>} : tensor<128x128xi1, #blocked2>
%60 = triton_gpu.convert_layout %43 {async_agent = dense<1> : vector<1xi32>} : (tensor<128x128xf16, #mma>) -> tensor<128x128xf16, #blocked2>
tt.store %52, %60, %59 {async_agent = dense<1> : vector<1xi32>, cache = 1 : i32, evict = 1 : i32} : tensor<128x128xf16, #blocked2>
triton_nvidia_gpu.unlock %4 {mutex.barId = dense<4> : vector<1xi32>, mutex.numThreads = dense<256> : vector<1xi32>} : !triton_nvidia_gpu.mutex
}
%32 = triton_nvidia_gpu.dot_wait %31#0 {async_agent = dense<1> : vector<1xi32>, pendings = 0 : i32} : tensor<128x128xf32, #mma>
%c0_i32_2 = arith.constant {async_agent = dense<1> : vector<1xi32>} 0 : i32
%c1_i32_3 = arith.constant {async_agent = dense<1> : vector<1xi32>} 1 : i32
%c2_i32 = arith.constant {async_agent = dense<1> : vector<1xi32>} 2 : i32
%33 = arith.subi %31#2, %c1_i32_3 {async_agent = dense<1> : vector<1xi32>} : i32
%34 = arith.cmpi eq, %31#2, %c0_i32_2 {async_agent = dense<1> : vector<1xi32>} : i32
%35 = arith.select %34, %c2_i32, %33 {async_agent = dense<1> : vector<1xi32>} : i32
triton_nvidia_gpu.consumer_release %2, %35 {async_agent = dense<1> : vector<1xi32>} : tensor<3x!triton_nvidia_gpu.token>, i32
%36 = triton_gpu.convert_layout %32 {async_agent = dense<1> : vector<1xi32>} : (tensor<128x128xf32, #mma>) -> tensor<128x128xf32, #blocked2>
tt.store %27, %36 {async_agent = dense<1> : vector<1xi32>, boundaryCheck = array<i32: 0, 1>, cache = 1 : i32, evict = 1 : i32} : !tt.ptr<tensor<128x128xf32, #blocked>, 1>, tensor<128x128xf32, #blocked2>
} {async_agent = dense<1> : vector<1xi32>}
tt.return
}
}
// -----
#blocked = #triton_gpu.blocked<{sizePerThread = [8, 1], threadsPerWarp = [2, 16], warpsPerCTA = [1, 4], order = [0, 1], CTAsPerCGA = [1, 1], CTASplitNum = [1, 1], CTAOrder = [0, 1]}>
#blocked1 = #triton_gpu.blocked<{sizePerThread = [1, 8], threadsPerWarp = [16, 2], warpsPerCTA = [4, 1], order = [1, 0], CTAsPerCGA = [1, 1], CTASplitNum = [1, 1], CTAOrder = [0, 1]}>
#blocked2 = #triton_gpu.blocked<{sizePerThread = [1, 4], threadsPerWarp = [2, 16], warpsPerCTA = [4, 1], order = [1, 0], CTAsPerCGA = [1, 1], CTASplitNum = [1, 1], CTAOrder = [0, 1]}>
#blocked = #triton_gpu.blocked<{sizePerThread = [1, 1], threadsPerWarp = [32, 1], warpsPerCTA = [2, 2], order = [0, 1], CTAsPerCGA = [1, 1], CTASplitNum = [1, 1], CTAOrder = [0, 1]}>
#blocked1 = #triton_gpu.blocked<{sizePerThread = [1, 1], threadsPerWarp = [16, 2], warpsPerCTA = [1, 4], order = [0, 1], CTAsPerCGA = [1, 1], CTASplitNum = [1, 1], CTAOrder = [0, 1]}>
#blocked2 = #triton_gpu.blocked<{sizePerThread = [1, 8], threadsPerWarp = [4, 8], warpsPerCTA = [4, 1], order = [1, 0], CTAsPerCGA = [1, 1], CTASplitNum = [1, 1], CTAOrder = [0, 1]}>
#mma = #triton_gpu.mma<{versionMajor = 3, versionMinor = 0, warpsPerCTA = [4, 1], CTAsPerCGA = [1, 1], CTASplitNum = [1, 1], CTAOrder = [1, 0], instrShape = [16, 64, 16]}>
#shared = #triton_gpu.shared<{vec = 8, perPhase = 4, maxPhase = 2, order = [1, 0], CTAsPerCGA = [1, 1], CTASplitNum = [1, 1], CTAOrder = [0, 1], hasLeadingOffset = true}>
#shared1 = #triton_gpu.shared<{vec = 8, perPhase = 4, maxPhase = 2, order = [0, 1], CTAsPerCGA = [1, 1], CTASplitNum = [1, 1], CTAOrder = [0, 1], hasLeadingOffset = true}>
module attributes {"async.num-agents" = 2 : i32, "triton_gpu.compute-capability" = 90 : i32, "triton_gpu.enable-warp-specialization" = 1 : i32, "triton_gpu.num-ctas" = 1 : i32, "triton_gpu.num-warps" = 4 : i32} {
module attributes {"async.num-agents" = 2 : i32, "triton_gpu.compute-capability" = 90 : i32, "triton_gpu.enable-warp-specialization" = 1 : i32, "triton_gpu.num-ctas" = 1 : i32, "triton_gpu.num-warps" = 4 : i32, "triton_gpu.threads-per-warp" = 32 : i32} {
// CHECK-LABEL: @matmal_from_wsmutex
// CHECK: triton_nvidia_gpu.alloc_mbarrier
// CHECK: scf.if
// CHECK: scf.for
// CHECK: triton_nvidia_gpu.extract_mbarrier
// CHECK: triton_nvidia_gpu.mbarrier_wait
// CHECK: triton_gpu.insert_slice
// CHECK: triton_gpu.insert_slice
// CHECK: triton_nvidia_gpu.insert_slice_async_v2
// CHECK: triton_nvidia_gpu.insert_slice_async_v2
// CHECK: triton_nvidia_gpu.extract_mbarrier
// CHECK: triton_nvidia_gpu.mbarrier_arrive
// CHECK: scf.yield
@@ -239,174 +177,224 @@ module attributes {"async.num-agents" = 2 : i32, "triton_gpu.compute-capability"
// CHECK: triton_nvidia_gpu.mbarrier_wait
// CHECK: triton_gpu.extract_slice
// CHECK: triton_gpu.extract_slice
// CHECK: tt.dot
// CHECK: triton_nvidia_gpu.dot_async
// CHECK: triton_nvidia_gpu.extract_mbarrier
// CHECK: triton_nvidia_gpu.mbarrier_arrive
// CHECK: scf.yield
// CHECK: triton_nvidia_gpu.bar_arrive
// CHECK: triton_nvidia_gpu.dot_wait
// CHECK: triton_nvidia_gpu.extract_mbarrier
// CHECK: triton_nvidia_gpu.mbarrier_arrive
// CHECK: triton_nvidia_gpu.bar_wait
// CHECK: tt.store
// CHECK: triton_nvidia_gpu.bar_arrive
tt.func public @matmal_from_wsmutex(%arg0: !tt.ptr<f16, 1> {tt.divisibility = 16 : i32}, %arg1: !tt.ptr<f16, 1> {tt.divisibility = 16 : i32}, %arg2: !tt.ptr<f32, 1> {tt.divisibility = 16 : i32}, %arg3: i32 {tt.divisibility = 16 : i32, tt.max_divisibility = 8 : i32}, %arg4: i32 {tt.divisibility = 16 : i32, tt.max_divisibility = 8 : i32}, %arg5: i32 {tt.divisibility = 16 : i32, tt.max_divisibility = 8 : i32}, %arg6: i32 {tt.divisibility = 16 : i32, tt.max_divisibility = 8 : i32}, %arg7: i32 {tt.divisibility = 16 : i32, tt.max_divisibility = 8 : i32}, %arg8: i32 {tt.divisibility = 16 : i32, tt.max_divisibility = 8 : i32}) {
tt.func public @matmal_from_wsmutex(%arg0: !tt.ptr<f16, 1> {tt.divisibility = 16 : i32}, %arg1: !tt.ptr<f16, 1> {tt.divisibility = 16 : i32}, %arg2: !tt.ptr<f16, 1> {tt.divisibility = 16 : i32}, %arg3: !tt.ptr<f16, 1> {tt.divisibility = 16 : i32}, %arg4: !tt.ptr<f16, 1> {tt.divisibility = 16 : i32}, %arg5: i32 {tt.divisibility = 16 : i32, tt.max_divisibility = 8 : i32}, %arg6: i32 {tt.divisibility = 16 : i32, tt.max_divisibility = 8 : i32}, %arg7: i32 {tt.divisibility = 16 : i32, tt.max_divisibility = 8 : i32}, %arg8: i32 {tt.divisibility = 16 : i32, tt.max_divisibility = 8 : i32}, %arg9: i32 {tt.divisibility = 16 : i32, tt.max_divisibility = 8 : i32}, %arg10: i32 {tt.divisibility = 16 : i32, tt.max_divisibility = 8 : i32}) attributes {noinline = false} {
%0 = triton_gpu.alloc_tensor : tensor<3x64x16xf16, #shared>
%1 = triton_gpu.alloc_tensor : tensor<3x16x64xf16, #shared1>
%2 = triton_nvidia_gpu.create_token {num = 3 : i32} : tensor<3x!triton_nvidia_gpu.token>
%3 = triton_nvidia_gpu.get_agent_id : i32
%c63_i32 = arith.constant 63 : i32
%c0_i32 = arith.constant 0 : i32
%4 = arith.cmpi eq, %3, %c0_i32 : i32
scf.if %4 {
%cst = arith.constant {async_agent = dense<0> : vector<1xi32>} dense<16> : tensor<16x64xi32, #blocked>
%cst_0 = arith.constant {async_agent = dense<0> : vector<1xi32>} dense<16> : tensor<64x16xi32, #blocked1>
%c63_i32 = arith.constant {async_agent = dense<[0, 1]> : vector<2xi32>} 63 : i32
%c114_i32 = arith.constant {async_agent = dense<[0, 1]> : vector<2xi32>} 114 : i32
%c1_i64 = arith.constant 1 : i64
%c64_i32 = arith.constant 64 : i32
%c8_i32 = arith.constant 8 : i32
%3 = tt.get_program_id x : i32
%4 = arith.addi %arg6, %c63_i32 : i32
%5 = arith.divsi %4, %c64_i32 : i32
%6 = arith.addi %arg5, %c63_i32 : i32
%7 = arith.divsi %6, %c64_i32 : i32
%8 = arith.muli %5, %c8_i32 : i32
%9 = arith.divsi %3, %8 : i32
%10 = arith.muli %9, %c8_i32 : i32
%11 = arith.subi %7, %10 : i32
%12 = arith.minsi %11, %c8_i32 : i32
%13 = arith.remsi %3, %8 : i32
%14 = arith.remsi %13, %12 : i32
%15 = arith.addi %10, %14 : i32
%16 = arith.divsi %13, %12 : i32
%17 = arith.muli %15, %c64_i32 : i32
%18 = arith.muli %16, %c64_i32 : i32
%19 = arith.extsi %arg5 : i32 to i64
%20 = arith.extsi %arg7 : i32 to i64
%21 = arith.extsi %arg8 : i32 to i64
%22 = tt.make_tensor_ptr %arg0, [%19, %20], [%21, %c1_i64], [%17, %c0_i32] {order = array<i32: 1, 0>} : <tensor<64x16xf16, #blocked>, 1>
%23 = arith.extsi %arg6 : i32 to i64
%24 = arith.extsi %arg9 : i32 to i64
%25 = tt.make_tensor_ptr %arg1, [%20, %23], [%c1_i64, %24], [%c0_i32, %18] {order = array<i32: 0, 1>} : <tensor<16x64xf16, #blocked1>, 1>
%26 = arith.extsi %arg10 : i32 to i64
%27 = tt.make_tensor_ptr %arg4, [%19, %23], [%26, %c1_i64], [%17, %18] {order = array<i32: 1, 0>} : <tensor<64x64xf16, #blocked>, 1>
%28 = triton_nvidia_gpu.get_agent_id : i32
%c0_i32_0 = arith.constant 0 : i32
%29 = arith.cmpi eq, %28, %c0_i32_0 : i32
scf.if %29 {
%c132_i32 = arith.constant {async_agent = dense<[0, 1]> : vector<2xi32>} 132 : i32
%c15_i32 = arith.constant {async_agent = dense<0> : vector<1xi32>} 15 : i32
%c16_i32 = arith.constant {async_agent = dense<[0, 1]> : vector<2xi32>} 16 : i32
%c0_i32_1 = arith.constant {async_agent = dense<[0, 1]> : vector<2xi32>} 0 : i32
%c64_i32 = arith.constant {async_agent = dense<[0, 1]> : vector<2xi32>} 64 : i32
%6 = tt.get_program_id x {async_agent = dense<[0, 1]> : vector<2xi32>, axis = 0 : i32} : i32
%7 = arith.addi %arg3, %c63_i32 {async_agent = dense<[0, 1]> : vector<2xi32>} : i32
%8 = arith.divsi %7, %c64_i32 {async_agent = dense<[0, 1]> : vector<2xi32>} : i32
%9 = arith.addi %arg4, %c63_i32 {async_agent = dense<[0, 1]> : vector<2xi32>} : i32
%10 = arith.divsi %9, %c64_i32 {async_agent = dense<[0, 1]> : vector<2xi32>} : i32
%11 = arith.muli %8, %10 {async_agent = dense<[0, 1]> : vector<2xi32>} : i32
%12 = tt.make_range {async_agent = dense<0> : vector<1xi32>, end = 64 : i32, start = 0 : i32} : tensor<64xi32, #triton_gpu.slice<{dim = 1, parent = #blocked1}>>
%13 = tt.make_range {async_agent = dense<0> : vector<1xi32>, end = 64 : i32, start = 0 : i32} : tensor<64xi32, #triton_gpu.slice<{dim = 0, parent = #blocked}>>
%14 = tt.splat %arg3 {async_agent = dense<0> : vector<1xi32>} : (i32) -> tensor<64xi32, #triton_gpu.slice<{dim = 1, parent = #blocked1}>>
%15 = tt.splat %arg4 {async_agent = dense<0> : vector<1xi32>} : (i32) -> tensor<64xi32, #triton_gpu.slice<{dim = 0, parent = #blocked}>>
%16 = tt.splat %arg6 {async_agent = dense<0> : vector<1xi32>} : (i32) -> tensor<64x1xi32, #blocked1>
%17 = tt.make_range {async_agent = dense<0> : vector<1xi32>, end = 16 : i32, start = 0 : i32} : tensor<16xi32, #triton_gpu.slice<{dim = 0, parent = #blocked1}>>
%18 = tt.expand_dims %17 {async_agent = dense<0> : vector<1xi32>, axis = 0 : i32} : (tensor<16xi32, #triton_gpu.slice<{dim = 0, parent = #blocked1}>>) -> tensor<1x16xi32, #blocked1>
%19 = tt.broadcast %18 {async_agent = dense<0> : vector<1xi32>} : (tensor<1x16xi32, #blocked1>) -> tensor<64x16xi32, #blocked1>
%20 = tt.splat %arg0 {async_agent = dense<0> : vector<1xi32>} : (!tt.ptr<f16, 1>) -> tensor<64x16x!tt.ptr<f16, 1>, #blocked1>
%21 = tt.make_range {async_agent = dense<0> : vector<1xi32>, end = 16 : i32, start = 0 : i32} : tensor<16xi32, #triton_gpu.slice<{dim = 1, parent = #blocked}>>
%22 = tt.expand_dims %21 {async_agent = dense<0> : vector<1xi32>, axis = 1 : i32} : (tensor<16xi32, #triton_gpu.slice<{dim = 1, parent = #blocked}>>) -> tensor<16x1xi32, #blocked>
%23 = tt.splat %arg7 {async_agent = dense<0> : vector<1xi32>} : (i32) -> tensor<1x64xi32, #blocked>
%24 = tt.broadcast %22 {async_agent = dense<0> : vector<1xi32>} : (tensor<16x1xi32, #blocked>) -> tensor<16x64xi32, #blocked>
%25 = tt.splat %arg1 {async_agent = dense<0> : vector<1xi32>} : (!tt.ptr<f16, 1>) -> tensor<16x64x!tt.ptr<f16, 1>, #blocked>
%31 = arith.muli %7, %5 {async_agent = dense<[0, 1]> : vector<2xi32>} : i32
%32 = arith.addi %arg7, %c15_i32 {async_agent = dense<0> : vector<1xi32>} : i32
%33 = arith.divsi %32, %c16_i32 {async_agent = dense<0> : vector<1xi32>} : i32
%34 = arith.subi %c0_i32, %33 {async_agent = dense<0> : vector<1xi32>} : i32
%35 = arith.muli %34, %c16_i32 {async_agent = dense<0> : vector<1xi32>} : i32
%c3_i32 = arith.constant {async_agent = dense<[0, 1]> : vector<2xi32>} 3 : i32
%c0_i32_2 = arith.constant {async_agent = dense<[0, 1]> : vector<2xi32>} 0 : i32
%26 = scf.for %arg9 = %6 to %11 step %c114_i32 iter_args(%arg10 = %c0_i32_2) -> (i32) : i32 {
%27 = arith.divsi %arg9, %10 {async_agent = dense<0> : vector<1xi32>} : i32
%28 = arith.remsi %arg9, %10 {async_agent = dense<0> : vector<1xi32>} : i32
%29 = arith.muli %27, %c64_i32 {async_agent = dense<0> : vector<1xi32>} : i32
%30 = tt.splat %29 {async_agent = dense<0> : vector<1xi32>} : (i32) -> tensor<64xi32, #triton_gpu.slice<{dim = 1, parent = #blocked1}>>
%31 = arith.addi %30, %12 {async_agent = dense<0> : vector<1xi32>} : tensor<64xi32, #triton_gpu.slice<{dim = 1, parent = #blocked1}>>
%32 = arith.remsi %31, %14 {async_agent = dense<0> : vector<1xi32>} : tensor<64xi32, #triton_gpu.slice<{dim = 1, parent = #blocked1}>>
%33 = arith.muli %28, %c64_i32 {async_agent = dense<0> : vector<1xi32>} : i32
%34 = tt.splat %33 {async_agent = dense<0> : vector<1xi32>} : (i32) -> tensor<64xi32, #triton_gpu.slice<{dim = 0, parent = #blocked}>>
%35 = arith.addi %34, %13 {async_agent = dense<0> : vector<1xi32>} : tensor<64xi32, #triton_gpu.slice<{dim = 0, parent = #blocked}>>
%36 = arith.remsi %35, %15 {async_agent = dense<0> : vector<1xi32>} : tensor<64xi32, #triton_gpu.slice<{dim = 0, parent = #blocked}>>
%37 = tt.expand_dims %32 {async_agent = dense<0> : vector<1xi32>, axis = 1 : i32} : (tensor<64xi32, #triton_gpu.slice<{dim = 1, parent = #blocked1}>>) -> tensor<64x1xi32, #blocked1>
%38 = arith.muli %37, %16 {async_agent = dense<0> : vector<1xi32>} : tensor<64x1xi32, #blocked1>
%39 = tt.broadcast %38 {async_agent = dense<0> : vector<1xi32>} : (tensor<64x1xi32, #blocked1>) -> tensor<64x16xi32, #blocked1>
%40 = arith.addi %39, %19 {async_agent = dense<0> : vector<1xi32>} : tensor<64x16xi32, #blocked1>
%41 = tt.addptr %20, %40 {async_agent = dense<0> : vector<1xi32>} : tensor<64x16x!tt.ptr<f16, 1>, #blocked1>, tensor<64x16xi32, #blocked1>
%42 = tt.expand_dims %36 {async_agent = dense<0> : vector<1xi32>, axis = 0 : i32} : (tensor<64xi32, #triton_gpu.slice<{dim = 0, parent = #blocked}>>) -> tensor<1x64xi32, #blocked>
%43 = arith.muli %42, %23 {async_agent = dense<0> : vector<1xi32>} : tensor<1x64xi32, #blocked>
%44 = tt.broadcast %43 {async_agent = dense<0> : vector<1xi32>} : (tensor<1x64xi32, #blocked>) -> tensor<16x64xi32, #blocked>
%45 = arith.addi %24, %44 {async_agent = dense<0> : vector<1xi32>} : tensor<16x64xi32, #blocked>
%46 = tt.addptr %25, %45 {async_agent = dense<0> : vector<1xi32>} : tensor<16x64x!tt.ptr<f16, 1>, #blocked>, tensor<16x64xi32, #blocked>
%c3_i32_3 = arith.constant {async_agent = dense<0> : vector<1xi32>} 3 : i32
%47 = arith.subi %arg5, %c0_i32_1 {async_agent = dense<0> : vector<1xi32>} : i32
%48 = arith.divui %47, %c16_i32 {async_agent = dense<0> : vector<1xi32>} : i32
%49 = arith.muli %arg10, %48 {async_agent = dense<0> : vector<1xi32>} : i32
%c3_i32_4 = arith.constant {async_agent = dense<0> : vector<1xi32>} 3 : i32
%50:3 = scf.for %arg11 = %c0_i32_1 to %arg5 step %c16_i32 iter_args(%arg12 = %41, %arg13 = %46, %arg14 = %49) -> (tensor<64x16x!tt.ptr<f16, 1>, #blocked1>, tensor<16x64x!tt.ptr<f16, 1>, #blocked>, i32) : i32 {
%52 = arith.remsi %arg14, %c3_i32_4 {async_agent = dense<0> : vector<1xi32>} : i32
triton_nvidia_gpu.producer_acquire %2, %52 {async_agent = dense<0> : vector<1xi32>} : tensor<3x!triton_nvidia_gpu.token>, i32
%53 = triton_gpu.insert_slice %arg12, %0, %52 {async_agent = dense<0> : vector<1xi32>, axis = 0 : i32, cache = 1 : i32, evict = 1 : i32, isVolatile = false} : tensor<64x16x!tt.ptr<f16, 1>, #blocked1> -> tensor<3x64x16xf16, #shared>
%54 = triton_gpu.insert_slice %arg13, %1, %52 {async_agent = dense<0> : vector<1xi32>, axis = 0 : i32, cache = 1 : i32, evict = 1 : i32, isVolatile = false} : tensor<16x64x!tt.ptr<f16, 1>, #blocked> -> tensor<3x16x64xf16, #shared1>
triton_nvidia_gpu.producer_commit %2, %52 {async_agent = dense<0> : vector<1xi32>} : tensor<3x!triton_nvidia_gpu.token>, i32
%55 = tt.addptr %arg12, %cst_0 {async_agent = dense<0> : vector<1xi32>} : tensor<64x16x!tt.ptr<f16, 1>, #blocked1>, tensor<64x16xi32, #blocked1>
%56 = tt.addptr %arg13, %cst {async_agent = dense<0> : vector<1xi32>} : tensor<16x64x!tt.ptr<f16, 1>, #blocked>, tensor<16x64xi32, #blocked>
%c0_i32_1 = arith.constant {async_agent = dense<[0, 1]> : vector<2xi32>} 0 : i32
%36:5 = scf.for %arg11 = %3 to %31 step %c132_i32 iter_args(%arg12 = %22, %arg13 = %25, %arg14 = %15, %arg15 = %16, %arg16 = %c0_i32_1) -> (!tt.ptr<tensor<64x16xf16, #blocked>, 1>, !tt.ptr<tensor<16x64xf16, #blocked1>, 1>, i32, i32, i32) : i32 {
%37 = arith.divsi %arg11, %8 {async_agent = dense<0> : vector<1xi32>} : i32
%38 = arith.muli %37, %c8_i32 {async_agent = dense<0> : vector<1xi32>} : i32
%39 = arith.subi %7, %38 {async_agent = dense<0> : vector<1xi32>} : i32
%40 = arith.minsi %39, %c8_i32 {async_agent = dense<0> : vector<1xi32>} : i32
%41 = arith.remsi %arg11, %8 {async_agent = dense<0> : vector<1xi32>} : i32
%42 = arith.remsi %41, %40 {async_agent = dense<0> : vector<1xi32>} : i32
%43 = arith.addi %38, %42 {async_agent = dense<0> : vector<1xi32>} : i32
%44 = arith.divsi %41, %40 {async_agent = dense<0> : vector<1xi32>} : i32
%45 = arith.subi %43, %arg14 {async_agent = dense<0> : vector<1xi32>} : i32
%46 = arith.muli %45, %c64_i32 {async_agent = dense<0> : vector<1xi32>} : i32
%47 = tt.advance %arg12, [%46, %c0_i32] {async_agent = dense<0> : vector<1xi32>} : <tensor<64x16xf16, #blocked>, 1>
%48 = arith.subi %44, %arg15 {async_agent = dense<0> : vector<1xi32>} : i32
%49 = arith.muli %48, %c64_i32 {async_agent = dense<0> : vector<1xi32>} : i32
%50 = tt.advance %arg13, [%c0_i32, %49] {async_agent = dense<0> : vector<1xi32>} : <tensor<16x64xf16, #blocked1>, 1>
%c3_i32_2 = arith.constant {async_agent = dense<0> : vector<1xi32>} 3 : i32
%c0_i32_3 = arith.constant {async_agent = dense<0> : vector<1xi32>} 0 : i32
%51 = arith.subi %arg7, %c0_i32 {async_agent = dense<0> : vector<1xi32>} : i32
%52 = arith.addi %51, %c16_i32 {async_agent = dense<0> : vector<1xi32>} : i32
%c1_i32_4 = arith.constant {async_agent = dense<0> : vector<1xi32>} 1 : i32
%c2_i32 = arith.constant {async_agent = dense<0> : vector<1xi32>} 2 : i32
%53 = arith.subi %52, %c1_i32_4 {async_agent = dense<0> : vector<1xi32>} : i32
%54 = arith.divui %53, %c16_i32 {async_agent = dense<0> : vector<1xi32>} : i32
%55 = arith.muli %arg16, %54 {async_agent = dense<0> : vector<1xi32>} : i32
%56 = arith.divui %55, %c3_i32_2 {async_agent = dense<0> : vector<1xi32>} : i32
%57 = arith.muli %56, %c3_i32_2 {async_agent = dense<0> : vector<1xi32>} : i32
%58 = arith.subi %55, %57 {async_agent = dense<0> : vector<1xi32>} : i32
%59 = arith.andi %56, %c1_i32_4 {async_agent = dense<0> : vector<1xi32>} : i32
%60 = arith.trunci %59 {async_agent = dense<0> : vector<1xi32>} : i32 to i1
%61:4 = scf.for %arg17 = %c0_i32 to %arg7 step %c16_i32 iter_args(%arg18 = %47, %arg19 = %50, %arg20 = %60, %arg21 = %58) -> (!tt.ptr<tensor<64x16xf16, #blocked>, 1>, !tt.ptr<tensor<16x64xf16, #blocked1>, 1>, i1, i32) : i32 {
triton_nvidia_gpu.producer_acquire %2, %arg21 {async_agent = dense<0> : vector<1xi32>} : tensor<3x!triton_nvidia_gpu.token>, i32
%65 = triton_gpu.insert_slice %arg18, %0, %arg21 {async_agent = dense<0> : vector<1xi32>, axis = 0 : i32, cache = 1 : i32, evict = 1 : i32, isVolatile = false} : !tt.ptr<tensor<64x16xf16, #blocked>, 1> -> tensor<3x64x16xf16, #shared>
%66 = triton_gpu.insert_slice %arg19, %1, %arg21 {async_agent = dense<0> : vector<1xi32>, axis = 0 : i32, cache = 1 : i32, evict = 1 : i32, isVolatile = false} : !tt.ptr<tensor<16x64xf16, #blocked1>, 1> -> tensor<3x16x64xf16, #shared1>
triton_nvidia_gpu.producer_commit %2, %arg21 {async_agent = dense<0> : vector<1xi32>} : tensor<3x!triton_nvidia_gpu.token>, i32
%67 = tt.advance %arg18, [%c0_i32, %c16_i32] {async_agent = dense<0> : vector<1xi32>} : <tensor<64x16xf16, #blocked>, 1>
%68 = tt.advance %arg19, [%c16_i32, %c0_i32] {async_agent = dense<0> : vector<1xi32>} : <tensor<16x64xf16, #blocked1>, 1>
%c1_i32_6 = arith.constant {async_agent = dense<0> : vector<1xi32>} 1 : i32
%57 = arith.addi %arg14, %c1_i32_6 {async_agent = dense<0> : vector<1xi32>} : i32
scf.yield {async_agent = dense<0> : vector<1xi32>} %55, %56, %57 : tensor<64x16x!tt.ptr<f16, 1>, #blocked1>, tensor<16x64x!tt.ptr<f16, 1>, #blocked>, i32
%c0_i32_7 = arith.constant {async_agent = dense<0> : vector<1xi32>} 0 : i32
%true = arith.constant {async_agent = dense<0> : vector<1xi32>} true
%69 = arith.addi %arg21, %c1_i32_6 {async_agent = dense<0> : vector<1xi32>} : i32
%70 = arith.cmpi uge, %69, %c3_i32_2 {async_agent = dense<0> : vector<1xi32>} : i32
%71 = arith.cmpi ult, %69, %c3_i32_2 {async_agent = dense<0> : vector<1xi32>} : i32
%72 = arith.subi %69, %c3_i32_2 {async_agent = dense<0> : vector<1xi32>} : i32
%73 = arith.select %70, %72, %69 {async_agent = dense<0> : vector<1xi32>} : i32
%74 = arith.xori %arg20, %true {async_agent = dense<0> : vector<1xi32>} : i1
%75 = arith.andi %70, %74 {async_agent = dense<0> : vector<1xi32>} : i1
%76 = arith.andi %71, %arg20 {async_agent = dense<0> : vector<1xi32>} : i1
%77 = arith.ori %75, %76 {async_agent = dense<0> : vector<1xi32>} : i1
scf.yield {async_agent = dense<0> : vector<1xi32>} %67, %68, %77, %73 : !tt.ptr<tensor<64x16xf16, #blocked>, 1>, !tt.ptr<tensor<16x64xf16, #blocked1>, 1>, i1, i32
} {async_agent = dense<0> : vector<1xi32>}
%62 = tt.advance %61#0, [%c0_i32, %35] {async_agent = dense<0> : vector<1xi32>} : <tensor<64x16xf16, #blocked>, 1>
%63 = tt.advance %61#1, [%35, %c0_i32] {async_agent = dense<0> : vector<1xi32>} : <tensor<16x64xf16, #blocked1>, 1>
%c1_i32_5 = arith.constant {async_agent = dense<0> : vector<1xi32>} 1 : i32
%51 = arith.addi %arg10, %c1_i32_5 {async_agent = dense<0> : vector<1xi32>} : i32
scf.yield {async_agent = dense<0> : vector<1xi32>} %51 : i32
%64 = arith.addi %arg16, %c1_i32_5 {async_agent = dense<0> : vector<1xi32>} : i32
scf.yield {async_agent = dense<0> : vector<1xi32>} %62, %63, %43, %44, %64 : !tt.ptr<tensor<64x16xf16, #blocked>, 1>, !tt.ptr<tensor<16x64xf16, #blocked1>, 1>, i32, i32, i32
} {async_agent = dense<0> : vector<1xi32>}
} {async_agent = dense<0> : vector<1xi32>}
%c1_i32 = arith.constant 1 : i32
%5 = arith.cmpi eq, %3, %c1_i32 : i32
scf.if %5 {
%c0_i32_0 = arith.constant 0 : i32
%6 = triton_nvidia_gpu.get_mutex_role_id {async_agent = dense<1> : vector<1xi32>, num = 2 : i32} : i32
%7 = arith.cmpi ne, %6, %c0_i32_0 : i32
%8 = triton_nvidia_gpu.create_mutex {async_agent = dense<1> : vector<1xi32>} : !triton_nvidia_gpu.mutex
%9 = triton_nvidia_gpu.create_mutex {async_agent = dense<1> : vector<1xi32>} : !triton_nvidia_gpu.mutex
%30 = arith.cmpi eq, %28, %c1_i32 : i32
scf.if %30 {
%c0_i32_1 = arith.constant 0 : i32
%31 = triton_nvidia_gpu.get_mutex_role_id {async_agent = dense<1> : vector<1xi32>, num = 2 : i32} : i32
%32 = arith.cmpi ne, %31, %c0_i32_1 : i32
%33 = triton_nvidia_gpu.create_mutex {async_agent = dense<1> : vector<1xi32>} : !triton_nvidia_gpu.mutex
%34 = triton_nvidia_gpu.create_mutex {async_agent = dense<1> : vector<1xi32>} : !triton_nvidia_gpu.mutex
%cst = arith.constant {async_agent = dense<1> : vector<1xi32>} dense<0.000000e+00> : tensor<64x64xf32, #mma>
%c63_i32 = arith.constant {async_agent = dense<[0, 1]> : vector<2xi32>} 63 : i32
%c114_i32 = arith.constant {async_agent = dense<[0, 1]> : vector<2xi32>} 114 : i32
%c132_i32 = arith.constant {async_agent = dense<[0, 1]> : vector<2xi32>} 132 : i32
%c16_i32 = arith.constant {async_agent = dense<[0, 1]> : vector<2xi32>} 16 : i32
%c0_i32_1 = arith.constant {async_agent = dense<[0, 1]> : vector<2xi32>} 0 : i32
%c64_i32 = arith.constant {async_agent = dense<[0, 1]> : vector<2xi32>} 64 : i32
%10 = tt.get_program_id x {async_agent = dense<[0, 1]> : vector<2xi32>, axis = 0 : i32} : i32
%11 = arith.addi %arg3, %c63_i32 {async_agent = dense<[0, 1]> : vector<2xi32>} : i32
%12 = arith.divsi %11, %c64_i32 {async_agent = dense<[0, 1]> : vector<2xi32>} : i32
%13 = arith.addi %arg4, %c63_i32 {async_agent = dense<[0, 1]> : vector<2xi32>} : i32
%14 = arith.divsi %13, %c64_i32 {async_agent = dense<[0, 1]> : vector<2xi32>} : i32
%15 = arith.muli %12, %14 {async_agent = dense<[0, 1]> : vector<2xi32>} : i32
%16 = tt.make_range {async_agent = dense<1> : vector<1xi32>, end = 64 : i32, start = 0 : i32} : tensor<64xi32, #triton_gpu.slice<{dim = 0, parent = #blocked2}>>
%17 = tt.make_range {async_agent = dense<1> : vector<1xi32>, end = 64 : i32, start = 0 : i32} : tensor<64xi32, #triton_gpu.slice<{dim = 1, parent = #blocked2}>>
%18 = tt.splat %arg8 {async_agent = dense<1> : vector<1xi32>} : (i32) -> tensor<64x1xi32, #blocked2>
%19 = tt.splat %arg2 {async_agent = dense<1> : vector<1xi32>} : (!tt.ptr<f32, 1>) -> tensor<64x1x!tt.ptr<f32, 1>, #blocked2>
%35 = arith.muli %7, %5 {async_agent = dense<[0, 1]> : vector<2xi32>} : i32
%c3_i32 = arith.constant {async_agent = dense<[0, 1]> : vector<2xi32>} 3 : i32
%c0_i32_2 = arith.constant {async_agent = dense<[0, 1]> : vector<2xi32>} 0 : i32
%20 = arith.muli %c114_i32, %6 {async_agent = dense<1> : vector<1xi32>} : i32
%21 = arith.addi %10, %20 {async_agent = dense<1> : vector<1xi32>} : i32
%36 = arith.muli %c132_i32, %31 {async_agent = dense<1> : vector<1xi32>} : i32
%37 = arith.addi %3, %36 {async_agent = dense<1> : vector<1xi32>} : i32
%c2_i32 = arith.constant {async_agent = dense<1> : vector<1xi32>} 2 : i32
%22 = arith.muli %c114_i32, %c2_i32 {async_agent = dense<1> : vector<1xi32>} : i32
%23 = arith.addi %c0_i32_2, %6 {async_agent = dense<1> : vector<1xi32>} : i32
%24 = scf.for %arg9 = %21 to %15 step %22 iter_args(%arg10 = %23) -> (i32) : i32 {
%25 = arith.cmpi ne, %arg9, %10 : i32
%26 = arith.ori %25, %7 {agent.mutex_role = 0 : i32} : i1
scf.if %26 {
triton_nvidia_gpu.lock %8 {agent.mutex_role = 0 : i32} : !triton_nvidia_gpu.mutex
%38 = arith.muli %c132_i32, %c2_i32 {async_agent = dense<1> : vector<1xi32>} : i32
%39 = arith.addi %c0_i32_2, %31 {async_agent = dense<1> : vector<1xi32>} : i32
%40:4 = scf.for %arg11 = %37 to %35 step %38 iter_args(%arg12 = %27, %arg13 = %15, %arg14 = %16, %arg15 = %39) -> (!tt.ptr<tensor<64x64xf16, #blocked>, 1>, i32, i32, i32) : i32 {
%41 = arith.cmpi ne, %arg11, %3 : i32
%42 = arith.ori %41, %32 : i1
scf.if %42 {
triton_nvidia_gpu.lock %33 {agent.mutex_role = 0 : i32} : !triton_nvidia_gpu.mutex
} {agent.mutex_role = 0 : i32}
%27 = arith.divsi %arg9, %14 {agent.mutex_role = 0 : i32, async_agent = dense<1> : vector<1xi32>} : i32
%28 = arith.remsi %arg9, %14 {agent.mutex_role = 0 : i32, async_agent = dense<1> : vector<1xi32>} : i32
%29 = arith.muli %27, %c64_i32 {agent.mutex_role = 0 : i32, async_agent = dense<1> : vector<1xi32>} : i32
%30 = tt.splat %29 {agent.mutex_role = 0 : i32, async_agent = dense<1> : vector<1xi32>} : (i32) -> tensor<64xi32, #triton_gpu.slice<{dim = 1, parent = #blocked2}>>
%31 = arith.addi %30, %17 {agent.mutex_role = 0 : i32, async_agent = dense<1> : vector<1xi32>} : tensor<64xi32, #triton_gpu.slice<{dim = 1, parent = #blocked2}>>
%32 = arith.muli %28, %c64_i32 {agent.mutex_role = 0 : i32, async_agent = dense<1> : vector<1xi32>} : i32
%33 = tt.splat %32 {agent.mutex_role = 0 : i32, async_agent = dense<1> : vector<1xi32>} : (i32) -> tensor<64xi32, #triton_gpu.slice<{dim = 0, parent = #blocked2}>>
%34 = arith.addi %33, %16 {agent.mutex_role = 0 : i32, async_agent = dense<1> : vector<1xi32>} : tensor<64xi32, #triton_gpu.slice<{dim = 0, parent = #blocked2}>>
%43 = arith.divsi %arg11, %8 {agent.mutex_role = 0 : i32, async_agent = dense<1> : vector<1xi32>} : i32
%44 = arith.muli %43, %c8_i32 {agent.mutex_role = 0 : i32, async_agent = dense<1> : vector<1xi32>} : i32
%45 = arith.subi %7, %44 {agent.mutex_role = 0 : i32, async_agent = dense<1> : vector<1xi32>} : i32
%46 = arith.minsi %45, %c8_i32 {agent.mutex_role = 0 : i32, async_agent = dense<1> : vector<1xi32>} : i32
%47 = arith.remsi %arg11, %8 {agent.mutex_role = 0 : i32, async_agent = dense<1> : vector<1xi32>} : i32
%48 = arith.remsi %47, %46 {agent.mutex_role = 0 : i32, async_agent = dense<1> : vector<1xi32>} : i32
%49 = arith.addi %44, %48 {agent.mutex_role = 0 : i32, async_agent = dense<1> : vector<1xi32>} : i32
%50 = arith.divsi %47, %46 {agent.mutex_role = 0 : i32, async_agent = dense<1> : vector<1xi32>} : i32
%51 = arith.subi %49, %arg13 {agent.mutex_role = 0 : i32, async_agent = dense<1> : vector<1xi32>} : i32
%52 = arith.muli %51, %c64_i32 {agent.mutex_role = 0 : i32, async_agent = dense<1> : vector<1xi32>} : i32
%53 = arith.subi %50, %arg14 {agent.mutex_role = 0 : i32, async_agent = dense<1> : vector<1xi32>} : i32
%54 = arith.muli %53, %c64_i32 {agent.mutex_role = 0 : i32, async_agent = dense<1> : vector<1xi32>} : i32
%c3_i32_3 = arith.constant {agent.mutex_role = 0 : i32, async_agent = dense<1> : vector<1xi32>} 3 : i32
%35 = arith.subi %arg5, %c0_i32_1 {agent.mutex_role = 0 : i32, async_agent = dense<1> : vector<1xi32>} : i32
%36 = arith.divui %35, %c16_i32 {agent.mutex_role = 0 : i32, async_agent = dense<1> : vector<1xi32>} : i32
%37 = arith.muli %arg10, %36 {agent.mutex_role = 0 : i32, async_agent = dense<1> : vector<1xi32>} : i32
%c3_i32_4 = arith.constant {agent.mutex_role = 0 : i32, async_agent = dense<1> : vector<1xi32>} 3 : i32
%38:2 = scf.for %arg11 = %c0_i32_1 to %arg5 step %c16_i32 iter_args(%arg12 = %cst, %arg13 = %37) -> (tensor<64x64xf32, #mma>, i32) : i32 {
%48 = arith.remsi %arg13, %c3_i32_4 {agent.mutex_role = 0 : i32, async_agent = dense<1> : vector<1xi32>} : i32
triton_nvidia_gpu.consumer_wait %2, %48 {agent.mutex_role = 0 : i32, async_agent = dense<1> : vector<1xi32>} : tensor<3x!triton_nvidia_gpu.token>, i32
%49 = triton_gpu.extract_slice %0[%48, 0, 0] [1, 64, 16] [1, 1, 1] {agent.mutex_role = 0 : i32, async_agent = dense<1> : vector<1xi32>} : tensor<3x64x16xf16, #shared> to tensor<64x16xf16, #shared>
%50 = triton_gpu.convert_layout %49 {agent.mutex_role = 0 : i32, async_agent = dense<1> : vector<1xi32>} : (tensor<64x16xf16, #shared>) -> tensor<64x16xf16, #shared>
%51 = triton_gpu.extract_slice %1[%48, 0, 0] [1, 16, 64] [1, 1, 1] {agent.mutex_role = 0 : i32, async_agent = dense<1> : vector<1xi32>} : tensor<3x16x64xf16, #shared1> to tensor<16x64xf16, #shared1>
%52 = triton_gpu.convert_layout %51 {agent.mutex_role = 0 : i32, async_agent = dense<1> : vector<1xi32>} : (tensor<16x64xf16, #shared1>) -> tensor<16x64xf16, #shared1>
%53 = tt.dot %50, %52, %arg12 {agent.mutex_role = 0 : i32, allowTF32 = true, maxNumImpreciseAcc = 0 : i32, async_agent = dense<1> : vector<1xi32>} : tensor<64x16xf16, #shared> * tensor<16x64xf16, #shared1> -> tensor<64x64xf32, #mma>
triton_nvidia_gpu.consumer_release %2, %48 {agent.mutex_role = 0 : i32, async_agent = dense<1> : vector<1xi32>} : tensor<3x!triton_nvidia_gpu.token>, i32
%c1_i32_6 = arith.constant {agent.mutex_role = 0 : i32, async_agent = dense<1> : vector<1xi32>} 1 : i32
%54 = arith.addi %arg13, %c1_i32_6 {agent.mutex_role = 0 : i32, async_agent = dense<1> : vector<1xi32>} : i32
scf.yield {agent.mutex_role = 0 : i32, async_agent = dense<1> : vector<1xi32>} %53, %54 : tensor<64x64xf32, #mma>, i32
} {agent.mutex_role = 1 : i32, async_agent = dense<1> : vector<1xi32>}
triton_nvidia_gpu.unlock %8 : !triton_nvidia_gpu.mutex
scf.if %26 {
triton_nvidia_gpu.lock %9 {agent.mutex_role = 1 : i32} : !triton_nvidia_gpu.mutex
%c0_i32_4 = arith.constant {agent.mutex_role = 0 : i32, async_agent = dense<1> : vector<1xi32>} 0 : i32
%55 = arith.subi %arg7, %c0_i32 {agent.mutex_role = 0 : i32, async_agent = dense<1> : vector<1xi32>} : i32
%56 = arith.addi %55, %c16_i32 {agent.mutex_role = 0 : i32, async_agent = dense<1> : vector<1xi32>} : i32
%c1_i32_5 = arith.constant {agent.mutex_role = 0 : i32, async_agent = dense<1> : vector<1xi32>} 1 : i32
%c2_i32_6 = arith.constant {agent.mutex_role = 0 : i32, async_agent = dense<1> : vector<1xi32>} 2 : i32
%57 = arith.subi %56, %c1_i32_5 {agent.mutex_role = 0 : i32, async_agent = dense<1> : vector<1xi32>} : i32
%58 = arith.divui %57, %c16_i32 {agent.mutex_role = 0 : i32, async_agent = dense<1> : vector<1xi32>} : i32
%59 = arith.muli %arg15, %58 {agent.mutex_role = 0 : i32, async_agent = dense<1> : vector<1xi32>} : i32
%60 = arith.divui %59, %c3_i32_3 {agent.mutex_role = 0 : i32, async_agent = dense<1> : vector<1xi32>} : i32
%61 = arith.muli %60, %c3_i32_3 {agent.mutex_role = 0 : i32, async_agent = dense<1> : vector<1xi32>} : i32
%62 = arith.subi %59, %61 {agent.mutex_role = 0 : i32, async_agent = dense<1> : vector<1xi32>} : i32
%63 = arith.andi %60, %c1_i32_5 {agent.mutex_role = 0 : i32, async_agent = dense<1> : vector<1xi32>} : i32
%64 = arith.trunci %63 {agent.mutex_role = 0 : i32, async_agent = dense<1> : vector<1xi32>} : i32 to i1
%65:3 = scf.for %arg16 = %c0_i32 to %arg7 step %c16_i32 iter_args(%arg17 = %cst, %arg18 = %64, %arg19 = %62) -> (tensor<64x64xf32, #mma>, i1, i32) : i32 {
triton_nvidia_gpu.consumer_wait %2, %arg19 {agent.mutex_role = 0 : i32, async_agent = dense<1> : vector<1xi32>} : tensor<3x!triton_nvidia_gpu.token>, i32
%74 = triton_gpu.extract_slice %0[%arg19, 0, 0] [1, 64, 16] [1, 1, 1] {agent.mutex_role = 0 : i32, async_agent = dense<1> : vector<1xi32>} : tensor<3x64x16xf16, #shared> to tensor<64x16xf16, #shared>
%75 = triton_gpu.convert_layout %74 {agent.mutex_role = 0 : i32, async_agent = dense<1> : vector<1xi32>} : (tensor<64x16xf16, #shared>) -> tensor<64x16xf16, #shared>
%76 = triton_gpu.extract_slice %1[%arg19, 0, 0] [1, 16, 64] [1, 1, 1] {agent.mutex_role = 0 : i32, async_agent = dense<1> : vector<1xi32>} : tensor<3x16x64xf16, #shared1> to tensor<16x64xf16, #shared1>
%77 = triton_gpu.convert_layout %76 {agent.mutex_role = 0 : i32, async_agent = dense<1> : vector<1xi32>} : (tensor<16x64xf16, #shared1>) -> tensor<16x64xf16, #shared1>
%78 = triton_nvidia_gpu.dot_async %75, %77, %arg17 {agent.mutex_role = 0 : i32, allowTF32 = true, async_agent = dense<1> : vector<1xi32>, maxNumImpreciseAcc = 0 : i32} : tensor<64x16xf16, #shared> * tensor<16x64xf16, #shared1> -> tensor<64x64xf32, #mma>
%79 = arith.cmpi sgt, %arg16, %c0_i32 {agent.mutex_role = 0 : i32, async_agent = dense<1> : vector<1xi32>} : i32
scf.if %79 {
%c0_i32_13 = arith.constant {agent.mutex_role = 0 : i32, async_agent = dense<1> : vector<1xi32>} 0 : i32
%c1_i32_14 = arith.constant {agent.mutex_role = 0 : i32, async_agent = dense<1> : vector<1xi32>} 1 : i32
%c2_i32_15 = arith.constant {agent.mutex_role = 0 : i32, async_agent = dense<1> : vector<1xi32>} 2 : i32
%89 = arith.subi %arg19, %c1_i32_14 {agent.mutex_role = 0 : i32, async_agent = dense<1> : vector<1xi32>} : i32
%90 = arith.cmpi eq, %arg19, %c0_i32_13 {agent.mutex_role = 0 : i32, async_agent = dense<1> : vector<1xi32>} : i32
%91 = arith.select %90, %c2_i32_15, %89 {agent.mutex_role = 0 : i32, async_agent = dense<1> : vector<1xi32>} : i32
triton_nvidia_gpu.consumer_release %2, %91 {agent.mutex_role = 0 : i32, async_agent = dense<1> : vector<1xi32>} : tensor<3x!triton_nvidia_gpu.token>, i32
} {agent.mutex_role = 0 : i32, async_agent = dense<1> : vector<1xi32>}
%c1_i32_11 = arith.constant {agent.mutex_role = 0 : i32, async_agent = dense<1> : vector<1xi32>} 1 : i32
%c0_i32_12 = arith.constant {agent.mutex_role = 0 : i32, async_agent = dense<1> : vector<1xi32>} 0 : i32
%true = arith.constant {agent.mutex_role = 0 : i32, async_agent = dense<1> : vector<1xi32>} true
%80 = arith.addi %arg19, %c1_i32_11 {agent.mutex_role = 0 : i32, async_agent = dense<1> : vector<1xi32>} : i32
%81 = arith.cmpi uge, %80, %c3_i32_3 {agent.mutex_role = 0 : i32, async_agent = dense<1> : vector<1xi32>} : i32
%82 = arith.cmpi ult, %80, %c3_i32_3 {agent.mutex_role = 0 : i32, async_agent = dense<1> : vector<1xi32>} : i32
%83 = arith.subi %80, %c3_i32_3 {agent.mutex_role = 0 : i32, async_agent = dense<1> : vector<1xi32>} : i32
%84 = arith.select %81, %83, %80 {agent.mutex_role = 0 : i32, async_agent = dense<1> : vector<1xi32>} : i32
%85 = arith.xori %arg18, %true {agent.mutex_role = 0 : i32, async_agent = dense<1> : vector<1xi32>} : i1
%86 = arith.andi %81, %85 {agent.mutex_role = 0 : i32, async_agent = dense<1> : vector<1xi32>} : i1
%87 = arith.andi %82, %arg18 {agent.mutex_role = 0 : i32, async_agent = dense<1> : vector<1xi32>} : i1
%88 = arith.ori %86, %87 {agent.mutex_role = 0 : i32, async_agent = dense<1> : vector<1xi32>} : i1
scf.yield {agent.mutex_role = 0 : i32, async_agent = dense<1> : vector<1xi32>} %78, %88, %84 : tensor<64x64xf32, #mma>, i1, i32
} {agent.mutex_role = 0 : i32, async_agent = dense<1> : vector<1xi32>}
triton_nvidia_gpu.unlock %33 : !triton_nvidia_gpu.mutex
%66 = triton_nvidia_gpu.dot_wait %65#0 {agent.mutex_role = 0 : i32, async_agent = dense<1> : vector<1xi32>, pendings = 0 : i32} : tensor<64x64xf32, #mma>
%c0_i32_7 = arith.constant {agent.mutex_role = 0 : i32, async_agent = dense<1> : vector<1xi32>} 0 : i32
%c1_i32_8 = arith.constant {agent.mutex_role = 0 : i32, async_agent = dense<1> : vector<1xi32>} 1 : i32
%c2_i32_9 = arith.constant {agent.mutex_role = 0 : i32, async_agent = dense<1> : vector<1xi32>} 2 : i32
%67 = arith.subi %65#2, %c1_i32_8 {agent.mutex_role = 0 : i32, async_agent = dense<1> : vector<1xi32>} : i32
%68 = arith.cmpi eq, %65#2, %c0_i32_7 {agent.mutex_role = 0 : i32, async_agent = dense<1> : vector<1xi32>} : i32
%69 = arith.select %68, %c2_i32_9, %67 {agent.mutex_role = 0 : i32, async_agent = dense<1> : vector<1xi32>} : i32
triton_nvidia_gpu.consumer_release %2, %69 {agent.mutex_role = 0 : i32, async_agent = dense<1> : vector<1xi32>} : tensor<3x!triton_nvidia_gpu.token>, i32
scf.if %42 {
triton_nvidia_gpu.lock %34 {agent.mutex_role = 1 : i32} : !triton_nvidia_gpu.mutex
} {agent.mutex_role = 1 : i32}
%39 = tt.expand_dims %31 {agent.mutex_role = 1 : i32, async_agent = dense<1> : vector<1xi32>, axis = 1 : i32} : (tensor<64xi32, #triton_gpu.slice<{dim = 1, parent = #blocked2}>>) -> tensor<64x1xi32, #blocked2>
%40 = arith.muli %39, %18 {agent.mutex_role = 1 : i32, async_agent = dense<1> : vector<1xi32>} : tensor<64x1xi32, #blocked2>
%41 = tt.addptr %19, %40 {agent.mutex_role = 1 : i32, async_agent = dense<1> : vector<1xi32>} : tensor<64x1x!tt.ptr<f32, 1>, #blocked2>, tensor<64x1xi32, #blocked2>
%42 = tt.expand_dims %34 {agent.mutex_role = 1 : i32, async_agent = dense<1> : vector<1xi32>, axis = 0 : i32} : (tensor<64xi32, #triton_gpu.slice<{dim = 0, parent = #blocked2}>>) -> tensor<1x64xi32, #blocked2>
%43 = tt.broadcast %41 {agent.mutex_role = 1 : i32, async_agent = dense<1> : vector<1xi32>} : (tensor<64x1x!tt.ptr<f32, 1>, #blocked2>) -> tensor<64x64x!tt.ptr<f32, 1>, #blocked2>
%44 = tt.broadcast %42 {agent.mutex_role = 1 : i32, async_agent = dense<1> : vector<1xi32>} : (tensor<1x64xi32, #blocked2>) -> tensor<64x64xi32, #blocked2>
%45 = tt.addptr %43, %44 {agent.mutex_role = 1 : i32, async_agent = dense<1> : vector<1xi32>} : tensor<64x64x!tt.ptr<f32, 1>, #blocked2>, tensor<64x64xi32, #blocked2>
%46 = triton_gpu.convert_layout %38#0 {agent.mutex_role = 1 : i32, async_agent = dense<1> : vector<1xi32>} : (tensor<64x64xf32, #mma>) -> tensor<64x64xf32, #blocked2>
tt.store %45, %46 {agent.mutex_role = 1 : i32, async_agent = dense<1> : vector<1xi32>, cache = 1 : i32, evict = 1 : i32} : tensor<64x64xf32, #blocked2>
triton_nvidia_gpu.unlock %9 : !triton_nvidia_gpu.mutex
%c1_i32_5 = arith.constant {agent.mutex_role = 1 : i32, async_agent = dense<1> : vector<1xi32>} 1 : i32
%47 = arith.addi %arg10, %c2_i32 {agent.mutex_role = 1 : i32, async_agent = dense<1> : vector<1xi32>} : i32
scf.yield {async_agent = dense<1> : vector<1xi32>} %47 : i32
%70 = arith.truncf %66 {agent.mutex_role = 1 : i32, async_agent = dense<1> : vector<1xi32>} : tensor<64x64xf32, #mma> to tensor<64x64xf16, #mma>
%71 = tt.advance %arg12, [%52, %54] {agent.mutex_role = 1 : i32, async_agent = dense<1> : vector<1xi32>} : <tensor<64x64xf16, #blocked>, 1>
%72 = triton_gpu.convert_layout %70 {agent.mutex_role = 1 : i32, async_agent = dense<1> : vector<1xi32>} : (tensor<64x64xf16, #mma>) -> tensor<64x64xf16, #blocked2>
tt.store %71, %72 {agent.mutex_role = 1 : i32, async_agent = dense<1> : vector<1xi32>, boundaryCheck = array<i32: 0, 1>, cache = 1 : i32, evict = 1 : i32} : !tt.ptr<tensor<64x64xf16, #blocked>, 1>, tensor<64x64xf16, #blocked2>
triton_nvidia_gpu.unlock %34 : !triton_nvidia_gpu.mutex
%c1_i32_10 = arith.constant {agent.mutex_role = 1 : i32, async_agent = dense<1> : vector<1xi32>} 1 : i32
%73 = arith.addi %arg15, %c2_i32 {agent.mutex_role = 1 : i32, async_agent = dense<1> : vector<1xi32>} : i32
scf.yield {async_agent = dense<1> : vector<1xi32>} %71, %49, %50, %73 : !tt.ptr<tensor<64x64xf16, #blocked>, 1>, i32, i32, i32
} {async_agent = dense<1> : vector<1xi32>}
} {"agent.num-roles" = 2 : i32, async_agent = dense<1> : vector<1xi32>}
tt.return

View File

@@ -2,9 +2,9 @@
// CHECK: scf.if
// CHECK: triton_nvidia_gpu.create_mutex
// CHECK: triton_nvidia_gpu.create_mutex
// CHECK: scf.for
// CHECK: triton_nvidia_gpu.create_mutex
// CHECK: triton_nvidia_gpu.create_mutex
// CHECK: triton_nvidia_gpu.lock
// CHECK: agent.mutex_role = 0 : i32
// CHECK: triton_nvidia_gpu.unlock

View File

@@ -22,9 +22,10 @@
// CHECK: triton_gpu.extract_slice
// CHECK: triton_gpu.extract_slice
// CHECK: triton_nvidia_gpu.dot_async
// CHECK: triton_nvidia_gpu.dot_wait
// CHECK: triton_nvidia_gpu.dot_wait {{.*}} pendings = 1
// CHECK: triton_nvidia_gpu.consumer_release
// CHECK: scf.yield
// CHECK: triton_nvidia_gpu.dot_wait {{.*}} pendings = 0
// CHECK: async_agent = dense<1> : vector<1xi32>
#blocked = #triton_gpu.blocked<{sizePerThread = [8, 1], threadsPerWarp = [4, 8], warpsPerCTA = [1, 4], order = [0, 1], CTAsPerCGA = [1, 1], CTASplitNum = [1, 1], CTAOrder = [0, 1]}>

View File

@@ -20,7 +20,7 @@ struct TestAliasPass
return opName;
}
static void print(StringRef name, SmallVector<std::string, 4> &vals,
static void print(StringRef name, SmallVector<std::string> &vals,
raw_ostream &os) {
if (vals.empty())
return;
@@ -57,7 +57,7 @@ struct TestAliasPass
auto getAllocOpNames = [&](Value value) {
dataflow::Lattice<AliasInfo> *latticeElement =
analysis->getLatticeElement(value);
SmallVector<std::string, 4> opNames;
SmallVector<std::string> opNames;
if (latticeElement) {
auto &info = latticeElement->getValue();
for (auto &alias : info.getAllocs()) {