mirror of
https://github.com/ROCm/ROCm.git
synced 2026-02-21 03:00:39 -05:00
Merge commit '721897fcc4f942aa97d2e9ba3787a5e213758177' into ifu-231108
Conflicts: bin/triton-translate.cpp lib/Conversion/TritonGPUToLLVM/ElementwiseOpToLLVM.cpp lib/Dialect/TritonGPU/Transforms/RemoveLayoutConversions.cpp python/triton/compiler/compiler.py python/triton/runtime/jit.py python/tutorials/06-fused-attention.py test/Conversion/tritongpu_to_llvm.mlir
This commit is contained in:
@@ -239,6 +239,13 @@ set(TRITON_BINARY_DIR "${CMAKE_CURRENT_BINARY_DIR}")
|
||||
get_property(dialect_libs GLOBAL PROPERTY MLIR_DIALECT_LIBS)
|
||||
get_property(conversion_libs GLOBAL PROPERTY MLIR_CONVERSION_LIBS)
|
||||
|
||||
# TODO: Figure out which target is sufficient to fix errors; triton is
|
||||
# apparently not enough. Currently set linking libstdc++fs for all targets
|
||||
# to support some old version GCC compilers like 8.3.0.
|
||||
if (NOT WIN32 AND NOT APPLE)
|
||||
link_libraries(stdc++fs)
|
||||
endif()
|
||||
|
||||
if(TRITON_BUILD_PYTHON_MODULE)
|
||||
add_library(triton SHARED ${PYTHON_SRC})
|
||||
set(TRITON_LIBRARIES
|
||||
@@ -259,7 +266,6 @@ if(TRITON_BUILD_PYTHON_MODULE)
|
||||
MLIRLLVMDialect
|
||||
MLIRSupport
|
||||
MLIRTargetLLVMIRExport
|
||||
MLIRExecutionEngine
|
||||
MLIRMathToLLVM
|
||||
MLIRNVVMToLLVMIRTranslation
|
||||
MLIRROCDLToLLVMIRTranslation
|
||||
@@ -278,9 +284,6 @@ if(TRITON_BUILD_PYTHON_MODULE)
|
||||
target_link_libraries(triton ${ROCM_LIBRARIES} ${LLVM_LIBRARIES} z
|
||||
${TRITON_LIBRARIES}
|
||||
)
|
||||
# TODO: Figure out which target is sufficient to fix errors; triton is
|
||||
# apparently not enough
|
||||
link_libraries(stdc++fs)
|
||||
endif()
|
||||
|
||||
target_link_options(triton PRIVATE ${LLVM_LDFLAGS})
|
||||
|
||||
@@ -90,10 +90,10 @@ arbitrary LLVM version.
|
||||
1. Find the version of LLVM that Triton builds against. Check `python/setup.py`
|
||||
for a line like
|
||||
|
||||
version = "llvm-17.0.0-c5dede880d17"
|
||||
version = "llvmorg-18-init-7000-g76ce4736721a"
|
||||
|
||||
This means that the version of Triton you have builds against
|
||||
[LLVM](https://github.com/llvm/llvm-project) c5dede880d17.
|
||||
[LLVM](https://github.com/llvm/llvm-project) 76ce4736721a.
|
||||
|
||||
2. `git checkout` LLVM at this revision. Optionally, make additional
|
||||
modifications to LLVM.
|
||||
|
||||
@@ -72,7 +72,6 @@ llvm_update_compile_flags(triton-translate)
|
||||
MLIRPass
|
||||
MLIRSupport
|
||||
MLIRTransforms
|
||||
MLIRExecutionEngine
|
||||
MLIRMathToLLVM
|
||||
MLIRTransformUtils
|
||||
MLIRLLVMToLLVMIRTranslation
|
||||
|
||||
@@ -104,10 +104,15 @@ LogicalResult tritonTranslateMain(int argc, char **argv,
|
||||
"", llvm::cl::desc("AMDGCN features. e.g. '+sramecc,-xnack'"),
|
||||
llvm::cl::value_desc("features"), llvm::cl::init("+sramecc,-xnack"));
|
||||
|
||||
static llvm::cl::opt<bool> enableFpFusion(
|
||||
"enable-fp-fusion", llvm::cl::desc("Enables fusion of fadd/fmul"),
|
||||
llvm::cl::init(true));
|
||||
|
||||
llvm::InitLLVM y(argc, argv);
|
||||
|
||||
registerAsmPrinterCLOptions();
|
||||
registerMLIRContextCLOptions();
|
||||
registerPassManagerCLOptions();
|
||||
llvm::cl::ParseCommandLineOptions(argc, argv, toolName);
|
||||
|
||||
mlir::MLIRContext context;
|
||||
@@ -142,12 +147,17 @@ LogicalResult tritonTranslateMain(int argc, char **argv,
|
||||
llvm::outs() << *llvmir << '\n';
|
||||
} else if (targetKind == "ptx") {
|
||||
llvm::outs() << ::triton::translateLLVMIRToPTX(*llvmir, SMArch.getValue(),
|
||||
<<<<<<< HEAD
|
||||
ptxVersion.getValue());
|
||||
} else if (targetKind == "hsaco") {
|
||||
auto [module, hsaco] = mlir::triton::translateLLVMIRToHSACO(
|
||||
*llvmir, GCNArch.getValue(), GCNTriple.getValue(),
|
||||
GCNFeatures.getValue());
|
||||
llvm::outs() << hsaco;
|
||||
=======
|
||||
ptxVersion.getValue(),
|
||||
enableFpFusion.getValue());
|
||||
>>>>>>> 721897fcc4f942aa97d2e9ba3787a5e213758177
|
||||
} else {
|
||||
llvm::errs() << "Error: Unknown target specified: " << targetKind << "\n";
|
||||
return failure();
|
||||
|
||||
@@ -63,8 +63,6 @@ Memory Ops
|
||||
|
||||
load
|
||||
store
|
||||
atomic_cas
|
||||
atomic_xchg
|
||||
|
||||
|
||||
Indexing Ops
|
||||
@@ -129,10 +127,11 @@ Atomic Ops
|
||||
:toctree: generated
|
||||
:nosignatures:
|
||||
|
||||
atomic_cas
|
||||
atomic_add
|
||||
atomic_cas
|
||||
atomic_max
|
||||
atomic_min
|
||||
atomic_xchg
|
||||
|
||||
|
||||
Comparison ops
|
||||
|
||||
@@ -63,11 +63,12 @@ private:
|
||||
// Shared Memory Alias Analysis
|
||||
//===----------------------------------------------------------------------===//
|
||||
class SharedMemoryAliasAnalysis
|
||||
: public dataflow::SparseDataFlowAnalysis<dataflow::Lattice<AliasInfo>> {
|
||||
: public dataflow::SparseForwardDataFlowAnalysis<
|
||||
dataflow::Lattice<AliasInfo>> {
|
||||
public:
|
||||
using dataflow::SparseDataFlowAnalysis<
|
||||
dataflow::Lattice<AliasInfo>>::SparseDataFlowAnalysis;
|
||||
using dataflow::SparseDataFlowAnalysis<
|
||||
using dataflow::SparseForwardDataFlowAnalysis<
|
||||
dataflow::Lattice<AliasInfo>>::SparseForwardDataFlowAnalysis;
|
||||
using dataflow::SparseForwardDataFlowAnalysis<
|
||||
dataflow::Lattice<AliasInfo>>::getLatticeElement;
|
||||
|
||||
/// XXX(Keren): Compatible interface with MLIR AliasAnalysis for future use.
|
||||
|
||||
@@ -271,8 +271,8 @@ private:
|
||||
std::vector<std::unique_ptr<AxisInfoVisitor>> visitors;
|
||||
};
|
||||
|
||||
class AxisInfoAnalysis
|
||||
: public dataflow::SparseDataFlowAnalysis<dataflow::Lattice<AxisInfo>> {
|
||||
class AxisInfoAnalysis : public dataflow::SparseForwardDataFlowAnalysis<
|
||||
dataflow::Lattice<AxisInfo>> {
|
||||
private:
|
||||
AxisInfoVisitorList visitors;
|
||||
|
||||
@@ -284,7 +284,7 @@ private:
|
||||
|
||||
public:
|
||||
AxisInfoAnalysis(DataFlowSolver &solver);
|
||||
using dataflow::SparseDataFlowAnalysis<
|
||||
using dataflow::SparseForwardDataFlowAnalysis<
|
||||
dataflow::Lattice<AxisInfo>>::getLatticeElement;
|
||||
using FuncAxisInfoMapT = DenseMap<FunctionOpInterface, AxisInfo>;
|
||||
|
||||
|
||||
@@ -54,6 +54,8 @@ public:
|
||||
|
||||
SmallVector<unsigned> getScratchConfig();
|
||||
|
||||
SmallVector<unsigned> getOrderWithAxisAtBeginning();
|
||||
|
||||
unsigned getScratchSizeInBytes();
|
||||
|
||||
bool isSupportedLayout();
|
||||
@@ -133,9 +135,9 @@ bool supportMMA(Value value, int version);
|
||||
|
||||
bool isSingleValue(Value value);
|
||||
|
||||
bool isMmaToDotShortcut(RankedTensorType &srcTy, RankedTensorType &dstTy);
|
||||
bool isMmaToDotShortcut(RankedTensorType srcTy, RankedTensorType dstTy);
|
||||
|
||||
bool isMmaToMmaShortcut(RankedTensorType &srcTy, RankedTensorType &dstTy);
|
||||
bool isMmaToMmaShortcut(RankedTensorType srcTy, RankedTensorType dstTy);
|
||||
|
||||
// Return true if the src and dst layout match.
|
||||
bool matchMmaV3AndDotOperandLayout(RankedTensorType srcTy,
|
||||
|
||||
@@ -27,6 +27,7 @@ include "triton/Dialect/NVGPU/IR/NVGPUAttrDefs.td"
|
||||
include "mlir/IR/OpBase.td"
|
||||
include "mlir/IR/EnumAttr.td"
|
||||
include "mlir/Dialect/LLVMIR/LLVMOpBase.td"
|
||||
include "mlir/Interfaces/InferTypeOpInterface.td" // SameOperandsAndResultType
|
||||
|
||||
def I8Ptr_global : LLVM_IntPtrBase<8, 1>;
|
||||
def I8Ptr_shared : LLVM_IntPtrBase<8, 3>;
|
||||
@@ -44,9 +45,13 @@ def NVGPU_WGMMACommitGroupOp : NVGPU_Op<"wgmma_commit_group", []> {
|
||||
let assemblyFormat = "attr-dict";
|
||||
}
|
||||
|
||||
def NVGPU_WGMMAWaitGroupOp : NVGPU_Op<"wgmma_wait_group", []> {
|
||||
let arguments = (ins I32Attr:$pendings);
|
||||
def NVGPU_WGMMAWaitGroupOp : NVGPU_Op<"wgmma_wait_group",
|
||||
[DeclareOpInterfaceMethods<InferTypeOpInterface>,
|
||||
AllTypesMatch<["input", "output"]>]> {
|
||||
let arguments = (ins LLVM_AnyStruct:$input, I32Attr:$pendings);
|
||||
let results = (outs LLVM_AnyStruct:$output);
|
||||
let assemblyFormat = "attr-dict";
|
||||
let assemblyFormat = "$input attr-dict `:` type($input)";
|
||||
}
|
||||
|
||||
def NVGPU_MBarrierInitOp : NVGPU_Op<"mbarrier_init", [MemoryEffects<[MemWrite]>]> {
|
||||
|
||||
@@ -9,9 +9,8 @@
|
||||
#include "mlir/Dialect/Tensor/IR/Tensor.h"
|
||||
#include "mlir/IR/BuiltinOps.h"
|
||||
#include "mlir/IR/Dialect.h"
|
||||
#include "mlir/IR/FunctionInterfaces.h"
|
||||
#include "mlir/Interfaces/CallInterfaces.h"
|
||||
#include "mlir/Interfaces/ControlFlowInterfaces.h"
|
||||
#include "mlir/Interfaces/FunctionInterfaces.h"
|
||||
#include "triton/Dialect/Triton/IR/Dialect.h.inc"
|
||||
#include "triton/Dialect/Triton/IR/OpsEnums.h.inc"
|
||||
#include "triton/Dialect/Triton/IR/Traits.h"
|
||||
|
||||
@@ -6,11 +6,11 @@ include "triton/Dialect/Triton/IR/TritonTypes.td"
|
||||
include "triton/Dialect/Triton/IR/TritonAttrDefs.td"
|
||||
include "triton/Dialect/Triton/IR/TritonInterfaces.td"
|
||||
include "mlir/IR/OpBase.td"
|
||||
include "mlir/IR/FunctionInterfaces.td" // FunctionOpInterface
|
||||
include "mlir/IR/SymbolInterfaces.td" // SymbolUserOpInterface
|
||||
include "mlir/IR/OpAsmInterface.td" // OpAsmOpInterface
|
||||
include "mlir/Interfaces/CallInterfaces.td" // CallOpInterface
|
||||
include "mlir/Interfaces/CastInterfaces.td" // CastOpInterface
|
||||
include "mlir/Interfaces/FunctionInterfaces.td" // FunctionOpInterface
|
||||
include "mlir/Interfaces/SideEffectInterfaces.td" // Pure
|
||||
include "mlir/Interfaces/ControlFlowInterfaces.td" // BranchOpInterface
|
||||
include "mlir/Interfaces/InferTypeOpInterface.td" // SameOperandsAndResultType
|
||||
@@ -668,6 +668,12 @@ def CallOp : TT_Op<"call", [CallOpInterface, /*MemRefsNormalizable, */DeclareOpI
|
||||
void setCalleeFromCallable(CallInterfaceCallable callee) {
|
||||
(*this)->setAttr("callee", callee.get<SymbolRefAttr>());
|
||||
}
|
||||
|
||||
// Required by CallOpInterface.
|
||||
MutableOperandRange getArgOperandsMutable() {
|
||||
return getOperandsMutable();
|
||||
}
|
||||
|
||||
}];
|
||||
|
||||
let assemblyFormat = [{
|
||||
|
||||
@@ -269,17 +269,17 @@ def TTNG_DotAsyncOp : TTNG_Op<"dot_async", [Pure,
|
||||
let assemblyFormat = "$a`,` $b`,` $c attr-dict `:` type($a) `*` type($b) `->` type($d)";
|
||||
}
|
||||
|
||||
def TTNG_DotWaitOp : TTNG_Op<"dot_wait", []> {
|
||||
def TTNG_DotWaitOp : TTNG_Op<"dot_wait", [DeclareOpInterfaceMethods<InferTypeOpInterface>,
|
||||
AllTypesMatch<["input", "output"]>]> {
|
||||
let summary = "dot wait";
|
||||
|
||||
let arguments = (ins TT_FpIntTensor:$input, I32Attr:$pendings);
|
||||
let results = (outs TT_FpIntTensor:$output);
|
||||
let description = [{
|
||||
This operation defining the waiting action for a async dot, MMAv3 .e.g.
|
||||
The subsequent operations should not execute until this operation completes waiting.
|
||||
}];
|
||||
|
||||
let arguments = (ins I32Attr:$pendings);
|
||||
|
||||
let assemblyFormat = "attr-dict";
|
||||
let assemblyFormat = "$input attr-dict `:` type($input)";
|
||||
}
|
||||
|
||||
def TTNG_StoreAsyncOp : TTNG_Op<"store_async",
|
||||
|
||||
@@ -10,7 +10,8 @@ class Module;
|
||||
namespace triton {
|
||||
|
||||
// Translate TritonGPU IR to PTX code.
|
||||
std::string translateLLVMIRToPTX(llvm::Module &module, int cc, int version);
|
||||
std::string translateLLVMIRToPTX(llvm::Module &module, int cc, int version,
|
||||
bool enable_fp_fusion);
|
||||
|
||||
} // namespace triton
|
||||
|
||||
|
||||
@@ -885,7 +885,8 @@ public:
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
||||
AxisInfoAnalysis::AxisInfoAnalysis(DataFlowSolver &solver)
|
||||
: dataflow::SparseDataFlowAnalysis<dataflow::Lattice<AxisInfo>>(solver) {
|
||||
: dataflow::SparseForwardDataFlowAnalysis<dataflow::Lattice<AxisInfo>>(
|
||||
solver) {
|
||||
// UnrealizedConversionCast:
|
||||
// This is needed by TritonGPUToLLVM, to get AxisInfo when the graph is
|
||||
// in the process of a PartialConversion, where UnrealizedConversionCast
|
||||
|
||||
@@ -38,6 +38,17 @@ bool ReduceOpHelper::isReductionOnLayoutFastAxis() {
|
||||
getParentOrder(getSrcLayout())[0];
|
||||
}
|
||||
|
||||
SmallVector<unsigned> ReduceOpHelper::getOrderWithAxisAtBeginning() {
|
||||
auto srcLayout = getSrcLayout();
|
||||
auto order = triton::gpu::getOrder(srcLayout);
|
||||
auto it = std::find(order.begin(), order.end(), axis);
|
||||
// delete the axis from order
|
||||
order.erase(it);
|
||||
// insert axis at the beginning of order
|
||||
order.insert(order.begin(), axis);
|
||||
return order;
|
||||
}
|
||||
|
||||
// Thread offset is the thread index offset of two adjacent threads on the
|
||||
// reduction axis within the warp.
|
||||
unsigned ReduceOpHelper::getThreadOffsetOnReductionAxis() {
|
||||
@@ -56,11 +67,11 @@ unsigned ReduceOpHelper::getThreadOffsetOnReductionAxis() {
|
||||
threadOffset = threadsPerWarp[sliceLayout.getDim()];
|
||||
} else {
|
||||
auto threadsPerWarp = triton::gpu::getThreadsPerWarp(srcLayout);
|
||||
if (threadsPerWarp.size() == 1) {
|
||||
threadOffset = 1;
|
||||
} else {
|
||||
assert(threadsPerWarp.size() == 2 && "Only supports 2D layouts");
|
||||
threadOffset = axis == 0 ? threadsPerWarp[1] : threadsPerWarp[0];
|
||||
auto order = triton::gpu::getOrder(srcLayout);
|
||||
for (unsigned i = 0; i < order.size(); i++) {
|
||||
if (order[i] == axis)
|
||||
break;
|
||||
threadOffset *= threadsPerWarp[order[i]];
|
||||
}
|
||||
}
|
||||
return threadOffset;
|
||||
@@ -150,8 +161,10 @@ unsigned ReduceOpHelper::getThreadsReductionAxis() {
|
||||
}
|
||||
|
||||
bool ReduceOpHelper::isWarpSynchronous() {
|
||||
auto argsLayout = getSrcLayout();
|
||||
return triton::gpu::getWarpsPerCTA(argsLayout)[axis] == 1;
|
||||
auto srcLayout = getSrcLayout();
|
||||
auto srcShape = getSrcShape();
|
||||
return triton::gpu::getWarpsPerCTAWithUniqueData(srcLayout, srcShape)[axis] ==
|
||||
1;
|
||||
}
|
||||
|
||||
SmallVector<unsigned> ReduceOpHelper::getScratchConfig() {
|
||||
@@ -502,10 +515,10 @@ static bool isMmaToMmaShortcut(Attribute srcEncoding, Attribute dstEncoding) {
|
||||
// when #mma = MmaEncoding<version=3, warpsPerCTA=[..., 1]>
|
||||
return src && dst && src.getVersionMajor() == 3 &&
|
||||
src.getWarpsPerCTA()[1] == 1 && dst.getVersionMajor() == 3 &&
|
||||
dst.getWarpsPerCTA()[1] == 1 && srcInstrShape[2] == dstInstrShape[2];
|
||||
dst.getWarpsPerCTA()[1] == 1;
|
||||
}
|
||||
|
||||
bool isMmaToMmaShortcut(RankedTensorType &srcTy, RankedTensorType &dstTy) {
|
||||
bool isMmaToMmaShortcut(RankedTensorType srcTy, RankedTensorType dstTy) {
|
||||
return isMmaToMmaShortcut(srcTy.getEncoding(), dstTy.getEncoding());
|
||||
}
|
||||
|
||||
@@ -521,7 +534,7 @@ bool matchMmaV3AndDotOperandLayout(RankedTensorType srcTy,
|
||||
srcTy.getElementType().isF16();
|
||||
}
|
||||
|
||||
bool isMmaToDotShortcut(RankedTensorType &srcTy, RankedTensorType &dstTy) {
|
||||
bool isMmaToDotShortcut(RankedTensorType srcTy, RankedTensorType dstTy) {
|
||||
if (matchMmaV3AndDotOperandLayout(srcTy, dstTy))
|
||||
return true;
|
||||
// dot_op<opIdx=0, parent=#mma> = #mma
|
||||
@@ -713,7 +726,10 @@ SetVector<Operation *> multiRootGetSlice(Operation *op,
|
||||
auto *currentOp = (slice)[currentIndex];
|
||||
// Compute and insert the backwardSlice starting from currentOp.
|
||||
backwardSlice.clear();
|
||||
getBackwardSlice(currentOp, &backwardSlice, backwardFilter);
|
||||
mlir::BackwardSliceOptions opt;
|
||||
opt.omitBlockArguments = true;
|
||||
opt.filter = backwardFilter;
|
||||
getBackwardSlice(currentOp, &backwardSlice, opt);
|
||||
slice.insert(backwardSlice.begin(), backwardSlice.end());
|
||||
|
||||
// Compute and insert the forwardSlice starting from currentOp.
|
||||
|
||||
@@ -14,7 +14,7 @@ add_mlir_conversion_library(NVGPUToLLVM
|
||||
LINK_LIBS PUBLIC
|
||||
MLIRIR
|
||||
MLIRPass
|
||||
MLIRGPUOps
|
||||
MLIRGPUDialect
|
||||
MLIRGPUToNVVMTransforms
|
||||
MLIRGPUToROCDLTransforms
|
||||
MLIRGPUTransforms
|
||||
|
||||
@@ -29,8 +29,6 @@ const std::string Reg_Alloc_Op = "setmaxnreg.inc.sync.aligned.u32 #regCount;";
|
||||
const std::string Wgmma_Fence_Op = "wgmma.fence.sync.aligned;";
|
||||
const std::string Cga_Barrier_Sync_op = "barrier.cluster.sync.aligned;";
|
||||
const std::string Wgmma_Commit_Group_Op = "wgmma.commit_group.sync.aligned;";
|
||||
const std::string Wgmma_Wait_Group_Op =
|
||||
"wgmma.wait_group.sync.aligned #pendings;";
|
||||
const std::string Cluster_Wait_Op = "barrier.cluster.wait.aligned;";
|
||||
const std::string Fence_Mbarrier_Init_Op =
|
||||
"fence.mbarrier_init.release.cluster;";
|
||||
@@ -200,29 +198,6 @@ public:
|
||||
return {};
|
||||
}
|
||||
|
||||
Type getReturnType(std::vector<std::string> outputConstraints,
|
||||
mlir::PatternRewriter &rewriter) const {
|
||||
auto ctx = rewriter.getContext();
|
||||
Type resTy;
|
||||
if (outputConstraints.empty()) {
|
||||
resTy = void_ty(ctx);
|
||||
} else {
|
||||
SmallVector<Type> retTys;
|
||||
for (auto &outputConstraint : outputConstraints) {
|
||||
assert(outputConstraint[0] == '=' &&
|
||||
"Constraint must be for an output");
|
||||
Type retTy = getTypeFromConstraint(outputConstraint[1], rewriter);
|
||||
retTys.push_back(retTy);
|
||||
}
|
||||
if (retTys.size() == 1) {
|
||||
resTy = retTys[0];
|
||||
} else {
|
||||
resTy = struct_ty(retTys);
|
||||
}
|
||||
}
|
||||
return resTy;
|
||||
}
|
||||
|
||||
std::string patchPtxAsm(mlir::Operation *op, std::string ptxAsm) const {
|
||||
std::vector<std::pair<int, int>> patchLocations;
|
||||
std::vector<std::string> patchValues;
|
||||
@@ -285,7 +260,8 @@ public:
|
||||
outputsAndOperands.append(ptxOperands.begin(), ptxOperands.end());
|
||||
auto &ptxInstr = *ptxBuilder.create<PTXInstr>(ptxAsmPatched);
|
||||
ptxInstr(outputsAndOperands, /*onlyAttachMLIRArgs=*/true);
|
||||
auto retTy = getReturnType(outputConstraints, rewriter);
|
||||
auto retTy =
|
||||
op->getNumResults() == 0 ? void_ty(ctx) : op->getResult(0).getType();
|
||||
auto res = ptxBuilder.launch(rewriter, loc, retTy,
|
||||
/*hasSideEffects*/ hasSideEffects);
|
||||
if (op->getNumResults() == 0) {
|
||||
@@ -700,6 +676,45 @@ public:
|
||||
}
|
||||
};
|
||||
|
||||
class WGMMAWaitGroupOpPattern
|
||||
: public NVGPUOpPatternBase<ttn::WGMMAWaitGroupOp,
|
||||
WGMMAWaitGroupOpPattern> {
|
||||
public:
|
||||
using Base =
|
||||
NVGPUOpPatternBase<ttn::WGMMAWaitGroupOp, WGMMAWaitGroupOpPattern>;
|
||||
using Base::Base;
|
||||
|
||||
std::vector<std::string>
|
||||
getOutputConstraints(ttn::WGMMAWaitGroupOp op) const {
|
||||
auto outputStructType = op.getType().cast<LLVM::LLVMStructType>();
|
||||
uint32_t numOutputRegs = outputStructType.getBody().size();
|
||||
std::string output =
|
||||
outputStructType.getBody().front().isF32() ? "=f" : "=r";
|
||||
return std::vector<std::string>(numOutputRegs, output);
|
||||
}
|
||||
|
||||
OperandsAndConstraints
|
||||
getOperandsAndConstraints(ttn::WGMMAWaitGroupOp op) const {
|
||||
OperandsAndConstraints operandsAndConstraints;
|
||||
auto input = op.getInput();
|
||||
operandsAndConstraints.push_back({input, "0"});
|
||||
return operandsAndConstraints;
|
||||
}
|
||||
|
||||
std::string getPtxAsm(ttn::WGMMAWaitGroupOp op) const {
|
||||
auto outputStructType = op.getType().dyn_cast<LLVM::LLVMStructType>();
|
||||
uint32_t numCRegs = outputStructType.getBody().size();
|
||||
std::string args = "";
|
||||
uint32_t asmOpIdx = 0;
|
||||
for (uint32_t i = 0; i < numCRegs; ++i) {
|
||||
args += "$" + std::to_string(asmOpIdx++) + (i == numCRegs - 1 ? "" : ",");
|
||||
}
|
||||
auto ptxAsm = "// wait for regs: " + args + "\n\t" +
|
||||
"wgmma.wait_group.sync.aligned #pendings;";
|
||||
return ptxAsm;
|
||||
}
|
||||
};
|
||||
|
||||
class WGMMAOpPattern : public NVGPUOpPatternBase<ttn::WGMMAOp, WGMMAOpPattern> {
|
||||
public:
|
||||
using Base = NVGPUOpPatternBase<ttn::WGMMAOp, WGMMAOpPattern>;
|
||||
@@ -1072,7 +1087,6 @@ public:
|
||||
POPULATE_NVGPU_OP(ttn::WGMMAFenceOp, Wgmma_Fence_Op)
|
||||
POPULATE_NVGPU_OP(ttn::CGABarrierSyncOp, Cga_Barrier_Sync_op)
|
||||
POPULATE_NVGPU_OP(ttn::WGMMACommitGroupOp, Wgmma_Commit_Group_Op)
|
||||
POPULATE_NVGPU_OP(ttn::WGMMAWaitGroupOp, Wgmma_Wait_Group_Op)
|
||||
POPULATE_NVGPU_OP(ttn::ClusterWaitOp, Cluster_Wait_Op)
|
||||
POPULATE_NVGPU_OP(ttn::FenceMBarrierInitOp, Fence_Mbarrier_Init_Op)
|
||||
POPULATE_NVGPU_OP(ttn::CGABarrierArriveOp, Cga_Barrier_Arrive_Op)
|
||||
@@ -1100,7 +1114,8 @@ public:
|
||||
OffsetOfStmatrixV4OpPattern, MBarrierArriveOpPattern,
|
||||
ClusterArriveOpPattern, TMALoadTiledOpPattern,
|
||||
TMAStoreTiledOpPattern, LoadDSmemOpPattern, WGMMAOpPattern,
|
||||
StoreDSmemOpPattern, OffsetOfSts64OpPattern>(context);
|
||||
WGMMAWaitGroupOpPattern, StoreDSmemOpPattern,
|
||||
OffsetOfSts64OpPattern>(context);
|
||||
|
||||
if (applyPatternsAndFoldGreedily(mod, std::move(patterns)).failed())
|
||||
signalPassFailure();
|
||||
|
||||
@@ -49,7 +49,7 @@ add_mlir_conversion_library(TritonGPUToLLVM
|
||||
ASMBuilder
|
||||
MLIRIR
|
||||
MLIRPass
|
||||
MLIRGPUOps
|
||||
MLIRGPUDialect
|
||||
MLIRGPUToNVVMTransforms
|
||||
MLIRGPUToROCDLTransforms
|
||||
MLIRGPUTransforms
|
||||
|
||||
@@ -146,10 +146,8 @@ struct DotWaitOpConversion
|
||||
matchAndRewrite(triton::nvidia_gpu::DotWaitOp op, OpAdaptor adaptor,
|
||||
ConversionPatternRewriter &rewriter) const override {
|
||||
auto pendings = op.getPendings();
|
||||
rewriter.create<triton::nvgpu::WGMMAWaitGroupOp>(op.getLoc(), pendings);
|
||||
|
||||
// Safe to remove the op since it doesn't have any return value.
|
||||
rewriter.eraseOp(op);
|
||||
rewriter.replaceOpWithNewOp<triton::nvgpu::WGMMAWaitGroupOp>(
|
||||
op, adaptor.getInput(), pendings);
|
||||
return success();
|
||||
}
|
||||
};
|
||||
|
||||
@@ -168,7 +168,9 @@ DotOpMmaV3SmemLoader loadA(TritonGPUToLLVMTypeConverter *typeConverter,
|
||||
int numRepM = ceil<unsigned>(shapePerCTA[0], instrShape[0] * wpt[0]);
|
||||
int numRepK = ceil<unsigned>(shapePerCTA[1], instrShape[2]);
|
||||
|
||||
Value warp = udiv(thread, i32_val(32));
|
||||
// The descriptor should be calculated based on the first warp of the
|
||||
// warpgroup.
|
||||
Value warp = and_(udiv(thread, i32_val(32)), i32_val(0xFFFFFFFC));
|
||||
Value warpM = urem(warp, i32_val(wpt[0]));
|
||||
Value warpId = urem(warpM, i32_val(shapePerCTA[0] / instrShape[0]));
|
||||
|
||||
@@ -199,7 +201,7 @@ DotOpMmaV3SmemLoader loadB(TritonGPUToLLVMTypeConverter *typeConverter,
|
||||
int numRepK = ceil<unsigned>(shapePerCTA[0], instrShape[2]);
|
||||
int numRepN = ceil<unsigned>(shapePerCTA[1], instrShape[1] * wpt[1]);
|
||||
|
||||
Value warp = udiv(thread, i32_val(32));
|
||||
Value warp = and_(udiv(thread, i32_val(32)), i32_val(0xFFFFFFFC));
|
||||
Value warpMN = udiv(warp, i32_val(wpt[0]));
|
||||
Value warpN = urem(warpMN, i32_val(wpt[1]));
|
||||
Value warpId = urem(warpN, i32_val(shapePerCTA[1] / instrShape[1]));
|
||||
@@ -293,6 +295,26 @@ static bool isZero(Value v) {
|
||||
return false;
|
||||
}
|
||||
|
||||
static SmallVector<Value> emitWait(ConversionPatternRewriter &rewriter,
|
||||
Location loc, SmallVector<Value> acc,
|
||||
int pendings) {
|
||||
SmallVector<Type> types(acc.size(), acc[0].getType());
|
||||
auto structTy =
|
||||
LLVM::LLVMStructType::getLiteral(rewriter.getContext(), types);
|
||||
Value llvmStruct = rewriter.create<LLVM::UndefOp>(loc, structTy);
|
||||
int i = 0;
|
||||
for (Value v : acc) {
|
||||
llvmStruct = insert_val(structTy, llvmStruct, v, i++);
|
||||
}
|
||||
Value res = rewriter.create<triton::nvgpu::WGMMAWaitGroupOp>(loc, llvmStruct,
|
||||
pendings);
|
||||
SmallVector<Value> results;
|
||||
for (int i = 0; i < acc.size(); ++i) {
|
||||
results.push_back(extract_val(types[0], res, i));
|
||||
}
|
||||
return results;
|
||||
}
|
||||
|
||||
LogicalResult convertDot(TritonGPUToLLVMTypeConverter *typeConverter,
|
||||
ConversionPatternRewriter &rewriter, Location loc,
|
||||
Operation *op, Value a, Value b, Value c, Value d,
|
||||
@@ -427,7 +449,7 @@ LogicalResult convertDot(TritonGPUToLLVMTypeConverter *typeConverter,
|
||||
rewriter.create<triton::nvgpu::WGMMACommitGroupOp>(loc);
|
||||
|
||||
if (sync)
|
||||
rewriter.create<triton::nvgpu::WGMMAWaitGroupOp>(loc, 0);
|
||||
mmaResults = emitWait(rewriter, loc, mmaResults, 0);
|
||||
|
||||
SmallVector<Value> results =
|
||||
unpackAccumulator(rewriter, loc, mmaResults, dTensorTy);
|
||||
|
||||
@@ -6,7 +6,24 @@ using ::mlir::triton::gpu::getTotalElemsPerThread;
|
||||
|
||||
/* ----- FP8E5M2 ------ */
|
||||
// This data-type is the standard FP8E5M2 format
|
||||
static const std::string Fp16_to_Fp8E5M2(bool hasNativeFP) {
|
||||
std::string ret;
|
||||
if (!hasNativeFP) {
|
||||
ret = "{ \n"
|
||||
".reg .b32 a<2>; \n"
|
||||
"and.b32 a0, $1, 0xfffefffe; \n" // a0 &= 0xfffefffe
|
||||
"and.b32 a1, $2, 0xfffefffe; \n" // (strip lowest bit)
|
||||
"add.u32 a0, a0, 0x00800080; \n" // a0 += 0x00800080
|
||||
"add.u32 a1, a1, 0x00800080; \n" // (round to nearest)
|
||||
"prmt.b32 $0, a0, a1, 0x7531; \n\t" // output = a1a0
|
||||
"}";
|
||||
} else {
|
||||
ret = "cvt.rn.satfinite.e5m2x2.f16x2 $0, $1; \n\t";
|
||||
}
|
||||
return ret;
|
||||
}
|
||||
|
||||
<<<<<<< HEAD
|
||||
#ifdef USE_ROCM
|
||||
static SmallVector<Value>
|
||||
Fp16_to_Fp8E5M2(Location loc, ConversionPatternRewriter &rewriter,
|
||||
@@ -356,11 +373,115 @@ const std::string Bf16_to_Fp8E5M2 =
|
||||
"or.b32 $0, nosign, sign; \n" // restore sign
|
||||
"}";
|
||||
#endif
|
||||
=======
|
||||
static const std::string Fp8E5M2_to_Fp16(bool hasNativeFP) {
|
||||
std::string ret;
|
||||
if (!hasNativeFP) {
|
||||
ret = "{ \n"
|
||||
"prmt.b32 $0, 0, $2, 0x5140; \n\t"
|
||||
"prmt.b32 $1, 0, $2, 0x7362; \n\t"
|
||||
"}";
|
||||
} else {
|
||||
ret = "cvt.rn.f16x2.e5m2x2 $0, $1; \n\t";
|
||||
}
|
||||
return ret;
|
||||
}
|
||||
|
||||
static const std::string Fp8E5M2_to_Bf16(bool hasNativeFP) {
|
||||
std::string ret;
|
||||
if (!hasNativeFP) {
|
||||
ret =
|
||||
"{ \n"
|
||||
".reg .b32 a<2>, b<2>; \n" // if input = 0xf1f2f3f4
|
||||
"prmt.b32 a0, 0, $2, 0x5140; \n" // a0 = 0xf300f400
|
||||
"prmt.b32 a1, 0, $2, 0x7362; \n" // a1 = 0xf100f200
|
||||
"lop3.b32 b0, a0, 0x7fff7fff, 0, 0xc0; \n" // b0 = a0 & 0x7fff7fff
|
||||
"lop3.b32 b1, a1, 0x7fff7fff, 0, 0xc0; \n" // (strip sign)
|
||||
"shr.b32 b0, b0, 3; \n" // b0 >>= 3
|
||||
"shr.b32 b1, b1, 3; \n" // shift into bf16 position
|
||||
"add.u32 b0, b0, 0x38003800; \n" // b0.exp += 2**7-2**4
|
||||
// exponent compensate = 112
|
||||
"add.u32 b1, b1, 0x38003800; \n" // b1 += 112<<7 | 112<<7<<16
|
||||
"lop3.b32 $0, b0, 0x80008000, a0, 0xf8; \n" // out0 = b0|(0x80008000&a0)
|
||||
"lop3.b32 $1, b1, 0x80008000, a1, 0xf8; \n" // (restore sign)
|
||||
"}";
|
||||
} else {
|
||||
ret = "{ \n"
|
||||
".reg .b32 a; \n"
|
||||
".reg .f16 a<2>; \n"
|
||||
".reg .b16 b<2>; \n"
|
||||
"cvt.rn.f16x2.e5m2x2 a, $1; \n"
|
||||
"mov.b32 {a0, a1}, a; \n"
|
||||
"cvt.bf16.f16 b0, a0; \n"
|
||||
"cvt.bf16.f16 b1, a1; \n"
|
||||
"mov.b32 $0, {b0, b1}; \n"
|
||||
"}";
|
||||
}
|
||||
return ret;
|
||||
}
|
||||
>>>>>>> 721897fcc4f942aa97d2e9ba3787a5e213758177
|
||||
|
||||
static const std::string Bf16_to_Fp8E5M2(bool hasNativeFP) {
|
||||
std::string ret;
|
||||
if (!hasNativeFP) {
|
||||
ret =
|
||||
"{ \n" // bf16=fp8>>3 + 112<<7
|
||||
".reg .u32 sign, sign<2>, nosign, nosign<2>; \n" // fp8_min = 0b00000000
|
||||
".reg .u32 fp8_min, fp8_max, rn_; \n" // fp8_max = 0b11111111
|
||||
"mov.u32 fp8_min, 0x38003800; \n" // so bf16_min = 0x3800
|
||||
"mov.u32 fp8_max, 0x57e057e0; \n" // so bf16_max = 0x57e0
|
||||
"mov.u32 rn_, 0x00100010; \n" // round to nearest
|
||||
"and.b32 sign0, $1, 0x80008000; \n" // sign0=in0&0x80008000
|
||||
"and.b32 sign1, $2, 0x80008000; \n" // (store sign)
|
||||
"prmt.b32 sign, sign0, sign1, 0x7531; \n"
|
||||
"and.b32 nosign0, $1, 0x7fff7fff; \n" // nosign0=in0&0x7fff7fff
|
||||
"and.b32 nosign1, $2, 0x7fff7fff; \n" // (strip sign)
|
||||
|
||||
// nosign = clamp(nosign, min, max)
|
||||
".reg .u32 nosign_0_<2>, nosign_1_<2>; \n"
|
||||
"and.b32 nosign_0_0, nosign0, 0xffff0000; \n"
|
||||
"max.u32 nosign_0_0, nosign_0_0, 0x38000000; \n"
|
||||
"min.u32 nosign_0_0, nosign_0_0, 0x57e00000; \n"
|
||||
"and.b32 nosign_0_1, nosign0, 0x0000ffff; \n"
|
||||
"max.u32 nosign_0_1, nosign_0_1, 0x3800; \n"
|
||||
"min.u32 nosign_0_1, nosign_0_1, 0x57e0; \n"
|
||||
"or.b32 nosign0, nosign_0_0, nosign_0_1; \n"
|
||||
"and.b32 nosign_1_0, nosign1, 0xffff0000; \n"
|
||||
"max.u32 nosign_1_0, nosign_1_0, 0x38000000; \n"
|
||||
"min.u32 nosign_1_0, nosign_1_0, 0x57e00000; \n"
|
||||
"and.b32 nosign_1_1, nosign1, 0x0000ffff; \n"
|
||||
"max.u32 nosign_1_1, nosign_1_1, 0x3800; \n"
|
||||
"min.u32 nosign_1_1, nosign_1_1, 0x57e0; \n"
|
||||
"or.b32 nosign1, nosign_1_0, nosign_1_1; \n"
|
||||
|
||||
"add.u32 nosign0, nosign0, rn_; \n" // nosign0 += rn_
|
||||
"add.u32 nosign1, nosign1, rn_; \n" // (round to nearest)
|
||||
"sub.u32 nosign0, nosign0, 0x38003800; \n" // nosign0-=0x38003800
|
||||
"sub.u32 nosign1, nosign1, 0x38003800; \n" // (compensate offset)
|
||||
"shl.b32 nosign0, nosign0, 3; \n" // nosign0 <<= 3
|
||||
"shl.b32 nosign1, nosign1, 3; \n" // shift into to fp8e4
|
||||
"prmt.b32 nosign, nosign0, nosign1, 0x7531; \n" // nosign0 = 0xf100f200
|
||||
// nosign1 = 0xf300f400
|
||||
// nosign = 0xf3f4f1f2
|
||||
"or.b32 $0, nosign, sign; \n" // restore sign
|
||||
"}";
|
||||
} else {
|
||||
ret = "{ \n"
|
||||
".reg .b16 a<2>; \n"
|
||||
".reg .f32 b<2>; \n"
|
||||
"mov.b32 {a0, a1}, $1; \n"
|
||||
"cvt.f32.bf16 b0, a0; \n"
|
||||
"cvt.f32.bf16 b1, a1; \n"
|
||||
"cvt.rn.satfinite.e5m2x2.f32 $0, b0, b1; \n"
|
||||
"}";
|
||||
}
|
||||
return ret;
|
||||
}
|
||||
/* ----- FP8E4M3B15 ------ */
|
||||
// This data-type is a variant of the standard FP8E4M3 format.
|
||||
// It was designed for fast software conversion to FP16 on
|
||||
// nvidia GPUs that do not support it natively.
|
||||
<<<<<<< HEAD
|
||||
// Specifically, this data-type:
|
||||
// - has infinities
|
||||
// - has multiple nans (when all exponent bits are 1)
|
||||
@@ -404,6 +525,11 @@ Fp8E4M3B15_to_Fp16(Location loc, ConversionPatternRewriter &rewriter,
|
||||
};
|
||||
}
|
||||
#else
|
||||
=======
|
||||
// This is the same format as FP8E4M3Nv, but:
|
||||
// - the exponent bias is 15 instead of 7
|
||||
// - 0xff and 0x7f are mapped to +-1.750 instead of +-nan
|
||||
>>>>>>> 721897fcc4f942aa97d2e9ba3787a5e213758177
|
||||
const std::string Fp8E4M3B15_to_Fp16 =
|
||||
"{ \n"
|
||||
".reg .b32 a<2>, b<2>; \n"
|
||||
@@ -416,6 +542,7 @@ const std::string Fp8E4M3B15_to_Fp16 =
|
||||
"lop3.b32 $0, b0, 0x80008000, a0, 0xf8; \n"
|
||||
"shl.b32 $1, b1, 7; \n"
|
||||
"} \n";
|
||||
<<<<<<< HEAD
|
||||
#endif
|
||||
|
||||
#ifdef USE_ROCM
|
||||
@@ -464,6 +591,10 @@ Fp16_to_Fp8E4M3B15(Location loc, ConversionPatternRewriter &rewriter,
|
||||
}
|
||||
#else
|
||||
const std::string Fp16_to_Fp8E4M3B15(bool has_minx2) {
|
||||
=======
|
||||
|
||||
static const std::string Fp16_to_Fp8E4M3B15(bool has_minx2) {
|
||||
>>>>>>> 721897fcc4f942aa97d2e9ba3787a5e213758177
|
||||
std::string ret;
|
||||
ret += "{ \n"
|
||||
".reg .pred p<4>; \n"
|
||||
@@ -509,6 +640,7 @@ const std::string Fp16_to_Fp8E4M3B15(bool has_minx2) {
|
||||
// $0 = (($2 << 1) & 0x80008000u) | (($2 << 7) & 0x3f803f80u);
|
||||
// $1 = (($2 << 0) & 0x80008000u) | (($2 << 0) & 0x3f803f80u);
|
||||
// WARN: subnormal (0bs0000xxx) are not handled
|
||||
<<<<<<< HEAD
|
||||
#ifdef USE_ROCM
|
||||
static SmallVector<Value>
|
||||
Fp8E4M3B15x4_to_Fp16(Location loc, ConversionPatternRewriter &rewriter,
|
||||
@@ -540,6 +672,9 @@ Fp8E4M3B15x4_to_Fp16(Location loc, ConversionPatternRewriter &rewriter,
|
||||
}
|
||||
#else
|
||||
const std::string Fp8E4M3B15x4_to_Fp16 =
|
||||
=======
|
||||
static const std::string Fp8E4M3B15x4_to_Fp16 =
|
||||
>>>>>>> 721897fcc4f942aa97d2e9ba3787a5e213758177
|
||||
"{ \n"
|
||||
".reg .b32 a<2>; \n"
|
||||
"add.u32 a0, $2, $2; \n"
|
||||
@@ -557,6 +692,7 @@ const std::string Fp8E4M3B15x4_to_Fp16 =
|
||||
// ((e4.y >> 0) & (0x80008000u >> 0)) |
|
||||
// ((e4.y >> 0) & (0x3f803f80u >> 0)) ;
|
||||
// WARN: subnormal (0bs0000xxx) are not handled
|
||||
<<<<<<< HEAD
|
||||
#ifdef USE_ROCM
|
||||
static SmallVector<Value>
|
||||
Fp16_to_Fp8E4M3B15x4(Location loc, ConversionPatternRewriter &rewriter,
|
||||
@@ -591,6 +727,9 @@ Fp16_to_Fp8E4M3B15x4(Location loc, ConversionPatternRewriter &rewriter,
|
||||
}
|
||||
#else
|
||||
const std::string Fp16_to_Fp8E4M3B15x4 =
|
||||
=======
|
||||
static const std::string Fp16_to_Fp8E4M3B15x4 =
|
||||
>>>>>>> 721897fcc4f942aa97d2e9ba3787a5e213758177
|
||||
"{ \n"
|
||||
".reg .b32 a<2>; \n"
|
||||
"shr.b32 a0, $1, 1; \n"
|
||||
@@ -904,17 +1043,18 @@ const std::string Bf16_to_Fp8E4M3 =
|
||||
#endif
|
||||
|
||||
// Fp8E4M3 (x2) -> Fp16 (x2) (packed)
|
||||
const std::string Fp8E4M3Nv_to_Fp16 = "{ \n"
|
||||
"cvt.rn.f16x2.e4m3x2 $0, $1; \n"
|
||||
"}";
|
||||
static const std::string Fp8E4M3Nv_to_Fp16 = "{ \n"
|
||||
"cvt.rn.f16x2.e4m3x2 $0, $1; \n"
|
||||
"}";
|
||||
// Fp16 (x2) -> Fp8E4M3 (x2) (packed)
|
||||
const std::string Fp16_to_Fp8E4M3Nv = "{ \n"
|
||||
"cvt.rn.satfinite.e4m3x2.f16x2 $0, $1; \n"
|
||||
"}";
|
||||
static const std::string Fp16_to_Fp8E4M3Nv =
|
||||
"{ \n"
|
||||
"cvt.rn.satfinite.e4m3x2.f16x2 $0, $1; \n"
|
||||
"}";
|
||||
|
||||
#ifndef USE_ROCM
|
||||
// Fp8E4M3 (x2) -> Fp16 (x2) (packed)
|
||||
const std::string Fp8E4M3Nv_to_Bf16 =
|
||||
static const std::string Fp8E4M3Nv_to_Bf16 =
|
||||
"{ \n"
|
||||
".reg .b32 a; \n"
|
||||
".reg .f16 a<2>; \n"
|
||||
@@ -927,7 +1067,7 @@ const std::string Fp8E4M3Nv_to_Bf16 =
|
||||
"}";
|
||||
|
||||
// Bf16 (x2) -> Fp8E4M3 (x2) (packed)
|
||||
const std::string Bf16_to_Fp8E4M3Nv =
|
||||
static const std::string Bf16_to_Fp8E4M3Nv =
|
||||
"{ \n"
|
||||
".reg .b16 a<2>; \n"
|
||||
".reg .f32 b<2>; \n"
|
||||
@@ -938,7 +1078,7 @@ const std::string Bf16_to_Fp8E4M3Nv =
|
||||
"}";
|
||||
|
||||
/* ----- Packed integer to BF16 ------ */
|
||||
const std::string S8_to_Bf16 =
|
||||
static const std::string S8_to_Bf16 =
|
||||
"{ \n"
|
||||
".reg .s8 s<4>; \n"
|
||||
".reg .f32 f<4>; \n"
|
||||
@@ -952,6 +1092,12 @@ const std::string S8_to_Bf16 =
|
||||
"}";
|
||||
#endif
|
||||
|
||||
// Fp32 (x2) -> Fp8 (x2) (packed)
|
||||
static const std::string Fp32_to_Fp8E4M3Nv =
|
||||
"cvt.rn.satfinite.e4m3x2.f32 $0, $2, $1; \n";
|
||||
static const std::string Fp32_to_Fp8E5M2 =
|
||||
"cvt.rn.satfinite.e5m2x2.f32 $0, $2, $1; \n";
|
||||
|
||||
static SmallVector<Value> reorderValues(const SmallVector<Value> &values,
|
||||
Type inType, Type ouType) {
|
||||
auto inTensorTy = inType.dyn_cast<RankedTensorType>();
|
||||
@@ -1383,9 +1529,14 @@ struct FpToFpOpConversion
|
||||
// F8 -> F16
|
||||
{{F8E4M3B15TyID, F16TyID}, Fp8E4M3B15_to_Fp16},
|
||||
{{F8E4M3FNTyID, F16TyID}, Fp8E4M3B15x4_to_Fp16},
|
||||
<<<<<<< HEAD
|
||||
{{F8E4M3FNUZTyID, F16TyID}, Fp8E4M3FNUZ_to_Fp16},
|
||||
{{F8E5M2TyID, F16TyID}, Fp8E5M2_to_Fp16},
|
||||
{{F8E5M2FNUZTyID, F16TyID}, Fp8E5M2FNUZ_to_Fp16},
|
||||
=======
|
||||
{{F8E4M3TyID, F16TyID}, Fp8E4M3Nv_to_Fp16},
|
||||
{{F8E5M2TyID, F16TyID}, Fp8E5M2_to_Fp16(computeCapability >= 90)},
|
||||
>>>>>>> 721897fcc4f942aa97d2e9ba3787a5e213758177
|
||||
// F16 -> F8
|
||||
#ifdef USE_ROCM
|
||||
{{F16TyID, F8E4M3B15TyID}, Fp16_to_Fp8E4M3B15},
|
||||
@@ -1393,27 +1544,44 @@ struct FpToFpOpConversion
|
||||
{{F16TyID, F8E4M3B15TyID}, Fp16_to_Fp8E4M3B15(computeCapability >= 80)},
|
||||
#endif
|
||||
{{F16TyID, F8E4M3FNTyID}, Fp16_to_Fp8E4M3B15x4},
|
||||
<<<<<<< HEAD
|
||||
{{F16TyID, F8E4M3FNUZTyID}, Fp16_to_Fp8E4M3FNUZ},
|
||||
{{F16TyID, F8E5M2TyID}, Fp16_to_Fp8E5M2},
|
||||
{{F16TyID, F8E5M2FNUZTyID}, Fp16_to_Fp8E5M2FNUZ},
|
||||
// F8 -> BF16
|
||||
{{F8E5M2TyID, BF16TyID}, Fp8E5M2_to_Bf16},
|
||||
#ifndef USE_ROCM
|
||||
=======
|
||||
{{F16TyID, F8E4M3TyID}, Fp16_to_Fp8E4M3Nv},
|
||||
{{F16TyID, F8E5M2TyID}, Fp16_to_Fp8E5M2(computeCapability >= 90)},
|
||||
// F8 -> BF16
|
||||
{{F8E5M2TyID, BF16TyID}, Fp8E5M2_to_Bf16(computeCapability >= 90)},
|
||||
>>>>>>> 721897fcc4f942aa97d2e9ba3787a5e213758177
|
||||
{{F8E4M3TyID, BF16TyID}, Fp8E4M3Nv_to_Bf16},
|
||||
#endif
|
||||
// BF16 -> F8
|
||||
<<<<<<< HEAD
|
||||
{{BF16TyID, F8E5M2TyID}, Bf16_to_Fp8E5M2},
|
||||
#ifndef USE_ROCM
|
||||
{{BF16TyID, F8E4M3TyID}, Bf16_to_Fp8E4M3Nv},
|
||||
#endif
|
||||
=======
|
||||
{{BF16TyID, F8E5M2TyID}, Bf16_to_Fp8E5M2(computeCapability >= 90)},
|
||||
{{BF16TyID, F8E4M3TyID}, Bf16_to_Fp8E4M3Nv},
|
||||
// F32 -> F8
|
||||
{{F32TyID, F8E4M3TyID}, Fp32_to_Fp8E4M3Nv},
|
||||
{{F32TyID, F8E5M2TyID}, Fp32_to_Fp8E5M2},
|
||||
>>>>>>> 721897fcc4f942aa97d2e9ba3787a5e213758177
|
||||
};
|
||||
int inVecWidthBits = 32;
|
||||
int outVecWidthBits = 32;
|
||||
if (srcTy.isFloat8E4M3FNUZ()) {
|
||||
if (srcTy.isFloat8E4M3FNUZ() ||
|
||||
(computeCapability >= 90 && srcTy.isFloat8E5M2())) {
|
||||
inVecWidthBits = 16;
|
||||
outVecWidthBits = 32;
|
||||
}
|
||||
if (dstTy.isFloat8E4M3FNUZ()) {
|
||||
if (dstTy.isFloat8E4M3FNUZ() ||
|
||||
(computeCapability >= 90 && dstTy.isFloat8E5M2())) {
|
||||
inVecWidthBits = 32;
|
||||
outVecWidthBits = 16;
|
||||
}
|
||||
@@ -1450,18 +1618,24 @@ struct FpToFpOpConversion
|
||||
|
||||
size_t numElements = 4;
|
||||
if (srcElementType.isFloat8E4M3FNUZ() ||
|
||||
dstElementType.isFloat8E4M3FNUZ()) {
|
||||
dstElementType.isFloat8E4M3FNUZ() ||
|
||||
(computeCapability >= 90 &&
|
||||
(srcElementType.isFloat8E5M2() || dstElementType.isFloat8E5M2()))) {
|
||||
numElements = 2;
|
||||
}
|
||||
bool isSrcFP32 = srcElementType.isF32();
|
||||
bool useFP16IntermediateSrc =
|
||||
srcElementType.isF32() &&
|
||||
!(computeCapability >= 90 &&
|
||||
(dstElementType.isFloat8E4M3FNUZ() || dstElementType.isFloat8E5M2()));
|
||||
bool isDstFP32 = dstElementType.isF32();
|
||||
auto cvtFunc = getConversionFunc(isSrcFP32 ? f16_ty : srcElementType,
|
||||
isDstFP32 ? f16_ty : dstElementType);
|
||||
auto cvtFunc =
|
||||
getConversionFunc(useFP16IntermediateSrc ? f16_ty : srcElementType,
|
||||
isDstFP32 ? f16_ty : dstElementType);
|
||||
SmallVector<Value> inVals;
|
||||
for (unsigned i = 0; i < std::min(numElements, operands.size()); i++) {
|
||||
inVals.push_back(operands[i][0]);
|
||||
}
|
||||
if (isSrcFP32)
|
||||
if (useFP16IntermediateSrc)
|
||||
for (Value &v : inVals)
|
||||
v = convertFp32ToFp16(loc, rewriter, v);
|
||||
inVals.resize(numElements,
|
||||
@@ -2115,18 +2289,18 @@ void populateElementwiseOpToLLVMPatterns(
|
||||
POPULATE_BINARY_OP(arith::RemFOp, LLVM::FRemOp) // %
|
||||
POPULATE_BINARY_OP(arith::RemSIOp, LLVM::SRemOp)
|
||||
POPULATE_BINARY_OP(arith::RemUIOp, LLVM::URemOp)
|
||||
POPULATE_BINARY_OP(arith::AndIOp, LLVM::AndOp) // &
|
||||
POPULATE_BINARY_OP(arith::OrIOp, LLVM::OrOp) // |
|
||||
POPULATE_BINARY_OP(arith::XOrIOp, LLVM::XOrOp) // ^
|
||||
POPULATE_BINARY_OP(arith::ShLIOp, LLVM::ShlOp) // <<
|
||||
POPULATE_BINARY_OP(arith::ShRSIOp, LLVM::AShrOp) // >>
|
||||
POPULATE_BINARY_OP(arith::ShRUIOp, LLVM::LShrOp) // >>
|
||||
POPULATE_BINARY_OP(arith::MinFOp, LLVM::MinNumOp) // fmin
|
||||
POPULATE_BINARY_OP(arith::MaxFOp, LLVM::MaxNumOp) // fmax
|
||||
POPULATE_BINARY_OP(arith::MinSIOp, LLVM::SMinOp) // smin
|
||||
POPULATE_BINARY_OP(arith::MaxSIOp, LLVM::SMaxOp) // smax
|
||||
POPULATE_BINARY_OP(arith::MinUIOp, LLVM::UMinOp) // umin
|
||||
POPULATE_BINARY_OP(arith::MaxUIOp, LLVM::UMaxOp) // umax
|
||||
POPULATE_BINARY_OP(arith::AndIOp, LLVM::AndOp) // &
|
||||
POPULATE_BINARY_OP(arith::OrIOp, LLVM::OrOp) // |
|
||||
POPULATE_BINARY_OP(arith::XOrIOp, LLVM::XOrOp) // ^
|
||||
POPULATE_BINARY_OP(arith::ShLIOp, LLVM::ShlOp) // <<
|
||||
POPULATE_BINARY_OP(arith::ShRSIOp, LLVM::AShrOp) // >>
|
||||
POPULATE_BINARY_OP(arith::ShRUIOp, LLVM::LShrOp) // >>
|
||||
POPULATE_BINARY_OP(arith::MinimumFOp, LLVM::MinNumOp) // fmin
|
||||
POPULATE_BINARY_OP(arith::MaximumFOp, LLVM::MaxNumOp) // fmax
|
||||
POPULATE_BINARY_OP(arith::MinSIOp, LLVM::SMinOp) // smin
|
||||
POPULATE_BINARY_OP(arith::MaxSIOp, LLVM::SMaxOp) // smax
|
||||
POPULATE_BINARY_OP(arith::MinUIOp, LLVM::UMinOp) // umin
|
||||
POPULATE_BINARY_OP(arith::MaxUIOp, LLVM::UMaxOp) // umax
|
||||
#undef POPULATE_BINARY_OP
|
||||
|
||||
#define POPULATE_UNARY_OP(SRC_OP, DST_OP) \
|
||||
|
||||
@@ -1549,10 +1549,12 @@ struct InsertSliceAsyncOpConversion
|
||||
auto srcTy = src.getType().cast<RankedTensorType>();
|
||||
auto resTy = dst.getType().cast<RankedTensorType>();
|
||||
auto resElemTy = getTypeConverter()->convertType(resTy.getElementType());
|
||||
auto srcBlockedLayout = srcTy.getEncoding().cast<BlockedEncodingAttr>();
|
||||
auto srcLayout = srcTy.getEncoding();
|
||||
assert((srcLayout.isa<BlockedEncodingAttr, SliceEncodingAttr>() &&
|
||||
"Unexpected srcLayout in InsertSliceAsyncOpConversion"));
|
||||
auto resSharedLayout = resTy.getEncoding().cast<SharedEncodingAttr>();
|
||||
auto srcShape = srcTy.getShape();
|
||||
assert(srcShape.size() == 2 &&
|
||||
assert((srcShape.size() == 1 || srcShape.size() == 2) &&
|
||||
"insert_slice_async: Unexpected rank of %src");
|
||||
|
||||
Value llDst = adaptor.getDst();
|
||||
@@ -1617,25 +1619,15 @@ struct InsertSliceAsyncOpConversion
|
||||
unsigned numElems = getTotalElemsPerThread(srcTy);
|
||||
unsigned perPhase = resSharedLayout.getPerPhase();
|
||||
unsigned maxPhase = resSharedLayout.getMaxPhase();
|
||||
auto sizePerThread = srcBlockedLayout.getSizePerThread();
|
||||
auto threadsPerCTA = getThreadsPerCTA(srcBlockedLayout);
|
||||
auto inOrder = srcBlockedLayout.getOrder();
|
||||
DenseMap<unsigned, Value> sharedPtrs =
|
||||
getSwizzledSharedPtrs(loc, inVec, srcTy, resSharedLayout, resElemTy,
|
||||
smemObj, rewriter, offsetVals, srcStrides);
|
||||
|
||||
// If perPhase * maxPhase > threadsPerCTA, we will have elements
|
||||
// that share the same tile indices. The index calculation will
|
||||
// be cached.
|
||||
auto numSwizzleRows = std::max<unsigned>(
|
||||
(perPhase * maxPhase) / threadsPerCTA[inOrder[1]], 1);
|
||||
// A sharedLayout encoding has a "vec" parameter.
|
||||
// On the column dimension, if inVec > outVec, it means we have to divide
|
||||
// single vector read into multiple ones
|
||||
auto numVecCols = std::max<unsigned>(inVec / outVec, 1);
|
||||
|
||||
auto srcIndices = emitIndices(loc, rewriter, srcBlockedLayout, srcTy);
|
||||
|
||||
for (unsigned elemIdx = 0; elemIdx < numElems; elemIdx += minVec) {
|
||||
// 16 * 8 = 128bits
|
||||
auto maxBitWidth =
|
||||
|
||||
@@ -419,16 +419,15 @@ private:
|
||||
getMultiDimWarpId(helper, warpId, loc, rewriter);
|
||||
Value warpIdAxis = multiDimWarpId[axis];
|
||||
|
||||
if (!helper.isReductionOnLayoutFastAxis()) {
|
||||
std::reverse(order.begin(), order.end());
|
||||
}
|
||||
auto smemOrder = helper.getOrderWithAxisAtBeginning();
|
||||
for (auto it : accs) {
|
||||
const SmallVector<unsigned> &key = it.first;
|
||||
SmallVector<Value> acc = it.second;
|
||||
|
||||
SmallVector<Value> writeIdx = indices[key];
|
||||
writeIdx[axis] = warpIdAxis;
|
||||
Value writeOffset = linearize(rewriter, loc, writeIdx, smemShape, order);
|
||||
Value writeOffset =
|
||||
linearize(rewriter, loc, writeIdx, smemShape, smemOrder);
|
||||
for (unsigned i = 0; i < op.getNumOperands(); ++i) {
|
||||
auto elemPtrTy = getElementPtrType(op, i);
|
||||
Value writePtr = gep(elemPtrTy, smemBases[i], writeOffset);
|
||||
@@ -513,10 +512,7 @@ private:
|
||||
Location loc = op.getLoc();
|
||||
auto srcLayout = helper.getSrcLayout();
|
||||
auto axis = op.getAxis();
|
||||
auto order = getOrder(srcLayout);
|
||||
if (!helper.isReductionOnLayoutFastAxis()) {
|
||||
std::reverse(order.begin(), order.end());
|
||||
}
|
||||
auto smemOrder = helper.getOrderWithAxisAtBeginning();
|
||||
SmallVector<Value> results(op.getNumOperands());
|
||||
for (unsigned i = 0; i < op.getNumOperands(); ++i) {
|
||||
if (auto resultTy =
|
||||
@@ -532,7 +528,7 @@ private:
|
||||
SmallVector<Value> readIdx = resultIndices[j];
|
||||
readIdx.insert(readIdx.begin() + op.getAxis(), i32_val(0));
|
||||
Value readOffset =
|
||||
linearize(rewriter, loc, readIdx, smemShape, order);
|
||||
linearize(rewriter, loc, readIdx, smemShape, smemOrder);
|
||||
Value readPtr =
|
||||
gep(getElementPtrType(op, i), smemBases[i], readOffset);
|
||||
resultVals[j] = load(readPtr);
|
||||
|
||||
@@ -622,10 +622,13 @@ struct AllocTensorOpConversion
|
||||
// TODO: we need to modify the pipeline pass to give a proper shared
|
||||
// encoding to 3D tensors
|
||||
SmallVector<unsigned> newOrder;
|
||||
if (resultTy.getShape().size() == 3)
|
||||
newOrder = {1 + order[0], 1 + order[1], 0};
|
||||
else
|
||||
if (resultTy.getShape().size() != order.size()) {
|
||||
for (auto i = 0; i < order.size(); ++i)
|
||||
newOrder.push_back(order[i] + 1);
|
||||
newOrder.push_back(0);
|
||||
} else {
|
||||
newOrder = SmallVector<unsigned>(order.begin(), order.end());
|
||||
}
|
||||
|
||||
auto shapePerCTA = getShapePerCTA(sharedLayout, resultTy.getShape());
|
||||
auto smemObj =
|
||||
@@ -659,10 +662,13 @@ struct ExtractSliceOpConversion
|
||||
SmallVector<Value, 4> opOffsetVals;
|
||||
SmallVector<Value, 4> offsetVals;
|
||||
auto mixedOffsets = op.getMixedOffsets();
|
||||
for (auto i = 0; i < mixedOffsets.size(); ++i) {
|
||||
if (op.isDynamicOffset(i))
|
||||
opOffsetVals.emplace_back(adaptor.getOffsets()[i]);
|
||||
else
|
||||
for (auto i = 0, j = 0; i < mixedOffsets.size(); ++i) {
|
||||
if (op.isDynamicOffset(i)) {
|
||||
// adaptor.getOffsets() returns list of variable offsets. the size of
|
||||
// the list may not be the same as mixedOffsets
|
||||
opOffsetVals.emplace_back(adaptor.getOffsets()[j]);
|
||||
++j;
|
||||
} else
|
||||
opOffsetVals.emplace_back(i32_val(op.getStaticOffset(i)));
|
||||
offsetVals.emplace_back(add(smemObj.offsets[i], opOffsetVals[i]));
|
||||
}
|
||||
|
||||
@@ -146,7 +146,8 @@ protected:
|
||||
}
|
||||
auto newFuncOp = rewriter.create<LLVM::LLVMFuncOp>(
|
||||
funcOp.getLoc(), funcOp.getName(), llvmType, linkage,
|
||||
/*dsoLocal*/ false, LLVM::CConv::C, attributes);
|
||||
/*dsoLocal*/ false, LLVM::CConv::C, /*comdat=*/SymbolRefAttr{},
|
||||
attributes);
|
||||
rewriter.inlineRegionBefore(funcOp.getBody(), newFuncOp.getBody(),
|
||||
newFuncOp.end());
|
||||
if (failed(rewriter.convertRegionTypes(&newFuncOp.getBody(), *typeConverter,
|
||||
@@ -361,8 +362,13 @@ public:
|
||||
unsigned numElemsPerSwizzlingRow =
|
||||
swizzlingByteWidth * 8 / resElemTy.getIntOrFloatBitWidth();
|
||||
Value numElemsPerSwizzlingRowVal = i32_val(numElemsPerSwizzlingRow);
|
||||
unsigned leadingDimOffset =
|
||||
numElemsPerSwizzlingRow * srcShapePerCTA[outOrder[1]];
|
||||
unsigned leadingDimOffset;
|
||||
if (outOrder.size() == 2) {
|
||||
leadingDimOffset = numElemsPerSwizzlingRow * srcShapePerCTA[outOrder[1]];
|
||||
} else {
|
||||
leadingDimOffset = numElemsPerSwizzlingRow;
|
||||
}
|
||||
|
||||
Value leadingDimOffsetVal = i32_val(leadingDimOffset);
|
||||
// Return values
|
||||
DenseMap<unsigned, Value> ret;
|
||||
@@ -374,9 +380,15 @@ public:
|
||||
// Extract multi dimensional index for current element
|
||||
auto idx = srcIndices[elemIdx];
|
||||
Value idxCol = idx[outOrder[0]]; // contiguous dimension
|
||||
Value idxRow = idx[outOrder[1]]; // discontiguous dimension
|
||||
Value idxRow, strideRow;
|
||||
if (outOrder.size() == 2) {
|
||||
idxRow = idx[outOrder[1]]; // discontiguous dimension
|
||||
strideRow = srcStrides[outOrder[1]];
|
||||
} else {
|
||||
idxRow = i32_val(0);
|
||||
strideRow = i32_val(0);
|
||||
}
|
||||
Value strideCol = srcStrides[outOrder[0]];
|
||||
Value strideRow = srcStrides[outOrder[1]];
|
||||
// compute phase = (row // perPhase) % maxPhase
|
||||
Value phase = urem(udiv(idxRow, i32_val(perPhase)), i32_val(maxPhase));
|
||||
// extract dynamic/static offset for immediate offsetting
|
||||
@@ -428,10 +440,16 @@ public:
|
||||
offset = add(offset, add(rowOff, mul(colOff, strideCol)));
|
||||
Value currPtr = gep(dstPtrTy, dstPtrBase, offset);
|
||||
// compute immediate offset
|
||||
Value immedateOff =
|
||||
add(mul(i32_val(immedateOffRow), srcStrides[outOrder[1]]),
|
||||
i32_val(immedateOffCol));
|
||||
ret[elemIdx] = gep(dstPtrTy, currPtr, immedateOff);
|
||||
Value immediateOff;
|
||||
if (outOrder.size() == 2) {
|
||||
immediateOff =
|
||||
add(mul(i32_val(immedateOffRow), srcStrides[outOrder[1]]),
|
||||
i32_val(immedateOffCol));
|
||||
} else {
|
||||
immediateOff = i32_val(immedateOffCol);
|
||||
}
|
||||
|
||||
ret[elemIdx] = gep(dstPtrTy, currPtr, immediateOff);
|
||||
}
|
||||
return ret;
|
||||
}
|
||||
|
||||
@@ -371,13 +371,15 @@ static Value commonShflSync(Location loc, ConversionPatternRewriter &rewriter,
|
||||
Type type = val.getType();
|
||||
if (type != i32_ty) {
|
||||
val = bitcast(val, int_ty(bits));
|
||||
val = zext(i32_ty, val);
|
||||
if (bits < 32)
|
||||
val = zext(i32_ty, val);
|
||||
}
|
||||
Value mask = i32_val(0xFFFFFFFF);
|
||||
Value result = rewriter.create<NVVM::ShflOp>(loc, i32_ty, mask, val, i, clamp,
|
||||
mode, UnitAttr());
|
||||
if (type != i32_ty) {
|
||||
result = trunc(int_ty(bits), result);
|
||||
if (bits < 32)
|
||||
result = trunc(int_ty(bits), result);
|
||||
result = bitcast(result, type);
|
||||
}
|
||||
return result;
|
||||
|
||||
@@ -97,7 +97,7 @@
|
||||
ptxBuilder.launch(rewriter, op->getLoc(), voidTy); \
|
||||
} while (0)
|
||||
#define undef(...) rewriter.create<LLVM::UndefOp>(loc, __VA_ARGS__)
|
||||
#define null(...) rewriter.create<LLVM::NullOp>(loc, __VA_ARGS__)
|
||||
#define null(...) rewriter.create<LLVM::ZeroOp>(loc, __VA_ARGS__)
|
||||
#define call(...) rewriter.create<LLVM::CallOp>(loc, __VA_ARGS__)
|
||||
|
||||
// Types
|
||||
|
||||
@@ -115,8 +115,8 @@ void populateArithPatternsAndLegality(TritonGPUTypeConverter &typeConverter,
|
||||
// Floating point
|
||||
GenericOpPattern<arith::AddFOp>, GenericOpPattern<arith::SubFOp>,
|
||||
// MaxMin
|
||||
GenericOpPattern<arith::MaxFOp>, GenericOpPattern<arith::MaxSIOp>,
|
||||
GenericOpPattern<arith::MaxUIOp>, GenericOpPattern<arith::MinFOp>,
|
||||
GenericOpPattern<arith::MaximumFOp>, GenericOpPattern<arith::MaxSIOp>,
|
||||
GenericOpPattern<arith::MaxUIOp>, GenericOpPattern<arith::MinimumFOp>,
|
||||
GenericOpPattern<arith::MinSIOp>, GenericOpPattern<arith::MinUIOp>,
|
||||
// Floating point
|
||||
GenericOpPattern<arith::MulFOp>, GenericOpPattern<arith::DivFOp>,
|
||||
|
||||
@@ -1,9 +1,9 @@
|
||||
#include "mlir/IR/Builders.h"
|
||||
#include "mlir/IR/BuiltinAttributes.h"
|
||||
#include "mlir/IR/BuiltinTypes.h"
|
||||
#include "mlir/IR/FunctionImplementation.h"
|
||||
#include "mlir/IR/FunctionInterfaces.h"
|
||||
#include "mlir/IR/OperationSupport.h"
|
||||
#include "mlir/Interfaces/FunctionImplementation.h"
|
||||
#include "mlir/Interfaces/FunctionInterfaces.h"
|
||||
#include "triton/Dialect/Triton/IR/Dialect.h"
|
||||
#include "triton/Dialect/Triton/IR/Types.h"
|
||||
|
||||
|
||||
@@ -8,6 +8,6 @@ add_mlir_dialect_library(TritonGPUIR
|
||||
TritonGPUAttrDefsIncGen
|
||||
|
||||
LINK_LIBS PUBLIC
|
||||
MLIRGPUOps
|
||||
MLIRGPUDialect
|
||||
TritonIR
|
||||
)
|
||||
|
||||
@@ -1603,7 +1603,7 @@ ParseResult parseInsertSliceOp(OpAsmParser &parser, OperationState &result) {
|
||||
result.operands))
|
||||
return failure();
|
||||
|
||||
// Deduce operand_segment_sizes from the number of the operands.
|
||||
// Deduce operandSegmentSizes from the number of the operands.
|
||||
auto operandSegmentSizesAttrName =
|
||||
OpT::getOperandSegmentSizesAttrName(result.name);
|
||||
result.addAttribute(
|
||||
@@ -1616,7 +1616,7 @@ template <class OpT>
|
||||
void printInsertSliceOp(OpAsmPrinter &printer, OpT insertSliceOp) {
|
||||
printer << " ";
|
||||
printer << insertSliceOp.getOperation()->getOperands();
|
||||
// "operand_segment_sizes" can be deduced, so we don't print it.
|
||||
// "operandSegmentSizes" can be deduced, so we don't print it.
|
||||
printer.printOptionalAttrDict(
|
||||
insertSliceOp->getAttrs(),
|
||||
{insertSliceOp.getOperandSegmentSizesAttrName()});
|
||||
|
||||
@@ -139,7 +139,10 @@ class BlockedToMMA : public mlir::RewritePattern {
|
||||
int finalBitWidth = getElementTypeOrSelf(x).getIntOrFloatBitWidth();
|
||||
int origBitWidth = finalBitWidth;
|
||||
SetVector<Operation *> slice;
|
||||
mlir::getBackwardSlice(x, &slice, bwdFilter);
|
||||
mlir::BackwardSliceOptions opt;
|
||||
opt.omitBlockArguments = true;
|
||||
opt.filter = bwdFilter;
|
||||
getBackwardSlice(x, &slice, opt);
|
||||
Operation *firstOp = slice.empty() ? nullptr : *slice.begin();
|
||||
if (firstOp)
|
||||
if (Value arg = firstOp->getOperand(0))
|
||||
@@ -235,8 +238,11 @@ public:
|
||||
if (versionMajor == 1) {
|
||||
SetVector<Operation *> aBwdSlices, bBwdSlices;
|
||||
auto isCvt = [](Operation *op) { return isa<ConvertLayoutOp>(op); };
|
||||
getBackwardSlice(a, &aBwdSlices, {isCvt});
|
||||
getBackwardSlice(b, &bBwdSlices, {isCvt});
|
||||
mlir::BackwardSliceOptions opt;
|
||||
opt.omitBlockArguments = true;
|
||||
opt.filter = isCvt;
|
||||
getBackwardSlice(a, &aBwdSlices, opt);
|
||||
getBackwardSlice(b, &bBwdSlices, opt);
|
||||
// get the source of the first conversion found in slices
|
||||
auto getCvtArgOrder = [](Operation *op) {
|
||||
return cast<ConvertLayoutOp>(op)
|
||||
|
||||
@@ -98,7 +98,9 @@ public:
|
||||
// and all operations between the load and the conversion
|
||||
// should be layout preserving
|
||||
SetVector<Operation *> slice;
|
||||
getBackwardSlice(op, &slice);
|
||||
mlir::BackwardSliceOptions opt;
|
||||
opt.omitBlockArguments = true;
|
||||
getBackwardSlice(op, &slice, opt);
|
||||
int loadIdx = -1;
|
||||
bool checkOp = false;
|
||||
for (int i = 0; i < slice.size(); i++) {
|
||||
|
||||
@@ -160,6 +160,8 @@ class LoopPipeliner {
|
||||
void checkOpShareBarriers(SetVector<Operation *> &ops);
|
||||
int numLoadsRequireAsyncWait = 0;
|
||||
int numLoadsRequireMBarrier = 0;
|
||||
// Number of buffers to allocate for each input.
|
||||
int numSharedMemorySlices = 0;
|
||||
|
||||
/// Iterator values
|
||||
Value nextIV;
|
||||
@@ -280,9 +282,12 @@ class LoopPipeliner {
|
||||
|
||||
public:
|
||||
LoopPipeliner(scf::ForOp forOp, int numStages, int numWarps, int numCTAs,
|
||||
bool mode, ConsumerReleaseMap &consumerReleaseMap)
|
||||
bool mode, int numSharedMemorySlices,
|
||||
ConsumerReleaseMap &consumerReleaseMap)
|
||||
: forOp(forOp), numStages(numStages), numWarps(numWarps),
|
||||
numCTAs(numCTAs), mode(mode), consumerReleaseMap(consumerReleaseMap) {
|
||||
numCTAs(numCTAs), mode(mode),
|
||||
numSharedMemorySlices(numSharedMemorySlices),
|
||||
consumerReleaseMap(consumerReleaseMap) {
|
||||
// cache yieldOp
|
||||
yieldOp = cast<scf::YieldOp>(forOp.getBody()->getTerminator());
|
||||
}
|
||||
@@ -644,7 +649,7 @@ void LoopPipeliner::createBufferTypes() {
|
||||
auto ty = loadOp.getType().cast<RankedTensorType>();
|
||||
SmallVector<int64_t> bufferShape(ty.getShape().begin(),
|
||||
ty.getShape().end());
|
||||
bufferShape.insert(bufferShape.begin(), numStages);
|
||||
bufferShape.insert(bufferShape.begin(), numSharedMemorySlices);
|
||||
auto CTALayout = ttg::getCTALayout(ty.getEncoding());
|
||||
Attribute sharedEnc;
|
||||
if (auto dotOpEnc = cvt.getType()
|
||||
@@ -946,6 +951,11 @@ void LoopPipeliner::emitPrologue() {
|
||||
pipelineIterIdx = builder.create<arith::AddIOp>(
|
||||
iv.getLoc(), pipelineIterIdx,
|
||||
builder.create<arith::ConstantIntOp>(iv.getLoc(), 1, 32));
|
||||
Value numSlices = builder.create<arith::ConstantIntOp>(
|
||||
iv.getLoc(), numSharedMemorySlices, 32);
|
||||
Value _0 = builder.create<arith::ConstantIntOp>(iv.getLoc(), 0, 32);
|
||||
pipelineIterIdx = getBoundedIterationValue(builder, pipelineIterIdx,
|
||||
numSlices, pipelineIterIdx, _0);
|
||||
// Some values have not been used by any ops in the loop body
|
||||
for (BlockArgument arg : forOp.getRegionIterArgs())
|
||||
setValueMappingYield(arg, valueMapping[arg][stage], stage + 1);
|
||||
@@ -1220,11 +1230,13 @@ void LoopPipeliner::prefetchNextIteration(scf::ForOp newForOp,
|
||||
Value _1 = builder.create<arith::ConstantIntOp>(idxLoc, 1, 32);
|
||||
Value numStagesVal =
|
||||
builder.create<arith::ConstantIntOp>(idxLoc, numStages, 32);
|
||||
Value numSlices =
|
||||
builder.create<arith::ConstantIntOp>(idxLoc, numSharedMemorySlices, 32);
|
||||
|
||||
// nextWaitIdx
|
||||
Value waitIdxPlusOne = builder.create<arith::AddIOp>(idxLoc, curWaitIdx, _1);
|
||||
Value nextWaitIdx = getBoundedIterationValue(
|
||||
builder, waitIdxPlusOne, numStagesVal, waitIdxPlusOne, _0);
|
||||
Value nextWaitIdx = getBoundedIterationValue(builder, waitIdxPlusOne,
|
||||
numSlices, waitIdxPlusOne, _0);
|
||||
|
||||
// Indices of InsertSliceAsyncOp and ExtractSliceOp
|
||||
Value insertSliceIndex = pipelineIterIdx;
|
||||
@@ -1417,9 +1429,8 @@ void LoopPipeliner::prefetchNextIteration(scf::ForOp newForOp,
|
||||
// Bump pipelineIterIdx
|
||||
Value pipelineIterIdxPlusOne =
|
||||
builder.create<arith::AddIOp>(idxLoc, pipelineIterIdx, _1);
|
||||
pipelineIterIdx =
|
||||
getBoundedIterationValue(builder, pipelineIterIdxPlusOne, numStagesVal,
|
||||
pipelineIterIdxPlusOne, _0);
|
||||
pipelineIterIdx = getBoundedIterationValue(
|
||||
builder, pipelineIterIdxPlusOne, numSlices, pipelineIterIdxPlusOne, _0);
|
||||
|
||||
// Bump curWaitIdx
|
||||
curWaitIdx = nextWaitIdx;
|
||||
@@ -1516,10 +1527,23 @@ struct PipelinePass : public TritonGPUPipelineBase<PipelinePass> {
|
||||
// applyPatternsAndFoldGreedily(getOperation(), std::move(patterns));
|
||||
|
||||
llvm::SmallVector<scf::ForOp> newForOps;
|
||||
|
||||
// Currently we schedule stage 0 after stage `numStages - 1` during
|
||||
// pipelining therefore we only need `numStages - 1` slice of memory.
|
||||
// On Hopper we have a separate post-processing that pipelines wgmma so we
|
||||
// need an extra buffer for each input.
|
||||
// Note that an alternative would be to keep allocating `numStages` buffers
|
||||
// and remove the barrier between the loads from shared memory and the
|
||||
// copies from global to shared. This would require improving existing
|
||||
// membar analysis.
|
||||
int numSharedMemorySlices =
|
||||
computeCapability < 90 ? numStages - 1 : numStages;
|
||||
|
||||
// Do the pipelining
|
||||
getOperation()->walk([&](scf::ForOp forOp) -> void {
|
||||
LoopPipeliner pipeliner(forOp, this->numStages, this->numWarps,
|
||||
this->numCTAs, mode, consumerReleaseMap);
|
||||
this->numCTAs, mode, numSharedMemorySlices,
|
||||
consumerReleaseMap);
|
||||
if (pipeliner.initialize().failed())
|
||||
return;
|
||||
|
||||
@@ -1593,7 +1617,8 @@ void PipelinePass::asyncLaunchDots(scf::ForOp forOp) {
|
||||
|
||||
/// XXX(Keren): Clean up the following duplicate code with checkDotOp
|
||||
/// dots to be pipelined
|
||||
SetVector<Value> dots;
|
||||
SmallVector<tt::DotOp> dots;
|
||||
SmallVector<unsigned> resultNeedSync;
|
||||
for (Operation &op : *loop) {
|
||||
if (auto dotOp = dyn_cast<tt::DotOp>(&op)) {
|
||||
auto resTy = dotOp.getResult().getType().dyn_cast<RankedTensorType>();
|
||||
@@ -1615,8 +1640,11 @@ void PipelinePass::asyncLaunchDots(scf::ForOp forOp) {
|
||||
if (!CArg || !CArg.hasOneUse())
|
||||
valid = false;
|
||||
|
||||
if (valid)
|
||||
dots.insert(dotOp);
|
||||
if (valid) {
|
||||
dots.push_back(dotOp);
|
||||
resultNeedSync.push_back(
|
||||
dotOp->getUses().begin()->getOperandNumber());
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
@@ -1627,39 +1655,39 @@ void PipelinePass::asyncLaunchDots(scf::ForOp forOp) {
|
||||
return;
|
||||
|
||||
OpBuilder builder(forOp);
|
||||
|
||||
// 0. insert dot_wait after the last dot in the loop
|
||||
Value dot = dots.back();
|
||||
auto loc = dot.getLoc();
|
||||
builder.setInsertionPointAfter(dot.getDefiningOp());
|
||||
auto dotWait = builder.create<tt::nvidia_gpu::DotWaitOp>(loc, dots.size());
|
||||
// 0. insert dot_wait after the last dot in the loop as we implicitly pipeline
|
||||
// wgmma ops by one stage.
|
||||
// This is needed to prevent shared memory inputs to be overriden before the
|
||||
// operation is completed.
|
||||
// TODO: merge this with the rest of the pipelining transformation and look at
|
||||
// a better representation for async dots.
|
||||
tt::DotOp lastDot = dots.back();
|
||||
builder.setInsertionPointAfter(lastDot);
|
||||
auto dotWait = builder.create<tt::nvidia_gpu::DotWaitOp>(
|
||||
lastDot.getLoc(), lastDot.getResult(), dots.size());
|
||||
|
||||
// 1. replace Dot with DotAsync
|
||||
for (size_t idx = 0; idx < dots.size(); ++idx) {
|
||||
Value dot = dots[idx];
|
||||
auto dotOp = cast<tt::DotOp>(dot.getDefiningOp());
|
||||
builder.setInsertionPoint(dot.getDefiningOp());
|
||||
tt::DotOp dotOp = dots[idx];
|
||||
builder.setInsertionPoint(dotOp);
|
||||
auto dotAsync = builder.create<tt::nvidia_gpu::DotAsyncOp>(
|
||||
loc, dotOp.getA(), dotOp.getB(), dotOp.getC(), dotOp.getAllowTF32(),
|
||||
dotOp.getMaxNumImpreciseAcc());
|
||||
dot.replaceAllUsesWith(dotAsync.getResult());
|
||||
updateConsumerReleaseInfo(dot.getDefiningOp(), dotWait, /*stage=*/1);
|
||||
dot.getDefiningOp()->erase();
|
||||
dotOp.getLoc(), dotOp.getA(), dotOp.getB(), dotOp.getC(),
|
||||
dotOp.getAllowTF32(), dotOp.getMaxNumImpreciseAcc());
|
||||
dotOp.replaceAllUsesWith(dotAsync.getResult());
|
||||
updateConsumerReleaseInfo(dotOp, dotWait, /*stage=*/1);
|
||||
dotOp->erase();
|
||||
}
|
||||
|
||||
// 2. If there's any outstanding DotAsyncOps, we need to wait for them.
|
||||
builder.setInsertionPointAfter(forOp);
|
||||
Value loopNotEmpty = builder.create<arith::CmpIOp>(
|
||||
loc, arith::CmpIPredicate::slt, forOp.getLowerBound(),
|
||||
forOp.getUpperBound());
|
||||
// TODO[goostavz]: it's a workaround to put the DotWaitOp in an IfOp for
|
||||
// a bug in ptxas which mistakenly analysis the control flow and turn the GMMA
|
||||
// into synchronuous implementation for safety.
|
||||
// Remove this If once the bug is fixed.
|
||||
auto ifOp = builder.create<scf::IfOp>(loc, ArrayRef<Type>{}, loopNotEmpty,
|
||||
/*hasElse*/ false);
|
||||
builder.setInsertionPointToStart(ifOp.thenBlock());
|
||||
builder.create<tt::nvidia_gpu::DotWaitOp>(forOp.getLoc(), 0);
|
||||
for (unsigned resultIndex : resultNeedSync) {
|
||||
Value result = forOp->getResult(resultIndex);
|
||||
if (result.use_empty())
|
||||
continue;
|
||||
auto dotWait =
|
||||
builder.create<tt::nvidia_gpu::DotWaitOp>(forOp.getLoc(), result, 0);
|
||||
result.replaceAllUsesExcept(dotWait.getResult(), dotWait);
|
||||
}
|
||||
}
|
||||
|
||||
Value PipelinePass::getRemoteCTAId(OpBuilder &b, Location loc,
|
||||
|
||||
@@ -31,6 +31,7 @@ using triton::gpu::SliceEncodingAttr;
|
||||
//
|
||||
// -----------------------------------------------------------------------------
|
||||
|
||||
<<<<<<< HEAD
|
||||
// convert(blocked, dot_operand) ->
|
||||
// convert(blocked, mma) + convert(mma, dot_operand)
|
||||
// if this value is itself the result of a dot operation
|
||||
@@ -102,6 +103,9 @@ public:
|
||||
};
|
||||
|
||||
//
|
||||
=======
|
||||
// dot(a, b, load(ptr)) -> add(load(ptr), dot(a, b, 0))
|
||||
>>>>>>> 721897fcc4f942aa97d2e9ba3787a5e213758177
|
||||
class ConvertDotConvert : public mlir::RewritePattern {
|
||||
public:
|
||||
ConvertDotConvert(mlir::MLIRContext *context)
|
||||
@@ -233,12 +237,17 @@ static bool hasConvertToMMATransisitiveUse(Operation *op, Attribute encoding) {
|
||||
getForwardSlice(currentValue, &forwardSlice);
|
||||
for (Operation *op : forwardSlice) {
|
||||
if (auto convertOp = dyn_cast<triton::gpu::ConvertLayoutOp>(op)) {
|
||||
if (convertOp.getResult()
|
||||
.getType()
|
||||
.cast<RankedTensorType>()
|
||||
.getEncoding()
|
||||
.isa<triton::gpu::MmaEncodingAttr>())
|
||||
return true;
|
||||
Attribute dstEncoding = convertOp.getResult()
|
||||
.getType()
|
||||
.cast<RankedTensorType>()
|
||||
.getEncoding();
|
||||
if (auto mmaLayout =
|
||||
dstEncoding.dyn_cast<triton::gpu::MmaEncodingAttr>())
|
||||
return (mmaLayout.getVersionMajor() > 1) ? true
|
||||
: mmaLayout == encoding;
|
||||
if (dstEncoding.isa<triton::gpu::DotOperandEncodingAttr>())
|
||||
return encoding.cast<triton::gpu::MmaEncodingAttr>()
|
||||
.getVersionMajor() > 1;
|
||||
}
|
||||
auto yield = dyn_cast<scf::YieldOp>(op);
|
||||
if (!yield)
|
||||
@@ -560,6 +569,15 @@ Value LayoutPropagation::getValueAs(Value value, Attribute encoding) {
|
||||
return rewrittenValue;
|
||||
OpBuilder rewriter(value.getContext());
|
||||
rewriter.setInsertionPointAfterValue(rewrittenValue);
|
||||
// Workaround: The pipeliner will insert async.wait after a pipelined loop
|
||||
// to ensure that there is no pending copies and it is safe to re-use shared
|
||||
// memory. We shouldn't insert ops that may use shared memory in between the
|
||||
// loop and the async.wait. This is a hack until we fix the IR
|
||||
// representation of async wait.
|
||||
if (Operation *op = rewrittenValue.getDefiningOp()) {
|
||||
if (isa<triton::gpu::AsyncWaitOp>(op->getNextNode()))
|
||||
rewriter.setInsertionPointAfter(op->getNextNode());
|
||||
}
|
||||
auto tmpType = RankedTensorType::get(tensorType.getShape(),
|
||||
tensorType.getElementType(), encoding);
|
||||
Value converted = rewriter.create<triton::gpu::ConvertLayoutOp>(
|
||||
@@ -1122,7 +1140,6 @@ public:
|
||||
hoistConvert(m);
|
||||
|
||||
mlir::RewritePatternSet decomposePatterns(context);
|
||||
decomposePatterns.add<DecomposeDotOperand>(context);
|
||||
decomposePatterns.add<ConvertDotConvert>(context);
|
||||
if (mlir::applyPatternsAndFoldGreedily(m, std::move(decomposePatterns))
|
||||
.failed()) {
|
||||
|
||||
@@ -91,7 +91,7 @@ private:
|
||||
// suport ForOp only
|
||||
if (auto forOp = dyn_cast<scf::ForOp>(argOwner)) {
|
||||
// prologue
|
||||
auto iterOperands = forOp.getIterOperands();
|
||||
auto iterOperands = forOp.getInitArgs();
|
||||
if (argNum == 0)
|
||||
return false;
|
||||
if (dependOnSharedEncOperand(iterOperands[argNum - 1]))
|
||||
|
||||
@@ -628,12 +628,13 @@ bool CTAPlanner::isElementwiseOp(Operation *op) const {
|
||||
arith::CeilDivUIOp, arith::DivFOp, arith::DivSIOp,
|
||||
arith::DivUIOp, arith::ExtFOp, arith::ExtSIOp, arith::ExtUIOp,
|
||||
arith::FloorDivSIOp, arith::FPToSIOp, arith::FPToUIOp,
|
||||
arith::MaxFOp, arith::MaxSIOp, arith::MaxUIOp, arith::MinFOp,
|
||||
arith::MinSIOp, arith::MinUIOp, arith::MulFOp, arith::MulIOp,
|
||||
arith::NegFOp, arith::OrIOp, arith::RemFOp, arith::RemSIOp,
|
||||
arith::RemUIOp, arith::ShLIOp, arith::ShRSIOp, arith::ShRUIOp,
|
||||
arith::SIToFPOp, arith::SubFOp, arith::SubIOp, arith::TruncFOp,
|
||||
arith::TruncIOp, arith::UIToFPOp, arith::XOrIOp>(op))
|
||||
arith::MaximumFOp, arith::MaxSIOp, arith::MaxUIOp,
|
||||
arith::MinimumFOp, arith::MinSIOp, arith::MinUIOp,
|
||||
arith::MulFOp, arith::MulIOp, arith::NegFOp, arith::OrIOp,
|
||||
arith::RemFOp, arith::RemSIOp, arith::RemUIOp, arith::ShLIOp,
|
||||
arith::ShRSIOp, arith::ShRUIOp, arith::SIToFPOp, arith::SubFOp,
|
||||
arith::SubIOp, arith::TruncFOp, arith::TruncIOp,
|
||||
arith::UIToFPOp, arith::XOrIOp>(op))
|
||||
return true;
|
||||
if (llvm::isa<math::AbsFOp, math::AbsIOp, math::AtanOp, math::Atan2Op,
|
||||
math::CeilOp, math::CopySignOp, math::CosOp, math::SinOp,
|
||||
|
||||
@@ -220,7 +220,9 @@ public:
|
||||
SetVector<Operation *> backwardSlice;
|
||||
mod.walk([&](triton::MakeTensorPtrOp op) -> void {
|
||||
assert(isa<triton::FuncOp>(op->getParentOp()));
|
||||
getBackwardSlice(op.getOperation(), &backwardSlice);
|
||||
mlir::BackwardSliceOptions opt;
|
||||
opt.omitBlockArguments = true;
|
||||
getBackwardSlice(op.getOperation(), &backwardSlice, opt);
|
||||
op->removeAttr("async_agent");
|
||||
});
|
||||
for (auto op : backwardSlice) {
|
||||
|
||||
@@ -79,6 +79,7 @@ void materializeGetAgentIdOp(Operation *parentOp) {
|
||||
builder.setInsertionPoint(agentIdOp);
|
||||
Value globalRoleId = builder.create<arith::ConstantIntOp>(loc, 0, 32);
|
||||
int globalNumWarps = 0;
|
||||
SmallVector<Operation *> deprecatedOps;
|
||||
for (auto cmpOp : agentIdOp->getUsers()) {
|
||||
assert(isa<arith::CmpIOp>(cmpOp));
|
||||
for (auto u : cmpOp->getUsers()) {
|
||||
@@ -111,11 +112,14 @@ void materializeGetAgentIdOp(Operation *parentOp) {
|
||||
Value cond =
|
||||
builder.create<arith::AndIOp>(loc, lowerBound, upperBound);
|
||||
cmpOp->getResult(0).replaceAllUsesWith(cond);
|
||||
cmpOp->erase();
|
||||
deprecatedOps.push_back(cmpOp);
|
||||
break;
|
||||
}
|
||||
}
|
||||
}
|
||||
for (Operation *cmpOp : deprecatedOps) {
|
||||
cmpOp->erase();
|
||||
}
|
||||
});
|
||||
}
|
||||
|
||||
@@ -145,39 +149,24 @@ LoadType scanLoadTypes(ttng::CreateTokenOp createTokenOp) {
|
||||
}
|
||||
|
||||
Value getMBarrierPhaseBit(OpBuilder &builder, Operation *op,
|
||||
bool skipFirstWait) {
|
||||
bool emptyBarrier) {
|
||||
// TODO: currently we only support one loop, no nested loop, while or
|
||||
// condition.
|
||||
auto loc = op->getLoc();
|
||||
auto forOp = op->getParentOfType<scf::ForOp>();
|
||||
if (!forOp) {
|
||||
return builder.create<arith::ConstantIntOp>(loc, skipFirstWait, 1);
|
||||
return builder.create<arith::ConstantIntOp>(loc, emptyBarrier, 1);
|
||||
}
|
||||
|
||||
auto defOp = op->getOperand(0).getDefiningOp();
|
||||
assert(isa<ttng::CreateTokenOp>(defOp) &&
|
||||
"mbarrier's definingOp is not createTokenOp");
|
||||
ttng::CreateTokenOp createTokenOp = dyn_cast<ttng::CreateTokenOp>(defOp);
|
||||
Value numStage =
|
||||
builder.create<arith::ConstantIntOp>(loc, createTokenOp.getNum(), 32);
|
||||
Value curStep = forOp.getBody()->getArguments().back();
|
||||
if (curStep.getType() == builder.getIndexType()) {
|
||||
curStep =
|
||||
builder.create<arith::IndexCastOp>(loc, numStage.getType(), curStep);
|
||||
// for (..., phase, pipelineIdx)
|
||||
unsigned numArgs = forOp.getBody()->getNumArguments();
|
||||
assert(numArgs > 2 && "Unexpected number of arguments");
|
||||
Value curPhase = forOp.getBody()->getArgument(numArgs - 2);
|
||||
if (emptyBarrier) {
|
||||
Value _1_1b = builder.create<arith::ConstantIntOp>(loc, 1, 1);
|
||||
curPhase = builder.create<mlir::arith::XOrIOp>(loc, curPhase, _1_1b);
|
||||
}
|
||||
Value curPhase = builder.create<arith::DivUIOp>(loc, curStep, numStage);
|
||||
if (skipFirstWait) {
|
||||
// If skipFirstWait, it waits for phaseBit 1
|
||||
Value _1 = builder.create<arith::ConstantIntOp>(loc, 1, 32);
|
||||
curPhase = builder.create<arith::AddIOp>(loc, curPhase, _1);
|
||||
}
|
||||
Value _2 = builder.create<arith::ConstantIntOp>(loc, 2, 32);
|
||||
// TODO: May use alternative methods of phaseBit calculation to avoid high
|
||||
// overhead of RemOp
|
||||
Value phaseBit = builder.create<arith::RemUIOp>(loc, curPhase, _2);
|
||||
Value _0 = builder.create<arith::ConstantIntOp>(loc, 0, 32);
|
||||
return builder.create<arith::CmpIOp>(loc, arith::CmpIPredicate::ne, phaseBit,
|
||||
_0);
|
||||
return curPhase;
|
||||
}
|
||||
|
||||
int getTxBytes(ttng::InsertSliceAsyncV2Op load) {
|
||||
@@ -260,7 +249,7 @@ void processProducerAcquireOp(OpBuilder &builder, ttng::ProducerAcquireOp op,
|
||||
auto loc = op.getLoc();
|
||||
// The first producer_aquire should be met immediately, so initailly producer
|
||||
// skips the fisrt wait
|
||||
Value phase = getMBarrierPhaseBit(builder, op, 1);
|
||||
Value phase = getMBarrierPhaseBit(builder, op, true);
|
||||
auto waitOp = builder.create<ttng::MBarrierWaitOp>(loc, bufferEmpty, phase);
|
||||
assert(op.getOperation()->hasAttr("async_agent"));
|
||||
setAgentIds(waitOp, getAgentIds(op.getOperation()));
|
||||
@@ -296,7 +285,7 @@ void processProducerCommitOp(OpBuilder &builder, ttng::ProducerCommitOp op,
|
||||
void processConsumerWaitOp(OpBuilder &builder, ttng::ConsumerWaitOp op,
|
||||
Value bufferFull) {
|
||||
auto loc = op.getLoc();
|
||||
Value phase = getMBarrierPhaseBit(builder, op, 0);
|
||||
Value phase = getMBarrierPhaseBit(builder, op, false);
|
||||
auto waitOp = builder.create<ttng::MBarrierWaitOp>(loc, bufferFull, phase);
|
||||
assert(op.getOperation()->hasAttr("async_agent"));
|
||||
setAgentIds(waitOp, getAgentIds(op.getOperation()));
|
||||
@@ -530,6 +519,7 @@ void mutexSyncPingPang(Operation *parentOp, int numAgents, int &nameBarrierId,
|
||||
builder.create<arith::ConstantIntOp>(loc, nameBarrierId - 1, 32);
|
||||
// Process mutex users
|
||||
int numUsers = 0;
|
||||
SmallVector<Operation *> deprecatedOps;
|
||||
for (Operation *user : createMutexOp.getResult().getUsers()) {
|
||||
numUsers++;
|
||||
assert(numUsers <= 2);
|
||||
@@ -543,14 +533,20 @@ void mutexSyncPingPang(Operation *parentOp, int numAgents, int &nameBarrierId,
|
||||
Value barLeave = builder.create<arith::SelectOp>(
|
||||
loc, isRole0, namedBarrierId1, namedBarrierId0);
|
||||
builder.create<ttng::NamedBarrierArriveOp>(loc, barLeave, numThreads);
|
||||
} else
|
||||
} else {
|
||||
llvm_unreachable("Unexpected user of mutex");
|
||||
}
|
||||
deprecatedOps.push_back(user);
|
||||
}
|
||||
for (Operation *user : deprecatedOps) {
|
||||
user->erase();
|
||||
}
|
||||
nameBarrierId -= 2;
|
||||
nameBarrierIdEnd -= 2;
|
||||
createMutexOp.erase();
|
||||
});
|
||||
|
||||
parentOp->walk(
|
||||
[](ttng::CreateMutexOp createMutexOp) { createMutexOp.erase(); });
|
||||
}
|
||||
|
||||
void processLockOp(OpBuilder &builder, ttng::LockOp op) {
|
||||
@@ -587,6 +583,7 @@ void materializeMutexOperationsOthers(ModuleOp parentOp) {
|
||||
OpBuilder builder(createMutexOp);
|
||||
|
||||
// Process mutex users
|
||||
SmallVector<Operation *> deprecatedOps;
|
||||
for (Operation *user : createMutexOp.getResult().getUsers()) {
|
||||
auto loc = user->getLoc();
|
||||
builder.setInsertionPoint(user);
|
||||
@@ -596,6 +593,9 @@ void materializeMutexOperationsOthers(ModuleOp parentOp) {
|
||||
processUnlockOp(builder, op);
|
||||
else
|
||||
llvm_unreachable("Unexpected user of mutex");
|
||||
deprecatedOps.push_back(user);
|
||||
}
|
||||
for (Operation *user : deprecatedOps) {
|
||||
user->erase();
|
||||
}
|
||||
|
||||
|
||||
@@ -156,14 +156,20 @@ void mutexSync(ModuleOp &mod, scf::IfOp &ifOp, scf::ForOp &persistentForOp,
|
||||
persistentForOp.getInitArgsMutable()
|
||||
.slice(persistentForOp.getInitArgs().size() - 1, 1)
|
||||
.assign(newIdx);
|
||||
auto yield =
|
||||
llvm::cast<scf::YieldOp>(persistentForOp.getBody()->getTerminator());
|
||||
auto idxPlusOneOp =
|
||||
yield->getOperand(yield->getNumOperands() - 1).getDefiningOp();
|
||||
assert(isa<arith::AddIOp>(idxPlusOneOp));
|
||||
assert(idxPlusOneOp->getOperand(0) ==
|
||||
persistentForOp.getBody()->getArgument(
|
||||
persistentForOp.getBody()->getNumArguments() - 1));
|
||||
|
||||
pipelineIdx = persistentForOp.getBody()->getArgument(
|
||||
persistentForOp.getBody()->getNumArguments() - 1);
|
||||
Operation *idxPlusOneOp = nullptr;
|
||||
for (OpOperand &v : pipelineIdx.getUses()) {
|
||||
if (isa<arith::AddIOp>(v.getOwner())) {
|
||||
idxPlusOneOp = v.getOwner();
|
||||
break;
|
||||
}
|
||||
}
|
||||
assert(idxPlusOneOp && "idxPlusOneOp should be arith::AddIOp");
|
||||
Operation *use = *idxPlusOneOp->getUsers().begin();
|
||||
assert(isa<scf::YieldOp>(use) || isa<arith::SelectOp>(use) ||
|
||||
isa<arith::CmpIOp>(use));
|
||||
idxPlusOneOp->setOperand(1, numRolesValue);
|
||||
|
||||
// Add operations at the start of persistentForOp
|
||||
@@ -213,45 +219,6 @@ void mutexSync(ModuleOp &mod, scf::IfOp &ifOp, scf::ForOp &persistentForOp,
|
||||
unlockLocs[i] = op;
|
||||
}
|
||||
|
||||
// Update unlockLocs
|
||||
// ====================== IR after async launch dots ======================
|
||||
// * %0:2 = scf.for %arg0 = %c0 to %1 step %c1 iter_args(%arg1 = %2, arg2 =
|
||||
// %3) {
|
||||
// * triton_nvidia_gpu.producer_wait arg2
|
||||
// * %5 = triton_nvidia_gpu.dot_async %4, %5
|
||||
// * triton_nvidia_gpu.dot_wait {pendings = 1}
|
||||
// * %6 = arith.cmpi sgt, arg0, %c0
|
||||
// * scf.if %6 {
|
||||
// * %7 = arith.subi arg2, c1
|
||||
// * triton_nvidia_gpu.consumer_release %7
|
||||
// * }
|
||||
// * %8 = arith.addi arg2, c1
|
||||
// * scf.yield %5, %8
|
||||
// * }
|
||||
// * triton_nvidia_gpu.dot_wait {pendings = 0}
|
||||
// * %9 = arith.subi %0#1, c1
|
||||
// * triton_nvidia_gpu.consumer_release %9
|
||||
// * =======================================================================
|
||||
// after async launch dots, there will be outstanding consumerReleaseOp after
|
||||
// ForOp. we should expend the unlockLocs from ForOp to the outstanding
|
||||
// consumerReleaseOp.
|
||||
for (int i = 0; i < numRoles; ++i) {
|
||||
Operation *unlockOp = unlockLocs[i];
|
||||
auto filter = [&](Operation *op) {
|
||||
return op->getBlock() == unlockOp->getBlock();
|
||||
};
|
||||
if (isa<scf::ForOp>(unlockOp)) {
|
||||
SetVector<Operation *> slices;
|
||||
mlir::getForwardSlice(unlockOp->getResults().back(), &slices, {filter});
|
||||
auto iter = llvm::find_if(slices, [](Operation *op) {
|
||||
return isa<triton::nvidia_gpu::ConsumerReleaseOp>(op);
|
||||
});
|
||||
if (iter != slices.end()) {
|
||||
unlockLocs[i] = *iter;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Only cases where all lock/unlock locations are in same level make sense.
|
||||
for (int i = 1; i < numRoles; ++i) {
|
||||
if (lockLocs[i]->getParentOp() != lockLocs[i - 1]->getParentOp() ||
|
||||
@@ -281,6 +248,54 @@ void mutexSync(ModuleOp &mod, scf::IfOp &ifOp, scf::ForOp &persistentForOp,
|
||||
else
|
||||
lockLocs[i] = unlockLocs[prevTypeIds[i]];
|
||||
}
|
||||
|
||||
// Update lockLocs
|
||||
// ====================== IR after async launch dots ======================
|
||||
// * %0:2 = scf.for %arg0 = %c0 to %1 step %c1 iter_args(%arg1 = %2, arg2 =
|
||||
// %3) {
|
||||
// * triton_nvidia_gpu.producer_wait arg2
|
||||
// * %5 = triton_nvidia_gpu.dot_async %4, %5
|
||||
// * triton_nvidia_gpu.dot_wait {pendings = 1}
|
||||
// * %6 = arith.cmpi sgt, arg0, %c0
|
||||
// * scf.if %6 {
|
||||
// * %7 = arith.subi arg2, c1
|
||||
// * triton_nvidia_gpu.consumer_release %7
|
||||
// * }
|
||||
// * %8 = arith.addi arg2, c1
|
||||
// * scf.yield %5, %8
|
||||
// * }
|
||||
// * triton_nvidia_gpu.dot_wait {pendings = 0}
|
||||
// * ...
|
||||
// * triton_nvidia_gpu.consumer_release ..
|
||||
// * =======================================================================
|
||||
// after async launch dots, there will be outstanding consumerReleaseOp after
|
||||
// ForOp. we should set the epilogue lockLocs after the outstanding
|
||||
// consumerReleaseOp.
|
||||
for (int i = 0; i < numRoles; ++i) {
|
||||
Operation *lockOp = lockLocs[i];
|
||||
if (isa<scf::ForOp>(lockOp)) {
|
||||
Operation *loc = nullptr;
|
||||
unsigned numOutstandingConsumerRelease = 0;
|
||||
for (auto v : lockOp->getResults()) {
|
||||
SetVector<Operation *> slices;
|
||||
mlir::getForwardSlice(v, &slices);
|
||||
auto iter = llvm::find_if(slices, [](Operation *op) {
|
||||
return isa<triton::nvidia_gpu::ConsumerReleaseOp>(op);
|
||||
});
|
||||
if (iter != slices.end()) {
|
||||
numOutstandingConsumerRelease++;
|
||||
loc = *iter;
|
||||
}
|
||||
}
|
||||
assert(numOutstandingConsumerRelease <= 1 &&
|
||||
"should have only one outstanding "
|
||||
"consumerReleaseOp after "
|
||||
"async launch dots");
|
||||
if (loc)
|
||||
lockLocs[i] = loc;
|
||||
}
|
||||
}
|
||||
|
||||
// lock
|
||||
for (int i = 0; i < numRoles; ++i) {
|
||||
builder.setInsertionPointAfter(lockLocs[i]);
|
||||
|
||||
@@ -129,11 +129,12 @@ DenseMap<Channel *, Value> createBuffer(const SmallVector<Channel *> &channels,
|
||||
}
|
||||
|
||||
//===----------------------------------------------------------------------===//
|
||||
// appendPipelineIdxToLoopArgs
|
||||
// createNewLoops
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
||||
scf::ForOp appendPipelineIdxToLoopArgs(scf::ForOp forOp, int numStages,
|
||||
scf::ForOp &parentForOp) {
|
||||
// for(...) -> for(..., pipelineIdx)
|
||||
scf::ForOp createNewPersistentLoop(scf::ForOp forOp, int numStages,
|
||||
scf::ForOp &parentForOp) {
|
||||
auto loc = forOp.getLoc();
|
||||
Block *body = forOp.getBody();
|
||||
|
||||
@@ -200,6 +201,117 @@ scf::ForOp appendPipelineIdxToLoopArgs(scf::ForOp forOp, int numStages,
|
||||
return newForOp;
|
||||
}
|
||||
|
||||
// for(...) -> for(..., phase, pipelineIdx)
|
||||
scf::ForOp createNewMathLoop(scf::ForOp forOp, int numStages,
|
||||
scf::ForOp &parentForOp) {
|
||||
auto loc = forOp.getLoc();
|
||||
Block *body = forOp.getBody();
|
||||
|
||||
// The agentId set of pipelineIdx is the union of agentId sets of all ops in
|
||||
// the for loop
|
||||
OpBuilderWithAgentIds builder(forOp.getContext());
|
||||
builder.setAgentIdsFromArray(collectAgentIds(forOp));
|
||||
|
||||
builder.setInsertionPoint(forOp);
|
||||
Value numStagesVal =
|
||||
builder.createWithAgentIds<arith::ConstantIntOp>(loc, numStages, 32);
|
||||
|
||||
// 0. Append pipelineIdx to block arguments
|
||||
Value phase =
|
||||
body->insertArgument(body->getNumArguments(), builder.getI1Type(), loc);
|
||||
Value pipelineIdx =
|
||||
body->insertArgument(body->getNumArguments(), builder.getI32Type(), loc);
|
||||
|
||||
// 1. prepare index and phase for next iteration
|
||||
// nextIdx = curIdx + 1
|
||||
// nextPhase = ((nextIdx < numStages && curPhase) || (nextIdx >= numStages &&
|
||||
// curPhase^1))
|
||||
// nextIdx = nextIdx >= numStages ? 0 : nextIdx
|
||||
auto yieldOp = llvm::cast<scf::YieldOp>(body->getTerminator());
|
||||
builder.setInsertionPoint(yieldOp);
|
||||
Value one = builder.createWithAgentIds<arith::ConstantIntOp>(loc, 1, 32);
|
||||
Value zero = builder.createWithAgentIds<arith::ConstantIntOp>(loc, 0, 32);
|
||||
Value _1_1b = builder.createWithAgentIds<arith::ConstantIntOp>(loc, 1, 1);
|
||||
// generate index for next iter
|
||||
Value nextPipelineIdx =
|
||||
builder.createWithAgentIds<arith::AddIOp>(loc, pipelineIdx, one);
|
||||
Value pipelineGECond = builder.createWithAgentIds<arith::CmpIOp>(
|
||||
loc, arith::CmpIPredicate::uge, nextPipelineIdx, numStagesVal);
|
||||
Value pipelineLTCond = builder.createWithAgentIds<arith::CmpIOp>(
|
||||
loc, arith::CmpIPredicate::ult, nextPipelineIdx, numStagesVal);
|
||||
Value cyclePipelineIdx = builder.createWithAgentIds<arith::SubIOp>(
|
||||
loc, nextPipelineIdx, numStagesVal);
|
||||
nextPipelineIdx = builder.createWithAgentIds<mlir::arith::SelectOp>(
|
||||
loc, pipelineGECond, cyclePipelineIdx, nextPipelineIdx);
|
||||
// generate phase for next iter
|
||||
Value flipPhase =
|
||||
builder.createWithAgentIds<mlir::arith::XOrIOp>(loc, phase, _1_1b);
|
||||
Value cond0 = builder.createWithAgentIds<mlir::arith::AndIOp>(
|
||||
loc, pipelineGECond, flipPhase);
|
||||
Value cond1 = builder.createWithAgentIds<mlir::arith::AndIOp>(
|
||||
loc, pipelineLTCond, phase);
|
||||
Value nextPhase =
|
||||
builder.createWithAgentIds<mlir::arith::OrIOp>(loc, cond0, cond1);
|
||||
|
||||
// 2. Append pipelineIdx to yield operands
|
||||
yieldOp->insertOperands(yieldOp.getNumOperands(),
|
||||
{nextPhase, nextPipelineIdx});
|
||||
|
||||
// 3. create newLoopArgs
|
||||
SmallVector<Value> newLoopArgs;
|
||||
for (auto operand : forOp.getInitArgs())
|
||||
newLoopArgs.push_back(operand);
|
||||
|
||||
builder.setInsertionPoint(forOp);
|
||||
Value initPipelineIdx, initEmptyIdx, initPhase;
|
||||
zero = builder.createWithAgentIds<arith::ConstantIntOp>(loc, 0, 32);
|
||||
if (parentForOp) {
|
||||
// Make sure prior pipelineIdx is inserted in the end of parentForOp
|
||||
initPipelineIdx = parentForOp.getBody()->getArguments().back();
|
||||
Value numSteps = builder.createWithAgentIds<arith::SubIOp>(
|
||||
loc, forOp.getUpperBound(), forOp.getLowerBound());
|
||||
numSteps = builder.createWithAgentIds<arith::AddIOp>(loc, numSteps,
|
||||
forOp.getStep());
|
||||
Value one = builder.createWithAgentIds<arith::ConstantIntOp>(loc, 1, 32);
|
||||
Value two = builder.createWithAgentIds<arith::ConstantIntOp>(loc, 2, 32);
|
||||
numSteps = builder.createWithAgentIds<arith::SubIOp>(loc, numSteps, one);
|
||||
numSteps = builder.createWithAgentIds<arith::DivUIOp>(loc, numSteps,
|
||||
forOp.getStep());
|
||||
// initPipelineIdx = (parentForOp.pipelineIdx * numSteps) % numStages
|
||||
// initPhase = ((parentForOp.pipelineIdx * numSteps) / numStages) & 1
|
||||
initPipelineIdx = builder.createWithAgentIds<arith::MulIOp>(
|
||||
loc, initPipelineIdx, numSteps);
|
||||
Value pipelineIdx = builder.createWithAgentIds<arith::DivUIOp>(
|
||||
loc, initPipelineIdx, numStagesVal);
|
||||
initPipelineIdx = builder.createWithAgentIds<arith::SubIOp>(
|
||||
loc, initPipelineIdx,
|
||||
builder.createWithAgentIds<arith::MulIOp>(loc, pipelineIdx,
|
||||
numStagesVal));
|
||||
pipelineIdx =
|
||||
builder.createWithAgentIds<arith::AndIOp>(loc, pipelineIdx, one);
|
||||
initPhase = builder.createWithAgentIds<arith::TruncIOp>(
|
||||
loc, builder.getI1Type(), pipelineIdx);
|
||||
} else {
|
||||
// phase init to false and pipelineIdx init to 0
|
||||
initPipelineIdx = zero;
|
||||
initPhase = builder.createWithAgentIds<arith::ConstantIntOp>(loc, 0, 1);
|
||||
}
|
||||
newLoopArgs.append({initPhase, initPipelineIdx});
|
||||
|
||||
// 4. Create newForOp and take the region of forOp
|
||||
auto newForOp = builder.createWithAgentIds<scf::ForOp>(
|
||||
loc, forOp.getLowerBound(), forOp.getUpperBound(), forOp.getStep(),
|
||||
newLoopArgs);
|
||||
newForOp.getRegion().takeBody(forOp.getRegion());
|
||||
|
||||
// 5. Replace forOp with newForOp
|
||||
for (unsigned i = 0; i < forOp.getNumResults(); ++i)
|
||||
forOp.getResult(i).replaceAllUsesWith(newForOp.getResult(i));
|
||||
forOp.erase();
|
||||
|
||||
return newForOp;
|
||||
}
|
||||
|
||||
//===----------------------------------------------------------------------===//
|
||||
// appendPipelineIdxArgs
|
||||
//===----------------------------------------------------------------------===//
|
||||
@@ -217,7 +329,22 @@ void appendPipelineIdxArgs(SmallVector<Operation *> &backbone, int numStages) {
|
||||
|
||||
for (auto &op : orderedForOps) {
|
||||
scf::ForOp parentForOp = op->getParentOfType<scf::ForOp>();
|
||||
auto newForOp = appendPipelineIdxToLoopArgs(op, numStages, parentForOp);
|
||||
scf::ForOp newForOp;
|
||||
bool hasDotOp = false;
|
||||
for (Operation &subOp : *op.getBody()) {
|
||||
if (isa<triton::DotOp>(&subOp)) {
|
||||
hasDotOp = true;
|
||||
break;
|
||||
}
|
||||
}
|
||||
|
||||
if (hasDotOp) {
|
||||
// for(...) -> for(..., phase, pipelineIdx)
|
||||
newForOp = createNewMathLoop(op, numStages, parentForOp);
|
||||
} else {
|
||||
// for(...) -> for(..., pipelineIdx)
|
||||
newForOp = createNewPersistentLoop(op, numStages, parentForOp);
|
||||
}
|
||||
auto backboneForItr =
|
||||
std::find(backbone.begin(), backbone.end(), op.getOperation());
|
||||
if (backboneForItr != backbone.end()) {
|
||||
@@ -688,8 +815,6 @@ void buildAsyncComm(const DenseMap<Operation *, SmallVector<Channel *>> &map,
|
||||
}
|
||||
builder.setAgentIdsFromArray(agentsPC);
|
||||
Value pipelineIdx;
|
||||
Value numStagesVal = builder.createWithAgentIds<arith::ConstantIntOp>(
|
||||
headProducer->getLoc(), numStages, 32);
|
||||
if (auto forOp = headProducer->getParentOfType<scf::ForOp>()) {
|
||||
pipelineIdx = forOp.getBody()->getArguments().back();
|
||||
} else {
|
||||
@@ -700,10 +825,6 @@ void buildAsyncComm(const DenseMap<Operation *, SmallVector<Channel *>> &map,
|
||||
|
||||
// insert ProducerAcquireOp
|
||||
builder.setInsertionPoint(headProducer);
|
||||
if (headProducer->getParentOfType<scf::ForOp>()) {
|
||||
pipelineIdx = builder.createWithAgentIds<arith::RemSIOp>(
|
||||
headProducer->getLoc(), pipelineIdx, numStagesVal);
|
||||
}
|
||||
builder.setAgentIdsFromArray(agentP);
|
||||
builder.createWithAgentIds<ttng::ProducerAcquireOp>(headProducer->getLoc(),
|
||||
token, pipelineIdx);
|
||||
@@ -738,7 +859,8 @@ void buildAsyncComm(const DenseMap<Operation *, SmallVector<Channel *>> &map,
|
||||
loc, dotOp.getA(), dotOp.getB(), dotOp.getC(),
|
||||
dotOp.getAllowTF32(), dotOp.getMaxNumImpreciseAcc());
|
||||
dot.replaceAllUsesWith(dotAsync.getResult());
|
||||
builder.createWithAgentIds<triton::nvidia_gpu::DotWaitOp>(loc, 1);
|
||||
builder.createWithAgentIds<triton::nvidia_gpu::DotWaitOp>(
|
||||
loc, dotAsync.getResult(), 1);
|
||||
|
||||
// 1. insert ConsumerReleaseOp for DotAsyncOps
|
||||
Value cond = builder.createWithAgentIds<arith::CmpIOp>(
|
||||
@@ -747,31 +869,43 @@ void buildAsyncComm(const DenseMap<Operation *, SmallVector<Channel *>> &map,
|
||||
auto ifOp =
|
||||
builder.createWithAgentIds<scf::IfOp>(loc, ArrayRef<Type>{}, cond,
|
||||
/*hasElse*/ false);
|
||||
setAgentIds(ifOp.thenYield().getOperation(), agentIds);
|
||||
builder.setInsertionPointToStart(ifOp.thenBlock());
|
||||
Value one = builder.createWithAgentIds<arith::ConstantIntOp>(
|
||||
headConsumer->getLoc(), 1, 32);
|
||||
auto oriIdx = forOp.getBody()->getArguments().back();
|
||||
Value consumerReleaseIdx =
|
||||
builder.createWithAgentIds<arith::SubIOp>(loc, oriIdx, one);
|
||||
consumerReleaseIdx = builder.createWithAgentIds<arith::RemSIOp>(
|
||||
loc, consumerReleaseIdx, numStagesVal);
|
||||
Value consumerReleaseIdx = forOp.getBody()->getArguments().back();
|
||||
Value zero = builder.createWithAgentIds<arith::ConstantIntOp>(loc, 0, 32);
|
||||
Value one = builder.createWithAgentIds<arith::ConstantIntOp>(loc, 1, 32);
|
||||
Value lastStage = builder.createWithAgentIds<arith::ConstantIntOp>(
|
||||
loc, numStages - 1, 32);
|
||||
Value consumerReleaseIdxMinusOne =
|
||||
builder.createWithAgentIds<arith::SubIOp>(loc, consumerReleaseIdx,
|
||||
one);
|
||||
cond = builder.createWithAgentIds<arith::CmpIOp>(
|
||||
loc, arith::CmpIPredicate::eq, consumerReleaseIdx, zero);
|
||||
consumerReleaseIdx = builder.createWithAgentIds<arith::SelectOp>(
|
||||
loc, cond, lastStage, consumerReleaseIdxMinusOne);
|
||||
builder.createWithAgentIds<ttng::ConsumerReleaseOp>(loc, token,
|
||||
consumerReleaseIdx);
|
||||
setAgentIds(ifOp.thenYield().getOperation(), agentIds);
|
||||
|
||||
// 2. If there's any outstanding DotAsyncOps, we need to wait for them.
|
||||
builder.setInsertionPointAfter(forOp);
|
||||
builder.createWithAgentIds<triton::nvidia_gpu::DotWaitOp>(forOp.getLoc(),
|
||||
0);
|
||||
unsigned resultIndex = dotAsync->getUses().begin()->getOperandNumber();
|
||||
Value result = forOp->getResult(resultIndex);
|
||||
auto dotWait = builder.createWithAgentIds<triton::nvidia_gpu::DotWaitOp>(
|
||||
forOp.getLoc(), result, 0);
|
||||
result.replaceAllUsesExcept(dotWait.getResult(), dotWait);
|
||||
|
||||
// 3. insert ConsumerReleaseOp for outstanding DotAsyncOps
|
||||
Value one_ = builder.createWithAgentIds<arith::ConstantIntOp>(
|
||||
headConsumer->getLoc(), 1, 32);
|
||||
zero = builder.createWithAgentIds<arith::ConstantIntOp>(loc, 0, 32);
|
||||
one = builder.createWithAgentIds<arith::ConstantIntOp>(loc, 1, 32);
|
||||
lastStage = builder.createWithAgentIds<arith::ConstantIntOp>(
|
||||
loc, numStages - 1, 32);
|
||||
consumerReleaseIdx = forOp.getResults().back();
|
||||
consumerReleaseIdx = builder.createWithAgentIds<arith::SubIOp>(
|
||||
loc, consumerReleaseIdx, one_);
|
||||
consumerReleaseIdx = builder.createWithAgentIds<arith::RemSIOp>(
|
||||
loc, consumerReleaseIdx, numStagesVal);
|
||||
consumerReleaseIdxMinusOne = builder.createWithAgentIds<arith::SubIOp>(
|
||||
loc, consumerReleaseIdx, one);
|
||||
cond = builder.createWithAgentIds<arith::CmpIOp>(
|
||||
loc, arith::CmpIPredicate::eq, consumerReleaseIdx, zero);
|
||||
consumerReleaseIdx = builder.createWithAgentIds<arith::SelectOp>(
|
||||
loc, cond, lastStage, consumerReleaseIdxMinusOne);
|
||||
builder.createWithAgentIds<ttng::ConsumerReleaseOp>(loc, token,
|
||||
consumerReleaseIdx);
|
||||
dotOp->erase();
|
||||
|
||||
@@ -14,7 +14,6 @@ add_mlir_translation_library(TritonLLVMIR
|
||||
PUBLIC
|
||||
MLIRArithToLLVM
|
||||
MLIRBuiltinToLLVMIRTranslation
|
||||
MLIRExecutionEngineUtils
|
||||
MLIRIndexToLLVM
|
||||
MLIRIR
|
||||
MLIRLLVMDialect
|
||||
|
||||
@@ -44,7 +44,8 @@ static bool findAndReplace(std::string &str, const std::string &begin,
|
||||
return true;
|
||||
}
|
||||
|
||||
std::string translateLLVMIRToPTX(llvm::Module &module, int cc, int version) {
|
||||
std::string translateLLVMIRToPTX(llvm::Module &module, int cc, int version,
|
||||
bool enable_fp_fusion) {
|
||||
// LLVM version in use may not officially support target hardware.
|
||||
// Supported versions for LLVM 14 are here:
|
||||
// https://github.com/llvm/llvm-project/blob/f28c006a5895fc0e329fe15fead81e37457cb1d1/clang/include/clang/Basic/BuiltinsNVPTX.def
|
||||
@@ -84,13 +85,15 @@ std::string translateLLVMIRToPTX(llvm::Module &module, int cc, int version) {
|
||||
auto target =
|
||||
llvm::TargetRegistry::lookupTarget(module.getTargetTriple(), error);
|
||||
llvm::TargetOptions opt;
|
||||
opt.AllowFPOpFusion = llvm::FPOpFusion::Fast;
|
||||
if (enable_fp_fusion)
|
||||
opt.AllowFPOpFusion = llvm::FPOpFusion::Fast;
|
||||
opt.UnsafeFPMath = false;
|
||||
opt.NoInfsFPMath = false;
|
||||
opt.NoNaNsFPMath = true;
|
||||
opt.TrapUnreachable = true;
|
||||
std::unique_ptr<llvm::TargetMachine> machine{target->createTargetMachine(
|
||||
module.getTargetTriple(), proc, features, opt, llvm::Reloc::PIC_,
|
||||
std::nullopt, llvm::CodeGenOpt::Aggressive)};
|
||||
std::nullopt, llvm::CodeGenOptLevel::Aggressive)};
|
||||
// set data layout
|
||||
if (layout.empty())
|
||||
module.setDataLayout(machine->createDataLayout());
|
||||
@@ -106,7 +109,7 @@ std::string translateLLVMIRToPTX(llvm::Module &module, int cc, int version) {
|
||||
llvm::legacy::PassManager pass;
|
||||
// emit
|
||||
machine->addPassesToEmitFile(pass, pstream, nullptr,
|
||||
llvm::CodeGenFileType::CGFT_AssemblyFile);
|
||||
llvm::CodeGenFileType::AssemblyFile);
|
||||
pass.run(module);
|
||||
}
|
||||
// post-process
|
||||
|
||||
@@ -73,23 +73,20 @@ def get_llvm_package_info():
|
||||
if arch == 'aarch64':
|
||||
arch = 'arm64'
|
||||
if system == "Darwin":
|
||||
system_suffix = "apple-darwin"
|
||||
arch = platform.machine()
|
||||
system_suffix = f"macos-{arch}"
|
||||
elif system == "Linux":
|
||||
# TODO: arm64
|
||||
vglibc = tuple(map(int, platform.libc_ver()[1].split('.')))
|
||||
vglibc = vglibc[0] * 100 + vglibc[1]
|
||||
linux_suffix = 'ubuntu-18.04' if vglibc > 217 else 'centos-7'
|
||||
system_suffix = f"linux-gnu-{linux_suffix}"
|
||||
system_suffix = 'ubuntu-x64' if vglibc > 217 else 'centos-x64'
|
||||
else:
|
||||
return Package("llvm", "LLVM-C.lib", "", "LLVM_INCLUDE_DIRS", "LLVM_LIBRARY_DIR", "LLVM_SYSPATH")
|
||||
use_assert_enabled_llvm = check_env_flag("TRITON_USE_ASSERT_ENABLED_LLVM", "False")
|
||||
release_suffix = "assert" if use_assert_enabled_llvm else "release"
|
||||
name = f'llvm+mlir-17.0.0-{arch}-{system_suffix}-{release_suffix}'
|
||||
version = "llvm-17.0.0-c5dede880d17"
|
||||
url = f"https://github.com/ptillet/triton-llvm-releases/releases/download/{version}/{name}.tar.xz"
|
||||
# FIXME: remove the following once github.com/ptillet/triton-llvm-releases has arm64 llvm releases
|
||||
if arch == 'arm64' and 'linux' in system_suffix:
|
||||
url = f"https://github.com/acollins3/triton-llvm-releases/releases/download/{version}/{name}.tar.xz"
|
||||
# use_assert_enabled_llvm = check_env_flag("TRITON_USE_ASSERT_ENABLED_LLVM", "False")
|
||||
# release_suffix = "assert" if use_assert_enabled_llvm else "release"
|
||||
rev = "b1115f8c"
|
||||
name = f"llvm-{rev}-{system_suffix}"
|
||||
url = f"https://tritonlang.blob.core.windows.net/llvm-builds/{name}.tar.gz"
|
||||
return Package("llvm", name, url, "LLVM_INCLUDE_DIRS", "LLVM_LIBRARY_DIR", "LLVM_SYSPATH")
|
||||
|
||||
|
||||
|
||||
@@ -1236,7 +1236,7 @@ void init_triton_ir(py::module &&m) {
|
||||
.def("create_minf",
|
||||
[](TritonOpBuilder &self, mlir::Value &lhs,
|
||||
mlir::Value &rhs) -> mlir::Value {
|
||||
return mlir::Value(self.create<mlir::arith::MinFOp>(lhs, rhs));
|
||||
return mlir::Value(self.create<mlir::arith::MinimumFOp>(lhs, rhs));
|
||||
})
|
||||
.def("create_maxsi",
|
||||
[](TritonOpBuilder &self, mlir::Value &lhs,
|
||||
@@ -1251,7 +1251,7 @@ void init_triton_ir(py::module &&m) {
|
||||
.def("create_maxf",
|
||||
[](TritonOpBuilder &self, mlir::Value &lhs,
|
||||
mlir::Value &rhs) -> mlir::Value {
|
||||
return mlir::Value(self.create<mlir::arith::MaxFOp>(lhs, rhs));
|
||||
return mlir::Value(self.create<mlir::arith::MaximumFOp>(lhs, rhs));
|
||||
})
|
||||
// AddPtr (similar to GEP)
|
||||
.def("create_addptr",
|
||||
@@ -2006,7 +2006,8 @@ void init_triton_translation(py::module &m) {
|
||||
|
||||
m.def(
|
||||
"translate_llvmir_to_ptx",
|
||||
[](const std::string llvmIR, int capability, int version) -> std::string {
|
||||
[](const std::string llvmIR, int capability, int version,
|
||||
bool enable_fp_fusion) -> std::string {
|
||||
py::gil_scoped_release allow_threads;
|
||||
// create LLVM module from C++
|
||||
llvm::LLVMContext context;
|
||||
@@ -2021,75 +2022,77 @@ void init_triton_translation(py::module &m) {
|
||||
"lineno: " + std::to_string(error.getLineNo()));
|
||||
}
|
||||
// translate module to PTX
|
||||
auto ptxCode =
|
||||
triton::translateLLVMIRToPTX(*module, capability, version);
|
||||
auto ptxCode = triton::translateLLVMIRToPTX(*module, capability,
|
||||
version, enable_fp_fusion);
|
||||
return ptxCode;
|
||||
},
|
||||
ret::take_ownership);
|
||||
|
||||
m.def(
|
||||
"compile_ptx_to_cubin",
|
||||
[](const std::string &ptxCode, const std::string &ptxasPath,
|
||||
int capability) -> py::object {
|
||||
std::string cubin;
|
||||
{
|
||||
py::gil_scoped_release allow_threads;
|
||||
m.def("compile_ptx_to_cubin",
|
||||
[](const std::string &ptxCode, const std::string &ptxasPath,
|
||||
int capability, bool enable_fp_fusion) -> py::object {
|
||||
std::string cubin;
|
||||
{
|
||||
py::gil_scoped_release allow_threads;
|
||||
|
||||
// compile ptx with ptxas
|
||||
llvm::SmallString<64> fsrc;
|
||||
llvm::SmallString<64> flog;
|
||||
llvm::sys::fs::createTemporaryFile("compile-ptx-src", "", fsrc);
|
||||
llvm::sys::fs::createTemporaryFile("compile-ptx-log", "", flog);
|
||||
std::string fbin = std::string(fsrc) + ".o";
|
||||
llvm::FileRemover logRemover(flog);
|
||||
llvm::FileRemover binRemover(fbin);
|
||||
const char *_fsrc = fsrc.c_str();
|
||||
const char *_flog = flog.c_str();
|
||||
const char *_fbin = fbin.c_str();
|
||||
std::ofstream ofs(_fsrc);
|
||||
ofs << ptxCode << std::endl;
|
||||
ofs.close();
|
||||
// compile ptx with ptxas
|
||||
llvm::SmallString<64> fsrc;
|
||||
llvm::SmallString<64> flog;
|
||||
llvm::sys::fs::createTemporaryFile("compile-ptx-src", "", fsrc);
|
||||
llvm::sys::fs::createTemporaryFile("compile-ptx-log", "", flog);
|
||||
std::string fbin = std::string(fsrc) + ".o";
|
||||
llvm::FileRemover logRemover(flog);
|
||||
llvm::FileRemover binRemover(fbin);
|
||||
const char *_fsrc = fsrc.c_str();
|
||||
const char *_flog = flog.c_str();
|
||||
const char *_fbin = fbin.c_str();
|
||||
std::ofstream ofs(_fsrc);
|
||||
ofs << ptxCode << std::endl;
|
||||
ofs.close();
|
||||
|
||||
auto lineInfoOption =
|
||||
triton::tools::getBoolEnv("TRITON_DISABLE_LINE_INFO")
|
||||
? ""
|
||||
: " -lineinfo";
|
||||
auto capabilitySuffix = (capability == 90) ? "a " : " ";
|
||||
auto outputFileName = std::string(_fsrc) + ".o";
|
||||
auto logRedirect = " 2> " + std::string(_flog);
|
||||
std::string cmd = ptxasPath + lineInfoOption + " -v --gpu-name=sm_" +
|
||||
std::to_string(capability) + capabilitySuffix +
|
||||
_fsrc + " -o " + outputFileName + logRedirect;
|
||||
auto lineInfoOption =
|
||||
triton::tools::getBoolEnv("TRITON_DISABLE_LINE_INFO")
|
||||
? ""
|
||||
: " -lineinfo";
|
||||
auto fmadOption = enable_fp_fusion ? "" : " --fmad=false";
|
||||
auto capabilitySuffix = (capability == 90) ? "a " : " ";
|
||||
auto outputFileName = std::string(_fsrc) + ".o";
|
||||
auto logRedirect = " 2> " + std::string(_flog);
|
||||
std::string cmd = ptxasPath + lineInfoOption + fmadOption +
|
||||
" -v --gpu-name=sm_" +
|
||||
std::to_string(capability) + capabilitySuffix +
|
||||
_fsrc + " -o " + outputFileName + logRedirect;
|
||||
|
||||
int err = system(cmd.c_str());
|
||||
if (err != 0) {
|
||||
err >>= 8;
|
||||
std::ifstream _log(_flog);
|
||||
std::string log(std::istreambuf_iterator<char>(_log), {});
|
||||
if (err == 255) {
|
||||
throw std::runtime_error("Internal Triton PTX codegen error: \n" +
|
||||
log);
|
||||
} else if (err == 128 + SIGSEGV) {
|
||||
throw std::runtime_error("Please run `ptxas " + fsrc.str().str() +
|
||||
"` to confirm that this is a "
|
||||
"bug in `ptxas`\n" +
|
||||
log);
|
||||
int err = system(cmd.c_str());
|
||||
if (err != 0) {
|
||||
err >>= 8;
|
||||
std::ifstream _log(_flog);
|
||||
std::string log(std::istreambuf_iterator<char>(_log), {});
|
||||
if (err == 255) {
|
||||
throw std::runtime_error(
|
||||
"Internal Triton PTX codegen error: \n" + log);
|
||||
} else if (err == 128 + SIGSEGV) {
|
||||
throw std::runtime_error("Please run `ptxas " +
|
||||
fsrc.str().str() +
|
||||
"` to confirm that this is a "
|
||||
"bug in `ptxas`\n" +
|
||||
log);
|
||||
} else {
|
||||
throw std::runtime_error("`ptxas` failed with error code " +
|
||||
std::to_string(err) + ": \n" + log);
|
||||
}
|
||||
return {};
|
||||
} else {
|
||||
throw std::runtime_error("`ptxas` failed with error code " +
|
||||
std::to_string(err) + ": \n" + log);
|
||||
llvm::FileRemover srcRemover(fsrc);
|
||||
std::ifstream _cubin(_fbin, std::ios::binary);
|
||||
cubin = std::string(std::istreambuf_iterator<char>(_cubin), {});
|
||||
_cubin.close();
|
||||
// Do not return here, exit the gil scope and return below
|
||||
}
|
||||
return {};
|
||||
} else {
|
||||
llvm::FileRemover srcRemover(fsrc);
|
||||
std::ifstream _cubin(_fbin, std::ios::binary);
|
||||
cubin = std::string(std::istreambuf_iterator<char>(_cubin), {});
|
||||
_cubin.close();
|
||||
// Do not return here, exit the gil scope and return below
|
||||
}
|
||||
}
|
||||
py::bytes bytes(cubin);
|
||||
return std::move(bytes);
|
||||
});
|
||||
py::bytes bytes(cubin);
|
||||
return std::move(bytes);
|
||||
});
|
||||
|
||||
m.def("add_external_libs",
|
||||
[](mlir::ModuleOp &op, const std::vector<std::string> &names,
|
||||
|
||||
@@ -860,9 +860,9 @@ def test_unary_op(dtype_x, expr, num_ctas, device):
|
||||
# ----------------
|
||||
|
||||
|
||||
@pytest.mark.parametrize("dtype_x, expr", [(dtype_x, expr) for dtype_x in ["float32", "float64"] for expr in ['exp', 'log', 'cos', 'sin']])
|
||||
def test_math_op(dtype_x, expr, device):
|
||||
_test_unary(dtype_x, f'tl.{expr}(x)', f'np.{expr}(x) ', device=device)
|
||||
@pytest.mark.parametrize("dtype_x, expr, x", [(dtype_x, expr, x) for dtype_x in ["float32", "float64"] for expr in ['exp', 'log', 'cos', 'sin'] for x in ['x', '3.0']])
|
||||
def test_math_op(dtype_x, expr, device, x):
|
||||
_test_unary(dtype_x, f'tl.{expr}({x})', f'np.{expr}({x}) ', device=device)
|
||||
|
||||
# ----------------
|
||||
# test abs
|
||||
@@ -1662,10 +1662,18 @@ reduce_configs2 = [
|
||||
for op in ['min', 'max', 'sum']
|
||||
]
|
||||
|
||||
reduce3d_shapes = [(2, 32, 16), (32, 2, 16), (32, 16, 2)]
|
||||
reduce_configs3 = [
|
||||
(op, 'float32', shape, axis)
|
||||
for op in ['min', 'max', 'sum', 'argmin', 'argmax']
|
||||
for shape in reduce3d_shapes
|
||||
for axis in [0, 1, 2]
|
||||
]
|
||||
|
||||
@pytest.mark.parametrize("op, dtype_str, shape, axis", reduce_configs1 + reduce_configs2)
|
||||
|
||||
@pytest.mark.parametrize("op, dtype_str, shape, axis", reduce_configs1 + reduce_configs2 + reduce_configs3)
|
||||
@pytest.mark.parametrize("num_ctas", num_ctas_list)
|
||||
def test_reduce2d(op, dtype_str, shape, axis, num_ctas, device):
|
||||
def test_reduce(op, dtype_str, shape, axis, num_ctas, device):
|
||||
check_type_supported(dtype_str, device) # bfloat16 on cc < 80 will not be tested
|
||||
|
||||
if is_hip():
|
||||
@@ -1673,17 +1681,31 @@ def test_reduce2d(op, dtype_str, shape, axis, num_ctas, device):
|
||||
# triton kernel
|
||||
|
||||
@triton.jit
|
||||
def kernel(X, Z, BLOCK_M: tl.constexpr, BLOCK_N: tl.constexpr, AXIS: tl.constexpr):
|
||||
def kernel(X, Z, BLOCK_M: tl.constexpr, BLOCK_N: tl.constexpr, BLOCK_K: tl.constexpr, IS_3D: tl.constexpr, AXIS: tl.constexpr):
|
||||
range_m = tl.arange(0, BLOCK_M)
|
||||
range_n = tl.arange(0, BLOCK_N)
|
||||
x = tl.load(X + range_m[:, None] * BLOCK_N + range_n[None, :])
|
||||
z = GENERATE_TEST_HERE
|
||||
if AXIS is None:
|
||||
tl.store(Z, z)
|
||||
elif AXIS == 1:
|
||||
tl.store(Z + range_m, z)
|
||||
range_k = tl.arange(0, BLOCK_K)
|
||||
if IS_3D:
|
||||
x = tl.load(X + range_m[:, None, None] * BLOCK_N * BLOCK_K + range_n[None, :, None] * BLOCK_K + range_k[None, None, :])
|
||||
else:
|
||||
tl.store(Z + range_n, z)
|
||||
x = tl.load(X + range_m[:, None] * BLOCK_N + range_n[None, :])
|
||||
z = GENERATE_TEST_HERE
|
||||
if IS_3D:
|
||||
if AXIS is None:
|
||||
tl.store(Z, z)
|
||||
elif AXIS == 0:
|
||||
tl.store(Z + range_n[:, None] * BLOCK_K + range_k[None, :], z)
|
||||
elif AXIS == 1:
|
||||
tl.store(Z + range_m[:, None] * BLOCK_K + range_k[None, :], z)
|
||||
else:
|
||||
tl.store(Z + range_m[:, None] * BLOCK_N + range_n[None, :], z)
|
||||
else:
|
||||
if AXIS is None:
|
||||
tl.store(Z, z)
|
||||
elif AXIS == 0:
|
||||
tl.store(Z + range_n, z)
|
||||
else:
|
||||
tl.store(Z + range_m, z)
|
||||
|
||||
kernel = patch_kernel(kernel, {'GENERATE_TEST_HERE': f'tl.{op}(x, axis=AXIS)'})
|
||||
# input
|
||||
@@ -1706,10 +1728,13 @@ def test_reduce2d(op, dtype_str, shape, axis, num_ctas, device):
|
||||
z_ref = numpy_op(x, axis=axis).astype(getattr(np, z_dtype_str))
|
||||
# triton result
|
||||
ret_numel = 1 if axis is None else shape[1 - axis]
|
||||
z_tri = to_triton(numpy_random((ret_numel,), dtype_str=z_dtype_str, rs=rs),
|
||||
z_shape = (1,) if axis is None else tuple(shape_i for i, shape_i in enumerate(shape) if i != axis)
|
||||
z_tri = to_triton(numpy_random(z_shape, dtype_str=z_dtype_str, rs=rs),
|
||||
device=device, dst_type=z_tri_dtype_str)
|
||||
BLOCK_K = 1 if len(shape) == 2 else shape[2]
|
||||
IS_3D = bool(len(shape) == 3)
|
||||
kernel[(1,)](x_tri, z_tri, BLOCK_M=shape[0],
|
||||
BLOCK_N=shape[1], AXIS=axis, num_ctas=num_ctas)
|
||||
BLOCK_N=shape[1], BLOCK_K=BLOCK_K, IS_3D=IS_3D, AXIS=axis, num_ctas=num_ctas)
|
||||
z_tri = to_numpy(z_tri)
|
||||
# compare
|
||||
if op == 'sum':
|
||||
@@ -1890,7 +1915,7 @@ def test_reduce_layouts(M, N, src_layout, axis, reduce2d, dtype_str, reduce_op,
|
||||
|
||||
ty = {"int32": "i32", "float32": "f32", "float16": "f16"}[dtype_str]
|
||||
arith_op = {
|
||||
"max": {"int32": "arith.maxsi", "float32": "arith.maxf", "float16": "arith.maxf"},
|
||||
"max": {"int32": "arith.maxsi", "float32": "arith.maximumf", "float16": "arith.maximumf"},
|
||||
"sum": {"int32": "arith.addi", "float32": "arith.addf", "float16": "arith.addf"}
|
||||
}[reduce_op][dtype_str]
|
||||
numpy_op = {
|
||||
@@ -3049,6 +3074,20 @@ def test_constexpr_scalar_shape(device):
|
||||
kernel[(1,)](x_tri, 32)
|
||||
np.testing.assert_equal(to_numpy(x_tri), np.arange(0, 256) % 8)
|
||||
|
||||
|
||||
@triton.jit
|
||||
def static_assert_func():
|
||||
tl.static_assert(tl.constexpr(False), f"Assert is firing because the constexpr progation did not work properly")
|
||||
|
||||
|
||||
def test_constexpr_propagation():
|
||||
@triton.jit
|
||||
def _kernel(COND: tl.constexpr):
|
||||
NEW_COND = COND
|
||||
if NEW_COND:
|
||||
static_assert_func()
|
||||
_kernel[(1,)](False)
|
||||
|
||||
# -------------
|
||||
# test call
|
||||
# -------------
|
||||
@@ -3893,3 +3932,22 @@ def test_fp8_dot_acc(in_type_str, low_precision_acc, device):
|
||||
torch.testing.assert_close(ref_out, C, rtol=1e-3, atol=1e-3)
|
||||
else:
|
||||
torch.testing.assert_close(ref_out, C)
|
||||
|
||||
# -----------------------
|
||||
# test enable_fp_fusion
|
||||
# -----------------------
|
||||
|
||||
|
||||
@pytest.mark.parametrize("enable_fp_fusion", [False, True])
|
||||
def test_enable_fp_fusion(enable_fp_fusion):
|
||||
# Sequential multiply add can be fused by backend
|
||||
@triton.jit
|
||||
def mul_add(data):
|
||||
ptrs = data + tl.arange(0, 128)
|
||||
tl.store(ptrs, tl.load(ptrs) * 1.5 + 1.0)
|
||||
|
||||
data = torch.randn((128,), device='cuda', dtype=torch.float32)
|
||||
h = mul_add[(1,)](data, enable_fp_fusion=enable_fp_fusion)
|
||||
|
||||
found_fma = re.search(r'(mad|fma)\.r[nzmp]\.(ftz\.)?f32', h.asm["ptx"]) is not None
|
||||
assert found_fma == enable_fp_fusion
|
||||
|
||||
@@ -228,6 +228,7 @@ class CodeGenerator(ast.NodeVisitor):
|
||||
self.local_defs: Dict[str, tensor] = {}
|
||||
self.global_uses: Dict[str, tensor] = {}
|
||||
self.dereference_name: Callable[[str], Any] = self._define_name_lookup()
|
||||
self.fn = None
|
||||
|
||||
builtin_namespace: Dict[str, Any] = {_.__name__: _ for _ in (range, float, int, isinstance, getattr)}
|
||||
builtin_namespace.update((
|
||||
@@ -322,6 +323,8 @@ class CodeGenerator(ast.NodeVisitor):
|
||||
|
||||
def visit_FunctionDef(self, node):
|
||||
arg_names, kwarg_names = self.visit(node.args)
|
||||
if self.fn:
|
||||
raise UnsupportedLanguageConstruct(None, node, "nested function definition is not supported.")
|
||||
# initialize defaults
|
||||
for i, default_value in enumerate(node.args.defaults):
|
||||
arg_node = node.args.args[-i - 1]
|
||||
@@ -335,9 +338,9 @@ class CodeGenerator(ast.NodeVisitor):
|
||||
self.visit(init_node)
|
||||
# initialize function
|
||||
visibility = "public" if self.is_kernel else "private"
|
||||
fn = self.builder.get_or_insert_function(self.module, self.function_name, self.prototype.to_ir(self.builder), visibility, self.noinline)
|
||||
self.module.push_back(fn)
|
||||
entry = fn.add_entry_block()
|
||||
self.fn = self.builder.get_or_insert_function(self.module, self.function_name, self.prototype.to_ir(self.builder), visibility, self.noinline)
|
||||
self.module.push_back(self.fn)
|
||||
entry = self.fn.add_entry_block()
|
||||
arg_values = []
|
||||
idx = 0
|
||||
for i, arg_name in enumerate(arg_names):
|
||||
@@ -350,8 +353,8 @@ class CodeGenerator(ast.NodeVisitor):
|
||||
else:
|
||||
if i in self.attributes:
|
||||
for name, value in self.attributes[i]:
|
||||
fn.set_arg_attr(idx, name, value)
|
||||
arg_values.append(tensor(fn.args(idx), self.prototype.param_types[idx]))
|
||||
self.fn.set_arg_attr(idx, name, value)
|
||||
arg_values.append(tensor(self.fn.args(idx), self.prototype.param_types[idx]))
|
||||
idx += 1
|
||||
|
||||
insert_pt = self.builder.get_insertion_block()
|
||||
@@ -367,14 +370,14 @@ class CodeGenerator(ast.NodeVisitor):
|
||||
# update return type
|
||||
if isinstance(self.last_ret_type, tuple):
|
||||
self.prototype.ret_types = list(self.last_ret_type)
|
||||
fn.reset_type(self.prototype.to_ir(self.builder))
|
||||
self.fn.reset_type(self.prototype.to_ir(self.builder))
|
||||
else:
|
||||
self.prototype.ret_types = [self.last_ret_type]
|
||||
fn.reset_type(self.prototype.to_ir(self.builder))
|
||||
self.fn.reset_type(self.prototype.to_ir(self.builder))
|
||||
if insert_pt:
|
||||
self.builder.set_insertion_point_to_end(insert_pt)
|
||||
# Remove dead code
|
||||
fn.finalize()
|
||||
self.fn.finalize()
|
||||
|
||||
def visit_arguments(self, node):
|
||||
arg_names = []
|
||||
@@ -412,6 +415,9 @@ class CodeGenerator(ast.NodeVisitor):
|
||||
raise UnsupportedLanguageConstruct(None, node, "simultaneous multiple assignment is not supported.")
|
||||
names = _names[0]
|
||||
values = self.visit(node.value)
|
||||
if not isinstance(node.value, ast.Constant) and _is_constexpr(values):
|
||||
self.set_value(names, values)
|
||||
return
|
||||
if not _is_list_like(names):
|
||||
names = [names]
|
||||
if not _is_list_like(values):
|
||||
@@ -686,9 +692,13 @@ class CodeGenerator(ast.NodeVisitor):
|
||||
for name in loop_defs:
|
||||
if name in liveins:
|
||||
# We should not def new constexpr
|
||||
assert _is_triton_tensor(loop_defs[name])
|
||||
assert _is_triton_tensor(liveins[name])
|
||||
assert loop_defs[name].type == liveins[name].type
|
||||
assert _is_triton_tensor(loop_defs[name]), f'cannoe reassign constxpr {name} in the loop'
|
||||
assert _is_triton_tensor(liveins[name]), f'cannot reasign constexpr {name} in the loop'
|
||||
assert loop_defs[name].type == liveins[name].type, \
|
||||
f'Loop-carried variable {name} has initial type {liveins[name].type} '\
|
||||
f'but is re-assigned to {loop_defs[name].type} in loop! '\
|
||||
f'Please make sure that the type stays consistent.'
|
||||
|
||||
# these are loop-carried values
|
||||
names.append(name)
|
||||
ret_types.append(loop_defs[name].type)
|
||||
|
||||
@@ -37,6 +37,7 @@ CUDA_DEFAULT_WARP_SIZE = 32
|
||||
class CudaTargetDescriptor:
|
||||
capability: int
|
||||
num_warps: int
|
||||
enable_fp_fusion: bool
|
||||
|
||||
|
||||
def _is_cuda(target):
|
||||
@@ -147,6 +148,7 @@ def optimize_ttgir(mod, num_stages, num_warps, num_ctas, target,
|
||||
pm.add_tritongpu_wspipeline_pass(num_stages, num_warps, capability)
|
||||
pm.add_tritongpu_wsmutex_pass(capability)
|
||||
pm.add_tritongpu_wsmaterialization_pass(capability)
|
||||
pm.add_licm_pass()
|
||||
pm.add_cse_pass()
|
||||
else:
|
||||
if is_hip():
|
||||
@@ -220,7 +222,7 @@ def llir_to_ptx(mod: Any, target: CudaTargetDescriptor, ptx_version: int = None)
|
||||
if ptx_version is None:
|
||||
_, cuda_version = path_to_ptxas()
|
||||
ptx_version = ptx_get_version(cuda_version)
|
||||
return translate_llvmir_to_ptx(mod, target.capability, ptx_version)
|
||||
return translate_llvmir_to_ptx(mod, target.capability, ptx_version, target.enable_fp_fusion)
|
||||
|
||||
|
||||
def ptx_to_cubin(ptx: str, target: CudaTargetDescriptor):
|
||||
@@ -231,7 +233,7 @@ def ptx_to_cubin(ptx: str, target: CudaTargetDescriptor):
|
||||
:return: str
|
||||
'''
|
||||
ptxas, _ = path_to_ptxas()
|
||||
return compile_ptx_to_cubin(ptx, ptxas, target.capability)
|
||||
return compile_ptx_to_cubin(ptx, ptxas, target.capability, target.enable_fp_fusion)
|
||||
|
||||
|
||||
# ------------------------------------------------------------------------------
|
||||
@@ -407,8 +409,12 @@ def compile(fn, **kwargs):
|
||||
assert num_warps > 0 and (num_warps & (num_warps - 1)) == 0, "num_warps must be a power of 2"
|
||||
num_ctas = kwargs.get("num_ctas", 1)
|
||||
num_stages = kwargs.get("num_stages", get_arch_default_num_stages(device_type, capability=capability))
|
||||
<<<<<<< HEAD
|
||||
waves_per_eu = kwargs.get("waves_per_eu", 0)
|
||||
matrix_instr_nonkdim = kwargs.get("matrix_instr_nonkdim", 0)
|
||||
=======
|
||||
enable_fp_fusion = kwargs.get("enable_fp_fusion", True)
|
||||
>>>>>>> 721897fcc4f942aa97d2e9ba3787a5e213758177
|
||||
# TODO[shuhaoj]: Default should be to enable warp specialization once possible
|
||||
enable_warp_specialization = kwargs.get("enable_warp_specialization", False)
|
||||
# TODO[shuhaoj]: persistent can be decoupled with warp specialization
|
||||
@@ -431,7 +437,7 @@ def compile(fn, **kwargs):
|
||||
# build architecture descriptor
|
||||
if device_type == "cuda":
|
||||
_device_backend = get_backend(device_type)
|
||||
target = CudaTargetDescriptor(capability=get_cuda_capability(capability), num_warps=num_warps)
|
||||
target = CudaTargetDescriptor(capability=get_cuda_capability(capability), num_warps=num_warps, enable_fp_fusion=enable_fp_fusion)
|
||||
else:
|
||||
_device_backend = get_backend(device_type)
|
||||
assert _device_backend
|
||||
|
||||
@@ -1057,14 +1057,19 @@ def load(pointer, mask=None, other=None, boundary_check=tuple(), padding_option=
|
||||
"""
|
||||
Return a tensor of data whose values are loaded from memory at location defined by `pointer`:
|
||||
(1) `pointer` could be a single element pointer, then a scalar will be loaded
|
||||
|
||||
- `mask` and `other` must be scalar too
|
||||
- `other` is implicitly typecast to `pointer.dtype.element_ty`
|
||||
- `boundary_check` and `padding_option` must be empty
|
||||
|
||||
(2) `pointer` could be element-wise tensor of pointers, in which case:
|
||||
|
||||
- `mask` and `other` are implicitly broadcast to `pointer.shape`
|
||||
- `other` is implicitly typecast to `pointer.dtype.element_ty`
|
||||
- `boundary_check` and `padding_option` must be empty
|
||||
|
||||
(3) `pointer` could be a block pointer defined by `make_block_ptr`, in which case:
|
||||
|
||||
- `mask` and `other` must be None
|
||||
- `boundary_check` and `padding_option` can be specified to control the behavior of out-of-bound access
|
||||
|
||||
@@ -1103,14 +1108,20 @@ def store(pointer, value, mask=None, boundary_check=(), cache_modifier="", evict
|
||||
"""
|
||||
Store a tensor of data into memory locations defined by `pointer`:
|
||||
(1) `pointer` could be a single element pointer, then a scalar will be stored
|
||||
|
||||
- `mask` must be scalar too
|
||||
- `boundary_check` and `padding_option` must be empty
|
||||
|
||||
(2) `pointer` could be element-wise tensor of pointers, in which case:
|
||||
|
||||
- `mask` is implicitly broadcast to `pointer.shape`
|
||||
- `boundary_check` must be empty
|
||||
|
||||
(3) or `pointer` could be a block pointer defined by `make_block_ptr`, in which case:
|
||||
|
||||
- `mask` must be None
|
||||
- `boundary_check` can be specified to control the behavior of out-of-bound access
|
||||
|
||||
`value` is implicitly broadcast to `pointer.shape` and typecast to `pointer.dtype.element_ty`.
|
||||
|
||||
:param pointer: The memory location where the elements of `value` are stored
|
||||
@@ -1165,29 +1176,35 @@ def advance(base: tensor, offsets, _builder=None):
|
||||
# -----------------------
|
||||
|
||||
|
||||
def _add_atomic_docstr(name: str) -> Callable[[T], T]:
|
||||
def _add_atomic_docstr(name: str, has_cmp: bool = False) -> Callable[[T], T]:
|
||||
|
||||
def _decorator(func: T) -> T:
|
||||
docstr = """
|
||||
docstr = f"""
|
||||
Performs an atomic {name} at the memory location specified by :code:`pointer`.
|
||||
|
||||
Return the data stored at :code:`pointer` before the atomic operation.
|
||||
|
||||
:param pointer: The memory locations to compare-and-swap.
|
||||
:type pointer: Block of dtype=triton.PointerDType
|
||||
:param pointer: The memory locations to operate on
|
||||
:type pointer: Block of dtype=triton.PointerDType"""
|
||||
if has_cmp:
|
||||
docstr += """
|
||||
:param cmp: The values expected to be found in the atomic object
|
||||
:type cmp: Block of dtype=`pointer.dtype.element_ty`
|
||||
:param val: The values to copy in case the expected value matches the contained value.
|
||||
:type val: Block of dtype=`pointer.dtype.element_ty`
|
||||
:type cmp: Block of dtype=pointer.dtype.element_ty"""
|
||||
docstr += """
|
||||
:param val: The values with which to perform the atomic operation
|
||||
:type val: Block of dtype=pointer.dtype.element_ty
|
||||
:param sem: Memory semantics to use ("ACQUIRE_RELEASE" (default),
|
||||
"ACQUIRE", "RELEASE", or "RELAXED")
|
||||
:type sem: str
|
||||
"""
|
||||
func.__doc__ = docstr.format(name=name)
|
||||
func.__doc__ = docstr
|
||||
return func
|
||||
|
||||
return _decorator
|
||||
|
||||
|
||||
@builtin
|
||||
@_add_atomic_docstr("compare-and-swap")
|
||||
@_add_atomic_docstr("compare-and-swap", has_cmp=True)
|
||||
def atomic_cas(pointer, cmp, val, sem=None, _builder=None):
|
||||
cmp = _to_tensor(cmp, _builder)
|
||||
val = _to_tensor(val, _builder)
|
||||
@@ -1309,6 +1326,8 @@ def fdiv(x, y, ieee_rounding=False, _builder=None):
|
||||
:type ieee_rounding: bool
|
||||
"""
|
||||
ieee_rounding = _constexpr_to_value(ieee_rounding)
|
||||
x = _to_tensor(x, _builder)
|
||||
y = _to_tensor(y, _builder)
|
||||
return semantic.fdiv(x, y, ieee_rounding, _builder)
|
||||
|
||||
|
||||
@@ -1330,36 +1349,42 @@ def _add_math_1arg_docstr(name: str) -> Callable[[T], T]:
|
||||
@builtin
|
||||
@_add_math_1arg_docstr("exponential")
|
||||
def exp(x, _builder=None):
|
||||
x = _to_tensor(x, _builder)
|
||||
return semantic.exp(x, _builder)
|
||||
|
||||
|
||||
@builtin
|
||||
@_add_math_1arg_docstr("natural logarithm")
|
||||
def log(x, _builder=None):
|
||||
x = _to_tensor(x, _builder)
|
||||
return semantic.log(x, _builder)
|
||||
|
||||
|
||||
@builtin
|
||||
@_add_math_1arg_docstr("cosine")
|
||||
def cos(x, _builder=None):
|
||||
x = _to_tensor(x, _builder)
|
||||
return semantic.cos(x, _builder)
|
||||
|
||||
|
||||
@builtin
|
||||
@_add_math_1arg_docstr("sine")
|
||||
def sin(x, _builder=None):
|
||||
x = _to_tensor(x, _builder)
|
||||
return semantic.sin(x, _builder)
|
||||
|
||||
|
||||
@builtin
|
||||
@_add_math_1arg_docstr("square root")
|
||||
def sqrt(x, _builder=None):
|
||||
x = _to_tensor(x, _builder)
|
||||
return semantic.sqrt(x, _builder)
|
||||
|
||||
|
||||
@builtin
|
||||
@_add_math_1arg_docstr("absolute value")
|
||||
def abs(x, _builder=None):
|
||||
x = _to_tensor(x, _builder)
|
||||
return semantic.abs(x, _builder)
|
||||
|
||||
|
||||
|
||||
@@ -453,7 +453,7 @@ def _unwrap(tensor):
|
||||
|
||||
builder = Builder()
|
||||
|
||||
RESERVED_KWS = ['num_warps', 'num_stages', 'num_ctas', 'enable_warp_specialization']
|
||||
RESERVED_KWS = ['num_warps', 'num_stages', 'num_ctas', 'enable_warp_specialization', 'enable_fp_fusion']
|
||||
|
||||
|
||||
class GridExecutor:
|
||||
|
||||
@@ -281,13 +281,21 @@ class JITFunction(KernelInterface[T]):
|
||||
constants = dict(zip(self.constexprs, constexpr_key))
|
||||
return constants
|
||||
|
||||
<<<<<<< HEAD
|
||||
def _call_hook(self, key, signature, device, constants, num_warps, num_ctas, num_stages, waves_per_eu, matrix_instr_nonkdim, enable_warp_specialization, extern_libs, configs):
|
||||
=======
|
||||
def _call_hook(self, key, signature, device, constants, num_warps, num_ctas, num_stages, enable_warp_specialization, enable_fp_fusion, extern_libs, configs):
|
||||
>>>>>>> 721897fcc4f942aa97d2e9ba3787a5e213758177
|
||||
if JITFunction.cache_hook is None:
|
||||
return False
|
||||
name = self.fn.__name__
|
||||
module = self.fn.__module__
|
||||
arg_reprs = ', '.join([f'{name}: {ty}' for name, ty in zip(self.arg_names, key[1])])
|
||||
<<<<<<< HEAD
|
||||
repr = f"{name}[num_warps={num_warps}, num_ctas={num_ctas}, num_stages={num_stages}, waves_per_eu={waves_per_eu}, matrix_instr_nonkdim={matrix_instr_nonkdim}, enable_warp_specialization={enable_warp_specialization}]({arg_reprs})"
|
||||
=======
|
||||
repr = f"{name}[num_warps={num_warps}, num_ctas={num_ctas}, num_stages={num_stages}, enable_warp_specialization={enable_warp_specialization}, enable_fp_fusion={enable_fp_fusion}]({arg_reprs})"
|
||||
>>>>>>> 721897fcc4f942aa97d2e9ba3787a5e213758177
|
||||
key = str(key)
|
||||
|
||||
class LegacyCompiler:
|
||||
@@ -297,7 +305,11 @@ class JITFunction(KernelInterface[T]):
|
||||
pass
|
||||
|
||||
kwargs = dict(signature=signature, device=device, constants=constants,
|
||||
<<<<<<< HEAD
|
||||
num_warps=num_warps, num_ctas=num_ctas, num_stages=num_stages, waves_per_eu=waves_per_eu, enable_warp_specialization=enable_warp_specialization, extern_libs=extern_libs,
|
||||
=======
|
||||
num_warps=num_warps, num_ctas=num_ctas, num_stages=num_stages, enable_warp_specialization=enable_warp_specialization, enable_fp_fusion=enable_fp_fusion, extern_libs=extern_libs,
|
||||
>>>>>>> 721897fcc4f942aa97d2e9ba3787a5e213758177
|
||||
configs=configs)
|
||||
|
||||
return JITFunction.cache_hook(key=key, repr=repr, fn=LegacyCompiler(module, name), compile={
|
||||
@@ -351,7 +363,11 @@ class JITFunction(KernelInterface[T]):
|
||||
def regular_args_v(args_proxy):
|
||||
return [args_proxy[arg_name] for arg_name in regular_args]
|
||||
|
||||
<<<<<<< HEAD
|
||||
def launcher_body(args_proxy, grid, num_warps, num_ctas, num_stages, waves_per_eu, matrix_instr_nonkdim, enable_warp_specialization, extern_libs, stream, warmup, device, device_type):
|
||||
=======
|
||||
def launcher_body(args_proxy, grid, num_warps, num_ctas, num_stages, enable_warp_specialization, enable_fp_fusion, extern_libs, stream, warmup, device, device_type):
|
||||
>>>>>>> 721897fcc4f942aa97d2e9ba3787a5e213758177
|
||||
from ..compiler import (CompiledKernel, compile,
|
||||
get_arch_default_num_stages,
|
||||
get_arch_default_num_warps)
|
||||
@@ -402,7 +418,11 @@ class JITFunction(KernelInterface[T]):
|
||||
if num_stages is None:
|
||||
num_stages = get_arch_default_num_stages(device_type)
|
||||
|
||||
<<<<<<< HEAD
|
||||
key = (version_key, sig_key, constexpr_key, spec_key, num_warps, num_ctas, num_stages, waves_per_eu, matrix_instr_nonkdim, enable_warp_specialization, self.debug)
|
||||
=======
|
||||
key = (version_key(), sig_key, constexpr_key, spec_key, num_warps, num_ctas, num_stages, enable_warp_specialization, enable_fp_fusion, self.debug)
|
||||
>>>>>>> 721897fcc4f942aa97d2e9ba3787a5e213758177
|
||||
if extern_libs is not None:
|
||||
key = (key, tuple(extern_libs.items()))
|
||||
|
||||
@@ -430,8 +450,13 @@ class JITFunction(KernelInterface[T]):
|
||||
for i, arg in constants.items():
|
||||
if callable(arg):
|
||||
raise TypeError(f"Callable constexpr at index {i} is not supported")
|
||||
<<<<<<< HEAD
|
||||
if not self._call_hook(key, signature, device, constants, num_warps, num_ctas, num_stages, waves_per_eu, matrix_instr_nonkdim, enable_warp_specialization, extern_libs, configs):
|
||||
bin = compile(self, signature=signature, device=device, constants=constants, num_warps=num_warps, num_ctas=num_ctas, num_stages=num_stages, waves_per_eu=waves_per_eu, matrix_instr_nonkdim=matrix_instr_nonkdim, enable_warp_specialization=enable_warp_specialization, extern_libs=extern_libs, configs=configs, debug=self.debug, device_type=device_type)
|
||||
=======
|
||||
if not self._call_hook(key, signature, device, constants, num_warps, num_ctas, num_stages, enable_warp_specialization, enable_fp_fusion, extern_libs, configs):
|
||||
bin = compile(self, signature=signature, device=device, constants=constants, num_warps=num_warps, num_ctas=num_ctas, num_stages=num_stages, enable_warp_specialization=enable_warp_specialization, enable_fp_fusion=enable_fp_fusion, extern_libs=extern_libs, configs=configs, debug=self.debug, device_type=device_type)
|
||||
>>>>>>> 721897fcc4f942aa97d2e9ba3787a5e213758177
|
||||
# Create tensormaps and append to args
|
||||
args = bin.assemble_tensormap_to_arg(args)
|
||||
if not warmup:
|
||||
@@ -446,8 +471,13 @@ class JITFunction(KernelInterface[T]):
|
||||
args_signature = args_signature + ', ' if len(args_signature) > 0 else ''
|
||||
src = f"""
|
||||
import triton
|
||||
<<<<<<< HEAD
|
||||
def {self.fn.__name__}({args_signature}grid=None, num_warps=None, num_ctas=1, num_stages=None, waves_per_eu=0, matrix_instr_nonkdim=0, enable_warp_specialization=False, extern_libs=None, stream=None, warmup=False, device=None, device_type=None):
|
||||
return launcher_body({{{args_map}}}, grid, num_warps, num_ctas, num_stages, waves_per_eu, matrix_instr_nonkdim, enable_warp_specialization, extern_libs, stream, warmup, device, device_type)
|
||||
=======
|
||||
def {self.fn.__name__}({args_signature}grid=None, num_warps=None, num_ctas=1, num_stages=None, enable_warp_specialization=False, enable_fp_fusion=True, extern_libs=None, stream=None, warmup=False, device=None, device_type=None):
|
||||
return launcher_body({{{args_map}}}, grid, num_warps, num_ctas, num_stages, enable_warp_specialization, enable_fp_fusion, extern_libs, stream, warmup, device, device_type)
|
||||
>>>>>>> 721897fcc4f942aa97d2e9ba3787a5e213758177
|
||||
"""
|
||||
scope = {"launcher_body": launcher_body}
|
||||
exec(src, scope)
|
||||
|
||||
@@ -32,13 +32,17 @@ def _attn_fwd_inner(
|
||||
STAGE: tl.constexpr,
|
||||
offs_m: tl.constexpr,
|
||||
offs_n: tl.constexpr,
|
||||
N_CTX: tl.constexpr,
|
||||
):
|
||||
# range of values handled by this stage
|
||||
if STAGE == 1:
|
||||
lo, hi = 0, start_m * BLOCK_M
|
||||
else:
|
||||
elif STAGE == 2:
|
||||
lo, hi = start_m * BLOCK_M, (start_m + 1) * BLOCK_M
|
||||
lo = tl.multiple_of(lo, BLOCK_M)
|
||||
# causal = False
|
||||
else:
|
||||
lo, hi = 0, N_CTX
|
||||
K_block_ptr = tl.advance(K_block_ptr, (0, lo))
|
||||
V_block_ptr = tl.advance(V_block_ptr, (lo, 0))
|
||||
# loop over k, v and update accumulator
|
||||
@@ -72,6 +76,7 @@ def _attn_fwd_inner(
|
||||
K_block_ptr = tl.advance(K_block_ptr, (0, BLOCK_N))
|
||||
return acc, l_i, m_i
|
||||
|
||||
<<<<<<< HEAD
|
||||
@triton.jit
|
||||
def _attn_fwd_inner(
|
||||
acc, l_i, m_i, q,
|
||||
@@ -138,6 +143,31 @@ def _attn_fwd_inner(
|
||||
],
|
||||
key=['Z', 'H', 'N_CTX', 'STAGE', 'BLOCK_DMODEL'],
|
||||
)
|
||||
=======
|
||||
# We don't run auto-tuning everytime to keep the tutorial fast. Uncommenting
|
||||
# the code below and commenting out the equivalent parameters is convenient for
|
||||
# re-tuning.
|
||||
# @triton.autotune(
|
||||
# configs=[
|
||||
# triton.Config({'BLOCK_M': 128, 'BLOCK_N': 64}, num_stages=4, num_warps=8),
|
||||
# triton.Config({'BLOCK_M': 256, 'BLOCK_N': 64}, num_stages=3, num_warps=8),
|
||||
# triton.Config({'BLOCK_M': 256, 'BLOCK_N': 32}, num_stages=3, num_warps=8),
|
||||
# triton.Config({'BLOCK_M': 256, 'BLOCK_N': 32}, num_stages=3, num_warps=4),
|
||||
# triton.Config({'BLOCK_M': 128, 'BLOCK_N': 32}, num_stages=3, num_warps=4),
|
||||
# triton.Config({'BLOCK_M': 128, 'BLOCK_N': 32}, num_stages=4, num_warps=4),
|
||||
# triton.Config({'BLOCK_M': 128, 'BLOCK_N': 64}, num_stages=3, num_warps=4),
|
||||
# triton.Config({'BLOCK_M': 128, 'BLOCK_N': 64}, num_stages=4, num_warps=4),
|
||||
# triton.Config({'BLOCK_M': 128, 'BLOCK_N': 64}, num_stages=3, num_warps=8),
|
||||
# triton.Config({'BLOCK_M': 128, 'BLOCK_N': 64}, num_stages=7, num_warps=8),
|
||||
# triton.Config({'BLOCK_M': 128, 'BLOCK_N': 32}, num_stages=7, num_warps=8),
|
||||
# triton.Config({'BLOCK_M': 128, 'BLOCK_N': 32}, num_stages=6, num_warps=8),
|
||||
# triton.Config({'BLOCK_M': 128, 'BLOCK_N': 32}, num_stages=5, num_warps=8),
|
||||
# triton.Config({'BLOCK_M': 128, 'BLOCK_N': 32}, num_stages=4, num_warps=8),
|
||||
# triton.Config({'BLOCK_M': 128, 'BLOCK_N': 64}, num_stages=6, num_warps=4),
|
||||
# ],
|
||||
# key=['N_CTX'],
|
||||
# )
|
||||
>>>>>>> 721897fcc4f942aa97d2e9ba3787a5e213758177
|
||||
|
||||
|
||||
@triton.jit
|
||||
@@ -227,8 +257,12 @@ def _attn_fwd(
|
||||
acc, l_i, m_i, q, K_block_ptr, V_block_ptr,
|
||||
start_m,
|
||||
BLOCK_M, BLOCK_DMODEL, BLOCK_N,
|
||||
<<<<<<< HEAD
|
||||
4 - STAGE, offs_m, offs_n,
|
||||
N_CTX, pre_load_v,
|
||||
=======
|
||||
4 - STAGE, offs_m, offs_n, N_CTX,
|
||||
>>>>>>> 721897fcc4f942aa97d2e9ba3787a5e213758177
|
||||
)
|
||||
# stage 2: on-band
|
||||
if STAGE & 2:
|
||||
@@ -239,8 +273,12 @@ def _attn_fwd(
|
||||
acc, l_i, m_i, q, K_block_ptr, V_block_ptr,
|
||||
start_m,
|
||||
BLOCK_M, BLOCK_DMODEL, BLOCK_N,
|
||||
<<<<<<< HEAD
|
||||
2, offs_m, offs_n,
|
||||
N_CTX, pre_load_v,
|
||||
=======
|
||||
2, offs_m, offs_n, N_CTX,
|
||||
>>>>>>> 721897fcc4f942aa97d2e9ba3787a5e213758177
|
||||
)
|
||||
# epilogue
|
||||
# write back m
|
||||
@@ -701,6 +739,7 @@ class _attention(torch.autograd.Function):
|
||||
Lq, Lk, Lv = q.shape[-1], k.shape[-1], v.shape[-1]
|
||||
assert Lq == Lk and Lk == Lv
|
||||
assert Lk in {16, 32, 64, 128}
|
||||
<<<<<<< HEAD
|
||||
o = torch.empty_like(q, dtype=v.dtype)
|
||||
if torch.version.hip is None:
|
||||
BLOCK_M = 128
|
||||
@@ -716,6 +755,20 @@ class _attention(torch.autograd.Function):
|
||||
)
|
||||
M = torch.empty((q.shape[0] * q.shape[1], q.shape[2]), device=q.device, dtype=torch.float32)
|
||||
|
||||
=======
|
||||
o = torch.empty_like(q)
|
||||
BLOCK_M = 128
|
||||
BLOCK_N = 64 if Lk <= 64 else 32
|
||||
num_stages = 4 if Lk <= 64 else 3
|
||||
num_warps = 4
|
||||
stage = 3 if causal else 1
|
||||
# Tuning for H100
|
||||
if torch.cuda.get_device_capability()[0] == 9:
|
||||
num_warps = 8
|
||||
num_stages = 7 if Lk >= 64 else 3
|
||||
grid = (triton.cdiv(q.shape[2], BLOCK_M), q.shape[0] * q.shape[1], 1)
|
||||
M = torch.empty((q.shape[0], q.shape[1], q.shape[2]), device=q.device, dtype=torch.float32)
|
||||
>>>>>>> 721897fcc4f942aa97d2e9ba3787a5e213758177
|
||||
_attn_fwd[grid](
|
||||
q, k, v, sm_scale, M, o,
|
||||
q.stride(0), q.stride(1), q.stride(2), q.stride(3),
|
||||
@@ -726,6 +779,11 @@ class _attention(torch.autograd.Function):
|
||||
N_CTX=q.shape[2],
|
||||
BLOCK_DMODEL=Lk,
|
||||
STAGE=stage,
|
||||
<<<<<<< HEAD
|
||||
=======
|
||||
num_warps=num_warps,
|
||||
num_stages=num_stages,
|
||||
>>>>>>> 721897fcc4f942aa97d2e9ba3787a5e213758177
|
||||
)
|
||||
|
||||
## restore the grid for bwd kernel
|
||||
@@ -905,6 +963,7 @@ try:
|
||||
except BaseException:
|
||||
HAS_FLASH = False
|
||||
|
||||
<<<<<<< HEAD
|
||||
# vary seq length for fixed head and batch=4
|
||||
configs = []
|
||||
for mode in ['fwd', 'bwd']:
|
||||
@@ -939,6 +998,36 @@ for mode in ['fwd', 'bwd']:
|
||||
'mode': mode,
|
||||
'causal': causal})
|
||||
)
|
||||
=======
|
||||
TORCH_HAS_FP8 = hasattr(torch, 'float8_e5m2')
|
||||
BATCH, N_HEADS, N_CTX, D_HEAD = 4, 48, 4096, 64
|
||||
# vary seq length for fixed head and batch=4
|
||||
configs = []
|
||||
for mode in ["fwd", "bwd"]:
|
||||
for causal in [True, False]:
|
||||
if mode == "bwd" and not causal:
|
||||
continue
|
||||
configs.append(
|
||||
triton.testing.Benchmark(
|
||||
x_names=["N_CTX"],
|
||||
x_vals=[2**i for i in range(10, 15)],
|
||||
line_arg="provider",
|
||||
line_vals=["triton"] + (["flash"] if HAS_FLASH else []),
|
||||
line_names=["Triton"] + (["Flash-2"] if HAS_FLASH else []),
|
||||
styles=[("red", "-"), ("blue", "-")],
|
||||
ylabel="ms",
|
||||
plot_name=f"fused-attention-batch{BATCH}-head{N_HEADS}-d{D_HEAD}-{mode}-causal={causal}",
|
||||
args={
|
||||
"H": N_HEADS,
|
||||
"BATCH": BATCH,
|
||||
"D_HEAD": D_HEAD,
|
||||
"dtype": torch.float16,
|
||||
"mode": mode,
|
||||
"causal": causal,
|
||||
},
|
||||
)
|
||||
)
|
||||
>>>>>>> 721897fcc4f942aa97d2e9ba3787a5e213758177
|
||||
|
||||
|
||||
@triton.testing.perf_report(configs)
|
||||
@@ -956,6 +1045,9 @@ def bench_flash_attention(
|
||||
if provider == "triton":
|
||||
q = torch.randn((BATCH, H, N_CTX, D_HEAD), dtype=dtype, device="cuda", requires_grad=True)
|
||||
k = torch.randn((BATCH, H, N_CTX, D_HEAD), dtype=dtype, device="cuda", requires_grad=True)
|
||||
if mode == "fwd" and TORCH_HAS_FP8:
|
||||
q = q.to(torch.float8_e5m2)
|
||||
k = k.to(torch.float8_e5m2)
|
||||
v = torch.randn((BATCH, H, N_CTX, D_HEAD), dtype=dtype, device="cuda", requires_grad=True)
|
||||
if mode == "fwd":
|
||||
q = q.to(torch_dtype)
|
||||
|
||||
@@ -216,10 +216,10 @@ tt.func @for_for_if(%lb : index, %ub : index, %step : index, %A : !tt.ptr<f16>,
|
||||
// CHECK-NEXT: %arg9 -> %cst_1
|
||||
// CHECK-NEXT: %0#0 -> %cst
|
||||
// CHECK-NEXT: %0#1 -> %cst_0
|
||||
// CHECK-NEXT: %0#2 -> %cst_2,%cst_2
|
||||
// CHECK-NEXT: %0#2 -> %cst_1,%cst_2,%cst_2
|
||||
%a_shared, %b_shared, %c_shared = scf.for %iv = %lb to %ub step %step iter_args(%a_shared = %a_shared_init, %b_shared = %b_shared_init, %c_shared = %c_shared_init) -> (tensor<128x32xf16, #A_SHARED>, tensor<128x32xf16, #A_SHARED>, tensor<128x32xf16, #A_SHARED>) {
|
||||
// CHECK-NEXT: %arg11 -> %cst_1,%cst_2,%cst_2
|
||||
// CHECK-NEXT: %1 -> %cst_2,%cst_2
|
||||
// CHECK-NEXT: %1 -> %cst_1,%cst_2,%cst_2
|
||||
%c_shared_next = scf.for %jv = %lb to %ub step %step iter_args(%c_shared_next = %c_shared) -> (tensor<128x32xf16, #A_SHARED>) {
|
||||
// CHECK-NEXT: %2 -> %cst_2,%cst_2
|
||||
%c_shared_next_next = scf.if %i1 -> tensor<128x32xf16, #A_SHARED> {
|
||||
@@ -257,9 +257,9 @@ tt.func @cf_for(%arg0: index, %arg1: index, %arg2: index, %arg3: !tt.ptr<f16>, %
|
||||
%5 = arith.cmpi slt, %1, %arg1 : index
|
||||
cf.cond_br %5, ^bb2, ^bb3
|
||||
^bb2: // pred: ^bb1
|
||||
%6 = tt.cat %cst, %cst_0 {axis = 0 : i64} : (tensor<128x32xf16, #A_SHARED>, tensor<128x32xf16, #A_SHARED>) -> tensor<256x32xf16, #blocked>
|
||||
%6 = tt.cat %cst, %cst_0 {axis = 0 : i64} : (tensor<128x32xf16, #A_SHARED>, tensor<128x32xf16, #A_SHARED>) -> tensor<256x32xf16>
|
||||
gpu.barrier
|
||||
%7 = tt.cat %2, %3 {axis = 0 : i64} : (tensor<128x32xf16, #A_SHARED>, tensor<128x32xf16, #A_SHARED>) -> tensor<256x32xf16, #blocked>
|
||||
%7 = tt.cat %2, %3 {axis = 0 : i64} : (tensor<128x32xf16, #A_SHARED>, tensor<128x32xf16, #A_SHARED>) -> tensor<256x32xf16>
|
||||
%8 = arith.addi %1, %arg2 : index
|
||||
cf.br ^bb1(%8, %4, %2, %3 : index, tensor<128x32xf16, #A_SHARED>, tensor<128x32xf16, #A_SHARED>, tensor<128x32xf16, #A_SHARED>)
|
||||
^bb3: // pred: ^bb1
|
||||
|
||||
@@ -864,6 +864,37 @@ module attributes {"triton_gpu.num-ctas" = 1 : i32, "triton_gpu.num-warps" = 4 :
|
||||
|
||||
// -----
|
||||
|
||||
#blocked1 = #triton_gpu.blocked<{sizePerThread = [1, 1], threadsPerWarp = [32, 1], warpsPerCTA = [1, 8], order = [0, 1], CTAsPerCGA = [1, 1], CTASplitNum = [1, 1], CTAOrder = [0, 1]}>
|
||||
#slice1d0 = #triton_gpu.slice<{dim = 0, parent = #blocked1}>
|
||||
#shared = #triton_gpu.shared<{vec = 2, perPhase = 1, maxPhase = 8, order = [0], CTAsPerCGA = [1], CTASplitNum = [1], CTAOrder = [0], hasLeadingOffset = true}>
|
||||
module attributes {"triton_gpu.num-ctas" = 1 : i32, "triton_gpu.num-warps" = 8 : i32} {
|
||||
// CHECK-LABEL: basic_insert_slice_async_1d
|
||||
tt.func @basic_insert_slice_async_1d(%arg0: !tt.ptr<i64, 1> {tt.divisibility = 16 : i32}) {
|
||||
%c0_i32 = arith.constant 0 : i32
|
||||
%cst_2 = arith.constant dense<64> : tensor<64xi32, #slice1d0>
|
||||
%58 = tt.splat %arg0 : (!tt.ptr<i64, 1>) -> tensor<64x!tt.ptr<i64, 1>, #slice1d0>
|
||||
%24 = tt.make_range {end = 64 : i32, start = 0 : i32} : tensor<64xi32, #slice1d0>
|
||||
%59 = tt.addptr %58, %24 : tensor<64x!tt.ptr<i64, 1>, #slice1d0>, tensor<64xi32, #slice1d0>
|
||||
%66 = tt.addptr %59, %cst_2 : tensor<64x!tt.ptr<i64, 1>, #slice1d0>, tensor<64xi32, #slice1d0>
|
||||
%71 = triton_gpu.alloc_tensor : tensor<2x64xi64, #shared>
|
||||
// CHECK: llvm.inline_asm has_side_effects asm_dialect = att
|
||||
// CHECK-SAME: cp.async.ca.shared.global [ ${{.*}} + 0 ], [ ${{.*}} + 0 ], 0x8, 0x8
|
||||
// CHECK-NEXT: cp.async.ca.shared.global [ ${{.*}} + 0 ], [ ${{.*}} + 0 ], 0x8, 0x8
|
||||
// CHECK-NEXT: cp.async.ca.shared.global [ ${{.*}} + 0 ], [ ${{.*}} + 0 ], 0x8, 0x8
|
||||
// CHECK-NEXT: cp.async.ca.shared.global [ ${{.*}} + 0 ], [ ${{.*}} + 0 ], 0x8, 0x8
|
||||
// CHECK-NEXT: cp.async.ca.shared.global [ ${{.*}} + 0 ], [ ${{.*}} + 0 ], 0x8, 0x8
|
||||
// CHECK-NEXT: cp.async.ca.shared.global [ ${{.*}} + 0 ], [ ${{.*}} + 0 ], 0x8, 0x8
|
||||
// CHECK-NEXT: cp.async.ca.shared.global [ ${{.*}} + 0 ], [ ${{.*}} + 0 ], 0x8, 0x8
|
||||
// CHECK-NEXT: cp.async.ca.shared.global [ ${{.*}} + 0 ], [ ${{.*}} + 0 ], 0x8, 0x8
|
||||
// CHECK-NEXT: cp.async.commit_group
|
||||
%73 = triton_gpu.insert_slice_async %66, %71, %c0_i32 {axis = 0 : i32, cache = 1 : i32, evict = 1 : i32, isVolatile = false} : tensor<64x!tt.ptr<i64, 1>, #slice1d0> -> tensor<2x64xi64, #shared>
|
||||
triton_gpu.async_commit_group
|
||||
tt.return
|
||||
}
|
||||
}
|
||||
|
||||
// -----
|
||||
|
||||
#block0 = #triton_gpu.blocked<{sizePerThread = [1], threadsPerWarp = [4], warpsPerCTA = [4], order = [0], CTAsPerCGA = [1], CTASplitNum = [1], CTAOrder = [0]}>
|
||||
#block1 = #triton_gpu.blocked<{sizePerThread = [1], threadsPerWarp = [8], warpsPerCTA = [4], order = [0], CTAsPerCGA = [1], CTASplitNum = [1], CTAOrder = [0]}>
|
||||
#block2 = #triton_gpu.blocked<{sizePerThread = [1, 1], threadsPerWarp = [4, 1], warpsPerCTA = [4, 1], order = [1, 0], CTAsPerCGA = [1, 1], CTASplitNum = [1, 1], CTAOrder = [1, 0]}>
|
||||
@@ -2012,6 +2043,7 @@ module attributes {"triton_gpu.compute-capability" = 80 : i32, "triton_gpu.num-c
|
||||
|
||||
// -----
|
||||
|
||||
<<<<<<< HEAD
|
||||
// CHECK-LABEL: copyitem
|
||||
// GCN: llvm.store
|
||||
// GCN: llvm.load
|
||||
@@ -2023,11 +2055,21 @@ module attributes {"triton_gpu.compute-capability" = 80 : i32, "triton_gpu.num-c
|
||||
module attributes {"triton_gpu.compute-capability" = 80 : i32, "triton_gpu.num-ctas" = 1 : i32, "triton_gpu.num-warps" = 4 : i32, "triton_gpu.threads-per-warp" = 32 : i32} {
|
||||
tt.func public @copyitem() attributes {noinline = false} {
|
||||
%cst = arith.constant dense<true> : tensor<4x1xi1, #blocked>
|
||||
=======
|
||||
// CHECK-LABEL: reduce_slice
|
||||
// CHECK-NOT: st.shared
|
||||
// CHECK-NOT: ld.shared
|
||||
#blocked = #triton_gpu.blocked<{sizePerThread = [1, 1, 1], threadsPerWarp = [4, 4, 2], warpsPerCTA = [2, 4, 2], order = [2, 0, 1], CTAsPerCGA = [1, 1, 1], CTASplitNum = [1, 1, 1], CTAOrder = [0, 1, 2]}>
|
||||
#sliced2 = #triton_gpu.slice<{dim = 2, parent = #blocked}>
|
||||
module attributes {"triton_gpu.compute-capability" = 80 : i32, "triton_gpu.num-ctas" = 1 : i32, "triton_gpu.num-warps" = 16 : i32, "triton_gpu.threads-per-warp" = 32 : i32} {
|
||||
tt.func public @reduce_slice() attributes {noinline = false} {
|
||||
%cst = arith.constant dense<true> : tensor<4x1xi1, #sliced2>
|
||||
>>>>>>> 721897fcc4f942aa97d2e9ba3787a5e213758177
|
||||
%0 = "tt.reduce"(%cst) <{axis = 1 : i32}> ({
|
||||
^bb0(%arg0: i1, %arg1: i1):
|
||||
%1 = arith.ori %arg0, %arg1 : i1
|
||||
tt.reduce.return %1 : i1
|
||||
}) : (tensor<4x1xi1, #blocked>) -> tensor<4xi1, #triton_gpu.slice<{dim = 1, parent = #blocked}>>
|
||||
}) : (tensor<4x1xi1, #sliced2>) -> tensor<4xi1, #triton_gpu.slice<{dim = 1, parent = #sliced2}>>
|
||||
tt.return
|
||||
}
|
||||
}
|
||||
|
||||
@@ -1,4 +1,4 @@
|
||||
// RUN: triton-opt %s -split-input-file --convert-triton-gpu-to-llvm 2>&1 | FileCheck %s
|
||||
// RUN: triton-opt %s -split-input-file --convert-triton-gpu-to-llvm=compute-capability=90 2>&1 | FileCheck %s
|
||||
|
||||
#blocked = #triton_gpu.blocked<{sizePerThread = [1, 4], threadsPerWarp = [2, 16], warpsPerCTA = [4, 1], order = [1, 0], CTAsPerCGA = [1, 4], CTASplitNum = [1, 4], CTAOrder = [0, 1]}>
|
||||
#shared = #triton_gpu.shared<{vec = 8, perPhase = 1, maxPhase = 8, order = [1, 0], CTAsPerCGA = [1, 4], CTASplitNum = [1, 4], CTAOrder = [0, 1], hasLeadingOffset = true}>
|
||||
@@ -12,8 +12,8 @@ module attributes {"triton_gpu.num-ctas" = 4 : i32, "triton_gpu.num-warps" = 4 :
|
||||
%dst = triton_gpu.alloc_tensor : tensor<1x64x64xf16, #shared>
|
||||
%c0 = arith.constant 0 : i32
|
||||
%src = tt.make_tensor_ptr %basePtr, [%dim0, %dim1], [%stride0, %stride1], [%coord0, %coord1] {order = array<i32: 1, 0>} : !tt.ptr<tensor<64x64xf16, #blocked>, 1>
|
||||
// CHECK: nvgpu.tma_load_tiled %{{.*}}, %{{.*}}, %{{.*}}, %{{.*}}, %{{.*}}, %{{.*}}, %{{.*}} {operand_segment_sizes = array<i32: 1, 1, 1, 1, 1, 2, 0>} : !llvm.ptr<i8, 3>, !llvm.ptr<i64, 3>, !llvm.ptr<i8, 1>, i64, i1, i32, i32
|
||||
%res = triton_nvidia_gpu.insert_slice_async_v2 %src, %dst, %c0, %mbar {axis = 0 : i32, cache = 1 : i32, evict = 1 : i32, isVolatile = false, operand_segment_sizes = array<i32: 1, 1, 1, 1, 0, 0>} : !tt.ptr<tensor<64x64xf16, #blocked>, 1>, tensor<1x64x64xf16, #shared>, i32, !tt.ptr<i64, 3> -> tensor<1x64x64xf16, #shared>
|
||||
// CHECK: nvgpu.tma_load_tiled %{{.*}}, %{{.*}}, %{{.*}}, %{{.*}}, %{{.*}}, %{{.*}}, %{{.*}} {operandSegmentSizes = array<i32: 1, 1, 1, 1, 1, 2, 0>} : !llvm.ptr<i8, 3>, !llvm.ptr<i64, 3>, !llvm.ptr<i8, 1>, i64, i1, i32, i32
|
||||
%res = triton_nvidia_gpu.insert_slice_async_v2 %src, %dst, %c0, %mbar {axis = 0 : i32, cache = 1 : i32, evict = 1 : i32, isVolatile = false, operandSegmentSizes = array<i32: 1, 1, 1, 1, 0, 0>} : !tt.ptr<tensor<64x64xf16, #blocked>, 1>, tensor<1x64x64xf16, #shared>, i32, !tt.ptr<i64, 3> -> tensor<1x64x64xf16, #shared>
|
||||
tt.return
|
||||
}
|
||||
}
|
||||
@@ -34,7 +34,7 @@ module attributes {"triton_gpu.num-ctas" = 4 : i32, "triton_gpu.num-warps" = 4 :
|
||||
%src = tt.make_tensor_ptr %basePtr, [%dim0, %dim1], [%stride0, %stride1], [%coord0, %coord1] {order = array<i32: 1, 0>} : !tt.ptr<tensor<64x64xf16, #blocked>, 1>
|
||||
// CHECK: %[[C15:.*]] = llvm.mlir.constant(15 : i16) : i16
|
||||
// CHECK: nvgpu.tma_load_tiled %{{.*}}, %{{.*}}, %{{.*}}, %{{.*}}, %{{.*}}, %{{.*}}, %{{.*}}, %[[C15]]
|
||||
%res = triton_nvidia_gpu.insert_slice_async_v2 %src, %dst, %c0, %mbar {axis = 0 : i32, cache = 1 : i32, evict = 1 : i32, isVolatile = false, operand_segment_sizes = array<i32: 1, 1, 1, 1, 0, 0>} : !tt.ptr<tensor<64x64xf16, #blocked>, 1>, tensor<1x64x64xf16, #shared>, i32, !tt.ptr<i64, 3> -> tensor<1x64x64xf16, #shared>
|
||||
%res = triton_nvidia_gpu.insert_slice_async_v2 %src, %dst, %c0, %mbar {axis = 0 : i32, cache = 1 : i32, evict = 1 : i32, isVolatile = false, operandSegmentSizes = array<i32: 1, 1, 1, 1, 0, 0>} : !tt.ptr<tensor<64x64xf16, #blocked>, 1>, tensor<1x64x64xf16, #shared>, i32, !tt.ptr<i64, 3> -> tensor<1x64x64xf16, #shared>
|
||||
tt.return
|
||||
}
|
||||
}
|
||||
@@ -55,7 +55,7 @@ module attributes {"triton_gpu.num-ctas" = 4 : i32, "triton_gpu.num-warps" = 4 :
|
||||
%src = tt.make_tensor_ptr %basePtr, [%dim0, %dim1], [%stride0, %stride1], [%coord0, %coord1] {order = array<i32: 1, 0>} : !tt.ptr<tensor<64x64xf16, #blocked>, 1>
|
||||
// CHECK: nvgpu.cluster_id
|
||||
// CHECK: nvgpu.tma_load_tiled
|
||||
%res = triton_nvidia_gpu.insert_slice_async_v2 %src, %dst, %c0, %mbar {axis = 0 : i32, cache = 1 : i32, evict = 1 : i32, isVolatile = false, operand_segment_sizes = array<i32: 1, 1, 1, 1, 0, 0>} : !tt.ptr<tensor<64x64xf16, #blocked>, 1>, tensor<1x64x64xf16, #shared>, i32, !tt.ptr<i64, 3> -> tensor<1x64x64xf16, #shared>
|
||||
%res = triton_nvidia_gpu.insert_slice_async_v2 %src, %dst, %c0, %mbar {axis = 0 : i32, cache = 1 : i32, evict = 1 : i32, isVolatile = false, operandSegmentSizes = array<i32: 1, 1, 1, 1, 0, 0>} : !tt.ptr<tensor<64x64xf16, #blocked>, 1>, tensor<1x64x64xf16, #shared>, i32, !tt.ptr<i64, 3> -> tensor<1x64x64xf16, #shared>
|
||||
tt.return
|
||||
}
|
||||
}
|
||||
@@ -175,6 +175,7 @@ module attributes {"triton_gpu.num-ctas" = 1 : i32, "triton_gpu.num-warps" = 4 :
|
||||
// CHECK-LABEL: @dot_reg_operand_A
|
||||
// Generate a wgmma where the first operand is a struct.
|
||||
// CHECK: nvgpu.wgmma {{.*}} : (!llvm.struct<(i32, i32, i32, i32)>, i64, !llvm.struct<(f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32)>) -> !llvm.struct<(f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32)>
|
||||
// CHECK: nvgpu.wgmma_wait_group %{{.*}} {pendings = 0 : i32} : !llvm.struct<(f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32)>
|
||||
tt.func @dot_reg_operand_A(%a: tensor<128x64xf16, #mma>, %b: tensor<64x64xf16, #shared>) {
|
||||
%cst = arith.constant dense<0.000000e+00> : tensor<128x64xf32, #mma>
|
||||
%opA = triton_gpu.convert_layout %a : (tensor<128x64xf16, #mma>) -> tensor<128x64xf16, #triton_gpu.dot_op<{opIdx = 0, parent = #mma}>>
|
||||
@@ -183,3 +184,29 @@ module attributes {"triton_gpu.num-ctas" = 1 : i32, "triton_gpu.num-warps" = 4 :
|
||||
tt.return
|
||||
}
|
||||
}
|
||||
|
||||
// -----
|
||||
|
||||
#blocked = #triton_gpu.blocked<{sizePerThread = [1], threadsPerWarp = [32], warpsPerCTA = [1], order = [0], CTAsPerCGA = [1], CTASplitNum = [1], CTAOrder = [0]}>
|
||||
module attributes {"triton_gpu.compute-capability" = 90 : i32, "triton_gpu.num-ctas" = 1 : i32, "triton_gpu.num-warps" = 1 : i32} {
|
||||
// CHECK-LABEL: test_fp8_to_f16_conversion
|
||||
tt.func @test_fp8_to_f16_conversion(
|
||||
%in0: tensor<128xf8E5M2, #blocked>, %in1: tensor<128xf8E4M3FNUZ, #blocked>,
|
||||
%in2: tensor<128xf16, #blocked>, %in3: tensor<128xf32, #blocked>) {
|
||||
// CHECK-COUNT-2: cvt.rn.f16x2.e5m2x2 {{.*}} "=r,h" %{{.*}} : (i16) -> vector<2xf16>
|
||||
%out0 = tt.fp_to_fp %in0 : tensor<128xf8E5M2, #blocked> -> tensor<128xf16, #blocked>
|
||||
// CHECK-COUNT-2: cvt.rn.f16x2.e4m3x2 {{.*}} "=r,h" %{{.*}} : (i16) -> vector<2xf16>
|
||||
%out1 = tt.fp_to_fp %in1 : tensor<128xf8E4M3FNUZ, #blocked> -> tensor<128xf16, #blocked>
|
||||
|
||||
// CHECK-COUNT-2: cvt.rn.satfinite.e5m2x2.f16x2 {{.*}} "=h,r" %{{.*}} : (i32) -> vector<2xi8>
|
||||
%out2 = tt.fp_to_fp %in2 : tensor<128xf16, #blocked> -> tensor<128xf8E5M2, #blocked>
|
||||
// CHECK-COUNT-2: cvt.rn.satfinite.e4m3x2.f16x2 {{.*}} "=h,r" %{{.*}} : (i32) -> vector<2xi8>
|
||||
%out3 = tt.fp_to_fp %in2 : tensor<128xf16, #blocked> -> tensor<128xf8E4M3FNUZ, #blocked>
|
||||
|
||||
// CHECK-COUNT-2: cvt.rn.satfinite.e5m2x2.f32 {{.*}} "=h,r,r" %{{.*}}, %{{.*}} : (i32, i32) -> vector<2xi8>
|
||||
%out4 = tt.fp_to_fp %in3 : tensor<128xf32, #blocked> -> tensor<128xf8E5M2, #blocked>
|
||||
// CHECK-COUNT-2: cvt.rn.satfinite.e4m3x2.f32 {{.*}} "=h,r,r" %{{.*}}, %{{.*}} : (i32, i32) -> vector<2xi8>
|
||||
%out5 = tt.fp_to_fp %in3 : tensor<128xf32, #blocked> -> tensor<128xf8E4M3FNUZ, #blocked>
|
||||
tt.return
|
||||
}
|
||||
}
|
||||
|
||||
@@ -14,7 +14,7 @@ module attributes {"triton_gpu.num-warps" = 4 : i32, "triton_gpu.num-ctas" = 2 :
|
||||
nvgpu.cga_barrier_arrive
|
||||
nvgpu.cga_barrier_wait
|
||||
|
||||
%ptr = llvm.mlir.null : !llvm.ptr<i32, 3>
|
||||
%ptr = llvm.mlir.zero : !llvm.ptr<i32, 3>
|
||||
|
||||
// CHECK: llvm.inline_asm
|
||||
%v = nvgpu.cluster_id
|
||||
|
||||
@@ -2,7 +2,7 @@
|
||||
#SHARED = #triton_gpu.shared<{vec = 2, perPhase = 2, maxPhase = 4, order = [1, 0]}>
|
||||
module attributes {"triton_gpu.num-warps" = 4 : i32, "triton_gpu.num-ctas" = 2 : i32} {
|
||||
tt.func @test_mbarrier() {
|
||||
%mbarrier = llvm.mlir.null : !llvm.ptr<i64, 3>
|
||||
%mbarrier = llvm.mlir.zero : !llvm.ptr<i64, 3>
|
||||
%pred = arith.constant 1 : i1
|
||||
// CHECK: llvm.inline_asm
|
||||
nvgpu.mbarrier_init %mbarrier, %pred { count = 32 : i32 } : !llvm.ptr<i64, 3>
|
||||
|
||||
@@ -2,9 +2,9 @@
|
||||
#SHARED = #triton_gpu.shared<{vec = 2, perPhase = 2, maxPhase = 4, order = [1, 0]}>
|
||||
module attributes {"triton_gpu.num-warps" = 4 : i32, "triton_gpu.num-ctas" = 2 : i32} {
|
||||
tt.func @test_tma(%im2colOffsets0 : !llvm.struct<(i16, i16)>, %im2colOffsets1 : !llvm.struct<(i16, i16, i16)>) {
|
||||
%mbarrier = llvm.mlir.null : !llvm.ptr<i64, 3>
|
||||
%tmaDesc = llvm.mlir.null : !llvm.ptr<i8, 1>
|
||||
%dst = llvm.mlir.null : !llvm.ptr<i8, 3>
|
||||
%mbarrier = llvm.mlir.zero : !llvm.ptr<i64, 3>
|
||||
%tmaDesc = llvm.mlir.zero : !llvm.ptr<i8, 1>
|
||||
%dst = llvm.mlir.zero : !llvm.ptr<i8, 3>
|
||||
%l2desc = arith.constant 0 : i64
|
||||
%c0 = arith.constant 0 : i32
|
||||
%c1 = arith.constant 1 : i32
|
||||
@@ -16,13 +16,13 @@ module attributes {"triton_gpu.num-warps" = 4 : i32, "triton_gpu.num-ctas" = 2
|
||||
|
||||
// CHECK: llvm.inline_asm {{.*}} cp.async.bulk.tensor.2d.shared::cluster.global.mbarrier::complete_tx::bytes.L2::cache_hint
|
||||
// CHECK: llvm.inline_asm {{.*}} cp.async.bulk.tensor.4d.shared::cluster.global.mbarrier::complete_tx::bytes.L2::cache_hint
|
||||
nvgpu.tma_load_tiled %dst, %mbarrier, %tmaDesc, %l2desc, %pred, %c0, %c1 {operand_segment_sizes = array<i32: 1, 1, 1, 1, 1, 2, 0>}: !llvm.ptr<i8, 3>, !llvm.ptr<i64, 3>, !llvm.ptr<i8, 1>, i64, i1, i32, i32
|
||||
nvgpu.tma_load_tiled %dst, %mbarrier, %tmaDesc, %l2desc, %pred, %c0, %c1, %c2, %c3 {operand_segment_sizes = array<i32: 1, 1, 1, 1, 1, 4, 0>}: !llvm.ptr<i8, 3>, !llvm.ptr<i64, 3>, !llvm.ptr<i8, 1>, i64, i1, i32, i32, i32, i32
|
||||
nvgpu.tma_load_tiled %dst, %mbarrier, %tmaDesc, %l2desc, %pred, %c0, %c1 {operandSegmentSizes = array<i32: 1, 1, 1, 1, 1, 2, 0>}: !llvm.ptr<i8, 3>, !llvm.ptr<i64, 3>, !llvm.ptr<i8, 1>, i64, i1, i32, i32
|
||||
nvgpu.tma_load_tiled %dst, %mbarrier, %tmaDesc, %l2desc, %pred, %c0, %c1, %c2, %c3 {operandSegmentSizes = array<i32: 1, 1, 1, 1, 1, 4, 0>}: !llvm.ptr<i8, 3>, !llvm.ptr<i64, 3>, !llvm.ptr<i8, 1>, i64, i1, i32, i32, i32, i32
|
||||
|
||||
// CHECK: llvm.inline_asm {{.*}} cp.async.bulk.tensor.2d.shared::cluster.global.mbarrier::complete_tx::bytes.multicast::cluster.L2::cache_hint
|
||||
// CHECK: llvm.inline_asm {{.*}} cp.async.bulk.tensor.4d.shared::cluster.global.mbarrier::complete_tx::bytes.L2::cache_hint
|
||||
nvgpu.tma_load_tiled %dst, %mbarrier, %tmaDesc, %l2desc, %pred, %c0, %c1, %mask {operand_segment_sizes = array<i32: 1, 1, 1, 1, 1, 2, 1>}: !llvm.ptr<i8, 3>, !llvm.ptr<i64, 3>, !llvm.ptr<i8, 1>, i64, i1, i32, i32, i16
|
||||
nvgpu.tma_load_tiled %dst, %mbarrier, %tmaDesc, %l2desc, %pred, %c0, %c1, %c2, %c3 {operand_segment_sizes = array<i32: 1, 1, 1, 1, 1, 4, 0>}: !llvm.ptr<i8, 3>, !llvm.ptr<i64, 3>, !llvm.ptr<i8, 1>, i64, i1, i32, i32, i32, i32
|
||||
nvgpu.tma_load_tiled %dst, %mbarrier, %tmaDesc, %l2desc, %pred, %c0, %c1, %mask {operandSegmentSizes = array<i32: 1, 1, 1, 1, 1, 2, 1>}: !llvm.ptr<i8, 3>, !llvm.ptr<i64, 3>, !llvm.ptr<i8, 1>, i64, i1, i32, i32, i16
|
||||
nvgpu.tma_load_tiled %dst, %mbarrier, %tmaDesc, %l2desc, %pred, %c0, %c1, %c2, %c3 {operandSegmentSizes = array<i32: 1, 1, 1, 1, 1, 4, 0>}: !llvm.ptr<i8, 3>, !llvm.ptr<i64, 3>, !llvm.ptr<i8, 1>, i64, i1, i32, i32, i32, i32
|
||||
|
||||
tt.return
|
||||
}
|
||||
|
||||
@@ -2,7 +2,7 @@
|
||||
#SHARED = #triton_gpu.shared<{vec = 2, perPhase = 2, maxPhase = 4, order = [1, 0]}>
|
||||
module attributes {"triton_gpu.num-warps" = 4 : i32, "triton_gpu.num-ctas" = 2 : i32} {
|
||||
tt.func @test_tma(%opC : !llvm.struct<(f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32)>) {
|
||||
%buffer = llvm.mlir.null : !llvm.ptr<i64, 3>
|
||||
%buffer = llvm.mlir.zero : !llvm.ptr<i64, 3>
|
||||
%height = arith.constant 16 : i32
|
||||
// CHECK: llvm.ptrtoint
|
||||
// CHECK: llvm.inline_asm
|
||||
@@ -30,3 +30,15 @@ module attributes {"triton_gpu.num-warps" = 4 : i32, "triton_gpu.num-ctas" = 2
|
||||
tt.return
|
||||
}
|
||||
}
|
||||
|
||||
// -----
|
||||
|
||||
module attributes {"triton_gpu.num-warps" = 4 : i32, "triton_gpu.num-ctas" = 2 : i32} {
|
||||
tt.func @wgmma_wait(%in: !llvm.struct<(f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32)>) {
|
||||
// CHECK: // wait for regs: $0,$1,$2,$3,$4,$5,$6,$7,$8,$9,$10,$11,$12,$13,$14,$15,$16,$17,$18,$19,$20,$21,$22,$23,$24,$25,$26,$27,$28,$29,$30,$31,$32,$33,$34,$35,$36,$37,$38,$39,$40,$41,$42,$43,$44,$45,$46,$47,$48,$49,$50,$51,$52,$53,$54,$55,$56,$57,$58,$59,$60,$61,$62,$63
|
||||
// CHECK: wgmma.wait_group.sync.aligned 0;
|
||||
%out = nvgpu.wgmma_wait_group %in {pendings = 0 : i32} :
|
||||
!llvm.struct<(f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32)>
|
||||
tt.return
|
||||
}
|
||||
}
|
||||
|
||||
31
test/Triton/print.mlir
Normal file
31
test/Triton/print.mlir
Normal file
@@ -0,0 +1,31 @@
|
||||
// RUN: triton-translate %s --mlir-print-ir-after-all -o %t 2>&1 | FileCheck %s
|
||||
|
||||
// CHECK: IR Dump After SCFToControlFlow (convert-scf-to-cf)
|
||||
// CHECK: tt.func public @add_kernel_0d1d2d3de
|
||||
// CHECK: IR Dump After ConvertIndexToLLVMPass (convert-index-to-llvm)
|
||||
// CHECK: tt.func public @add_kernel_0d1d2d3de
|
||||
|
||||
#blocked = #triton_gpu.blocked<{sizePerThread = [4], threadsPerWarp = [32], warpsPerCTA = [4], order = [0], CTAsPerCGA = [1], CTASplitNum = [1], CTAOrder = [0]}>
|
||||
module attributes {"triton_gpu.compute-capability" = 80 : i32, "triton_gpu.num-ctas" = 1 : i32, "triton_gpu.num-warps" = 4 : i32, "triton_gpu.threads-per-warp" = 32 : i32} {
|
||||
tt.func public @add_kernel_0d1d2d3de(%arg0: !tt.ptr<f32, 1> {tt.divisibility = 16 : i32}, %arg1: !tt.ptr<f32, 1> {tt.divisibility = 16 : i32}, %arg2: !tt.ptr<f32, 1> {tt.divisibility = 16 : i32}, %arg3: i32 {tt.divisibility = 16 : i32, tt.max_divisibility = 8 : i32}) attributes {noinline = false} {
|
||||
%c1024_i32 = arith.constant 1024 : i32
|
||||
%0 = tt.get_program_id x : i32
|
||||
%1 = arith.muli %0, %c1024_i32 : i32
|
||||
%2 = tt.make_range {end = 1024 : i32, start = 0 : i32} : tensor<1024xi32, #blocked>
|
||||
%3 = tt.splat %1 : (i32) -> tensor<1024xi32, #blocked>
|
||||
%4 = arith.addi %3, %2 : tensor<1024xi32, #blocked>
|
||||
%5 = tt.splat %arg3 : (i32) -> tensor<1024xi32, #blocked>
|
||||
%6 = "triton_gpu.cmpi"(%4, %5) <{predicate = 2 : i64}> : (tensor<1024xi32, #blocked>, tensor<1024xi32, #blocked>) -> tensor<1024xi1, #blocked>
|
||||
%7 = tt.splat %arg0 : (!tt.ptr<f32, 1>) -> tensor<1024x!tt.ptr<f32, 1>, #blocked>
|
||||
%8 = tt.addptr %7, %4 : tensor<1024x!tt.ptr<f32, 1>, #blocked>, tensor<1024xi32, #blocked>
|
||||
%9 = tt.load %8, %6 {cache = 1 : i32, evict = 1 : i32, isVolatile = false} : tensor<1024xf32, #blocked>
|
||||
%10 = tt.splat %arg1 : (!tt.ptr<f32, 1>) -> tensor<1024x!tt.ptr<f32, 1>, #blocked>
|
||||
%11 = tt.addptr %10, %4 : tensor<1024x!tt.ptr<f32, 1>, #blocked>, tensor<1024xi32, #blocked>
|
||||
%12 = tt.load %11, %6 {cache = 1 : i32, evict = 1 : i32, isVolatile = false} : tensor<1024xf32, #blocked>
|
||||
%13 = arith.addf %9, %12 : tensor<1024xf32, #blocked>
|
||||
%14 = tt.splat %arg2 : (!tt.ptr<f32, 1>) -> tensor<1024x!tt.ptr<f32, 1>, #blocked>
|
||||
%15 = tt.addptr %14, %4 : tensor<1024x!tt.ptr<f32, 1>, #blocked>, tensor<1024xi32, #blocked>
|
||||
tt.store %15, %13, %6 {cache = 1 : i32, evict = 1 : i32} : tensor<1024xf32, #blocked>
|
||||
tt.return
|
||||
}
|
||||
}
|
||||
@@ -223,7 +223,7 @@ tt.func @if_else_both_convert(%arg0: i32, %arg1: !tt.ptr<i32> {tt.divisibility =
|
||||
// CHECK: [[$col_layout:#.*]] = #triton_gpu.blocked<{sizePerThread = [4, 1], threadsPerWarp = [16, 2], warpsPerCTA = [4, 1], order = [0, 1], CTAsPerCGA = [1, 1], CTASplitNum = [1, 1], CTAOrder = [0, 1]}>
|
||||
// CHECK: [[$col_layout_novec:#.*]] = #triton_gpu.blocked<{sizePerThread = [1, 1], threadsPerWarp = [32, 1], warpsPerCTA = [4, 1], order = [0, 1], CTAsPerCGA = [1, 1], CTASplitNum = [1, 1], CTAOrder = [0, 1]}>
|
||||
|
||||
// CHECK-LABEL: transpose
|
||||
// CHECK-LABEL: @transpose
|
||||
module attributes {"triton_gpu.num-warps" = 4 : i32} {
|
||||
tt.func @transpose(%arg0: !tt.ptr<f32> {tt.divisibility = 16 : i32}, %arg1: i32 {tt.divisibility = 16 : i32}, %arg2: !tt.ptr<f32> {tt.divisibility = 16 : i32}, %arg3: i32 {tt.divisibility = 16 : i32}) {
|
||||
// CHECK-NOT: triton_gpu.convert_layout
|
||||
@@ -920,7 +920,7 @@ tt.func public @mnist(%arg0: !tt.ptr<f32> {tt.divisibility = 16 : i32}, %arg1: !
|
||||
%29 = "triton_gpu.select"(%28, %26, %cst_2) : (tensor<16x16xi1, #blocked2>, tensor<16x16xf32, #blocked2>, tensor<16x16xf32, #blocked2>) -> tensor<16x16xf32, #blocked2>
|
||||
%30 = "tt.reduce" (%29) ({
|
||||
^bb0(%arg4: f32, %arg5: f32):
|
||||
%max = arith.maxf %arg4, %arg5 : f32
|
||||
%max = arith.maximumf %arg4, %arg5 : f32
|
||||
tt.reduce.return %max : f32
|
||||
}) {axis = 1 : i32} : (tensor<16x16xf32, #blocked2>) -> tensor<16xf32, #triton_gpu.slice<{dim = 1, parent = #blocked2}>>
|
||||
%31 = triton_gpu.convert_layout %30 : (tensor<16xf32, #triton_gpu.slice<{dim = 1, parent = #blocked2}>>) -> tensor<16xf32, #blocked0>
|
||||
@@ -1697,11 +1697,11 @@ module attributes {"triton_gpu.compute-capability" = 90 : i32, "triton_gpu.num-c
|
||||
%122 = arith.extf %121 : tensor<128x64xf16, #blocked2> to tensor<128x64xf32, #blocked2>
|
||||
%123 = "tt.reduce"(%122) <{axis = 1 : i32}> ({
|
||||
^bb0(%arg28: f32, %arg29: f32):
|
||||
%153 = arith.maxf %arg28, %arg29 : f32
|
||||
%153 = arith.maximumf %arg28, %arg29 : f32
|
||||
tt.reduce.return %153 : f32
|
||||
}) : (tensor<128x64xf32, #blocked2>) -> tensor<128xf32, #triton_gpu.slice<{dim = 1, parent = #blocked2}>>
|
||||
%124 = triton_gpu.convert_layout %123 : (tensor<128xf32, #triton_gpu.slice<{dim = 1, parent = #blocked2}>>) -> tensor<128xf32, #blocked1>
|
||||
%125 = arith.maxf %arg25, %124 : tensor<128xf32, #blocked1>
|
||||
%125 = arith.maximumf %arg25, %124 : tensor<128xf32, #blocked1>
|
||||
%126 = arith.subf %arg25, %125 : tensor<128xf32, #blocked1>
|
||||
%127 = tt.extern_elementwise %126 {pure = true, libname = "libdevice", libpath = "/root/.pyenv/versions/3.9.9/lib/python3.9/site-packages/triton/language/../third_party/cuda/lib/libdevice.10.bc", symbol = "__nv_exp2f"} : (tensor<128xf32, #blocked1>) -> tensor<128xf32, #blocked1>
|
||||
%128 = triton_gpu.convert_layout %125 : (tensor<128xf32, #blocked1>) -> tensor<128xf32, #triton_gpu.slice<{dim = 1, parent = #blocked9}>>
|
||||
|
||||
@@ -32,19 +32,19 @@ module attributes {"triton_gpu.compute-capability" = 90 : i32, "triton_gpu.num-c
|
||||
%14 = triton_nvidia_gpu.get_thread_id : i32
|
||||
%15 = arith.cmpi eq, %14, %c0_i32 : i32
|
||||
%16 = arith.andi %15, %10 : i1
|
||||
triton_nvidia_gpu.mbarrier_arrive %13, %16 {operand_segment_sizes = array<i32: 1, 1, 0>, trackAsyncOp = false, txCount = 65536 : i32} : !tt.ptr<i64, 3>, i1
|
||||
%17 = triton_nvidia_gpu.insert_slice_async_v2 %3, %11, %c0_i32, %13, %12 {axis = 0 : i32, cache = 1 : i32, evict = 1 : i32, isVolatile = false, operand_segment_sizes = array<i32: 1, 1, 1, 1, 1, 0>} : !tt.ptr<tensor<128x128xf16, #blocked>, 1>, tensor<3x128x128xf16, #shared1>, i32, !tt.ptr<i64, 3>, tensor<128x128xi1, #blocked1> -> tensor<3x128x128xf16, #shared1>
|
||||
triton_nvidia_gpu.mbarrier_arrive %13, %16 {operandSegmentSizes = array<i32: 1, 1, 0>, trackAsyncOp = false, txCount = 65536 : i32} : !tt.ptr<i64, 3>, i1
|
||||
%17 = triton_nvidia_gpu.insert_slice_async_v2 %3, %11, %c0_i32, %13, %12 {axis = 0 : i32, cache = 1 : i32, evict = 1 : i32, isVolatile = false, operandSegmentSizes = array<i32: 1, 1, 1, 1, 1, 0>} : !tt.ptr<tensor<128x128xf16, #blocked>, 1>, tensor<3x128x128xf16, #shared1>, i32, !tt.ptr<i64, 3>, tensor<128x128xi1, #blocked1> -> tensor<3x128x128xf16, #shared1>
|
||||
%18 = triton_gpu.alloc_tensor : tensor<3x128x128xf16, #shared1>
|
||||
%19 = triton_nvidia_gpu.insert_slice_async_v2 %6, %18, %c0_i32, %13, %12 {axis = 0 : i32, cache = 1 : i32, evict = 1 : i32, isVolatile = false, operand_segment_sizes = array<i32: 1, 1, 1, 1, 1, 0>} : !tt.ptr<tensor<128x128xf16, #blocked>, 1>, tensor<3x128x128xf16, #shared1>, i32, !tt.ptr<i64, 3>, tensor<128x128xi1, #blocked1> -> tensor<3x128x128xf16, #shared1>
|
||||
%19 = triton_nvidia_gpu.insert_slice_async_v2 %6, %18, %c0_i32, %13, %12 {axis = 0 : i32, cache = 1 : i32, evict = 1 : i32, isVolatile = false, operandSegmentSizes = array<i32: 1, 1, 1, 1, 1, 0>} : !tt.ptr<tensor<128x128xf16, #blocked>, 1>, tensor<3x128x128xf16, #shared1>, i32, !tt.ptr<i64, 3>, tensor<128x128xi1, #blocked1> -> tensor<3x128x128xf16, #shared1>
|
||||
%20 = tt.advance %3, [%c0_i32, %c128_i32] : <tensor<128x128xf16, #blocked>, 1>
|
||||
%21 = tt.advance %6, [%c128_i32, %c0_i32] : <tensor<128x128xf16, #blocked>, 1>
|
||||
%22 = arith.cmpi sgt, %arg5, %c128_i32 : i32
|
||||
%23 = tt.splat %22 : (i1) -> tensor<128x128xi1, #blocked1>
|
||||
%24 = triton_nvidia_gpu.extract_mbarrier %9[%c1_i32] : tensor<3xi64, #shared>, i32 -> <i64, 3>
|
||||
%25 = arith.andi %15, %22 : i1
|
||||
triton_nvidia_gpu.mbarrier_arrive %24, %25 {operand_segment_sizes = array<i32: 1, 1, 0>, trackAsyncOp = false, txCount = 65536 : i32} : !tt.ptr<i64, 3>, i1
|
||||
%26 = triton_nvidia_gpu.insert_slice_async_v2 %20, %17, %c1_i32, %24, %23 {axis = 0 : i32, cache = 1 : i32, evict = 1 : i32, isVolatile = false, operand_segment_sizes = array<i32: 1, 1, 1, 1, 1, 0>} : !tt.ptr<tensor<128x128xf16, #blocked>, 1>, tensor<3x128x128xf16, #shared1>, i32, !tt.ptr<i64, 3>, tensor<128x128xi1, #blocked1> -> tensor<3x128x128xf16, #shared1>
|
||||
%27 = triton_nvidia_gpu.insert_slice_async_v2 %21, %19, %c1_i32, %24, %23 {axis = 0 : i32, cache = 1 : i32, evict = 1 : i32, isVolatile = false, operand_segment_sizes = array<i32: 1, 1, 1, 1, 1, 0>} : !tt.ptr<tensor<128x128xf16, #blocked>, 1>, tensor<3x128x128xf16, #shared1>, i32, !tt.ptr<i64, 3>, tensor<128x128xi1, #blocked1> -> tensor<3x128x128xf16, #shared1>
|
||||
triton_nvidia_gpu.mbarrier_arrive %24, %25 {operandSegmentSizes = array<i32: 1, 1, 0>, trackAsyncOp = false, txCount = 65536 : i32} : !tt.ptr<i64, 3>, i1
|
||||
%26 = triton_nvidia_gpu.insert_slice_async_v2 %20, %17, %c1_i32, %24, %23 {axis = 0 : i32, cache = 1 : i32, evict = 1 : i32, isVolatile = false, operandSegmentSizes = array<i32: 1, 1, 1, 1, 1, 0>} : !tt.ptr<tensor<128x128xf16, #blocked>, 1>, tensor<3x128x128xf16, #shared1>, i32, !tt.ptr<i64, 3>, tensor<128x128xi1, #blocked1> -> tensor<3x128x128xf16, #shared1>
|
||||
%27 = triton_nvidia_gpu.insert_slice_async_v2 %21, %19, %c1_i32, %24, %23 {axis = 0 : i32, cache = 1 : i32, evict = 1 : i32, isVolatile = false, operandSegmentSizes = array<i32: 1, 1, 1, 1, 1, 0>} : !tt.ptr<tensor<128x128xf16, #blocked>, 1>, tensor<3x128x128xf16, #shared1>, i32, !tt.ptr<i64, 3>, tensor<128x128xi1, #blocked1> -> tensor<3x128x128xf16, #shared1>
|
||||
%28 = triton_gpu.extract_slice %26[0, 0, 0] [1, 128, 128] [1, 1, 1] : tensor<3x128x128xf16, #shared1> to tensor<128x128xf16, #shared1>
|
||||
%29 = triton_gpu.extract_slice %27[0, 0, 0] [1, 128, 128] [1, 1, 1] : tensor<3x128x128xf16, #shared1> to tensor<128x128xf16, #shared1>
|
||||
%30:15 = scf.for %arg9 = %c0_i32 to %arg5 step %c128_i32 iter_args(%arg10 = %cst, %arg11 = %3, %arg12 = %6, %arg13 = %26, %arg14 = %27, %arg15 = %28, %arg16 = %29, %arg17 = %20, %arg18 = %21, %arg19 = %c128_i32, %arg20 = %c2_i32, %arg21 = %c0_i32, %arg22 = %c0_i32, %arg23 = %false, %arg24 = %true) -> (tensor<128x128xf32, #mma>, !tt.ptr<tensor<128x128xf16, #blocked>, 1>, !tt.ptr<tensor<128x128xf16, #blocked>, 1>, tensor<3x128x128xf16, #shared1>, tensor<3x128x128xf16, #shared1>, tensor<128x128xf16, #shared1>, tensor<128x128xf16, #shared1>, !tt.ptr<tensor<128x128xf16, #blocked>, 1>, !tt.ptr<tensor<128x128xf16, #blocked>, 1>, i32, i32, i32, i32, i1, i1) : i32 {
|
||||
@@ -52,7 +52,6 @@ module attributes {"triton_gpu.compute-capability" = 90 : i32, "triton_gpu.num-c
|
||||
triton_nvidia_gpu.mbarrier_wait %33, %arg23 : <i64, 3>
|
||||
// CHECK: triton_nvidia_gpu.fence_async_shared
|
||||
%34 = triton_nvidia_gpu.dot_async %arg15, %arg16, %arg10 {allowTF32 = true, maxNumImpreciseAcc = 0 : i32} : tensor<128x128xf16, #shared1> * tensor<128x128xf16, #shared1> -> tensor<128x128xf32, #mma>
|
||||
triton_nvidia_gpu.dot_wait {pendings = 1 : i32}
|
||||
%35 = tt.advance %arg11, [%c0_i32, %c128_i32] : <tensor<128x128xf16, #blocked>, 1>
|
||||
%36 = tt.advance %arg12, [%c128_i32, %c0_i32] : <tensor<128x128xf16, #blocked>, 1>
|
||||
%37 = arith.addi %arg19, %c128_i32 : i32
|
||||
@@ -65,10 +64,10 @@ module attributes {"triton_gpu.compute-capability" = 90 : i32, "triton_gpu.num-c
|
||||
%44 = tt.splat %38 : (i1) -> tensor<128x128xi1, #blocked1>
|
||||
%45 = triton_nvidia_gpu.extract_mbarrier %9[%arg20] : tensor<3xi64, #shared>, i32 -> <i64, 3>
|
||||
%46 = arith.andi %15, %38 : i1
|
||||
triton_nvidia_gpu.mbarrier_arrive %45, %46 {operand_segment_sizes = array<i32: 1, 1, 0>, trackAsyncOp = false, txCount = 65536 : i32} : !tt.ptr<i64, 3>, i1
|
||||
%47 = triton_nvidia_gpu.insert_slice_async_v2 %42, %arg13, %arg20, %45, %44 {axis = 0 : i32, cache = 1 : i32, evict = 1 : i32, isVolatile = false, operand_segment_sizes = array<i32: 1, 1, 1, 1, 1, 0>} : !tt.ptr<tensor<128x128xf16, #blocked>, 1>, tensor<3x128x128xf16, #shared1>, i32, !tt.ptr<i64, 3>, tensor<128x128xi1, #blocked1> -> tensor<3x128x128xf16, #shared1>
|
||||
triton_nvidia_gpu.mbarrier_arrive %45, %46 {operandSegmentSizes = array<i32: 1, 1, 0>, trackAsyncOp = false, txCount = 65536 : i32} : !tt.ptr<i64, 3>, i1
|
||||
%47 = triton_nvidia_gpu.insert_slice_async_v2 %42, %arg13, %arg20, %45, %44 {axis = 0 : i32, cache = 1 : i32, evict = 1 : i32, isVolatile = false, operandSegmentSizes = array<i32: 1, 1, 1, 1, 1, 0>} : !tt.ptr<tensor<128x128xf16, #blocked>, 1>, tensor<3x128x128xf16, #shared1>, i32, !tt.ptr<i64, 3>, tensor<128x128xi1, #blocked1> -> tensor<3x128x128xf16, #shared1>
|
||||
%48 = triton_gpu.extract_slice %47[%41, 0, 0] [1, 128, 128] [1, 1, 1] : tensor<3x128x128xf16, #shared1> to tensor<128x128xf16, #shared1>
|
||||
%49 = triton_nvidia_gpu.insert_slice_async_v2 %43, %arg14, %arg20, %45, %44 {axis = 0 : i32, cache = 1 : i32, evict = 1 : i32, isVolatile = false, operand_segment_sizes = array<i32: 1, 1, 1, 1, 1, 0>} : !tt.ptr<tensor<128x128xf16, #blocked>, 1>, tensor<3x128x128xf16, #shared1>, i32, !tt.ptr<i64, 3>, tensor<128x128xi1, #blocked1> -> tensor<3x128x128xf16, #shared1>
|
||||
%49 = triton_nvidia_gpu.insert_slice_async_v2 %43, %arg14, %arg20, %45, %44 {axis = 0 : i32, cache = 1 : i32, evict = 1 : i32, isVolatile = false, operandSegmentSizes = array<i32: 1, 1, 1, 1, 1, 0>} : !tt.ptr<tensor<128x128xf16, #blocked>, 1>, tensor<3x128x128xf16, #shared1>, i32, !tt.ptr<i64, 3>, tensor<128x128xi1, #blocked1> -> tensor<3x128x128xf16, #shared1>
|
||||
%50 = triton_gpu.extract_slice %49[%41, 0, 0] [1, 128, 128] [1, 1, 1] : tensor<3x128x128xf16, #shared1> to tensor<128x128xf16, #shared1>
|
||||
%b_48 = triton_gpu.convert_layout %48 : (tensor<128x128xf16, #shared1>) -> tensor<128x128xf16, #blocked1>
|
||||
%s_48 = triton_gpu.convert_layout %b_48 : (tensor<128x128xf16, #blocked1>) -> tensor<128x128xf16, #shared1>
|
||||
@@ -88,10 +87,8 @@ module attributes {"triton_gpu.compute-capability" = 90 : i32, "triton_gpu.num-c
|
||||
%64 = arith.ori %62, %63 : i1
|
||||
scf.yield %34, %35, %36, %47, %49, %s_48, %50, %42, %43, %37, %53, %41, %54, %59, %64 : tensor<128x128xf32, #mma>, !tt.ptr<tensor<128x128xf16, #blocked>, 1>, !tt.ptr<tensor<128x128xf16, #blocked>, 1>, tensor<3x128x128xf16, #shared1>, tensor<3x128x128xf16, #shared1>, tensor<128x128xf16, #shared1>, tensor<128x128xf16, #shared1>, !tt.ptr<tensor<128x128xf16, #blocked>, 1>, !tt.ptr<tensor<128x128xf16, #blocked>, 1>, i32, i32, i32, i32, i1, i1
|
||||
}
|
||||
scf.if %10 {
|
||||
triton_nvidia_gpu.dot_wait {pendings = 0 : i32}
|
||||
}
|
||||
%31 = arith.truncf %30#0 : tensor<128x128xf32, #mma> to tensor<128x128xf16, #mma>
|
||||
%w = triton_nvidia_gpu.dot_wait %30#0 {pendings = 0 : i32} : tensor<128x128xf32, #mma>
|
||||
%31 = arith.truncf %w : tensor<128x128xf32, #mma> to tensor<128x128xf16, #mma>
|
||||
%32 = triton_gpu.convert_layout %31 : (tensor<128x128xf16, #mma>) -> tensor<128x128xf16, #shared1>
|
||||
triton_nvidia_gpu.store_async %8, %32 : !tt.ptr<tensor<128x128xf16, #blocked>, 1>, tensor<128x128xf16, #shared1>
|
||||
triton_gpu.async_bulk_commit_group
|
||||
@@ -136,19 +133,19 @@ module attributes {"triton_gpu.compute-capability" = 90 : i32, "triton_gpu.num-c
|
||||
%14 = triton_nvidia_gpu.get_thread_id : i32
|
||||
%15 = arith.cmpi eq, %14, %c0_i32 : i32
|
||||
%16 = arith.andi %15, %10 : i1
|
||||
triton_nvidia_gpu.mbarrier_arrive %13, %16 {operand_segment_sizes = array<i32: 1, 1, 0>, trackAsyncOp = false, txCount = 65536 : i32} : !tt.ptr<i64, 3>, i1
|
||||
%17 = triton_nvidia_gpu.insert_slice_async_v2 %3, %11, %c0_i32, %13, %12 {axis = 0 : i32, cache = 1 : i32, evict = 1 : i32, isVolatile = false, operand_segment_sizes = array<i32: 1, 1, 1, 1, 1, 0>} : !tt.ptr<tensor<128x128xf16, #blocked>, 1>, tensor<3x128x128xf16, #shared1>, i32, !tt.ptr<i64, 3>, tensor<128x128xi1, #blocked1> -> tensor<3x128x128xf16, #shared1>
|
||||
triton_nvidia_gpu.mbarrier_arrive %13, %16 {operandSegmentSizes = array<i32: 1, 1, 0>, trackAsyncOp = false, txCount = 65536 : i32} : !tt.ptr<i64, 3>, i1
|
||||
%17 = triton_nvidia_gpu.insert_slice_async_v2 %3, %11, %c0_i32, %13, %12 {axis = 0 : i32, cache = 1 : i32, evict = 1 : i32, isVolatile = false, operandSegmentSizes = array<i32: 1, 1, 1, 1, 1, 0>} : !tt.ptr<tensor<128x128xf16, #blocked>, 1>, tensor<3x128x128xf16, #shared1>, i32, !tt.ptr<i64, 3>, tensor<128x128xi1, #blocked1> -> tensor<3x128x128xf16, #shared1>
|
||||
%18 = triton_gpu.alloc_tensor : tensor<3x128x128xf16, #shared1>
|
||||
%19 = triton_nvidia_gpu.insert_slice_async_v2 %6, %18, %c0_i32, %13, %12 {axis = 0 : i32, cache = 1 : i32, evict = 1 : i32, isVolatile = false, operand_segment_sizes = array<i32: 1, 1, 1, 1, 1, 0>} : !tt.ptr<tensor<128x128xf16, #blocked>, 1>, tensor<3x128x128xf16, #shared1>, i32, !tt.ptr<i64, 3>, tensor<128x128xi1, #blocked1> -> tensor<3x128x128xf16, #shared1>
|
||||
%19 = triton_nvidia_gpu.insert_slice_async_v2 %6, %18, %c0_i32, %13, %12 {axis = 0 : i32, cache = 1 : i32, evict = 1 : i32, isVolatile = false, operandSegmentSizes = array<i32: 1, 1, 1, 1, 1, 0>} : !tt.ptr<tensor<128x128xf16, #blocked>, 1>, tensor<3x128x128xf16, #shared1>, i32, !tt.ptr<i64, 3>, tensor<128x128xi1, #blocked1> -> tensor<3x128x128xf16, #shared1>
|
||||
%20 = tt.advance %3, [%c0_i32, %c128_i32] : <tensor<128x128xf16, #blocked>, 1>
|
||||
%21 = tt.advance %6, [%c128_i32, %c0_i32] : <tensor<128x128xf16, #blocked>, 1>
|
||||
%22 = arith.cmpi sgt, %arg5, %c128_i32 : i32
|
||||
%23 = tt.splat %22 : (i1) -> tensor<128x128xi1, #blocked1>
|
||||
%24 = triton_nvidia_gpu.extract_mbarrier %9[%c1_i32] : tensor<3xi64, #shared>, i32 -> <i64, 3>
|
||||
%25 = arith.andi %15, %22 : i1
|
||||
triton_nvidia_gpu.mbarrier_arrive %24, %25 {operand_segment_sizes = array<i32: 1, 1, 0>, trackAsyncOp = false, txCount = 65536 : i32} : !tt.ptr<i64, 3>, i1
|
||||
%26 = triton_nvidia_gpu.insert_slice_async_v2 %20, %17, %c1_i32, %24, %23 {axis = 0 : i32, cache = 1 : i32, evict = 1 : i32, isVolatile = false, operand_segment_sizes = array<i32: 1, 1, 1, 1, 1, 0>} : !tt.ptr<tensor<128x128xf16, #blocked>, 1>, tensor<3x128x128xf16, #shared1>, i32, !tt.ptr<i64, 3>, tensor<128x128xi1, #blocked1> -> tensor<3x128x128xf16, #shared1>
|
||||
%27 = triton_nvidia_gpu.insert_slice_async_v2 %21, %19, %c1_i32, %24, %23 {axis = 0 : i32, cache = 1 : i32, evict = 1 : i32, isVolatile = false, operand_segment_sizes = array<i32: 1, 1, 1, 1, 1, 0>} : !tt.ptr<tensor<128x128xf16, #blocked>, 1>, tensor<3x128x128xf16, #shared1>, i32, !tt.ptr<i64, 3>, tensor<128x128xi1, #blocked1> -> tensor<3x128x128xf16, #shared1>
|
||||
triton_nvidia_gpu.mbarrier_arrive %24, %25 {operandSegmentSizes = array<i32: 1, 1, 0>, trackAsyncOp = false, txCount = 65536 : i32} : !tt.ptr<i64, 3>, i1
|
||||
%26 = triton_nvidia_gpu.insert_slice_async_v2 %20, %17, %c1_i32, %24, %23 {axis = 0 : i32, cache = 1 : i32, evict = 1 : i32, isVolatile = false, operandSegmentSizes = array<i32: 1, 1, 1, 1, 1, 0>} : !tt.ptr<tensor<128x128xf16, #blocked>, 1>, tensor<3x128x128xf16, #shared1>, i32, !tt.ptr<i64, 3>, tensor<128x128xi1, #blocked1> -> tensor<3x128x128xf16, #shared1>
|
||||
%27 = triton_nvidia_gpu.insert_slice_async_v2 %21, %19, %c1_i32, %24, %23 {axis = 0 : i32, cache = 1 : i32, evict = 1 : i32, isVolatile = false, operandSegmentSizes = array<i32: 1, 1, 1, 1, 1, 0>} : !tt.ptr<tensor<128x128xf16, #blocked>, 1>, tensor<3x128x128xf16, #shared1>, i32, !tt.ptr<i64, 3>, tensor<128x128xi1, #blocked1> -> tensor<3x128x128xf16, #shared1>
|
||||
%28 = triton_gpu.extract_slice %26[0, 0, 0] [1, 128, 128] [1, 1, 1] : tensor<3x128x128xf16, #shared1> to tensor<128x128xf16, #shared1>
|
||||
%29 = triton_gpu.extract_slice %27[0, 0, 0] [1, 128, 128] [1, 1, 1] : tensor<3x128x128xf16, #shared1> to tensor<128x128xf16, #shared1>
|
||||
%b_29 = triton_gpu.convert_layout %29 : (tensor<128x128xf16, #shared1>) -> tensor<128x128xf16, #blocked1>
|
||||
@@ -158,7 +155,6 @@ module attributes {"triton_gpu.compute-capability" = 90 : i32, "triton_gpu.num-c
|
||||
triton_nvidia_gpu.mbarrier_wait %33, %arg23 : <i64, 3>
|
||||
// CHECK: triton_nvidia_gpu.fence_async_shared
|
||||
%34 = triton_nvidia_gpu.dot_async %arg15, %arg16, %arg10 {allowTF32 = true, maxNumImpreciseAcc = 0 : i32} : tensor<128x128xf16, #shared1> * tensor<128x128xf16, #shared1> -> tensor<128x128xf32, #mma>
|
||||
triton_nvidia_gpu.dot_wait {pendings = 1 : i32}
|
||||
%35 = tt.advance %arg11, [%c0_i32, %c128_i32] : <tensor<128x128xf16, #blocked>, 1>
|
||||
%36 = tt.advance %arg12, [%c128_i32, %c0_i32] : <tensor<128x128xf16, #blocked>, 1>
|
||||
%37 = arith.addi %arg19, %c128_i32 : i32
|
||||
@@ -171,10 +167,10 @@ module attributes {"triton_gpu.compute-capability" = 90 : i32, "triton_gpu.num-c
|
||||
%44 = tt.splat %38 : (i1) -> tensor<128x128xi1, #blocked1>
|
||||
%45 = triton_nvidia_gpu.extract_mbarrier %9[%arg20] : tensor<3xi64, #shared>, i32 -> <i64, 3>
|
||||
%46 = arith.andi %15, %38 : i1
|
||||
triton_nvidia_gpu.mbarrier_arrive %45, %46 {operand_segment_sizes = array<i32: 1, 1, 0>, trackAsyncOp = false, txCount = 65536 : i32} : !tt.ptr<i64, 3>, i1
|
||||
%47 = triton_nvidia_gpu.insert_slice_async_v2 %42, %arg13, %arg20, %45, %44 {axis = 0 : i32, cache = 1 : i32, evict = 1 : i32, isVolatile = false, operand_segment_sizes = array<i32: 1, 1, 1, 1, 1, 0>} : !tt.ptr<tensor<128x128xf16, #blocked>, 1>, tensor<3x128x128xf16, #shared1>, i32, !tt.ptr<i64, 3>, tensor<128x128xi1, #blocked1> -> tensor<3x128x128xf16, #shared1>
|
||||
triton_nvidia_gpu.mbarrier_arrive %45, %46 {operandSegmentSizes = array<i32: 1, 1, 0>, trackAsyncOp = false, txCount = 65536 : i32} : !tt.ptr<i64, 3>, i1
|
||||
%47 = triton_nvidia_gpu.insert_slice_async_v2 %42, %arg13, %arg20, %45, %44 {axis = 0 : i32, cache = 1 : i32, evict = 1 : i32, isVolatile = false, operandSegmentSizes = array<i32: 1, 1, 1, 1, 1, 0>} : !tt.ptr<tensor<128x128xf16, #blocked>, 1>, tensor<3x128x128xf16, #shared1>, i32, !tt.ptr<i64, 3>, tensor<128x128xi1, #blocked1> -> tensor<3x128x128xf16, #shared1>
|
||||
%48 = triton_gpu.extract_slice %47[%41, 0, 0] [1, 128, 128] [1, 1, 1] : tensor<3x128x128xf16, #shared1> to tensor<128x128xf16, #shared1>
|
||||
%49 = triton_nvidia_gpu.insert_slice_async_v2 %43, %arg14, %arg20, %45, %44 {axis = 0 : i32, cache = 1 : i32, evict = 1 : i32, isVolatile = false, operand_segment_sizes = array<i32: 1, 1, 1, 1, 1, 0>} : !tt.ptr<tensor<128x128xf16, #blocked>, 1>, tensor<3x128x128xf16, #shared1>, i32, !tt.ptr<i64, 3>, tensor<128x128xi1, #blocked1> -> tensor<3x128x128xf16, #shared1>
|
||||
%49 = triton_nvidia_gpu.insert_slice_async_v2 %43, %arg14, %arg20, %45, %44 {axis = 0 : i32, cache = 1 : i32, evict = 1 : i32, isVolatile = false, operandSegmentSizes = array<i32: 1, 1, 1, 1, 1, 0>} : !tt.ptr<tensor<128x128xf16, #blocked>, 1>, tensor<3x128x128xf16, #shared1>, i32, !tt.ptr<i64, 3>, tensor<128x128xi1, #blocked1> -> tensor<3x128x128xf16, #shared1>
|
||||
%50 = triton_gpu.extract_slice %49[%41, 0, 0] [1, 128, 128] [1, 1, 1] : tensor<3x128x128xf16, #shared1> to tensor<128x128xf16, #shared1>
|
||||
%51 = arith.addi %arg20, %c1_i32 : i32
|
||||
%52 = arith.cmpi uge, %51, %c3_i32 : i32
|
||||
@@ -192,10 +188,8 @@ module attributes {"triton_gpu.compute-capability" = 90 : i32, "triton_gpu.num-c
|
||||
%64 = arith.ori %62, %63 : i1
|
||||
scf.yield %34, %35, %36, %47, %49, %48, %50, %42, %43, %37, %53, %41, %54, %59, %64 : tensor<128x128xf32, #mma>, !tt.ptr<tensor<128x128xf16, #blocked>, 1>, !tt.ptr<tensor<128x128xf16, #blocked>, 1>, tensor<3x128x128xf16, #shared1>, tensor<3x128x128xf16, #shared1>, tensor<128x128xf16, #shared1>, tensor<128x128xf16, #shared1>, !tt.ptr<tensor<128x128xf16, #blocked>, 1>, !tt.ptr<tensor<128x128xf16, #blocked>, 1>, i32, i32, i32, i32, i1, i1
|
||||
}
|
||||
scf.if %10 {
|
||||
triton_nvidia_gpu.dot_wait {pendings = 0 : i32}
|
||||
}
|
||||
%31 = arith.truncf %30#0 : tensor<128x128xf32, #mma> to tensor<128x128xf16, #mma>
|
||||
%w = triton_nvidia_gpu.dot_wait %30#0 {pendings = 0 : i32} : tensor<128x128xf32, #mma>
|
||||
%31 = arith.truncf %w : tensor<128x128xf32, #mma> to tensor<128x128xf16, #mma>
|
||||
%32 = triton_gpu.convert_layout %31 : (tensor<128x128xf16, #mma>) -> tensor<128x128xf16, #shared1>
|
||||
triton_nvidia_gpu.store_async %8, %32 : !tt.ptr<tensor<128x128xf16, #blocked>, 1>, tensor<128x128xf16, #shared1>
|
||||
triton_gpu.async_bulk_commit_group
|
||||
|
||||
@@ -15,7 +15,6 @@
|
||||
// CHECK-DAG: %[[CONSTANT_0:.*]] = arith.constant 0 : i32
|
||||
// CHECK-DAG: %[[CONSTANT_1:.*]] = arith.constant 1 : i32
|
||||
// CHECK-DAG: %[[CONSTANT_2:.*]] = arith.constant 2 : i32
|
||||
// CHECK-DAG: %[[CONSTANT_3:.*]] = arith.constant 3 : i32
|
||||
// CHECK-DAG: %[[LOOP_COND_0:.*]] = arith.cmpi slt, %[[LB:.*]], %[[UB:.*]]
|
||||
// CHECK: %[[ABUFFER:.*]] = triton_gpu.alloc_tensor
|
||||
// CHECK-DAG: %[[LOOP_COND_0_SPLAT_A:.*]] = tt.splat %[[LOOP_COND_0]]
|
||||
@@ -32,13 +31,13 @@
|
||||
// CHECK: triton_gpu.async_wait {num = 2 : i32}
|
||||
// CHECK: %[[A0:.*]] = triton_gpu.extract_slice %[[A1BUFFER]][0, 0, 0]
|
||||
// CHECK: %[[B0:.*]] = triton_gpu.extract_slice %[[B1BUFFER]][0, 0, 0]
|
||||
// CHECK: scf.for {{.*}} iter_args({{.*}}, {{.*}}, {{.*}}, {{.*}}, {{.*}}, %[[arg_a0:.*]] = %[[A0]], %[[arg_b0:.*]] = %[[B0]], {{.*}}, {{.*}}, {{.*}}, %[[PIPELINE_IDX:.*]] = %[[CONSTANT_2]], %[[LOOP_IDX:.*]] = %[[CONSTANT_0]]
|
||||
// CHECK: scf.for {{.*}} iter_args({{.*}}, {{.*}}, {{.*}}, {{.*}}, {{.*}}, %[[arg_a0:.*]] = %[[A0]], %[[arg_b0:.*]] = %[[B0]], {{.*}}, {{.*}}, {{.*}}, %[[PIPELINE_IDX:.*]] = %[[CONSTANT_0]], %[[LOOP_IDX:.*]] = %[[CONSTANT_0]]
|
||||
// CHECK: %[[arg_a0_dot_op:.*]] = triton_gpu.convert_layout %[[arg_a0]]
|
||||
// CHECK: %[[arg_b0_dot_op_0:.*]] = triton_gpu.convert_layout %[[arg_b0]]
|
||||
// CHECK: %[[arg_b0_dot_op_1:.*]] = arith.mulf %[[arg_b0_dot_op_0]]
|
||||
// CHECK: tt.dot %[[arg_a0_dot_op]], %[[arg_b0_dot_op_1]], {{.*}}
|
||||
// CHECK-DAG: %[[NEXT_LOOP_IDX:.*]] = arith.addi %[[LOOP_IDX]], %[[CONSTANT_1]] : i32
|
||||
// CHECK-DAG: %[[CMP_LOOP:.*]] = arith.cmpi uge, %[[NEXT_LOOP_IDX]], %[[CONSTANT_3]]
|
||||
// CHECK-DAG: %[[CMP_LOOP:.*]] = arith.cmpi uge, %[[NEXT_LOOP_IDX]], %[[CONSTANT_2]]
|
||||
// CHECK-DAG: %[[EXTRACT_IDX:.*]] = arith.select %[[CMP_LOOP]], %[[CONSTANT_0]], %[[NEXT_LOOP_IDX]]
|
||||
// CHECK: %[[NEXT_A_BUFFER:.*]] = triton_gpu.insert_slice_async {{.*}}, {{.*}}, %[[PIPELINE_IDX]]
|
||||
// CHECK: %[[NEXT_B_BUFFER:.*]] = triton_gpu.insert_slice_async {{.*}}, {{.*}}, %[[PIPELINE_IDX]]
|
||||
@@ -46,7 +45,7 @@
|
||||
// CHECK: %[[NEXT_A:.*]] = triton_gpu.extract_slice %[[NEXT_A_BUFFER]][%[[EXTRACT_IDX]], 0, 0]
|
||||
// CHECK: %[[NEXT_B:.*]] = triton_gpu.extract_slice %[[NEXT_B_BUFFER]][%[[EXTRACT_IDX]], 0, 0]
|
||||
// CHECK-DAG: %[[PIPELINE_IDX_PLUS_ONE:.*]] = arith.addi %[[PIPELINE_IDX]], %[[CONSTANT_1]]
|
||||
// CHECK-DAG: %[[CMP_PIPELINE:.*]] = arith.cmpi uge, %[[PIPELINE_IDX_PLUS_ONE]], %[[CONSTANT_3]]
|
||||
// CHECK-DAG: %[[CMP_PIPELINE:.*]] = arith.cmpi uge, %[[PIPELINE_IDX_PLUS_ONE]], %[[CONSTANT_2]]
|
||||
// CHECK-DAG: %[[NEXT_PIPELINE_IDX:.*]] = arith.select %[[CMP_PIPELINE]], %[[CONSTANT_0]], %[[PIPELINE_IDX_PLUS_ONE]]
|
||||
// CHECK: scf.yield {{.*}}, {{.*}}, {{.*}}, %[[NEXT_A_BUFFER]], %[[NEXT_B_BUFFER]], %[[NEXT_A]], %[[NEXT_B]], {{.*}}, {{.*}}, {{.*}}, %[[NEXT_PIPELINE_IDX]], %[[EXTRACT_IDX]]
|
||||
tt.func @matmul_loop(%lb : index, %ub : index, %step : index,
|
||||
@@ -97,7 +96,6 @@ tt.func @matmul_loop(%lb : index, %ub : index, %step : index,
|
||||
// CHECK-DAG: %[[CONSTANT_0:.*]] = arith.constant 0 : i32
|
||||
// CHECK-DAG: %[[CONSTANT_1:.*]] = arith.constant 1 : i32
|
||||
// CHECK-DAG: %[[CONSTANT_2:.*]] = arith.constant 2 : i32
|
||||
// CHECK-DAG: %[[CONSTANT_3:.*]] = arith.constant 3 : i32
|
||||
// CHECK: scf.for
|
||||
// CHECK: %[[ABUFFER:.*]] = triton_gpu.alloc_tensor
|
||||
// CHECK: %[[A0BUFFER:.*]] = triton_gpu.insert_slice_async {{.*}}, {{.*}}, %[[CONSTANT_0]]
|
||||
@@ -108,12 +106,12 @@ tt.func @matmul_loop(%lb : index, %ub : index, %step : index,
|
||||
// CHECK: triton_gpu.async_wait {num = 2 : i32}
|
||||
// CHECK: %[[A0:.*]] = triton_gpu.extract_slice %[[A1BUFFER]][0, 0, 0]
|
||||
// CHECK: %[[B0:.*]] = triton_gpu.extract_slice %[[B1BUFFER]][0, 0, 0]
|
||||
// CHECK: scf.for {{.*}} iter_args({{.*}}, {{.*}}, {{.*}}, {{.*}}, {{.*}}, %[[arg_a0:.*]] = %[[A0]], %[[arg_b0:.*]] = %[[B0]], {{.*}}, {{.*}}, {{.*}}, %[[PIPELINE_IDX:.*]] = %[[CONSTANT_2]], %[[LOOP_IDX:.*]] = %[[CONSTANT_0]]
|
||||
// CHECK: scf.for {{.*}} iter_args({{.*}}, {{.*}}, {{.*}}, {{.*}}, {{.*}}, %[[arg_a0:.*]] = %[[A0]], %[[arg_b0:.*]] = %[[B0]], {{.*}}, {{.*}}, {{.*}}, %[[PIPELINE_IDX:.*]] = %[[CONSTANT_0]], %[[LOOP_IDX:.*]] = %[[CONSTANT_0]]
|
||||
// CHECK: %[[arg_a0_dot_op:.*]] = triton_gpu.convert_layout %[[arg_a0]]
|
||||
// CHECK: %[[arg_b0_dot_op:.*]] = triton_gpu.convert_layout %[[arg_b0]]
|
||||
// CHECK: tt.dot %[[arg_a0_dot_op]], %[[arg_b0_dot_op]], {{.*}}
|
||||
// CHECK-DAG: %[[NEXT_LOOP_IDX:.*]] = arith.addi %[[LOOP_IDX]], %[[CONSTANT_1]] : i32
|
||||
// CHECK-DAG: %[[CMP_LOOP:.*]] = arith.cmpi uge, %[[NEXT_LOOP_IDX]], %[[CONSTANT_3]]
|
||||
// CHECK-DAG: %[[CMP_LOOP:.*]] = arith.cmpi uge, %[[NEXT_LOOP_IDX]], %[[CONSTANT_2]]
|
||||
// CHECK-DAG: %[[EXTRACT_IDX:.*]] = arith.select %[[CMP_LOOP]], %[[CONSTANT_0]], %[[NEXT_LOOP_IDX]]
|
||||
// CHECK: %[[NEXT_A_BUFFER:.*]] = triton_gpu.insert_slice_async {{.*}}, {{.*}}, %[[PIPELINE_IDX]]
|
||||
// CHECK: %[[NEXT_B_BUFFER:.*]] = triton_gpu.insert_slice_async {{.*}}, {{.*}}, %[[PIPELINE_IDX]]
|
||||
@@ -121,7 +119,7 @@ tt.func @matmul_loop(%lb : index, %ub : index, %step : index,
|
||||
// CHECK: %[[NEXT_A:.*]] = triton_gpu.extract_slice %[[NEXT_A_BUFFER]][%[[EXTRACT_IDX]], 0, 0]
|
||||
// CHECK: %[[NEXT_B:.*]] = triton_gpu.extract_slice %[[NEXT_B_BUFFER]][%[[EXTRACT_IDX]], 0, 0]
|
||||
// CHECK-DAG: %[[PIPELINE_IDX_PLUS_ONE:.*]] = arith.addi %[[PIPELINE_IDX]], %[[CONSTANT_1]]
|
||||
// CHECK-DAG: %[[CMP_PIPELINE:.*]] = arith.cmpi uge, %[[PIPELINE_IDX_PLUS_ONE]], %[[CONSTANT_3]]
|
||||
// CHECK-DAG: %[[CMP_PIPELINE:.*]] = arith.cmpi uge, %[[PIPELINE_IDX_PLUS_ONE]], %[[CONSTANT_2]]
|
||||
// CHECK-DAG: %[[NEXT_PIPELINE_IDX:.*]] = arith.select %[[CMP_PIPELINE]], %[[CONSTANT_0]], %[[PIPELINE_IDX_PLUS_ONE]]
|
||||
// CHECK: scf.yield {{.*}}, {{.*}}, {{.*}}, %[[NEXT_A_BUFFER]], %[[NEXT_B_BUFFER]], %[[NEXT_A]], %[[NEXT_B]], {{.*}}, {{.*}}, {{.*}}, %[[NEXT_PIPELINE_IDX]], %[[EXTRACT_IDX]]
|
||||
tt.func @matmul_loop_nested(%lb : index, %ub : index, %step : index,
|
||||
@@ -174,23 +172,22 @@ tt.func @matmul_loop_nested(%lb : index, %ub : index, %step : index,
|
||||
// CHECK-DAG: %[[CONSTANT_0:.*]] = arith.constant 0 : i32
|
||||
// CHECK-DAG: %[[CONSTANT_1:.*]] = arith.constant 1 : i32
|
||||
// CHECK-DAG: %[[CONSTANT_2:.*]] = arith.constant 2 : i32
|
||||
// CHECK-DAG: %[[CONSTANT_3:.*]] = arith.constant 3 : i32
|
||||
// CHECK: %[[BBUFFER:.*]] = triton_gpu.alloc_tensor
|
||||
// CHECK: %[[B0BUFFER:.*]] = triton_gpu.insert_slice_async {{.*}}, {{.*}}, %[[CONSTANT_0]]
|
||||
// CHECK: %[[B1BUFFER:.*]] = triton_gpu.insert_slice_async {{.*}}, {{.*}}, %[[CONSTANT_1]]
|
||||
// CHECK: triton_gpu.async_wait {num = 1 : i32}
|
||||
// CHECK: %[[B0:.*]] = triton_gpu.extract_slice %[[B1BUFFER]][0, 0, 0]
|
||||
// CHECK: scf.for {{.*}} iter_args({{.*}}, {{.*}}, {{.*}}, %[[arg_b0:.*]] = %[[B0]], {{.*}}, {{.*}}, %[[PIPELINE_IDX:.*]] = %[[CONSTANT_2]], %[[LOOP_IDX:.*]] = %[[CONSTANT_0]]
|
||||
// CHECK: scf.for {{.*}} iter_args({{.*}}, {{.*}}, {{.*}}, %[[arg_b0:.*]] = %[[B0]], {{.*}}, {{.*}}, %[[PIPELINE_IDX:.*]] = %[[CONSTANT_0]], %[[LOOP_IDX:.*]] = %[[CONSTANT_0]]
|
||||
// CHECK: %[[arg_b0_dot_op:.*]] = triton_gpu.convert_layout %[[arg_b0]]
|
||||
// CHECK: tt.dot {{.*}}, %[[arg_b0_dot_op]], {{.*}}
|
||||
// CHECK-DAG: %[[NEXT_LOOP_IDX:.*]] = arith.addi %[[LOOP_IDX]], %[[CONSTANT_1]] : i32
|
||||
// CHECK-DAG: %[[CMP_LOOP:.*]] = arith.cmpi uge, %[[NEXT_LOOP_IDX]], %[[CONSTANT_3]]
|
||||
// CHECK-DAG: %[[CMP_LOOP:.*]] = arith.cmpi uge, %[[NEXT_LOOP_IDX]], %[[CONSTANT_2]]
|
||||
// CHECK-DAG: %[[EXTRACT_IDX:.*]] = arith.select %[[CMP_LOOP]], %[[CONSTANT_0]], %[[NEXT_LOOP_IDX]]
|
||||
// CHECK: %[[NEXT_B_BUFFER:.*]] = triton_gpu.insert_slice_async {{.*}}, {{.*}}, %[[PIPELINE_IDX]]
|
||||
// CHECK: triton_gpu.async_wait {num = 1 : i32}
|
||||
// CHECK: %[[NEXT_B:.*]] = triton_gpu.extract_slice %[[NEXT_B_BUFFER]][%[[EXTRACT_IDX]], 0, 0]
|
||||
// CHECK-DAG: %[[PIPELINE_IDX_PLUS_ONE:.*]] = arith.addi %[[PIPELINE_IDX]], %[[CONSTANT_1]]
|
||||
// CHECK-DAG: %[[CMP_PIPELINE:.*]] = arith.cmpi uge, %[[PIPELINE_IDX_PLUS_ONE]], %[[CONSTANT_3]]
|
||||
// CHECK-DAG: %[[CMP_PIPELINE:.*]] = arith.cmpi uge, %[[PIPELINE_IDX_PLUS_ONE]], %[[CONSTANT_2]]
|
||||
// CHECK-DAG: %[[NEXT_PIPELINE_IDX:.*]] = arith.select %[[CMP_PIPELINE]], %[[CONSTANT_0]], %[[PIPELINE_IDX_PLUS_ONE]]
|
||||
// CHECK: scf.yield {{.*}}, {{.*}}, %[[NEXT_B_BUFFER]], %[[NEXT_B]], {{.*}}, {{.*}}, %[[NEXT_PIPELINE_IDX]], %[[EXTRACT_IDX]]
|
||||
tt.func @matmul_loop_single_pipeline(%lb : index, %ub : index, %step : index,
|
||||
|
||||
@@ -12,7 +12,7 @@ module attributes {"triton_gpu.num-ctas" = 1 : i32, "triton_gpu.num-warps" = 4 :
|
||||
%a_tileptr_init = tt.make_tensor_ptr %A, [%c64, %c16], [%c16, %c1], [%c0, %c0] { order = array<i32: 1, 0> } : !tt.ptr<tensor<64x16xf16>, 1>
|
||||
// CHECK: %[[BUFFER:.*]] = triton_gpu.alloc_tensor : tensor<1x64x16xf16, #shared>
|
||||
// CHECK: %[[MBAR:.*]] = triton_nvidia_gpu.alloc_mbarrier {count = 1 : i32} : !tt.ptr<i64, 3>
|
||||
// CHECK: triton_nvidia_gpu.mbarrier_arrive %[[MBAR]], %{{.*}} {operand_segment_sizes = array<i32: 1, 1, 0>, trackAsyncOp = false, txCount = 2048 : i32} : !tt.ptr<i64, 3>, i1
|
||||
// CHECK: triton_nvidia_gpu.mbarrier_arrive %[[MBAR]], %{{.*}} {operandSegmentSizes = array<i32: 1, 1, 0>, trackAsyncOp = false, txCount = 2048 : i32} : !tt.ptr<i64, 3>, i1
|
||||
// CHECK: %[[INSERT:.*]] = triton_nvidia_gpu.insert_slice_async_v2 %[[TENSOR_PTR]], %[[BUFFER]], %{{.*}}, %[[MBAR]]
|
||||
// CHECK: %[[EXT:.*]] = triton_gpu.extract_slice %[[INSERT]][0, 0, 0] [1, 64, 16] [1, 1, 1] : tensor<1x64x16xf16, #shared> to tensor<64x16xf16, #shared>
|
||||
// CHECK: triton_nvidia_gpu.mbarrier_wait %[[MBAR]], %false : <i64, 3>
|
||||
|
||||
@@ -1,8 +1,8 @@
|
||||
// RUN: triton-opt %s -split-input-file -convert-triton-to-tritongpu -tritongpu-remove-layout-conversions -tritongpu-pipeline=num-stages=3 -test-print-allocation 2>&1 | FileCheck %s
|
||||
|
||||
// CHECK: offset = 0, size = 49152
|
||||
// CHECK: offset = 49152, size = 49152
|
||||
// CHECK: size = 98304
|
||||
// CHECK: offset = 0, size = 32768
|
||||
// CHECK: offset = 32768, size = 32768
|
||||
// CHECK: size = 65536
|
||||
module {
|
||||
tt.func @matmul_kernel__Pfp32_Pfp32_Pfp32_i32_i32_i32_i32_i32_i32_i32_i32_i32__12c64_13c64_14c64_15c8(%arg0: !tt.ptr<f32> {tt.divisibility = 16 : i32}, %arg1: !tt.ptr<f32> {tt.divisibility = 16 : i32}, %arg2: !tt.ptr<f32> {tt.divisibility = 16 : i32}, %arg3: i32, %arg4: i32, %arg5: i32, %arg6: i32 {tt.divisibility = 16 : i32}, %arg7: i32, %arg8: i32 {tt.divisibility = 16 : i32}, %arg9: i32, %arg10: i32 {tt.divisibility = 16 : i32}, %arg11: i32) {
|
||||
%cst = arith.constant dense<true> : tensor<64x64xi1>
|
||||
|
||||
@@ -1,234 +1,172 @@
|
||||
// RUN: triton-opt -split-input-file -triton-nvidia-gpu-ws-materialization='compute-capability=90' %s | FileCheck %s
|
||||
|
||||
#blocked = #triton_gpu.blocked<{sizePerThread = [8, 1], threadsPerWarp = [4, 8], warpsPerCTA = [1, 4], order = [0, 1], CTAsPerCGA = [1, 1], CTASplitNum = [1, 1], CTAOrder = [0, 1]}>
|
||||
#blocked1 = #triton_gpu.blocked<{sizePerThread = [1, 8], threadsPerWarp = [8, 4], warpsPerCTA = [4, 1], order = [1, 0], CTAsPerCGA = [1, 1], CTASplitNum = [1, 1], CTAOrder = [0, 1]}>
|
||||
#blocked2 = #triton_gpu.blocked<{sizePerThread = [1, 8], threadsPerWarp = [2, 16], warpsPerCTA = [4, 1], order = [1, 0], CTAsPerCGA = [1, 1], CTASplitNum = [1, 1], CTAOrder = [0, 1]}>
|
||||
#blocked = #triton_gpu.blocked<{sizePerThread = [1, 1], threadsPerWarp = [32, 1], warpsPerCTA = [4, 1], order = [0, 1], CTAsPerCGA = [1, 1], CTASplitNum = [1, 1], CTAOrder = [0, 1]}>
|
||||
#blocked1 = #triton_gpu.blocked<{sizePerThread = [1, 1], threadsPerWarp = [32, 1], warpsPerCTA = [2, 2], order = [0, 1], CTAsPerCGA = [1, 1], CTASplitNum = [1, 1], CTAOrder = [0, 1]}>
|
||||
#blocked2 = #triton_gpu.blocked<{sizePerThread = [1, 4], threadsPerWarp = [1, 32], warpsPerCTA = [4, 1], order = [1, 0], CTAsPerCGA = [1, 1], CTASplitNum = [1, 1], CTAOrder = [0, 1]}>
|
||||
#mma = #triton_gpu.mma<{versionMajor = 3, versionMinor = 0, warpsPerCTA = [4, 1], CTAsPerCGA = [1, 1], CTASplitNum = [1, 1], CTAOrder = [1, 0], instrShape = [16, 128, 16]}>
|
||||
#shared = #triton_gpu.shared<{vec = 8, perPhase = 2, maxPhase = 4, order = [0, 1], CTAsPerCGA = [1, 1], CTASplitNum = [1, 1], CTAOrder = [0, 1], hasLeadingOffset = true}>
|
||||
#shared1 = #triton_gpu.shared<{vec = 8, perPhase = 2, maxPhase = 4, order = [1, 0], CTAsPerCGA = [1, 1], CTASplitNum = [1, 1], CTAOrder = [0, 1], hasLeadingOffset = true}>
|
||||
module attributes {"triton_gpu.num-ctas" = 1 : i32, "triton_gpu.num-warps" = 4 : i32, "triton_gpu.enable-warp-specialization" = 1 : i32} {
|
||||
#shared = #triton_gpu.shared<{vec = 8, perPhase = 1, maxPhase = 8, order = [1, 0], CTAsPerCGA = [1, 1], CTASplitNum = [1, 1], CTAOrder = [0, 1], hasLeadingOffset = true}>
|
||||
#shared1 = #triton_gpu.shared<{vec = 8, perPhase = 1, maxPhase = 8, order = [0, 1], CTAsPerCGA = [1, 1], CTASplitNum = [1, 1], CTAOrder = [0, 1], hasLeadingOffset = true}>
|
||||
module attributes {"async.num-agents" = 2 : i32, "triton_gpu.compute-capability" = 90 : i32, "triton_gpu.enable-warp-specialization" = 1 : i32, "triton_gpu.num-ctas" = 1 : i32, "triton_gpu.num-warps" = 4 : i32, "triton_gpu.threads-per-warp" = 32 : i32} {
|
||||
// CHECK-LABEL: @simple_gemm
|
||||
// CHECK: triton_nvidia_gpu.alloc_mbarrier
|
||||
// CHECK: scf.if
|
||||
// CHECK: scf.for
|
||||
// CHECK: triton_nvidia_gpu.extract_mbarrier
|
||||
// CHECK: triton_nvidia_gpu.mbarrier_wait
|
||||
// CHECK: triton_gpu.insert_slice
|
||||
// CHECK: triton_gpu.insert_slice
|
||||
// CHECK: triton_nvidia_gpu.insert_slice_async_v2
|
||||
// CHECK: triton_nvidia_gpu.insert_slice_async_v2
|
||||
// CHECK: triton_nvidia_gpu.extract_mbarrier
|
||||
// CHECK: triton_nvidia_gpu.mbarrier_arrive
|
||||
// CHECK: scf.yield
|
||||
// CHECK: scf.if
|
||||
// CHECK: triton_nvidia_gpu.bar_wait
|
||||
// CHECK: scf.for
|
||||
// CHECK: triton_nvidia_gpu.extract_mbarrier
|
||||
// CHECK: triton_nvidia_gpu.mbarrier_wait
|
||||
// CHECK: triton_gpu.extract_slice
|
||||
// CHECK: triton_gpu.extract_slice
|
||||
// CHECK: tt.dot
|
||||
// CHECK: triton_nvidia_gpu.dot_async
|
||||
// CHECK: triton_nvidia_gpu.extract_mbarrier
|
||||
// CHECK: triton_nvidia_gpu.mbarrier_arrive
|
||||
// CHECK: scf.yield
|
||||
// CHECK: triton_nvidia_gpu.bar_arrive
|
||||
// CHECK: triton_nvidia_gpu.bar_wait
|
||||
// CHECK: triton_nvidia_gpu.dot_wait
|
||||
// CHECK: triton_nvidia_gpu.extract_mbarrier
|
||||
// CHECK: triton_nvidia_gpu.mbarrier_arrive
|
||||
// CHECK: tt.store
|
||||
// CHECK: triton_nvidia_gpu.bar_arrive
|
||||
tt.func public @simple_gemm(%arg0: !tt.ptr<f16, 1> {tt.divisibility = 16 : i32}, %arg1: !tt.ptr<f16, 1> {tt.divisibility = 16 : i32}, %arg2: !tt.ptr<f16, 1> {tt.divisibility = 16 : i32}, %arg3: i32 {tt.divisibility = 16 : i32}, %arg4: i32 {tt.divisibility = 16 : i32}, %arg5: i32 {tt.divisibility = 16 : i32}, %arg6: i32 {tt.divisibility = 16 : i32}, %arg7: i32 {tt.divisibility = 16 : i32}, %arg8: i32 {tt.divisibility = 16 : i32}) {
|
||||
%0 = triton_gpu.alloc_tensor : tensor<3x32x128xf16, #shared>
|
||||
%1 = triton_gpu.alloc_tensor : tensor<3x128x32xf16, #shared1>
|
||||
tt.func public @simple_gemm(%arg0: !tt.ptr<f16, 1> {tt.divisibility = 16 : i32}, %arg1: !tt.ptr<f16, 1> {tt.divisibility = 16 : i32}, %arg2: !tt.ptr<f16, 1> {tt.divisibility = 16 : i32}, %arg3: !tt.ptr<f32, 1> {tt.divisibility = 16 : i32}, %arg4: !tt.ptr<f32, 1> {tt.divisibility = 16 : i32}, %arg5: i32 {tt.divisibility = 16 : i32, tt.max_divisibility = 8 : i32}, %arg6: i32 {tt.max_divisibility = 8 : i32}, %arg7: i32 {tt.divisibility = 16 : i32, tt.max_divisibility = 8 : i32}, %arg8: i32 {tt.divisibility = 16 : i32, tt.max_divisibility = 8 : i32}, %arg9: i32 {tt.divisibility = 16 : i32, tt.max_divisibility = 8 : i32}, %arg10: i32 {tt.max_divisibility = 8 : i32}, %arg11: i32 {tt.max_divisibility = 8 : i32}) attributes {noinline = false} {
|
||||
%0 = triton_gpu.alloc_tensor : tensor<3x128x64xf16, #shared>
|
||||
%1 = triton_gpu.alloc_tensor : tensor<3x64x128xf16, #shared1>
|
||||
%2 = triton_nvidia_gpu.create_token {num = 3 : i32} : tensor<3x!triton_nvidia_gpu.token>
|
||||
%3 = triton_nvidia_gpu.create_mutex : !triton_nvidia_gpu.mutex
|
||||
%4 = triton_nvidia_gpu.create_mutex : !triton_nvidia_gpu.mutex
|
||||
%5 = triton_nvidia_gpu.get_agent_id : i32
|
||||
%c0_i32 = arith.constant 0 : i32
|
||||
%6 = arith.cmpi eq, %5, %c0_i32 : i32
|
||||
scf.if %6 {
|
||||
%cst = arith.constant {async_agent = dense<0> : vector<1xi32>} dense<32> : tensor<32x128xi32, #blocked>
|
||||
%cst_1 = arith.constant {async_agent = dense<0> : vector<1xi32>} dense<32> : tensor<128x32xi32, #blocked1>
|
||||
%c31_i32 = arith.constant {async_agent = dense<[0, 1]> : vector<2xi32>} 31 : i32
|
||||
%c127_i32 = arith.constant {async_agent = dense<[0, 1]> : vector<2xi32>} 127 : i32
|
||||
%c1 = arith.constant {async_agent = dense<[0, 1]> : vector<2xi32>} 1 : index
|
||||
%c0 = arith.constant {async_agent = dense<[0, 1]> : vector<2xi32>} 0 : index
|
||||
%c32_i32 = arith.constant {async_agent = dense<[0, 1]> : vector<2xi32>} 32 : i32
|
||||
%c128_i32 = arith.constant {async_agent = dense<[0, 1]> : vector<2xi32>} 128 : i32
|
||||
%c8_i32 = arith.constant {async_agent = dense<[0, 1]> : vector<2xi32>} 8 : i32
|
||||
%8 = tt.get_program_id x {async_agent = dense<[0, 1]> : vector<2xi32>} : i32
|
||||
%9 = tt.get_program_id y {async_agent = dense<0> : vector<1xi32>} : i32
|
||||
%10 = arith.addi %arg3, %c127_i32 {async_agent = dense<[0, 1]> : vector<2xi32>} : i32
|
||||
%11 = arith.divsi %10, %c128_i32 {async_agent = dense<[0, 1]> : vector<2xi32>} : i32
|
||||
%12 = arith.addi %arg4, %c127_i32 {async_agent = dense<[0, 1]> : vector<2xi32>} : i32
|
||||
%13 = arith.divsi %12, %c128_i32 {async_agent = dense<[0, 1]> : vector<2xi32>} : i32
|
||||
%14 = arith.muli %13, %c8_i32 {async_agent = dense<[0, 1]> : vector<2xi32>} : i32
|
||||
%15 = arith.divsi %8, %14 {async_agent = dense<[0, 1]> : vector<2xi32>} : i32
|
||||
%16 = arith.muli %15, %c8_i32 {async_agent = dense<[0, 1]> : vector<2xi32>} : i32
|
||||
%17 = arith.subi %11, %16 {async_agent = dense<[0, 1]> : vector<2xi32>} : i32
|
||||
%18 = arith.cmpi slt, %17, %c8_i32 {async_agent = dense<[0, 1]> : vector<2xi32>} : i32
|
||||
%19 = arith.select %18, %17, %c8_i32 {async_agent = dense<[0, 1]> : vector<2xi32>} : i32
|
||||
%20 = arith.remsi %8, %19 {async_agent = dense<[0, 1]> : vector<2xi32>} : i32
|
||||
%21 = arith.addi %16, %20 {async_agent = dense<[0, 1]> : vector<2xi32>} : i32
|
||||
%22 = arith.remsi %8, %14 {async_agent = dense<[0, 1]> : vector<2xi32>} : i32
|
||||
%23 = arith.divsi %22, %19 {async_agent = dense<[0, 1]> : vector<2xi32>} : i32
|
||||
%24 = arith.muli %21, %c128_i32 {async_agent = dense<[0, 1]> : vector<2xi32>} : i32
|
||||
%25 = tt.make_range {async_agent = dense<0> : vector<1xi32>, end = 128 : i32, start = 0 : i32} : tensor<128xi32, #triton_gpu.slice<{dim = 1, parent = #blocked1}>>
|
||||
%26 = tt.make_range {async_agent = dense<0> : vector<1xi32>, end = 128 : i32, start = 0 : i32} : tensor<128xi32, #triton_gpu.slice<{dim = 0, parent = #blocked}>>
|
||||
%27 = tt.splat %24 {async_agent = dense<0> : vector<1xi32>} : (i32) -> tensor<128xi32, #triton_gpu.slice<{dim = 1, parent = #blocked1}>>
|
||||
%28 = arith.addi %27, %25 {async_agent = dense<0> : vector<1xi32>} : tensor<128xi32, #triton_gpu.slice<{dim = 1, parent = #blocked1}>>
|
||||
%29 = arith.muli %23, %c128_i32 {async_agent = dense<[0, 1]> : vector<2xi32>} : i32
|
||||
%30 = tt.splat %29 {async_agent = dense<0> : vector<1xi32>} : (i32) -> tensor<128xi32, #triton_gpu.slice<{dim = 0, parent = #blocked}>>
|
||||
%31 = arith.addi %30, %26 {async_agent = dense<0> : vector<1xi32>} : tensor<128xi32, #triton_gpu.slice<{dim = 0, parent = #blocked}>>
|
||||
%32 = tt.splat %arg3 {async_agent = dense<0> : vector<1xi32>} : (i32) -> tensor<128xi32, #triton_gpu.slice<{dim = 1, parent = #blocked1}>>
|
||||
%33 = arith.remsi %28, %32 {async_agent = dense<0> : vector<1xi32>, tt.contiguity = dense<128> : tensor<1xi32>, tt.divisibility = dense<128> : tensor<1xi32>} : tensor<128xi32, #triton_gpu.slice<{dim = 1, parent = #blocked1}>>
|
||||
%34 = tt.splat %arg4 {async_agent = dense<0> : vector<1xi32>} : (i32) -> tensor<128xi32, #triton_gpu.slice<{dim = 0, parent = #blocked}>>
|
||||
%35 = arith.remsi %31, %34 {async_agent = dense<0> : vector<1xi32>, tt.contiguity = dense<128> : tensor<1xi32>, tt.divisibility = dense<128> : tensor<1xi32>} : tensor<128xi32, #triton_gpu.slice<{dim = 0, parent = #blocked}>>
|
||||
%36 = arith.muli %9, %c32_i32 {async_agent = dense<0> : vector<1xi32>} : i32
|
||||
%37 = tt.make_range {async_agent = dense<0> : vector<1xi32>, end = 32 : i32, start = 0 : i32} : tensor<32xi32, #triton_gpu.slice<{dim = 0, parent = #blocked1}>>
|
||||
%38 = tt.make_range {async_agent = dense<0> : vector<1xi32>, end = 32 : i32, start = 0 : i32} : tensor<32xi32, #triton_gpu.slice<{dim = 1, parent = #blocked}>>
|
||||
%39 = tt.splat %36 {async_agent = dense<0> : vector<1xi32>} : (i32) -> tensor<32xi32, #triton_gpu.slice<{dim = 0, parent = #blocked1}>>
|
||||
%40 = tt.splat %36 {async_agent = dense<0> : vector<1xi32>} : (i32) -> tensor<32xi32, #triton_gpu.slice<{dim = 1, parent = #blocked}>>
|
||||
%41 = arith.addi %39, %37 {async_agent = dense<0> : vector<1xi32>} : tensor<32xi32, #triton_gpu.slice<{dim = 0, parent = #blocked1}>>
|
||||
%42 = arith.addi %40, %38 {async_agent = dense<0> : vector<1xi32>} : tensor<32xi32, #triton_gpu.slice<{dim = 1, parent = #blocked}>>
|
||||
%43 = tt.expand_dims %33 {async_agent = dense<0> : vector<1xi32>, axis = 1 : i32} : (tensor<128xi32, #triton_gpu.slice<{dim = 1, parent = #blocked1}>>) -> tensor<128x1xi32, #blocked1>
|
||||
%44 = tt.splat %arg6 {async_agent = dense<0> : vector<1xi32>} : (i32) -> tensor<128x1xi32, #blocked1>
|
||||
%45 = arith.muli %43, %44 {async_agent = dense<0> : vector<1xi32>} : tensor<128x1xi32, #blocked1>
|
||||
%46 = tt.expand_dims %41 {async_agent = dense<0> : vector<1xi32>, axis = 0 : i32} : (tensor<32xi32, #triton_gpu.slice<{dim = 0, parent = #blocked1}>>) -> tensor<1x32xi32, #blocked1>
|
||||
%47 = tt.broadcast %45 {async_agent = dense<0> : vector<1xi32>} : (tensor<128x1xi32, #blocked1>) -> tensor<128x32xi32, #blocked1>
|
||||
%48 = tt.broadcast %46 {async_agent = dense<0> : vector<1xi32>} : (tensor<1x32xi32, #blocked1>) -> tensor<128x32xi32, #blocked1>
|
||||
%49 = arith.addi %47, %48 {async_agent = dense<0> : vector<1xi32>} : tensor<128x32xi32, #blocked1>
|
||||
%50 = tt.splat %arg0 {async_agent = dense<0> : vector<1xi32>} : (!tt.ptr<f16, 1>) -> tensor<128x32x!tt.ptr<f16, 1>, #blocked1>
|
||||
%51 = tt.addptr %50, %49 {async_agent = dense<0> : vector<1xi32>} : tensor<128x32x!tt.ptr<f16, 1>, #blocked1>, tensor<128x32xi32, #blocked1>
|
||||
%52 = tt.expand_dims %42 {async_agent = dense<0> : vector<1xi32>, axis = 1 : i32} : (tensor<32xi32, #triton_gpu.slice<{dim = 1, parent = #blocked}>>) -> tensor<32x1xi32, #blocked>
|
||||
%53 = tt.expand_dims %35 {async_agent = dense<0> : vector<1xi32>, axis = 0 : i32} : (tensor<128xi32, #triton_gpu.slice<{dim = 0, parent = #blocked}>>) -> tensor<1x128xi32, #blocked>
|
||||
%54 = tt.splat %arg7 {async_agent = dense<0> : vector<1xi32>} : (i32) -> tensor<1x128xi32, #blocked>
|
||||
%55 = arith.muli %53, %54 {async_agent = dense<0> : vector<1xi32>} : tensor<1x128xi32, #blocked>
|
||||
%56 = tt.broadcast %52 {async_agent = dense<0> : vector<1xi32>} : (tensor<32x1xi32, #blocked>) -> tensor<32x128xi32, #blocked>
|
||||
%57 = tt.broadcast %55 {async_agent = dense<0> : vector<1xi32>} : (tensor<1x128xi32, #blocked>) -> tensor<32x128xi32, #blocked>
|
||||
%58 = arith.addi %56, %57 {async_agent = dense<0> : vector<1xi32>} : tensor<32x128xi32, #blocked>
|
||||
%59 = tt.splat %arg1 {async_agent = dense<0> : vector<1xi32>} : (!tt.ptr<f16, 1>) -> tensor<32x128x!tt.ptr<f16, 1>, #blocked>
|
||||
%60 = tt.addptr %59, %58 {async_agent = dense<0> : vector<1xi32>} : tensor<32x128x!tt.ptr<f16, 1>, #blocked>, tensor<32x128xi32, #blocked>
|
||||
%61 = arith.addi %arg5, %c31_i32 {async_agent = dense<[0, 1]> : vector<2xi32>} : i32
|
||||
%62 = arith.divsi %61, %c32_i32 {async_agent = dense<[0, 1]> : vector<2xi32>} : i32
|
||||
%63 = arith.index_cast %62 {async_agent = dense<[0, 1]> : vector<2xi32>} : i32 to index
|
||||
%c0_i32_2 = arith.constant {async_agent = dense<[0, 1]> : vector<2xi32>} 0 : i32
|
||||
%64:3 = scf.for %arg9 = %c0 to %63 step %c1 iter_args(%arg10 = %51, %arg11 = %60, %arg12 = %c0_i32_2) -> (tensor<128x32x!tt.ptr<f16, 1>, #blocked1>, tensor<32x128x!tt.ptr<f16, 1>, #blocked>, i32) {
|
||||
triton_nvidia_gpu.producer_acquire %2, %arg12 {async_agent = dense<1> : vector<1xi32>} : tensor<3x!triton_nvidia_gpu.token>, i32
|
||||
%65 = triton_gpu.insert_slice %arg10, %1, %arg12 {async_agent = dense<0> : vector<1xi32>, axis = 0 : i32, cache = 1 : i32, evict = 1 : i32, isVolatile = false} : tensor<128x32x!tt.ptr<f16, 1>, #blocked1> -> tensor<3x128x32xf16, #shared1>
|
||||
%66 = triton_gpu.insert_slice %arg11, %0, %arg12 {async_agent = dense<0> : vector<1xi32>, axis = 0 : i32, cache = 1 : i32, evict = 1 : i32, isVolatile = false} : tensor<32x128x!tt.ptr<f16, 1>, #blocked> -> tensor<3x32x128xf16, #shared>
|
||||
%67 = tt.addptr %arg10, %cst_1 {async_agent = dense<0> : vector<1xi32>} : tensor<128x32x!tt.ptr<f16, 1>, #blocked1>, tensor<128x32xi32, #blocked1>
|
||||
%68 = tt.addptr %arg11, %cst {async_agent = dense<0> : vector<1xi32>} : tensor<32x128x!tt.ptr<f16, 1>, #blocked>, tensor<32x128xi32, #blocked>
|
||||
%c1_i32_3 = arith.constant {async_agent = dense<0> : vector<1xi32>} 1 : i32
|
||||
%c3_i32 = arith.constant {async_agent = dense<0> : vector<1xi32>} 3 : i32
|
||||
%69 = arith.addi %arg12, %c1_i32_3 {async_agent = dense<0> : vector<1xi32>} : i32
|
||||
%70 = arith.remsi %69, %c3_i32 {async_agent = dense<0> : vector<1xi32>} : i32
|
||||
triton_nvidia_gpu.producer_commit %2, %arg12 {async_agent = dense<1> : vector<1xi32>} : tensor<3x!triton_nvidia_gpu.token>, i32
|
||||
scf.yield %67, %68, %70 : tensor<128x32x!tt.ptr<f16, 1>, #blocked1>, tensor<32x128x!tt.ptr<f16, 1>, #blocked>, i32
|
||||
} {async_agent = dense<0> : vector<1xi32>}
|
||||
}
|
||||
%c1_i32 = arith.constant 1 : i32
|
||||
%c1_i32_0 = arith.constant 1 : i32
|
||||
%7 = arith.cmpi sge, %5, %c1_i32_0 : i32
|
||||
scf.if %7 {
|
||||
%cst = arith.constant {async_agent = dense<1> : vector<1xi32>} dense<0.000000e+00> : tensor<128x128xf32, #mma>
|
||||
%c31_i32 = arith.constant {async_agent = dense<[0, 1]> : vector<2xi32>} 31 : i32
|
||||
%c127_i32 = arith.constant {async_agent = dense<[0, 1]> : vector<2xi32>} 127 : i32
|
||||
%c1 = arith.constant {async_agent = dense<[0, 1]> : vector<2xi32>} 1 : index
|
||||
%c0 = arith.constant {async_agent = dense<[0, 1]> : vector<2xi32>} 0 : index
|
||||
%c32_i32 = arith.constant {async_agent = dense<[0, 1]> : vector<2xi32>} 32 : i32
|
||||
%c128_i32 = arith.constant {async_agent = dense<[0, 1]> : vector<2xi32>} 128 : i32
|
||||
%c8_i32 = arith.constant {async_agent = dense<[0, 1]> : vector<2xi32>} 8 : i32
|
||||
%8 = tt.get_program_id x {async_agent = dense<[0, 1]> : vector<2xi32>} : i32
|
||||
%9 = arith.addi %arg3, %c127_i32 {async_agent = dense<[0, 1]> : vector<2xi32>} : i32
|
||||
%10 = arith.divsi %9, %c128_i32 {async_agent = dense<[0, 1]> : vector<2xi32>} : i32
|
||||
%11 = arith.addi %arg4, %c127_i32 {async_agent = dense<[0, 1]> : vector<2xi32>} : i32
|
||||
%12 = arith.divsi %11, %c128_i32 {async_agent = dense<[0, 1]> : vector<2xi32>} : i32
|
||||
%13 = arith.muli %12, %c8_i32 {async_agent = dense<[0, 1]> : vector<2xi32>} : i32
|
||||
%14 = arith.divsi %8, %13 {async_agent = dense<[0, 1]> : vector<2xi32>} : i32
|
||||
%15 = arith.muli %14, %c8_i32 {async_agent = dense<[0, 1]> : vector<2xi32>} : i32
|
||||
%16 = arith.subi %10, %15 {async_agent = dense<[0, 1]> : vector<2xi32>} : i32
|
||||
%17 = arith.cmpi slt, %16, %c8_i32 {async_agent = dense<[0, 1]> : vector<2xi32>} : i32
|
||||
%18 = arith.select %17, %16, %c8_i32 {async_agent = dense<[0, 1]> : vector<2xi32>} : i32
|
||||
%19 = arith.remsi %8, %18 {async_agent = dense<[0, 1]> : vector<2xi32>} : i32
|
||||
%20 = arith.addi %15, %19 {async_agent = dense<[0, 1]> : vector<2xi32>} : i32
|
||||
%21 = arith.remsi %8, %13 {async_agent = dense<[0, 1]> : vector<2xi32>} : i32
|
||||
%22 = arith.divsi %21, %18 {async_agent = dense<[0, 1]> : vector<2xi32>} : i32
|
||||
%23 = arith.muli %20, %c128_i32 {async_agent = dense<[0, 1]> : vector<2xi32>} : i32
|
||||
%24 = tt.make_range {async_agent = dense<1> : vector<1xi32>, end = 128 : i32, start = 0 : i32} : tensor<128xi32, #triton_gpu.slice<{dim = 1, parent = #blocked2}>>
|
||||
%25 = tt.make_range {async_agent = dense<1> : vector<1xi32>, end = 128 : i32, start = 0 : i32} : tensor<128xi32, #triton_gpu.slice<{dim = 0, parent = #blocked2}>>
|
||||
%26 = tt.make_range {async_agent = dense<1> : vector<1xi32>, end = 128 : i32, start = 0 : i32} : tensor<128xi32, #triton_gpu.slice<{dim = 1, parent = #blocked2}>>
|
||||
%27 = tt.make_range {async_agent = dense<1> : vector<1xi32>, end = 128 : i32, start = 0 : i32} : tensor<128xi32, #triton_gpu.slice<{dim = 0, parent = #blocked2}>>
|
||||
%28 = tt.splat %23 {async_agent = dense<1> : vector<1xi32>} : (i32) -> tensor<128xi32, #triton_gpu.slice<{dim = 1, parent = #blocked2}>>
|
||||
%29 = tt.splat %23 {async_agent = dense<1> : vector<1xi32>} : (i32) -> tensor<128xi32, #triton_gpu.slice<{dim = 1, parent = #blocked2}>>
|
||||
%30 = arith.addi %28, %24 {async_agent = dense<1> : vector<1xi32>} : tensor<128xi32, #triton_gpu.slice<{dim = 1, parent = #blocked2}>>
|
||||
%31 = arith.addi %29, %26 {async_agent = dense<1> : vector<1xi32>} : tensor<128xi32, #triton_gpu.slice<{dim = 1, parent = #blocked2}>>
|
||||
%32 = arith.muli %22, %c128_i32 {async_agent = dense<[0, 1]> : vector<2xi32>} : i32
|
||||
%33 = tt.splat %32 {async_agent = dense<1> : vector<1xi32>} : (i32) -> tensor<128xi32, #triton_gpu.slice<{dim = 0, parent = #blocked2}>>
|
||||
%34 = tt.splat %32 {async_agent = dense<1> : vector<1xi32>} : (i32) -> tensor<128xi32, #triton_gpu.slice<{dim = 0, parent = #blocked2}>>
|
||||
%35 = arith.addi %33, %25 {async_agent = dense<1> : vector<1xi32>} : tensor<128xi32, #triton_gpu.slice<{dim = 0, parent = #blocked2}>>
|
||||
%36 = arith.addi %34, %27 {async_agent = dense<1> : vector<1xi32>} : tensor<128xi32, #triton_gpu.slice<{dim = 0, parent = #blocked2}>>
|
||||
%37 = tt.splat %arg3 {async_agent = dense<1> : vector<1xi32>} : (i32) -> tensor<128xi32, #triton_gpu.slice<{dim = 1, parent = #blocked2}>>
|
||||
%38 = tt.splat %arg4 {async_agent = dense<1> : vector<1xi32>} : (i32) -> tensor<128xi32, #triton_gpu.slice<{dim = 0, parent = #blocked2}>>
|
||||
%39 = arith.addi %arg5, %c31_i32 {async_agent = dense<[0, 1]> : vector<2xi32>} : i32
|
||||
%40 = arith.divsi %39, %c32_i32 {async_agent = dense<[0, 1]> : vector<2xi32>} : i32
|
||||
%41 = arith.index_cast %40 {async_agent = dense<[0, 1]> : vector<2xi32>} : i32 to index
|
||||
%c127_i32 = arith.constant 127 : i32
|
||||
%c1_i64 = arith.constant 1 : i64
|
||||
%c128_i32 = arith.constant 128 : i32
|
||||
%c8_i32 = arith.constant 8 : i32
|
||||
%3 = tt.get_program_id x : i32
|
||||
%4 = arith.addi %arg6, %c127_i32 : i32
|
||||
%5 = arith.divsi %4, %c128_i32 : i32
|
||||
%6 = arith.addi %arg5, %c127_i32 : i32
|
||||
%7 = arith.divsi %6, %c128_i32 : i32
|
||||
%8 = arith.muli %5, %c8_i32 : i32
|
||||
%9 = arith.divsi %3, %8 : i32
|
||||
%10 = arith.muli %9, %c8_i32 : i32
|
||||
%11 = arith.subi %7, %10 : i32
|
||||
%12 = arith.minsi %11, %c8_i32 : i32
|
||||
%13 = arith.remsi %3, %12 : i32
|
||||
%14 = arith.addi %10, %13 : i32
|
||||
%15 = arith.remsi %3, %8 : i32
|
||||
%16 = arith.divsi %15, %12 : i32
|
||||
%17 = arith.muli %14, %c128_i32 : i32
|
||||
%18 = arith.muli %16, %c128_i32 : i32
|
||||
%19 = arith.extsi %arg5 : i32 to i64
|
||||
%20 = arith.extsi %arg7 : i32 to i64
|
||||
%21 = arith.extsi %arg8 : i32 to i64
|
||||
%22 = tt.make_tensor_ptr %arg0, [%19, %20], [%21, %c1_i64], [%17, %c0_i32] {order = array<i32: 1, 0>} : <tensor<128x64xf16, #blocked>, 1>
|
||||
%23 = arith.extsi %arg6 : i32 to i64
|
||||
%24 = arith.extsi %arg9 : i32 to i64
|
||||
%25 = tt.make_tensor_ptr %arg1, [%20, %23], [%c1_i64, %24], [%c0_i32, %18] {order = array<i32: 0, 1>} : <tensor<64x128xf16, #blocked1>, 1>
|
||||
%26 = arith.extsi %arg11 : i32 to i64
|
||||
%27 = tt.make_tensor_ptr %arg4, [%19, %23], [%26, %c1_i64], [%17, %18] {order = array<i32: 1, 0>} : <tensor<128x128xf32, #blocked>, 1>
|
||||
%28 = triton_nvidia_gpu.get_agent_id : i32
|
||||
%c0_i32_0 = arith.constant 0 : i32
|
||||
%29 = arith.cmpi eq, %28, %c0_i32_0 : i32
|
||||
scf.if %29 {
|
||||
%c64_i32 = arith.constant {async_agent = dense<[0, 1]> : vector<2xi32>} 64 : i32
|
||||
%c3_i32 = arith.constant {async_agent = dense<[0, 1]> : vector<2xi32>} 3 : i32
|
||||
%c0_i32_1 = arith.constant {async_agent = dense<[0, 1]> : vector<2xi32>} 0 : i32
|
||||
triton_nvidia_gpu.lock %3 {mutex.barId = dense<1> : vector<1xi32>, mutex.numThreads = dense<256> : vector<1xi32>} : !triton_nvidia_gpu.mutex
|
||||
%42:2 = scf.for %arg9 = %c0 to %41 step %c1 iter_args(%arg10 = %cst, %arg11 = %c0_i32_1) -> (tensor<128x128xf32, #mma>, i32) {
|
||||
triton_nvidia_gpu.consumer_wait %2, %arg11 {async_agent = dense<1> : vector<1xi32>} : tensor<3x!triton_nvidia_gpu.token>, i32
|
||||
%62 = triton_gpu.extract_slice %1[%arg11, 0, 0] [1, 128, 32] [1, 1, 1] {async_agent = dense<1> : vector<1xi32>} : tensor<3x128x32xf16, #shared1> to tensor<128x32xf16, #shared1>
|
||||
%63 = triton_gpu.extract_slice %0[%arg11, 0, 0] [1, 32, 128] [1, 1, 1] {async_agent = dense<1> : vector<1xi32>} : tensor<3x32x128xf16, #shared> to tensor<32x128xf16, #shared>
|
||||
%64 = triton_gpu.convert_layout %62 {async_agent = dense<1> : vector<1xi32>} : (tensor<128x32xf16, #shared1>) -> tensor<128x32xf16, #shared1>
|
||||
%65 = triton_gpu.convert_layout %63 {async_agent = dense<1> : vector<1xi32>} : (tensor<32x128xf16, #shared>) -> tensor<32x128xf16, #shared>
|
||||
%66 = tt.dot %64, %65, %arg10 {allowTF32 = true, maxNumImpreciseAcc = 0 : i32, async_agent = dense<1> : vector<1xi32>} : tensor<128x32xf16, #shared1> * tensor<32x128xf16, #shared> -> tensor<128x128xf32, #mma>
|
||||
%c1_i32_2 = arith.constant {async_agent = dense<1> : vector<1xi32>} 1 : i32
|
||||
%c3_i32 = arith.constant {async_agent = dense<1> : vector<1xi32>} 3 : i32
|
||||
%67 = arith.addi %arg11, %c1_i32_2 {async_agent = dense<1> : vector<1xi32>} : i32
|
||||
%68 = arith.remsi %67, %c3_i32 {async_agent = dense<1> : vector<1xi32>} : i32
|
||||
triton_nvidia_gpu.consumer_release %2, %arg11 {async_agent = dense<1> : vector<1xi32>} : tensor<3x!triton_nvidia_gpu.token>, i32
|
||||
scf.yield %66, %68 : tensor<128x128xf32, #mma>, i32
|
||||
%false = arith.constant {async_agent = dense<[0, 1]> : vector<2xi32>} false
|
||||
%31:4 = scf.for %arg12 = %c0_i32 to %arg7 step %c64_i32 iter_args(%arg13 = %22, %arg14 = %25, %arg15 = %false, %arg16 = %c0_i32_1) -> (!tt.ptr<tensor<128x64xf16, #blocked>, 1>, !tt.ptr<tensor<64x128xf16, #blocked1>, 1>, i1, i32) : i32 {
|
||||
triton_nvidia_gpu.producer_acquire %2, %arg16 {async_agent = dense<0> : vector<1xi32>} : tensor<3x!triton_nvidia_gpu.token>, i32
|
||||
%32 = triton_gpu.insert_slice %arg13, %0, %arg16 {async_agent = dense<0> : vector<1xi32>, axis = 0 : i32, cache = 1 : i32, evict = 1 : i32, isVolatile = false} : !tt.ptr<tensor<128x64xf16, #blocked>, 1> -> tensor<3x128x64xf16, #shared>
|
||||
%33 = triton_gpu.insert_slice %arg14, %1, %arg16 {async_agent = dense<0> : vector<1xi32>, axis = 0 : i32, cache = 1 : i32, evict = 1 : i32, isVolatile = false} : !tt.ptr<tensor<64x128xf16, #blocked1>, 1> -> tensor<3x64x128xf16, #shared1>
|
||||
triton_nvidia_gpu.producer_commit %2, %arg16 {async_agent = dense<0> : vector<1xi32>} : tensor<3x!triton_nvidia_gpu.token>, i32
|
||||
%34 = tt.advance %arg13, [%c0_i32, %c64_i32] {async_agent = dense<0> : vector<1xi32>} : <tensor<128x64xf16, #blocked>, 1>
|
||||
%35 = tt.advance %arg14, [%c64_i32, %c0_i32] {async_agent = dense<0> : vector<1xi32>} : <tensor<64x128xf16, #blocked1>, 1>
|
||||
%c1_i32_2 = arith.constant {async_agent = dense<0> : vector<1xi32>} 1 : i32
|
||||
%c0_i32_3 = arith.constant {async_agent = dense<0> : vector<1xi32>} 0 : i32
|
||||
%true = arith.constant {async_agent = dense<0> : vector<1xi32>} true
|
||||
%36 = arith.addi %arg16, %c1_i32_2 {async_agent = dense<0> : vector<1xi32>} : i32
|
||||
%37 = arith.cmpi uge, %36, %c3_i32 {async_agent = dense<0> : vector<1xi32>} : i32
|
||||
%38 = arith.cmpi ult, %36, %c3_i32 {async_agent = dense<0> : vector<1xi32>} : i32
|
||||
%39 = arith.subi %36, %c3_i32 {async_agent = dense<0> : vector<1xi32>} : i32
|
||||
%40 = arith.select %37, %39, %36 {async_agent = dense<0> : vector<1xi32>} : i32
|
||||
%41 = arith.xori %arg15, %true {async_agent = dense<0> : vector<1xi32>} : i1
|
||||
%42 = arith.andi %37, %41 {async_agent = dense<0> : vector<1xi32>} : i1
|
||||
%43 = arith.andi %38, %arg15 {async_agent = dense<0> : vector<1xi32>} : i1
|
||||
%44 = arith.ori %42, %43 {async_agent = dense<0> : vector<1xi32>} : i1
|
||||
scf.yield {async_agent = dense<0> : vector<1xi32>} %34, %35, %44, %40 : !tt.ptr<tensor<128x64xf16, #blocked>, 1>, !tt.ptr<tensor<64x128xf16, #blocked1>, 1>, i1, i32
|
||||
} {async_agent = dense<0> : vector<1xi32>}
|
||||
} {async_agent = dense<0> : vector<1xi32>}
|
||||
%c1_i32 = arith.constant 1 : i32
|
||||
%30 = arith.cmpi eq, %28, %c1_i32 : i32
|
||||
scf.if %30 {
|
||||
%cst = arith.constant {async_agent = dense<1> : vector<1xi32>} dense<0.000000e+00> : tensor<128x128xf32, #mma>
|
||||
%c64_i32 = arith.constant {async_agent = dense<[0, 1]> : vector<2xi32>} 64 : i32
|
||||
%c3_i32 = arith.constant {async_agent = dense<[0, 1]> : vector<2xi32>} 3 : i32
|
||||
%c0_i32_1 = arith.constant {async_agent = dense<[0, 1]> : vector<2xi32>} 0 : i32
|
||||
%false = arith.constant {async_agent = dense<[0, 1]> : vector<2xi32>} false
|
||||
%31:3 = scf.for %arg12 = %c0_i32 to %arg7 step %c64_i32 iter_args(%arg13 = %cst, %arg14 = %false, %arg15 = %c0_i32_1) -> (tensor<128x128xf32, #mma>, i1, i32) : i32 {
|
||||
triton_nvidia_gpu.consumer_wait %2, %arg15 {async_agent = dense<1> : vector<1xi32>} : tensor<3x!triton_nvidia_gpu.token>, i32
|
||||
%37 = triton_gpu.extract_slice %0[%arg15, 0, 0] [1, 128, 64] [1, 1, 1] {async_agent = dense<1> : vector<1xi32>} : tensor<3x128x64xf16, #shared> to tensor<128x64xf16, #shared>
|
||||
%38 = triton_gpu.convert_layout %37 {async_agent = dense<1> : vector<1xi32>} : (tensor<128x64xf16, #shared>) -> tensor<128x64xf16, #shared>
|
||||
%39 = triton_gpu.extract_slice %1[%arg15, 0, 0] [1, 64, 128] [1, 1, 1] {async_agent = dense<1> : vector<1xi32>} : tensor<3x64x128xf16, #shared1> to tensor<64x128xf16, #shared1>
|
||||
%40 = triton_gpu.convert_layout %39 {async_agent = dense<1> : vector<1xi32>} : (tensor<64x128xf16, #shared1>) -> tensor<64x128xf16, #shared1>
|
||||
%41 = triton_nvidia_gpu.dot_async %38, %40, %arg13 {allowTF32 = true, async_agent = dense<1> : vector<1xi32>, maxNumImpreciseAcc = 0 : i32} : tensor<128x64xf16, #shared> * tensor<64x128xf16, #shared1> -> tensor<128x128xf32, #mma>
|
||||
%42 = arith.cmpi sgt, %arg12, %c0_i32 {async_agent = dense<1> : vector<1xi32>} : i32
|
||||
scf.if %42 {
|
||||
%c0_i32_6 = arith.constant {async_agent = dense<1> : vector<1xi32>} 0 : i32
|
||||
%c1_i32_7 = arith.constant {async_agent = dense<1> : vector<1xi32>} 1 : i32
|
||||
%c2_i32_8 = arith.constant {async_agent = dense<1> : vector<1xi32>} 2 : i32
|
||||
%52 = arith.subi %arg15, %c1_i32_7 {async_agent = dense<1> : vector<1xi32>} : i32
|
||||
%53 = arith.cmpi eq, %arg15, %c0_i32_6 {async_agent = dense<1> : vector<1xi32>} : i32
|
||||
%54 = arith.select %53, %c2_i32_8, %52 {async_agent = dense<1> : vector<1xi32>} : i32
|
||||
triton_nvidia_gpu.consumer_release %2, %54 {async_agent = dense<1> : vector<1xi32>} : tensor<3x!triton_nvidia_gpu.token>, i32
|
||||
} {async_agent = dense<1> : vector<1xi32>}
|
||||
%c1_i32_4 = arith.constant {async_agent = dense<1> : vector<1xi32>} 1 : i32
|
||||
%c0_i32_5 = arith.constant {async_agent = dense<1> : vector<1xi32>} 0 : i32
|
||||
%true = arith.constant {async_agent = dense<1> : vector<1xi32>} true
|
||||
%43 = arith.addi %arg15, %c1_i32_4 {async_agent = dense<1> : vector<1xi32>} : i32
|
||||
%44 = arith.cmpi uge, %43, %c3_i32 {async_agent = dense<1> : vector<1xi32>} : i32
|
||||
%45 = arith.cmpi ult, %43, %c3_i32 {async_agent = dense<1> : vector<1xi32>} : i32
|
||||
%46 = arith.subi %43, %c3_i32 {async_agent = dense<1> : vector<1xi32>} : i32
|
||||
%47 = arith.select %44, %46, %43 {async_agent = dense<1> : vector<1xi32>} : i32
|
||||
%48 = arith.xori %arg14, %true {async_agent = dense<1> : vector<1xi32>} : i1
|
||||
%49 = arith.andi %44, %48 {async_agent = dense<1> : vector<1xi32>} : i1
|
||||
%50 = arith.andi %45, %arg14 {async_agent = dense<1> : vector<1xi32>} : i1
|
||||
%51 = arith.ori %49, %50 {async_agent = dense<1> : vector<1xi32>} : i1
|
||||
scf.yield {async_agent = dense<1> : vector<1xi32>} %41, %51, %47 : tensor<128x128xf32, #mma>, i1, i32
|
||||
} {async_agent = dense<1> : vector<1xi32>}
|
||||
triton_nvidia_gpu.unlock %3 {mutex.barId = dense<2> : vector<1xi32>, mutex.numThreads = dense<256> : vector<1xi32>} : !triton_nvidia_gpu.mutex
|
||||
triton_nvidia_gpu.lock %4 {mutex.barId = dense<3> : vector<1xi32>, mutex.numThreads = dense<256> : vector<1xi32>} : !triton_nvidia_gpu.mutex
|
||||
%43 = arith.truncf %42#0 {async_agent = dense<1> : vector<1xi32>} : tensor<128x128xf32, #mma> to tensor<128x128xf16, #mma>
|
||||
%44 = tt.expand_dims %30 {async_agent = dense<1> : vector<1xi32>, axis = 1 : i32} : (tensor<128xi32, #triton_gpu.slice<{dim = 1, parent = #blocked2}>>) -> tensor<128x1xi32, #blocked2>
|
||||
%45 = tt.splat %arg8 {async_agent = dense<1> : vector<1xi32>} : (i32) -> tensor<128x1xi32, #blocked2>
|
||||
%46 = arith.muli %44, %45 {async_agent = dense<1> : vector<1xi32>} : tensor<128x1xi32, #blocked2>
|
||||
%47 = tt.expand_dims %35 {async_agent = dense<1> : vector<1xi32>, axis = 0 : i32} : (tensor<128xi32, #triton_gpu.slice<{dim = 0, parent = #blocked2}>>) -> tensor<1x128xi32, #blocked2>
|
||||
%48 = tt.broadcast %46 {async_agent = dense<1> : vector<1xi32>} : (tensor<128x1xi32, #blocked2>) -> tensor<128x128xi32, #blocked2>
|
||||
%49 = tt.broadcast %47 {async_agent = dense<1> : vector<1xi32>} : (tensor<1x128xi32, #blocked2>) -> tensor<128x128xi32, #blocked2>
|
||||
%50 = arith.addi %48, %49 {async_agent = dense<1> : vector<1xi32>} : tensor<128x128xi32, #blocked2>
|
||||
%51 = tt.splat %arg2 {async_agent = dense<1> : vector<1xi32>} : (!tt.ptr<f16, 1>) -> tensor<128x128x!tt.ptr<f16, 1>, #blocked2>
|
||||
%52 = tt.addptr %51, %50 {async_agent = dense<1> : vector<1xi32>} : tensor<128x128x!tt.ptr<f16, 1>, #blocked2>, tensor<128x128xi32, #blocked2>
|
||||
%53 = "triton_gpu.cmpi"(%31, %37) {async_agent = dense<1> : vector<1xi32>, predicate = 2 : i64} : (tensor<128xi32, #triton_gpu.slice<{dim = 1, parent = #blocked2}>>, tensor<128xi32, #triton_gpu.slice<{dim = 1, parent = #blocked2}>>) -> tensor<128xi1, #triton_gpu.slice<{dim = 1, parent = #blocked2}>>
|
||||
%54 = tt.expand_dims %53 {async_agent = dense<1> : vector<1xi32>, axis = 1 : i32} : (tensor<128xi1, #triton_gpu.slice<{dim = 1, parent = #blocked2}>>) -> tensor<128x1xi1, #blocked2>
|
||||
%55 = "triton_gpu.cmpi"(%36, %38) {async_agent = dense<1> : vector<1xi32>, predicate = 2 : i64} : (tensor<128xi32, #triton_gpu.slice<{dim = 0, parent = #blocked2}>>, tensor<128xi32, #triton_gpu.slice<{dim = 0, parent = #blocked2}>>) -> tensor<128xi1, #triton_gpu.slice<{dim = 0, parent = #blocked2}>>
|
||||
%56 = tt.expand_dims %55 {async_agent = dense<1> : vector<1xi32>, axis = 0 : i32} : (tensor<128xi1, #triton_gpu.slice<{dim = 0, parent = #blocked2}>>) -> tensor<1x128xi1, #blocked2>
|
||||
%57 = tt.broadcast %54 {async_agent = dense<1> : vector<1xi32>} : (tensor<128x1xi1, #blocked2>) -> tensor<128x128xi1, #blocked2>
|
||||
%58 = tt.broadcast %56 {async_agent = dense<1> : vector<1xi32>} : (tensor<1x128xi1, #blocked2>) -> tensor<128x128xi1, #blocked2>
|
||||
%59 = arith.andi %57, %58 {async_agent = dense<1> : vector<1xi32>} : tensor<128x128xi1, #blocked2>
|
||||
%60 = triton_gpu.convert_layout %43 {async_agent = dense<1> : vector<1xi32>} : (tensor<128x128xf16, #mma>) -> tensor<128x128xf16, #blocked2>
|
||||
tt.store %52, %60, %59 {async_agent = dense<1> : vector<1xi32>, cache = 1 : i32, evict = 1 : i32} : tensor<128x128xf16, #blocked2>
|
||||
triton_nvidia_gpu.unlock %4 {mutex.barId = dense<4> : vector<1xi32>, mutex.numThreads = dense<256> : vector<1xi32>} : !triton_nvidia_gpu.mutex
|
||||
}
|
||||
%32 = triton_nvidia_gpu.dot_wait %31#0 {async_agent = dense<1> : vector<1xi32>, pendings = 0 : i32} : tensor<128x128xf32, #mma>
|
||||
%c0_i32_2 = arith.constant {async_agent = dense<1> : vector<1xi32>} 0 : i32
|
||||
%c1_i32_3 = arith.constant {async_agent = dense<1> : vector<1xi32>} 1 : i32
|
||||
%c2_i32 = arith.constant {async_agent = dense<1> : vector<1xi32>} 2 : i32
|
||||
%33 = arith.subi %31#2, %c1_i32_3 {async_agent = dense<1> : vector<1xi32>} : i32
|
||||
%34 = arith.cmpi eq, %31#2, %c0_i32_2 {async_agent = dense<1> : vector<1xi32>} : i32
|
||||
%35 = arith.select %34, %c2_i32, %33 {async_agent = dense<1> : vector<1xi32>} : i32
|
||||
triton_nvidia_gpu.consumer_release %2, %35 {async_agent = dense<1> : vector<1xi32>} : tensor<3x!triton_nvidia_gpu.token>, i32
|
||||
%36 = triton_gpu.convert_layout %32 {async_agent = dense<1> : vector<1xi32>} : (tensor<128x128xf32, #mma>) -> tensor<128x128xf32, #blocked2>
|
||||
tt.store %27, %36 {async_agent = dense<1> : vector<1xi32>, boundaryCheck = array<i32: 0, 1>, cache = 1 : i32, evict = 1 : i32} : !tt.ptr<tensor<128x128xf32, #blocked>, 1>, tensor<128x128xf32, #blocked2>
|
||||
} {async_agent = dense<1> : vector<1xi32>}
|
||||
tt.return
|
||||
}
|
||||
}
|
||||
|
||||
// -----
|
||||
|
||||
#blocked = #triton_gpu.blocked<{sizePerThread = [8, 1], threadsPerWarp = [2, 16], warpsPerCTA = [1, 4], order = [0, 1], CTAsPerCGA = [1, 1], CTASplitNum = [1, 1], CTAOrder = [0, 1]}>
|
||||
#blocked1 = #triton_gpu.blocked<{sizePerThread = [1, 8], threadsPerWarp = [16, 2], warpsPerCTA = [4, 1], order = [1, 0], CTAsPerCGA = [1, 1], CTASplitNum = [1, 1], CTAOrder = [0, 1]}>
|
||||
#blocked2 = #triton_gpu.blocked<{sizePerThread = [1, 4], threadsPerWarp = [2, 16], warpsPerCTA = [4, 1], order = [1, 0], CTAsPerCGA = [1, 1], CTASplitNum = [1, 1], CTAOrder = [0, 1]}>
|
||||
#blocked = #triton_gpu.blocked<{sizePerThread = [1, 1], threadsPerWarp = [32, 1], warpsPerCTA = [2, 2], order = [0, 1], CTAsPerCGA = [1, 1], CTASplitNum = [1, 1], CTAOrder = [0, 1]}>
|
||||
#blocked1 = #triton_gpu.blocked<{sizePerThread = [1, 1], threadsPerWarp = [16, 2], warpsPerCTA = [1, 4], order = [0, 1], CTAsPerCGA = [1, 1], CTASplitNum = [1, 1], CTAOrder = [0, 1]}>
|
||||
#blocked2 = #triton_gpu.blocked<{sizePerThread = [1, 8], threadsPerWarp = [4, 8], warpsPerCTA = [4, 1], order = [1, 0], CTAsPerCGA = [1, 1], CTASplitNum = [1, 1], CTAOrder = [0, 1]}>
|
||||
#mma = #triton_gpu.mma<{versionMajor = 3, versionMinor = 0, warpsPerCTA = [4, 1], CTAsPerCGA = [1, 1], CTASplitNum = [1, 1], CTAOrder = [1, 0], instrShape = [16, 64, 16]}>
|
||||
#shared = #triton_gpu.shared<{vec = 8, perPhase = 4, maxPhase = 2, order = [1, 0], CTAsPerCGA = [1, 1], CTASplitNum = [1, 1], CTAOrder = [0, 1], hasLeadingOffset = true}>
|
||||
#shared1 = #triton_gpu.shared<{vec = 8, perPhase = 4, maxPhase = 2, order = [0, 1], CTAsPerCGA = [1, 1], CTASplitNum = [1, 1], CTAOrder = [0, 1], hasLeadingOffset = true}>
|
||||
module attributes {"async.num-agents" = 2 : i32, "triton_gpu.compute-capability" = 90 : i32, "triton_gpu.enable-warp-specialization" = 1 : i32, "triton_gpu.num-ctas" = 1 : i32, "triton_gpu.num-warps" = 4 : i32} {
|
||||
module attributes {"async.num-agents" = 2 : i32, "triton_gpu.compute-capability" = 90 : i32, "triton_gpu.enable-warp-specialization" = 1 : i32, "triton_gpu.num-ctas" = 1 : i32, "triton_gpu.num-warps" = 4 : i32, "triton_gpu.threads-per-warp" = 32 : i32} {
|
||||
// CHECK-LABEL: @matmal_from_wsmutex
|
||||
// CHECK: triton_nvidia_gpu.alloc_mbarrier
|
||||
// CHECK: scf.if
|
||||
// CHECK: scf.for
|
||||
// CHECK: triton_nvidia_gpu.extract_mbarrier
|
||||
// CHECK: triton_nvidia_gpu.mbarrier_wait
|
||||
// CHECK: triton_gpu.insert_slice
|
||||
// CHECK: triton_gpu.insert_slice
|
||||
// CHECK: triton_nvidia_gpu.insert_slice_async_v2
|
||||
// CHECK: triton_nvidia_gpu.insert_slice_async_v2
|
||||
// CHECK: triton_nvidia_gpu.extract_mbarrier
|
||||
// CHECK: triton_nvidia_gpu.mbarrier_arrive
|
||||
// CHECK: scf.yield
|
||||
@@ -239,174 +177,224 @@ module attributes {"async.num-agents" = 2 : i32, "triton_gpu.compute-capability"
|
||||
// CHECK: triton_nvidia_gpu.mbarrier_wait
|
||||
// CHECK: triton_gpu.extract_slice
|
||||
// CHECK: triton_gpu.extract_slice
|
||||
// CHECK: tt.dot
|
||||
// CHECK: triton_nvidia_gpu.dot_async
|
||||
// CHECK: triton_nvidia_gpu.extract_mbarrier
|
||||
// CHECK: triton_nvidia_gpu.mbarrier_arrive
|
||||
// CHECK: scf.yield
|
||||
// CHECK: triton_nvidia_gpu.bar_arrive
|
||||
// CHECK: triton_nvidia_gpu.dot_wait
|
||||
// CHECK: triton_nvidia_gpu.extract_mbarrier
|
||||
// CHECK: triton_nvidia_gpu.mbarrier_arrive
|
||||
// CHECK: triton_nvidia_gpu.bar_wait
|
||||
// CHECK: tt.store
|
||||
// CHECK: triton_nvidia_gpu.bar_arrive
|
||||
tt.func public @matmal_from_wsmutex(%arg0: !tt.ptr<f16, 1> {tt.divisibility = 16 : i32}, %arg1: !tt.ptr<f16, 1> {tt.divisibility = 16 : i32}, %arg2: !tt.ptr<f32, 1> {tt.divisibility = 16 : i32}, %arg3: i32 {tt.divisibility = 16 : i32, tt.max_divisibility = 8 : i32}, %arg4: i32 {tt.divisibility = 16 : i32, tt.max_divisibility = 8 : i32}, %arg5: i32 {tt.divisibility = 16 : i32, tt.max_divisibility = 8 : i32}, %arg6: i32 {tt.divisibility = 16 : i32, tt.max_divisibility = 8 : i32}, %arg7: i32 {tt.divisibility = 16 : i32, tt.max_divisibility = 8 : i32}, %arg8: i32 {tt.divisibility = 16 : i32, tt.max_divisibility = 8 : i32}) {
|
||||
tt.func public @matmal_from_wsmutex(%arg0: !tt.ptr<f16, 1> {tt.divisibility = 16 : i32}, %arg1: !tt.ptr<f16, 1> {tt.divisibility = 16 : i32}, %arg2: !tt.ptr<f16, 1> {tt.divisibility = 16 : i32}, %arg3: !tt.ptr<f16, 1> {tt.divisibility = 16 : i32}, %arg4: !tt.ptr<f16, 1> {tt.divisibility = 16 : i32}, %arg5: i32 {tt.divisibility = 16 : i32, tt.max_divisibility = 8 : i32}, %arg6: i32 {tt.divisibility = 16 : i32, tt.max_divisibility = 8 : i32}, %arg7: i32 {tt.divisibility = 16 : i32, tt.max_divisibility = 8 : i32}, %arg8: i32 {tt.divisibility = 16 : i32, tt.max_divisibility = 8 : i32}, %arg9: i32 {tt.divisibility = 16 : i32, tt.max_divisibility = 8 : i32}, %arg10: i32 {tt.divisibility = 16 : i32, tt.max_divisibility = 8 : i32}) attributes {noinline = false} {
|
||||
%0 = triton_gpu.alloc_tensor : tensor<3x64x16xf16, #shared>
|
||||
%1 = triton_gpu.alloc_tensor : tensor<3x16x64xf16, #shared1>
|
||||
%2 = triton_nvidia_gpu.create_token {num = 3 : i32} : tensor<3x!triton_nvidia_gpu.token>
|
||||
%3 = triton_nvidia_gpu.get_agent_id : i32
|
||||
%c63_i32 = arith.constant 63 : i32
|
||||
%c0_i32 = arith.constant 0 : i32
|
||||
%4 = arith.cmpi eq, %3, %c0_i32 : i32
|
||||
scf.if %4 {
|
||||
%cst = arith.constant {async_agent = dense<0> : vector<1xi32>} dense<16> : tensor<16x64xi32, #blocked>
|
||||
%cst_0 = arith.constant {async_agent = dense<0> : vector<1xi32>} dense<16> : tensor<64x16xi32, #blocked1>
|
||||
%c63_i32 = arith.constant {async_agent = dense<[0, 1]> : vector<2xi32>} 63 : i32
|
||||
%c114_i32 = arith.constant {async_agent = dense<[0, 1]> : vector<2xi32>} 114 : i32
|
||||
%c1_i64 = arith.constant 1 : i64
|
||||
%c64_i32 = arith.constant 64 : i32
|
||||
%c8_i32 = arith.constant 8 : i32
|
||||
%3 = tt.get_program_id x : i32
|
||||
%4 = arith.addi %arg6, %c63_i32 : i32
|
||||
%5 = arith.divsi %4, %c64_i32 : i32
|
||||
%6 = arith.addi %arg5, %c63_i32 : i32
|
||||
%7 = arith.divsi %6, %c64_i32 : i32
|
||||
%8 = arith.muli %5, %c8_i32 : i32
|
||||
%9 = arith.divsi %3, %8 : i32
|
||||
%10 = arith.muli %9, %c8_i32 : i32
|
||||
%11 = arith.subi %7, %10 : i32
|
||||
%12 = arith.minsi %11, %c8_i32 : i32
|
||||
%13 = arith.remsi %3, %8 : i32
|
||||
%14 = arith.remsi %13, %12 : i32
|
||||
%15 = arith.addi %10, %14 : i32
|
||||
%16 = arith.divsi %13, %12 : i32
|
||||
%17 = arith.muli %15, %c64_i32 : i32
|
||||
%18 = arith.muli %16, %c64_i32 : i32
|
||||
%19 = arith.extsi %arg5 : i32 to i64
|
||||
%20 = arith.extsi %arg7 : i32 to i64
|
||||
%21 = arith.extsi %arg8 : i32 to i64
|
||||
%22 = tt.make_tensor_ptr %arg0, [%19, %20], [%21, %c1_i64], [%17, %c0_i32] {order = array<i32: 1, 0>} : <tensor<64x16xf16, #blocked>, 1>
|
||||
%23 = arith.extsi %arg6 : i32 to i64
|
||||
%24 = arith.extsi %arg9 : i32 to i64
|
||||
%25 = tt.make_tensor_ptr %arg1, [%20, %23], [%c1_i64, %24], [%c0_i32, %18] {order = array<i32: 0, 1>} : <tensor<16x64xf16, #blocked1>, 1>
|
||||
%26 = arith.extsi %arg10 : i32 to i64
|
||||
%27 = tt.make_tensor_ptr %arg4, [%19, %23], [%26, %c1_i64], [%17, %18] {order = array<i32: 1, 0>} : <tensor<64x64xf16, #blocked>, 1>
|
||||
%28 = triton_nvidia_gpu.get_agent_id : i32
|
||||
%c0_i32_0 = arith.constant 0 : i32
|
||||
%29 = arith.cmpi eq, %28, %c0_i32_0 : i32
|
||||
scf.if %29 {
|
||||
%c132_i32 = arith.constant {async_agent = dense<[0, 1]> : vector<2xi32>} 132 : i32
|
||||
%c15_i32 = arith.constant {async_agent = dense<0> : vector<1xi32>} 15 : i32
|
||||
%c16_i32 = arith.constant {async_agent = dense<[0, 1]> : vector<2xi32>} 16 : i32
|
||||
%c0_i32_1 = arith.constant {async_agent = dense<[0, 1]> : vector<2xi32>} 0 : i32
|
||||
%c64_i32 = arith.constant {async_agent = dense<[0, 1]> : vector<2xi32>} 64 : i32
|
||||
%6 = tt.get_program_id x {async_agent = dense<[0, 1]> : vector<2xi32>, axis = 0 : i32} : i32
|
||||
%7 = arith.addi %arg3, %c63_i32 {async_agent = dense<[0, 1]> : vector<2xi32>} : i32
|
||||
%8 = arith.divsi %7, %c64_i32 {async_agent = dense<[0, 1]> : vector<2xi32>} : i32
|
||||
%9 = arith.addi %arg4, %c63_i32 {async_agent = dense<[0, 1]> : vector<2xi32>} : i32
|
||||
%10 = arith.divsi %9, %c64_i32 {async_agent = dense<[0, 1]> : vector<2xi32>} : i32
|
||||
%11 = arith.muli %8, %10 {async_agent = dense<[0, 1]> : vector<2xi32>} : i32
|
||||
%12 = tt.make_range {async_agent = dense<0> : vector<1xi32>, end = 64 : i32, start = 0 : i32} : tensor<64xi32, #triton_gpu.slice<{dim = 1, parent = #blocked1}>>
|
||||
%13 = tt.make_range {async_agent = dense<0> : vector<1xi32>, end = 64 : i32, start = 0 : i32} : tensor<64xi32, #triton_gpu.slice<{dim = 0, parent = #blocked}>>
|
||||
%14 = tt.splat %arg3 {async_agent = dense<0> : vector<1xi32>} : (i32) -> tensor<64xi32, #triton_gpu.slice<{dim = 1, parent = #blocked1}>>
|
||||
%15 = tt.splat %arg4 {async_agent = dense<0> : vector<1xi32>} : (i32) -> tensor<64xi32, #triton_gpu.slice<{dim = 0, parent = #blocked}>>
|
||||
%16 = tt.splat %arg6 {async_agent = dense<0> : vector<1xi32>} : (i32) -> tensor<64x1xi32, #blocked1>
|
||||
%17 = tt.make_range {async_agent = dense<0> : vector<1xi32>, end = 16 : i32, start = 0 : i32} : tensor<16xi32, #triton_gpu.slice<{dim = 0, parent = #blocked1}>>
|
||||
%18 = tt.expand_dims %17 {async_agent = dense<0> : vector<1xi32>, axis = 0 : i32} : (tensor<16xi32, #triton_gpu.slice<{dim = 0, parent = #blocked1}>>) -> tensor<1x16xi32, #blocked1>
|
||||
%19 = tt.broadcast %18 {async_agent = dense<0> : vector<1xi32>} : (tensor<1x16xi32, #blocked1>) -> tensor<64x16xi32, #blocked1>
|
||||
%20 = tt.splat %arg0 {async_agent = dense<0> : vector<1xi32>} : (!tt.ptr<f16, 1>) -> tensor<64x16x!tt.ptr<f16, 1>, #blocked1>
|
||||
%21 = tt.make_range {async_agent = dense<0> : vector<1xi32>, end = 16 : i32, start = 0 : i32} : tensor<16xi32, #triton_gpu.slice<{dim = 1, parent = #blocked}>>
|
||||
%22 = tt.expand_dims %21 {async_agent = dense<0> : vector<1xi32>, axis = 1 : i32} : (tensor<16xi32, #triton_gpu.slice<{dim = 1, parent = #blocked}>>) -> tensor<16x1xi32, #blocked>
|
||||
%23 = tt.splat %arg7 {async_agent = dense<0> : vector<1xi32>} : (i32) -> tensor<1x64xi32, #blocked>
|
||||
%24 = tt.broadcast %22 {async_agent = dense<0> : vector<1xi32>} : (tensor<16x1xi32, #blocked>) -> tensor<16x64xi32, #blocked>
|
||||
%25 = tt.splat %arg1 {async_agent = dense<0> : vector<1xi32>} : (!tt.ptr<f16, 1>) -> tensor<16x64x!tt.ptr<f16, 1>, #blocked>
|
||||
%31 = arith.muli %7, %5 {async_agent = dense<[0, 1]> : vector<2xi32>} : i32
|
||||
%32 = arith.addi %arg7, %c15_i32 {async_agent = dense<0> : vector<1xi32>} : i32
|
||||
%33 = arith.divsi %32, %c16_i32 {async_agent = dense<0> : vector<1xi32>} : i32
|
||||
%34 = arith.subi %c0_i32, %33 {async_agent = dense<0> : vector<1xi32>} : i32
|
||||
%35 = arith.muli %34, %c16_i32 {async_agent = dense<0> : vector<1xi32>} : i32
|
||||
%c3_i32 = arith.constant {async_agent = dense<[0, 1]> : vector<2xi32>} 3 : i32
|
||||
%c0_i32_2 = arith.constant {async_agent = dense<[0, 1]> : vector<2xi32>} 0 : i32
|
||||
%26 = scf.for %arg9 = %6 to %11 step %c114_i32 iter_args(%arg10 = %c0_i32_2) -> (i32) : i32 {
|
||||
%27 = arith.divsi %arg9, %10 {async_agent = dense<0> : vector<1xi32>} : i32
|
||||
%28 = arith.remsi %arg9, %10 {async_agent = dense<0> : vector<1xi32>} : i32
|
||||
%29 = arith.muli %27, %c64_i32 {async_agent = dense<0> : vector<1xi32>} : i32
|
||||
%30 = tt.splat %29 {async_agent = dense<0> : vector<1xi32>} : (i32) -> tensor<64xi32, #triton_gpu.slice<{dim = 1, parent = #blocked1}>>
|
||||
%31 = arith.addi %30, %12 {async_agent = dense<0> : vector<1xi32>} : tensor<64xi32, #triton_gpu.slice<{dim = 1, parent = #blocked1}>>
|
||||
%32 = arith.remsi %31, %14 {async_agent = dense<0> : vector<1xi32>} : tensor<64xi32, #triton_gpu.slice<{dim = 1, parent = #blocked1}>>
|
||||
%33 = arith.muli %28, %c64_i32 {async_agent = dense<0> : vector<1xi32>} : i32
|
||||
%34 = tt.splat %33 {async_agent = dense<0> : vector<1xi32>} : (i32) -> tensor<64xi32, #triton_gpu.slice<{dim = 0, parent = #blocked}>>
|
||||
%35 = arith.addi %34, %13 {async_agent = dense<0> : vector<1xi32>} : tensor<64xi32, #triton_gpu.slice<{dim = 0, parent = #blocked}>>
|
||||
%36 = arith.remsi %35, %15 {async_agent = dense<0> : vector<1xi32>} : tensor<64xi32, #triton_gpu.slice<{dim = 0, parent = #blocked}>>
|
||||
%37 = tt.expand_dims %32 {async_agent = dense<0> : vector<1xi32>, axis = 1 : i32} : (tensor<64xi32, #triton_gpu.slice<{dim = 1, parent = #blocked1}>>) -> tensor<64x1xi32, #blocked1>
|
||||
%38 = arith.muli %37, %16 {async_agent = dense<0> : vector<1xi32>} : tensor<64x1xi32, #blocked1>
|
||||
%39 = tt.broadcast %38 {async_agent = dense<0> : vector<1xi32>} : (tensor<64x1xi32, #blocked1>) -> tensor<64x16xi32, #blocked1>
|
||||
%40 = arith.addi %39, %19 {async_agent = dense<0> : vector<1xi32>} : tensor<64x16xi32, #blocked1>
|
||||
%41 = tt.addptr %20, %40 {async_agent = dense<0> : vector<1xi32>} : tensor<64x16x!tt.ptr<f16, 1>, #blocked1>, tensor<64x16xi32, #blocked1>
|
||||
%42 = tt.expand_dims %36 {async_agent = dense<0> : vector<1xi32>, axis = 0 : i32} : (tensor<64xi32, #triton_gpu.slice<{dim = 0, parent = #blocked}>>) -> tensor<1x64xi32, #blocked>
|
||||
%43 = arith.muli %42, %23 {async_agent = dense<0> : vector<1xi32>} : tensor<1x64xi32, #blocked>
|
||||
%44 = tt.broadcast %43 {async_agent = dense<0> : vector<1xi32>} : (tensor<1x64xi32, #blocked>) -> tensor<16x64xi32, #blocked>
|
||||
%45 = arith.addi %24, %44 {async_agent = dense<0> : vector<1xi32>} : tensor<16x64xi32, #blocked>
|
||||
%46 = tt.addptr %25, %45 {async_agent = dense<0> : vector<1xi32>} : tensor<16x64x!tt.ptr<f16, 1>, #blocked>, tensor<16x64xi32, #blocked>
|
||||
%c3_i32_3 = arith.constant {async_agent = dense<0> : vector<1xi32>} 3 : i32
|
||||
%47 = arith.subi %arg5, %c0_i32_1 {async_agent = dense<0> : vector<1xi32>} : i32
|
||||
%48 = arith.divui %47, %c16_i32 {async_agent = dense<0> : vector<1xi32>} : i32
|
||||
%49 = arith.muli %arg10, %48 {async_agent = dense<0> : vector<1xi32>} : i32
|
||||
%c3_i32_4 = arith.constant {async_agent = dense<0> : vector<1xi32>} 3 : i32
|
||||
%50:3 = scf.for %arg11 = %c0_i32_1 to %arg5 step %c16_i32 iter_args(%arg12 = %41, %arg13 = %46, %arg14 = %49) -> (tensor<64x16x!tt.ptr<f16, 1>, #blocked1>, tensor<16x64x!tt.ptr<f16, 1>, #blocked>, i32) : i32 {
|
||||
%52 = arith.remsi %arg14, %c3_i32_4 {async_agent = dense<0> : vector<1xi32>} : i32
|
||||
triton_nvidia_gpu.producer_acquire %2, %52 {async_agent = dense<0> : vector<1xi32>} : tensor<3x!triton_nvidia_gpu.token>, i32
|
||||
%53 = triton_gpu.insert_slice %arg12, %0, %52 {async_agent = dense<0> : vector<1xi32>, axis = 0 : i32, cache = 1 : i32, evict = 1 : i32, isVolatile = false} : tensor<64x16x!tt.ptr<f16, 1>, #blocked1> -> tensor<3x64x16xf16, #shared>
|
||||
%54 = triton_gpu.insert_slice %arg13, %1, %52 {async_agent = dense<0> : vector<1xi32>, axis = 0 : i32, cache = 1 : i32, evict = 1 : i32, isVolatile = false} : tensor<16x64x!tt.ptr<f16, 1>, #blocked> -> tensor<3x16x64xf16, #shared1>
|
||||
triton_nvidia_gpu.producer_commit %2, %52 {async_agent = dense<0> : vector<1xi32>} : tensor<3x!triton_nvidia_gpu.token>, i32
|
||||
%55 = tt.addptr %arg12, %cst_0 {async_agent = dense<0> : vector<1xi32>} : tensor<64x16x!tt.ptr<f16, 1>, #blocked1>, tensor<64x16xi32, #blocked1>
|
||||
%56 = tt.addptr %arg13, %cst {async_agent = dense<0> : vector<1xi32>} : tensor<16x64x!tt.ptr<f16, 1>, #blocked>, tensor<16x64xi32, #blocked>
|
||||
%c0_i32_1 = arith.constant {async_agent = dense<[0, 1]> : vector<2xi32>} 0 : i32
|
||||
%36:5 = scf.for %arg11 = %3 to %31 step %c132_i32 iter_args(%arg12 = %22, %arg13 = %25, %arg14 = %15, %arg15 = %16, %arg16 = %c0_i32_1) -> (!tt.ptr<tensor<64x16xf16, #blocked>, 1>, !tt.ptr<tensor<16x64xf16, #blocked1>, 1>, i32, i32, i32) : i32 {
|
||||
%37 = arith.divsi %arg11, %8 {async_agent = dense<0> : vector<1xi32>} : i32
|
||||
%38 = arith.muli %37, %c8_i32 {async_agent = dense<0> : vector<1xi32>} : i32
|
||||
%39 = arith.subi %7, %38 {async_agent = dense<0> : vector<1xi32>} : i32
|
||||
%40 = arith.minsi %39, %c8_i32 {async_agent = dense<0> : vector<1xi32>} : i32
|
||||
%41 = arith.remsi %arg11, %8 {async_agent = dense<0> : vector<1xi32>} : i32
|
||||
%42 = arith.remsi %41, %40 {async_agent = dense<0> : vector<1xi32>} : i32
|
||||
%43 = arith.addi %38, %42 {async_agent = dense<0> : vector<1xi32>} : i32
|
||||
%44 = arith.divsi %41, %40 {async_agent = dense<0> : vector<1xi32>} : i32
|
||||
%45 = arith.subi %43, %arg14 {async_agent = dense<0> : vector<1xi32>} : i32
|
||||
%46 = arith.muli %45, %c64_i32 {async_agent = dense<0> : vector<1xi32>} : i32
|
||||
%47 = tt.advance %arg12, [%46, %c0_i32] {async_agent = dense<0> : vector<1xi32>} : <tensor<64x16xf16, #blocked>, 1>
|
||||
%48 = arith.subi %44, %arg15 {async_agent = dense<0> : vector<1xi32>} : i32
|
||||
%49 = arith.muli %48, %c64_i32 {async_agent = dense<0> : vector<1xi32>} : i32
|
||||
%50 = tt.advance %arg13, [%c0_i32, %49] {async_agent = dense<0> : vector<1xi32>} : <tensor<16x64xf16, #blocked1>, 1>
|
||||
%c3_i32_2 = arith.constant {async_agent = dense<0> : vector<1xi32>} 3 : i32
|
||||
%c0_i32_3 = arith.constant {async_agent = dense<0> : vector<1xi32>} 0 : i32
|
||||
%51 = arith.subi %arg7, %c0_i32 {async_agent = dense<0> : vector<1xi32>} : i32
|
||||
%52 = arith.addi %51, %c16_i32 {async_agent = dense<0> : vector<1xi32>} : i32
|
||||
%c1_i32_4 = arith.constant {async_agent = dense<0> : vector<1xi32>} 1 : i32
|
||||
%c2_i32 = arith.constant {async_agent = dense<0> : vector<1xi32>} 2 : i32
|
||||
%53 = arith.subi %52, %c1_i32_4 {async_agent = dense<0> : vector<1xi32>} : i32
|
||||
%54 = arith.divui %53, %c16_i32 {async_agent = dense<0> : vector<1xi32>} : i32
|
||||
%55 = arith.muli %arg16, %54 {async_agent = dense<0> : vector<1xi32>} : i32
|
||||
%56 = arith.divui %55, %c3_i32_2 {async_agent = dense<0> : vector<1xi32>} : i32
|
||||
%57 = arith.muli %56, %c3_i32_2 {async_agent = dense<0> : vector<1xi32>} : i32
|
||||
%58 = arith.subi %55, %57 {async_agent = dense<0> : vector<1xi32>} : i32
|
||||
%59 = arith.andi %56, %c1_i32_4 {async_agent = dense<0> : vector<1xi32>} : i32
|
||||
%60 = arith.trunci %59 {async_agent = dense<0> : vector<1xi32>} : i32 to i1
|
||||
%61:4 = scf.for %arg17 = %c0_i32 to %arg7 step %c16_i32 iter_args(%arg18 = %47, %arg19 = %50, %arg20 = %60, %arg21 = %58) -> (!tt.ptr<tensor<64x16xf16, #blocked>, 1>, !tt.ptr<tensor<16x64xf16, #blocked1>, 1>, i1, i32) : i32 {
|
||||
triton_nvidia_gpu.producer_acquire %2, %arg21 {async_agent = dense<0> : vector<1xi32>} : tensor<3x!triton_nvidia_gpu.token>, i32
|
||||
%65 = triton_gpu.insert_slice %arg18, %0, %arg21 {async_agent = dense<0> : vector<1xi32>, axis = 0 : i32, cache = 1 : i32, evict = 1 : i32, isVolatile = false} : !tt.ptr<tensor<64x16xf16, #blocked>, 1> -> tensor<3x64x16xf16, #shared>
|
||||
%66 = triton_gpu.insert_slice %arg19, %1, %arg21 {async_agent = dense<0> : vector<1xi32>, axis = 0 : i32, cache = 1 : i32, evict = 1 : i32, isVolatile = false} : !tt.ptr<tensor<16x64xf16, #blocked1>, 1> -> tensor<3x16x64xf16, #shared1>
|
||||
triton_nvidia_gpu.producer_commit %2, %arg21 {async_agent = dense<0> : vector<1xi32>} : tensor<3x!triton_nvidia_gpu.token>, i32
|
||||
%67 = tt.advance %arg18, [%c0_i32, %c16_i32] {async_agent = dense<0> : vector<1xi32>} : <tensor<64x16xf16, #blocked>, 1>
|
||||
%68 = tt.advance %arg19, [%c16_i32, %c0_i32] {async_agent = dense<0> : vector<1xi32>} : <tensor<16x64xf16, #blocked1>, 1>
|
||||
%c1_i32_6 = arith.constant {async_agent = dense<0> : vector<1xi32>} 1 : i32
|
||||
%57 = arith.addi %arg14, %c1_i32_6 {async_agent = dense<0> : vector<1xi32>} : i32
|
||||
scf.yield {async_agent = dense<0> : vector<1xi32>} %55, %56, %57 : tensor<64x16x!tt.ptr<f16, 1>, #blocked1>, tensor<16x64x!tt.ptr<f16, 1>, #blocked>, i32
|
||||
%c0_i32_7 = arith.constant {async_agent = dense<0> : vector<1xi32>} 0 : i32
|
||||
%true = arith.constant {async_agent = dense<0> : vector<1xi32>} true
|
||||
%69 = arith.addi %arg21, %c1_i32_6 {async_agent = dense<0> : vector<1xi32>} : i32
|
||||
%70 = arith.cmpi uge, %69, %c3_i32_2 {async_agent = dense<0> : vector<1xi32>} : i32
|
||||
%71 = arith.cmpi ult, %69, %c3_i32_2 {async_agent = dense<0> : vector<1xi32>} : i32
|
||||
%72 = arith.subi %69, %c3_i32_2 {async_agent = dense<0> : vector<1xi32>} : i32
|
||||
%73 = arith.select %70, %72, %69 {async_agent = dense<0> : vector<1xi32>} : i32
|
||||
%74 = arith.xori %arg20, %true {async_agent = dense<0> : vector<1xi32>} : i1
|
||||
%75 = arith.andi %70, %74 {async_agent = dense<0> : vector<1xi32>} : i1
|
||||
%76 = arith.andi %71, %arg20 {async_agent = dense<0> : vector<1xi32>} : i1
|
||||
%77 = arith.ori %75, %76 {async_agent = dense<0> : vector<1xi32>} : i1
|
||||
scf.yield {async_agent = dense<0> : vector<1xi32>} %67, %68, %77, %73 : !tt.ptr<tensor<64x16xf16, #blocked>, 1>, !tt.ptr<tensor<16x64xf16, #blocked1>, 1>, i1, i32
|
||||
} {async_agent = dense<0> : vector<1xi32>}
|
||||
%62 = tt.advance %61#0, [%c0_i32, %35] {async_agent = dense<0> : vector<1xi32>} : <tensor<64x16xf16, #blocked>, 1>
|
||||
%63 = tt.advance %61#1, [%35, %c0_i32] {async_agent = dense<0> : vector<1xi32>} : <tensor<16x64xf16, #blocked1>, 1>
|
||||
%c1_i32_5 = arith.constant {async_agent = dense<0> : vector<1xi32>} 1 : i32
|
||||
%51 = arith.addi %arg10, %c1_i32_5 {async_agent = dense<0> : vector<1xi32>} : i32
|
||||
scf.yield {async_agent = dense<0> : vector<1xi32>} %51 : i32
|
||||
%64 = arith.addi %arg16, %c1_i32_5 {async_agent = dense<0> : vector<1xi32>} : i32
|
||||
scf.yield {async_agent = dense<0> : vector<1xi32>} %62, %63, %43, %44, %64 : !tt.ptr<tensor<64x16xf16, #blocked>, 1>, !tt.ptr<tensor<16x64xf16, #blocked1>, 1>, i32, i32, i32
|
||||
} {async_agent = dense<0> : vector<1xi32>}
|
||||
} {async_agent = dense<0> : vector<1xi32>}
|
||||
%c1_i32 = arith.constant 1 : i32
|
||||
%5 = arith.cmpi eq, %3, %c1_i32 : i32
|
||||
scf.if %5 {
|
||||
%c0_i32_0 = arith.constant 0 : i32
|
||||
%6 = triton_nvidia_gpu.get_mutex_role_id {async_agent = dense<1> : vector<1xi32>, num = 2 : i32} : i32
|
||||
%7 = arith.cmpi ne, %6, %c0_i32_0 : i32
|
||||
%8 = triton_nvidia_gpu.create_mutex {async_agent = dense<1> : vector<1xi32>} : !triton_nvidia_gpu.mutex
|
||||
%9 = triton_nvidia_gpu.create_mutex {async_agent = dense<1> : vector<1xi32>} : !triton_nvidia_gpu.mutex
|
||||
%30 = arith.cmpi eq, %28, %c1_i32 : i32
|
||||
scf.if %30 {
|
||||
%c0_i32_1 = arith.constant 0 : i32
|
||||
%31 = triton_nvidia_gpu.get_mutex_role_id {async_agent = dense<1> : vector<1xi32>, num = 2 : i32} : i32
|
||||
%32 = arith.cmpi ne, %31, %c0_i32_1 : i32
|
||||
%33 = triton_nvidia_gpu.create_mutex {async_agent = dense<1> : vector<1xi32>} : !triton_nvidia_gpu.mutex
|
||||
%34 = triton_nvidia_gpu.create_mutex {async_agent = dense<1> : vector<1xi32>} : !triton_nvidia_gpu.mutex
|
||||
%cst = arith.constant {async_agent = dense<1> : vector<1xi32>} dense<0.000000e+00> : tensor<64x64xf32, #mma>
|
||||
%c63_i32 = arith.constant {async_agent = dense<[0, 1]> : vector<2xi32>} 63 : i32
|
||||
%c114_i32 = arith.constant {async_agent = dense<[0, 1]> : vector<2xi32>} 114 : i32
|
||||
%c132_i32 = arith.constant {async_agent = dense<[0, 1]> : vector<2xi32>} 132 : i32
|
||||
%c16_i32 = arith.constant {async_agent = dense<[0, 1]> : vector<2xi32>} 16 : i32
|
||||
%c0_i32_1 = arith.constant {async_agent = dense<[0, 1]> : vector<2xi32>} 0 : i32
|
||||
%c64_i32 = arith.constant {async_agent = dense<[0, 1]> : vector<2xi32>} 64 : i32
|
||||
%10 = tt.get_program_id x {async_agent = dense<[0, 1]> : vector<2xi32>, axis = 0 : i32} : i32
|
||||
%11 = arith.addi %arg3, %c63_i32 {async_agent = dense<[0, 1]> : vector<2xi32>} : i32
|
||||
%12 = arith.divsi %11, %c64_i32 {async_agent = dense<[0, 1]> : vector<2xi32>} : i32
|
||||
%13 = arith.addi %arg4, %c63_i32 {async_agent = dense<[0, 1]> : vector<2xi32>} : i32
|
||||
%14 = arith.divsi %13, %c64_i32 {async_agent = dense<[0, 1]> : vector<2xi32>} : i32
|
||||
%15 = arith.muli %12, %14 {async_agent = dense<[0, 1]> : vector<2xi32>} : i32
|
||||
%16 = tt.make_range {async_agent = dense<1> : vector<1xi32>, end = 64 : i32, start = 0 : i32} : tensor<64xi32, #triton_gpu.slice<{dim = 0, parent = #blocked2}>>
|
||||
%17 = tt.make_range {async_agent = dense<1> : vector<1xi32>, end = 64 : i32, start = 0 : i32} : tensor<64xi32, #triton_gpu.slice<{dim = 1, parent = #blocked2}>>
|
||||
%18 = tt.splat %arg8 {async_agent = dense<1> : vector<1xi32>} : (i32) -> tensor<64x1xi32, #blocked2>
|
||||
%19 = tt.splat %arg2 {async_agent = dense<1> : vector<1xi32>} : (!tt.ptr<f32, 1>) -> tensor<64x1x!tt.ptr<f32, 1>, #blocked2>
|
||||
%35 = arith.muli %7, %5 {async_agent = dense<[0, 1]> : vector<2xi32>} : i32
|
||||
%c3_i32 = arith.constant {async_agent = dense<[0, 1]> : vector<2xi32>} 3 : i32
|
||||
%c0_i32_2 = arith.constant {async_agent = dense<[0, 1]> : vector<2xi32>} 0 : i32
|
||||
%20 = arith.muli %c114_i32, %6 {async_agent = dense<1> : vector<1xi32>} : i32
|
||||
%21 = arith.addi %10, %20 {async_agent = dense<1> : vector<1xi32>} : i32
|
||||
%36 = arith.muli %c132_i32, %31 {async_agent = dense<1> : vector<1xi32>} : i32
|
||||
%37 = arith.addi %3, %36 {async_agent = dense<1> : vector<1xi32>} : i32
|
||||
%c2_i32 = arith.constant {async_agent = dense<1> : vector<1xi32>} 2 : i32
|
||||
%22 = arith.muli %c114_i32, %c2_i32 {async_agent = dense<1> : vector<1xi32>} : i32
|
||||
%23 = arith.addi %c0_i32_2, %6 {async_agent = dense<1> : vector<1xi32>} : i32
|
||||
%24 = scf.for %arg9 = %21 to %15 step %22 iter_args(%arg10 = %23) -> (i32) : i32 {
|
||||
%25 = arith.cmpi ne, %arg9, %10 : i32
|
||||
%26 = arith.ori %25, %7 {agent.mutex_role = 0 : i32} : i1
|
||||
scf.if %26 {
|
||||
triton_nvidia_gpu.lock %8 {agent.mutex_role = 0 : i32} : !triton_nvidia_gpu.mutex
|
||||
%38 = arith.muli %c132_i32, %c2_i32 {async_agent = dense<1> : vector<1xi32>} : i32
|
||||
%39 = arith.addi %c0_i32_2, %31 {async_agent = dense<1> : vector<1xi32>} : i32
|
||||
%40:4 = scf.for %arg11 = %37 to %35 step %38 iter_args(%arg12 = %27, %arg13 = %15, %arg14 = %16, %arg15 = %39) -> (!tt.ptr<tensor<64x64xf16, #blocked>, 1>, i32, i32, i32) : i32 {
|
||||
%41 = arith.cmpi ne, %arg11, %3 : i32
|
||||
%42 = arith.ori %41, %32 : i1
|
||||
scf.if %42 {
|
||||
triton_nvidia_gpu.lock %33 {agent.mutex_role = 0 : i32} : !triton_nvidia_gpu.mutex
|
||||
} {agent.mutex_role = 0 : i32}
|
||||
%27 = arith.divsi %arg9, %14 {agent.mutex_role = 0 : i32, async_agent = dense<1> : vector<1xi32>} : i32
|
||||
%28 = arith.remsi %arg9, %14 {agent.mutex_role = 0 : i32, async_agent = dense<1> : vector<1xi32>} : i32
|
||||
%29 = arith.muli %27, %c64_i32 {agent.mutex_role = 0 : i32, async_agent = dense<1> : vector<1xi32>} : i32
|
||||
%30 = tt.splat %29 {agent.mutex_role = 0 : i32, async_agent = dense<1> : vector<1xi32>} : (i32) -> tensor<64xi32, #triton_gpu.slice<{dim = 1, parent = #blocked2}>>
|
||||
%31 = arith.addi %30, %17 {agent.mutex_role = 0 : i32, async_agent = dense<1> : vector<1xi32>} : tensor<64xi32, #triton_gpu.slice<{dim = 1, parent = #blocked2}>>
|
||||
%32 = arith.muli %28, %c64_i32 {agent.mutex_role = 0 : i32, async_agent = dense<1> : vector<1xi32>} : i32
|
||||
%33 = tt.splat %32 {agent.mutex_role = 0 : i32, async_agent = dense<1> : vector<1xi32>} : (i32) -> tensor<64xi32, #triton_gpu.slice<{dim = 0, parent = #blocked2}>>
|
||||
%34 = arith.addi %33, %16 {agent.mutex_role = 0 : i32, async_agent = dense<1> : vector<1xi32>} : tensor<64xi32, #triton_gpu.slice<{dim = 0, parent = #blocked2}>>
|
||||
%43 = arith.divsi %arg11, %8 {agent.mutex_role = 0 : i32, async_agent = dense<1> : vector<1xi32>} : i32
|
||||
%44 = arith.muli %43, %c8_i32 {agent.mutex_role = 0 : i32, async_agent = dense<1> : vector<1xi32>} : i32
|
||||
%45 = arith.subi %7, %44 {agent.mutex_role = 0 : i32, async_agent = dense<1> : vector<1xi32>} : i32
|
||||
%46 = arith.minsi %45, %c8_i32 {agent.mutex_role = 0 : i32, async_agent = dense<1> : vector<1xi32>} : i32
|
||||
%47 = arith.remsi %arg11, %8 {agent.mutex_role = 0 : i32, async_agent = dense<1> : vector<1xi32>} : i32
|
||||
%48 = arith.remsi %47, %46 {agent.mutex_role = 0 : i32, async_agent = dense<1> : vector<1xi32>} : i32
|
||||
%49 = arith.addi %44, %48 {agent.mutex_role = 0 : i32, async_agent = dense<1> : vector<1xi32>} : i32
|
||||
%50 = arith.divsi %47, %46 {agent.mutex_role = 0 : i32, async_agent = dense<1> : vector<1xi32>} : i32
|
||||
%51 = arith.subi %49, %arg13 {agent.mutex_role = 0 : i32, async_agent = dense<1> : vector<1xi32>} : i32
|
||||
%52 = arith.muli %51, %c64_i32 {agent.mutex_role = 0 : i32, async_agent = dense<1> : vector<1xi32>} : i32
|
||||
%53 = arith.subi %50, %arg14 {agent.mutex_role = 0 : i32, async_agent = dense<1> : vector<1xi32>} : i32
|
||||
%54 = arith.muli %53, %c64_i32 {agent.mutex_role = 0 : i32, async_agent = dense<1> : vector<1xi32>} : i32
|
||||
%c3_i32_3 = arith.constant {agent.mutex_role = 0 : i32, async_agent = dense<1> : vector<1xi32>} 3 : i32
|
||||
%35 = arith.subi %arg5, %c0_i32_1 {agent.mutex_role = 0 : i32, async_agent = dense<1> : vector<1xi32>} : i32
|
||||
%36 = arith.divui %35, %c16_i32 {agent.mutex_role = 0 : i32, async_agent = dense<1> : vector<1xi32>} : i32
|
||||
%37 = arith.muli %arg10, %36 {agent.mutex_role = 0 : i32, async_agent = dense<1> : vector<1xi32>} : i32
|
||||
%c3_i32_4 = arith.constant {agent.mutex_role = 0 : i32, async_agent = dense<1> : vector<1xi32>} 3 : i32
|
||||
%38:2 = scf.for %arg11 = %c0_i32_1 to %arg5 step %c16_i32 iter_args(%arg12 = %cst, %arg13 = %37) -> (tensor<64x64xf32, #mma>, i32) : i32 {
|
||||
%48 = arith.remsi %arg13, %c3_i32_4 {agent.mutex_role = 0 : i32, async_agent = dense<1> : vector<1xi32>} : i32
|
||||
triton_nvidia_gpu.consumer_wait %2, %48 {agent.mutex_role = 0 : i32, async_agent = dense<1> : vector<1xi32>} : tensor<3x!triton_nvidia_gpu.token>, i32
|
||||
%49 = triton_gpu.extract_slice %0[%48, 0, 0] [1, 64, 16] [1, 1, 1] {agent.mutex_role = 0 : i32, async_agent = dense<1> : vector<1xi32>} : tensor<3x64x16xf16, #shared> to tensor<64x16xf16, #shared>
|
||||
%50 = triton_gpu.convert_layout %49 {agent.mutex_role = 0 : i32, async_agent = dense<1> : vector<1xi32>} : (tensor<64x16xf16, #shared>) -> tensor<64x16xf16, #shared>
|
||||
%51 = triton_gpu.extract_slice %1[%48, 0, 0] [1, 16, 64] [1, 1, 1] {agent.mutex_role = 0 : i32, async_agent = dense<1> : vector<1xi32>} : tensor<3x16x64xf16, #shared1> to tensor<16x64xf16, #shared1>
|
||||
%52 = triton_gpu.convert_layout %51 {agent.mutex_role = 0 : i32, async_agent = dense<1> : vector<1xi32>} : (tensor<16x64xf16, #shared1>) -> tensor<16x64xf16, #shared1>
|
||||
%53 = tt.dot %50, %52, %arg12 {agent.mutex_role = 0 : i32, allowTF32 = true, maxNumImpreciseAcc = 0 : i32, async_agent = dense<1> : vector<1xi32>} : tensor<64x16xf16, #shared> * tensor<16x64xf16, #shared1> -> tensor<64x64xf32, #mma>
|
||||
triton_nvidia_gpu.consumer_release %2, %48 {agent.mutex_role = 0 : i32, async_agent = dense<1> : vector<1xi32>} : tensor<3x!triton_nvidia_gpu.token>, i32
|
||||
%c1_i32_6 = arith.constant {agent.mutex_role = 0 : i32, async_agent = dense<1> : vector<1xi32>} 1 : i32
|
||||
%54 = arith.addi %arg13, %c1_i32_6 {agent.mutex_role = 0 : i32, async_agent = dense<1> : vector<1xi32>} : i32
|
||||
scf.yield {agent.mutex_role = 0 : i32, async_agent = dense<1> : vector<1xi32>} %53, %54 : tensor<64x64xf32, #mma>, i32
|
||||
} {agent.mutex_role = 1 : i32, async_agent = dense<1> : vector<1xi32>}
|
||||
triton_nvidia_gpu.unlock %8 : !triton_nvidia_gpu.mutex
|
||||
scf.if %26 {
|
||||
triton_nvidia_gpu.lock %9 {agent.mutex_role = 1 : i32} : !triton_nvidia_gpu.mutex
|
||||
%c0_i32_4 = arith.constant {agent.mutex_role = 0 : i32, async_agent = dense<1> : vector<1xi32>} 0 : i32
|
||||
%55 = arith.subi %arg7, %c0_i32 {agent.mutex_role = 0 : i32, async_agent = dense<1> : vector<1xi32>} : i32
|
||||
%56 = arith.addi %55, %c16_i32 {agent.mutex_role = 0 : i32, async_agent = dense<1> : vector<1xi32>} : i32
|
||||
%c1_i32_5 = arith.constant {agent.mutex_role = 0 : i32, async_agent = dense<1> : vector<1xi32>} 1 : i32
|
||||
%c2_i32_6 = arith.constant {agent.mutex_role = 0 : i32, async_agent = dense<1> : vector<1xi32>} 2 : i32
|
||||
%57 = arith.subi %56, %c1_i32_5 {agent.mutex_role = 0 : i32, async_agent = dense<1> : vector<1xi32>} : i32
|
||||
%58 = arith.divui %57, %c16_i32 {agent.mutex_role = 0 : i32, async_agent = dense<1> : vector<1xi32>} : i32
|
||||
%59 = arith.muli %arg15, %58 {agent.mutex_role = 0 : i32, async_agent = dense<1> : vector<1xi32>} : i32
|
||||
%60 = arith.divui %59, %c3_i32_3 {agent.mutex_role = 0 : i32, async_agent = dense<1> : vector<1xi32>} : i32
|
||||
%61 = arith.muli %60, %c3_i32_3 {agent.mutex_role = 0 : i32, async_agent = dense<1> : vector<1xi32>} : i32
|
||||
%62 = arith.subi %59, %61 {agent.mutex_role = 0 : i32, async_agent = dense<1> : vector<1xi32>} : i32
|
||||
%63 = arith.andi %60, %c1_i32_5 {agent.mutex_role = 0 : i32, async_agent = dense<1> : vector<1xi32>} : i32
|
||||
%64 = arith.trunci %63 {agent.mutex_role = 0 : i32, async_agent = dense<1> : vector<1xi32>} : i32 to i1
|
||||
%65:3 = scf.for %arg16 = %c0_i32 to %arg7 step %c16_i32 iter_args(%arg17 = %cst, %arg18 = %64, %arg19 = %62) -> (tensor<64x64xf32, #mma>, i1, i32) : i32 {
|
||||
triton_nvidia_gpu.consumer_wait %2, %arg19 {agent.mutex_role = 0 : i32, async_agent = dense<1> : vector<1xi32>} : tensor<3x!triton_nvidia_gpu.token>, i32
|
||||
%74 = triton_gpu.extract_slice %0[%arg19, 0, 0] [1, 64, 16] [1, 1, 1] {agent.mutex_role = 0 : i32, async_agent = dense<1> : vector<1xi32>} : tensor<3x64x16xf16, #shared> to tensor<64x16xf16, #shared>
|
||||
%75 = triton_gpu.convert_layout %74 {agent.mutex_role = 0 : i32, async_agent = dense<1> : vector<1xi32>} : (tensor<64x16xf16, #shared>) -> tensor<64x16xf16, #shared>
|
||||
%76 = triton_gpu.extract_slice %1[%arg19, 0, 0] [1, 16, 64] [1, 1, 1] {agent.mutex_role = 0 : i32, async_agent = dense<1> : vector<1xi32>} : tensor<3x16x64xf16, #shared1> to tensor<16x64xf16, #shared1>
|
||||
%77 = triton_gpu.convert_layout %76 {agent.mutex_role = 0 : i32, async_agent = dense<1> : vector<1xi32>} : (tensor<16x64xf16, #shared1>) -> tensor<16x64xf16, #shared1>
|
||||
%78 = triton_nvidia_gpu.dot_async %75, %77, %arg17 {agent.mutex_role = 0 : i32, allowTF32 = true, async_agent = dense<1> : vector<1xi32>, maxNumImpreciseAcc = 0 : i32} : tensor<64x16xf16, #shared> * tensor<16x64xf16, #shared1> -> tensor<64x64xf32, #mma>
|
||||
%79 = arith.cmpi sgt, %arg16, %c0_i32 {agent.mutex_role = 0 : i32, async_agent = dense<1> : vector<1xi32>} : i32
|
||||
scf.if %79 {
|
||||
%c0_i32_13 = arith.constant {agent.mutex_role = 0 : i32, async_agent = dense<1> : vector<1xi32>} 0 : i32
|
||||
%c1_i32_14 = arith.constant {agent.mutex_role = 0 : i32, async_agent = dense<1> : vector<1xi32>} 1 : i32
|
||||
%c2_i32_15 = arith.constant {agent.mutex_role = 0 : i32, async_agent = dense<1> : vector<1xi32>} 2 : i32
|
||||
%89 = arith.subi %arg19, %c1_i32_14 {agent.mutex_role = 0 : i32, async_agent = dense<1> : vector<1xi32>} : i32
|
||||
%90 = arith.cmpi eq, %arg19, %c0_i32_13 {agent.mutex_role = 0 : i32, async_agent = dense<1> : vector<1xi32>} : i32
|
||||
%91 = arith.select %90, %c2_i32_15, %89 {agent.mutex_role = 0 : i32, async_agent = dense<1> : vector<1xi32>} : i32
|
||||
triton_nvidia_gpu.consumer_release %2, %91 {agent.mutex_role = 0 : i32, async_agent = dense<1> : vector<1xi32>} : tensor<3x!triton_nvidia_gpu.token>, i32
|
||||
} {agent.mutex_role = 0 : i32, async_agent = dense<1> : vector<1xi32>}
|
||||
%c1_i32_11 = arith.constant {agent.mutex_role = 0 : i32, async_agent = dense<1> : vector<1xi32>} 1 : i32
|
||||
%c0_i32_12 = arith.constant {agent.mutex_role = 0 : i32, async_agent = dense<1> : vector<1xi32>} 0 : i32
|
||||
%true = arith.constant {agent.mutex_role = 0 : i32, async_agent = dense<1> : vector<1xi32>} true
|
||||
%80 = arith.addi %arg19, %c1_i32_11 {agent.mutex_role = 0 : i32, async_agent = dense<1> : vector<1xi32>} : i32
|
||||
%81 = arith.cmpi uge, %80, %c3_i32_3 {agent.mutex_role = 0 : i32, async_agent = dense<1> : vector<1xi32>} : i32
|
||||
%82 = arith.cmpi ult, %80, %c3_i32_3 {agent.mutex_role = 0 : i32, async_agent = dense<1> : vector<1xi32>} : i32
|
||||
%83 = arith.subi %80, %c3_i32_3 {agent.mutex_role = 0 : i32, async_agent = dense<1> : vector<1xi32>} : i32
|
||||
%84 = arith.select %81, %83, %80 {agent.mutex_role = 0 : i32, async_agent = dense<1> : vector<1xi32>} : i32
|
||||
%85 = arith.xori %arg18, %true {agent.mutex_role = 0 : i32, async_agent = dense<1> : vector<1xi32>} : i1
|
||||
%86 = arith.andi %81, %85 {agent.mutex_role = 0 : i32, async_agent = dense<1> : vector<1xi32>} : i1
|
||||
%87 = arith.andi %82, %arg18 {agent.mutex_role = 0 : i32, async_agent = dense<1> : vector<1xi32>} : i1
|
||||
%88 = arith.ori %86, %87 {agent.mutex_role = 0 : i32, async_agent = dense<1> : vector<1xi32>} : i1
|
||||
scf.yield {agent.mutex_role = 0 : i32, async_agent = dense<1> : vector<1xi32>} %78, %88, %84 : tensor<64x64xf32, #mma>, i1, i32
|
||||
} {agent.mutex_role = 0 : i32, async_agent = dense<1> : vector<1xi32>}
|
||||
triton_nvidia_gpu.unlock %33 : !triton_nvidia_gpu.mutex
|
||||
%66 = triton_nvidia_gpu.dot_wait %65#0 {agent.mutex_role = 0 : i32, async_agent = dense<1> : vector<1xi32>, pendings = 0 : i32} : tensor<64x64xf32, #mma>
|
||||
%c0_i32_7 = arith.constant {agent.mutex_role = 0 : i32, async_agent = dense<1> : vector<1xi32>} 0 : i32
|
||||
%c1_i32_8 = arith.constant {agent.mutex_role = 0 : i32, async_agent = dense<1> : vector<1xi32>} 1 : i32
|
||||
%c2_i32_9 = arith.constant {agent.mutex_role = 0 : i32, async_agent = dense<1> : vector<1xi32>} 2 : i32
|
||||
%67 = arith.subi %65#2, %c1_i32_8 {agent.mutex_role = 0 : i32, async_agent = dense<1> : vector<1xi32>} : i32
|
||||
%68 = arith.cmpi eq, %65#2, %c0_i32_7 {agent.mutex_role = 0 : i32, async_agent = dense<1> : vector<1xi32>} : i32
|
||||
%69 = arith.select %68, %c2_i32_9, %67 {agent.mutex_role = 0 : i32, async_agent = dense<1> : vector<1xi32>} : i32
|
||||
triton_nvidia_gpu.consumer_release %2, %69 {agent.mutex_role = 0 : i32, async_agent = dense<1> : vector<1xi32>} : tensor<3x!triton_nvidia_gpu.token>, i32
|
||||
scf.if %42 {
|
||||
triton_nvidia_gpu.lock %34 {agent.mutex_role = 1 : i32} : !triton_nvidia_gpu.mutex
|
||||
} {agent.mutex_role = 1 : i32}
|
||||
%39 = tt.expand_dims %31 {agent.mutex_role = 1 : i32, async_agent = dense<1> : vector<1xi32>, axis = 1 : i32} : (tensor<64xi32, #triton_gpu.slice<{dim = 1, parent = #blocked2}>>) -> tensor<64x1xi32, #blocked2>
|
||||
%40 = arith.muli %39, %18 {agent.mutex_role = 1 : i32, async_agent = dense<1> : vector<1xi32>} : tensor<64x1xi32, #blocked2>
|
||||
%41 = tt.addptr %19, %40 {agent.mutex_role = 1 : i32, async_agent = dense<1> : vector<1xi32>} : tensor<64x1x!tt.ptr<f32, 1>, #blocked2>, tensor<64x1xi32, #blocked2>
|
||||
%42 = tt.expand_dims %34 {agent.mutex_role = 1 : i32, async_agent = dense<1> : vector<1xi32>, axis = 0 : i32} : (tensor<64xi32, #triton_gpu.slice<{dim = 0, parent = #blocked2}>>) -> tensor<1x64xi32, #blocked2>
|
||||
%43 = tt.broadcast %41 {agent.mutex_role = 1 : i32, async_agent = dense<1> : vector<1xi32>} : (tensor<64x1x!tt.ptr<f32, 1>, #blocked2>) -> tensor<64x64x!tt.ptr<f32, 1>, #blocked2>
|
||||
%44 = tt.broadcast %42 {agent.mutex_role = 1 : i32, async_agent = dense<1> : vector<1xi32>} : (tensor<1x64xi32, #blocked2>) -> tensor<64x64xi32, #blocked2>
|
||||
%45 = tt.addptr %43, %44 {agent.mutex_role = 1 : i32, async_agent = dense<1> : vector<1xi32>} : tensor<64x64x!tt.ptr<f32, 1>, #blocked2>, tensor<64x64xi32, #blocked2>
|
||||
%46 = triton_gpu.convert_layout %38#0 {agent.mutex_role = 1 : i32, async_agent = dense<1> : vector<1xi32>} : (tensor<64x64xf32, #mma>) -> tensor<64x64xf32, #blocked2>
|
||||
tt.store %45, %46 {agent.mutex_role = 1 : i32, async_agent = dense<1> : vector<1xi32>, cache = 1 : i32, evict = 1 : i32} : tensor<64x64xf32, #blocked2>
|
||||
triton_nvidia_gpu.unlock %9 : !triton_nvidia_gpu.mutex
|
||||
%c1_i32_5 = arith.constant {agent.mutex_role = 1 : i32, async_agent = dense<1> : vector<1xi32>} 1 : i32
|
||||
%47 = arith.addi %arg10, %c2_i32 {agent.mutex_role = 1 : i32, async_agent = dense<1> : vector<1xi32>} : i32
|
||||
scf.yield {async_agent = dense<1> : vector<1xi32>} %47 : i32
|
||||
%70 = arith.truncf %66 {agent.mutex_role = 1 : i32, async_agent = dense<1> : vector<1xi32>} : tensor<64x64xf32, #mma> to tensor<64x64xf16, #mma>
|
||||
%71 = tt.advance %arg12, [%52, %54] {agent.mutex_role = 1 : i32, async_agent = dense<1> : vector<1xi32>} : <tensor<64x64xf16, #blocked>, 1>
|
||||
%72 = triton_gpu.convert_layout %70 {agent.mutex_role = 1 : i32, async_agent = dense<1> : vector<1xi32>} : (tensor<64x64xf16, #mma>) -> tensor<64x64xf16, #blocked2>
|
||||
tt.store %71, %72 {agent.mutex_role = 1 : i32, async_agent = dense<1> : vector<1xi32>, boundaryCheck = array<i32: 0, 1>, cache = 1 : i32, evict = 1 : i32} : !tt.ptr<tensor<64x64xf16, #blocked>, 1>, tensor<64x64xf16, #blocked2>
|
||||
triton_nvidia_gpu.unlock %34 : !triton_nvidia_gpu.mutex
|
||||
%c1_i32_10 = arith.constant {agent.mutex_role = 1 : i32, async_agent = dense<1> : vector<1xi32>} 1 : i32
|
||||
%73 = arith.addi %arg15, %c2_i32 {agent.mutex_role = 1 : i32, async_agent = dense<1> : vector<1xi32>} : i32
|
||||
scf.yield {async_agent = dense<1> : vector<1xi32>} %71, %49, %50, %73 : !tt.ptr<tensor<64x64xf16, #blocked>, 1>, i32, i32, i32
|
||||
} {async_agent = dense<1> : vector<1xi32>}
|
||||
} {"agent.num-roles" = 2 : i32, async_agent = dense<1> : vector<1xi32>}
|
||||
tt.return
|
||||
|
||||
@@ -2,9 +2,9 @@
|
||||
|
||||
|
||||
// CHECK: scf.if
|
||||
// CHECK: triton_nvidia_gpu.create_mutex
|
||||
// CHECK: triton_nvidia_gpu.create_mutex
|
||||
// CHECK: scf.for
|
||||
// CHECK: triton_nvidia_gpu.create_mutex
|
||||
// CHECK: triton_nvidia_gpu.create_mutex
|
||||
// CHECK: triton_nvidia_gpu.lock
|
||||
// CHECK: agent.mutex_role = 0 : i32
|
||||
// CHECK: triton_nvidia_gpu.unlock
|
||||
|
||||
@@ -22,9 +22,10 @@
|
||||
// CHECK: triton_gpu.extract_slice
|
||||
// CHECK: triton_gpu.extract_slice
|
||||
// CHECK: triton_nvidia_gpu.dot_async
|
||||
// CHECK: triton_nvidia_gpu.dot_wait
|
||||
// CHECK: triton_nvidia_gpu.dot_wait {{.*}} pendings = 1
|
||||
// CHECK: triton_nvidia_gpu.consumer_release
|
||||
// CHECK: scf.yield
|
||||
// CHECK: triton_nvidia_gpu.dot_wait {{.*}} pendings = 0
|
||||
// CHECK: async_agent = dense<1> : vector<1xi32>
|
||||
|
||||
#blocked = #triton_gpu.blocked<{sizePerThread = [8, 1], threadsPerWarp = [4, 8], warpsPerCTA = [1, 4], order = [0, 1], CTAsPerCGA = [1, 1], CTASplitNum = [1, 1], CTAOrder = [0, 1]}>
|
||||
|
||||
@@ -20,7 +20,7 @@ struct TestAliasPass
|
||||
return opName;
|
||||
}
|
||||
|
||||
static void print(StringRef name, SmallVector<std::string, 4> &vals,
|
||||
static void print(StringRef name, SmallVector<std::string> &vals,
|
||||
raw_ostream &os) {
|
||||
if (vals.empty())
|
||||
return;
|
||||
@@ -57,7 +57,7 @@ struct TestAliasPass
|
||||
auto getAllocOpNames = [&](Value value) {
|
||||
dataflow::Lattice<AliasInfo> *latticeElement =
|
||||
analysis->getLatticeElement(value);
|
||||
SmallVector<std::string, 4> opNames;
|
||||
SmallVector<std::string> opNames;
|
||||
if (latticeElement) {
|
||||
auto &info = latticeElement->getValue();
|
||||
for (auto &alias : info.getAllocs()) {
|
||||
|
||||
2
third_party/triton_shared
vendored
2
third_party/triton_shared
vendored
Submodule third_party/triton_shared updated: d0ac5898ff...07ea84207a
Reference in New Issue
Block a user