Merge commit '721897fcc4f942aa97d2e9ba3787a5e213758177' into ifu-231108

Conflicts:
	bin/triton-translate.cpp
	lib/Conversion/TritonGPUToLLVM/ElementwiseOpToLLVM.cpp
	lib/Dialect/TritonGPU/Transforms/RemoveLayoutConversions.cpp
	python/triton/compiler/compiler.py
	python/triton/runtime/jit.py
	python/tutorials/06-fused-attention.py
	test/Conversion/tritongpu_to_llvm.mlir
This commit is contained in:
Jason Furmanek
2023-11-08 18:51:23 +00:00
72 changed files with 1623 additions and 838 deletions

View File

@@ -885,7 +885,8 @@ public:
//===----------------------------------------------------------------------===//
AxisInfoAnalysis::AxisInfoAnalysis(DataFlowSolver &solver)
: dataflow::SparseDataFlowAnalysis<dataflow::Lattice<AxisInfo>>(solver) {
: dataflow::SparseForwardDataFlowAnalysis<dataflow::Lattice<AxisInfo>>(
solver) {
// UnrealizedConversionCast:
// This is needed by TritonGPUToLLVM, to get AxisInfo when the graph is
// in the process of a PartialConversion, where UnrealizedConversionCast

View File

@@ -38,6 +38,17 @@ bool ReduceOpHelper::isReductionOnLayoutFastAxis() {
getParentOrder(getSrcLayout())[0];
}
SmallVector<unsigned> ReduceOpHelper::getOrderWithAxisAtBeginning() {
auto srcLayout = getSrcLayout();
auto order = triton::gpu::getOrder(srcLayout);
auto it = std::find(order.begin(), order.end(), axis);
// delete the axis from order
order.erase(it);
// insert axis at the beginning of order
order.insert(order.begin(), axis);
return order;
}
// Thread offset is the thread index offset of two adjacent threads on the
// reduction axis within the warp.
unsigned ReduceOpHelper::getThreadOffsetOnReductionAxis() {
@@ -56,11 +67,11 @@ unsigned ReduceOpHelper::getThreadOffsetOnReductionAxis() {
threadOffset = threadsPerWarp[sliceLayout.getDim()];
} else {
auto threadsPerWarp = triton::gpu::getThreadsPerWarp(srcLayout);
if (threadsPerWarp.size() == 1) {
threadOffset = 1;
} else {
assert(threadsPerWarp.size() == 2 && "Only supports 2D layouts");
threadOffset = axis == 0 ? threadsPerWarp[1] : threadsPerWarp[0];
auto order = triton::gpu::getOrder(srcLayout);
for (unsigned i = 0; i < order.size(); i++) {
if (order[i] == axis)
break;
threadOffset *= threadsPerWarp[order[i]];
}
}
return threadOffset;
@@ -150,8 +161,10 @@ unsigned ReduceOpHelper::getThreadsReductionAxis() {
}
bool ReduceOpHelper::isWarpSynchronous() {
auto argsLayout = getSrcLayout();
return triton::gpu::getWarpsPerCTA(argsLayout)[axis] == 1;
auto srcLayout = getSrcLayout();
auto srcShape = getSrcShape();
return triton::gpu::getWarpsPerCTAWithUniqueData(srcLayout, srcShape)[axis] ==
1;
}
SmallVector<unsigned> ReduceOpHelper::getScratchConfig() {
@@ -502,10 +515,10 @@ static bool isMmaToMmaShortcut(Attribute srcEncoding, Attribute dstEncoding) {
// when #mma = MmaEncoding<version=3, warpsPerCTA=[..., 1]>
return src && dst && src.getVersionMajor() == 3 &&
src.getWarpsPerCTA()[1] == 1 && dst.getVersionMajor() == 3 &&
dst.getWarpsPerCTA()[1] == 1 && srcInstrShape[2] == dstInstrShape[2];
dst.getWarpsPerCTA()[1] == 1;
}
bool isMmaToMmaShortcut(RankedTensorType &srcTy, RankedTensorType &dstTy) {
bool isMmaToMmaShortcut(RankedTensorType srcTy, RankedTensorType dstTy) {
return isMmaToMmaShortcut(srcTy.getEncoding(), dstTy.getEncoding());
}
@@ -521,7 +534,7 @@ bool matchMmaV3AndDotOperandLayout(RankedTensorType srcTy,
srcTy.getElementType().isF16();
}
bool isMmaToDotShortcut(RankedTensorType &srcTy, RankedTensorType &dstTy) {
bool isMmaToDotShortcut(RankedTensorType srcTy, RankedTensorType dstTy) {
if (matchMmaV3AndDotOperandLayout(srcTy, dstTy))
return true;
// dot_op<opIdx=0, parent=#mma> = #mma
@@ -713,7 +726,10 @@ SetVector<Operation *> multiRootGetSlice(Operation *op,
auto *currentOp = (slice)[currentIndex];
// Compute and insert the backwardSlice starting from currentOp.
backwardSlice.clear();
getBackwardSlice(currentOp, &backwardSlice, backwardFilter);
mlir::BackwardSliceOptions opt;
opt.omitBlockArguments = true;
opt.filter = backwardFilter;
getBackwardSlice(currentOp, &backwardSlice, opt);
slice.insert(backwardSlice.begin(), backwardSlice.end());
// Compute and insert the forwardSlice starting from currentOp.

View File

@@ -14,7 +14,7 @@ add_mlir_conversion_library(NVGPUToLLVM
LINK_LIBS PUBLIC
MLIRIR
MLIRPass
MLIRGPUOps
MLIRGPUDialect
MLIRGPUToNVVMTransforms
MLIRGPUToROCDLTransforms
MLIRGPUTransforms

View File

@@ -29,8 +29,6 @@ const std::string Reg_Alloc_Op = "setmaxnreg.inc.sync.aligned.u32 #regCount;";
const std::string Wgmma_Fence_Op = "wgmma.fence.sync.aligned;";
const std::string Cga_Barrier_Sync_op = "barrier.cluster.sync.aligned;";
const std::string Wgmma_Commit_Group_Op = "wgmma.commit_group.sync.aligned;";
const std::string Wgmma_Wait_Group_Op =
"wgmma.wait_group.sync.aligned #pendings;";
const std::string Cluster_Wait_Op = "barrier.cluster.wait.aligned;";
const std::string Fence_Mbarrier_Init_Op =
"fence.mbarrier_init.release.cluster;";
@@ -200,29 +198,6 @@ public:
return {};
}
Type getReturnType(std::vector<std::string> outputConstraints,
mlir::PatternRewriter &rewriter) const {
auto ctx = rewriter.getContext();
Type resTy;
if (outputConstraints.empty()) {
resTy = void_ty(ctx);
} else {
SmallVector<Type> retTys;
for (auto &outputConstraint : outputConstraints) {
assert(outputConstraint[0] == '=' &&
"Constraint must be for an output");
Type retTy = getTypeFromConstraint(outputConstraint[1], rewriter);
retTys.push_back(retTy);
}
if (retTys.size() == 1) {
resTy = retTys[0];
} else {
resTy = struct_ty(retTys);
}
}
return resTy;
}
std::string patchPtxAsm(mlir::Operation *op, std::string ptxAsm) const {
std::vector<std::pair<int, int>> patchLocations;
std::vector<std::string> patchValues;
@@ -285,7 +260,8 @@ public:
outputsAndOperands.append(ptxOperands.begin(), ptxOperands.end());
auto &ptxInstr = *ptxBuilder.create<PTXInstr>(ptxAsmPatched);
ptxInstr(outputsAndOperands, /*onlyAttachMLIRArgs=*/true);
auto retTy = getReturnType(outputConstraints, rewriter);
auto retTy =
op->getNumResults() == 0 ? void_ty(ctx) : op->getResult(0).getType();
auto res = ptxBuilder.launch(rewriter, loc, retTy,
/*hasSideEffects*/ hasSideEffects);
if (op->getNumResults() == 0) {
@@ -700,6 +676,45 @@ public:
}
};
class WGMMAWaitGroupOpPattern
: public NVGPUOpPatternBase<ttn::WGMMAWaitGroupOp,
WGMMAWaitGroupOpPattern> {
public:
using Base =
NVGPUOpPatternBase<ttn::WGMMAWaitGroupOp, WGMMAWaitGroupOpPattern>;
using Base::Base;
std::vector<std::string>
getOutputConstraints(ttn::WGMMAWaitGroupOp op) const {
auto outputStructType = op.getType().cast<LLVM::LLVMStructType>();
uint32_t numOutputRegs = outputStructType.getBody().size();
std::string output =
outputStructType.getBody().front().isF32() ? "=f" : "=r";
return std::vector<std::string>(numOutputRegs, output);
}
OperandsAndConstraints
getOperandsAndConstraints(ttn::WGMMAWaitGroupOp op) const {
OperandsAndConstraints operandsAndConstraints;
auto input = op.getInput();
operandsAndConstraints.push_back({input, "0"});
return operandsAndConstraints;
}
std::string getPtxAsm(ttn::WGMMAWaitGroupOp op) const {
auto outputStructType = op.getType().dyn_cast<LLVM::LLVMStructType>();
uint32_t numCRegs = outputStructType.getBody().size();
std::string args = "";
uint32_t asmOpIdx = 0;
for (uint32_t i = 0; i < numCRegs; ++i) {
args += "$" + std::to_string(asmOpIdx++) + (i == numCRegs - 1 ? "" : ",");
}
auto ptxAsm = "// wait for regs: " + args + "\n\t" +
"wgmma.wait_group.sync.aligned #pendings;";
return ptxAsm;
}
};
class WGMMAOpPattern : public NVGPUOpPatternBase<ttn::WGMMAOp, WGMMAOpPattern> {
public:
using Base = NVGPUOpPatternBase<ttn::WGMMAOp, WGMMAOpPattern>;
@@ -1072,7 +1087,6 @@ public:
POPULATE_NVGPU_OP(ttn::WGMMAFenceOp, Wgmma_Fence_Op)
POPULATE_NVGPU_OP(ttn::CGABarrierSyncOp, Cga_Barrier_Sync_op)
POPULATE_NVGPU_OP(ttn::WGMMACommitGroupOp, Wgmma_Commit_Group_Op)
POPULATE_NVGPU_OP(ttn::WGMMAWaitGroupOp, Wgmma_Wait_Group_Op)
POPULATE_NVGPU_OP(ttn::ClusterWaitOp, Cluster_Wait_Op)
POPULATE_NVGPU_OP(ttn::FenceMBarrierInitOp, Fence_Mbarrier_Init_Op)
POPULATE_NVGPU_OP(ttn::CGABarrierArriveOp, Cga_Barrier_Arrive_Op)
@@ -1100,7 +1114,8 @@ public:
OffsetOfStmatrixV4OpPattern, MBarrierArriveOpPattern,
ClusterArriveOpPattern, TMALoadTiledOpPattern,
TMAStoreTiledOpPattern, LoadDSmemOpPattern, WGMMAOpPattern,
StoreDSmemOpPattern, OffsetOfSts64OpPattern>(context);
WGMMAWaitGroupOpPattern, StoreDSmemOpPattern,
OffsetOfSts64OpPattern>(context);
if (applyPatternsAndFoldGreedily(mod, std::move(patterns)).failed())
signalPassFailure();

View File

@@ -49,7 +49,7 @@ add_mlir_conversion_library(TritonGPUToLLVM
ASMBuilder
MLIRIR
MLIRPass
MLIRGPUOps
MLIRGPUDialect
MLIRGPUToNVVMTransforms
MLIRGPUToROCDLTransforms
MLIRGPUTransforms

View File

@@ -146,10 +146,8 @@ struct DotWaitOpConversion
matchAndRewrite(triton::nvidia_gpu::DotWaitOp op, OpAdaptor adaptor,
ConversionPatternRewriter &rewriter) const override {
auto pendings = op.getPendings();
rewriter.create<triton::nvgpu::WGMMAWaitGroupOp>(op.getLoc(), pendings);
// Safe to remove the op since it doesn't have any return value.
rewriter.eraseOp(op);
rewriter.replaceOpWithNewOp<triton::nvgpu::WGMMAWaitGroupOp>(
op, adaptor.getInput(), pendings);
return success();
}
};

View File

@@ -168,7 +168,9 @@ DotOpMmaV3SmemLoader loadA(TritonGPUToLLVMTypeConverter *typeConverter,
int numRepM = ceil<unsigned>(shapePerCTA[0], instrShape[0] * wpt[0]);
int numRepK = ceil<unsigned>(shapePerCTA[1], instrShape[2]);
Value warp = udiv(thread, i32_val(32));
// The descriptor should be calculated based on the first warp of the
// warpgroup.
Value warp = and_(udiv(thread, i32_val(32)), i32_val(0xFFFFFFFC));
Value warpM = urem(warp, i32_val(wpt[0]));
Value warpId = urem(warpM, i32_val(shapePerCTA[0] / instrShape[0]));
@@ -199,7 +201,7 @@ DotOpMmaV3SmemLoader loadB(TritonGPUToLLVMTypeConverter *typeConverter,
int numRepK = ceil<unsigned>(shapePerCTA[0], instrShape[2]);
int numRepN = ceil<unsigned>(shapePerCTA[1], instrShape[1] * wpt[1]);
Value warp = udiv(thread, i32_val(32));
Value warp = and_(udiv(thread, i32_val(32)), i32_val(0xFFFFFFFC));
Value warpMN = udiv(warp, i32_val(wpt[0]));
Value warpN = urem(warpMN, i32_val(wpt[1]));
Value warpId = urem(warpN, i32_val(shapePerCTA[1] / instrShape[1]));
@@ -293,6 +295,26 @@ static bool isZero(Value v) {
return false;
}
static SmallVector<Value> emitWait(ConversionPatternRewriter &rewriter,
Location loc, SmallVector<Value> acc,
int pendings) {
SmallVector<Type> types(acc.size(), acc[0].getType());
auto structTy =
LLVM::LLVMStructType::getLiteral(rewriter.getContext(), types);
Value llvmStruct = rewriter.create<LLVM::UndefOp>(loc, structTy);
int i = 0;
for (Value v : acc) {
llvmStruct = insert_val(structTy, llvmStruct, v, i++);
}
Value res = rewriter.create<triton::nvgpu::WGMMAWaitGroupOp>(loc, llvmStruct,
pendings);
SmallVector<Value> results;
for (int i = 0; i < acc.size(); ++i) {
results.push_back(extract_val(types[0], res, i));
}
return results;
}
LogicalResult convertDot(TritonGPUToLLVMTypeConverter *typeConverter,
ConversionPatternRewriter &rewriter, Location loc,
Operation *op, Value a, Value b, Value c, Value d,
@@ -427,7 +449,7 @@ LogicalResult convertDot(TritonGPUToLLVMTypeConverter *typeConverter,
rewriter.create<triton::nvgpu::WGMMACommitGroupOp>(loc);
if (sync)
rewriter.create<triton::nvgpu::WGMMAWaitGroupOp>(loc, 0);
mmaResults = emitWait(rewriter, loc, mmaResults, 0);
SmallVector<Value> results =
unpackAccumulator(rewriter, loc, mmaResults, dTensorTy);

View File

@@ -6,7 +6,24 @@ using ::mlir::triton::gpu::getTotalElemsPerThread;
/* ----- FP8E5M2 ------ */
// This data-type is the standard FP8E5M2 format
static const std::string Fp16_to_Fp8E5M2(bool hasNativeFP) {
std::string ret;
if (!hasNativeFP) {
ret = "{ \n"
".reg .b32 a<2>; \n"
"and.b32 a0, $1, 0xfffefffe; \n" // a0 &= 0xfffefffe
"and.b32 a1, $2, 0xfffefffe; \n" // (strip lowest bit)
"add.u32 a0, a0, 0x00800080; \n" // a0 += 0x00800080
"add.u32 a1, a1, 0x00800080; \n" // (round to nearest)
"prmt.b32 $0, a0, a1, 0x7531; \n\t" // output = a1a0
"}";
} else {
ret = "cvt.rn.satfinite.e5m2x2.f16x2 $0, $1; \n\t";
}
return ret;
}
<<<<<<< HEAD
#ifdef USE_ROCM
static SmallVector<Value>
Fp16_to_Fp8E5M2(Location loc, ConversionPatternRewriter &rewriter,
@@ -356,11 +373,115 @@ const std::string Bf16_to_Fp8E5M2 =
"or.b32 $0, nosign, sign; \n" // restore sign
"}";
#endif
=======
static const std::string Fp8E5M2_to_Fp16(bool hasNativeFP) {
std::string ret;
if (!hasNativeFP) {
ret = "{ \n"
"prmt.b32 $0, 0, $2, 0x5140; \n\t"
"prmt.b32 $1, 0, $2, 0x7362; \n\t"
"}";
} else {
ret = "cvt.rn.f16x2.e5m2x2 $0, $1; \n\t";
}
return ret;
}
static const std::string Fp8E5M2_to_Bf16(bool hasNativeFP) {
std::string ret;
if (!hasNativeFP) {
ret =
"{ \n"
".reg .b32 a<2>, b<2>; \n" // if input = 0xf1f2f3f4
"prmt.b32 a0, 0, $2, 0x5140; \n" // a0 = 0xf300f400
"prmt.b32 a1, 0, $2, 0x7362; \n" // a1 = 0xf100f200
"lop3.b32 b0, a0, 0x7fff7fff, 0, 0xc0; \n" // b0 = a0 & 0x7fff7fff
"lop3.b32 b1, a1, 0x7fff7fff, 0, 0xc0; \n" // (strip sign)
"shr.b32 b0, b0, 3; \n" // b0 >>= 3
"shr.b32 b1, b1, 3; \n" // shift into bf16 position
"add.u32 b0, b0, 0x38003800; \n" // b0.exp += 2**7-2**4
// exponent compensate = 112
"add.u32 b1, b1, 0x38003800; \n" // b1 += 112<<7 | 112<<7<<16
"lop3.b32 $0, b0, 0x80008000, a0, 0xf8; \n" // out0 = b0|(0x80008000&a0)
"lop3.b32 $1, b1, 0x80008000, a1, 0xf8; \n" // (restore sign)
"}";
} else {
ret = "{ \n"
".reg .b32 a; \n"
".reg .f16 a<2>; \n"
".reg .b16 b<2>; \n"
"cvt.rn.f16x2.e5m2x2 a, $1; \n"
"mov.b32 {a0, a1}, a; \n"
"cvt.bf16.f16 b0, a0; \n"
"cvt.bf16.f16 b1, a1; \n"
"mov.b32 $0, {b0, b1}; \n"
"}";
}
return ret;
}
>>>>>>> 721897fcc4f942aa97d2e9ba3787a5e213758177
static const std::string Bf16_to_Fp8E5M2(bool hasNativeFP) {
std::string ret;
if (!hasNativeFP) {
ret =
"{ \n" // bf16=fp8>>3 + 112<<7
".reg .u32 sign, sign<2>, nosign, nosign<2>; \n" // fp8_min = 0b00000000
".reg .u32 fp8_min, fp8_max, rn_; \n" // fp8_max = 0b11111111
"mov.u32 fp8_min, 0x38003800; \n" // so bf16_min = 0x3800
"mov.u32 fp8_max, 0x57e057e0; \n" // so bf16_max = 0x57e0
"mov.u32 rn_, 0x00100010; \n" // round to nearest
"and.b32 sign0, $1, 0x80008000; \n" // sign0=in0&0x80008000
"and.b32 sign1, $2, 0x80008000; \n" // (store sign)
"prmt.b32 sign, sign0, sign1, 0x7531; \n"
"and.b32 nosign0, $1, 0x7fff7fff; \n" // nosign0=in0&0x7fff7fff
"and.b32 nosign1, $2, 0x7fff7fff; \n" // (strip sign)
// nosign = clamp(nosign, min, max)
".reg .u32 nosign_0_<2>, nosign_1_<2>; \n"
"and.b32 nosign_0_0, nosign0, 0xffff0000; \n"
"max.u32 nosign_0_0, nosign_0_0, 0x38000000; \n"
"min.u32 nosign_0_0, nosign_0_0, 0x57e00000; \n"
"and.b32 nosign_0_1, nosign0, 0x0000ffff; \n"
"max.u32 nosign_0_1, nosign_0_1, 0x3800; \n"
"min.u32 nosign_0_1, nosign_0_1, 0x57e0; \n"
"or.b32 nosign0, nosign_0_0, nosign_0_1; \n"
"and.b32 nosign_1_0, nosign1, 0xffff0000; \n"
"max.u32 nosign_1_0, nosign_1_0, 0x38000000; \n"
"min.u32 nosign_1_0, nosign_1_0, 0x57e00000; \n"
"and.b32 nosign_1_1, nosign1, 0x0000ffff; \n"
"max.u32 nosign_1_1, nosign_1_1, 0x3800; \n"
"min.u32 nosign_1_1, nosign_1_1, 0x57e0; \n"
"or.b32 nosign1, nosign_1_0, nosign_1_1; \n"
"add.u32 nosign0, nosign0, rn_; \n" // nosign0 += rn_
"add.u32 nosign1, nosign1, rn_; \n" // (round to nearest)
"sub.u32 nosign0, nosign0, 0x38003800; \n" // nosign0-=0x38003800
"sub.u32 nosign1, nosign1, 0x38003800; \n" // (compensate offset)
"shl.b32 nosign0, nosign0, 3; \n" // nosign0 <<= 3
"shl.b32 nosign1, nosign1, 3; \n" // shift into to fp8e4
"prmt.b32 nosign, nosign0, nosign1, 0x7531; \n" // nosign0 = 0xf100f200
// nosign1 = 0xf300f400
// nosign = 0xf3f4f1f2
"or.b32 $0, nosign, sign; \n" // restore sign
"}";
} else {
ret = "{ \n"
".reg .b16 a<2>; \n"
".reg .f32 b<2>; \n"
"mov.b32 {a0, a1}, $1; \n"
"cvt.f32.bf16 b0, a0; \n"
"cvt.f32.bf16 b1, a1; \n"
"cvt.rn.satfinite.e5m2x2.f32 $0, b0, b1; \n"
"}";
}
return ret;
}
/* ----- FP8E4M3B15 ------ */
// This data-type is a variant of the standard FP8E4M3 format.
// It was designed for fast software conversion to FP16 on
// nvidia GPUs that do not support it natively.
<<<<<<< HEAD
// Specifically, this data-type:
// - has infinities
// - has multiple nans (when all exponent bits are 1)
@@ -404,6 +525,11 @@ Fp8E4M3B15_to_Fp16(Location loc, ConversionPatternRewriter &rewriter,
};
}
#else
=======
// This is the same format as FP8E4M3Nv, but:
// - the exponent bias is 15 instead of 7
// - 0xff and 0x7f are mapped to +-1.750 instead of +-nan
>>>>>>> 721897fcc4f942aa97d2e9ba3787a5e213758177
const std::string Fp8E4M3B15_to_Fp16 =
"{ \n"
".reg .b32 a<2>, b<2>; \n"
@@ -416,6 +542,7 @@ const std::string Fp8E4M3B15_to_Fp16 =
"lop3.b32 $0, b0, 0x80008000, a0, 0xf8; \n"
"shl.b32 $1, b1, 7; \n"
"} \n";
<<<<<<< HEAD
#endif
#ifdef USE_ROCM
@@ -464,6 +591,10 @@ Fp16_to_Fp8E4M3B15(Location loc, ConversionPatternRewriter &rewriter,
}
#else
const std::string Fp16_to_Fp8E4M3B15(bool has_minx2) {
=======
static const std::string Fp16_to_Fp8E4M3B15(bool has_minx2) {
>>>>>>> 721897fcc4f942aa97d2e9ba3787a5e213758177
std::string ret;
ret += "{ \n"
".reg .pred p<4>; \n"
@@ -509,6 +640,7 @@ const std::string Fp16_to_Fp8E4M3B15(bool has_minx2) {
// $0 = (($2 << 1) & 0x80008000u) | (($2 << 7) & 0x3f803f80u);
// $1 = (($2 << 0) & 0x80008000u) | (($2 << 0) & 0x3f803f80u);
// WARN: subnormal (0bs0000xxx) are not handled
<<<<<<< HEAD
#ifdef USE_ROCM
static SmallVector<Value>
Fp8E4M3B15x4_to_Fp16(Location loc, ConversionPatternRewriter &rewriter,
@@ -540,6 +672,9 @@ Fp8E4M3B15x4_to_Fp16(Location loc, ConversionPatternRewriter &rewriter,
}
#else
const std::string Fp8E4M3B15x4_to_Fp16 =
=======
static const std::string Fp8E4M3B15x4_to_Fp16 =
>>>>>>> 721897fcc4f942aa97d2e9ba3787a5e213758177
"{ \n"
".reg .b32 a<2>; \n"
"add.u32 a0, $2, $2; \n"
@@ -557,6 +692,7 @@ const std::string Fp8E4M3B15x4_to_Fp16 =
// ((e4.y >> 0) & (0x80008000u >> 0)) |
// ((e4.y >> 0) & (0x3f803f80u >> 0)) ;
// WARN: subnormal (0bs0000xxx) are not handled
<<<<<<< HEAD
#ifdef USE_ROCM
static SmallVector<Value>
Fp16_to_Fp8E4M3B15x4(Location loc, ConversionPatternRewriter &rewriter,
@@ -591,6 +727,9 @@ Fp16_to_Fp8E4M3B15x4(Location loc, ConversionPatternRewriter &rewriter,
}
#else
const std::string Fp16_to_Fp8E4M3B15x4 =
=======
static const std::string Fp16_to_Fp8E4M3B15x4 =
>>>>>>> 721897fcc4f942aa97d2e9ba3787a5e213758177
"{ \n"
".reg .b32 a<2>; \n"
"shr.b32 a0, $1, 1; \n"
@@ -904,17 +1043,18 @@ const std::string Bf16_to_Fp8E4M3 =
#endif
// Fp8E4M3 (x2) -> Fp16 (x2) (packed)
const std::string Fp8E4M3Nv_to_Fp16 = "{ \n"
"cvt.rn.f16x2.e4m3x2 $0, $1; \n"
"}";
static const std::string Fp8E4M3Nv_to_Fp16 = "{ \n"
"cvt.rn.f16x2.e4m3x2 $0, $1; \n"
"}";
// Fp16 (x2) -> Fp8E4M3 (x2) (packed)
const std::string Fp16_to_Fp8E4M3Nv = "{ \n"
"cvt.rn.satfinite.e4m3x2.f16x2 $0, $1; \n"
"}";
static const std::string Fp16_to_Fp8E4M3Nv =
"{ \n"
"cvt.rn.satfinite.e4m3x2.f16x2 $0, $1; \n"
"}";
#ifndef USE_ROCM
// Fp8E4M3 (x2) -> Fp16 (x2) (packed)
const std::string Fp8E4M3Nv_to_Bf16 =
static const std::string Fp8E4M3Nv_to_Bf16 =
"{ \n"
".reg .b32 a; \n"
".reg .f16 a<2>; \n"
@@ -927,7 +1067,7 @@ const std::string Fp8E4M3Nv_to_Bf16 =
"}";
// Bf16 (x2) -> Fp8E4M3 (x2) (packed)
const std::string Bf16_to_Fp8E4M3Nv =
static const std::string Bf16_to_Fp8E4M3Nv =
"{ \n"
".reg .b16 a<2>; \n"
".reg .f32 b<2>; \n"
@@ -938,7 +1078,7 @@ const std::string Bf16_to_Fp8E4M3Nv =
"}";
/* ----- Packed integer to BF16 ------ */
const std::string S8_to_Bf16 =
static const std::string S8_to_Bf16 =
"{ \n"
".reg .s8 s<4>; \n"
".reg .f32 f<4>; \n"
@@ -952,6 +1092,12 @@ const std::string S8_to_Bf16 =
"}";
#endif
// Fp32 (x2) -> Fp8 (x2) (packed)
static const std::string Fp32_to_Fp8E4M3Nv =
"cvt.rn.satfinite.e4m3x2.f32 $0, $2, $1; \n";
static const std::string Fp32_to_Fp8E5M2 =
"cvt.rn.satfinite.e5m2x2.f32 $0, $2, $1; \n";
static SmallVector<Value> reorderValues(const SmallVector<Value> &values,
Type inType, Type ouType) {
auto inTensorTy = inType.dyn_cast<RankedTensorType>();
@@ -1383,9 +1529,14 @@ struct FpToFpOpConversion
// F8 -> F16
{{F8E4M3B15TyID, F16TyID}, Fp8E4M3B15_to_Fp16},
{{F8E4M3FNTyID, F16TyID}, Fp8E4M3B15x4_to_Fp16},
<<<<<<< HEAD
{{F8E4M3FNUZTyID, F16TyID}, Fp8E4M3FNUZ_to_Fp16},
{{F8E5M2TyID, F16TyID}, Fp8E5M2_to_Fp16},
{{F8E5M2FNUZTyID, F16TyID}, Fp8E5M2FNUZ_to_Fp16},
=======
{{F8E4M3TyID, F16TyID}, Fp8E4M3Nv_to_Fp16},
{{F8E5M2TyID, F16TyID}, Fp8E5M2_to_Fp16(computeCapability >= 90)},
>>>>>>> 721897fcc4f942aa97d2e9ba3787a5e213758177
// F16 -> F8
#ifdef USE_ROCM
{{F16TyID, F8E4M3B15TyID}, Fp16_to_Fp8E4M3B15},
@@ -1393,27 +1544,44 @@ struct FpToFpOpConversion
{{F16TyID, F8E4M3B15TyID}, Fp16_to_Fp8E4M3B15(computeCapability >= 80)},
#endif
{{F16TyID, F8E4M3FNTyID}, Fp16_to_Fp8E4M3B15x4},
<<<<<<< HEAD
{{F16TyID, F8E4M3FNUZTyID}, Fp16_to_Fp8E4M3FNUZ},
{{F16TyID, F8E5M2TyID}, Fp16_to_Fp8E5M2},
{{F16TyID, F8E5M2FNUZTyID}, Fp16_to_Fp8E5M2FNUZ},
// F8 -> BF16
{{F8E5M2TyID, BF16TyID}, Fp8E5M2_to_Bf16},
#ifndef USE_ROCM
=======
{{F16TyID, F8E4M3TyID}, Fp16_to_Fp8E4M3Nv},
{{F16TyID, F8E5M2TyID}, Fp16_to_Fp8E5M2(computeCapability >= 90)},
// F8 -> BF16
{{F8E5M2TyID, BF16TyID}, Fp8E5M2_to_Bf16(computeCapability >= 90)},
>>>>>>> 721897fcc4f942aa97d2e9ba3787a5e213758177
{{F8E4M3TyID, BF16TyID}, Fp8E4M3Nv_to_Bf16},
#endif
// BF16 -> F8
<<<<<<< HEAD
{{BF16TyID, F8E5M2TyID}, Bf16_to_Fp8E5M2},
#ifndef USE_ROCM
{{BF16TyID, F8E4M3TyID}, Bf16_to_Fp8E4M3Nv},
#endif
=======
{{BF16TyID, F8E5M2TyID}, Bf16_to_Fp8E5M2(computeCapability >= 90)},
{{BF16TyID, F8E4M3TyID}, Bf16_to_Fp8E4M3Nv},
// F32 -> F8
{{F32TyID, F8E4M3TyID}, Fp32_to_Fp8E4M3Nv},
{{F32TyID, F8E5M2TyID}, Fp32_to_Fp8E5M2},
>>>>>>> 721897fcc4f942aa97d2e9ba3787a5e213758177
};
int inVecWidthBits = 32;
int outVecWidthBits = 32;
if (srcTy.isFloat8E4M3FNUZ()) {
if (srcTy.isFloat8E4M3FNUZ() ||
(computeCapability >= 90 && srcTy.isFloat8E5M2())) {
inVecWidthBits = 16;
outVecWidthBits = 32;
}
if (dstTy.isFloat8E4M3FNUZ()) {
if (dstTy.isFloat8E4M3FNUZ() ||
(computeCapability >= 90 && dstTy.isFloat8E5M2())) {
inVecWidthBits = 32;
outVecWidthBits = 16;
}
@@ -1450,18 +1618,24 @@ struct FpToFpOpConversion
size_t numElements = 4;
if (srcElementType.isFloat8E4M3FNUZ() ||
dstElementType.isFloat8E4M3FNUZ()) {
dstElementType.isFloat8E4M3FNUZ() ||
(computeCapability >= 90 &&
(srcElementType.isFloat8E5M2() || dstElementType.isFloat8E5M2()))) {
numElements = 2;
}
bool isSrcFP32 = srcElementType.isF32();
bool useFP16IntermediateSrc =
srcElementType.isF32() &&
!(computeCapability >= 90 &&
(dstElementType.isFloat8E4M3FNUZ() || dstElementType.isFloat8E5M2()));
bool isDstFP32 = dstElementType.isF32();
auto cvtFunc = getConversionFunc(isSrcFP32 ? f16_ty : srcElementType,
isDstFP32 ? f16_ty : dstElementType);
auto cvtFunc =
getConversionFunc(useFP16IntermediateSrc ? f16_ty : srcElementType,
isDstFP32 ? f16_ty : dstElementType);
SmallVector<Value> inVals;
for (unsigned i = 0; i < std::min(numElements, operands.size()); i++) {
inVals.push_back(operands[i][0]);
}
if (isSrcFP32)
if (useFP16IntermediateSrc)
for (Value &v : inVals)
v = convertFp32ToFp16(loc, rewriter, v);
inVals.resize(numElements,
@@ -2115,18 +2289,18 @@ void populateElementwiseOpToLLVMPatterns(
POPULATE_BINARY_OP(arith::RemFOp, LLVM::FRemOp) // %
POPULATE_BINARY_OP(arith::RemSIOp, LLVM::SRemOp)
POPULATE_BINARY_OP(arith::RemUIOp, LLVM::URemOp)
POPULATE_BINARY_OP(arith::AndIOp, LLVM::AndOp) // &
POPULATE_BINARY_OP(arith::OrIOp, LLVM::OrOp) // |
POPULATE_BINARY_OP(arith::XOrIOp, LLVM::XOrOp) // ^
POPULATE_BINARY_OP(arith::ShLIOp, LLVM::ShlOp) // <<
POPULATE_BINARY_OP(arith::ShRSIOp, LLVM::AShrOp) // >>
POPULATE_BINARY_OP(arith::ShRUIOp, LLVM::LShrOp) // >>
POPULATE_BINARY_OP(arith::MinFOp, LLVM::MinNumOp) // fmin
POPULATE_BINARY_OP(arith::MaxFOp, LLVM::MaxNumOp) // fmax
POPULATE_BINARY_OP(arith::MinSIOp, LLVM::SMinOp) // smin
POPULATE_BINARY_OP(arith::MaxSIOp, LLVM::SMaxOp) // smax
POPULATE_BINARY_OP(arith::MinUIOp, LLVM::UMinOp) // umin
POPULATE_BINARY_OP(arith::MaxUIOp, LLVM::UMaxOp) // umax
POPULATE_BINARY_OP(arith::AndIOp, LLVM::AndOp) // &
POPULATE_BINARY_OP(arith::OrIOp, LLVM::OrOp) // |
POPULATE_BINARY_OP(arith::XOrIOp, LLVM::XOrOp) // ^
POPULATE_BINARY_OP(arith::ShLIOp, LLVM::ShlOp) // <<
POPULATE_BINARY_OP(arith::ShRSIOp, LLVM::AShrOp) // >>
POPULATE_BINARY_OP(arith::ShRUIOp, LLVM::LShrOp) // >>
POPULATE_BINARY_OP(arith::MinimumFOp, LLVM::MinNumOp) // fmin
POPULATE_BINARY_OP(arith::MaximumFOp, LLVM::MaxNumOp) // fmax
POPULATE_BINARY_OP(arith::MinSIOp, LLVM::SMinOp) // smin
POPULATE_BINARY_OP(arith::MaxSIOp, LLVM::SMaxOp) // smax
POPULATE_BINARY_OP(arith::MinUIOp, LLVM::UMinOp) // umin
POPULATE_BINARY_OP(arith::MaxUIOp, LLVM::UMaxOp) // umax
#undef POPULATE_BINARY_OP
#define POPULATE_UNARY_OP(SRC_OP, DST_OP) \

View File

@@ -1549,10 +1549,12 @@ struct InsertSliceAsyncOpConversion
auto srcTy = src.getType().cast<RankedTensorType>();
auto resTy = dst.getType().cast<RankedTensorType>();
auto resElemTy = getTypeConverter()->convertType(resTy.getElementType());
auto srcBlockedLayout = srcTy.getEncoding().cast<BlockedEncodingAttr>();
auto srcLayout = srcTy.getEncoding();
assert((srcLayout.isa<BlockedEncodingAttr, SliceEncodingAttr>() &&
"Unexpected srcLayout in InsertSliceAsyncOpConversion"));
auto resSharedLayout = resTy.getEncoding().cast<SharedEncodingAttr>();
auto srcShape = srcTy.getShape();
assert(srcShape.size() == 2 &&
assert((srcShape.size() == 1 || srcShape.size() == 2) &&
"insert_slice_async: Unexpected rank of %src");
Value llDst = adaptor.getDst();
@@ -1617,25 +1619,15 @@ struct InsertSliceAsyncOpConversion
unsigned numElems = getTotalElemsPerThread(srcTy);
unsigned perPhase = resSharedLayout.getPerPhase();
unsigned maxPhase = resSharedLayout.getMaxPhase();
auto sizePerThread = srcBlockedLayout.getSizePerThread();
auto threadsPerCTA = getThreadsPerCTA(srcBlockedLayout);
auto inOrder = srcBlockedLayout.getOrder();
DenseMap<unsigned, Value> sharedPtrs =
getSwizzledSharedPtrs(loc, inVec, srcTy, resSharedLayout, resElemTy,
smemObj, rewriter, offsetVals, srcStrides);
// If perPhase * maxPhase > threadsPerCTA, we will have elements
// that share the same tile indices. The index calculation will
// be cached.
auto numSwizzleRows = std::max<unsigned>(
(perPhase * maxPhase) / threadsPerCTA[inOrder[1]], 1);
// A sharedLayout encoding has a "vec" parameter.
// On the column dimension, if inVec > outVec, it means we have to divide
// single vector read into multiple ones
auto numVecCols = std::max<unsigned>(inVec / outVec, 1);
auto srcIndices = emitIndices(loc, rewriter, srcBlockedLayout, srcTy);
for (unsigned elemIdx = 0; elemIdx < numElems; elemIdx += minVec) {
// 16 * 8 = 128bits
auto maxBitWidth =

View File

@@ -419,16 +419,15 @@ private:
getMultiDimWarpId(helper, warpId, loc, rewriter);
Value warpIdAxis = multiDimWarpId[axis];
if (!helper.isReductionOnLayoutFastAxis()) {
std::reverse(order.begin(), order.end());
}
auto smemOrder = helper.getOrderWithAxisAtBeginning();
for (auto it : accs) {
const SmallVector<unsigned> &key = it.first;
SmallVector<Value> acc = it.second;
SmallVector<Value> writeIdx = indices[key];
writeIdx[axis] = warpIdAxis;
Value writeOffset = linearize(rewriter, loc, writeIdx, smemShape, order);
Value writeOffset =
linearize(rewriter, loc, writeIdx, smemShape, smemOrder);
for (unsigned i = 0; i < op.getNumOperands(); ++i) {
auto elemPtrTy = getElementPtrType(op, i);
Value writePtr = gep(elemPtrTy, smemBases[i], writeOffset);
@@ -513,10 +512,7 @@ private:
Location loc = op.getLoc();
auto srcLayout = helper.getSrcLayout();
auto axis = op.getAxis();
auto order = getOrder(srcLayout);
if (!helper.isReductionOnLayoutFastAxis()) {
std::reverse(order.begin(), order.end());
}
auto smemOrder = helper.getOrderWithAxisAtBeginning();
SmallVector<Value> results(op.getNumOperands());
for (unsigned i = 0; i < op.getNumOperands(); ++i) {
if (auto resultTy =
@@ -532,7 +528,7 @@ private:
SmallVector<Value> readIdx = resultIndices[j];
readIdx.insert(readIdx.begin() + op.getAxis(), i32_val(0));
Value readOffset =
linearize(rewriter, loc, readIdx, smemShape, order);
linearize(rewriter, loc, readIdx, smemShape, smemOrder);
Value readPtr =
gep(getElementPtrType(op, i), smemBases[i], readOffset);
resultVals[j] = load(readPtr);

View File

@@ -622,10 +622,13 @@ struct AllocTensorOpConversion
// TODO: we need to modify the pipeline pass to give a proper shared
// encoding to 3D tensors
SmallVector<unsigned> newOrder;
if (resultTy.getShape().size() == 3)
newOrder = {1 + order[0], 1 + order[1], 0};
else
if (resultTy.getShape().size() != order.size()) {
for (auto i = 0; i < order.size(); ++i)
newOrder.push_back(order[i] + 1);
newOrder.push_back(0);
} else {
newOrder = SmallVector<unsigned>(order.begin(), order.end());
}
auto shapePerCTA = getShapePerCTA(sharedLayout, resultTy.getShape());
auto smemObj =
@@ -659,10 +662,13 @@ struct ExtractSliceOpConversion
SmallVector<Value, 4> opOffsetVals;
SmallVector<Value, 4> offsetVals;
auto mixedOffsets = op.getMixedOffsets();
for (auto i = 0; i < mixedOffsets.size(); ++i) {
if (op.isDynamicOffset(i))
opOffsetVals.emplace_back(adaptor.getOffsets()[i]);
else
for (auto i = 0, j = 0; i < mixedOffsets.size(); ++i) {
if (op.isDynamicOffset(i)) {
// adaptor.getOffsets() returns list of variable offsets. the size of
// the list may not be the same as mixedOffsets
opOffsetVals.emplace_back(adaptor.getOffsets()[j]);
++j;
} else
opOffsetVals.emplace_back(i32_val(op.getStaticOffset(i)));
offsetVals.emplace_back(add(smemObj.offsets[i], opOffsetVals[i]));
}

View File

@@ -146,7 +146,8 @@ protected:
}
auto newFuncOp = rewriter.create<LLVM::LLVMFuncOp>(
funcOp.getLoc(), funcOp.getName(), llvmType, linkage,
/*dsoLocal*/ false, LLVM::CConv::C, attributes);
/*dsoLocal*/ false, LLVM::CConv::C, /*comdat=*/SymbolRefAttr{},
attributes);
rewriter.inlineRegionBefore(funcOp.getBody(), newFuncOp.getBody(),
newFuncOp.end());
if (failed(rewriter.convertRegionTypes(&newFuncOp.getBody(), *typeConverter,
@@ -361,8 +362,13 @@ public:
unsigned numElemsPerSwizzlingRow =
swizzlingByteWidth * 8 / resElemTy.getIntOrFloatBitWidth();
Value numElemsPerSwizzlingRowVal = i32_val(numElemsPerSwizzlingRow);
unsigned leadingDimOffset =
numElemsPerSwizzlingRow * srcShapePerCTA[outOrder[1]];
unsigned leadingDimOffset;
if (outOrder.size() == 2) {
leadingDimOffset = numElemsPerSwizzlingRow * srcShapePerCTA[outOrder[1]];
} else {
leadingDimOffset = numElemsPerSwizzlingRow;
}
Value leadingDimOffsetVal = i32_val(leadingDimOffset);
// Return values
DenseMap<unsigned, Value> ret;
@@ -374,9 +380,15 @@ public:
// Extract multi dimensional index for current element
auto idx = srcIndices[elemIdx];
Value idxCol = idx[outOrder[0]]; // contiguous dimension
Value idxRow = idx[outOrder[1]]; // discontiguous dimension
Value idxRow, strideRow;
if (outOrder.size() == 2) {
idxRow = idx[outOrder[1]]; // discontiguous dimension
strideRow = srcStrides[outOrder[1]];
} else {
idxRow = i32_val(0);
strideRow = i32_val(0);
}
Value strideCol = srcStrides[outOrder[0]];
Value strideRow = srcStrides[outOrder[1]];
// compute phase = (row // perPhase) % maxPhase
Value phase = urem(udiv(idxRow, i32_val(perPhase)), i32_val(maxPhase));
// extract dynamic/static offset for immediate offsetting
@@ -428,10 +440,16 @@ public:
offset = add(offset, add(rowOff, mul(colOff, strideCol)));
Value currPtr = gep(dstPtrTy, dstPtrBase, offset);
// compute immediate offset
Value immedateOff =
add(mul(i32_val(immedateOffRow), srcStrides[outOrder[1]]),
i32_val(immedateOffCol));
ret[elemIdx] = gep(dstPtrTy, currPtr, immedateOff);
Value immediateOff;
if (outOrder.size() == 2) {
immediateOff =
add(mul(i32_val(immedateOffRow), srcStrides[outOrder[1]]),
i32_val(immedateOffCol));
} else {
immediateOff = i32_val(immedateOffCol);
}
ret[elemIdx] = gep(dstPtrTy, currPtr, immediateOff);
}
return ret;
}

View File

@@ -371,13 +371,15 @@ static Value commonShflSync(Location loc, ConversionPatternRewriter &rewriter,
Type type = val.getType();
if (type != i32_ty) {
val = bitcast(val, int_ty(bits));
val = zext(i32_ty, val);
if (bits < 32)
val = zext(i32_ty, val);
}
Value mask = i32_val(0xFFFFFFFF);
Value result = rewriter.create<NVVM::ShflOp>(loc, i32_ty, mask, val, i, clamp,
mode, UnitAttr());
if (type != i32_ty) {
result = trunc(int_ty(bits), result);
if (bits < 32)
result = trunc(int_ty(bits), result);
result = bitcast(result, type);
}
return result;

View File

@@ -97,7 +97,7 @@
ptxBuilder.launch(rewriter, op->getLoc(), voidTy); \
} while (0)
#define undef(...) rewriter.create<LLVM::UndefOp>(loc, __VA_ARGS__)
#define null(...) rewriter.create<LLVM::NullOp>(loc, __VA_ARGS__)
#define null(...) rewriter.create<LLVM::ZeroOp>(loc, __VA_ARGS__)
#define call(...) rewriter.create<LLVM::CallOp>(loc, __VA_ARGS__)
// Types

View File

@@ -115,8 +115,8 @@ void populateArithPatternsAndLegality(TritonGPUTypeConverter &typeConverter,
// Floating point
GenericOpPattern<arith::AddFOp>, GenericOpPattern<arith::SubFOp>,
// MaxMin
GenericOpPattern<arith::MaxFOp>, GenericOpPattern<arith::MaxSIOp>,
GenericOpPattern<arith::MaxUIOp>, GenericOpPattern<arith::MinFOp>,
GenericOpPattern<arith::MaximumFOp>, GenericOpPattern<arith::MaxSIOp>,
GenericOpPattern<arith::MaxUIOp>, GenericOpPattern<arith::MinimumFOp>,
GenericOpPattern<arith::MinSIOp>, GenericOpPattern<arith::MinUIOp>,
// Floating point
GenericOpPattern<arith::MulFOp>, GenericOpPattern<arith::DivFOp>,

View File

@@ -1,9 +1,9 @@
#include "mlir/IR/Builders.h"
#include "mlir/IR/BuiltinAttributes.h"
#include "mlir/IR/BuiltinTypes.h"
#include "mlir/IR/FunctionImplementation.h"
#include "mlir/IR/FunctionInterfaces.h"
#include "mlir/IR/OperationSupport.h"
#include "mlir/Interfaces/FunctionImplementation.h"
#include "mlir/Interfaces/FunctionInterfaces.h"
#include "triton/Dialect/Triton/IR/Dialect.h"
#include "triton/Dialect/Triton/IR/Types.h"

View File

@@ -8,6 +8,6 @@ add_mlir_dialect_library(TritonGPUIR
TritonGPUAttrDefsIncGen
LINK_LIBS PUBLIC
MLIRGPUOps
MLIRGPUDialect
TritonIR
)

View File

@@ -1603,7 +1603,7 @@ ParseResult parseInsertSliceOp(OpAsmParser &parser, OperationState &result) {
result.operands))
return failure();
// Deduce operand_segment_sizes from the number of the operands.
// Deduce operandSegmentSizes from the number of the operands.
auto operandSegmentSizesAttrName =
OpT::getOperandSegmentSizesAttrName(result.name);
result.addAttribute(
@@ -1616,7 +1616,7 @@ template <class OpT>
void printInsertSliceOp(OpAsmPrinter &printer, OpT insertSliceOp) {
printer << " ";
printer << insertSliceOp.getOperation()->getOperands();
// "operand_segment_sizes" can be deduced, so we don't print it.
// "operandSegmentSizes" can be deduced, so we don't print it.
printer.printOptionalAttrDict(
insertSliceOp->getAttrs(),
{insertSliceOp.getOperandSegmentSizesAttrName()});

View File

@@ -139,7 +139,10 @@ class BlockedToMMA : public mlir::RewritePattern {
int finalBitWidth = getElementTypeOrSelf(x).getIntOrFloatBitWidth();
int origBitWidth = finalBitWidth;
SetVector<Operation *> slice;
mlir::getBackwardSlice(x, &slice, bwdFilter);
mlir::BackwardSliceOptions opt;
opt.omitBlockArguments = true;
opt.filter = bwdFilter;
getBackwardSlice(x, &slice, opt);
Operation *firstOp = slice.empty() ? nullptr : *slice.begin();
if (firstOp)
if (Value arg = firstOp->getOperand(0))
@@ -235,8 +238,11 @@ public:
if (versionMajor == 1) {
SetVector<Operation *> aBwdSlices, bBwdSlices;
auto isCvt = [](Operation *op) { return isa<ConvertLayoutOp>(op); };
getBackwardSlice(a, &aBwdSlices, {isCvt});
getBackwardSlice(b, &bBwdSlices, {isCvt});
mlir::BackwardSliceOptions opt;
opt.omitBlockArguments = true;
opt.filter = isCvt;
getBackwardSlice(a, &aBwdSlices, opt);
getBackwardSlice(b, &bBwdSlices, opt);
// get the source of the first conversion found in slices
auto getCvtArgOrder = [](Operation *op) {
return cast<ConvertLayoutOp>(op)

View File

@@ -98,7 +98,9 @@ public:
// and all operations between the load and the conversion
// should be layout preserving
SetVector<Operation *> slice;
getBackwardSlice(op, &slice);
mlir::BackwardSliceOptions opt;
opt.omitBlockArguments = true;
getBackwardSlice(op, &slice, opt);
int loadIdx = -1;
bool checkOp = false;
for (int i = 0; i < slice.size(); i++) {

View File

@@ -160,6 +160,8 @@ class LoopPipeliner {
void checkOpShareBarriers(SetVector<Operation *> &ops);
int numLoadsRequireAsyncWait = 0;
int numLoadsRequireMBarrier = 0;
// Number of buffers to allocate for each input.
int numSharedMemorySlices = 0;
/// Iterator values
Value nextIV;
@@ -280,9 +282,12 @@ class LoopPipeliner {
public:
LoopPipeliner(scf::ForOp forOp, int numStages, int numWarps, int numCTAs,
bool mode, ConsumerReleaseMap &consumerReleaseMap)
bool mode, int numSharedMemorySlices,
ConsumerReleaseMap &consumerReleaseMap)
: forOp(forOp), numStages(numStages), numWarps(numWarps),
numCTAs(numCTAs), mode(mode), consumerReleaseMap(consumerReleaseMap) {
numCTAs(numCTAs), mode(mode),
numSharedMemorySlices(numSharedMemorySlices),
consumerReleaseMap(consumerReleaseMap) {
// cache yieldOp
yieldOp = cast<scf::YieldOp>(forOp.getBody()->getTerminator());
}
@@ -644,7 +649,7 @@ void LoopPipeliner::createBufferTypes() {
auto ty = loadOp.getType().cast<RankedTensorType>();
SmallVector<int64_t> bufferShape(ty.getShape().begin(),
ty.getShape().end());
bufferShape.insert(bufferShape.begin(), numStages);
bufferShape.insert(bufferShape.begin(), numSharedMemorySlices);
auto CTALayout = ttg::getCTALayout(ty.getEncoding());
Attribute sharedEnc;
if (auto dotOpEnc = cvt.getType()
@@ -946,6 +951,11 @@ void LoopPipeliner::emitPrologue() {
pipelineIterIdx = builder.create<arith::AddIOp>(
iv.getLoc(), pipelineIterIdx,
builder.create<arith::ConstantIntOp>(iv.getLoc(), 1, 32));
Value numSlices = builder.create<arith::ConstantIntOp>(
iv.getLoc(), numSharedMemorySlices, 32);
Value _0 = builder.create<arith::ConstantIntOp>(iv.getLoc(), 0, 32);
pipelineIterIdx = getBoundedIterationValue(builder, pipelineIterIdx,
numSlices, pipelineIterIdx, _0);
// Some values have not been used by any ops in the loop body
for (BlockArgument arg : forOp.getRegionIterArgs())
setValueMappingYield(arg, valueMapping[arg][stage], stage + 1);
@@ -1220,11 +1230,13 @@ void LoopPipeliner::prefetchNextIteration(scf::ForOp newForOp,
Value _1 = builder.create<arith::ConstantIntOp>(idxLoc, 1, 32);
Value numStagesVal =
builder.create<arith::ConstantIntOp>(idxLoc, numStages, 32);
Value numSlices =
builder.create<arith::ConstantIntOp>(idxLoc, numSharedMemorySlices, 32);
// nextWaitIdx
Value waitIdxPlusOne = builder.create<arith::AddIOp>(idxLoc, curWaitIdx, _1);
Value nextWaitIdx = getBoundedIterationValue(
builder, waitIdxPlusOne, numStagesVal, waitIdxPlusOne, _0);
Value nextWaitIdx = getBoundedIterationValue(builder, waitIdxPlusOne,
numSlices, waitIdxPlusOne, _0);
// Indices of InsertSliceAsyncOp and ExtractSliceOp
Value insertSliceIndex = pipelineIterIdx;
@@ -1417,9 +1429,8 @@ void LoopPipeliner::prefetchNextIteration(scf::ForOp newForOp,
// Bump pipelineIterIdx
Value pipelineIterIdxPlusOne =
builder.create<arith::AddIOp>(idxLoc, pipelineIterIdx, _1);
pipelineIterIdx =
getBoundedIterationValue(builder, pipelineIterIdxPlusOne, numStagesVal,
pipelineIterIdxPlusOne, _0);
pipelineIterIdx = getBoundedIterationValue(
builder, pipelineIterIdxPlusOne, numSlices, pipelineIterIdxPlusOne, _0);
// Bump curWaitIdx
curWaitIdx = nextWaitIdx;
@@ -1516,10 +1527,23 @@ struct PipelinePass : public TritonGPUPipelineBase<PipelinePass> {
// applyPatternsAndFoldGreedily(getOperation(), std::move(patterns));
llvm::SmallVector<scf::ForOp> newForOps;
// Currently we schedule stage 0 after stage `numStages - 1` during
// pipelining therefore we only need `numStages - 1` slice of memory.
// On Hopper we have a separate post-processing that pipelines wgmma so we
// need an extra buffer for each input.
// Note that an alternative would be to keep allocating `numStages` buffers
// and remove the barrier between the loads from shared memory and the
// copies from global to shared. This would require improving existing
// membar analysis.
int numSharedMemorySlices =
computeCapability < 90 ? numStages - 1 : numStages;
// Do the pipelining
getOperation()->walk([&](scf::ForOp forOp) -> void {
LoopPipeliner pipeliner(forOp, this->numStages, this->numWarps,
this->numCTAs, mode, consumerReleaseMap);
this->numCTAs, mode, numSharedMemorySlices,
consumerReleaseMap);
if (pipeliner.initialize().failed())
return;
@@ -1593,7 +1617,8 @@ void PipelinePass::asyncLaunchDots(scf::ForOp forOp) {
/// XXX(Keren): Clean up the following duplicate code with checkDotOp
/// dots to be pipelined
SetVector<Value> dots;
SmallVector<tt::DotOp> dots;
SmallVector<unsigned> resultNeedSync;
for (Operation &op : *loop) {
if (auto dotOp = dyn_cast<tt::DotOp>(&op)) {
auto resTy = dotOp.getResult().getType().dyn_cast<RankedTensorType>();
@@ -1615,8 +1640,11 @@ void PipelinePass::asyncLaunchDots(scf::ForOp forOp) {
if (!CArg || !CArg.hasOneUse())
valid = false;
if (valid)
dots.insert(dotOp);
if (valid) {
dots.push_back(dotOp);
resultNeedSync.push_back(
dotOp->getUses().begin()->getOperandNumber());
}
}
}
}
@@ -1627,39 +1655,39 @@ void PipelinePass::asyncLaunchDots(scf::ForOp forOp) {
return;
OpBuilder builder(forOp);
// 0. insert dot_wait after the last dot in the loop
Value dot = dots.back();
auto loc = dot.getLoc();
builder.setInsertionPointAfter(dot.getDefiningOp());
auto dotWait = builder.create<tt::nvidia_gpu::DotWaitOp>(loc, dots.size());
// 0. insert dot_wait after the last dot in the loop as we implicitly pipeline
// wgmma ops by one stage.
// This is needed to prevent shared memory inputs to be overriden before the
// operation is completed.
// TODO: merge this with the rest of the pipelining transformation and look at
// a better representation for async dots.
tt::DotOp lastDot = dots.back();
builder.setInsertionPointAfter(lastDot);
auto dotWait = builder.create<tt::nvidia_gpu::DotWaitOp>(
lastDot.getLoc(), lastDot.getResult(), dots.size());
// 1. replace Dot with DotAsync
for (size_t idx = 0; idx < dots.size(); ++idx) {
Value dot = dots[idx];
auto dotOp = cast<tt::DotOp>(dot.getDefiningOp());
builder.setInsertionPoint(dot.getDefiningOp());
tt::DotOp dotOp = dots[idx];
builder.setInsertionPoint(dotOp);
auto dotAsync = builder.create<tt::nvidia_gpu::DotAsyncOp>(
loc, dotOp.getA(), dotOp.getB(), dotOp.getC(), dotOp.getAllowTF32(),
dotOp.getMaxNumImpreciseAcc());
dot.replaceAllUsesWith(dotAsync.getResult());
updateConsumerReleaseInfo(dot.getDefiningOp(), dotWait, /*stage=*/1);
dot.getDefiningOp()->erase();
dotOp.getLoc(), dotOp.getA(), dotOp.getB(), dotOp.getC(),
dotOp.getAllowTF32(), dotOp.getMaxNumImpreciseAcc());
dotOp.replaceAllUsesWith(dotAsync.getResult());
updateConsumerReleaseInfo(dotOp, dotWait, /*stage=*/1);
dotOp->erase();
}
// 2. If there's any outstanding DotAsyncOps, we need to wait for them.
builder.setInsertionPointAfter(forOp);
Value loopNotEmpty = builder.create<arith::CmpIOp>(
loc, arith::CmpIPredicate::slt, forOp.getLowerBound(),
forOp.getUpperBound());
// TODO[goostavz]: it's a workaround to put the DotWaitOp in an IfOp for
// a bug in ptxas which mistakenly analysis the control flow and turn the GMMA
// into synchronuous implementation for safety.
// Remove this If once the bug is fixed.
auto ifOp = builder.create<scf::IfOp>(loc, ArrayRef<Type>{}, loopNotEmpty,
/*hasElse*/ false);
builder.setInsertionPointToStart(ifOp.thenBlock());
builder.create<tt::nvidia_gpu::DotWaitOp>(forOp.getLoc(), 0);
for (unsigned resultIndex : resultNeedSync) {
Value result = forOp->getResult(resultIndex);
if (result.use_empty())
continue;
auto dotWait =
builder.create<tt::nvidia_gpu::DotWaitOp>(forOp.getLoc(), result, 0);
result.replaceAllUsesExcept(dotWait.getResult(), dotWait);
}
}
Value PipelinePass::getRemoteCTAId(OpBuilder &b, Location loc,

View File

@@ -31,6 +31,7 @@ using triton::gpu::SliceEncodingAttr;
//
// -----------------------------------------------------------------------------
<<<<<<< HEAD
// convert(blocked, dot_operand) ->
// convert(blocked, mma) + convert(mma, dot_operand)
// if this value is itself the result of a dot operation
@@ -102,6 +103,9 @@ public:
};
//
=======
// dot(a, b, load(ptr)) -> add(load(ptr), dot(a, b, 0))
>>>>>>> 721897fcc4f942aa97d2e9ba3787a5e213758177
class ConvertDotConvert : public mlir::RewritePattern {
public:
ConvertDotConvert(mlir::MLIRContext *context)
@@ -233,12 +237,17 @@ static bool hasConvertToMMATransisitiveUse(Operation *op, Attribute encoding) {
getForwardSlice(currentValue, &forwardSlice);
for (Operation *op : forwardSlice) {
if (auto convertOp = dyn_cast<triton::gpu::ConvertLayoutOp>(op)) {
if (convertOp.getResult()
.getType()
.cast<RankedTensorType>()
.getEncoding()
.isa<triton::gpu::MmaEncodingAttr>())
return true;
Attribute dstEncoding = convertOp.getResult()
.getType()
.cast<RankedTensorType>()
.getEncoding();
if (auto mmaLayout =
dstEncoding.dyn_cast<triton::gpu::MmaEncodingAttr>())
return (mmaLayout.getVersionMajor() > 1) ? true
: mmaLayout == encoding;
if (dstEncoding.isa<triton::gpu::DotOperandEncodingAttr>())
return encoding.cast<triton::gpu::MmaEncodingAttr>()
.getVersionMajor() > 1;
}
auto yield = dyn_cast<scf::YieldOp>(op);
if (!yield)
@@ -560,6 +569,15 @@ Value LayoutPropagation::getValueAs(Value value, Attribute encoding) {
return rewrittenValue;
OpBuilder rewriter(value.getContext());
rewriter.setInsertionPointAfterValue(rewrittenValue);
// Workaround: The pipeliner will insert async.wait after a pipelined loop
// to ensure that there is no pending copies and it is safe to re-use shared
// memory. We shouldn't insert ops that may use shared memory in between the
// loop and the async.wait. This is a hack until we fix the IR
// representation of async wait.
if (Operation *op = rewrittenValue.getDefiningOp()) {
if (isa<triton::gpu::AsyncWaitOp>(op->getNextNode()))
rewriter.setInsertionPointAfter(op->getNextNode());
}
auto tmpType = RankedTensorType::get(tensorType.getShape(),
tensorType.getElementType(), encoding);
Value converted = rewriter.create<triton::gpu::ConvertLayoutOp>(
@@ -1122,7 +1140,6 @@ public:
hoistConvert(m);
mlir::RewritePatternSet decomposePatterns(context);
decomposePatterns.add<DecomposeDotOperand>(context);
decomposePatterns.add<ConvertDotConvert>(context);
if (mlir::applyPatternsAndFoldGreedily(m, std::move(decomposePatterns))
.failed()) {

View File

@@ -91,7 +91,7 @@ private:
// suport ForOp only
if (auto forOp = dyn_cast<scf::ForOp>(argOwner)) {
// prologue
auto iterOperands = forOp.getIterOperands();
auto iterOperands = forOp.getInitArgs();
if (argNum == 0)
return false;
if (dependOnSharedEncOperand(iterOperands[argNum - 1]))

View File

@@ -628,12 +628,13 @@ bool CTAPlanner::isElementwiseOp(Operation *op) const {
arith::CeilDivUIOp, arith::DivFOp, arith::DivSIOp,
arith::DivUIOp, arith::ExtFOp, arith::ExtSIOp, arith::ExtUIOp,
arith::FloorDivSIOp, arith::FPToSIOp, arith::FPToUIOp,
arith::MaxFOp, arith::MaxSIOp, arith::MaxUIOp, arith::MinFOp,
arith::MinSIOp, arith::MinUIOp, arith::MulFOp, arith::MulIOp,
arith::NegFOp, arith::OrIOp, arith::RemFOp, arith::RemSIOp,
arith::RemUIOp, arith::ShLIOp, arith::ShRSIOp, arith::ShRUIOp,
arith::SIToFPOp, arith::SubFOp, arith::SubIOp, arith::TruncFOp,
arith::TruncIOp, arith::UIToFPOp, arith::XOrIOp>(op))
arith::MaximumFOp, arith::MaxSIOp, arith::MaxUIOp,
arith::MinimumFOp, arith::MinSIOp, arith::MinUIOp,
arith::MulFOp, arith::MulIOp, arith::NegFOp, arith::OrIOp,
arith::RemFOp, arith::RemSIOp, arith::RemUIOp, arith::ShLIOp,
arith::ShRSIOp, arith::ShRUIOp, arith::SIToFPOp, arith::SubFOp,
arith::SubIOp, arith::TruncFOp, arith::TruncIOp,
arith::UIToFPOp, arith::XOrIOp>(op))
return true;
if (llvm::isa<math::AbsFOp, math::AbsIOp, math::AtanOp, math::Atan2Op,
math::CeilOp, math::CopySignOp, math::CosOp, math::SinOp,

View File

@@ -220,7 +220,9 @@ public:
SetVector<Operation *> backwardSlice;
mod.walk([&](triton::MakeTensorPtrOp op) -> void {
assert(isa<triton::FuncOp>(op->getParentOp()));
getBackwardSlice(op.getOperation(), &backwardSlice);
mlir::BackwardSliceOptions opt;
opt.omitBlockArguments = true;
getBackwardSlice(op.getOperation(), &backwardSlice, opt);
op->removeAttr("async_agent");
});
for (auto op : backwardSlice) {

View File

@@ -79,6 +79,7 @@ void materializeGetAgentIdOp(Operation *parentOp) {
builder.setInsertionPoint(agentIdOp);
Value globalRoleId = builder.create<arith::ConstantIntOp>(loc, 0, 32);
int globalNumWarps = 0;
SmallVector<Operation *> deprecatedOps;
for (auto cmpOp : agentIdOp->getUsers()) {
assert(isa<arith::CmpIOp>(cmpOp));
for (auto u : cmpOp->getUsers()) {
@@ -111,11 +112,14 @@ void materializeGetAgentIdOp(Operation *parentOp) {
Value cond =
builder.create<arith::AndIOp>(loc, lowerBound, upperBound);
cmpOp->getResult(0).replaceAllUsesWith(cond);
cmpOp->erase();
deprecatedOps.push_back(cmpOp);
break;
}
}
}
for (Operation *cmpOp : deprecatedOps) {
cmpOp->erase();
}
});
}
@@ -145,39 +149,24 @@ LoadType scanLoadTypes(ttng::CreateTokenOp createTokenOp) {
}
Value getMBarrierPhaseBit(OpBuilder &builder, Operation *op,
bool skipFirstWait) {
bool emptyBarrier) {
// TODO: currently we only support one loop, no nested loop, while or
// condition.
auto loc = op->getLoc();
auto forOp = op->getParentOfType<scf::ForOp>();
if (!forOp) {
return builder.create<arith::ConstantIntOp>(loc, skipFirstWait, 1);
return builder.create<arith::ConstantIntOp>(loc, emptyBarrier, 1);
}
auto defOp = op->getOperand(0).getDefiningOp();
assert(isa<ttng::CreateTokenOp>(defOp) &&
"mbarrier's definingOp is not createTokenOp");
ttng::CreateTokenOp createTokenOp = dyn_cast<ttng::CreateTokenOp>(defOp);
Value numStage =
builder.create<arith::ConstantIntOp>(loc, createTokenOp.getNum(), 32);
Value curStep = forOp.getBody()->getArguments().back();
if (curStep.getType() == builder.getIndexType()) {
curStep =
builder.create<arith::IndexCastOp>(loc, numStage.getType(), curStep);
// for (..., phase, pipelineIdx)
unsigned numArgs = forOp.getBody()->getNumArguments();
assert(numArgs > 2 && "Unexpected number of arguments");
Value curPhase = forOp.getBody()->getArgument(numArgs - 2);
if (emptyBarrier) {
Value _1_1b = builder.create<arith::ConstantIntOp>(loc, 1, 1);
curPhase = builder.create<mlir::arith::XOrIOp>(loc, curPhase, _1_1b);
}
Value curPhase = builder.create<arith::DivUIOp>(loc, curStep, numStage);
if (skipFirstWait) {
// If skipFirstWait, it waits for phaseBit 1
Value _1 = builder.create<arith::ConstantIntOp>(loc, 1, 32);
curPhase = builder.create<arith::AddIOp>(loc, curPhase, _1);
}
Value _2 = builder.create<arith::ConstantIntOp>(loc, 2, 32);
// TODO: May use alternative methods of phaseBit calculation to avoid high
// overhead of RemOp
Value phaseBit = builder.create<arith::RemUIOp>(loc, curPhase, _2);
Value _0 = builder.create<arith::ConstantIntOp>(loc, 0, 32);
return builder.create<arith::CmpIOp>(loc, arith::CmpIPredicate::ne, phaseBit,
_0);
return curPhase;
}
int getTxBytes(ttng::InsertSliceAsyncV2Op load) {
@@ -260,7 +249,7 @@ void processProducerAcquireOp(OpBuilder &builder, ttng::ProducerAcquireOp op,
auto loc = op.getLoc();
// The first producer_aquire should be met immediately, so initailly producer
// skips the fisrt wait
Value phase = getMBarrierPhaseBit(builder, op, 1);
Value phase = getMBarrierPhaseBit(builder, op, true);
auto waitOp = builder.create<ttng::MBarrierWaitOp>(loc, bufferEmpty, phase);
assert(op.getOperation()->hasAttr("async_agent"));
setAgentIds(waitOp, getAgentIds(op.getOperation()));
@@ -296,7 +285,7 @@ void processProducerCommitOp(OpBuilder &builder, ttng::ProducerCommitOp op,
void processConsumerWaitOp(OpBuilder &builder, ttng::ConsumerWaitOp op,
Value bufferFull) {
auto loc = op.getLoc();
Value phase = getMBarrierPhaseBit(builder, op, 0);
Value phase = getMBarrierPhaseBit(builder, op, false);
auto waitOp = builder.create<ttng::MBarrierWaitOp>(loc, bufferFull, phase);
assert(op.getOperation()->hasAttr("async_agent"));
setAgentIds(waitOp, getAgentIds(op.getOperation()));
@@ -530,6 +519,7 @@ void mutexSyncPingPang(Operation *parentOp, int numAgents, int &nameBarrierId,
builder.create<arith::ConstantIntOp>(loc, nameBarrierId - 1, 32);
// Process mutex users
int numUsers = 0;
SmallVector<Operation *> deprecatedOps;
for (Operation *user : createMutexOp.getResult().getUsers()) {
numUsers++;
assert(numUsers <= 2);
@@ -543,14 +533,20 @@ void mutexSyncPingPang(Operation *parentOp, int numAgents, int &nameBarrierId,
Value barLeave = builder.create<arith::SelectOp>(
loc, isRole0, namedBarrierId1, namedBarrierId0);
builder.create<ttng::NamedBarrierArriveOp>(loc, barLeave, numThreads);
} else
} else {
llvm_unreachable("Unexpected user of mutex");
}
deprecatedOps.push_back(user);
}
for (Operation *user : deprecatedOps) {
user->erase();
}
nameBarrierId -= 2;
nameBarrierIdEnd -= 2;
createMutexOp.erase();
});
parentOp->walk(
[](ttng::CreateMutexOp createMutexOp) { createMutexOp.erase(); });
}
void processLockOp(OpBuilder &builder, ttng::LockOp op) {
@@ -587,6 +583,7 @@ void materializeMutexOperationsOthers(ModuleOp parentOp) {
OpBuilder builder(createMutexOp);
// Process mutex users
SmallVector<Operation *> deprecatedOps;
for (Operation *user : createMutexOp.getResult().getUsers()) {
auto loc = user->getLoc();
builder.setInsertionPoint(user);
@@ -596,6 +593,9 @@ void materializeMutexOperationsOthers(ModuleOp parentOp) {
processUnlockOp(builder, op);
else
llvm_unreachable("Unexpected user of mutex");
deprecatedOps.push_back(user);
}
for (Operation *user : deprecatedOps) {
user->erase();
}

View File

@@ -156,14 +156,20 @@ void mutexSync(ModuleOp &mod, scf::IfOp &ifOp, scf::ForOp &persistentForOp,
persistentForOp.getInitArgsMutable()
.slice(persistentForOp.getInitArgs().size() - 1, 1)
.assign(newIdx);
auto yield =
llvm::cast<scf::YieldOp>(persistentForOp.getBody()->getTerminator());
auto idxPlusOneOp =
yield->getOperand(yield->getNumOperands() - 1).getDefiningOp();
assert(isa<arith::AddIOp>(idxPlusOneOp));
assert(idxPlusOneOp->getOperand(0) ==
persistentForOp.getBody()->getArgument(
persistentForOp.getBody()->getNumArguments() - 1));
pipelineIdx = persistentForOp.getBody()->getArgument(
persistentForOp.getBody()->getNumArguments() - 1);
Operation *idxPlusOneOp = nullptr;
for (OpOperand &v : pipelineIdx.getUses()) {
if (isa<arith::AddIOp>(v.getOwner())) {
idxPlusOneOp = v.getOwner();
break;
}
}
assert(idxPlusOneOp && "idxPlusOneOp should be arith::AddIOp");
Operation *use = *idxPlusOneOp->getUsers().begin();
assert(isa<scf::YieldOp>(use) || isa<arith::SelectOp>(use) ||
isa<arith::CmpIOp>(use));
idxPlusOneOp->setOperand(1, numRolesValue);
// Add operations at the start of persistentForOp
@@ -213,45 +219,6 @@ void mutexSync(ModuleOp &mod, scf::IfOp &ifOp, scf::ForOp &persistentForOp,
unlockLocs[i] = op;
}
// Update unlockLocs
// ====================== IR after async launch dots ======================
// * %0:2 = scf.for %arg0 = %c0 to %1 step %c1 iter_args(%arg1 = %2, arg2 =
// %3) {
// * triton_nvidia_gpu.producer_wait arg2
// * %5 = triton_nvidia_gpu.dot_async %4, %5
// * triton_nvidia_gpu.dot_wait {pendings = 1}
// * %6 = arith.cmpi sgt, arg0, %c0
// * scf.if %6 {
// * %7 = arith.subi arg2, c1
// * triton_nvidia_gpu.consumer_release %7
// * }
// * %8 = arith.addi arg2, c1
// * scf.yield %5, %8
// * }
// * triton_nvidia_gpu.dot_wait {pendings = 0}
// * %9 = arith.subi %0#1, c1
// * triton_nvidia_gpu.consumer_release %9
// * =======================================================================
// after async launch dots, there will be outstanding consumerReleaseOp after
// ForOp. we should expend the unlockLocs from ForOp to the outstanding
// consumerReleaseOp.
for (int i = 0; i < numRoles; ++i) {
Operation *unlockOp = unlockLocs[i];
auto filter = [&](Operation *op) {
return op->getBlock() == unlockOp->getBlock();
};
if (isa<scf::ForOp>(unlockOp)) {
SetVector<Operation *> slices;
mlir::getForwardSlice(unlockOp->getResults().back(), &slices, {filter});
auto iter = llvm::find_if(slices, [](Operation *op) {
return isa<triton::nvidia_gpu::ConsumerReleaseOp>(op);
});
if (iter != slices.end()) {
unlockLocs[i] = *iter;
}
}
}
// Only cases where all lock/unlock locations are in same level make sense.
for (int i = 1; i < numRoles; ++i) {
if (lockLocs[i]->getParentOp() != lockLocs[i - 1]->getParentOp() ||
@@ -281,6 +248,54 @@ void mutexSync(ModuleOp &mod, scf::IfOp &ifOp, scf::ForOp &persistentForOp,
else
lockLocs[i] = unlockLocs[prevTypeIds[i]];
}
// Update lockLocs
// ====================== IR after async launch dots ======================
// * %0:2 = scf.for %arg0 = %c0 to %1 step %c1 iter_args(%arg1 = %2, arg2 =
// %3) {
// * triton_nvidia_gpu.producer_wait arg2
// * %5 = triton_nvidia_gpu.dot_async %4, %5
// * triton_nvidia_gpu.dot_wait {pendings = 1}
// * %6 = arith.cmpi sgt, arg0, %c0
// * scf.if %6 {
// * %7 = arith.subi arg2, c1
// * triton_nvidia_gpu.consumer_release %7
// * }
// * %8 = arith.addi arg2, c1
// * scf.yield %5, %8
// * }
// * triton_nvidia_gpu.dot_wait {pendings = 0}
// * ...
// * triton_nvidia_gpu.consumer_release ..
// * =======================================================================
// after async launch dots, there will be outstanding consumerReleaseOp after
// ForOp. we should set the epilogue lockLocs after the outstanding
// consumerReleaseOp.
for (int i = 0; i < numRoles; ++i) {
Operation *lockOp = lockLocs[i];
if (isa<scf::ForOp>(lockOp)) {
Operation *loc = nullptr;
unsigned numOutstandingConsumerRelease = 0;
for (auto v : lockOp->getResults()) {
SetVector<Operation *> slices;
mlir::getForwardSlice(v, &slices);
auto iter = llvm::find_if(slices, [](Operation *op) {
return isa<triton::nvidia_gpu::ConsumerReleaseOp>(op);
});
if (iter != slices.end()) {
numOutstandingConsumerRelease++;
loc = *iter;
}
}
assert(numOutstandingConsumerRelease <= 1 &&
"should have only one outstanding "
"consumerReleaseOp after "
"async launch dots");
if (loc)
lockLocs[i] = loc;
}
}
// lock
for (int i = 0; i < numRoles; ++i) {
builder.setInsertionPointAfter(lockLocs[i]);

View File

@@ -129,11 +129,12 @@ DenseMap<Channel *, Value> createBuffer(const SmallVector<Channel *> &channels,
}
//===----------------------------------------------------------------------===//
// appendPipelineIdxToLoopArgs
// createNewLoops
//===----------------------------------------------------------------------===//
scf::ForOp appendPipelineIdxToLoopArgs(scf::ForOp forOp, int numStages,
scf::ForOp &parentForOp) {
// for(...) -> for(..., pipelineIdx)
scf::ForOp createNewPersistentLoop(scf::ForOp forOp, int numStages,
scf::ForOp &parentForOp) {
auto loc = forOp.getLoc();
Block *body = forOp.getBody();
@@ -200,6 +201,117 @@ scf::ForOp appendPipelineIdxToLoopArgs(scf::ForOp forOp, int numStages,
return newForOp;
}
// for(...) -> for(..., phase, pipelineIdx)
scf::ForOp createNewMathLoop(scf::ForOp forOp, int numStages,
scf::ForOp &parentForOp) {
auto loc = forOp.getLoc();
Block *body = forOp.getBody();
// The agentId set of pipelineIdx is the union of agentId sets of all ops in
// the for loop
OpBuilderWithAgentIds builder(forOp.getContext());
builder.setAgentIdsFromArray(collectAgentIds(forOp));
builder.setInsertionPoint(forOp);
Value numStagesVal =
builder.createWithAgentIds<arith::ConstantIntOp>(loc, numStages, 32);
// 0. Append pipelineIdx to block arguments
Value phase =
body->insertArgument(body->getNumArguments(), builder.getI1Type(), loc);
Value pipelineIdx =
body->insertArgument(body->getNumArguments(), builder.getI32Type(), loc);
// 1. prepare index and phase for next iteration
// nextIdx = curIdx + 1
// nextPhase = ((nextIdx < numStages && curPhase) || (nextIdx >= numStages &&
// curPhase^1))
// nextIdx = nextIdx >= numStages ? 0 : nextIdx
auto yieldOp = llvm::cast<scf::YieldOp>(body->getTerminator());
builder.setInsertionPoint(yieldOp);
Value one = builder.createWithAgentIds<arith::ConstantIntOp>(loc, 1, 32);
Value zero = builder.createWithAgentIds<arith::ConstantIntOp>(loc, 0, 32);
Value _1_1b = builder.createWithAgentIds<arith::ConstantIntOp>(loc, 1, 1);
// generate index for next iter
Value nextPipelineIdx =
builder.createWithAgentIds<arith::AddIOp>(loc, pipelineIdx, one);
Value pipelineGECond = builder.createWithAgentIds<arith::CmpIOp>(
loc, arith::CmpIPredicate::uge, nextPipelineIdx, numStagesVal);
Value pipelineLTCond = builder.createWithAgentIds<arith::CmpIOp>(
loc, arith::CmpIPredicate::ult, nextPipelineIdx, numStagesVal);
Value cyclePipelineIdx = builder.createWithAgentIds<arith::SubIOp>(
loc, nextPipelineIdx, numStagesVal);
nextPipelineIdx = builder.createWithAgentIds<mlir::arith::SelectOp>(
loc, pipelineGECond, cyclePipelineIdx, nextPipelineIdx);
// generate phase for next iter
Value flipPhase =
builder.createWithAgentIds<mlir::arith::XOrIOp>(loc, phase, _1_1b);
Value cond0 = builder.createWithAgentIds<mlir::arith::AndIOp>(
loc, pipelineGECond, flipPhase);
Value cond1 = builder.createWithAgentIds<mlir::arith::AndIOp>(
loc, pipelineLTCond, phase);
Value nextPhase =
builder.createWithAgentIds<mlir::arith::OrIOp>(loc, cond0, cond1);
// 2. Append pipelineIdx to yield operands
yieldOp->insertOperands(yieldOp.getNumOperands(),
{nextPhase, nextPipelineIdx});
// 3. create newLoopArgs
SmallVector<Value> newLoopArgs;
for (auto operand : forOp.getInitArgs())
newLoopArgs.push_back(operand);
builder.setInsertionPoint(forOp);
Value initPipelineIdx, initEmptyIdx, initPhase;
zero = builder.createWithAgentIds<arith::ConstantIntOp>(loc, 0, 32);
if (parentForOp) {
// Make sure prior pipelineIdx is inserted in the end of parentForOp
initPipelineIdx = parentForOp.getBody()->getArguments().back();
Value numSteps = builder.createWithAgentIds<arith::SubIOp>(
loc, forOp.getUpperBound(), forOp.getLowerBound());
numSteps = builder.createWithAgentIds<arith::AddIOp>(loc, numSteps,
forOp.getStep());
Value one = builder.createWithAgentIds<arith::ConstantIntOp>(loc, 1, 32);
Value two = builder.createWithAgentIds<arith::ConstantIntOp>(loc, 2, 32);
numSteps = builder.createWithAgentIds<arith::SubIOp>(loc, numSteps, one);
numSteps = builder.createWithAgentIds<arith::DivUIOp>(loc, numSteps,
forOp.getStep());
// initPipelineIdx = (parentForOp.pipelineIdx * numSteps) % numStages
// initPhase = ((parentForOp.pipelineIdx * numSteps) / numStages) & 1
initPipelineIdx = builder.createWithAgentIds<arith::MulIOp>(
loc, initPipelineIdx, numSteps);
Value pipelineIdx = builder.createWithAgentIds<arith::DivUIOp>(
loc, initPipelineIdx, numStagesVal);
initPipelineIdx = builder.createWithAgentIds<arith::SubIOp>(
loc, initPipelineIdx,
builder.createWithAgentIds<arith::MulIOp>(loc, pipelineIdx,
numStagesVal));
pipelineIdx =
builder.createWithAgentIds<arith::AndIOp>(loc, pipelineIdx, one);
initPhase = builder.createWithAgentIds<arith::TruncIOp>(
loc, builder.getI1Type(), pipelineIdx);
} else {
// phase init to false and pipelineIdx init to 0
initPipelineIdx = zero;
initPhase = builder.createWithAgentIds<arith::ConstantIntOp>(loc, 0, 1);
}
newLoopArgs.append({initPhase, initPipelineIdx});
// 4. Create newForOp and take the region of forOp
auto newForOp = builder.createWithAgentIds<scf::ForOp>(
loc, forOp.getLowerBound(), forOp.getUpperBound(), forOp.getStep(),
newLoopArgs);
newForOp.getRegion().takeBody(forOp.getRegion());
// 5. Replace forOp with newForOp
for (unsigned i = 0; i < forOp.getNumResults(); ++i)
forOp.getResult(i).replaceAllUsesWith(newForOp.getResult(i));
forOp.erase();
return newForOp;
}
//===----------------------------------------------------------------------===//
// appendPipelineIdxArgs
//===----------------------------------------------------------------------===//
@@ -217,7 +329,22 @@ void appendPipelineIdxArgs(SmallVector<Operation *> &backbone, int numStages) {
for (auto &op : orderedForOps) {
scf::ForOp parentForOp = op->getParentOfType<scf::ForOp>();
auto newForOp = appendPipelineIdxToLoopArgs(op, numStages, parentForOp);
scf::ForOp newForOp;
bool hasDotOp = false;
for (Operation &subOp : *op.getBody()) {
if (isa<triton::DotOp>(&subOp)) {
hasDotOp = true;
break;
}
}
if (hasDotOp) {
// for(...) -> for(..., phase, pipelineIdx)
newForOp = createNewMathLoop(op, numStages, parentForOp);
} else {
// for(...) -> for(..., pipelineIdx)
newForOp = createNewPersistentLoop(op, numStages, parentForOp);
}
auto backboneForItr =
std::find(backbone.begin(), backbone.end(), op.getOperation());
if (backboneForItr != backbone.end()) {
@@ -688,8 +815,6 @@ void buildAsyncComm(const DenseMap<Operation *, SmallVector<Channel *>> &map,
}
builder.setAgentIdsFromArray(agentsPC);
Value pipelineIdx;
Value numStagesVal = builder.createWithAgentIds<arith::ConstantIntOp>(
headProducer->getLoc(), numStages, 32);
if (auto forOp = headProducer->getParentOfType<scf::ForOp>()) {
pipelineIdx = forOp.getBody()->getArguments().back();
} else {
@@ -700,10 +825,6 @@ void buildAsyncComm(const DenseMap<Operation *, SmallVector<Channel *>> &map,
// insert ProducerAcquireOp
builder.setInsertionPoint(headProducer);
if (headProducer->getParentOfType<scf::ForOp>()) {
pipelineIdx = builder.createWithAgentIds<arith::RemSIOp>(
headProducer->getLoc(), pipelineIdx, numStagesVal);
}
builder.setAgentIdsFromArray(agentP);
builder.createWithAgentIds<ttng::ProducerAcquireOp>(headProducer->getLoc(),
token, pipelineIdx);
@@ -738,7 +859,8 @@ void buildAsyncComm(const DenseMap<Operation *, SmallVector<Channel *>> &map,
loc, dotOp.getA(), dotOp.getB(), dotOp.getC(),
dotOp.getAllowTF32(), dotOp.getMaxNumImpreciseAcc());
dot.replaceAllUsesWith(dotAsync.getResult());
builder.createWithAgentIds<triton::nvidia_gpu::DotWaitOp>(loc, 1);
builder.createWithAgentIds<triton::nvidia_gpu::DotWaitOp>(
loc, dotAsync.getResult(), 1);
// 1. insert ConsumerReleaseOp for DotAsyncOps
Value cond = builder.createWithAgentIds<arith::CmpIOp>(
@@ -747,31 +869,43 @@ void buildAsyncComm(const DenseMap<Operation *, SmallVector<Channel *>> &map,
auto ifOp =
builder.createWithAgentIds<scf::IfOp>(loc, ArrayRef<Type>{}, cond,
/*hasElse*/ false);
setAgentIds(ifOp.thenYield().getOperation(), agentIds);
builder.setInsertionPointToStart(ifOp.thenBlock());
Value one = builder.createWithAgentIds<arith::ConstantIntOp>(
headConsumer->getLoc(), 1, 32);
auto oriIdx = forOp.getBody()->getArguments().back();
Value consumerReleaseIdx =
builder.createWithAgentIds<arith::SubIOp>(loc, oriIdx, one);
consumerReleaseIdx = builder.createWithAgentIds<arith::RemSIOp>(
loc, consumerReleaseIdx, numStagesVal);
Value consumerReleaseIdx = forOp.getBody()->getArguments().back();
Value zero = builder.createWithAgentIds<arith::ConstantIntOp>(loc, 0, 32);
Value one = builder.createWithAgentIds<arith::ConstantIntOp>(loc, 1, 32);
Value lastStage = builder.createWithAgentIds<arith::ConstantIntOp>(
loc, numStages - 1, 32);
Value consumerReleaseIdxMinusOne =
builder.createWithAgentIds<arith::SubIOp>(loc, consumerReleaseIdx,
one);
cond = builder.createWithAgentIds<arith::CmpIOp>(
loc, arith::CmpIPredicate::eq, consumerReleaseIdx, zero);
consumerReleaseIdx = builder.createWithAgentIds<arith::SelectOp>(
loc, cond, lastStage, consumerReleaseIdxMinusOne);
builder.createWithAgentIds<ttng::ConsumerReleaseOp>(loc, token,
consumerReleaseIdx);
setAgentIds(ifOp.thenYield().getOperation(), agentIds);
// 2. If there's any outstanding DotAsyncOps, we need to wait for them.
builder.setInsertionPointAfter(forOp);
builder.createWithAgentIds<triton::nvidia_gpu::DotWaitOp>(forOp.getLoc(),
0);
unsigned resultIndex = dotAsync->getUses().begin()->getOperandNumber();
Value result = forOp->getResult(resultIndex);
auto dotWait = builder.createWithAgentIds<triton::nvidia_gpu::DotWaitOp>(
forOp.getLoc(), result, 0);
result.replaceAllUsesExcept(dotWait.getResult(), dotWait);
// 3. insert ConsumerReleaseOp for outstanding DotAsyncOps
Value one_ = builder.createWithAgentIds<arith::ConstantIntOp>(
headConsumer->getLoc(), 1, 32);
zero = builder.createWithAgentIds<arith::ConstantIntOp>(loc, 0, 32);
one = builder.createWithAgentIds<arith::ConstantIntOp>(loc, 1, 32);
lastStage = builder.createWithAgentIds<arith::ConstantIntOp>(
loc, numStages - 1, 32);
consumerReleaseIdx = forOp.getResults().back();
consumerReleaseIdx = builder.createWithAgentIds<arith::SubIOp>(
loc, consumerReleaseIdx, one_);
consumerReleaseIdx = builder.createWithAgentIds<arith::RemSIOp>(
loc, consumerReleaseIdx, numStagesVal);
consumerReleaseIdxMinusOne = builder.createWithAgentIds<arith::SubIOp>(
loc, consumerReleaseIdx, one);
cond = builder.createWithAgentIds<arith::CmpIOp>(
loc, arith::CmpIPredicate::eq, consumerReleaseIdx, zero);
consumerReleaseIdx = builder.createWithAgentIds<arith::SelectOp>(
loc, cond, lastStage, consumerReleaseIdxMinusOne);
builder.createWithAgentIds<ttng::ConsumerReleaseOp>(loc, token,
consumerReleaseIdx);
dotOp->erase();

View File

@@ -14,7 +14,6 @@ add_mlir_translation_library(TritonLLVMIR
PUBLIC
MLIRArithToLLVM
MLIRBuiltinToLLVMIRTranslation
MLIRExecutionEngineUtils
MLIRIndexToLLVM
MLIRIR
MLIRLLVMDialect

View File

@@ -44,7 +44,8 @@ static bool findAndReplace(std::string &str, const std::string &begin,
return true;
}
std::string translateLLVMIRToPTX(llvm::Module &module, int cc, int version) {
std::string translateLLVMIRToPTX(llvm::Module &module, int cc, int version,
bool enable_fp_fusion) {
// LLVM version in use may not officially support target hardware.
// Supported versions for LLVM 14 are here:
// https://github.com/llvm/llvm-project/blob/f28c006a5895fc0e329fe15fead81e37457cb1d1/clang/include/clang/Basic/BuiltinsNVPTX.def
@@ -84,13 +85,15 @@ std::string translateLLVMIRToPTX(llvm::Module &module, int cc, int version) {
auto target =
llvm::TargetRegistry::lookupTarget(module.getTargetTriple(), error);
llvm::TargetOptions opt;
opt.AllowFPOpFusion = llvm::FPOpFusion::Fast;
if (enable_fp_fusion)
opt.AllowFPOpFusion = llvm::FPOpFusion::Fast;
opt.UnsafeFPMath = false;
opt.NoInfsFPMath = false;
opt.NoNaNsFPMath = true;
opt.TrapUnreachable = true;
std::unique_ptr<llvm::TargetMachine> machine{target->createTargetMachine(
module.getTargetTriple(), proc, features, opt, llvm::Reloc::PIC_,
std::nullopt, llvm::CodeGenOpt::Aggressive)};
std::nullopt, llvm::CodeGenOptLevel::Aggressive)};
// set data layout
if (layout.empty())
module.setDataLayout(machine->createDataLayout());
@@ -106,7 +109,7 @@ std::string translateLLVMIRToPTX(llvm::Module &module, int cc, int version) {
llvm::legacy::PassManager pass;
// emit
machine->addPassesToEmitFile(pass, pstream, nullptr,
llvm::CodeGenFileType::CGFT_AssemblyFile);
llvm::CodeGenFileType::AssemblyFile);
pass.run(module);
}
// post-process