mirror of
https://github.com/ROCm/ROCm.git
synced 2026-04-05 03:01:17 -04:00
Merge pull request #410 from ROCmSoftwarePlatform/ifu-231117
Ifu 231117
This commit is contained in:
@@ -364,17 +364,23 @@ private:
|
||||
scratchAlignment);
|
||||
}
|
||||
} else if (auto atomicCASOp = dyn_cast<triton::AtomicCASOp>(op)) {
|
||||
// only scalar requires scratch memory
|
||||
// make it explicit for readability
|
||||
auto value = op->getOperand(0);
|
||||
auto smemShape = getScratchConfigForAtomicCAS(atomicCASOp);
|
||||
unsigned elems = std::accumulate(smemShape.begin(), smemShape.end(), 1,
|
||||
std::multiplies{});
|
||||
auto elemTy =
|
||||
value.getType().cast<triton::PointerType>().getPointeeType();
|
||||
auto bytes = elemTy.isa<triton::PointerType>()
|
||||
? elems * kPtrBitWidth / 8
|
||||
: elems * elemTy.getIntOrFloatBitWidth() / 8;
|
||||
maybeAddScratchBuffer<BufferT::BufferKind::Scratch>(op, bytes,
|
||||
scratchAlignment);
|
||||
if (value.getType().dyn_cast<RankedTensorType>()) {
|
||||
// nothing to do
|
||||
} else {
|
||||
auto smemShape = getScratchConfigForAtomicCAS(atomicCASOp);
|
||||
unsigned elems = std::accumulate(smemShape.begin(), smemShape.end(), 1,
|
||||
std::multiplies{});
|
||||
auto elemTy =
|
||||
value.getType().cast<triton::PointerType>().getPointeeType();
|
||||
auto bytes = elemTy.isa<triton::PointerType>()
|
||||
? elems * kPtrBitWidth / 8
|
||||
: elems * elemTy.getIntOrFloatBitWidth() / 8;
|
||||
maybeAddScratchBuffer<BufferT::BufferKind::Scratch>(op, bytes,
|
||||
scratchAlignment);
|
||||
}
|
||||
} else if (auto callOp = dyn_cast<CallOpInterface>(op)) {
|
||||
auto callable = callOp.resolveCallable();
|
||||
auto funcOp = dyn_cast<FunctionOpInterface>(callable);
|
||||
|
||||
@@ -635,10 +635,6 @@ public:
|
||||
}
|
||||
|
||||
private:
|
||||
static arith::CmpIPredicate getPredicate(triton::gpu::CmpIOp op) {
|
||||
return op.getPredicate();
|
||||
}
|
||||
|
||||
static arith::CmpIPredicate getPredicate(arith::CmpIOp op) {
|
||||
return op.getPredicate();
|
||||
}
|
||||
@@ -917,13 +913,11 @@ AxisInfoAnalysis::AxisInfoAnalysis(DataFlowSolver &solver)
|
||||
visitors.append<BroadcastOpAxisInfoVisitor>();
|
||||
visitors.append<SplatOpAxisInfoVisitor>();
|
||||
visitors.append<ExpandDimsOpAxisInfoVisitor>();
|
||||
visitors.append<CmpOpAxisInfoVisitor<arith::CmpIOp>,
|
||||
CmpOpAxisInfoVisitor<triton::gpu::CmpIOp>>();
|
||||
visitors.append<CmpOpAxisInfoVisitor<arith::CmpIOp>>();
|
||||
visitors.append<LogicalOpAxisInfoVisitor<arith::AndIOp>,
|
||||
LogicalOpAxisInfoVisitor<arith::OrIOp>,
|
||||
LogicalOpAxisInfoVisitor<arith::XOrIOp>>();
|
||||
visitors.append<SelectOpAxisInfoVisitor<mlir::arith::SelectOp>,
|
||||
SelectOpAxisInfoVisitor<triton::gpu::SelectOp>>();
|
||||
visitors.append<SelectOpAxisInfoVisitor<mlir::arith::SelectOp>>();
|
||||
visitors.append<ShLIOpAxisInfoVisitor, ShROpAxisInfoVisitor<arith::ShRUIOp>,
|
||||
ShROpAxisInfoVisitor<arith::ShRSIOp>>();
|
||||
visitors.append<MaxMinOpAxisInfoVisitor<arith::MaxSIOp>,
|
||||
|
||||
@@ -98,7 +98,7 @@ bool shouldUseDistSmem(Attribute srcLayout, Attribute dstLayout) {
|
||||
auto dim = sliceLayout.getDim();
|
||||
auto CTAsPerCGA = triton::gpu::getCTAsPerCGA(sliceLayout.getParent());
|
||||
if (CTAsPerCGA[dim] != 1)
|
||||
assert(0 && "Layout conversion to be implemented");
|
||||
llvm::report_fatal_error("Layout conversion to be implemented");
|
||||
}
|
||||
|
||||
// Case where CTAsPerCGA of dstLayout in the sliced dim is not 1 is supported
|
||||
|
||||
@@ -676,7 +676,8 @@ private:
|
||||
inNumCTAsEachRep, multiDimRepId, inVec, paddedRepShape,
|
||||
origRepShape, outOrd, vals, smemBase);
|
||||
} else {
|
||||
assert(0 && "ConvertLayout with input layout not implemented");
|
||||
llvm::report_fatal_error(
|
||||
"ConvertLayout with input layout not implemented");
|
||||
return failure();
|
||||
}
|
||||
|
||||
@@ -713,7 +714,8 @@ private:
|
||||
paddedRepShape, origRepShape, outOrd, outVals,
|
||||
smemBase);
|
||||
} else {
|
||||
assert(0 && "ConvertLayout with output layout not implemented");
|
||||
llvm::report_fatal_error(
|
||||
"ConvertLayout with output layout not implemented");
|
||||
return failure();
|
||||
}
|
||||
}
|
||||
@@ -1159,7 +1161,7 @@ private:
|
||||
}
|
||||
return res;
|
||||
}
|
||||
}; // namespace triton::gpu::ConvertLayoutOp>
|
||||
};
|
||||
|
||||
void populateConvertLayoutOpToLLVMPatterns(
|
||||
TritonGPUToLLVMTypeConverter &typeConverter, RewritePatternSet &patterns,
|
||||
|
||||
@@ -422,7 +422,6 @@ MMA16816SmemLoader::MMA16816SmemLoader(
|
||||
needTrans = kOrder != order[0];
|
||||
canUseLdmatrix = elemBytes == 2 || (!needTrans);
|
||||
canUseLdmatrix = canUseLdmatrix && (kWidth == vecWidth);
|
||||
// canUseLdmatrix = false;
|
||||
|
||||
if (canUseLdmatrix) {
|
||||
// Each CTA, the warps is arranged as [1xwarpsPerTile] if not transposed,
|
||||
|
||||
@@ -146,8 +146,55 @@ struct DotWaitOpConversion
|
||||
matchAndRewrite(triton::nvidia_gpu::DotWaitOp op, OpAdaptor adaptor,
|
||||
ConversionPatternRewriter &rewriter) const override {
|
||||
auto pendings = op.getPendings();
|
||||
rewriter.replaceOpWithNewOp<triton::nvgpu::WGMMAWaitGroupOp>(
|
||||
op, adaptor.getInput(), pendings);
|
||||
Location loc = op.getLoc();
|
||||
if (adaptor.getInputs().size() <= 1) {
|
||||
Value intput =
|
||||
adaptor.getInputs().size() == 1 ? adaptor.getInputs()[0] : Value();
|
||||
rewriter.replaceOpWithNewOp<triton::nvgpu::WGMMAWaitGroupOp>(op, intput,
|
||||
pendings);
|
||||
return success();
|
||||
}
|
||||
std::vector<Type> types;
|
||||
// Pack the inputs into a single struct.
|
||||
for (Value input : adaptor.getInputs()) {
|
||||
auto structType = input.getType().dyn_cast<LLVM::LLVMStructType>();
|
||||
if (!structType)
|
||||
return failure();
|
||||
for (Type type : structType.getBody())
|
||||
types.push_back(type);
|
||||
}
|
||||
auto packedType =
|
||||
LLVM::LLVMStructType::getLiteral(rewriter.getContext(), types);
|
||||
Value packed = rewriter.create<LLVM::UndefOp>(loc, packedType);
|
||||
unsigned outputStructIndex = 0;
|
||||
for (Value input : adaptor.getInputs()) {
|
||||
auto structType = input.getType().dyn_cast<LLVM::LLVMStructType>();
|
||||
for (unsigned i = 0; i < structType.getBody().size(); ++i) {
|
||||
Value value = rewriter.create<LLVM::ExtractValueOp>(
|
||||
loc, structType.getBody()[i], input, i);
|
||||
packed = rewriter.create<LLVM::InsertValueOp>(
|
||||
loc, packedType, packed, value, outputStructIndex++);
|
||||
}
|
||||
}
|
||||
Value packedOutput =
|
||||
rewriter.create<triton::nvgpu::WGMMAWaitGroupOp>(loc, packed, pendings);
|
||||
// Unpack the output into the original struct types.
|
||||
SmallVector<Value> outputs;
|
||||
outputStructIndex = 0;
|
||||
for (Value input : adaptor.getInputs()) {
|
||||
auto structType = input.getType().cast<LLVM::LLVMStructType>();
|
||||
Value unpacked = rewriter.create<LLVM::UndefOp>(loc, structType);
|
||||
for (unsigned i = 0; i < structType.getBody().size(); ++i) {
|
||||
Value value = rewriter.create<LLVM::ExtractValueOp>(
|
||||
loc, packedType.getBody()[outputStructIndex], packedOutput,
|
||||
outputStructIndex);
|
||||
outputStructIndex++;
|
||||
unpacked = rewriter.create<LLVM::InsertValueOp>(loc, structType,
|
||||
unpacked, value, i);
|
||||
}
|
||||
outputs.push_back(unpacked);
|
||||
}
|
||||
rewriter.replaceOp(op, outputs);
|
||||
return success();
|
||||
}
|
||||
};
|
||||
|
||||
@@ -332,7 +332,7 @@ Fp8E5M2_to_Bf16(Location loc, ConversionPatternRewriter &rewriter,
|
||||
static const std::string Fp8E5M2_to_Bf16(bool hasNativeFP) {
|
||||
std::string ret;
|
||||
if (!hasNativeFP) {
|
||||
ret = "{ \n"
|
||||
ret = "{ \n"
|
||||
".reg .b32 a<2>, b<2>, c<4>, d<4>, e112; \n" // if input = 0xf1f2f3f4
|
||||
"mov.u32 e112, 0x77800000; \n"
|
||||
"prmt.b32 a0, 0, $2, 0x5140; \n" // a0 = 0xf300f400
|
||||
@@ -357,16 +357,23 @@ static const std::string Fp8E5M2_to_Bf16(bool hasNativeFP) {
|
||||
"lop3.b32 $1, b1, 0x80008000, a1, 0xf8; \n" // (restore sign)
|
||||
"}";
|
||||
} else {
|
||||
ret = "{ \n"
|
||||
".reg .b32 a; \n"
|
||||
".reg .f16 a<2>; \n"
|
||||
".reg .b16 b<2>; \n"
|
||||
"cvt.rn.f16x2.e5m2x2 a, $1; \n"
|
||||
"mov.b32 {a0, a1}, a; \n"
|
||||
"cvt.bf16.f16 b0, a0; \n"
|
||||
"cvt.bf16.f16 b1, a1; \n"
|
||||
"mov.b32 $0, {b0, b1}; \n"
|
||||
"}";
|
||||
ret =
|
||||
"{ \n"
|
||||
".reg .b32 a<2>, b<2>; \n" // if input = 0xf1f2f3f4
|
||||
".reg .b32 e112; \n"
|
||||
"mov.u32 e112, 0x77807780; \n" // 2**112 represented as
|
||||
// bf16x2
|
||||
"prmt.b32 a0, 0, $2, 0x5140; \n" // a0 = 0xf300f400
|
||||
"prmt.b32 a1, 0, $2, 0x7362; \n" // a1 = 0xf100f200
|
||||
"lop3.b32 b0, a0, 0x7fff7fff, 0, 0xc0; \n" // b0 = a0 & 0x7fff7fff
|
||||
"lop3.b32 b1, a1, 0x7fff7fff, 0, 0xc0; \n" // (strip sign)
|
||||
"shr.b32 b0, b0, 3; \n" // b0 >>= 3
|
||||
"shr.b32 b1, b1, 3; \n" // shift into bf16 position
|
||||
"lop3.b32 b0, b0, 0x80008000, a0, 0xf8; \n" // out0 = b0|(0x80008000&a0)
|
||||
"lop3.b32 b1, b1, 0x80008000, a1, 0xf8; \n" // (restore sign)
|
||||
"mul.rn.bf16x2 $0, b0, e112; \n" // b0.exp += 2**7-2**4
|
||||
"mul.rn.bf16x2 $1, b1, e112; \n" // exponent compensate = 112
|
||||
"}";
|
||||
}
|
||||
return ret;
|
||||
}
|
||||
@@ -492,7 +499,7 @@ static const std::string Bf16_to_Fp8E5M2(bool hasNativeFP) {
|
||||
"mov.b32 {a0, a1}, $1; \n"
|
||||
"cvt.f32.bf16 b0, a0; \n"
|
||||
"cvt.f32.bf16 b1, a1; \n"
|
||||
"cvt.rn.satfinite.e5m2x2.f32 $0, b0, b1; \n"
|
||||
"cvt.rn.satfinite.e5m2x2.f32 $0, b1, b0; \n"
|
||||
"}";
|
||||
}
|
||||
return ret;
|
||||
@@ -1058,7 +1065,7 @@ static const std::string Bf16_to_Fp8E4M3Nv =
|
||||
"mov.b32 {a0, a1}, $1; \n"
|
||||
"cvt.f32.bf16 b0, a0; \n"
|
||||
"cvt.f32.bf16 b1, a1; \n"
|
||||
"cvt.rn.satfinite.e4m3x2.f32 $0, b0, b1; \n"
|
||||
"cvt.rn.satfinite.e4m3x2.f32 $0, b1, b0; \n"
|
||||
"}";
|
||||
|
||||
/* ----- Packed integer to BF16 ------ */
|
||||
@@ -1312,8 +1319,118 @@ public:
|
||||
using OpAdaptor = typename SourceOp::Adaptor;
|
||||
|
||||
explicit ElementwiseOpConversionBase(
|
||||
TritonGPUToLLVMTypeConverter &typeConverter, PatternBenefit benefit = 1)
|
||||
: ConvertTritonGPUOpToLLVMPattern<SourceOp>(typeConverter, benefit) {}
|
||||
TritonGPUToLLVMTypeConverter &typeConverter,
|
||||
ModuleAxisInfoAnalysis &axisAnalysisPass, PatternBenefit benefit = 1)
|
||||
: ConvertTritonGPUOpToLLVMPattern<SourceOp>(typeConverter, benefit),
|
||||
axisAnalysisPass(axisAnalysisPass) {}
|
||||
|
||||
// Try to deduplicate the resultVals based on the
|
||||
// constancy properties of the result discovered by
|
||||
// the axis analysis pass. If possible, redundant
|
||||
// computation is eliminated.
|
||||
SmallVector<Value> maybeDeduplicate(SourceOp op,
|
||||
SmallVector<Value> resultVals) const {
|
||||
if (!isMemoryEffectFree(op))
|
||||
// the op has side effects: can't dedup
|
||||
return resultVals;
|
||||
SmallVector<Value> results = op->getResults();
|
||||
if (results.size() == 0 || results.size() > 1)
|
||||
// there must be exactly 1 result
|
||||
return resultVals;
|
||||
Value result = results[0];
|
||||
Type type = result.getType();
|
||||
if (!type)
|
||||
return resultVals;
|
||||
RankedTensorType rtType = type.dyn_cast<RankedTensorType>();
|
||||
if (!rtType)
|
||||
// the result must be a tensor
|
||||
return resultVals;
|
||||
Attribute encoding = rtType.getEncoding();
|
||||
if (!encoding)
|
||||
// encoding not available
|
||||
return resultVals;
|
||||
if (!encoding.dyn_cast<triton::gpu::BlockedEncodingAttr>() &&
|
||||
!encoding.dyn_cast<triton::gpu::SliceEncodingAttr>()) {
|
||||
// TODO: constraining the ecndoing type here is necessary
|
||||
// for avoiding crashes in the triton::gpu::getElemsPerThread
|
||||
// call below happening in the test_core::test_fp8_dot_acc
|
||||
return resultVals;
|
||||
}
|
||||
|
||||
SmallVector<unsigned> elemsPerThread =
|
||||
triton::gpu::getElemsPerThread(rtType);
|
||||
int rank = elemsPerThread.size();
|
||||
if (product<unsigned>(elemsPerThread) != resultVals.size())
|
||||
return resultVals;
|
||||
AxisInfo *axisInfo = axisAnalysisPass.getAxisInfo(result);
|
||||
if (!axisInfo)
|
||||
// axis info (e.g., constancy) not available
|
||||
return resultVals;
|
||||
SmallVector<unsigned> sizePerThread =
|
||||
triton::gpu::getSizePerThread(encoding);
|
||||
if (rank != sizePerThread.size())
|
||||
return resultVals;
|
||||
|
||||
SmallVector<int64_t> constancy = axisInfo->getConstancy();
|
||||
if (rank != constancy.size())
|
||||
return resultVals;
|
||||
bool hasConstancy = false;
|
||||
for (int i = 0; i < rank; ++i) {
|
||||
if (constancy[i] > sizePerThread[i]) {
|
||||
if (constancy[i] % sizePerThread[i] != 0)
|
||||
// constancy is not evenly covered by sizePerThread
|
||||
return resultVals;
|
||||
// can't move the values across different
|
||||
// "sizePerThread"-sized blocks
|
||||
constancy[i] = sizePerThread[i];
|
||||
}
|
||||
if (elemsPerThread[i] < 1 || constancy[i] < 1)
|
||||
return resultVals;
|
||||
if (!(elemsPerThread[i] % constancy[i] == 0 ||
|
||||
constancy[i] % elemsPerThread[i] == 0))
|
||||
// either the constancy along each dimension must fit
|
||||
// into the elemsPerThread or the other way around
|
||||
return resultVals;
|
||||
if (constancy[i] > 1)
|
||||
hasConstancy = true;
|
||||
}
|
||||
if (!hasConstancy)
|
||||
// nothing to deduplicate
|
||||
return resultVals;
|
||||
|
||||
if (rank > 1) {
|
||||
// reorder the shape and constancy vectors by the axis order:
|
||||
// from the fastest-changing to the smallest-changing axis
|
||||
SmallVector<unsigned> order = triton::gpu::getOrder(encoding);
|
||||
if (rank != order.size())
|
||||
return resultVals;
|
||||
ArrayRef<unsigned> orderRef(order);
|
||||
elemsPerThread = reorder(ArrayRef<unsigned>(elemsPerThread), orderRef);
|
||||
constancy = reorder(ArrayRef<int64_t>(constancy), orderRef);
|
||||
}
|
||||
|
||||
SmallVector<unsigned> strides(rank, 1);
|
||||
for (int i = 1; i < rank; ++i) {
|
||||
strides[i] = strides[i - 1] * elemsPerThread[i - 1];
|
||||
}
|
||||
SmallVector<Value> dedupResultVals;
|
||||
dedupResultVals.reserve(resultVals.size());
|
||||
for (int i = 0; i < resultVals.size(); ++i) {
|
||||
// each coordinate of the orig_idx is "coarsened" using the
|
||||
// constancy along this dimension: the resulting dedup_idx
|
||||
// points to the reused value in the original resultsVal
|
||||
int orig_idx = i;
|
||||
int dedup_idx = 0;
|
||||
for (int j = 0; j < rank; ++j) {
|
||||
int coord_j = orig_idx % elemsPerThread[j];
|
||||
dedup_idx += (coord_j / constancy[j] * constancy[j]) * strides[j];
|
||||
orig_idx /= elemsPerThread[j];
|
||||
}
|
||||
dedupResultVals.push_back(resultVals[dedup_idx]);
|
||||
}
|
||||
|
||||
return dedupResultVals;
|
||||
}
|
||||
|
||||
LogicalResult
|
||||
matchAndRewrite(SourceOp op, OpAdaptor adaptor,
|
||||
@@ -1356,6 +1473,7 @@ public:
|
||||
auto argTy = op->getOperand(0).getType();
|
||||
resultVals = reorderValues(resultVals, argTy, resultTy);
|
||||
}
|
||||
resultVals = maybeDeduplicate(op, resultVals);
|
||||
resultVals =
|
||||
packI32(resultVals, resultTy, rewriter, loc, this->getTypeConverter());
|
||||
resultVals = this->getTypeConverter()->packMfmaOperand(resultVals, resultTy, rewriter, loc);
|
||||
@@ -1367,6 +1485,9 @@ public:
|
||||
return success();
|
||||
}
|
||||
|
||||
protected:
|
||||
ModuleAxisInfoAnalysis &axisAnalysisPass;
|
||||
|
||||
private:
|
||||
int computeCapability;
|
||||
};
|
||||
@@ -1398,8 +1519,9 @@ struct FpToFpOpConversion
|
||||
triton::FpToFpOp, FpToFpOpConversion>::ElementwiseOpConversionBase;
|
||||
|
||||
explicit FpToFpOpConversion(TritonGPUToLLVMTypeConverter &typeConverter,
|
||||
ModuleAxisInfoAnalysis &axisAnalysisPass,
|
||||
int computeCapability, PatternBenefit benefit = 1)
|
||||
: ElementwiseOpConversionBase(typeConverter, benefit),
|
||||
: ElementwiseOpConversionBase(typeConverter, axisAnalysisPass, benefit),
|
||||
computeCapability(computeCapability) {}
|
||||
|
||||
static Value convertBf16ToFp32(Location loc,
|
||||
@@ -1466,14 +1588,14 @@ struct FpToFpOpConversion
|
||||
#endif
|
||||
}
|
||||
|
||||
static Value convertFp32ToFp16(Location loc,
|
||||
ConversionPatternRewriter &rewriter,
|
||||
const Value &v) {
|
||||
static Value convertFp32ToFp16NZ(Location loc,
|
||||
ConversionPatternRewriter &rewriter,
|
||||
const Value &v) {
|
||||
#ifdef USE_ROCM
|
||||
return cvtFp32ToFp16(loc, rewriter, v);
|
||||
#else
|
||||
PTXBuilder builder;
|
||||
auto &cvt = *builder.create("cvt.rn.f16.f32");
|
||||
auto &cvt = *builder.create("cvt.rz.f16.f32");
|
||||
auto res = builder.newOperand("=h");
|
||||
auto operand = builder.newOperand(v, "r");
|
||||
cvt(res, operand);
|
||||
@@ -1557,7 +1679,7 @@ struct FpToFpOpConversion
|
||||
int inVecWidthBits = 32;
|
||||
int outVecWidthBits = 32;
|
||||
if (srcTy.isFloat8E4M3FNUZ() ||
|
||||
(computeCapability >= 90 && srcTy.isFloat8E5M2())) {
|
||||
(computeCapability >= 90 && srcTy.isFloat8E5M2() && dstTy.isF16())) {
|
||||
inVecWidthBits = 16;
|
||||
outVecWidthBits = 32;
|
||||
}
|
||||
@@ -1596,7 +1718,9 @@ struct FpToFpOpConversion
|
||||
dstElementType.isFloat8E5M2FNUZ())
|
||||
#else
|
||||
(computeCapability >= 90 &&
|
||||
(srcElementType.isFloat8E5M2() || dstElementType.isFloat8E5M2())))
|
||||
((srcElementType.isFloat8E5M2() &&
|
||||
(dstElementType.isF16() || dstElementType.isF32())) ||
|
||||
dstElementType.isFloat8E5M2()))) {
|
||||
#endif
|
||||
{
|
||||
numElements = 2;
|
||||
@@ -1610,18 +1734,17 @@ struct FpToFpOpConversion
|
||||
(dstElementType.isFloat8E4M3FNUZ() || dstElementType.isFloat8E5M2()));
|
||||
#endif
|
||||
bool isDstFP32 = dstElementType.isF32();
|
||||
auto cvtFunc =
|
||||
getConversionFunc(useFP16IntermediateSrc ? f16_ty : srcElementType,
|
||||
isDstFP32 ? f16_ty : dstElementType);
|
||||
Type srcType = useFP16IntermediateSrc ? f16_ty : srcElementType;
|
||||
Type dstType = isDstFP32 ? f16_ty : dstElementType;
|
||||
auto cvtFunc = getConversionFunc(srcType, dstType);
|
||||
SmallVector<Value> inVals;
|
||||
for (unsigned i = 0; i < std::min(numElements, operands.size()); i++) {
|
||||
inVals.push_back(operands[i][0]);
|
||||
}
|
||||
if (useFP16IntermediateSrc)
|
||||
for (Value &v : inVals)
|
||||
v = convertFp32ToFp16(loc, rewriter, v);
|
||||
inVals.resize(numElements,
|
||||
undef(typeConverter->convertType(srcElementType)));
|
||||
v = convertFp32ToFp16NZ(loc, rewriter, v);
|
||||
inVals.resize(numElements, undef(typeConverter->convertType(srcType)));
|
||||
SmallVector<Value> outVals = cvtFunc(loc, rewriter, inVals);
|
||||
assert(outVals.size() == inVals.size());
|
||||
outVals.resize(std::min(numElements, operands.size()));
|
||||
@@ -1647,18 +1770,17 @@ Value EmitDualBF16ElementwiseOp(Location loc,
|
||||
}
|
||||
|
||||
struct CmpIOpConversion
|
||||
: public ElementwiseOpConversionBase<triton::gpu::CmpIOp,
|
||||
CmpIOpConversion> {
|
||||
using Base =
|
||||
ElementwiseOpConversionBase<triton::gpu::CmpIOp, CmpIOpConversion>;
|
||||
: public ElementwiseOpConversionBase<arith::CmpIOp, CmpIOpConversion> {
|
||||
using Base = ElementwiseOpConversionBase<arith::CmpIOp, CmpIOpConversion>;
|
||||
using Base::Base;
|
||||
using Adaptor = typename Base::OpAdaptor;
|
||||
|
||||
// An interface to support variant DestOp builder.
|
||||
SmallVector<LLVM::ICmpOp>
|
||||
createDestOps(triton::gpu::CmpIOp op, OpAdaptor adaptor,
|
||||
ConversionPatternRewriter &rewriter, Type elemTy,
|
||||
MultipleOperandsRange operands, Location loc) const {
|
||||
SmallVector<LLVM::ICmpOp> createDestOps(arith::CmpIOp op, OpAdaptor adaptor,
|
||||
ConversionPatternRewriter &rewriter,
|
||||
Type elemTy,
|
||||
MultipleOperandsRange operands,
|
||||
Location loc) const {
|
||||
return {rewriter.create<LLVM::ICmpOp>(
|
||||
loc, elemTy, ArithCmpIPredicateToLLVM(op.getPredicate()),
|
||||
operands[0][0], operands[0][1])};
|
||||
@@ -1689,16 +1811,14 @@ struct CmpIOpConversion
|
||||
};
|
||||
|
||||
struct CmpFOpConversion
|
||||
: public ElementwiseOpConversionBase<triton::gpu::CmpFOp,
|
||||
CmpFOpConversion> {
|
||||
using Base =
|
||||
ElementwiseOpConversionBase<triton::gpu::CmpFOp, CmpFOpConversion>;
|
||||
: public ElementwiseOpConversionBase<arith::CmpFOp, CmpFOpConversion> {
|
||||
using Base = ElementwiseOpConversionBase<arith::CmpFOp, CmpFOpConversion>;
|
||||
using Base::Base;
|
||||
using Adaptor = typename Base::OpAdaptor;
|
||||
|
||||
// An interface to support variant DestOp builder.
|
||||
static SmallVector<LLVM::FCmpOp>
|
||||
createDestOps(triton::gpu::CmpFOp op, OpAdaptor adaptor,
|
||||
createDestOps(arith::CmpFOp op, OpAdaptor adaptor,
|
||||
ConversionPatternRewriter &rewriter, Type elemTy,
|
||||
MultipleOperandsRange operands, Location loc) {
|
||||
return {rewriter.create<LLVM::FCmpOp>(
|
||||
@@ -1890,7 +2010,7 @@ struct FDivOpConversion
|
||||
} else if (64 == bitwidth) {
|
||||
fdiv.o("rn").o("f64");
|
||||
} else {
|
||||
assert(0 && bitwidth && "not supported");
|
||||
llvm::report_fatal_error("Unsupported bitwidth");
|
||||
}
|
||||
|
||||
auto res = ptxBuilder.newOperand(bitwidth == 32 ? "=r" : "=l");
|
||||
@@ -2249,20 +2369,40 @@ struct IndexCastOpLowering
|
||||
}
|
||||
};
|
||||
|
||||
struct SelectOpConversion
|
||||
: ElementwiseOpConversionBase<mlir::arith::SelectOp, SelectOpConversion> {
|
||||
using Base =
|
||||
ElementwiseOpConversionBase<mlir::arith::SelectOp, SelectOpConversion>;
|
||||
using Base::Base;
|
||||
using Adaptor = typename Base::OpAdaptor;
|
||||
|
||||
SmallVector<Value> createDestOps(mlir::arith::SelectOp op, OpAdaptor adaptor,
|
||||
ConversionPatternRewriter &rewriter,
|
||||
Type elemTy, MultipleOperandsRange operands,
|
||||
Location loc) const {
|
||||
std::array<Value, 3> llvmOperands;
|
||||
if (operands[0].size() == 2) {
|
||||
// Case of scalar condition with tensor operands.
|
||||
assert(op.getCondition().getType().isInteger(1));
|
||||
llvmOperands = {adaptor.getCondition(), operands[0][0], operands[0][1]};
|
||||
} else {
|
||||
llvmOperands = {operands[0][0], operands[0][1], operands[0][2]};
|
||||
}
|
||||
return {rewriter.create<LLVM::SelectOp>(
|
||||
loc, llvmOperands[1].getType(), llvmOperands,
|
||||
adaptor.getAttributes().getValue())};
|
||||
}
|
||||
};
|
||||
|
||||
void populateElementwiseOpToLLVMPatterns(
|
||||
TritonGPUToLLVMTypeConverter &typeConverter, RewritePatternSet &patterns,
|
||||
int numWarps, ModuleAxisInfoAnalysis &axisInfoAnalysis,
|
||||
ModuleAllocation &allocation,
|
||||
ConvertTritonGPUOpToLLVMPatternBase::IndexCacheInfo &indexCacheInfo,
|
||||
int computeCapability, PatternBenefit benefit) {
|
||||
#define POPULATE_TERNARY_OP(SRC_OP, DST_OP) \
|
||||
patterns.add<ElementwiseOpConversion<SRC_OP, DST_OP>>(typeConverter, benefit);
|
||||
POPULATE_TERNARY_OP(triton::gpu::SelectOp, LLVM::SelectOp)
|
||||
POPULATE_TERNARY_OP(arith::SelectOp, LLVM::SelectOp)
|
||||
#undef POPULATE_TERNARY_OP
|
||||
|
||||
#define POPULATE_BINARY_OP(SRC_OP, DST_OP) \
|
||||
patterns.add<ElementwiseOpConversion<SRC_OP, DST_OP>>(typeConverter, benefit);
|
||||
patterns.add<ElementwiseOpConversion<SRC_OP, DST_OP>>( \
|
||||
typeConverter, axisInfoAnalysis, benefit);
|
||||
POPULATE_BINARY_OP(arith::SubIOp, LLVM::SubOp) // -
|
||||
POPULATE_BINARY_OP(arith::AddIOp, LLVM::AddOp) // +
|
||||
POPULATE_BINARY_OP(arith::MulIOp, LLVM::MulOp) // *
|
||||
@@ -2286,7 +2426,8 @@ void populateElementwiseOpToLLVMPatterns(
|
||||
#undef POPULATE_BINARY_OP
|
||||
|
||||
#define POPULATE_UNARY_OP(SRC_OP, DST_OP) \
|
||||
patterns.add<ElementwiseOpConversion<SRC_OP, DST_OP>>(typeConverter, benefit);
|
||||
patterns.add<ElementwiseOpConversion<SRC_OP, DST_OP>>( \
|
||||
typeConverter, axisInfoAnalysis, benefit);
|
||||
POPULATE_UNARY_OP(arith::TruncIOp, LLVM::TruncOp)
|
||||
POPULATE_UNARY_OP(arith::ExtSIOp, LLVM::SExtOp)
|
||||
POPULATE_UNARY_OP(arith::ExtUIOp, LLVM::ZExtOp)
|
||||
@@ -2302,29 +2443,33 @@ void populateElementwiseOpToLLVMPatterns(
|
||||
POPULATE_UNARY_OP(triton::PtrToIntOp, LLVM::PtrToIntOp)
|
||||
#undef POPULATE_UNARY_OP
|
||||
|
||||
patterns.add<AbsIOpConversion>(typeConverter, benefit);
|
||||
patterns.add<AbsFOpConversion>(typeConverter, benefit);
|
||||
patterns.add<CmpIOpConversion>(typeConverter, benefit);
|
||||
patterns.add<CmpFOpConversion>(typeConverter, benefit);
|
||||
patterns.add<AbsIOpConversion>(typeConverter, axisInfoAnalysis, benefit);
|
||||
patterns.add<AbsFOpConversion>(typeConverter, axisInfoAnalysis, benefit);
|
||||
patterns.add<CmpIOpConversion>(typeConverter, axisInfoAnalysis, benefit);
|
||||
patterns.add<CmpFOpConversion>(typeConverter, axisInfoAnalysis, benefit);
|
||||
|
||||
patterns.add<FDivOpConversion>(typeConverter, benefit);
|
||||
patterns.add<FSubOpConversion>(typeConverter, benefit);
|
||||
patterns.add<FAddOpConversion>(typeConverter, benefit);
|
||||
patterns.add<FMulOpConversion>(typeConverter, benefit);
|
||||
patterns.add<FDivOpConversion>(typeConverter, axisInfoAnalysis, benefit);
|
||||
patterns.add<FSubOpConversion>(typeConverter, axisInfoAnalysis, benefit);
|
||||
patterns.add<FAddOpConversion>(typeConverter, axisInfoAnalysis, benefit);
|
||||
patterns.add<FMulOpConversion>(typeConverter, axisInfoAnalysis, benefit);
|
||||
|
||||
patterns.add<ExtFOpConversion>(typeConverter, benefit);
|
||||
patterns.add<TruncFOpConversion>(typeConverter, benefit);
|
||||
patterns.add<FPToSIOpConversion>(typeConverter, benefit);
|
||||
patterns.add<SIToFPOpConversion>(typeConverter, benefit);
|
||||
patterns.add<IndexCastOpLowering>(typeConverter, benefit);
|
||||
patterns.add<SelectOpConversion>(typeConverter, axisInfoAnalysis, benefit);
|
||||
patterns.add<ExtFOpConversion>(typeConverter, axisInfoAnalysis, benefit);
|
||||
patterns.add<TruncFOpConversion>(typeConverter, axisInfoAnalysis, benefit);
|
||||
patterns.add<FPToSIOpConversion>(typeConverter, axisInfoAnalysis, benefit);
|
||||
patterns.add<SIToFPOpConversion>(typeConverter, axisInfoAnalysis, benefit);
|
||||
patterns.add<IndexCastOpLowering>(typeConverter, axisInfoAnalysis, benefit);
|
||||
|
||||
patterns.add<FpToFpOpConversion>(typeConverter, computeCapability, benefit);
|
||||
patterns.add<FpToFpOpConversion>(typeConverter, axisInfoAnalysis,
|
||||
computeCapability, benefit);
|
||||
|
||||
patterns.add<ExternElementwiseOpConversion>(typeConverter, benefit);
|
||||
patterns.add<ElementwiseInlineAsmOpConversion>(typeConverter, benefit);
|
||||
patterns.add<ExternElementwiseOpConversion>(typeConverter, axisInfoAnalysis,
|
||||
benefit);
|
||||
patterns.add<ElementwiseInlineAsmOpConversion>(typeConverter,
|
||||
axisInfoAnalysis, benefit);
|
||||
// ExpOpConversionApprox will try using ex2.approx if the input type is
|
||||
// FP32. For other input types, ExpOpConversionApprox will return failure and
|
||||
// ElementwiseOpConversion<math::ExpOp, math::ExpOp> defined below will call
|
||||
// __nv_expf for higher-precision calculation
|
||||
patterns.add<ExpOpConversionApprox>(typeConverter, benefit);
|
||||
patterns.add<ExpOpConversionApprox>(typeConverter, axisInfoAnalysis, benefit);
|
||||
}
|
||||
|
||||
@@ -28,6 +28,8 @@ static CUtensorMapDataType getCUtensorMapDataType(Type ty) {
|
||||
return CUtensorMapDataType::CU_TENSOR_MAP_DATA_TYPE_BFLOAT16;
|
||||
} else if (ty.isF32()) {
|
||||
return CUtensorMapDataType::CU_TENSOR_MAP_DATA_TYPE_FLOAT32;
|
||||
} else if (ty.getIntOrFloatBitWidth() == 8) {
|
||||
return CUtensorMapDataType::CU_TENSOR_MAP_DATA_TYPE_UINT8;
|
||||
} else {
|
||||
llvm::report_fatal_error("Unsupported elemTy for InsertSliceAsyncV2Op");
|
||||
return CUtensorMapDataType::CU_TENSOR_MAP_DATA_TYPE_FLOAT16;
|
||||
@@ -930,6 +932,11 @@ private:
|
||||
return -1 -
|
||||
op.getValue().dyn_cast<IntegerAttr>().getValue().getZExtValue();
|
||||
}
|
||||
if (!isa<BlockArgument>(v) &&
|
||||
!isa<mlir::UnrealizedConversionCastOp, arith::ExtSIOp>(
|
||||
v.getDefiningOp()))
|
||||
llvm::report_fatal_error(
|
||||
"Operand of `MakeTensorPtrOp` is not the function's argument");
|
||||
if (v.getDefiningOp() &&
|
||||
isa<mlir::UnrealizedConversionCastOp>(v.getDefiningOp())) {
|
||||
return getArgIdx(v.getDefiningOp()->getOperand(0));
|
||||
@@ -974,12 +981,16 @@ private:
|
||||
namespace {
|
||||
void createBarrier(ConversionPatternRewriter &rewriter, Location loc,
|
||||
int numCTAs) {
|
||||
#ifdef USE_ROCM
|
||||
barrier();
|
||||
#else
|
||||
if (numCTAs == 1) {
|
||||
barrier();
|
||||
} else {
|
||||
rewriter.create<triton::nvidia_gpu::ClusterArriveOp>(loc, false);
|
||||
rewriter.create<triton::nvidia_gpu::ClusterWaitOp>(loc);
|
||||
}
|
||||
#endif
|
||||
}
|
||||
} // namespace
|
||||
|
||||
@@ -1001,6 +1012,7 @@ struct AtomicCASOpConversion
|
||||
LogicalResult
|
||||
matchAndRewrite(triton::AtomicCASOp op, OpAdaptor adaptor,
|
||||
ConversionPatternRewriter &rewriter) const override {
|
||||
// extract relevant info from Module
|
||||
auto loc = op.getLoc();
|
||||
MLIRContext *ctx = rewriter.getContext();
|
||||
Value ptr = op.getPtr();
|
||||
@@ -1009,6 +1021,7 @@ struct AtomicCASOpConversion
|
||||
Value llCmp = adaptor.getCmp();
|
||||
Value llVal = adaptor.getVal();
|
||||
|
||||
// prep data by unpacking to get data ready
|
||||
auto ptrElements = getTypeConverter()->unpackLLElements(
|
||||
loc, llPtr, rewriter, op.getPtr().getType());
|
||||
auto cmpElements = getTypeConverter()->unpackLLElements(
|
||||
@@ -1016,53 +1029,106 @@ struct AtomicCASOpConversion
|
||||
auto valElements = getTypeConverter()->unpackLLElements(
|
||||
loc, llVal, rewriter, op.getVal().getType());
|
||||
|
||||
auto TensorTy = op.getResult().getType().dyn_cast<RankedTensorType>();
|
||||
// deal with tensor or scalar
|
||||
auto valueTy = op.getResult().getType();
|
||||
auto TensorTy = valueTy.dyn_cast<RankedTensorType>();
|
||||
Type valueElemTy =
|
||||
TensorTy ? getTypeConverter()->convertType(TensorTy.getElementType())
|
||||
: op.getResult().getType();
|
||||
auto tid = tid_val();
|
||||
Value pred = icmp_eq(tid, i32_val(0));
|
||||
: valueTy;
|
||||
auto valueElemNBits = valueElemTy.getIntOrFloatBitWidth();
|
||||
auto elemsPerThread = getTotalElemsPerThread(op.getVal().getType());
|
||||
// vec = 1 for scalar
|
||||
auto vec = getVectorSize(op.getPtr());
|
||||
// tensor
|
||||
if (TensorTy) {
|
||||
auto valTy = op.getVal().getType().cast<RankedTensorType>();
|
||||
vec = std::min<unsigned>(vec, valTy.getElementType().isF16() ? 2 : 1);
|
||||
}
|
||||
|
||||
Value casPtr = ptrElements[0];
|
||||
Value casCmp = cmpElements[0];
|
||||
Value casVal = valElements[0];
|
||||
Value mask = getMask(valueTy, rewriter, loc);
|
||||
auto vecTy = vec_ty(valueElemTy, vec);
|
||||
SmallVector<Value> resultVals(elemsPerThread);
|
||||
|
||||
// Build blocks to bypass the atomic instruction for ~rmwMask.
|
||||
auto *curBlock = rewriter.getInsertionBlock();
|
||||
auto *endBlock = curBlock->splitBlock(rewriter.getInsertionPoint());
|
||||
auto *atomicBlock = rewriter.createBlock(
|
||||
curBlock->getParent(), std::next(Region::iterator(curBlock)));
|
||||
// atomic ops
|
||||
for (size_t i = 0; i < elemsPerThread; i += vec) {
|
||||
Value casVal = undef(vecTy);
|
||||
for (int ii = 0; ii < vec; ++ii) {
|
||||
Value iiVal = createIndexAttrConstant(
|
||||
rewriter, loc, getTypeConverter()->getIndexType(), ii);
|
||||
casVal = insert_element(vecTy, casVal, valElements[i + ii], iiVal);
|
||||
}
|
||||
|
||||
// Fill entry block with global memory barrier and conditional branch.
|
||||
rewriter.setInsertionPointToEnd(curBlock);
|
||||
Value atomPtr = getSharedMemoryBase(loc, rewriter, op.getOperation());
|
||||
atomPtr = bitcast(atomPtr, ptr_ty(valueElemTy, 3));
|
||||
rewriter.create<LLVM::CondBrOp>(loc, pred, atomicBlock, endBlock);
|
||||
Value casPtr = ptrElements[i];
|
||||
Value casCmp = cmpElements[i];
|
||||
casVal = valElements[i];
|
||||
|
||||
// Build main block with atomic_cmpxchg.
|
||||
rewriter.setInsertionPointToEnd(atomicBlock);
|
||||
// use op
|
||||
if (TensorTy) { // for tensor
|
||||
auto retType = vec == 1 ? valueElemTy : vecTy;
|
||||
// TODO: USE ATOMIC CAS OP on Tensor
|
||||
auto successOrdering = LLVM::AtomicOrdering::acq_rel;
|
||||
auto failureOrdering = LLVM::AtomicOrdering::monotonic;
|
||||
auto cmpxchg = rewriter.create<LLVM::AtomicCmpXchgOp>(
|
||||
loc, casPtr, casCmp, casVal, successOrdering, failureOrdering,
|
||||
StringRef("agent"));
|
||||
|
||||
auto successOrdering = LLVM::AtomicOrdering::acq_rel;
|
||||
auto failureOrdering = LLVM::AtomicOrdering::monotonic;
|
||||
auto cmpxchg = rewriter.create<LLVM::AtomicCmpXchgOp>(
|
||||
loc, casPtr, casCmp, casVal, successOrdering,
|
||||
failureOrdering, StringRef("agent"));
|
||||
// Extract the new_loaded value from the pair.
|
||||
Value newLoaded = extract_val(valueElemTy, cmpxchg, 0);
|
||||
// Extract the new_loaded value from the pair.
|
||||
Value ret = extract_val(valueElemTy, cmpxchg, i);
|
||||
|
||||
for (int ii = 0; ii < vec; ++ii) {
|
||||
resultVals[i + ii] =
|
||||
vec == 1 ? ret : extract_element(valueElemTy, ret, i32_val(ii));
|
||||
}
|
||||
} else { // for scalar
|
||||
// Build blocks to bypass the atomic instruction for ~rmwMask.
|
||||
auto *curBlock = rewriter.getInsertionBlock();
|
||||
auto *endBlock = curBlock->splitBlock(rewriter.getInsertionPoint());
|
||||
auto *atomicBlock = rewriter.createBlock(
|
||||
curBlock->getParent(), std::next(Region::iterator(curBlock)));
|
||||
|
||||
store(newLoaded, atomPtr);
|
||||
// Fill entry block with global memory barrier and conditional branch.
|
||||
rewriter.setInsertionPointToEnd(curBlock);
|
||||
Value atomPtr = getSharedMemoryBase(loc, rewriter, op.getOperation());
|
||||
atomPtr = bitcast(atomPtr, ptr_ty(valueElemTy, 3));
|
||||
auto tid = tid_val();
|
||||
Value pred = icmp_eq(tid, i32_val(i));
|
||||
rewriter.create<LLVM::CondBrOp>(loc, pred, atomicBlock, endBlock);
|
||||
|
||||
rewriter.create<LLVM::BrOp>(loc, ValueRange(), endBlock);
|
||||
// Build main block with atomic_cmpxchg.
|
||||
rewriter.setInsertionPointToEnd(atomicBlock);
|
||||
|
||||
// Build the last block: synced load from shared memory, exit.
|
||||
rewriter.setInsertionPointToStart(endBlock);
|
||||
auto successOrdering = LLVM::AtomicOrdering::acq_rel;
|
||||
auto failureOrdering = LLVM::AtomicOrdering::monotonic;
|
||||
auto cmpxchg = rewriter.create<LLVM::AtomicCmpXchgOp>(
|
||||
loc, casPtr, casCmp, casVal, successOrdering, failureOrdering,
|
||||
StringRef("agent"));
|
||||
|
||||
GCNBuilder BuilderMemfenceLDS;
|
||||
BuilderMemfenceLDS.create<>("s_waitcnt lgkmcnt(0)")->operator()();
|
||||
BuilderMemfenceLDS.launch(rewriter, loc, void_ty(ctx));
|
||||
barrier();
|
||||
Value ret = load(atomPtr);
|
||||
rewriter.replaceOp(op, {ret});
|
||||
// Extract the new_loaded value from the pair.
|
||||
Value newLoaded = extract_val(valueElemTy, cmpxchg, 0);
|
||||
|
||||
store(newLoaded, atomPtr);
|
||||
|
||||
rewriter.create<LLVM::BrOp>(loc, ValueRange(), endBlock);
|
||||
|
||||
// Build the last block: synced load from shared memory, exit.
|
||||
rewriter.setInsertionPointToStart(endBlock);
|
||||
|
||||
GCNBuilder BuilderMemfenceLDS;
|
||||
BuilderMemfenceLDS.create<>("s_waitcnt lgkmcnt(0)")->operator()();
|
||||
BuilderMemfenceLDS.launch(rewriter, loc, void_ty(ctx));
|
||||
barrier();
|
||||
Value ret = load(atomPtr);
|
||||
rewriter.replaceOp(op, {ret});
|
||||
}
|
||||
}
|
||||
|
||||
// replace op
|
||||
if (TensorTy) {
|
||||
Type structTy = getTypeConverter()->convertType(TensorTy);
|
||||
Value resultStruct = getTypeConverter()->packLLElements(
|
||||
loc, resultVals, rewriter, structTy);
|
||||
rewriter.replaceOp(op, {resultStruct});
|
||||
}
|
||||
return success();
|
||||
}
|
||||
|
||||
@@ -1095,40 +1161,81 @@ struct AtomicCASOpConversion
|
||||
TensorTy ? getTypeConverter()->convertType(TensorTy.getElementType())
|
||||
: valueTy;
|
||||
auto valueElemNBits = valueElemTy.getIntOrFloatBitWidth();
|
||||
auto elemsPerThread = getTotalElemsPerThread(op.getVal().getType());
|
||||
// vec = 1 for scalar
|
||||
auto vec = getVectorSize(op.getPtr());
|
||||
// tensor
|
||||
if (TensorTy) {
|
||||
auto valTy = op.getVal().getType().cast<RankedTensorType>();
|
||||
vec = std::min<unsigned>(vec, valTy.getElementType().isF16() ? 2 : 1);
|
||||
}
|
||||
|
||||
Value mask = getMask(valueTy, rewriter, loc);
|
||||
auto vecTy = vec_ty(valueElemTy, vec);
|
||||
SmallVector<Value> resultVals(elemsPerThread);
|
||||
|
||||
Value atomPtr = getSharedMemoryBase(loc, rewriter, op.getOperation());
|
||||
atomPtr = bitcast(atomPtr, ptr_ty(valueElemTy, 3));
|
||||
Value casPtr = ptrElements[0];
|
||||
Value casCmp = cmpElements[0];
|
||||
Value casVal = valElements[0];
|
||||
for (size_t i = 0; i < elemsPerThread; i += vec) {
|
||||
Value casVal = undef(vecTy);
|
||||
for (int ii = 0; ii < vec; ++ii) {
|
||||
Value iiVal = createIndexAttrConstant(
|
||||
rewriter, loc, getTypeConverter()->getIndexType(), ii);
|
||||
casVal = insert_element(vecTy, casVal, valElements[i + ii], iiVal);
|
||||
}
|
||||
|
||||
PTXBuilder ptxBuilderAtomicCAS;
|
||||
auto *dstOpr = ptxBuilderAtomicCAS.newOperand("=r", /*init=*/true);
|
||||
auto *ptrOpr = ptxBuilderAtomicCAS.newAddrOperand(casPtr, "l");
|
||||
auto *cmpOpr = ptxBuilderAtomicCAS.newOperand(casCmp, "r");
|
||||
auto *valOpr = ptxBuilderAtomicCAS.newOperand(casVal, "r");
|
||||
auto &atom = *ptxBuilderAtomicCAS.create<PTXInstr>("atom");
|
||||
std::string semStr;
|
||||
llvm::raw_string_ostream os(semStr);
|
||||
os << op.getSem();
|
||||
atom.global().o(semStr).o("cas").o("b32");
|
||||
atom(dstOpr, ptrOpr, cmpOpr, valOpr).predicate(mask);
|
||||
auto old = ptxBuilderAtomicCAS.launch(rewriter, loc, valueElemTy);
|
||||
createBarrier(rewriter, loc, numCTAs);
|
||||
Value casPtr = ptrElements[i];
|
||||
Value casCmp = cmpElements[i];
|
||||
casVal = valElements[i];
|
||||
PTXBuilder ptxBuilderAtomicCAS;
|
||||
std::string tyId = valueElemNBits * vec == 64
|
||||
? "l"
|
||||
: (valueElemNBits * vec == 32 ? "r" : "h");
|
||||
auto *dstOpr = ptxBuilderAtomicCAS.newOperand("=" + tyId, /*init=*/true);
|
||||
auto *ptrOpr = ptxBuilderAtomicCAS.newAddrOperand(casPtr, "l");
|
||||
auto *cmpOpr = ptxBuilderAtomicCAS.newOperand(casCmp, tyId);
|
||||
auto *valOpr = ptxBuilderAtomicCAS.newOperand(casVal, tyId);
|
||||
auto &atom = *ptxBuilderAtomicCAS.create<PTXInstr>("atom");
|
||||
auto sTy = "b" + std::to_string(valueElemNBits);
|
||||
std::string semStr;
|
||||
llvm::raw_string_ostream os(semStr);
|
||||
os << op.getSem();
|
||||
auto scope = stringifyMemSyncScope(op.getScope()).str();
|
||||
atom.global().o(semStr).o(scope).o("cas").o(sTy);
|
||||
atom(dstOpr, ptrOpr, cmpOpr, valOpr).predicate(mask);
|
||||
|
||||
PTXBuilder ptxBuilderStore;
|
||||
auto *dstOprStore = ptxBuilderStore.newAddrOperand(atomPtr, "r");
|
||||
auto *valOprStore = ptxBuilderStore.newOperand(old, "r");
|
||||
auto &st = *ptxBuilderStore.create<PTXInstr>("st");
|
||||
st.shared().o("b32");
|
||||
st(dstOprStore, valOprStore).predicate(mask);
|
||||
auto ASMReturnTy = void_ty(ctx);
|
||||
ptxBuilderStore.launch(rewriter, loc, ASMReturnTy);
|
||||
createBarrier(rewriter, loc, numCTAs);
|
||||
Value ret = load(atomPtr);
|
||||
createBarrier(rewriter, loc, numCTAs);
|
||||
rewriter.replaceOp(op, {ret});
|
||||
if (TensorTy) {
|
||||
auto retType = vec == 1 ? valueElemTy : vecTy;
|
||||
auto ret = ptxBuilderAtomicCAS.launch(rewriter, loc, retType);
|
||||
for (int ii = 0; ii < vec; ++ii) {
|
||||
resultVals[i + ii] =
|
||||
vec == 1 ? ret : extract_element(valueElemTy, ret, i32_val(ii));
|
||||
}
|
||||
} else {
|
||||
auto old = ptxBuilderAtomicCAS.launch(rewriter, loc, valueElemTy);
|
||||
createBarrier(rewriter, loc, numCTAs);
|
||||
Value atomPtr = getSharedMemoryBase(loc, rewriter, op.getOperation());
|
||||
atomPtr = bitcast(atomPtr, ptr_ty(valueElemTy, 3));
|
||||
// Only threads with mask = True store the result
|
||||
PTXBuilder ptxBuilderStore;
|
||||
auto *dstOprStore = ptxBuilderStore.newAddrOperand(atomPtr, "r");
|
||||
auto *valOprStore = ptxBuilderStore.newOperand(old, "r");
|
||||
auto &st = *ptxBuilderStore.create<PTXInstr>("st");
|
||||
st.shared().o(sTy);
|
||||
st(dstOprStore, valOprStore).predicate(mask);
|
||||
auto ASMReturnTy = void_ty(ctx);
|
||||
ptxBuilderStore.launch(rewriter, loc, ASMReturnTy);
|
||||
createBarrier(rewriter, loc, numCTAs);
|
||||
Value ret = load(atomPtr);
|
||||
createBarrier(rewriter, loc, numCTAs);
|
||||
rewriter.replaceOp(op, {ret});
|
||||
}
|
||||
}
|
||||
|
||||
if (TensorTy) {
|
||||
Type structTy = getTypeConverter()->convertType(TensorTy);
|
||||
Value resultStruct = getTypeConverter()->packLLElements(
|
||||
loc, resultVals, rewriter, structTy);
|
||||
rewriter.replaceOp(op, {resultStruct});
|
||||
}
|
||||
return success();
|
||||
}
|
||||
#endif // USE_ROCM
|
||||
@@ -1360,7 +1467,8 @@ struct AtomicRMWOpConversion
|
||||
auto *ptrOpr = ptxBuilderAtomicRMW.newAddrOperand(rmwPtr, "l");
|
||||
auto *valOpr = ptxBuilderAtomicRMW.newOperand(rmwVal, tyId);
|
||||
|
||||
auto &atom = ptxBuilderAtomicRMW.create<>("atom")->global().o("gpu");
|
||||
auto scope = stringifyMemSyncScope(op.getScope()).str();
|
||||
auto &atom = ptxBuilderAtomicRMW.create<>("atom")->global().o(scope);
|
||||
auto rmwOp = stringifyRMWOp(atomicRmwAttr).str();
|
||||
auto sBits = std::to_string(valueElemNBits);
|
||||
switch (atomicRmwAttr) {
|
||||
@@ -2001,6 +2109,11 @@ private:
|
||||
return -1 -
|
||||
op.getValue().dyn_cast<IntegerAttr>().getValue().getZExtValue();
|
||||
}
|
||||
if (!isa<BlockArgument>(v) &&
|
||||
!isa<mlir::UnrealizedConversionCastOp, arith::ExtSIOp>(
|
||||
v.getDefiningOp()))
|
||||
llvm::report_fatal_error(
|
||||
"Operand of `MakeTensorPtrOp` is not the function's argument");
|
||||
if (v.getDefiningOp() &&
|
||||
isa<mlir::UnrealizedConversionCastOp>(v.getDefiningOp())) {
|
||||
return getArgIdx(v.getDefiningOp()->getOperand(0));
|
||||
|
||||
@@ -2,6 +2,8 @@
|
||||
#include "Utility.h"
|
||||
#include "mlir/Dialect/ControlFlow/IR/ControlFlowOps.h"
|
||||
|
||||
namespace {
|
||||
|
||||
using namespace mlir;
|
||||
using namespace mlir::triton;
|
||||
|
||||
@@ -10,6 +12,30 @@ using ::mlir::LLVM::getSRegValue;
|
||||
using ::mlir::triton::gpu::getTotalElemsPerThread;
|
||||
using ::mlir::triton::gpu::SharedEncodingAttr;
|
||||
|
||||
Value llGetPid(int axis, Location loc, ModuleOp moduleOp,
|
||||
ConversionPatternRewriter &rewriter) {
|
||||
assert(axis >= 0);
|
||||
assert(axis < 3);
|
||||
assert(moduleOp);
|
||||
#ifdef USE_ROCM
|
||||
static constexpr mlir::gpu::Dimension dims[] = {mlir::gpu::Dimension::x,
|
||||
mlir::gpu::Dimension::y,
|
||||
mlir::gpu::Dimension::z};
|
||||
Value blockId = rewriter.create<::mlir::gpu::BlockIdOp>(loc, dims[axis]);
|
||||
return rewriter.create<arith::IndexCastOp>(loc, i32_ty, blockId);
|
||||
#else
|
||||
// It is not easy to get the compute capability here, so we use numCTAs to
|
||||
// decide the semantic of GetProgramIdOp. If numCTAs = 1, then
|
||||
// GetProgramIdOp is converted to "%ctaid", otherwise it is converted to
|
||||
// "%clusterid".
|
||||
int numCTAs = triton::gpu::TritonGPUDialect::getNumCTAs(moduleOp);
|
||||
|
||||
std::string sreg = numCTAs == 1 ? "%ctaid." : "%clusterid.";
|
||||
sreg.append(1, 'x' + axis); // 0 -> 'x', 1 -> 'y', 2 -> 'z'
|
||||
return getSRegValue(rewriter, loc, sreg);
|
||||
#endif
|
||||
}
|
||||
|
||||
struct ReturnOpConversion : public ConvertOpToLLVMPattern<triton::ReturnOp> {
|
||||
using ConvertOpToLLVMPattern<triton::ReturnOp>::ConvertOpToLLVMPattern;
|
||||
|
||||
@@ -91,6 +117,12 @@ struct BroadcastOpConversion
|
||||
}
|
||||
};
|
||||
|
||||
// The input print op contains:
|
||||
// - a "prefix" (string) specified by the user, and
|
||||
// - one or more "operands" (tensors).
|
||||
//
|
||||
// For each operand, we print all of the values contained in this GPU thread,
|
||||
// one per line, along with the index of the value in its tensor.
|
||||
struct PrintOpConversion
|
||||
: public ConvertTritonGPUOpToLLVMPattern<triton::PrintOp> {
|
||||
using ConvertTritonGPUOpToLLVMPattern<
|
||||
@@ -100,50 +132,192 @@ struct PrintOpConversion
|
||||
matchAndRewrite(triton::PrintOp op, OpAdaptor adaptor,
|
||||
ConversionPatternRewriter &rewriter) const override {
|
||||
auto loc = op->getLoc();
|
||||
SmallVector<Value, 16> operands;
|
||||
for (size_t i = 0; i < op.getNumOperands(); i++) {
|
||||
auto sub_operands = getTypeConverter()->unpackLLElements(
|
||||
loc, adaptor.getOperands()[i], rewriter, op.getOperand(i).getType());
|
||||
for (auto elem : sub_operands) {
|
||||
operands.push_back(elem);
|
||||
Value prefixStr =
|
||||
LLVM::addStringToModule(loc, rewriter, "printfPrefix_", op.getPrefix());
|
||||
|
||||
auto getPid = [&](int axis) {
|
||||
return llGetPid(axis, loc, op->getParentOfType<ModuleOp>(), rewriter);
|
||||
};
|
||||
std::array<Value, 3> pid = {getPid(0), getPid(1), getPid(2)};
|
||||
|
||||
// Simple printf of a string without any tensors.
|
||||
if (op.getNumOperands() == 0) {
|
||||
std::string formatStr;
|
||||
llvm::raw_string_ostream os(formatStr);
|
||||
#ifdef USE_ROCM
|
||||
os << "pid (" << getFormatSubstr(pid[0]) << ", "
|
||||
<< getFormatSubstr(pid[1]) << ", " << getFormatSubstr(pid[2]) << ")" << op.getPrefix().str();
|
||||
llPrintfHIP(loc, op->getParentOfType<mlir::ModuleOp>(), formatStr,
|
||||
{pid[0], pid[1], pid[2]}, rewriter);
|
||||
#else
|
||||
os << "pid (" << getFormatSubstr(pid[0]) << ", "
|
||||
<< getFormatSubstr(pid[1]) << ", " << getFormatSubstr(pid[2]) << ")%s";
|
||||
llPrintf(formatStr, {pid[0], pid[1], pid[2], prefixStr}, rewriter);
|
||||
#endif
|
||||
} else {
|
||||
for (size_t i = 0; i < op.getNumOperands(); i++) {
|
||||
// Elements of the tensor that are resident in this GPU thread.
|
||||
auto elems = getTypeConverter()->unpackLLElements(
|
||||
loc, adaptor.getOperands()[i], rewriter,
|
||||
op.getOperand(i).getType());
|
||||
|
||||
// Get the indices of `elems` within the tensor. Note that if `elems`
|
||||
// has an "interesting" layout, then these will not be in any
|
||||
// particularly nice order.
|
||||
|
||||
// Extract the shape of the tensor being printed and use it to figure
|
||||
// out how many digits we need for each of the dimensions.
|
||||
SmallVector<int, 8> dimWidths;
|
||||
SmallVector<SmallVector<Value>> indices;
|
||||
if (auto rankedTy =
|
||||
op.getOperand(i).getType().dyn_cast<RankedTensorType>()) {
|
||||
indices =
|
||||
emitIndices(loc, rewriter, rankedTy.getEncoding(), rankedTy);
|
||||
for (int64_t dim : rankedTy.getShape()) {
|
||||
if (dim > 0) {
|
||||
dimWidths.push_back(static_cast<int>(std::ceil(std::log10(dim))));
|
||||
} else {
|
||||
dimWidths.push_back(0);
|
||||
}
|
||||
}
|
||||
} else {
|
||||
// We're printing a scalar.
|
||||
assert(elems.size() == 1);
|
||||
indices.push_back({});
|
||||
}
|
||||
|
||||
if (!elems.empty()) {
|
||||
printTensor(op, prefixStr, /*operand=*/i,
|
||||
/*numOperands=*/op.getNumOperands(), elems, pid, indices,
|
||||
dimWidths, rewriter);
|
||||
}
|
||||
}
|
||||
}
|
||||
std::string formatStr;
|
||||
llvm::raw_string_ostream os(formatStr);
|
||||
os << op.getPrefix();
|
||||
if (!operands.empty()) {
|
||||
os << getFormatSubstr(operands[0]);
|
||||
}
|
||||
|
||||
for (size_t i = 1; i < operands.size(); ++i) {
|
||||
os << ", " << getFormatSubstr(operands[i]);
|
||||
}
|
||||
#ifdef USE_ROCM
|
||||
llPrintfHIP(loc, op->getParentOfType<mlir::ModuleOp>(), formatStr, operands,
|
||||
rewriter);
|
||||
#else
|
||||
llPrintf(formatStr, operands, rewriter);
|
||||
#endif
|
||||
rewriter.eraseOp(op);
|
||||
return success();
|
||||
}
|
||||
|
||||
std::string getFormatSubstr(Value value) const {
|
||||
void printTensor(triton::PrintOp op, Value prefixStr, size_t operand, size_t numOperands,
|
||||
ArrayRef<Value> elems, std::array<Value, 3> pid,
|
||||
ArrayRef<SmallVector<Value>> indices,
|
||||
ArrayRef<int> dimWidths,
|
||||
ConversionPatternRewriter &rewriter) const {
|
||||
assert(!elems.empty());
|
||||
assert(elems.size() == indices.size());
|
||||
assert(dimWidths.size() == indices.front().size());
|
||||
|
||||
size_t rank = dimWidths.size();
|
||||
|
||||
// Format is:
|
||||
// pid (<x>, <y>, <z>) idx (<i1>, <i2>, ...)<prefix> (operand <n>) <elem>
|
||||
// where we leave off "(operand <n>)" if there's only one operand.
|
||||
//
|
||||
// The Python wrapper munges `prefix` so that it prints nicely (e.g. starts
|
||||
// with " " and ends with ": ").
|
||||
|
||||
Value formatStrValue;
|
||||
ConvertTritonGPUOpToLLVMPatternBase::PrintFormatting formatting;
|
||||
for (int i = 0; i < elems.size(); i++) {
|
||||
std::string formatStr;
|
||||
llvm::raw_string_ostream os(formatStr);
|
||||
|
||||
// nvptx printf can only accept 32 args; if we pass more than that, it
|
||||
// will print garbage for the trailing args.
|
||||
constexpr int kMaxPrintfOperands = 32;
|
||||
SmallVector<Value, kMaxPrintfOperands> printfOperands;
|
||||
|
||||
// TODO(jlebar): We really should pad the pid, but because the max pid is
|
||||
// not known at compile-time, this would require nontrivial device-side
|
||||
// work.
|
||||
os << "pid (";
|
||||
for (int j = 0; j < pid.size(); j++) {
|
||||
if (j != 0) {
|
||||
os << ", ";
|
||||
}
|
||||
os << getFormatSubstr(pid[j]);
|
||||
printfOperands.push_back(pid[j]);
|
||||
}
|
||||
os << ") ";
|
||||
|
||||
// If `rank` is large enough, we could end up exceeding
|
||||
// kMaxPrintfOperands. In that case, just truncate the index.
|
||||
// (Subtract 2 because we're going to add two operands after the index.)
|
||||
int maxAllowedRank = kMaxPrintfOperands - printfOperands.size() - 2;
|
||||
|
||||
os << "idx (";
|
||||
const auto &index = indices[i];
|
||||
for (size_t dim = 0; dim < index.size(); dim++) {
|
||||
if (dim != 0) {
|
||||
os << ", ";
|
||||
}
|
||||
if (dim == maxAllowedRank) {
|
||||
os << "... (truncated)";
|
||||
break;
|
||||
}
|
||||
os << getFormatSubstr(index[dim], /*width=*/dimWidths[dim]);
|
||||
printfOperands.push_back(index[dim]);
|
||||
}
|
||||
os << ")";
|
||||
|
||||
#if USE_ROCM
|
||||
os << op.getPrefix().str();
|
||||
#else
|
||||
os << "%s";
|
||||
printfOperands.push_back(prefixStr);
|
||||
#endif
|
||||
|
||||
if (numOperands > 1) {
|
||||
os << "(operand " << operand << ") ";
|
||||
}
|
||||
|
||||
auto elem = elems[i];
|
||||
os << getFormatSubstr(elem);
|
||||
printfOperands.push_back(elem);
|
||||
|
||||
// It's the same format string each iteration, but it's a lot easier if we
|
||||
// construct the format string at the same time as we populate
|
||||
// printfOperands. But we don't want to create BLOCK_SIZE duplicate
|
||||
// strings, so we cache the Value.
|
||||
if (i == 0) {
|
||||
#if USE_ROCM
|
||||
formatting = llPrintfHIP(op->getLoc(), op->getParentOfType<mlir::ModuleOp>(), formatStr,
|
||||
printfOperands, rewriter);
|
||||
#else
|
||||
formatStrValue = llPrintf(formatStr, printfOperands, rewriter);
|
||||
#endif
|
||||
} else {
|
||||
#ifdef USE_ROCM
|
||||
llPrintfHIP(op->getLoc(), op->getParentOfType<mlir::ModuleOp>(), formatting,
|
||||
printfOperands, rewriter);
|
||||
#else
|
||||
llPrintf(formatStrValue, printfOperands, rewriter);
|
||||
#endif
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
std::string getFormatSubstr(Value value,
|
||||
std::optional<int> width = std::nullopt) const {
|
||||
std::string prefix = "%";
|
||||
if (width.has_value()) {
|
||||
prefix += std::to_string(*width);
|
||||
}
|
||||
|
||||
Type type = value.getType();
|
||||
if (type.isa<LLVM::LLVMPointerType>()) {
|
||||
return "%p";
|
||||
return prefix + "p";
|
||||
} else if (type.isBF16() || type.isF16() || type.isF32() || type.isF64()) {
|
||||
return "%f";
|
||||
return prefix + "f";
|
||||
} else if (type.isSignedInteger()) {
|
||||
if (type.getIntOrFloatBitWidth() == 64)
|
||||
return "%lli";
|
||||
return prefix + "lli";
|
||||
else
|
||||
return "%i";
|
||||
return prefix + "i";
|
||||
} else if (type.isUnsignedInteger() || type.isSignlessInteger()) {
|
||||
if (type.getIntOrFloatBitWidth() == 64)
|
||||
return "%llu";
|
||||
return prefix + "llu";
|
||||
else
|
||||
return "%u";
|
||||
return prefix + "u";
|
||||
}
|
||||
assert(false && "not supported type");
|
||||
return "";
|
||||
@@ -199,9 +373,22 @@ struct PrintOpConversion
|
||||
return {newType, newOp};
|
||||
}
|
||||
|
||||
static void llPrintf(StringRef msg, ValueRange args,
|
||||
// Returns a Value for the format string, which you can reuse.
|
||||
static Value llPrintf(StringRef msg, ValueRange args,
|
||||
ConversionPatternRewriter &rewriter) {
|
||||
assert(!msg.empty() && "printf with empty string not supported");
|
||||
llvm::SmallString<64> msgNewline(msg);
|
||||
msgNewline.push_back('\n');
|
||||
msgNewline.push_back('\0');
|
||||
Value msgValue =
|
||||
LLVM::addStringToModule(UnknownLoc::get(rewriter.getContext()),
|
||||
rewriter, "printfFormat_", msgNewline);
|
||||
llPrintf(msgValue, args, rewriter);
|
||||
return msgValue;
|
||||
}
|
||||
|
||||
static void llPrintf(Value msg, ValueRange args,
|
||||
ConversionPatternRewriter &rewriter) {
|
||||
assert(!msg.empty() && "printf with empty string not support");
|
||||
Type int8Ptr = ptr_ty(i8_ty);
|
||||
|
||||
auto *ctx = rewriter.getContext();
|
||||
@@ -213,11 +400,6 @@ struct PrintOpConversion
|
||||
Value one = i32_val(1);
|
||||
Value zero = i32_val(0);
|
||||
|
||||
llvm::SmallString<64> msgNewline(msg);
|
||||
msgNewline.push_back('\n');
|
||||
msgNewline.push_back('\0');
|
||||
Value prefixString =
|
||||
LLVM::addStringToModule(loc, rewriter, "printfFormat_", msgNewline);
|
||||
Value bufferPtr = null(int8Ptr);
|
||||
|
||||
SmallVector<Value, 16> newArgs;
|
||||
@@ -245,7 +427,7 @@ struct PrintOpConversion
|
||||
bufferPtr = bitcast(allocated, int8Ptr);
|
||||
}
|
||||
|
||||
SmallVector<Value> operands{prefixString, bufferPtr};
|
||||
SmallVector<Value> operands{msg, bufferPtr};
|
||||
call(funcOp, operands);
|
||||
}
|
||||
};
|
||||
@@ -447,32 +629,14 @@ struct GetProgramIdOpConversion
|
||||
matchAndRewrite(triton::GetProgramIdOp op, OpAdaptor adaptor,
|
||||
ConversionPatternRewriter &rewriter) const override {
|
||||
|
||||
Value programId = llGetPid(op.getAxisAsInt(), op->getLoc(),
|
||||
op->getParentOfType<ModuleOp>(), rewriter);
|
||||
#ifdef USE_ROCM
|
||||
Location loc = op->getLoc();
|
||||
assert(op.getAxisAsInt() < 3);
|
||||
|
||||
Value blockId =
|
||||
rewriter.create<::mlir::gpu::BlockIdOp>(loc, dims[op.getAxisAsInt()]);
|
||||
rewriter.replaceOpWithNewOp<arith::IndexCastOp>(op, i32_ty, blockId);
|
||||
return success();
|
||||
rewriter.replaceOpWithNewOp<arith::IndexCastOp>(op, i32_ty, programId);
|
||||
#else
|
||||
// It is not easy to get the compute capability here, so we use numCTAs to
|
||||
// decide the semantic of GetProgramIdOp. If numCTAs = 1, then
|
||||
// GetProgramIdOp is converted to "%ctaid", otherwise it is converted to
|
||||
// "%clusterid".
|
||||
auto moduleOp = op->getParentOfType<ModuleOp>();
|
||||
assert(moduleOp && "Parent ModuleOp not found for GetProgramIdOp");
|
||||
int numCTAs = triton::gpu::TritonGPUDialect::getNumCTAs(moduleOp);
|
||||
|
||||
Location loc = op->getLoc();
|
||||
assert(op.getAxisAsInt() < 3);
|
||||
std::string sreg = numCTAs == 1 ? "%ctaid." : "%clusterid.";
|
||||
sreg.append(1, 'x' + op.getAxisAsInt()); // 0 -> 'x', 1 -> 'y', 2 -> 'z'
|
||||
|
||||
Value programId = getSRegValue(rewriter, loc, sreg);
|
||||
rewriter.replaceOp(op, programId);
|
||||
return success();
|
||||
#endif
|
||||
return success();
|
||||
}
|
||||
static constexpr mlir::gpu::Dimension dims[] = {mlir::gpu::Dimension::x,
|
||||
mlir::gpu::Dimension::y,
|
||||
@@ -770,28 +934,9 @@ struct AsyncBulkCommitGroupOpConversion
|
||||
}
|
||||
};
|
||||
|
||||
namespace mlir {
|
||||
namespace LLVM {
|
||||
} // namespace
|
||||
|
||||
void vprintf(StringRef msg, ValueRange args,
|
||||
ConversionPatternRewriter &rewriter) {
|
||||
PrintOpConversion::llPrintf(msg, args, rewriter);
|
||||
}
|
||||
|
||||
void vprintf_array(Value thread, ArrayRef<Value> arr, std::string info,
|
||||
std::string elem_repr, ConversionPatternRewriter &builder) {
|
||||
std::string fmt = info + " t-%d ";
|
||||
std::vector<Value> new_arr({thread});
|
||||
for (int i = 0; i < arr.size(); ++i) {
|
||||
fmt += elem_repr + ((i == arr.size() - 1) ? "" : ", ");
|
||||
new_arr.push_back(arr[i]);
|
||||
}
|
||||
|
||||
vprintf(fmt, new_arr, builder);
|
||||
}
|
||||
|
||||
} // namespace LLVM
|
||||
} // namespace mlir
|
||||
namespace mlir::triton {
|
||||
|
||||
void populateTritonGPUToLLVMPatterns(
|
||||
TritonGPUToLLVMTypeConverter &typeConverter, RewritePatternSet &patterns,
|
||||
@@ -818,3 +963,5 @@ void populateTritonGPUToLLVMPatterns(
|
||||
patterns.add<PrintOpConversion>(typeConverter, benefit);
|
||||
patterns.add<AssertOpConversion>(typeConverter, benefit);
|
||||
}
|
||||
|
||||
} // namespace mlir::triton
|
||||
|
||||
@@ -6,6 +6,8 @@
|
||||
using namespace mlir;
|
||||
using namespace mlir::triton;
|
||||
|
||||
namespace mlir::triton {
|
||||
|
||||
void populateTritonGPUToLLVMPatterns(
|
||||
TritonGPUToLLVMTypeConverter &typeConverter, RewritePatternSet &patterns,
|
||||
int numWarps, ModuleAxisInfoAnalysis &axisInfoAnalysis,
|
||||
@@ -13,4 +15,6 @@ void populateTritonGPUToLLVMPatterns(
|
||||
ConvertTritonGPUOpToLLVMPatternBase::IndexCacheInfo &indexCacheInfo,
|
||||
PatternBenefit benefit);
|
||||
|
||||
} // namespace mlir::triton
|
||||
|
||||
#endif
|
||||
|
||||
@@ -38,19 +38,6 @@ namespace ttng = ::mlir::triton::nvidia_gpu;
|
||||
|
||||
typedef DenseMap<Operation *, triton::MakeTensorPtrOp> TensorPtrMapT;
|
||||
|
||||
namespace mlir {
|
||||
namespace LLVM {
|
||||
|
||||
// Helper function for using printf in LLVM conversion.
|
||||
void vprintf(StringRef msg, ValueRange args,
|
||||
ConversionPatternRewriter &rewriter);
|
||||
|
||||
void vprintf_array(Value thread, ArrayRef<Value> arr, std::string info,
|
||||
std::string elem_repr, ConversionPatternRewriter &builder);
|
||||
|
||||
} // namespace LLVM
|
||||
} // namespace mlir
|
||||
|
||||
// FuncOpConversion/FuncOpConversionBase is borrowed from
|
||||
// https://github.com/llvm/llvm-project/blob/fae656b2dd80246c3c6f01e9c77c49560368752c/mlir/lib/Conversion/FuncToLLVM/FuncToLLVM.cpp#L276
|
||||
// since it is not exposed on header files in mlir v14
|
||||
@@ -193,10 +180,16 @@ public:
|
||||
// Key: {layout, shape, withCTAOffset}
|
||||
struct IndexCacheInfo {
|
||||
DenseMap<IndexCacheKeyT, SmallVector<Value>, CacheKeyDenseMapInfo>
|
||||
*baseIndexCache;
|
||||
*baseIndexCache = nullptr;
|
||||
DenseMap<IndexCacheKeyT, SmallVector<SmallVector<Value>>,
|
||||
CacheKeyDenseMapInfo> *indexCache;
|
||||
OpBuilder::InsertPoint *indexInsertPoint;
|
||||
CacheKeyDenseMapInfo> *indexCache = nullptr;
|
||||
OpBuilder::InsertPoint *indexInsertPoint = nullptr;
|
||||
};
|
||||
|
||||
struct PrintFormatting
|
||||
{
|
||||
Value formatStrValue;
|
||||
size_t formatStrSize;
|
||||
};
|
||||
|
||||
explicit ConvertTritonGPUOpToLLVMPatternBase(
|
||||
@@ -837,7 +830,7 @@ public:
|
||||
emitIndicesForDistributedLayout(loc, b, slice, type, withCTAOffset);
|
||||
} else {
|
||||
llvm_unreachable(
|
||||
"emitIndices for layouts other than blocked & slice not "
|
||||
"emitIndices for layouts other than blocked, mma, and slice not "
|
||||
"implemented yet");
|
||||
}
|
||||
if (cache) {
|
||||
@@ -1332,15 +1325,31 @@ private:
|
||||
}
|
||||
|
||||
protected:
|
||||
// Returns a Value for the format string, which you can reuse.
|
||||
PrintFormatting llPrintfHIP(mlir::Location loc, mlir::ModuleOp moduleOp, StringRef msg,
|
||||
ValueRange args, ConversionPatternRewriter &rewriter,
|
||||
bool stderr = false) const {
|
||||
assert(!msg.empty() && "printf with empty string not supported");
|
||||
PrintFormatting formatting;
|
||||
llvm::SmallString<32> msgNewline(msg);
|
||||
msgNewline.push_back('\n');
|
||||
msgNewline.push_back('\0');
|
||||
formatting.formatStrValue =
|
||||
LLVM::addStringToModule(loc, rewriter, "printfFormat_", msgNewline);
|
||||
formatting.formatStrSize = msgNewline.size_in_bytes();
|
||||
llPrintfHIP(loc, moduleOp, formatting, args, rewriter, stderr);
|
||||
return formatting;
|
||||
}
|
||||
|
||||
// The code is borrowed from https://reviews.llvm.org/D110448
|
||||
// from GPUPrintfOpToHIPLowering::matchAndRewrite().
|
||||
void llPrintfHIP(mlir::Location loc, mlir::ModuleOp moduleOp, StringRef msg,
|
||||
void llPrintfHIP(mlir::Location loc, mlir::ModuleOp moduleOp, PrintFormatting formatting,
|
||||
ValueRange args, ConversionPatternRewriter &rewriter,
|
||||
bool stderr = false) const {
|
||||
|
||||
auto typeConverter = getTypeConverter();
|
||||
mlir::Type llvmI8 = typeConverter->convertType(rewriter.getI8Type());
|
||||
mlir::Type i8Ptr = typeConverter->getPointerType(llvmI8);
|
||||
auto ptrType = LLVM::LLVMPointerType::get(rewriter.getContext());
|
||||
mlir::Type llvmI32 = typeConverter->convertType(rewriter.getI32Type());
|
||||
mlir::Type llvmI64 = typeConverter->convertType(rewriter.getI64Type());
|
||||
|
||||
@@ -1362,7 +1371,7 @@ protected:
|
||||
moduleOp, loc, rewriter, "__ockl_printf_append_string_n",
|
||||
LLVM::LLVMFunctionType::get(
|
||||
llvmI64,
|
||||
{llvmI64, i8Ptr, /*length (bytes)*/ llvmI64, /*isLast*/ llvmI32}));
|
||||
{llvmI64, {ptrType}, /*length (bytes)*/ llvmI64, /*isLast*/ llvmI32}));
|
||||
|
||||
/// Start the printf hostcall
|
||||
Value zeroI64 = rewriter.create<LLVM::ConstantOp>(loc, llvmI64, 0);
|
||||
@@ -1373,19 +1382,11 @@ protected:
|
||||
// Get a unique global name for the format.
|
||||
SmallString<16> stringConstName = getUniqueFormatGlobalName(moduleOp);
|
||||
|
||||
SmallString<32> formatString(msg);
|
||||
formatString.push_back('\n'); // Triton adds CR for each print.
|
||||
formatString.push_back('\0'); // Null terminate for C
|
||||
size_t formatStringSize = formatString.size_in_bytes();
|
||||
|
||||
Value prefixString =
|
||||
LLVM::addStringToModule(loc, rewriter, "printfFormat_", formatString);
|
||||
|
||||
auto prefixPtrType = ocklAppendStringN.getArgumentTypes()[1];
|
||||
prefixString = bitcast(prefixString, prefixPtrType);
|
||||
Value prefixString = bitcast(formatting.formatStrValue, prefixPtrType);
|
||||
|
||||
Value stringLen =
|
||||
rewriter.create<LLVM::ConstantOp>(loc, llvmI64, formatStringSize);
|
||||
rewriter.create<LLVM::ConstantOp>(loc, llvmI64, formatting.formatStrSize);
|
||||
|
||||
Value oneI32 = rewriter.create<LLVM::ConstantOp>(loc, llvmI32, 1);
|
||||
Value zeroI32 = rewriter.create<LLVM::ConstantOp>(loc, llvmI32, 0);
|
||||
@@ -1411,12 +1412,11 @@ protected:
|
||||
Value arg = args[i];
|
||||
if (auto floatType = arg.getType().dyn_cast<FloatType>()) {
|
||||
if (!floatType.isF64())
|
||||
arg = rewriter.create<LLVM::FPExtOp>(
|
||||
loc, typeConverter->convertType(rewriter.getF64Type()), arg);
|
||||
arg = rewriter.create<LLVM::BitcastOp>(loc, llvmI64, arg);
|
||||
arg = fpext(typeConverter->convertType(rewriter.getF64Type()), arg);
|
||||
arg = bitcast(arg, llvmI64);
|
||||
}
|
||||
if (arg.getType().getIntOrFloatBitWidth() != 64)
|
||||
arg = rewriter.create<LLVM::ZExtOp>(loc, llvmI64, arg);
|
||||
arg = zext(llvmI64, arg);
|
||||
|
||||
arguments.push_back(arg);
|
||||
}
|
||||
@@ -1427,7 +1427,7 @@ protected:
|
||||
|
||||
auto isLast = (bound == nArgs) ? oneI32 : zeroI32;
|
||||
arguments.push_back(isLast);
|
||||
auto call = rewriter.create<LLVM::CallOp>(loc, ocklAppendArgs, arguments);
|
||||
auto call = call(ocklAppendArgs, arguments);
|
||||
printfDesc = call.getResult();
|
||||
}
|
||||
}
|
||||
|
||||
@@ -64,8 +64,8 @@ Type TritonGPUToLLVMTypeConverter::convertTritonPointerType(
|
||||
for (size_t i = 0; i < 2 * shape.size(); ++i)
|
||||
types.push_back(IntegerType::get(ctx, 64));
|
||||
|
||||
types.push_back(
|
||||
LLVM::LLVMPointerType::get(eleType, type.getAddressSpace()));
|
||||
types.push_back(LLVM::LLVMPointerType::get(convertType(eleType),
|
||||
type.getAddressSpace()));
|
||||
|
||||
return LLVM::LLVMStructType::getLiteral(ctx, types);
|
||||
}
|
||||
|
||||
@@ -442,8 +442,10 @@ Value addStringToModule(Location loc, ConversionPatternRewriter &rewriter,
|
||||
}
|
||||
|
||||
Value zero = i32_val(0);
|
||||
Value globalPtr =
|
||||
rewriter.create<LLVM::AddressOfOp>(UnknownLoc::get(ctx), global);
|
||||
Type globalPtrType =
|
||||
LLVM::LLVMPointerType::get(globalType, global.getAddrSpace());
|
||||
Value globalPtr = rewriter.create<LLVM::AddressOfOp>(
|
||||
UnknownLoc::get(ctx), globalPtrType, global.getSymName());
|
||||
Value stringStart =
|
||||
rewriter.create<LLVM::GEPOp>(UnknownLoc::get(ctx), ptr_ty(i8_ty),
|
||||
globalPtr, SmallVector<Value>({zero, zero}));
|
||||
|
||||
@@ -147,6 +147,8 @@ struct ViewOpConversion : public ConvertTritonGPUOpToLLVMPattern<ViewOp> {
|
||||
matchAndRewrite(ViewOp op, OpAdaptor adaptor,
|
||||
ConversionPatternRewriter &rewriter) const override {
|
||||
Location loc = op->getLoc();
|
||||
assert(!triton::gpu::isExpensiveView(op.getSrc().getType(), op.getType()) &&
|
||||
"expensive view not supported");
|
||||
auto resultTy = op.getType().template cast<RankedTensorType>();
|
||||
auto vals = this->getTypeConverter()->unpackLLElements(
|
||||
loc, adaptor.getSrc(), rewriter, op.getOperand().getType());
|
||||
|
||||
@@ -46,23 +46,6 @@ template <class Op> struct GenericOpPattern : public OpConversionPattern<Op> {
|
||||
}
|
||||
};
|
||||
|
||||
template <class SrcOp, class DstOp>
|
||||
class ArithCmpPattern : public OpConversionPattern<SrcOp> {
|
||||
public:
|
||||
using OpConversionPattern<SrcOp>::OpConversionPattern;
|
||||
|
||||
LogicalResult
|
||||
matchAndRewrite(SrcOp op, typename SrcOp::Adaptor adaptor,
|
||||
ConversionPatternRewriter &rewriter) const override {
|
||||
Type retType = this->getTypeConverter()->convertType(op.getType());
|
||||
addNamedAttrs(
|
||||
rewriter.replaceOpWithNewOp<DstOp>(op, retType, adaptor.getPredicate(),
|
||||
adaptor.getLhs(), adaptor.getRhs()),
|
||||
adaptor.getAttributes());
|
||||
return success();
|
||||
}
|
||||
};
|
||||
|
||||
class ArithConstantPattern : public OpConversionPattern<arith::ConstantOp> {
|
||||
public:
|
||||
using OpConversionPattern<arith::ConstantOp>::OpConversionPattern;
|
||||
@@ -122,8 +105,9 @@ void populateArithPatternsAndLegality(TritonGPUTypeConverter &typeConverter,
|
||||
GenericOpPattern<arith::MulFOp>, GenericOpPattern<arith::DivFOp>,
|
||||
GenericOpPattern<arith::RemFOp>,
|
||||
// Cmp
|
||||
ArithCmpPattern<arith::CmpIOp, triton::gpu::CmpIOp>,
|
||||
ArithCmpPattern<arith::CmpFOp, triton::gpu::CmpFOp>,
|
||||
GenericOpPattern<arith::CmpIOp>, GenericOpPattern<arith::CmpFOp>,
|
||||
// Select
|
||||
GenericOpPattern<arith::SelectOp>,
|
||||
// Cast Ops
|
||||
GenericOpPattern<arith::TruncIOp>, GenericOpPattern<arith::TruncFOp>,
|
||||
GenericOpPattern<arith::ExtUIOp>, GenericOpPattern<arith::ExtSIOp>,
|
||||
@@ -132,45 +116,6 @@ void populateArithPatternsAndLegality(TritonGPUTypeConverter &typeConverter,
|
||||
GenericOpPattern<arith::UIToFPOp>>(typeConverter, context);
|
||||
}
|
||||
|
||||
// this shouldn't exist if mlir's SelectOp checked encodings properly
|
||||
class StdSelectPattern : public OpConversionPattern<arith::SelectOp> {
|
||||
public:
|
||||
using OpConversionPattern<arith::SelectOp>::OpConversionPattern;
|
||||
|
||||
LogicalResult
|
||||
matchAndRewrite(arith::SelectOp op, OpAdaptor adaptor,
|
||||
ConversionPatternRewriter &rewriter) const override {
|
||||
Type retType = this->getTypeConverter()->convertType(op.getType());
|
||||
|
||||
Value cond = adaptor.getCondition();
|
||||
if (llvm::isa<RankedTensorType>(retType) &&
|
||||
!llvm::isa<TensorType>(cond.getType())) {
|
||||
// triton_gpu.select doesn't support scalar condition values, so add a
|
||||
// splat
|
||||
auto retTypeTensor = llvm::cast<RankedTensorType>(retType);
|
||||
auto retShape = retTypeTensor.getShape();
|
||||
auto retEncoding = retTypeTensor.getEncoding();
|
||||
Type condTy =
|
||||
RankedTensorType::get(retShape, cond.getType(), retEncoding);
|
||||
cond = rewriter.create<triton::SplatOp>(op.getLoc(), condTy, cond);
|
||||
}
|
||||
|
||||
addNamedAttrs(
|
||||
rewriter.replaceOpWithNewOp<triton::gpu::SelectOp>(
|
||||
op, retType, cond, adaptor.getTrueValue(), adaptor.getFalseValue()),
|
||||
adaptor.getAttributes());
|
||||
return success();
|
||||
}
|
||||
};
|
||||
|
||||
void populateStdPatternsAndLegality(TritonGPUTypeConverter &typeConverter,
|
||||
RewritePatternSet &patterns,
|
||||
TritonGPUConversionTarget &target) {
|
||||
MLIRContext *context = patterns.getContext();
|
||||
// Rewrite rule
|
||||
patterns.add<StdSelectPattern>(typeConverter, context);
|
||||
}
|
||||
|
||||
void populateMathPatternsAndLegality(TritonGPUTypeConverter &typeConverter,
|
||||
RewritePatternSet &patterns,
|
||||
TritonGPUConversionTarget &target) {
|
||||
@@ -529,6 +474,7 @@ void populateTritonPatterns(TritonGPUTypeConverter &typeConverter,
|
||||
GenericOpPattern<triton::StoreOp>,
|
||||
GenericOpPattern<triton::ExternElementwiseOp>,
|
||||
GenericOpPattern<triton::PrintOp>, GenericOpPattern<triton::AssertOp>,
|
||||
GenericOpPattern<triton::AtomicCASOp>,
|
||||
GenericOpPattern<triton::AtomicRMWOp>, GenericOpPattern<ReturnOp>,
|
||||
GenericOpPattern<triton::CallOp>, TritonFuncOpPattern>(typeConverter,
|
||||
context);
|
||||
@@ -745,7 +691,6 @@ public:
|
||||
// rewrite patterns
|
||||
RewritePatternSet patterns(context);
|
||||
// add rules
|
||||
populateStdPatternsAndLegality(typeConverter, patterns, target);
|
||||
populateArithPatternsAndLegality(typeConverter, patterns, target);
|
||||
populateMathPatternsAndLegality(typeConverter, patterns, target);
|
||||
populateTritonPatterns(typeConverter, patterns, numCTAs);
|
||||
|
||||
@@ -52,7 +52,7 @@ unsigned getTotalElemsPerThread(Attribute layout, ArrayRef<int64_t> shape,
|
||||
} else if (auto dotLayout = layout.dyn_cast<DotOperandEncodingAttr>()) {
|
||||
return dotLayout.getTotalElemsPerThread(shape, eltTy);
|
||||
} else {
|
||||
assert(0 && "getElemsPerThread not implemented");
|
||||
llvm::report_fatal_error("getElemsPerThread not implemented");
|
||||
return 0;
|
||||
}
|
||||
}
|
||||
@@ -68,7 +68,7 @@ SmallVector<unsigned> getElemsPerThread(Attribute layout,
|
||||
} else if (auto mfmaLayout = layout.dyn_cast<MfmaEncodingAttr>()) {
|
||||
return mfmaLayout.getElemsPerThread(shape, eltTy);
|
||||
} else {
|
||||
assert(0 && "getElemsPerThread not implemented");
|
||||
llvm::report_fatal_error("getElemsPerThread not implemented");
|
||||
return SmallVector<unsigned>();
|
||||
}
|
||||
}
|
||||
@@ -129,7 +129,7 @@ SmallVector<unsigned> getThreadsPerWarp(Attribute layout) {
|
||||
threadsPerWarp[i] *= parentThreadsPerWarp[sliceLayout.getDim()];
|
||||
return threadsPerWarp;
|
||||
}
|
||||
assert(0 && "getThreadsPerWarp not implemented");
|
||||
llvm::report_fatal_error("getThreadsPerWarp not implemented");
|
||||
return {};
|
||||
}
|
||||
|
||||
@@ -180,15 +180,17 @@ SmallVector<unsigned> getWarpsPerCTA(Attribute layout) {
|
||||
if (auto sliceLayout = layout.dyn_cast<SliceEncodingAttr>()) {
|
||||
auto parent = sliceLayout.getParent();
|
||||
auto parentWarpsPerCTA = getWarpsPerCTA(parent);
|
||||
assert(parentWarpsPerCTA.size() == 2 &&
|
||||
"getWarpsPerCTA only implemented for 2D slice layout");
|
||||
assert(parentWarpsPerCTA.size() == 2 ||
|
||||
parentWarpsPerCTA[sliceLayout.getDim()] == 1 &&
|
||||
"getWarpsPerCTA only implemented for 2D slice layout or the "
|
||||
"slice dim must have 1 warp in the parent layout");
|
||||
SmallVector<unsigned> warpsPerCTA = parentWarpsPerCTA;
|
||||
warpsPerCTA.erase(warpsPerCTA.begin() + sliceLayout.getDim());
|
||||
for (unsigned i = 0; i < warpsPerCTA.size(); i++)
|
||||
warpsPerCTA[i] *= parentWarpsPerCTA[sliceLayout.getDim()];
|
||||
return warpsPerCTA;
|
||||
}
|
||||
assert(0 && "getWarpsPerCTA not implemented");
|
||||
llvm::report_fatal_error("getWarpsPerCTA not implemented");
|
||||
return {};
|
||||
}
|
||||
|
||||
@@ -271,7 +273,7 @@ SmallVector<unsigned> getSizePerThread(Attribute layout) {
|
||||
} else if (opIdx == 1) {
|
||||
return {4, 1};
|
||||
} else {
|
||||
assert(0 && "DotOperandEncodingAttr opIdx must be 0 or 1");
|
||||
llvm::report_fatal_error("DotOperandEncodingAttr opIdx must be 0 or 1");
|
||||
return {};
|
||||
}
|
||||
} else if (parentLayout.isa<MfmaEncodingAttr>()) {
|
||||
@@ -285,12 +287,13 @@ SmallVector<unsigned> getSizePerThread(Attribute layout) {
|
||||
return {};
|
||||
}
|
||||
} else {
|
||||
assert(0 && "DotOperandEncodingAttr non-MmaEncodingAttr parent not "
|
||||
"supported yet");
|
||||
llvm::report_fatal_error(
|
||||
"DotOperandEncodingAttr non-MmaEncodingAttr parent not "
|
||||
"supported yet");
|
||||
return {};
|
||||
}
|
||||
} else {
|
||||
assert(0 && "getSizePerThread not implemented");
|
||||
llvm::report_fatal_error("getSizePerThread not implemented");
|
||||
return {};
|
||||
}
|
||||
}
|
||||
@@ -344,7 +347,7 @@ SmallVector<unsigned> getThreadsPerCTA(Attribute layout) {
|
||||
threads = {8 * mmaLayout.getWarpsPerCTA()[0],
|
||||
4 * mmaLayout.getWarpsPerCTA()[1]};
|
||||
} else
|
||||
assert(0 && "Unimplemented usage of MmaEncodingAttr");
|
||||
llvm::report_fatal_error("Unimplemented usage of MmaEncodingAttr");
|
||||
} else if (auto mfmaLayout = layout.dyn_cast<MfmaEncodingAttr>()) {
|
||||
if (mfmaLayout.getNonKDim() == 32) {
|
||||
threads = {32 * mfmaLayout.getWarpsPerCTA()[0],
|
||||
@@ -354,7 +357,7 @@ SmallVector<unsigned> getThreadsPerCTA(Attribute layout) {
|
||||
4 * mfmaLayout.getWarpsPerCTA()[1]};
|
||||
}
|
||||
} else {
|
||||
assert(0 && "Unimplemented usage of getThreadsPerCTA");
|
||||
llvm::report_fatal_error("Unimplemented usage of getThreadsPerCTA");
|
||||
}
|
||||
|
||||
return threads;
|
||||
@@ -388,7 +391,7 @@ SmallVector<unsigned> getShapePerCTATile(Attribute layout,
|
||||
return {16 * mmaLayout.getWarpsPerCTA()[0],
|
||||
instrShape[1] * mmaLayout.getWarpsPerCTA()[1]};
|
||||
}
|
||||
assert(0 && "Unexpected MMA layout version found");
|
||||
llvm::report_fatal_error("Unexpected MMA layout version found");
|
||||
} else if (auto mfmaLayout = layout.dyn_cast<MfmaEncodingAttr>()) {
|
||||
auto nonKDim = mfmaLayout.getNonKDim();
|
||||
return {nonKDim * mfmaLayout.getWarpsPerCTA()[0],
|
||||
@@ -408,7 +411,7 @@ SmallVector<unsigned> getShapePerCTATile(Attribute layout,
|
||||
} else if (opIdx == 1) {
|
||||
return {16, parentShapePerCTATile[1]};
|
||||
} else {
|
||||
assert(0 && "DotOperandEncodingAttr opIdx must be 0 or 1");
|
||||
llvm::report_fatal_error("DotOperandEncodingAttr opIdx must be 0 or 1");
|
||||
}
|
||||
} else if (auto parentMfmaLayout =
|
||||
parentLayout.dyn_cast<MfmaEncodingAttr>()) {
|
||||
@@ -423,15 +426,20 @@ SmallVector<unsigned> getShapePerCTATile(Attribute layout,
|
||||
assert(0 && "DotOperandEncodingAttr opIdx must be 0 or 1");
|
||||
}
|
||||
} else {
|
||||
assert(0 && "DotOperandEncodingAttr non-MmaEncodingAttr parent not "
|
||||
"supported yet");
|
||||
llvm::report_fatal_error(
|
||||
"DotOperandEncodingAttr non-MmaEncodingAttr parent not "
|
||||
"supported yet");
|
||||
}
|
||||
} else {
|
||||
assert(0 && "Unimplemented usage of getShapePerCTATile");
|
||||
llvm::report_fatal_error("Unimplemented usage of getShapePerCTATile");
|
||||
}
|
||||
return shape;
|
||||
}
|
||||
|
||||
bool isExpensiveView(Type srcType, Type dstType) {
|
||||
return getTotalElemsPerThread(srcType) != getTotalElemsPerThread(dstType);
|
||||
}
|
||||
|
||||
namespace {
|
||||
|
||||
/* Utility function used by getOrder and getCTAOrder of SliceEncodingAttr.
|
||||
@@ -480,7 +488,7 @@ SmallVector<unsigned> getOrder(Attribute layout) {
|
||||
return SmallVector<unsigned>(sharedLayout.getOrder().begin(),
|
||||
sharedLayout.getOrder().end());
|
||||
} else {
|
||||
assert(0 && "Unimplemented usage of getOrder");
|
||||
llvm::report_fatal_error("Unimplemented usage of getOrder");
|
||||
}
|
||||
return {};
|
||||
};
|
||||
@@ -501,7 +509,7 @@ CTALayoutAttr getCTALayout(Attribute layout) {
|
||||
else if (auto sharedLayout = layout.dyn_cast<SharedEncodingAttr>())
|
||||
return sharedLayout.getCTALayout();
|
||||
else
|
||||
assert(0 && "Unimplemented usage of getCTALayout");
|
||||
llvm::report_fatal_error("Unimplemented usage of getCTALayout");
|
||||
return {};
|
||||
}
|
||||
|
||||
@@ -529,7 +537,8 @@ SmallVector<unsigned> getCTAsPerCGA(Attribute layout) {
|
||||
* in the branch where layout is an instance of SliceEncodingAttr. This is
|
||||
* inconvenient but safe.
|
||||
*/
|
||||
assert(0 && "getCTAsPerCGA for SliceEncodingAttr is not well-defined");
|
||||
llvm::report_fatal_error(
|
||||
"getCTAsPerCGA for SliceEncodingAttr is not well-defined");
|
||||
} else if (auto mmaLayout = layout.dyn_cast<MmaEncodingAttr>())
|
||||
ref = mmaLayout.getCTALayout().getCTAsPerCGA();
|
||||
#ifdef USE_ROCM
|
||||
@@ -541,7 +550,7 @@ SmallVector<unsigned> getCTAsPerCGA(Attribute layout) {
|
||||
else if (auto sharedLayout = layout.dyn_cast<SharedEncodingAttr>())
|
||||
ref = sharedLayout.getCTALayout().getCTAsPerCGA();
|
||||
else
|
||||
assert(0 && "Unimplemented usage of getCTAsPerCGA");
|
||||
llvm::report_fatal_error("Unimplemented usage of getCTAsPerCGA");
|
||||
return SmallVector<unsigned>(ref.begin(), ref.end());
|
||||
}
|
||||
|
||||
@@ -596,7 +605,7 @@ SmallVector<unsigned> getCTAOrder(Attribute layout) {
|
||||
} else if (auto sharedLayout = layout.dyn_cast<SharedEncodingAttr>()) {
|
||||
ref = sharedLayout.getCTALayout().getCTAOrder();
|
||||
} else {
|
||||
assert(0 && "Unimplemented usage of getCTAOrder");
|
||||
llvm::report_fatal_error("Unimplemented usage of getCTAOrder");
|
||||
}
|
||||
return SmallVector<unsigned>(ref.begin(), ref.end());
|
||||
}
|
||||
@@ -649,9 +658,9 @@ unsigned getNumWarpsPerCTA(Attribute layout) {
|
||||
else if (auto dotLayout = layout.dyn_cast<DotOperandEncodingAttr>())
|
||||
return getNumWarpsPerCTA(dotLayout.getParent());
|
||||
else if (auto sharedLayout = layout.dyn_cast<SharedEncodingAttr>())
|
||||
assert(0 && "Cannot get numWarps from SharedEncodingAttr");
|
||||
llvm::report_fatal_error("Cannot get numWarps from SharedEncodingAttr");
|
||||
else
|
||||
assert(0 && "Unimplemented usage of getNumWarpsPerCTA");
|
||||
llvm::report_fatal_error("Unimplemented usage of getNumWarpsPerCTA");
|
||||
return product<unsigned>(warpsPerCTA);
|
||||
}
|
||||
|
||||
@@ -672,7 +681,7 @@ unsigned getNumCTAs(Attribute layout) {
|
||||
else if (auto sharedLayout = layout.dyn_cast<SharedEncodingAttr>())
|
||||
CTAsPerCGA = sharedLayout.getCTALayout().getCTAsPerCGA();
|
||||
else
|
||||
assert(0 && "Unimplemented usage of getNumCTAs");
|
||||
llvm::report_fatal_error("Unimplemented usage of getNumCTAs");
|
||||
return product<unsigned>(CTAsPerCGA);
|
||||
}
|
||||
|
||||
@@ -1787,13 +1796,15 @@ struct CanonicalizeConvertFromView
|
||||
Operation *arg = op->getOperand(0).getDefiningOp();
|
||||
if (!arg)
|
||||
return mlir::failure();
|
||||
auto convert = dyn_cast<ConvertLayoutOp>(arg);
|
||||
if (!convert)
|
||||
return failure();
|
||||
if (isExpensiveView(convert.getOperand().getType(), op.getType()))
|
||||
return failure();
|
||||
// view(convert) -> view
|
||||
if (auto convert = dyn_cast<ConvertLayoutOp>(arg)) {
|
||||
rewriter.replaceOpWithNewOp<triton::ViewOp>(
|
||||
op, op->getResult(0).getType(), convert.getOperand());
|
||||
return mlir::success();
|
||||
}
|
||||
return mlir::failure();
|
||||
rewriter.replaceOpWithNewOp<triton::ViewOp>(op, op->getResult(0).getType(),
|
||||
convert.getOperand());
|
||||
return mlir::success();
|
||||
}
|
||||
};
|
||||
|
||||
@@ -1839,6 +1850,8 @@ struct CanonicalizeConvertFromConvert
|
||||
return mlir::failure();
|
||||
// cvt(view) -> view
|
||||
if (auto view = dyn_cast<triton::ViewOp>(arg)) {
|
||||
if (isExpensiveView(view.getOperand().getType(), op.getType()))
|
||||
return failure();
|
||||
rewriter.replaceOpWithNewOp<triton::ViewOp>(
|
||||
op, op->getResult(0).getType(), view.getResult());
|
||||
return mlir::success();
|
||||
|
||||
@@ -70,10 +70,15 @@ warpsPerTileV2(tt::DotOp dotOp, const ArrayRef<int64_t> shape, int numWarps) {
|
||||
auto filter = [&dotOp](Operation *op) {
|
||||
return op->getParentRegion() == dotOp->getParentRegion();
|
||||
};
|
||||
auto slices = mlir::getSlice(dotOp, {filter});
|
||||
auto slices = multiRootGetSlice(dotOp, {filter});
|
||||
for (Operation *op : slices)
|
||||
if (isa<tt::DotOp>(op) && (op != dotOp))
|
||||
return {(unsigned)numWarps, 1};
|
||||
if (isa<tt::DotOp>(op) && (op != dotOp)) {
|
||||
if (shape[0] >= shape[1]) {
|
||||
return {(unsigned)numWarps, 1};
|
||||
} else {
|
||||
return {1, (unsigned)numWarps};
|
||||
}
|
||||
}
|
||||
|
||||
SmallVector<unsigned, 2> ret = {1, 1};
|
||||
SmallVector<int64_t, 2> shapePerWarp = {16, 8};
|
||||
@@ -133,8 +138,18 @@ class BlockedToMMA : public mlir::RewritePattern {
|
||||
mlir::TypeID::get<arith::ArithDialect>());
|
||||
}
|
||||
|
||||
// finds the first different value bitwidth in the chain of
|
||||
// shape-preserving unary ops that x depends on
|
||||
// Finds the first different bitwidth in the chain of shape-preserving
|
||||
// unary ops that x depends on.
|
||||
// There are two primary scenarios:
|
||||
// (1) Upcasting: A sequence such as loading an fp16, followed by arithmetic
|
||||
// operations, then bitcasting to fp32, and finally computing in fp32.
|
||||
// (2) Downcasting: This might involve loading an fp32, performing arithmetic
|
||||
// operations, bitcasting to fp16, and finally computing in fp16.
|
||||
// In the upcasting scenario, element reordering converts the original
|
||||
// elements distribution to the order of higher precision primitives. As a
|
||||
// result, kwidth can be the bitwidth of the lower precision primitive.
|
||||
// Conversely, in the downcasting scenario, no reordering is performed,
|
||||
// making it directory use the lower precision primitive.
|
||||
static int computeOrigBitWidth(Value x) {
|
||||
int finalBitWidth = getElementTypeOrSelf(x).getIntOrFloatBitWidth();
|
||||
int origBitWidth = finalBitWidth;
|
||||
@@ -143,11 +158,17 @@ class BlockedToMMA : public mlir::RewritePattern {
|
||||
opt.omitBlockArguments = true;
|
||||
opt.filter = bwdFilter;
|
||||
getBackwardSlice(x, &slice, opt);
|
||||
Operation *firstOp = slice.empty() ? nullptr : *slice.begin();
|
||||
if (firstOp)
|
||||
if (Value arg = firstOp->getOperand(0))
|
||||
if (RankedTensorType argTy = arg.getType().dyn_cast<RankedTensorType>())
|
||||
origBitWidth = argTy.getElementType().getIntOrFloatBitWidth();
|
||||
for (auto op : slice) {
|
||||
if (Value arg = op->getOperand(0))
|
||||
if (RankedTensorType argTy =
|
||||
arg.getType().dyn_cast<RankedTensorType>()) {
|
||||
auto argBitWidth = argTy.getElementType().getIntOrFloatBitWidth();
|
||||
if (argBitWidth != origBitWidth) {
|
||||
origBitWidth = std::min<int>(origBitWidth, argBitWidth);
|
||||
break;
|
||||
}
|
||||
}
|
||||
}
|
||||
return origBitWidth;
|
||||
}
|
||||
|
||||
|
||||
@@ -5,7 +5,10 @@ add_mlir_dialect_library(TritonGPUTransforms
|
||||
DecomposeConversions.cpp
|
||||
OptimizeDotOperands.cpp
|
||||
OptimizeEpilogue.cpp
|
||||
Pipeline.cpp
|
||||
OptimizeThreadLocality.cpp
|
||||
Pipeliner/MatmulLoopPipeline.cpp
|
||||
Pipeliner/PipelineExpander.cpp
|
||||
Pipeliner/SoftwarePipeliner.cpp
|
||||
Prefetch.cpp
|
||||
RemoveLayoutConversions.cpp
|
||||
ReorderInstructions.cpp
|
||||
|
||||
312
lib/Dialect/TritonGPU/Transforms/OptimizeThreadLocality.cpp
Normal file
312
lib/Dialect/TritonGPU/Transforms/OptimizeThreadLocality.cpp
Normal file
@@ -0,0 +1,312 @@
|
||||
#include "mlir/Pass/Pass.h"
|
||||
#include "mlir/Pass/PassManager.h"
|
||||
#include "mlir/Transforms/Passes.h"
|
||||
#include "triton/Dialect/TritonGPU/IR/Dialect.h"
|
||||
#include "triton/Dialect/TritonGPU/Transforms/Passes.h"
|
||||
#include "triton/Dialect/TritonGPU/Transforms/Utility.h"
|
||||
#include <memory>
|
||||
#define GEN_PASS_CLASSES
|
||||
#include "triton/Dialect/TritonGPU/Transforms/Passes.h.inc"
|
||||
|
||||
using namespace mlir;
|
||||
|
||||
class TritonGPUOptimizeThreadLocalityPass
|
||||
: public TritonGPUOptimizeThreadLocalityBase<
|
||||
TritonGPUOptimizeThreadLocalityPass> {
|
||||
void runOnOperation() override {
|
||||
ModuleOp mod = getOperation();
|
||||
DenseSet<triton::ReduceOp> reduceOps;
|
||||
mod.walk([&](triton::ReduceOp reduce) -> void {
|
||||
auto srcType = reduce.getOperands()[0].getType().cast<RankedTensorType>();
|
||||
auto rank = srcType.getShape().size();
|
||||
auto srcEncoding = srcType.getEncoding();
|
||||
auto reductionOp = getReductionOp(reduce);
|
||||
if (!reductionOp ||
|
||||
!isa<arith::AddFOp, arith::MaximumFOp, arith::MinimumFOp,
|
||||
arith::MulFOp>(reductionOp.value()))
|
||||
return;
|
||||
// TODO: relax this restriction
|
||||
if (!(srcEncoding.isa<triton::gpu::BlockedEncodingAttr>() && rank > 1))
|
||||
return;
|
||||
for (auto operand : reduce->getOperands()) {
|
||||
auto def = operand.getDefiningOp();
|
||||
if (!isa<triton::LoadOp>(def))
|
||||
return;
|
||||
}
|
||||
auto elemsPerThread =
|
||||
triton::gpu::getElemsPerThread(srcType)[reduce.getAxis()];
|
||||
// Not worth applying this optimization if there is only one element per
|
||||
// thread on the reduction axis
|
||||
if (elemsPerThread == 1)
|
||||
return;
|
||||
if (!reduce->hasOneUse())
|
||||
return;
|
||||
Operation *user = *(reduce->getUsers().begin());
|
||||
if (!user->hasOneUse())
|
||||
return;
|
||||
OpOperand &yieldOpOperand = *(user->getUses().begin());
|
||||
auto yieldOp = dyn_cast<scf::YieldOp>(yieldOpOperand.getOwner());
|
||||
if (!yieldOp)
|
||||
return;
|
||||
auto operandNumber = yieldOpOperand.getOperandNumber();
|
||||
Block *block = reduce->getBlock();
|
||||
Operation *parentOp = block->getParentOp();
|
||||
auto forOp = dyn_cast<scf::ForOp>(parentOp);
|
||||
if (!forOp)
|
||||
return;
|
||||
auto argNum = yieldOpOperand.getOperandNumber();
|
||||
auto oldAccum = forOp.getInitArgs()[argNum];
|
||||
auto cstOp = dyn_cast<arith::ConstantOp>(oldAccum.getDefiningOp());
|
||||
if (!cstOp)
|
||||
return;
|
||||
reduceOps.insert(reduce);
|
||||
});
|
||||
|
||||
for (auto reduce : reduceOps) {
|
||||
OpBuilder builder(reduce);
|
||||
auto srcType = reduce.getOperands()[0].getType().cast<RankedTensorType>();
|
||||
auto srcShape = srcType.getShape();
|
||||
auto srcEncoding = srcType.getEncoding();
|
||||
assert(srcEncoding.isa<triton::gpu::BlockedEncodingAttr>() &&
|
||||
"Thread locality optimization only supports blocked encoding");
|
||||
auto blocked = srcEncoding.dyn_cast<triton::gpu::BlockedEncodingAttr>();
|
||||
auto elemsPerThread =
|
||||
triton::gpu::getElemsPerThread(srcType)[reduce.getAxis()];
|
||||
auto rank = srcShape.size();
|
||||
// create new layouts
|
||||
auto blocked3d = getThreadLocalityOptimizedEncoding(reduce);
|
||||
auto viewOpTensorShape = getThreadLocalityOptimizedShape(reduce);
|
||||
auto viewOpTensorType = RankedTensorType::get(
|
||||
viewOpTensorShape, srcType.getElementType(), blocked3d);
|
||||
auto slice2d = triton::gpu::SliceEncodingAttr::get(mod.getContext(), rank,
|
||||
blocked3d);
|
||||
// Get forOp
|
||||
assert(reduce->hasOneUse());
|
||||
OpOperand &use = *(reduce->getUses().begin());
|
||||
auto operandNumber = use.getOperandNumber();
|
||||
auto oldUpdate = use.getOwner();
|
||||
assert(oldUpdate->getNumOperands() == 2);
|
||||
auto accumOperandNumber = (operandNumber == 0) ? 1 : 0;
|
||||
auto accumOperand = oldUpdate->getOperand(accumOperandNumber);
|
||||
assert(accumOperand.isa<BlockArgument>());
|
||||
auto blockArg = accumOperand.dyn_cast<BlockArgument>();
|
||||
auto blockArgNum = blockArg.getArgNumber();
|
||||
auto forOp = dyn_cast<scf::ForOp>(blockArg.getOwner()->getParentOp());
|
||||
// get oldAccum
|
||||
auto oldAccum =
|
||||
forOp.getInitArgs()[blockArgNum - forOp.getNumInductionVars()];
|
||||
// get old loop user
|
||||
Value loopResult =
|
||||
forOp.getResult(blockArgNum - forOp.getNumInductionVars());
|
||||
assert(loopResult.hasOneUse());
|
||||
OpOperand &loopUse = *(loopResult.getUses().begin());
|
||||
Operation *loopUser = loopUse.getOwner();
|
||||
// get old loop yield
|
||||
auto oldYield = cast<scf::YieldOp>(forOp.getBody()->getTerminator());
|
||||
// create newAccum initialization
|
||||
auto newAccum =
|
||||
createAccum(builder, reduce, oldAccum, viewOpTensorShape, slice2d);
|
||||
// create new loop by copying the old for op signature and appending
|
||||
// newAccum to the block arguments
|
||||
auto newLoop = replaceForOpWithNewSignature(
|
||||
builder, forOp, ValueRange{newAccum->getResult(0)});
|
||||
// create thread local reduction (also adds viewOps)
|
||||
auto newReduce = createReduce(builder, reduce, viewOpTensorType);
|
||||
|
||||
// create new accum update
|
||||
auto newUpdate = createUpdate(builder, newLoop, newReduce, oldUpdate);
|
||||
// create new yield
|
||||
auto newYield = createYield(builder, newLoop, oldYield,
|
||||
newUpdate->getResult(0), blockArgNum);
|
||||
// create post loop reduction on the original reduce axis
|
||||
auto newReduce2 = createPostLoopReduce(builder, newLoop, reduce);
|
||||
// add convert_layout to get back to original layout, the result layout
|
||||
// should now match the layout of the old accumulator (%cst)
|
||||
Type destType = loopResult.getType();
|
||||
auto cvtLayout = createConvertLayout(builder, destType, newReduce2);
|
||||
// incorporate the original accumulator value into the final result
|
||||
auto finalOp = incorporateOriginalAccumulatorValue(builder, oldUpdate,
|
||||
cvtLayout, oldAccum);
|
||||
// Replace the old loop user with the final result
|
||||
loopUser->setOperand(loopUse.getOperandNumber(), finalOp->getResult(0));
|
||||
|
||||
// cleanup
|
||||
oldYield.erase();
|
||||
forOp.erase();
|
||||
}
|
||||
};
|
||||
|
||||
private:
|
||||
std::optional<Operation *> getReductionOp(triton::ReduceOp reduce) const {
|
||||
auto numRegions = reduce->getNumRegions();
|
||||
if (numRegions != 1)
|
||||
return std::nullopt;
|
||||
Region ®ion = reduce->getRegion(0);
|
||||
auto numBlocks = region.getBlocks().size();
|
||||
if (numBlocks != 1)
|
||||
return std::nullopt;
|
||||
Block &block = region.front();
|
||||
auto blockWithoutTerminator = block.without_terminator();
|
||||
auto blockSizeWithoutTerminator = std::distance(
|
||||
blockWithoutTerminator.begin(), blockWithoutTerminator.end());
|
||||
if (blockSizeWithoutTerminator != 1)
|
||||
return std::nullopt;
|
||||
Operation *op = &block.front();
|
||||
return std::optional<Operation *>(op);
|
||||
}
|
||||
Operation *incorporateOriginalAccumulatorValue(OpBuilder &builder,
|
||||
Operation *oldUpdate,
|
||||
Operation *cvtLayout,
|
||||
Value oldAccum) const {
|
||||
builder.setInsertionPointAfter(cvtLayout);
|
||||
IRMapping mapping;
|
||||
mapping.map(oldUpdate->getOperand(0), oldAccum);
|
||||
mapping.map(oldUpdate->getOperand(1), cvtLayout->getResult(0));
|
||||
auto finalOp = cloneWithInferType(builder, &(*oldUpdate), mapping);
|
||||
return finalOp;
|
||||
}
|
||||
Operation *createConvertLayout(OpBuilder &builder, Type destType,
|
||||
Operation *newReduce) const {
|
||||
builder.setInsertionPointAfter(newReduce);
|
||||
auto newCvt = builder.create<triton::gpu::ConvertLayoutOp>(
|
||||
newReduce->getLoc(), destType, newReduce->getResult(0));
|
||||
return newCvt;
|
||||
}
|
||||
|
||||
Operation *createPostLoopReduce(OpBuilder &builder, scf::ForOp &loop,
|
||||
triton::ReduceOp &reduce) const {
|
||||
auto resultIndex =
|
||||
loop.getBody()->getNumArguments() - 1 - loop.getNumInductionVars();
|
||||
auto newLoopResult = loop.getResult(resultIndex);
|
||||
builder.setInsertionPointAfter(loop);
|
||||
IRMapping mapping;
|
||||
mapping.map(*(reduce.getOperands().begin()), newLoopResult);
|
||||
auto newReduce2 = cloneWithInferType(builder, &(*reduce), mapping);
|
||||
return newReduce2;
|
||||
}
|
||||
|
||||
Operation *createYield(OpBuilder &builder, scf::ForOp &loop,
|
||||
scf::YieldOp &oldYield, Value newUpdate,
|
||||
int oldAccumBlockArgNum) const {
|
||||
builder.setInsertionPoint(oldYield);
|
||||
SmallVector<Value> yieldValues = llvm::to_vector(oldYield.getOperands());
|
||||
yieldValues[oldAccumBlockArgNum - 1] =
|
||||
loop.getBody()->getArgument(oldAccumBlockArgNum);
|
||||
yieldValues.push_back(newUpdate);
|
||||
auto newYield =
|
||||
builder.create<scf::YieldOp>(oldYield.getLoc(), yieldValues);
|
||||
return newYield;
|
||||
}
|
||||
|
||||
Operation *createUpdate(OpBuilder &builder, scf::ForOp &loop,
|
||||
Operation *newReduce, Operation *oldUpdate) const {
|
||||
auto blockArgNum = loop.getBody()->getNumArguments() - 1;
|
||||
auto newArg = loop.getBody()->getArgument(blockArgNum);
|
||||
builder.setInsertionPointAfter(newReduce);
|
||||
IRMapping mapping;
|
||||
mapping.map(oldUpdate->getOperand(0), newArg);
|
||||
mapping.map(oldUpdate->getOperand(1), newReduce->getResult(0));
|
||||
auto newUpdate = cloneWithInferType(builder, oldUpdate, mapping);
|
||||
return newUpdate;
|
||||
}
|
||||
|
||||
Operation *createReduce(OpBuilder &builder, triton::ReduceOp reduce,
|
||||
Type viewOpTensorType) const {
|
||||
auto srcType = reduce.getOperands()[0].getType().cast<RankedTensorType>();
|
||||
auto rank = srcType.getShape().size();
|
||||
builder.setInsertionPointAfter(reduce);
|
||||
IRMapping mapping;
|
||||
for (auto operand : reduce.getOperands()) {
|
||||
auto viewOp = builder.create<triton::ViewOp>(reduce.getLoc(),
|
||||
viewOpTensorType, operand);
|
||||
mapping.map(operand, viewOp);
|
||||
}
|
||||
|
||||
auto newReduce = cloneWithInferType(builder, &(*reduce), mapping);
|
||||
newReduce->setAttr("axis", builder.getI32IntegerAttr(rank));
|
||||
auto typeInfer = dyn_cast<InferTypeOpInterface>(newReduce);
|
||||
if (typeInfer) {
|
||||
SmallVector<Type, 1> newTypes;
|
||||
auto success = typeInfer.inferReturnTypes(
|
||||
newReduce->getContext(), newReduce->getLoc(),
|
||||
newReduce->getOperands(), newReduce->getAttrDictionary(),
|
||||
newReduce->getPropertiesStorage(), newReduce->getRegions(), newTypes);
|
||||
if (succeeded(success)) {
|
||||
for (size_t i = 0; i < newTypes.size(); i++)
|
||||
newReduce->getResult(i).setType(newTypes[i]);
|
||||
}
|
||||
}
|
||||
return newReduce;
|
||||
}
|
||||
|
||||
Operation *createAccum(OpBuilder &builder, triton::ReduceOp reduce,
|
||||
Value &oldAccum, SmallVector<int64_t> &shape,
|
||||
Attribute &slice2d) const {
|
||||
// Drop the last dimension (thread locality dimension)
|
||||
SmallVector<int64_t> accumShape(shape.begin(), shape.end() - 1);
|
||||
auto elemType =
|
||||
oldAccum.getType().cast<RankedTensorType>().getElementType();
|
||||
// Create tensor type for the new accumulator
|
||||
auto accumType = RankedTensorType::get(accumShape, elemType, slice2d);
|
||||
// Create new accumulator
|
||||
builder.setInsertionPointAfter(oldAccum.getDefiningOp());
|
||||
auto reductionOp = getReductionOp(reduce);
|
||||
assert(reductionOp && "Processing a reduce that is not supported!");
|
||||
auto neutralVal = mlir::arith::getNeutralElement(reductionOp.value());
|
||||
assert(neutralVal && "Could not find neutral value for reduction op!");
|
||||
auto denseAttr = DenseElementsAttr::get(accumType, neutralVal.value());
|
||||
auto newAccum = builder.create<arith::ConstantOp>(oldAccum.getLoc(),
|
||||
accumType, denseAttr);
|
||||
return newAccum;
|
||||
}
|
||||
|
||||
SmallVector<int64_t>
|
||||
getThreadLocalityOptimizedShape(triton::ReduceOp reduce) const {
|
||||
auto srcType = reduce.getOperands()[0].getType().cast<RankedTensorType>();
|
||||
auto srcShape = srcType.getShape();
|
||||
auto rank = srcShape.size();
|
||||
auto elemsPerThread =
|
||||
triton::gpu::getElemsPerThread(srcType)[reduce.getAxis()];
|
||||
auto viewOpTensorShape = insertValue(srcShape, rank, 1);
|
||||
viewOpTensorShape[reduce.getAxis()] /= elemsPerThread;
|
||||
viewOpTensorShape[rank] = elemsPerThread;
|
||||
return viewOpTensorShape;
|
||||
}
|
||||
|
||||
Attribute getThreadLocalityOptimizedEncoding(triton::ReduceOp reduce) const {
|
||||
auto srcType = reduce.getOperands()[0].getType().cast<RankedTensorType>();
|
||||
auto rank = srcType.getShape().size();
|
||||
auto srcEncoding = srcType.getEncoding();
|
||||
auto blocked = srcEncoding.dyn_cast<triton::gpu::BlockedEncodingAttr>();
|
||||
auto sizePerThread3d =
|
||||
insertValue(blocked.getSizePerThread(), rank,
|
||||
blocked.getSizePerThread()[reduce.getAxis()]);
|
||||
sizePerThread3d[reduce.getAxis()] = 1;
|
||||
auto threadsPerWarp3d = insertValue(blocked.getThreadsPerWarp(), rank, 1);
|
||||
auto warsPerCTA3d = insertValue(blocked.getWarpsPerCTA(), rank, 1);
|
||||
auto order3d = insertValue(blocked.getOrder(), 0, rank);
|
||||
auto ctasPerCGA3d =
|
||||
insertValue(blocked.getCTALayout().getCTAsPerCGA(), rank, 1);
|
||||
auto ctasSplitNum3d =
|
||||
insertValue(blocked.getCTALayout().getCTASplitNum(), rank, 1);
|
||||
auto ctaOrder3d =
|
||||
insertValue(blocked.getCTALayout().getCTAOrder(), rank, rank);
|
||||
auto ctaLayout3d = triton::gpu::CTALayoutAttr::get(
|
||||
reduce.getContext(), ctasPerCGA3d, ctasSplitNum3d, ctaOrder3d);
|
||||
auto blocked3d = triton::gpu::BlockedEncodingAttr::get(
|
||||
reduce.getContext(), sizePerThread3d, threadsPerWarp3d, warsPerCTA3d,
|
||||
order3d, ctaLayout3d);
|
||||
return blocked3d;
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
SmallVector<T> insertValue(ArrayRef<T> vec, unsigned index, int value) const {
|
||||
SmallVector<T> res(vec.begin(), vec.end());
|
||||
res.insert(res.begin() + index, static_cast<T>(value));
|
||||
return res;
|
||||
}
|
||||
};
|
||||
|
||||
std::unique_ptr<Pass> mlir::createTritonGPUOptimizeThreadLocalityPass() {
|
||||
return std::make_unique<TritonGPUOptimizeThreadLocalityPass>();
|
||||
}
|
||||
File diff suppressed because it is too large
Load Diff
@@ -0,0 +1,814 @@
|
||||
#include "PipelineExpander.h"
|
||||
#include "Schedule.h"
|
||||
#include "mlir/Dialect/Tensor/IR/Tensor.h"
|
||||
#include "mlir/IR/IRMapping.h"
|
||||
#include "mlir/IR/TypeUtilities.h"
|
||||
#include "mlir/Interfaces/SideEffectInterfaces.h"
|
||||
#include "triton/Analysis/AxisInfo.h"
|
||||
#include "triton/Analysis/Utility.h"
|
||||
#include "triton/Dialect/TritonGPU/IR/Dialect.h"
|
||||
#include "triton/Dialect/TritonGPU/Transforms/Passes.h"
|
||||
#include "triton/Dialect/TritonGPU/Transforms/Utility.h"
|
||||
#include "llvm/Support/Debug.h"
|
||||
|
||||
#define int_attr(num) builder.getI64IntegerAttr(num)
|
||||
|
||||
using namespace mlir;
|
||||
namespace tt = mlir::triton;
|
||||
namespace ttg = mlir::triton::gpu;
|
||||
namespace ttng = mlir::triton::nvidia_gpu;
|
||||
|
||||
// TODO: We can extra some helpers into common utilities once we add more
|
||||
// schedules.
|
||||
|
||||
/// Replace the yield with a new one with the given operands appended.
|
||||
static void appendToYield(scf::ForOp forOp, ArrayRef<Value> newOperands) {
|
||||
// Fix up the yield op.
|
||||
Operation *yieldOp = forOp.getBody()->getTerminator();
|
||||
SmallVector<Value> operands(yieldOp->getOperands().begin(),
|
||||
yieldOp->getOperands().end());
|
||||
operands.append(newOperands.begin(), newOperands.end());
|
||||
OpBuilder builder(yieldOp);
|
||||
builder.create<scf::YieldOp>(yieldOp->getLoc(), operands);
|
||||
yieldOp->erase();
|
||||
}
|
||||
|
||||
static void createAsyncCopy(scf::ForOp &forOp, tt::LoadOp loadOp, Value alloc,
|
||||
Value insertIdx, Value extractIdx) {
|
||||
OpBuilder builder(forOp);
|
||||
// Replace the load with insert/extract slice.
|
||||
builder.setInsertionPoint(loadOp);
|
||||
Location loc = loadOp.getLoc();
|
||||
auto insertOp = builder.create<ttg::InsertSliceAsyncOp>(
|
||||
loc, alloc.getType(), loadOp.getPtr(), alloc, insertIdx, loadOp.getMask(),
|
||||
loadOp.getOther(), loadOp.getCache(), loadOp.getEvict(),
|
||||
loadOp.getIsVolatile(), /*axis*/ 0);
|
||||
auto commmit = builder.create<ttg::AsyncCommitGroupOp>(loc);
|
||||
|
||||
// Extract part.
|
||||
auto allocType = alloc.getType().cast<RankedTensorType>();
|
||||
RankedTensorType sliceType = RankedTensorType::get(
|
||||
{allocType.getShape()[1], allocType.getShape()[2]},
|
||||
allocType.getElementType(), allocType.getEncoding());
|
||||
auto extract = builder.create<ttg::ExtractSliceOp>(
|
||||
loc, sliceType, insertOp.getResult(),
|
||||
SmallVector<OpFoldResult>{extractIdx, int_attr(0), int_attr(0)},
|
||||
SmallVector<OpFoldResult>{int_attr(1), int_attr(sliceType.getShape()[0]),
|
||||
int_attr(sliceType.getShape()[1])},
|
||||
SmallVector<OpFoldResult>{int_attr(1), int_attr(1), int_attr(1)});
|
||||
Operation *user = *loadOp.getResult().getUsers().begin();
|
||||
auto convertLayout = llvm::cast<ttg::ConvertLayoutOp>(user);
|
||||
auto newCvt = builder.create<ttg::ConvertLayoutOp>(
|
||||
convertLayout->getLoc(), convertLayout.getType(), extract.getResult());
|
||||
convertLayout->replaceAllUsesWith(newCvt->getResults());
|
||||
convertLayout->erase();
|
||||
loadOp.erase();
|
||||
|
||||
// Fix up the yield op.
|
||||
appendToYield(forOp, {insertOp});
|
||||
}
|
||||
|
||||
static void createTMALoad(scf::ForOp &forOp, tt::LoadOp loadOp, Value alloc,
|
||||
Value insertIdx, Value extractIdx, Value phase) {
|
||||
OpBuilder builder(forOp);
|
||||
Location loc = loadOp.getLoc();
|
||||
auto CTALayout = ttg::CTALayoutAttr::get(loadOp.getContext(),
|
||||
/*CTAsPerCGA*/ {1},
|
||||
/*CTASplitNum*/ {1},
|
||||
/*CTAOrder*/ {0});
|
||||
auto sharedEncoding = ttg::SharedEncodingAttr::get(loadOp.getContext(), 1, 1,
|
||||
1, {0}, CTALayout, false);
|
||||
int64_t numBuffers = alloc.getType().cast<RankedTensorType>().getShape()[0];
|
||||
auto mBarriersTy = RankedTensorType::get(
|
||||
{numBuffers}, builder.getIntegerType(64), sharedEncoding);
|
||||
// Allocate an array of mbarrier objects outside the loop.
|
||||
Value barrierArray =
|
||||
builder.create<ttng::AllocMBarrierOp>(loc, mBarriersTy, 1);
|
||||
// extract the barrier and emit arriver/copy/wait/extract code sequence.
|
||||
builder.setInsertionPoint(loadOp);
|
||||
auto mBarTy = tt::PointerType::get(builder.getIntegerType(64), 3);
|
||||
Value barrier = builder.create<ttng::ExtractMBarrierOp>(
|
||||
loc, mBarTy, barrierArray, insertIdx);
|
||||
Value zero = builder.create<arith::ConstantIntOp>(loc, 0, 32);
|
||||
Value threadId = builder.create<ttng::GetThreadIdOp>(loc);
|
||||
Value pred = builder.create<arith::CmpIOp>(loc, arith::CmpIPredicate::eq,
|
||||
threadId, zero);
|
||||
|
||||
auto loadTy = loadOp.getType().dyn_cast<RankedTensorType>();
|
||||
auto loadShape = loadTy.getShape();
|
||||
auto CTASplitNum = ttg::getCTASplitNum(loadTy.getEncoding());
|
||||
auto shapePerSlice = ttg::getShapePerCTA(CTASplitNum, loadShape);
|
||||
auto elemTy = loadTy.getElementType();
|
||||
unsigned elems = std::accumulate(shapePerSlice.begin(), shapePerSlice.end(),
|
||||
1, std::multiplies{});
|
||||
elems *= (elemTy.getIntOrFloatBitWidth() / 8);
|
||||
builder.create<ttng::MBarrierArriveOp>(loc, barrier, pred,
|
||||
/*remoteCtaId*/ nullptr,
|
||||
/*trackAsyncOp*/ false, elems);
|
||||
auto allocType = alloc.getType().cast<RankedTensorType>();
|
||||
auto insertOp = builder.create<ttng::InsertSliceAsyncV2Op>(
|
||||
loc, allocType, loadOp.getPtr(), alloc,
|
||||
/*index*/ insertIdx, barrier, loadOp.getMask(), loadOp.getOther(),
|
||||
loadOp.getCache(), loadOp.getEvict(), loadOp.getIsVolatile(),
|
||||
/*axis*/ 0);
|
||||
|
||||
RankedTensorType sliceType = RankedTensorType::get(
|
||||
{allocType.getShape()[1], allocType.getShape()[2]},
|
||||
allocType.getElementType(), allocType.getEncoding());
|
||||
auto extract = builder.create<mlir::triton::gpu::ExtractSliceOp>(
|
||||
loc, sliceType, insertOp.getResult(),
|
||||
SmallVector<OpFoldResult>{extractIdx, int_attr(0), int_attr(0)},
|
||||
SmallVector<OpFoldResult>{int_attr(1), int_attr(sliceType.getShape()[0]),
|
||||
int_attr(sliceType.getShape()[1])},
|
||||
SmallVector<OpFoldResult>{int_attr(1), int_attr(1), int_attr(1)});
|
||||
|
||||
Value barrierWait = builder.create<ttng::ExtractMBarrierOp>(
|
||||
loc, mBarTy, barrierArray, extractIdx);
|
||||
builder.create<ttng::MBarrierWaitOp>(loc, barrierWait, phase);
|
||||
|
||||
Operation *user = *loadOp.getResult().getUsers().begin();
|
||||
auto convertLayout = llvm::cast<ttg::ConvertLayoutOp>(user);
|
||||
auto newCvt = builder.create<ttg::ConvertLayoutOp>(
|
||||
convertLayout->getLoc(), convertLayout.getType(), extract.getResult());
|
||||
convertLayout->replaceAllUsesWith(newCvt->getResults());
|
||||
convertLayout->erase();
|
||||
loadOp.erase();
|
||||
|
||||
// Fix up the yield op.
|
||||
appendToYield(forOp, {insertOp});
|
||||
}
|
||||
|
||||
/// Create an async load equivalent to the given load.
|
||||
static void createAsyncLoad(scf::ForOp &forOp, tt::LoadOp loadOp, Value alloc,
|
||||
Value insertIdx, Value extractIdx, Value phase) {
|
||||
if (isLoadFromTensorPtr(loadOp)) {
|
||||
createTMALoad(forOp, loadOp, alloc, insertIdx, extractIdx, phase);
|
||||
} else {
|
||||
createAsyncCopy(forOp, loadOp, alloc, insertIdx, extractIdx);
|
||||
}
|
||||
}
|
||||
|
||||
// Return the transitive use of the load which is a dot operand.
|
||||
static Value loadDotOperand(tt::LoadOp loadOp, bool &hasMMAV3) {
|
||||
// We only pipeline loads that have one covert_layout (to dot_op) use
|
||||
// TODO: lift this constraint in the future
|
||||
bool isCandidate = false;
|
||||
if (!loadOp.getResult().hasOneUse())
|
||||
return Value();
|
||||
|
||||
Operation *use = *loadOp.getResult().getUsers().begin();
|
||||
if (auto convertLayout = llvm::dyn_cast<ttg::ConvertLayoutOp>(use)) {
|
||||
auto tensorType =
|
||||
convertLayout.getResult().getType().cast<RankedTensorType>();
|
||||
if (auto sharedEnc =
|
||||
tensorType.getEncoding().dyn_cast<ttg::SharedEncodingAttr>()) {
|
||||
if (sharedEnc.getHasLeadingOffset()) {
|
||||
// MMA V3 case.
|
||||
auto newOrder = sharedEnc.getOrder();
|
||||
auto ty = loadOp.getType().cast<RankedTensorType>();
|
||||
auto oldOrder = ttg::getOrder(ty.getEncoding());
|
||||
if (newOrder[0] == oldOrder[0] || newOrder[1] == oldOrder[1]) {
|
||||
// The operand of MMAv3 is in SharedEncoding and it's order should
|
||||
// not be changed after FuseTranspositions Pass. So we only pipeline
|
||||
// the load if the order of the loaded BlockedEncoding is the same
|
||||
// as the order of the SharedEncoding it is converted to.
|
||||
// TODO: remove this constraint once the LoadOp supports transpose
|
||||
// fusion
|
||||
hasMMAV3 = true;
|
||||
return convertLayout.getResult();
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
// Advance to the first conversion as long as the use resides in shared
|
||||
// memory and it has a single use itself
|
||||
while (use) {
|
||||
if (use->getNumResults() != 1 || !use->getResult(0).hasOneUse())
|
||||
break;
|
||||
auto tensorType = use->getResult(0).getType().dyn_cast<RankedTensorType>();
|
||||
if (!tensorType.getEncoding().isa<ttg::SharedEncodingAttr>())
|
||||
break;
|
||||
use = *use->getResult(0).getUsers().begin();
|
||||
}
|
||||
|
||||
if (auto convertLayout = llvm::dyn_cast<ttg::ConvertLayoutOp>(use)) {
|
||||
if (auto tensorType =
|
||||
convertLayout.getResult().getType().dyn_cast<RankedTensorType>()) {
|
||||
if (auto dotOpEnc = tensorType.getEncoding()
|
||||
.dyn_cast<ttg::DotOperandEncodingAttr>()) {
|
||||
return convertLayout.getResult();
|
||||
}
|
||||
}
|
||||
}
|
||||
return Value();
|
||||
}
|
||||
|
||||
namespace {
|
||||
struct LoadDotOperand {
|
||||
LoadDotOperand(tt::LoadOp load, Value dotOperand)
|
||||
: load(load), dotOperand(dotOperand) {}
|
||||
tt::LoadOp load;
|
||||
Value dotOperand;
|
||||
};
|
||||
} // namespace
|
||||
|
||||
/// Collect loads to pipeline. Return success if we can pipeline this loop
|
||||
static void collectOpsToPipeline(scf::ForOp forOp,
|
||||
SmallVectorImpl<LoadDotOperand> &ops,
|
||||
bool &hasMMAV3) {
|
||||
ModuleOp moduleOp = forOp->getParentOfType<ModuleOp>();
|
||||
ModuleAxisInfoAnalysis axisInfoAnalysis(moduleOp);
|
||||
|
||||
// We cannot use forOp.walk(...) here because we only want to visit the
|
||||
// operations in the loop body block. Nested blocks are handled separately.
|
||||
for (Operation &op : forOp) {
|
||||
if (auto loadOp = dyn_cast<tt::LoadOp>(&op)) {
|
||||
bool candidate = false;
|
||||
if (isLoadFromTensorPtr(loadOp)) {
|
||||
// Map to TMA load.
|
||||
candidate = true;
|
||||
} else {
|
||||
auto ptr = loadOp.getPtr();
|
||||
unsigned vec = axisInfoAnalysis.getPtrContiguity(ptr);
|
||||
if (auto mask = loadOp.getMask())
|
||||
vec =
|
||||
std::min<unsigned>(vec, axisInfoAnalysis.getMaskAlignment(mask));
|
||||
|
||||
auto tensorTy = ptr.getType().dyn_cast<RankedTensorType>();
|
||||
if (!tensorTy || tensorTy.getRank() < 2)
|
||||
continue;
|
||||
auto ty =
|
||||
tensorTy.getElementType().cast<tt::PointerType>().getPointeeType();
|
||||
unsigned width = vec * ty.getIntOrFloatBitWidth();
|
||||
// We do not pipeline all loads for the following reasons:
|
||||
// 1. On nvidia GPUs, cp.async's cp-size can only be 4, 8 and 16.
|
||||
// 2. It's likely that pipling small loads won't offer much performance
|
||||
// improvement and may even hurt performance by increasing register
|
||||
// pressure.
|
||||
if (width >= 32)
|
||||
candidate = true;
|
||||
}
|
||||
if (!candidate)
|
||||
continue;
|
||||
Value dotOperand = loadDotOperand(loadOp, hasMMAV3);
|
||||
if (!dotOperand)
|
||||
continue;
|
||||
ops.emplace_back(loadOp, dotOperand);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Create an allocation that can old distance number of loadOp shapes.
|
||||
static Value createAlloc(scf::ForOp &forOp, tt::LoadOp loadOp, Value dotOperand,
|
||||
unsigned distance) {
|
||||
OpBuilder builder(forOp);
|
||||
auto ty = loadOp.getType().cast<RankedTensorType>();
|
||||
if (!loadOp.getResult().hasOneUse())
|
||||
return Value();
|
||||
Attribute sharedEnc;
|
||||
auto CTALayout = ttg::getCTALayout(ty.getEncoding());
|
||||
auto tensorType = dotOperand.getType().cast<RankedTensorType>();
|
||||
if (auto dotOpEnc =
|
||||
tensorType.getEncoding().dyn_cast<ttg::DotOperandEncodingAttr>()) {
|
||||
auto convertLayout = dotOperand.getDefiningOp<ttg::ConvertLayoutOp>();
|
||||
bool needTrans = dyn_cast_or_null<tt::TransOp>(
|
||||
convertLayout->getOperand(0).getDefiningOp());
|
||||
unsigned bitWidth = ty.getElementType().getIntOrFloatBitWidth();
|
||||
sharedEnc = ttg::SharedEncodingAttr::get(
|
||||
ty.getContext(), dotOpEnc, ty.getShape(),
|
||||
ttg::getOrder(ty.getEncoding()), CTALayout, bitWidth, needTrans);
|
||||
} else {
|
||||
// MMAv3
|
||||
sharedEnc = ttg::SharedEncodingAttr::get(ty.getContext(), ty.getShape(),
|
||||
ttg::getOrder(ty.getEncoding()),
|
||||
CTALayout, ty.getElementType());
|
||||
}
|
||||
SmallVector<int64_t> bufferShape(ty.getShape().begin(), ty.getShape().end());
|
||||
bufferShape.insert(bufferShape.begin(), distance);
|
||||
Type allocType =
|
||||
RankedTensorType::get(bufferShape, ty.getElementType(), sharedEnc);
|
||||
Value alloc = builder.create<mlir::triton::gpu::AllocTensorOp>(
|
||||
loadOp.getLoc(), allocType);
|
||||
return alloc;
|
||||
}
|
||||
|
||||
// Convert load ops into their asyn version and apply multi-buffering based on
|
||||
// the number of stages.
|
||||
static void createAsynOps(scf::ForOp &forOp, ArrayRef<LoadDotOperand> loads,
|
||||
int numStages, bool hasMMAV3) {
|
||||
struct AsyncLoad {
|
||||
AsyncLoad(tt::LoadOp loadOp, Value alloc) : loadOp(loadOp), alloc(alloc) {}
|
||||
tt::LoadOp loadOp;
|
||||
Value alloc;
|
||||
};
|
||||
int numBuffers = numStages - 1;
|
||||
// For MMAv3 we need an extra buffer as this is assumed in the wgmma
|
||||
// pipelining post-processing.
|
||||
// TODO: Improve modeling of wgmma pipelining.
|
||||
if (hasMMAV3)
|
||||
numBuffers++;
|
||||
SmallVector<AsyncLoad> asyncLoads;
|
||||
SmallVector<Value> newOperands;
|
||||
bool needsMbarrierPhase = false;
|
||||
bool needsAsyncWait = false;
|
||||
for (const LoadDotOperand &loadOperand : loads) {
|
||||
tt::LoadOp loadOp = loadOperand.load;
|
||||
Value dotOperand = loadOperand.dotOperand;
|
||||
Value alloc = createAlloc(forOp, loadOp, dotOperand, numBuffers);
|
||||
assert(alloc && "Failed to create alloc for the async load.");
|
||||
newOperands.push_back(alloc);
|
||||
asyncLoads.emplace_back(loadOp, alloc);
|
||||
if (isLoadFromTensorPtr(loadOp))
|
||||
needsMbarrierPhase = true;
|
||||
else
|
||||
needsAsyncWait = true;
|
||||
}
|
||||
|
||||
OpBuilder builder(forOp);
|
||||
Location loc = forOp.getLoc();
|
||||
// Create two new counters to index into the allocs.
|
||||
Value minusOne = builder.create<arith::ConstantIntOp>(loc, -1, 32);
|
||||
Value zero = builder.create<arith::ConstantIntOp>(loc, 0, 32);
|
||||
Value one = builder.create<arith::ConstantIntOp>(loc, 1, 32);
|
||||
Value insertIdx = minusOne;
|
||||
Value extractIdx = minusOne;
|
||||
Value numBuffersVal =
|
||||
builder.create<arith::ConstantIntOp>(loc, numBuffers, 32);
|
||||
newOperands.push_back(insertIdx);
|
||||
newOperands.push_back(extractIdx);
|
||||
Value phase;
|
||||
if (needsMbarrierPhase) {
|
||||
phase = builder.create<arith::ConstantIntOp>(loc, 0, 1);
|
||||
newOperands.push_back(phase);
|
||||
}
|
||||
unsigned newOperandIndex = forOp.getBody()->getNumArguments();
|
||||
// Patch the loop to add the new loop carried dependencies.
|
||||
scf::ForOp newForOp =
|
||||
replaceForOpWithNewSignature(builder, forOp, newOperands);
|
||||
forOp.erase();
|
||||
forOp = newForOp;
|
||||
for (int i = 0; i < asyncLoads.size(); i++) {
|
||||
asyncLoads[i].alloc = newForOp.getBody()->getArgument(newOperandIndex + i);
|
||||
}
|
||||
insertIdx =
|
||||
newForOp.getBody()->getArgument(newOperandIndex + asyncLoads.size());
|
||||
extractIdx =
|
||||
newForOp.getBody()->getArgument(newOperandIndex + asyncLoads.size() + 1);
|
||||
|
||||
// Create two counters for the insert and extract indices to avoid creating
|
||||
// long liverange.
|
||||
builder.setInsertionPoint(asyncLoads.front().loadOp);
|
||||
insertIdx = builder.create<arith::AddIOp>(loc, insertIdx, one);
|
||||
Value cndIns = builder.create<arith::CmpIOp>(loc, arith::CmpIPredicate::slt,
|
||||
insertIdx, numBuffersVal);
|
||||
insertIdx = builder.create<arith::SelectOp>(loc, cndIns, insertIdx, zero);
|
||||
|
||||
extractIdx = builder.create<arith::AddIOp>(loc, extractIdx, one);
|
||||
Value cndExt = builder.create<arith::CmpIOp>(loc, arith::CmpIPredicate::slt,
|
||||
extractIdx, numBuffersVal);
|
||||
extractIdx = builder.create<arith::SelectOp>(loc, cndExt, extractIdx, zero);
|
||||
|
||||
if (needsMbarrierPhase) {
|
||||
phase = newForOp.getBody()->getArgument(newOperandIndex +
|
||||
asyncLoads.size() + 2);
|
||||
Value oneI1 = builder.create<arith::ConstantIntOp>(loc, 1, 1);
|
||||
Value nextPhase = builder.create<arith::XOrIOp>(loc, phase, oneI1);
|
||||
phase = builder.create<arith::SelectOp>(loc, cndExt, phase, nextPhase);
|
||||
}
|
||||
|
||||
bool firstLoad = true;
|
||||
for (AsyncLoad &asyncLoad : asyncLoads) {
|
||||
createAsyncLoad(forOp, asyncLoad.loadOp, asyncLoad.alloc, insertIdx,
|
||||
extractIdx, phase);
|
||||
firstLoad = false;
|
||||
}
|
||||
// Insert a waitOp after the first async copy. This does make the assumption
|
||||
// that the wait will be scheduled in a different stage that all the async
|
||||
// copy but we cannot guarantee that one wait is enough otherwise.
|
||||
for (auto &op : forOp.getBody()->without_terminator()) {
|
||||
if (isa<ttg::InsertSliceAsyncOp>(op)) {
|
||||
OpBuilder builder(op.getContext());
|
||||
builder.setInsertionPointAfter(&op);
|
||||
builder.create<ttg::AsyncWaitOp>(op.getLoc(), 0);
|
||||
break;
|
||||
}
|
||||
}
|
||||
SmallVector<Value> newYieldOperands = {insertIdx, extractIdx};
|
||||
if (needsMbarrierPhase)
|
||||
newYieldOperands.push_back(phase);
|
||||
// Patch the yield with the updated counters.
|
||||
appendToYield(forOp, newYieldOperands);
|
||||
}
|
||||
|
||||
// Combine the current mask with the given predicate.
|
||||
static Value getPredMask(RewriterBase &rewriter, Type typeLike,
|
||||
Value currentMask, Value pred) {
|
||||
Type maskType = tt::getI1SameShape(typeLike);
|
||||
Location loc = pred.getLoc();
|
||||
Value mask = pred;
|
||||
if (maskType.isa<RankedTensorType>()) {
|
||||
mask = rewriter.create<tt::SplatOp>(loc, maskType, pred);
|
||||
}
|
||||
if (currentMask) {
|
||||
mask = rewriter.create<arith::AndIOp>(loc, mask, currentMask);
|
||||
}
|
||||
return mask;
|
||||
}
|
||||
|
||||
// Function to mask operations during scheduling.
|
||||
static Operation *predicateOp(RewriterBase &rewriter, Operation *op,
|
||||
Value pred) {
|
||||
OpBuilder::InsertionGuard guard(rewriter);
|
||||
if (mlir::isMemoryEffectFree(op))
|
||||
return op;
|
||||
if (isa<ttg::AsyncCommitGroupOp>(op))
|
||||
return op;
|
||||
if (isa<ttg::AsyncWaitOp>(op))
|
||||
return op;
|
||||
if (auto insertOp = dyn_cast<ttg::InsertSliceAsyncOp>(op)) {
|
||||
rewriter.setInsertionPoint(insertOp);
|
||||
Value mask = getPredMask(rewriter, insertOp.getSrc().getType(),
|
||||
insertOp.getMask(), pred);
|
||||
insertOp.getMaskMutable().assign(mask);
|
||||
return op;
|
||||
}
|
||||
if (auto insertOp = dyn_cast<ttng::InsertSliceAsyncV2Op>(op)) {
|
||||
rewriter.setInsertionPoint(insertOp);
|
||||
Value mask = getPredMask(
|
||||
rewriter,
|
||||
insertOp.getSrc().getType().cast<tt::PointerType>().getPointeeType(),
|
||||
insertOp.getMask(), pred);
|
||||
insertOp.getMaskMutable().assign(mask);
|
||||
return op;
|
||||
}
|
||||
if (auto arriveOp = dyn_cast<ttng::MBarrierArriveOp>(op)) {
|
||||
rewriter.setInsertionPoint(arriveOp);
|
||||
Value mask = getPredMask(rewriter, rewriter.getIntegerType(1),
|
||||
arriveOp.getPred(), pred);
|
||||
arriveOp.getPredMutable().assign(mask);
|
||||
return op;
|
||||
}
|
||||
if (isa<ttng::MBarrierWaitOp>(op)) {
|
||||
return op;
|
||||
}
|
||||
if (auto loadOp = dyn_cast<tt::LoadOp>(op)) {
|
||||
rewriter.setInsertionPoint(loadOp);
|
||||
Value mask = getPredMask(rewriter, loadOp.getPtr().getType(),
|
||||
loadOp.getMask(), pred);
|
||||
loadOp.getMaskMutable().assign(mask);
|
||||
return op;
|
||||
}
|
||||
|
||||
assert("don't know how to predicate this op" && false);
|
||||
return op;
|
||||
}
|
||||
|
||||
static void setWaitNum(Operation *op,
|
||||
mlir::triton::PipeliningOption::PipelinerPart part,
|
||||
unsigned iteration, unsigned numLoadsInStage) {
|
||||
if (auto waitOp = dyn_cast<ttg::AsyncWaitOp>(op)) {
|
||||
waitOp.setNum(numLoadsInStage);
|
||||
}
|
||||
}
|
||||
|
||||
/// Helper to recursively add dependencies to the same stage.
|
||||
static void addDep(Operation *op, DenseSet<Operation *> &deps,
|
||||
bool includeArg = true,
|
||||
DenseSet<Operation *> *filter = nullptr) {
|
||||
if (filter && filter->count(op))
|
||||
return;
|
||||
if (!deps.insert(op).second)
|
||||
return;
|
||||
for (Value operand : op->getOperands()) {
|
||||
Value v = operand;
|
||||
llvm::SmallDenseSet<Value> seen;
|
||||
while (auto arg = v.dyn_cast<BlockArgument>()) {
|
||||
if (!includeArg)
|
||||
break;
|
||||
if (!seen.insert(v).second)
|
||||
break;
|
||||
if (arg.getArgNumber() > 0 && arg.getOwner() == op->getBlock()) {
|
||||
auto yieldOp = op->getBlock()->getTerminator();
|
||||
v = yieldOp->getOperand(arg.getArgNumber() - 1);
|
||||
continue;
|
||||
}
|
||||
break;
|
||||
}
|
||||
Operation *defOp = v.getDefiningOp();
|
||||
if (defOp && defOp->getBlock() == op->getBlock()) {
|
||||
addDep(defOp, deps, includeArg, filter);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Add operations to the shedule with the given stage based on the filter
|
||||
// function.
|
||||
static void addOps(scf::ForOp forOp, int stage,
|
||||
std::vector<std::pair<Operation *, unsigned>> &schedule,
|
||||
std::function<bool(Operation *)> filter) {
|
||||
for (Operation &op : forOp.getBody()->without_terminator()) {
|
||||
if (!filter(&op))
|
||||
continue;
|
||||
schedule.emplace_back(&op, stage);
|
||||
}
|
||||
}
|
||||
|
||||
// create the schedule for a matmul loop. This is ad hoc based on how we know
|
||||
// matmul loops should be pipelined and is not a generic scheduler.
|
||||
static std::vector<std::pair<Operation *, unsigned>>
|
||||
createSchedule(scf::ForOp forOp, int numStages, bool prefetchExtract) {
|
||||
SmallVector<Operation *> insertOps;
|
||||
SmallVector<Operation *> extractOps;
|
||||
// Find the insert/extract ops that will go respectively in stage 0 and stage
|
||||
// `numStages - 2`. All the other operations will go in stage `numStages - 1`.
|
||||
for (Operation &op : forOp.getBody()->without_terminator()) {
|
||||
if (isa<ttg::InsertSliceAsyncOp, ttg::AsyncCommitGroupOp,
|
||||
ttng::MBarrierArriveOp, ttng::InsertSliceAsyncV2Op>(op))
|
||||
insertOps.emplace_back(&op);
|
||||
if (prefetchExtract) {
|
||||
if (isa<ttg::ExtractSliceOp, ttg::AsyncWaitOp>(op))
|
||||
extractOps.emplace_back(&op);
|
||||
}
|
||||
}
|
||||
DenseSet<Operation *> insertAndDeps;
|
||||
for (Operation *op : insertOps) {
|
||||
addDep(op, insertAndDeps, false);
|
||||
}
|
||||
|
||||
// Find depenencies with distance of 1.
|
||||
SmallVector<Operation *> distanceOneUsers;
|
||||
for (Operation *op : insertAndDeps) {
|
||||
for (Value operand : op->getOperands()) {
|
||||
if (auto arg = operand.dyn_cast<BlockArgument>()) {
|
||||
if (arg.getArgNumber() > 0 && arg.getOwner() == op->getBlock()) {
|
||||
auto yieldOp = op->getBlock()->getTerminator();
|
||||
Value v = yieldOp->getOperand(arg.getArgNumber() - 1);
|
||||
Operation *defOp = v.getDefiningOp();
|
||||
if (defOp && insertAndDeps.count(defOp) == 0) {
|
||||
distanceOneUsers.push_back(defOp);
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
// Schedule loads with a distance of 1 in stage 0
|
||||
for (Operation *op : distanceOneUsers) {
|
||||
if (isa<tt::LoadOp>(op)) {
|
||||
addDep(op, insertAndDeps, true);
|
||||
}
|
||||
}
|
||||
// For the rest of the ops we can move then into stage 1 so that they can be
|
||||
// closer to their uses.
|
||||
DenseSet<Operation *> stage1deps;
|
||||
for (Operation *op : distanceOneUsers) {
|
||||
if (!isa<tt::LoadOp>(op)) {
|
||||
addDep(op, stage1deps, true, &insertAndDeps);
|
||||
}
|
||||
}
|
||||
|
||||
DenseSet<Operation *> extractAndDeps;
|
||||
for (Operation *op : extractOps) {
|
||||
addDep(op, extractAndDeps, true, &insertAndDeps);
|
||||
}
|
||||
std::vector<std::pair<Operation *, unsigned>> schedule;
|
||||
// Schedule stage `numStage - 1` first.
|
||||
addOps(forOp, numStages - 1, schedule, [&](Operation *op) {
|
||||
return insertAndDeps.count(op) == 0 && stage1deps.count(op) == 0 &&
|
||||
extractAndDeps.count(op) == 0;
|
||||
});
|
||||
|
||||
// Schedule some dependencies with distance of 1 into stage 1 to reduce
|
||||
// pressure.
|
||||
addOps(forOp, 1, schedule,
|
||||
[&](Operation *op) { return stage1deps.count(op); });
|
||||
|
||||
// Then Schedule stage 0.
|
||||
addOps(forOp, 0, schedule,
|
||||
[&](Operation *op) { return insertAndDeps.count(op); });
|
||||
|
||||
// Finally schedule the extract ops in stage `numStage - 2` so that they get
|
||||
// pre-fetched and play well with pretech pass.
|
||||
addOps(forOp, numStages - 2, schedule,
|
||||
[&](Operation *op) { return extractAndDeps.count(op); });
|
||||
return schedule;
|
||||
}
|
||||
|
||||
bool mlir::triton::preProcessLoopAndGetSchedule(
|
||||
scf::ForOp &forOp, int numStages, mlir::triton::PipeliningOption &options) {
|
||||
// 1. First collect "interesting" operations with a stage where to schedule
|
||||
// them. This gives a coarse scheduling for the loop.
|
||||
SmallVector<LoadDotOperand> loads;
|
||||
bool hasMMAV3 = false;
|
||||
collectOpsToPipeline(forOp, loads, hasMMAV3);
|
||||
if (loads.empty())
|
||||
return false;
|
||||
bool hasAsynCp = llvm::any_of(loads, [](LoadDotOperand &load) {
|
||||
return !isLoadFromTensorPtr(load.load);
|
||||
});
|
||||
// 2. Convert the loads into async loads and create the allocs.
|
||||
createAsynOps(forOp, loads, numStages, hasMMAV3);
|
||||
|
||||
// 3. Create the final schedule for the kernel loop. This will dictate the
|
||||
// stages and order of operations to the pipeline expander.
|
||||
std::vector<std::pair<Operation *, unsigned>> schedule =
|
||||
createSchedule(forOp, numStages, /*prefetchExtract=*/!hasMMAV3);
|
||||
|
||||
// 4. Fill out the pipeline options.
|
||||
options.getScheduleFn =
|
||||
[schedule](scf::ForOp forOp,
|
||||
std::vector<std::pair<Operation *, unsigned>> &s) {
|
||||
s = std::move(schedule);
|
||||
};
|
||||
options.peelEpilogue = false;
|
||||
options.predicateFn = predicateOp;
|
||||
options.supportDynamicLoops = true;
|
||||
unsigned numLoadsInStage = (numStages - 2) * loads.size();
|
||||
options.annotateFn =
|
||||
[numLoadsInStage](Operation *op,
|
||||
mlir::triton::PipeliningOption::PipelinerPart part,
|
||||
unsigned iteration) {
|
||||
return setWaitNum(op, part, iteration, numLoadsInStage);
|
||||
};
|
||||
|
||||
if (hasAsynCp) {
|
||||
// Insert a wait 0 after the loop
|
||||
OpBuilder builder(forOp);
|
||||
builder.setInsertionPointAfter(forOp);
|
||||
builder.create<ttg::AsyncWaitOp>(forOp.getLoc(), 0);
|
||||
}
|
||||
return true;
|
||||
}
|
||||
|
||||
/// MMA V3 post-processing.
|
||||
static bool selfDepend(tt::DotOp dotOp, scf::ForOp forOp,
|
||||
Operation **firstUse) {
|
||||
std::function<bool(Value, int, scf::ForOp)> dependOn =
|
||||
[&dependOn](Value v, int argId, scf::ForOp forOp) {
|
||||
auto op = v.getDefiningOp();
|
||||
if (isa<BlockArgument>(v)) {
|
||||
auto iterArgs = forOp.getRegionIterArgs();
|
||||
auto iter = std::find(iterArgs.begin(), iterArgs.end(), v);
|
||||
if (iter != iterArgs.end())
|
||||
return std::distance(iterArgs.begin(), iter) == argId;
|
||||
} else {
|
||||
if (!op)
|
||||
return false;
|
||||
for (auto operand : op->getOperands()) {
|
||||
if (dependOn(operand, argId, forOp))
|
||||
return true;
|
||||
}
|
||||
}
|
||||
return false;
|
||||
};
|
||||
auto result = dotOp.getResult();
|
||||
auto yieldOp = forOp.getBody()->getTerminator();
|
||||
int argIdx = -1;
|
||||
auto iter = std::find(yieldOp->getOperands().begin(),
|
||||
yieldOp->getOperands().end(), result);
|
||||
if (iter != yieldOp->getOperands().end())
|
||||
argIdx = std::distance(yieldOp->getOperands().begin(), iter);
|
||||
if (argIdx == -1)
|
||||
return false;
|
||||
for (auto operand : dotOp.getOperands()) {
|
||||
if (dependOn(operand, argIdx, forOp)) {
|
||||
auto iterArgs = forOp.getRegionIterArgs();
|
||||
*firstUse = iterArgs[argIdx].use_begin().getUser();
|
||||
return true;
|
||||
}
|
||||
}
|
||||
return false;
|
||||
}
|
||||
|
||||
static void removeExtraWait(tt::nvidia_gpu::DotWaitOp dotWaitOp,
|
||||
bool hasDotWait0) {
|
||||
if (hasDotWait0) {
|
||||
dotWaitOp->erase();
|
||||
}
|
||||
}
|
||||
|
||||
void mlir::triton::asyncLaunchDots(scf::ForOp forOp) {
|
||||
Block *loop = forOp.getBody();
|
||||
auto getBlockNumInFor = [](Operation *op, scf::ForOp forOp) {
|
||||
if (!op)
|
||||
return -1l;
|
||||
auto lastOp = op;
|
||||
while (op->getBlock()->getParentOp() != forOp) {
|
||||
lastOp = op;
|
||||
op = op->getBlock()->getParentOp();
|
||||
}
|
||||
return std::distance(lastOp->getBlock()->getParent()->begin(),
|
||||
lastOp->getBlock()->getIterator());
|
||||
};
|
||||
/// XXX(Keren): Clean up the following duplicate code with checkDotOp
|
||||
/// dots to be pipelined
|
||||
bool hasSyncDot = false;
|
||||
bool hasDotWait0 = false;
|
||||
SmallVector<tt::DotOp> allDots;
|
||||
SmallVector<tt::DotOp> dots;
|
||||
SmallVector<unsigned> resultNeedSync;
|
||||
for (Operation &op : *loop) {
|
||||
if (auto dotWaitOp = dyn_cast<tt::nvidia_gpu::DotWaitOp>(&op)) {
|
||||
auto attr = dotWaitOp->getAttrOfType<IntegerAttr>("pendings");
|
||||
auto pendingCount = attr.getInt();
|
||||
if (pendingCount == 0)
|
||||
hasDotWait0 = true;
|
||||
}
|
||||
if (auto dotOp = dyn_cast<tt::DotOp>(&op)) {
|
||||
allDots.push_back(dotOp);
|
||||
}
|
||||
}
|
||||
for (Operation &op : *loop) {
|
||||
if (auto dotOp = dyn_cast<tt::DotOp>(&op)) {
|
||||
auto resTy = dotOp.getResult().getType().dyn_cast<RankedTensorType>();
|
||||
if (auto resEnc = resTy.getEncoding().dyn_cast<ttg::MmaEncodingAttr>()) {
|
||||
if (resEnc && resEnc.isHopper()) {
|
||||
auto dot = dotOp.getResult();
|
||||
bool valid = true;
|
||||
|
||||
// all users of dot should be scf.yield
|
||||
if (!dot.hasOneUse())
|
||||
valid = false;
|
||||
if (!isa<scf::YieldOp>(*dot.getUsers().begin()))
|
||||
valid = false;
|
||||
|
||||
Operation *firstUse = nullptr;
|
||||
auto depend = selfDepend(dotOp, forOp, &firstUse);
|
||||
bool selfDirectDepend = (dotOp == firstUse);
|
||||
for (auto tempInAll : allDots) {
|
||||
auto iter = std::find(dots.begin(), dots.end(), tempInAll);
|
||||
if (iter != dots.end())
|
||||
continue;
|
||||
auto db = getBlockNumInFor(tempInAll, forOp);
|
||||
auto fb = getBlockNumInFor(firstUse, forOp);
|
||||
if (db < fb ||
|
||||
(db == fb && db >= 0 && tempInAll->isBeforeInBlock(firstUse)))
|
||||
hasSyncDot = true;
|
||||
}
|
||||
auto CArg = dotOp.getOperand(2);
|
||||
if (!(selfDirectDepend ||
|
||||
(depend && !selfDirectDepend && hasSyncDot)) ||
|
||||
!CArg.hasOneUse())
|
||||
valid = false;
|
||||
|
||||
if (valid) {
|
||||
dots.push_back(dotOp);
|
||||
resultNeedSync.push_back(
|
||||
dotOp->getUses().begin()->getOperandNumber());
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Early stop: no need to continue if there is no valid dot in the loop.
|
||||
if (dots.empty())
|
||||
return;
|
||||
|
||||
OpBuilder builder(forOp);
|
||||
// 0. insert dot_wait after the last dot in the loop as we implicitly pipeline
|
||||
// wgmma ops by one stage.
|
||||
// This is needed to prevent shared memory inputs to be overriden before the
|
||||
// operation is completed.
|
||||
// TODO: merge this with the rest of the pipelining transformation and look at
|
||||
// a better representation for async dots.
|
||||
tt::DotOp lastDot = dots.back();
|
||||
auto loc = lastDot.getLoc();
|
||||
builder.setInsertionPointAfter(lastDot);
|
||||
auto dotWait = builder.create<tt::nvidia_gpu::DotWaitOp>(
|
||||
lastDot.getLoc(), lastDot.getResult(), dots.size());
|
||||
|
||||
// 1. replace Dot with DotAsync
|
||||
for (size_t idx = 0; idx < dots.size(); ++idx) {
|
||||
tt::DotOp dotOp = dots[idx];
|
||||
builder.setInsertionPoint(dotOp);
|
||||
auto dotAsync = builder.create<tt::nvidia_gpu::DotAsyncOp>(
|
||||
dotOp.getLoc(), dotOp.getA(), dotOp.getB(), dotOp.getC(),
|
||||
dotOp.getAllowTF32(), dotOp.getMaxNumImpreciseAcc());
|
||||
dotOp.replaceAllUsesWith(dotAsync.getResult());
|
||||
dotOp->erase();
|
||||
}
|
||||
|
||||
hasDotWait0 = hasDotWait0 || hasSyncDot;
|
||||
|
||||
// 2. If there's any outstanding DotAsyncOps, we need to wait for them.
|
||||
builder.setInsertionPointAfter(forOp);
|
||||
SmallVector<Value> waitOperands;
|
||||
for (int i = 0; i < resultNeedSync.size(); ++i) {
|
||||
Value result = forOp->getResult(resultNeedSync[i]);
|
||||
if (result.use_empty())
|
||||
continue;
|
||||
waitOperands.push_back(result);
|
||||
}
|
||||
if (!waitOperands.empty()) {
|
||||
auto dotWait = builder.create<tt::nvidia_gpu::DotWaitOp>(forOp.getLoc(),
|
||||
waitOperands, 0);
|
||||
for (int i = 0; i < resultNeedSync.size(); ++i) {
|
||||
Value result = forOp->getResult(resultNeedSync[i]);
|
||||
result.replaceAllUsesExcept(dotWait.getResult(i), dotWait);
|
||||
}
|
||||
}
|
||||
|
||||
// 3. potentially remove redundant dot_wait after dot_async if having mutiple
|
||||
// DotOp in the loop
|
||||
removeExtraWait(dotWait, hasDotWait0);
|
||||
}
|
||||
704
lib/Dialect/TritonGPU/Transforms/Pipeliner/PipelineExpander.cpp
Normal file
704
lib/Dialect/TritonGPU/Transforms/Pipeliner/PipelineExpander.cpp
Normal file
@@ -0,0 +1,704 @@
|
||||
//===- LoopPipelining.cpp - Code to perform loop software pipelining-------===//
|
||||
//
|
||||
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
|
||||
// See https://llvm.org/LICENSE.txt for license information.
|
||||
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
|
||||
//
|
||||
//===----------------------------------------------------------------------===//
|
||||
//
|
||||
// This file implements loop software pipelining
|
||||
//
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
||||
// Fork of upstream pipeliner. This will be merged upstream once things are
|
||||
// stable. Modifications so far are:
|
||||
// -Bug fix for def with a distance of 1 scheduled in stage 0.
|
||||
// -Support dynamic loops and predicate operations in the prologue.
|
||||
// -Support for non-index type for induction variable.
|
||||
// -Support source with distance of 1 used multiple stages later.
|
||||
// -Fix bug when a value yield is used outside the loop and the value def is not
|
||||
// in the last stage. If we are not peeling the epilgue we need to remap the
|
||||
// output correctly.
|
||||
|
||||
#include "mlir/Dialect/Arith/IR/Arith.h"
|
||||
#include "mlir/Dialect/SCF/IR/SCF.h"
|
||||
#include "mlir/Dialect/SCF/Transforms/Patterns.h"
|
||||
#include "mlir/Dialect/SCF/Utils/Utils.h"
|
||||
#include "mlir/IR/IRMapping.h"
|
||||
#include "mlir/IR/PatternMatch.h"
|
||||
#include "mlir/Support/MathExtras.h"
|
||||
#include "mlir/Transforms/RegionUtils.h"
|
||||
#include "llvm/ADT/MapVector.h"
|
||||
#include "llvm/Support/Debug.h"
|
||||
|
||||
#include "PipelineExpander.h"
|
||||
|
||||
#define DEBUG_TYPE "triton-loop-pipelining"
|
||||
#define DBGS() (llvm::dbgs() << "[" DEBUG_TYPE "]: ")
|
||||
#define LDBG(X) LLVM_DEBUG(DBGS() << X << "\n")
|
||||
|
||||
using namespace mlir;
|
||||
using namespace mlir::scf;
|
||||
using namespace mlir::triton;
|
||||
|
||||
namespace {
|
||||
|
||||
/// Helper to keep internal information during pipelining transformation.
|
||||
struct LoopPipelinerInternal {
|
||||
/// Coarse liverange information for ops used across stages.
|
||||
struct LiverangeInfo {
|
||||
unsigned lastUseStage = 0;
|
||||
unsigned defStage = 0;
|
||||
};
|
||||
|
||||
protected:
|
||||
ForOp forOp;
|
||||
unsigned maxStage = 0;
|
||||
DenseMap<Operation *, unsigned> stages;
|
||||
std::vector<Operation *> opOrder;
|
||||
Value ub;
|
||||
Value lb;
|
||||
Value step;
|
||||
bool dynamicLoop;
|
||||
triton::PipeliningOption::AnnotationlFnType annotateFn = nullptr;
|
||||
bool peelEpilogue;
|
||||
triton::PipeliningOption::PredicateOpFnType predicateFn = nullptr;
|
||||
|
||||
// When peeling the kernel we generate several version of each value for
|
||||
// different stage of the prologue. This map tracks the mapping between
|
||||
// original Values in the loop and the different versions
|
||||
// peeled from the loop.
|
||||
DenseMap<Value, llvm::SmallVector<Value>> valueMapping;
|
||||
|
||||
/// Assign a value to `valueMapping`, this means `val` represents the version
|
||||
/// `idx` of `key` in the epilogue.
|
||||
void setValueMapping(Value key, Value el, int64_t idx);
|
||||
|
||||
std::pair<Operation *, int64_t> getDefiningOpAndDistance(Value value);
|
||||
|
||||
public:
|
||||
/// Initalize the information for the given `op`, return true if it
|
||||
/// satisfies the pre-condition to apply pipelining.
|
||||
bool initializeLoopInfo(ForOp op, const triton::PipeliningOption &options);
|
||||
/// Emits the prologue, this creates `maxStage - 1` part which will contain
|
||||
/// operations from stages [0; i], where i is the part index.
|
||||
void emitPrologue(RewriterBase &rewriter);
|
||||
/// Gather liverange information for Values that are used in a different stage
|
||||
/// than its definition.
|
||||
llvm::MapVector<Value, LiverangeInfo> analyzeCrossStageValues();
|
||||
scf::ForOp createKernelLoop(
|
||||
const llvm::MapVector<Value, LiverangeInfo> &crossStageValues,
|
||||
RewriterBase &rewriter,
|
||||
llvm::DenseMap<std::pair<Value, unsigned>, unsigned> &loopArgMap);
|
||||
/// Emits the pipelined kernel. This clones loop operations following user
|
||||
/// order and remaps operands defined in a different stage as their use.
|
||||
LogicalResult createKernel(
|
||||
scf::ForOp newForOp,
|
||||
const llvm::MapVector<Value, LiverangeInfo> &crossStageValues,
|
||||
const llvm::DenseMap<std::pair<Value, unsigned>, unsigned> &loopArgMap,
|
||||
RewriterBase &rewriter);
|
||||
/// Emits the epilogue, this creates `maxStage - 1` part which will contain
|
||||
/// operations from stages [i; maxStage], where i is the part index.
|
||||
llvm::SmallVector<Value> emitEpilogue(RewriterBase &rewriter);
|
||||
};
|
||||
|
||||
bool LoopPipelinerInternal::initializeLoopInfo(
|
||||
ForOp op, const triton::PipeliningOption &options) {
|
||||
LDBG("Start initializeLoopInfo");
|
||||
forOp = op;
|
||||
ub = forOp.getUpperBound();
|
||||
lb = forOp.getLowerBound();
|
||||
step = forOp.getStep();
|
||||
|
||||
dynamicLoop = true;
|
||||
auto upperBoundCst = ub.getDefiningOp<arith::ConstantIndexOp>();
|
||||
auto lowerBoundCst = lb.getDefiningOp<arith::ConstantIndexOp>();
|
||||
auto stepCst = step.getDefiningOp<arith::ConstantIndexOp>();
|
||||
if (!upperBoundCst || !lowerBoundCst || !stepCst) {
|
||||
if (!options.supportDynamicLoops) {
|
||||
LDBG("--dynamic loop not supported -> BAIL");
|
||||
return false;
|
||||
}
|
||||
} else {
|
||||
int64_t ubImm = upperBoundCst.value();
|
||||
int64_t lbImm = lowerBoundCst.value();
|
||||
int64_t stepImm = stepCst.value();
|
||||
int64_t numIteration = ceilDiv(ubImm - lbImm, stepImm);
|
||||
if (numIteration > maxStage) {
|
||||
dynamicLoop = false;
|
||||
} else if (!options.supportDynamicLoops) {
|
||||
LDBG("--fewer loop iterations than pipeline stages -> BAIL");
|
||||
return false;
|
||||
}
|
||||
}
|
||||
peelEpilogue = options.peelEpilogue;
|
||||
predicateFn = options.predicateFn;
|
||||
if ((!peelEpilogue || dynamicLoop) && predicateFn == nullptr) {
|
||||
LDBG("--no epilogue or predicate set -> BAIL");
|
||||
return false;
|
||||
}
|
||||
std::vector<std::pair<Operation *, unsigned>> schedule;
|
||||
options.getScheduleFn(forOp, schedule);
|
||||
if (schedule.empty()) {
|
||||
LDBG("--empty schedule -> BAIL");
|
||||
return false;
|
||||
}
|
||||
|
||||
opOrder.reserve(schedule.size());
|
||||
for (auto &opSchedule : schedule) {
|
||||
maxStage = std::max(maxStage, opSchedule.second);
|
||||
stages[opSchedule.first] = opSchedule.second;
|
||||
opOrder.push_back(opSchedule.first);
|
||||
}
|
||||
|
||||
// All operations need to have a stage.
|
||||
for (Operation &op : forOp.getBody()->without_terminator()) {
|
||||
if (!stages.contains(&op)) {
|
||||
op.emitOpError("not assigned a pipeline stage");
|
||||
LDBG("--op not assigned a pipeline stage: " << op << " -> BAIL");
|
||||
return false;
|
||||
}
|
||||
}
|
||||
|
||||
// Currently, we do not support assigning stages to ops in nested regions. The
|
||||
// block of all operations assigned a stage should be the single `scf.for`
|
||||
// body block.
|
||||
for (const auto &[op, stageNum] : stages) {
|
||||
(void)stageNum;
|
||||
if (op == forOp.getBody()->getTerminator()) {
|
||||
op->emitError("terminator should not be assigned a stage");
|
||||
LDBG("--terminator should not be assigned stage: " << *op << " -> BAIL");
|
||||
return false;
|
||||
}
|
||||
if (op->getBlock() != forOp.getBody()) {
|
||||
op->emitOpError("the owning Block of all operations assigned a stage "
|
||||
"should be the loop body block");
|
||||
LDBG("--the owning Block of all operations assigned a stage "
|
||||
"should be the loop body block: "
|
||||
<< *op << " -> BAIL");
|
||||
return false;
|
||||
}
|
||||
}
|
||||
|
||||
// Only support loop carried dependency with a distance of 1. This means the
|
||||
// source of all the scf.yield operands needs to be defined by operations in
|
||||
// the loop.
|
||||
if (llvm::any_of(forOp.getBody()->getTerminator()->getOperands(),
|
||||
[this](Value operand) {
|
||||
Operation *def = operand.getDefiningOp();
|
||||
return !def || !stages.contains(def);
|
||||
})) {
|
||||
LDBG("--only support loop carried dependency with a distance of 1 -> BAIL");
|
||||
return false;
|
||||
}
|
||||
annotateFn = options.annotateFn;
|
||||
return true;
|
||||
}
|
||||
|
||||
/// Clone `op` and call `callback` on the cloned op's oeprands as well as any
|
||||
/// operands of nested ops that:
|
||||
/// 1) aren't defined within the new op or
|
||||
/// 2) are block arguments.
|
||||
static Operation *
|
||||
cloneAndUpdateOperands(RewriterBase &rewriter, Operation *op,
|
||||
function_ref<void(OpOperand *newOperand)> callback) {
|
||||
Operation *clone = rewriter.clone(*op);
|
||||
for (OpOperand &operand : clone->getOpOperands())
|
||||
callback(&operand);
|
||||
clone->walk([&](Operation *nested) {
|
||||
for (OpOperand &operand : nested->getOpOperands()) {
|
||||
Operation *def = operand.get().getDefiningOp();
|
||||
if ((def && !clone->isAncestor(def)) || isa<BlockArgument>(operand.get()))
|
||||
callback(&operand);
|
||||
}
|
||||
});
|
||||
return clone;
|
||||
}
|
||||
|
||||
void LoopPipelinerInternal::emitPrologue(RewriterBase &rewriter) {
|
||||
// Initialize the iteration argument to the loop initiale values.
|
||||
for (BlockArgument &arg : forOp.getRegionIterArgs()) {
|
||||
OpOperand &operand = forOp.getOpOperandForRegionIterArg(arg);
|
||||
setValueMapping(arg, operand.get(), 0);
|
||||
}
|
||||
auto yield = cast<scf::YieldOp>(forOp.getBody()->getTerminator());
|
||||
Location loc = forOp.getLoc();
|
||||
for (int64_t i = 0; i < maxStage; i++) {
|
||||
Value predicate;
|
||||
if (dynamicLoop) {
|
||||
Type t = ub.getType();
|
||||
// pred = ub > lb + (i * step)
|
||||
Value iv = rewriter.create<arith::AddIOp>(
|
||||
loc, lb,
|
||||
rewriter.create<arith::MulIOp>(
|
||||
loc, step,
|
||||
rewriter.create<arith::ConstantOp>(
|
||||
loc, rewriter.getIntegerAttr(t, i))));
|
||||
predicate = rewriter.create<arith::CmpIOp>(loc, arith::CmpIPredicate::slt,
|
||||
iv, ub);
|
||||
}
|
||||
|
||||
// special handling for induction variable as the increment is implicit.
|
||||
// iv = lb + i * step
|
||||
Type t = lb.getType();
|
||||
Value iv = rewriter.create<arith::AddIOp>(
|
||||
loc, lb,
|
||||
rewriter.create<arith::MulIOp>(
|
||||
loc, step,
|
||||
rewriter.create<arith::ConstantOp>(loc,
|
||||
rewriter.getIntegerAttr(t, i))));
|
||||
setValueMapping(forOp.getInductionVar(), iv, i);
|
||||
for (Operation *op : opOrder) {
|
||||
if (stages[op] > i)
|
||||
continue;
|
||||
Operation *newOp =
|
||||
cloneAndUpdateOperands(rewriter, op, [&](OpOperand *newOperand) {
|
||||
auto it = valueMapping.find(newOperand->get());
|
||||
if (it != valueMapping.end()) {
|
||||
Value replacement = it->second[i - stages[op]];
|
||||
newOperand->set(replacement);
|
||||
}
|
||||
});
|
||||
if (predicate) {
|
||||
newOp = predicateFn(rewriter, newOp, predicate);
|
||||
assert(newOp && "failed to predicate op.");
|
||||
}
|
||||
if (annotateFn)
|
||||
annotateFn(newOp, triton::PipeliningOption::PipelinerPart::Prologue, i);
|
||||
for (unsigned destId : llvm::seq(unsigned(0), op->getNumResults())) {
|
||||
setValueMapping(op->getResult(destId), newOp->getResult(destId),
|
||||
i - stages[op]);
|
||||
// If the value is a loop carried dependency update the loop argument
|
||||
// mapping.
|
||||
for (OpOperand &operand : yield->getOpOperands()) {
|
||||
if (operand.get() != op->getResult(destId))
|
||||
continue;
|
||||
setValueMapping(forOp.getRegionIterArgs()[operand.getOperandNumber()],
|
||||
newOp->getResult(destId), i - stages[op] + 1);
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
std::pair<Operation *, int64_t>
|
||||
LoopPipelinerInternal::getDefiningOpAndDistance(Value value) {
|
||||
int64_t distance = 0;
|
||||
if (auto arg = dyn_cast<BlockArgument>(value)) {
|
||||
if (arg.getOwner() != forOp.getBody())
|
||||
return {nullptr, 0};
|
||||
// Ignore induction variable.
|
||||
if (arg.getArgNumber() == 0)
|
||||
return {nullptr, 0};
|
||||
distance++;
|
||||
value =
|
||||
forOp.getBody()->getTerminator()->getOperand(arg.getArgNumber() - 1);
|
||||
}
|
||||
Operation *def = value.getDefiningOp();
|
||||
if (!def)
|
||||
return {nullptr, 0};
|
||||
return {def, distance};
|
||||
}
|
||||
|
||||
llvm::MapVector<Value, LoopPipelinerInternal::LiverangeInfo>
|
||||
LoopPipelinerInternal::analyzeCrossStageValues() {
|
||||
llvm::MapVector<Value, LoopPipelinerInternal::LiverangeInfo> crossStageValues;
|
||||
for (Operation *op : opOrder) {
|
||||
unsigned stage = stages[op];
|
||||
|
||||
auto analyzeOperand = [&](OpOperand &operand) {
|
||||
auto [def, distance] = getDefiningOpAndDistance(operand.get());
|
||||
if (!def)
|
||||
return;
|
||||
auto defStage = stages.find(def);
|
||||
if (defStage == stages.end() || defStage->second == stage ||
|
||||
defStage->second == stage + distance)
|
||||
return;
|
||||
assert(stage > defStage->second);
|
||||
LiverangeInfo &info = crossStageValues[operand.get()];
|
||||
info.defStage = defStage->second;
|
||||
info.lastUseStage = std::max(info.lastUseStage, stage);
|
||||
};
|
||||
|
||||
for (OpOperand &operand : op->getOpOperands())
|
||||
analyzeOperand(operand);
|
||||
visitUsedValuesDefinedAbove(op->getRegions(), [&](OpOperand *operand) {
|
||||
analyzeOperand(*operand);
|
||||
});
|
||||
}
|
||||
return crossStageValues;
|
||||
}
|
||||
|
||||
scf::ForOp LoopPipelinerInternal::createKernelLoop(
|
||||
const llvm::MapVector<Value, LoopPipelinerInternal::LiverangeInfo>
|
||||
&crossStageValues,
|
||||
RewriterBase &rewriter,
|
||||
llvm::DenseMap<std::pair<Value, unsigned>, unsigned> &loopArgMap) {
|
||||
// Creates the list of initial values associated to values used across
|
||||
// stages. The initial values come from the prologue created above.
|
||||
// Keep track of the kernel argument associated to each version of the
|
||||
// values passed to the kernel.
|
||||
llvm::SmallVector<Value> newLoopArg;
|
||||
// For existing loop argument initialize them with the right version from the
|
||||
// prologue.
|
||||
for (const auto &retVal :
|
||||
llvm::enumerate(forOp.getBody()->getTerminator()->getOperands())) {
|
||||
Operation *def = retVal.value().getDefiningOp();
|
||||
assert(def && "Only support loop carried dependencies of distance 1");
|
||||
unsigned defStage = stages[def];
|
||||
Value valueVersion = valueMapping[forOp.getRegionIterArgs()[retVal.index()]]
|
||||
[maxStage - defStage];
|
||||
assert(valueVersion);
|
||||
newLoopArg.push_back(valueVersion);
|
||||
}
|
||||
for (auto escape : crossStageValues) {
|
||||
LiverangeInfo &info = escape.second;
|
||||
Value value = escape.first;
|
||||
for (unsigned stageIdx = 0; stageIdx < info.lastUseStage - info.defStage;
|
||||
stageIdx++) {
|
||||
Value valueVersion =
|
||||
valueMapping[value][maxStage - info.lastUseStage + stageIdx];
|
||||
assert(valueVersion);
|
||||
newLoopArg.push_back(valueVersion);
|
||||
loopArgMap[std::make_pair(value, info.lastUseStage - info.defStage -
|
||||
stageIdx)] = newLoopArg.size() - 1;
|
||||
}
|
||||
}
|
||||
|
||||
// Create the new kernel loop. When we peel the epilgue we need to peel
|
||||
// `numStages - 1` iterations. Then we adjust the upper bound to remove those
|
||||
// iterations.
|
||||
Value newUb = forOp.getUpperBound();
|
||||
if (peelEpilogue) {
|
||||
Type t = ub.getType();
|
||||
Location loc = forOp.getLoc();
|
||||
// newUb = ub - maxStage * step
|
||||
newUb = rewriter.create<arith::AddIOp>(
|
||||
loc, ub,
|
||||
rewriter.create<arith::MulIOp>(
|
||||
loc, step,
|
||||
rewriter.create<arith::ConstantOp>(
|
||||
loc, rewriter.getIntegerAttr(t, -maxStage))));
|
||||
}
|
||||
auto newForOp =
|
||||
rewriter.create<scf::ForOp>(forOp.getLoc(), forOp.getLowerBound(), newUb,
|
||||
forOp.getStep(), newLoopArg);
|
||||
// When there are no iter args, the loop body terminator will be created.
|
||||
// Since we always create it below, remove the terminator if it was created.
|
||||
if (!newForOp.getBody()->empty())
|
||||
rewriter.eraseOp(newForOp.getBody()->getTerminator());
|
||||
return newForOp;
|
||||
}
|
||||
|
||||
LogicalResult LoopPipelinerInternal::createKernel(
|
||||
scf::ForOp newForOp,
|
||||
const llvm::MapVector<Value, LoopPipelinerInternal::LiverangeInfo>
|
||||
&crossStageValues,
|
||||
const llvm::DenseMap<std::pair<Value, unsigned>, unsigned> &loopArgMap,
|
||||
RewriterBase &rewriter) {
|
||||
valueMapping.clear();
|
||||
|
||||
// Create the kernel, we clone instruction based on the order given by
|
||||
// user and remap operands coming from a previous stages.
|
||||
rewriter.setInsertionPoint(newForOp.getBody(), newForOp.getBody()->begin());
|
||||
IRMapping mapping;
|
||||
mapping.map(forOp.getInductionVar(), newForOp.getInductionVar());
|
||||
for (const auto &arg : llvm::enumerate(forOp.getRegionIterArgs())) {
|
||||
mapping.map(arg.value(), newForOp.getRegionIterArgs()[arg.index()]);
|
||||
}
|
||||
SmallVector<Value> predicates(maxStage + 1, nullptr);
|
||||
if (!peelEpilogue) {
|
||||
// Create a predicate for each stage except the last stage.
|
||||
Location loc = newForOp.getLoc();
|
||||
Type t = ub.getType();
|
||||
for (unsigned i = 0; i < maxStage; i++) {
|
||||
// c = ub - (maxStage - i) * step
|
||||
Value c = rewriter.create<arith::AddIOp>(
|
||||
loc, ub,
|
||||
rewriter.create<arith::MulIOp>(
|
||||
loc, step,
|
||||
rewriter.create<arith::ConstantOp>(
|
||||
loc, rewriter.getIntegerAttr(t, -int64_t(maxStage - i)))));
|
||||
|
||||
Value pred = rewriter.create<arith::CmpIOp>(
|
||||
newForOp.getLoc(), arith::CmpIPredicate::slt,
|
||||
newForOp.getInductionVar(), c);
|
||||
predicates[i] = pred;
|
||||
}
|
||||
}
|
||||
for (Operation *op : opOrder) {
|
||||
int64_t useStage = stages[op];
|
||||
auto *newOp = rewriter.clone(*op, mapping);
|
||||
SmallVector<OpOperand *> operands;
|
||||
// Collect all the operands for the cloned op and its nested ops.
|
||||
op->walk([&operands](Operation *nestedOp) {
|
||||
for (OpOperand &operand : nestedOp->getOpOperands()) {
|
||||
operands.push_back(&operand);
|
||||
}
|
||||
});
|
||||
for (OpOperand *operand : operands) {
|
||||
Operation *nestedNewOp = mapping.lookup(operand->getOwner());
|
||||
// Special case for the induction variable uses. We replace it with a
|
||||
// version incremented based on the stage where it is used.
|
||||
if (operand->get() == forOp.getInductionVar()) {
|
||||
rewriter.setInsertionPoint(newOp);
|
||||
|
||||
// offset = (maxStage - stages[op]) * step
|
||||
Type t = step.getType();
|
||||
Value offset = rewriter.create<arith::MulIOp>(
|
||||
forOp.getLoc(), step,
|
||||
rewriter.create<arith::ConstantOp>(
|
||||
forOp.getLoc(),
|
||||
rewriter.getIntegerAttr(t, maxStage - stages[op])));
|
||||
Value iv = rewriter.create<arith::AddIOp>(
|
||||
forOp.getLoc(), newForOp.getInductionVar(), offset);
|
||||
nestedNewOp->setOperand(operand->getOperandNumber(), iv);
|
||||
rewriter.setInsertionPointAfter(newOp);
|
||||
continue;
|
||||
}
|
||||
Value source = operand->get();
|
||||
auto arg = dyn_cast<BlockArgument>(source);
|
||||
if (arg && arg.getOwner() == forOp.getBody()) {
|
||||
Value ret = forOp.getBody()->getTerminator()->getOperand(
|
||||
arg.getArgNumber() - 1);
|
||||
Operation *dep = ret.getDefiningOp();
|
||||
if (!dep)
|
||||
continue;
|
||||
auto stageDep = stages.find(dep);
|
||||
if (stageDep == stages.end() || stageDep->second == useStage)
|
||||
continue;
|
||||
// If the value is a loop carried value coming from stage N + 1 remap,
|
||||
// it will become a direct use.
|
||||
if (stageDep->second == useStage + 1) {
|
||||
nestedNewOp->setOperand(operand->getOperandNumber(),
|
||||
mapping.lookupOrDefault(ret));
|
||||
continue;
|
||||
}
|
||||
source = ret;
|
||||
}
|
||||
// For operands defined in a previous stage we need to remap it to use
|
||||
// the correct region argument. We look for the right version of the
|
||||
// Value based on the stage where it is used.
|
||||
Operation *def = source.getDefiningOp();
|
||||
if (!def)
|
||||
continue;
|
||||
auto stageDef = stages.find(def);
|
||||
if (stageDef == stages.end() || stageDef->second == useStage)
|
||||
continue;
|
||||
auto remap = loopArgMap.find(
|
||||
std::make_pair(operand->get(), useStage - stageDef->second));
|
||||
assert(remap != loopArgMap.end());
|
||||
nestedNewOp->setOperand(operand->getOperandNumber(),
|
||||
newForOp.getRegionIterArgs()[remap->second]);
|
||||
}
|
||||
|
||||
if (predicates[useStage]) {
|
||||
newOp = predicateFn(rewriter, newOp, predicates[useStage]);
|
||||
if (!newOp)
|
||||
return failure();
|
||||
// Remap the results to the new predicated one.
|
||||
for (auto values : llvm::zip(op->getResults(), newOp->getResults()))
|
||||
mapping.map(std::get<0>(values), std::get<1>(values));
|
||||
}
|
||||
rewriter.setInsertionPointAfter(newOp);
|
||||
if (annotateFn)
|
||||
annotateFn(newOp, triton::PipeliningOption::PipelinerPart::Kernel, 0);
|
||||
}
|
||||
|
||||
// Collect the Values that need to be returned by the forOp. For each
|
||||
// value we need to have `LastUseStage - DefStage` number of versions
|
||||
// returned.
|
||||
// We create a mapping between original values and the associated loop
|
||||
// returned values that will be needed by the epilogue.
|
||||
llvm::SmallVector<Value> yieldOperands;
|
||||
for (OpOperand &yielOperand :
|
||||
forOp.getBody()->getTerminator()->getOpOperands()) {
|
||||
Value source = mapping.lookupOrDefault(yielOperand.get());
|
||||
// When we don't peel the epilogue the yield value is used outside the loop
|
||||
// we need to make sure we return the version from numStages - defStage.
|
||||
if (!peelEpilogue &&
|
||||
!forOp.getResult(yielOperand.getOperandNumber()).use_empty()) {
|
||||
auto [def, distance] = getDefiningOpAndDistance(yielOperand.get());
|
||||
if (def) {
|
||||
auto defStage = stages.find(def);
|
||||
if (defStage != stages.end()) {
|
||||
Value pred = predicates[defStage->second];
|
||||
if (pred) {
|
||||
source = rewriter.create<arith::SelectOp>(
|
||||
pred.getLoc(), pred, source,
|
||||
newForOp.getBody()
|
||||
->getArguments()[yielOperand.getOperandNumber() + 1]);
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
yieldOperands.push_back(source);
|
||||
}
|
||||
|
||||
for (auto &it : crossStageValues) {
|
||||
int64_t version = maxStage - it.second.lastUseStage + 1;
|
||||
unsigned numVersionReturned = it.second.lastUseStage - it.second.defStage;
|
||||
// add the original version to yield ops.
|
||||
// If there is a live range spanning across more than 2 stages we need to
|
||||
// add extra arg.
|
||||
for (unsigned i = 1; i < numVersionReturned; i++) {
|
||||
setValueMapping(it.first, newForOp->getResult(yieldOperands.size()),
|
||||
version++);
|
||||
yieldOperands.push_back(
|
||||
newForOp.getBody()->getArguments()[yieldOperands.size() + 1 +
|
||||
newForOp.getNumInductionVars()]);
|
||||
}
|
||||
setValueMapping(it.first, newForOp->getResult(yieldOperands.size()),
|
||||
version++);
|
||||
yieldOperands.push_back(mapping.lookupOrDefault(it.first));
|
||||
}
|
||||
// Map the yield operand to the forOp returned value.
|
||||
for (const auto &retVal :
|
||||
llvm::enumerate(forOp.getBody()->getTerminator()->getOperands())) {
|
||||
Operation *def = retVal.value().getDefiningOp();
|
||||
assert(def && "Only support loop carried dependencies of distance 1");
|
||||
unsigned defStage = stages[def];
|
||||
if (defStage > 0) {
|
||||
setValueMapping(forOp.getRegionIterArgs()[retVal.index()],
|
||||
newForOp->getResult(retVal.index()),
|
||||
maxStage - defStage + 1);
|
||||
}
|
||||
}
|
||||
rewriter.create<scf::YieldOp>(forOp.getLoc(), yieldOperands);
|
||||
return success();
|
||||
}
|
||||
|
||||
llvm::SmallVector<Value>
|
||||
LoopPipelinerInternal::emitEpilogue(RewriterBase &rewriter) {
|
||||
llvm::SmallVector<Value> returnValues(forOp->getNumResults());
|
||||
// Emit different versions of the induction variable. They will be
|
||||
// removed by dead code if not used.
|
||||
for (int64_t i = 0; i < maxStage; i++) {
|
||||
Location loc = forOp.getLoc();
|
||||
Type t = lb.getType();
|
||||
Value minusOne =
|
||||
rewriter.create<arith::ConstantOp>(loc, rewriter.getIntegerAttr(t, -1));
|
||||
// number of iterations = ((ub - 1) - lb) / step
|
||||
Value totlaNumIteration = rewriter.create<arith::DivUIOp>(
|
||||
loc,
|
||||
rewriter.create<arith::SubIOp>(
|
||||
loc, rewriter.create<arith::AddIOp>(loc, ub, minusOne), lb),
|
||||
step);
|
||||
// newLastIter = lb + step * ((((ub - 1) - lb) / step) - i)
|
||||
Value minusI =
|
||||
rewriter.create<arith::ConstantOp>(loc, rewriter.getIntegerAttr(t, -i));
|
||||
Value newlastIter = rewriter.create<arith::AddIOp>(
|
||||
loc, lb,
|
||||
rewriter.create<arith::MulIOp>(
|
||||
loc, step,
|
||||
rewriter.create<arith::AddIOp>(loc, totlaNumIteration, minusI)));
|
||||
setValueMapping(forOp.getInductionVar(), newlastIter, maxStage - i);
|
||||
}
|
||||
// Emit `maxStage - 1` epilogue part that includes operations from stages
|
||||
// [i; maxStage].
|
||||
for (int64_t i = 1; i <= maxStage; i++) {
|
||||
for (Operation *op : opOrder) {
|
||||
if (stages[op] < i)
|
||||
continue;
|
||||
Operation *newOp =
|
||||
cloneAndUpdateOperands(rewriter, op, [&](OpOperand *newOperand) {
|
||||
auto it = valueMapping.find(newOperand->get());
|
||||
if (it != valueMapping.end()) {
|
||||
Value replacement = it->second[maxStage - stages[op] + i];
|
||||
newOperand->set(replacement);
|
||||
}
|
||||
});
|
||||
if (annotateFn)
|
||||
annotateFn(newOp, triton::PipeliningOption::PipelinerPart::Epilogue,
|
||||
i - 1);
|
||||
for (unsigned destId : llvm::seq(unsigned(0), op->getNumResults())) {
|
||||
setValueMapping(op->getResult(destId), newOp->getResult(destId),
|
||||
maxStage - stages[op] + i);
|
||||
// If the value is a loop carried dependency update the loop argument
|
||||
// mapping and keep track of the last version to replace the original
|
||||
// forOp uses.
|
||||
for (OpOperand &operand :
|
||||
forOp.getBody()->getTerminator()->getOpOperands()) {
|
||||
if (operand.get() != op->getResult(destId))
|
||||
continue;
|
||||
unsigned version = maxStage - stages[op] + i + 1;
|
||||
// If the version is greater than maxStage it means it maps to the
|
||||
// original forOp returned value.
|
||||
if (version > maxStage) {
|
||||
returnValues[operand.getOperandNumber()] = newOp->getResult(destId);
|
||||
continue;
|
||||
}
|
||||
setValueMapping(forOp.getRegionIterArgs()[operand.getOperandNumber()],
|
||||
newOp->getResult(destId), version);
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
return returnValues;
|
||||
}
|
||||
|
||||
void LoopPipelinerInternal::setValueMapping(Value key, Value el, int64_t idx) {
|
||||
auto it = valueMapping.find(key);
|
||||
// If the value is not in the map yet add a vector big enough to store all
|
||||
// versions.
|
||||
if (it == valueMapping.end())
|
||||
it =
|
||||
valueMapping
|
||||
.insert(std::make_pair(key, llvm::SmallVector<Value>(maxStage + 1)))
|
||||
.first;
|
||||
it->second[idx] = el;
|
||||
}
|
||||
|
||||
} // namespace
|
||||
|
||||
FailureOr<ForOp>
|
||||
mlir::triton::pipelineForLoop(RewriterBase &rewriter, ForOp forOp,
|
||||
const triton::PipeliningOption &options,
|
||||
bool *modifiedIR) {
|
||||
if (modifiedIR)
|
||||
*modifiedIR = false;
|
||||
LoopPipelinerInternal pipeliner;
|
||||
if (!pipeliner.initializeLoopInfo(forOp, options))
|
||||
return failure();
|
||||
|
||||
if (modifiedIR)
|
||||
*modifiedIR = true;
|
||||
|
||||
// 1. Emit prologue.
|
||||
pipeliner.emitPrologue(rewriter);
|
||||
|
||||
// 2. Track values used across stages. When a value cross stages it will
|
||||
// need to be passed as loop iteration arguments.
|
||||
// We first collect the values that are used in a different stage than where
|
||||
// they are defined.
|
||||
llvm::MapVector<Value, LoopPipelinerInternal::LiverangeInfo>
|
||||
crossStageValues = pipeliner.analyzeCrossStageValues();
|
||||
|
||||
// Mapping between original loop values used cross stage and the block
|
||||
// arguments associated after pipelining. A Value may map to several
|
||||
// arguments if its liverange spans across more than 2 stages.
|
||||
llvm::DenseMap<std::pair<Value, unsigned>, unsigned> loopArgMap;
|
||||
// 3. Create the new kernel loop and return the block arguments mapping.
|
||||
ForOp newForOp =
|
||||
pipeliner.createKernelLoop(crossStageValues, rewriter, loopArgMap);
|
||||
// Create the kernel block, order ops based on user choice and remap
|
||||
// operands.
|
||||
if (failed(pipeliner.createKernel(newForOp, crossStageValues, loopArgMap,
|
||||
rewriter)))
|
||||
return failure();
|
||||
|
||||
llvm::SmallVector<Value> returnValues =
|
||||
newForOp.getResults().take_front(forOp->getNumResults());
|
||||
if (options.peelEpilogue) {
|
||||
// 4. Emit the epilogue after the new forOp.
|
||||
rewriter.setInsertionPointAfter(newForOp);
|
||||
returnValues = pipeliner.emitEpilogue(rewriter);
|
||||
}
|
||||
// 5. Erase the original loop and replace the uses with the epilogue output.
|
||||
if (forOp->getNumResults() > 0)
|
||||
rewriter.replaceOp(forOp, returnValues);
|
||||
else
|
||||
rewriter.eraseOp(forOp);
|
||||
|
||||
return newForOp;
|
||||
}
|
||||
101
lib/Dialect/TritonGPU/Transforms/Pipeliner/PipelineExpander.h
Normal file
101
lib/Dialect/TritonGPU/Transforms/Pipeliner/PipelineExpander.h
Normal file
@@ -0,0 +1,101 @@
|
||||
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
|
||||
// See https://llvm.org/LICENSE.txt for license information.
|
||||
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
|
||||
#ifndef TRITON_DIALECT_TRITONGPU_TRANSFORMS_PIPELINE_H_
|
||||
#define TRITON_DIALECT_TRITONGPU_TRANSFORMS_PIPELINE_H_
|
||||
|
||||
// This is a fork of upstream pipeline transformation. This will be merged back
|
||||
// upstream once we have a stable solution.
|
||||
|
||||
#include "mlir/Support/LLVM.h"
|
||||
#include "mlir/Support/LogicalResult.h"
|
||||
#include "llvm/ADT/ArrayRef.h"
|
||||
|
||||
namespace mlir {
|
||||
|
||||
class RewriterBase;
|
||||
class Operation;
|
||||
class Value;
|
||||
|
||||
namespace scf {
|
||||
class ForOp;
|
||||
}
|
||||
|
||||
namespace triton {
|
||||
|
||||
/// Options to dictate how loops should be pipelined.
|
||||
struct PipeliningOption {
|
||||
/// Lambda returning all the operation in the forOp, with their stage, in the
|
||||
/// order picked for the pipelined loop.
|
||||
using GetScheduleFnType = std::function<void(
|
||||
scf::ForOp, std::vector<std::pair<Operation *, unsigned>> &)>;
|
||||
GetScheduleFnType getScheduleFn = nullptr;
|
||||
enum class PipelinerPart {
|
||||
Prologue,
|
||||
Kernel,
|
||||
Epilogue,
|
||||
};
|
||||
/// Lambda called by the pipeliner to allow the user to annotate the IR while
|
||||
/// it is generated.
|
||||
/// The callback passes the operation created along with the part of the
|
||||
/// pipeline and the iteration index. The iteration index is always 0 for the
|
||||
/// kernel. For the prologue and epilogue, it corresponds to the iteration
|
||||
/// peeled out of the loop in the range [0, maxStage[.
|
||||
using AnnotationlFnType =
|
||||
std::function<void(Operation *, PipelinerPart, unsigned)>;
|
||||
AnnotationlFnType annotateFn = nullptr;
|
||||
|
||||
/// Control whether the epilogue should be peeled out of the loop or
|
||||
/// operations should be predicated to skip the early stages in the last loop
|
||||
/// iterations. If the epilogue is predicated; the user needs to provide a
|
||||
/// lambda to generate the predicated version of operations.
|
||||
bool peelEpilogue = true;
|
||||
|
||||
/// Control whether the transformation checks that the number of iterations is
|
||||
/// greater or equal to the number of stages and skip the transformation if
|
||||
/// this is not the case. If the loop is dynamic and this is set to true the
|
||||
/// pipeliner will have to predicate operations in the the prologue/epilogue.
|
||||
bool supportDynamicLoops = false;
|
||||
|
||||
// Callback to predicate operations when the prologue or epilogue are not
|
||||
// peeled. This takes the original operation, an i1 predicate value and the
|
||||
// pattern rewriter. It is expected to replace the given operation with
|
||||
// the predicated equivalent and return it, or return nullptr if the
|
||||
// predication is impossible. In the latter case, pipelining will fail and
|
||||
// may leave IR in a partially transformed state.
|
||||
using PredicateOpFnType =
|
||||
std::function<Operation *(RewriterBase &, Operation *, Value)>;
|
||||
PredicateOpFnType predicateFn = nullptr;
|
||||
|
||||
// TODO: add option to decide if the prologue should be peeled.
|
||||
};
|
||||
|
||||
/// Generate a pipelined version of the scf.for loop based on the schedule given
|
||||
/// as option. This applies the mechanical transformation of changing the loop
|
||||
/// and generating the prologue/epilogue for the pipelining and doesn't make any
|
||||
/// decision regarding the schedule.
|
||||
/// Based on the options the loop is split into several stages.
|
||||
/// The transformation assumes that the scheduling given by user is valid.
|
||||
/// For example if we break a loop into 3 stages named S0, S1, S2 we would
|
||||
/// generate the following code with the number in parenthesis as the iteration
|
||||
/// index:
|
||||
///
|
||||
/// S0(0) // Prologue
|
||||
/// S0(1) S1(0) // Prologue
|
||||
/// scf.for %I = %C0 to %N - 2 {
|
||||
/// S0(I+2) S1(I+1) S2(I) // Pipelined kernel
|
||||
/// }
|
||||
/// S1(N) S2(N-1) // Epilogue
|
||||
/// S2(N) // Epilogue
|
||||
///
|
||||
/// If `modifiedIR` is provided, it will be set to a value that indicates
|
||||
/// whether pipelining modified the IR before failing, signaling to the caller
|
||||
/// whether they can proceed with different transformations.
|
||||
FailureOr<scf::ForOp> pipelineForLoop(RewriterBase &rewriter, scf::ForOp forOp,
|
||||
const PipeliningOption &options,
|
||||
bool *modifiedIR = nullptr);
|
||||
|
||||
} // namespace triton
|
||||
} // namespace mlir
|
||||
|
||||
#endif // TRITON_DIALECT_TRITONGPU_TRANSFORMS_PIPELINE_H_
|
||||
27
lib/Dialect/TritonGPU/Transforms/Pipeliner/Schedule.h
Normal file
27
lib/Dialect/TritonGPU/Transforms/Pipeliner/Schedule.h
Normal file
@@ -0,0 +1,27 @@
|
||||
#ifndef TRITON_TRITONGPU_TRANSFORM_PIPELINE_SCHEDULE_H_
|
||||
#define TRITON_TRITONGPU_TRANSFORM_PIPELINE_SCHEDULE_H_
|
||||
|
||||
#include "PipelineExpander.h"
|
||||
#include "mlir/Dialect/SCF/IR/SCF.h"
|
||||
#include "mlir/Support/LLVM.h"
|
||||
#include "llvm/ADT/ArrayRef.h"
|
||||
#include <vector>
|
||||
|
||||
namespace mlir {
|
||||
namespace triton {
|
||||
|
||||
/// This fill out the pipelining options including schedule and annotations for
|
||||
/// wait ops. This also does pre-processing by converting some of the loads into
|
||||
/// async loads so that the IR is ready to be pipelined.
|
||||
bool preProcessLoopAndGetSchedule(scf::ForOp &forOp, int numStages,
|
||||
mlir::triton::PipeliningOption &options);
|
||||
|
||||
/// This does post-processing on the pipelined loop to try to pipeline wgmma
|
||||
/// ops.
|
||||
// TODO: this should be included as part of the pipeline but currently the wgmma
|
||||
// wait modeling is problematic.
|
||||
void asyncLaunchDots(scf::ForOp forOp);
|
||||
|
||||
} // namespace triton
|
||||
} // namespace mlir
|
||||
#endif // TRITON_TRITONGPU_TRANSFORM_PIPELINE_SCHEDULE_H_
|
||||
@@ -0,0 +1,88 @@
|
||||
#include "PipelineExpander.h"
|
||||
#include "Schedule.h"
|
||||
#include "mlir/Dialect/Tensor/IR/Tensor.h"
|
||||
#include "mlir/IR/IRMapping.h"
|
||||
#include "mlir/IR/TypeUtilities.h"
|
||||
#include "mlir/Interfaces/SideEffectInterfaces.h"
|
||||
#include "mlir/Transforms/GreedyPatternRewriteDriver.h"
|
||||
#include "triton/Analysis/AxisInfo.h"
|
||||
#include "triton/Analysis/Utility.h"
|
||||
#include "triton/Dialect/TritonGPU/IR/Dialect.h"
|
||||
#include "triton/Dialect/TritonGPU/Transforms/Passes.h"
|
||||
#include "triton/Dialect/TritonGPU/Transforms/Utility.h"
|
||||
#include "triton/Dialect/TritonNvidiaGPU/IR/Dialect.h"
|
||||
#include "triton/Tools/Sys/GetEnv.hpp"
|
||||
#include "llvm/Support/Debug.h"
|
||||
|
||||
//===----------------------------------------------------------------------===//
|
||||
// This file will create a schedule that will be handed over to the pipeline
|
||||
// expander.
|
||||
// Software pipeliners are usually separated into two pieces, one that create a
|
||||
// modulo schedule and an expander that rewrites the loop and emits a prologue
|
||||
// and epilogue. This pass first calls a helper that will pre-process the IR
|
||||
// to create async operations and create a modulo schedule. Then we call the
|
||||
// expander to generate the prologue and new loop.
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
||||
using namespace mlir;
|
||||
|
||||
#define GEN_PASS_CLASSES
|
||||
#include "triton/Dialect/TritonGPU/Transforms/Passes.h.inc"
|
||||
|
||||
static void pipelineLoop(scf::ForOp forOp, int numStages) {
|
||||
mlir::triton::PipeliningOption options;
|
||||
// Skip loop with distance > 1 for now.
|
||||
// TODO: relax the constraint in the expander.
|
||||
if (llvm::any_of(forOp.getBody()->getTerminator()->getOperands(),
|
||||
[](Value operand) {
|
||||
Operation *def = operand.getDefiningOp();
|
||||
return !def;
|
||||
}))
|
||||
return;
|
||||
|
||||
bool foundSchedule = false;
|
||||
foundSchedule = preProcessLoopAndGetSchedule(forOp, numStages, options);
|
||||
|
||||
// TODO: add more pipelines strategy.
|
||||
if (!foundSchedule)
|
||||
return;
|
||||
|
||||
IRRewriter rewriter(forOp->getContext());
|
||||
rewriter.setInsertionPoint(forOp);
|
||||
FailureOr<scf::ForOp> newForOp =
|
||||
mlir::triton::pipelineForLoop(rewriter, forOp, options);
|
||||
|
||||
if (succeeded(newForOp))
|
||||
mlir::triton::asyncLaunchDots(newForOp.value());
|
||||
}
|
||||
|
||||
namespace {
|
||||
struct PipelinePass : public TritonGPUPipelineBase<PipelinePass> {
|
||||
PipelinePass() = default;
|
||||
PipelinePass(int numStages, int numWarps, int numCTAs,
|
||||
int computeCapability) {
|
||||
this->numStages = numStages;
|
||||
this->numWarps = numWarps;
|
||||
this->numCTAs = numCTAs;
|
||||
this->computeCapability = computeCapability;
|
||||
}
|
||||
|
||||
void runOnOperation() override {
|
||||
if (this->numStages <= 1)
|
||||
return;
|
||||
SmallVector<scf::ForOp> loops;
|
||||
getOperation()->walk([&](scf::ForOp forOp) { loops.push_back(forOp); });
|
||||
for (scf::ForOp forOp : loops) {
|
||||
pipelineLoop(forOp, numStages);
|
||||
}
|
||||
}
|
||||
};
|
||||
} // anonymous namespace
|
||||
|
||||
std::unique_ptr<Pass> mlir::createTritonGPUPipelinePass(int numStages,
|
||||
int numWarps,
|
||||
int numCTAs,
|
||||
int computeCapability) {
|
||||
return std::make_unique<PipelinePass>(numStages, numWarps, numCTAs,
|
||||
computeCapability);
|
||||
}
|
||||
@@ -332,9 +332,6 @@ SmallVector<Value> LayoutPropagation::propagateToUsers(Value value,
|
||||
setEncoding({afterArg, result}, info, changed, user);
|
||||
continue;
|
||||
}
|
||||
// Workaround: don't propagate through truncI
|
||||
if (isa<arith::TruncIOp>(user))
|
||||
continue;
|
||||
if (user->hasTrait<mlir::OpTrait::SameOperandsAndResultEncoding>() ||
|
||||
user->hasTrait<mlir::OpTrait::Elementwise>() ||
|
||||
isa<triton::ReduceOp, triton::ExpandDimsOp,
|
||||
@@ -755,7 +752,7 @@ Operation *LayoutPropagation::rewriteOp(Operation *op) {
|
||||
map(oldResult, newResult);
|
||||
return newOp;
|
||||
}
|
||||
assert(0 && "unexpected op in rewrite");
|
||||
llvm::report_fatal_error("unexpected op in rewrite");
|
||||
return nullptr;
|
||||
}
|
||||
|
||||
@@ -772,34 +769,6 @@ static bool canBeRemat(Operation *op) {
|
||||
return true;
|
||||
}
|
||||
|
||||
// Replace ForOp with a new ForOp with extra operands. The YieldOp is not
|
||||
// updated and needs to be updated separatly for the loop to be correct.
|
||||
static scf::ForOp replaceForOpWithNewSignature(OpBuilder &rewriter,
|
||||
scf::ForOp loop,
|
||||
ValueRange newIterOperands) {
|
||||
OpBuilder::InsertionGuard g(rewriter);
|
||||
rewriter.setInsertionPoint(loop);
|
||||
|
||||
// Create a new loop before the existing one, with the extra operands.
|
||||
rewriter.setInsertionPoint(loop);
|
||||
auto operands = llvm::to_vector<4>(loop.getInitArgs());
|
||||
operands.append(newIterOperands.begin(), newIterOperands.end());
|
||||
scf::ForOp newLoop = rewriter.create<scf::ForOp>(
|
||||
loop.getLoc(), loop.getLowerBound(), loop.getUpperBound(), loop.getStep(),
|
||||
operands);
|
||||
newLoop.getBody()->erase();
|
||||
|
||||
newLoop.getRegion().getBlocks().splice(
|
||||
newLoop.getRegion().getBlocks().begin(), loop.getRegion().getBlocks());
|
||||
for (Value operand : newIterOperands)
|
||||
newLoop.getBody()->addArgument(operand.getType(), operand.getLoc());
|
||||
|
||||
for (auto it : llvm::zip(loop.getResults(), newLoop.getResults().take_front(
|
||||
loop.getNumResults())))
|
||||
std::get<0>(it).replaceAllUsesWith(std::get<1>(it));
|
||||
return newLoop;
|
||||
}
|
||||
|
||||
static void rewriteSlice(SetVector<Value> &slice,
|
||||
DenseMap<Value, Attribute> &layout,
|
||||
ConvertLayoutOp convertOp, IRMapping &mapping) {
|
||||
|
||||
@@ -98,8 +98,6 @@ TritonGPUConversionTarget::TritonGPUConversionTarget(
|
||||
// Some ops from SCF are illegal
|
||||
addIllegalOp<scf::ExecuteRegionOp, scf::ParallelOp, scf::ReduceOp,
|
||||
scf::ReduceReturnOp>();
|
||||
// We have custom versions of some arith operators
|
||||
addIllegalOp<arith::CmpIOp, arith::CmpFOp>();
|
||||
|
||||
addDynamicallyLegalDialect<arith::ArithDialect, math::MathDialect,
|
||||
triton::TritonDialect, cf::ControlFlowDialect,
|
||||
|
||||
@@ -232,8 +232,10 @@ std::string GraphLayoutMarker::getColor(const Type &type) const {
|
||||
return "orange";
|
||||
else if (layout.isa<triton::gpu::SharedEncodingAttr>())
|
||||
return "orangered";
|
||||
else
|
||||
assert(0 && "Unrecognized layout");
|
||||
else {
|
||||
llvm::report_fatal_error("Unrecognized layout");
|
||||
return "unknown";
|
||||
}
|
||||
} else {
|
||||
return "white";
|
||||
}
|
||||
@@ -342,11 +344,39 @@ bool canFoldIntoConversion(Operation *op, Attribute targetEncoding) {
|
||||
}
|
||||
return true;
|
||||
}
|
||||
if (auto view = dyn_cast<triton::ViewOp>(op)) {
|
||||
auto viewDstType = view.getType().cast<RankedTensorType>();
|
||||
RankedTensorType newDstType = RankedTensorType::get(
|
||||
viewDstType.getShape(), viewDstType.getElementType(), targetEncoding);
|
||||
return !triton::gpu::isExpensiveView(view.getOperand().getType(),
|
||||
newDstType);
|
||||
}
|
||||
return isa<triton::gpu::ConvertLayoutOp, arith::ConstantOp,
|
||||
triton::MakeRangeOp, triton::SplatOp, triton::ViewOp>(op);
|
||||
triton::MakeRangeOp, triton::SplatOp>(op);
|
||||
}
|
||||
|
||||
//
|
||||
scf::ForOp replaceForOpWithNewSignature(OpBuilder &rewriter, scf::ForOp loop,
|
||||
ValueRange newIterOperands) {
|
||||
OpBuilder::InsertionGuard g(rewriter);
|
||||
rewriter.setInsertionPoint(loop);
|
||||
|
||||
// Create a new loop before the existing one, with the extra operands.
|
||||
auto operands = llvm::to_vector<4>(loop.getInitArgs());
|
||||
operands.append(newIterOperands.begin(), newIterOperands.end());
|
||||
scf::ForOp newLoop = rewriter.create<scf::ForOp>(
|
||||
loop.getLoc(), loop.getLowerBound(), loop.getUpperBound(), loop.getStep(),
|
||||
operands);
|
||||
newLoop.getBody()->erase();
|
||||
newLoop.getRegion().getBlocks().splice(
|
||||
newLoop.getRegion().getBlocks().begin(), loop.getRegion().getBlocks());
|
||||
for (Value operand : newIterOperands)
|
||||
newLoop.getBody()->addArgument(operand.getType(), operand.getLoc());
|
||||
|
||||
for (auto it : llvm::zip(loop.getResults(), newLoop.getResults().take_front(
|
||||
loop.getNumResults())))
|
||||
std::get<0>(it).replaceAllUsesWith(std::get<1>(it));
|
||||
return newLoop;
|
||||
}
|
||||
|
||||
Operation *cloneWithInferType(mlir::OpBuilder &rewriter, Operation *op,
|
||||
IRMapping &mapping) {
|
||||
|
||||
@@ -79,6 +79,17 @@ void CreateMutexOp::build(::mlir::OpBuilder &builder,
|
||||
build(builder, state, MutexType::get(builder.getContext()));
|
||||
}
|
||||
|
||||
///--- DotWaitOp ---
|
||||
LogicalResult DotWaitOp::inferReturnTypes(
|
||||
::mlir::MLIRContext *context, ::std::optional<::mlir::Location> location,
|
||||
::mlir::ValueRange operands, ::mlir::DictionaryAttr attributes,
|
||||
::mlir::OpaqueProperties properties, ::mlir::RegionRange regions,
|
||||
::llvm::SmallVectorImpl<::mlir::Type> &inferredReturnTypes) {
|
||||
for (Value operand : operands)
|
||||
inferredReturnTypes.push_back(operand.getType());
|
||||
return mlir::success();
|
||||
}
|
||||
|
||||
} // namespace nvidia_gpu
|
||||
} // namespace triton
|
||||
} // namespace mlir
|
||||
|
||||
@@ -68,7 +68,8 @@ Attribute replaceCTALayout(Attribute layout, llvm::ArrayRef<int64_t> shape,
|
||||
replaceCTALayout(sliceLayout.getParent(), shape, newCTALayout));
|
||||
} else {
|
||||
// Other layouts are generated by passes after PlanCTAPass
|
||||
assert(0 && "replaceCTALayout not implemented");
|
||||
llvm::report_fatal_error("replaceCTALayout not implemented");
|
||||
return layout;
|
||||
}
|
||||
}
|
||||
|
||||
@@ -393,7 +394,8 @@ bool CTAPlanner::propagateBackward(CastOp cast) {
|
||||
Value output = cast.getResult(0);
|
||||
unsigned numUsers = getNumUsers(input);
|
||||
if (numUsers == 0) {
|
||||
assert(0 && "Unreachable branch");
|
||||
llvm::report_fatal_error("Unreachable branch");
|
||||
return false;
|
||||
} else if (numUsers == 1) {
|
||||
Type outTy = output.getType();
|
||||
if (auto ptrTy = outTy.dyn_cast<triton::PointerType>())
|
||||
@@ -649,7 +651,7 @@ bool CTAPlanner::isElementwiseOp(Operation *op) const {
|
||||
return true;
|
||||
if (auto externElementwiseOp = dyn_cast<triton::ExternElementwiseOp>(op))
|
||||
return externElementwiseOp.getPure();
|
||||
if (llvm::isa<ttg::CmpIOp, ttg::CmpFOp, ttg::SelectOp>(op))
|
||||
if (llvm::isa<arith::CmpIOp, arith::CmpFOp, arith::SelectOp>(op))
|
||||
return true;
|
||||
return false;
|
||||
}
|
||||
@@ -711,7 +713,7 @@ bool CTAPlanner::processExpandDimsBackward(triton::ExpandDimsOp expandDims,
|
||||
|
||||
bool CTAPlanner::processExpandDimsForward(triton::ExpandDimsOp expandDims,
|
||||
Attribute newSrcLayout) {
|
||||
assert(0 && "processExpandDimsForward not implemented yet");
|
||||
llvm::report_fatal_error("processExpandDimsForward not implemented yet");
|
||||
return true;
|
||||
}
|
||||
|
||||
@@ -827,7 +829,7 @@ int findResultIndex(Operation *op, Value result) {
|
||||
for (int i = 0; i < op->getNumResults(); ++i)
|
||||
if (op->getResult(i) == result)
|
||||
return i;
|
||||
assert(0 && "Invalid index of op result");
|
||||
llvm::report_fatal_error("Invalid index of op result");
|
||||
return -1;
|
||||
}
|
||||
|
||||
@@ -849,7 +851,7 @@ bool CTAPlanner::processBlockArgBackward(BlockArgument arg, CastOp cast) {
|
||||
auto newType = cast.getResult(0).getType();
|
||||
return processForOp(forOp, index, newType);
|
||||
} else {
|
||||
assert(0 && "Unexpected parent op of block argument");
|
||||
llvm::report_fatal_error("Unexpected parent op of block argument");
|
||||
return true;
|
||||
}
|
||||
}
|
||||
@@ -869,7 +871,7 @@ bool CTAPlanner::processYieldOpForward(scf::YieldOp yieldOp, CastOp cast) {
|
||||
else if (auto forOp = llvm::dyn_cast<scf::ForOp>(yieldOp->getParentOp()))
|
||||
return processForOp(forOp, index, newType);
|
||||
else
|
||||
assert(0 && "Unexpected parent op of YieldOp");
|
||||
llvm::report_fatal_error("Unexpected parent op of YieldOp");
|
||||
return true;
|
||||
}
|
||||
|
||||
@@ -936,7 +938,8 @@ bool CTAPlanner::processMultiUsersBackward(Value input, CastOp cast) {
|
||||
Operation *clonedOp = builder.clone(*defOp);
|
||||
newInput = clonedOp->getResult(0);
|
||||
} else {
|
||||
assert(0 && "Layout conflict for block arg"); // TODO
|
||||
llvm::report_fatal_error("Layout conflict for block arg"); // TODO
|
||||
return false;
|
||||
}
|
||||
}
|
||||
first = false;
|
||||
|
||||
@@ -55,7 +55,7 @@ bool isDivisible(Value v, unsigned divisor) {
|
||||
auto func = dyn_cast<tt::FuncOp>(parentOp);
|
||||
assert(func);
|
||||
if (auto attr = func.getArgAttrOfType<IntegerAttr>(blockArg.getArgNumber(),
|
||||
"tt.max_divisibility"))
|
||||
"tt.divisibility"))
|
||||
return attr.getValue().getZExtValue() % divisor == 0;
|
||||
return false;
|
||||
} else if (v.getParentBlock()->isEntryBlock() && (!v.isa<BlockArgument>())) {
|
||||
@@ -98,13 +98,8 @@ bool shouldRemove(tt::MakeTensorPtrOp &op, int computeCapability) {
|
||||
return !(boxDimSwizzle && strideDivisible && enableTMA);
|
||||
}
|
||||
|
||||
// TODO: When encoding exists use triton::gpu::CmpIOp as arith::CmpIOp doesn't
|
||||
// play well with encoding attributes. Move back to arith::CmpIOp when this pass
|
||||
// moves back to triton IR level.
|
||||
Value createCmpOp(OpBuilder &builder, Location loc, RankedTensorType type,
|
||||
arith::CmpIPredicate pred, Value lhs, Value rhs) {
|
||||
if (type.getEncoding())
|
||||
return builder.create<ttg::CmpIOp>(loc, type, pred, lhs, rhs);
|
||||
return builder.create<arith::CmpIOp>(loc, type, pred, lhs, rhs);
|
||||
}
|
||||
|
||||
@@ -358,12 +353,17 @@ class TritonGPURewriteTensorPointerPass
|
||||
: public TritonGPURewriteTensorPointerBase<
|
||||
TritonGPURewriteTensorPointerPass> {
|
||||
private:
|
||||
int computeCapability;
|
||||
// int computeCapability;
|
||||
DenseMap<Value, RewritedInfo> rewritedInfo;
|
||||
|
||||
public:
|
||||
explicit TritonGPURewriteTensorPointerPass(int computeCapability)
|
||||
: computeCapability(computeCapability) {}
|
||||
// explicit TritonGPURewriteTensorPointerPass(int computeCapability)
|
||||
// : computeCapability(computeCapability) {}
|
||||
|
||||
TritonGPURewriteTensorPointerPass() = default;
|
||||
TritonGPURewriteTensorPointerPass(int computeCapability) {
|
||||
this->computeCapability = computeCapability;
|
||||
}
|
||||
|
||||
static bool needRewrite(Operation *op, const DenseSet<Value> &valueToRemove) {
|
||||
if (auto ifOp = dyn_cast<scf::IfOp>(op)) {
|
||||
@@ -763,17 +763,16 @@ public:
|
||||
ModuleOp mod = getOperation();
|
||||
|
||||
DenseSet<Value> valueToRemove;
|
||||
mod.walk([&valueToRemove,
|
||||
computeCapability = this->computeCapability](Operation *op) {
|
||||
mod.walk([&valueToRemove, this](Operation *op) {
|
||||
if (auto makeTensorPtrOp = dyn_cast<tt::MakeTensorPtrOp>(op)) {
|
||||
if (shouldRemove(makeTensorPtrOp, computeCapability))
|
||||
if (shouldRemove(makeTensorPtrOp, this->computeCapability))
|
||||
valueToRemove.insert(op->getResult(0));
|
||||
}
|
||||
if (llvm::isa<tt::AdvanceOp>(op)) {
|
||||
auto src = op->getOperand(0);
|
||||
if (tt::isTensorPointerType(src.getType())) {
|
||||
auto makeTensorPtrOp = getMakeTensorPtrOp(src);
|
||||
if (shouldRemove(makeTensorPtrOp, computeCapability)) {
|
||||
if (shouldRemove(makeTensorPtrOp, this->computeCapability)) {
|
||||
valueToRemove.insert(op->getResult(0));
|
||||
}
|
||||
}
|
||||
@@ -782,7 +781,7 @@ public:
|
||||
auto src = op->getOperand(0);
|
||||
if (tt::isTensorPointerType(src.getType())) {
|
||||
auto makeTensorPtrOp = getMakeTensorPtrOp(src);
|
||||
if (shouldRemove(makeTensorPtrOp, computeCapability))
|
||||
if (shouldRemove(makeTensorPtrOp, this->computeCapability))
|
||||
valueToRemove.insert(src);
|
||||
}
|
||||
}
|
||||
@@ -791,7 +790,7 @@ public:
|
||||
for (unsigned i = 0, size = forOp.getInitArgs().size(); i < size; ++i) {
|
||||
if (tt::isTensorPointerType(iterOperands[i].getType())) {
|
||||
auto makeTensorPtrOp = getMakeTensorPtrOp(iterOperands[i]);
|
||||
if (shouldRemove(makeTensorPtrOp, computeCapability))
|
||||
if (shouldRemove(makeTensorPtrOp, this->computeCapability))
|
||||
valueToRemove.insert(iterOperands[i]);
|
||||
}
|
||||
}
|
||||
@@ -800,7 +799,7 @@ public:
|
||||
for (unsigned i = 0, size = yieldOp.getNumOperands(); i < size; ++i) {
|
||||
if (tt::isTensorPointerType(operands[i].getType())) {
|
||||
auto makeTensorPtrOp = getMakeTensorPtrOp(operands[i]);
|
||||
if (shouldRemove(makeTensorPtrOp, computeCapability))
|
||||
if (shouldRemove(makeTensorPtrOp, this->computeCapability))
|
||||
valueToRemove.insert(operands[i]);
|
||||
}
|
||||
}
|
||||
|
||||
@@ -892,7 +892,7 @@ void buildAsyncComm(const DenseMap<Operation *, SmallVector<Channel *>> &map,
|
||||
Value result = forOp->getResult(resultIndex);
|
||||
auto dotWait = builder.createWithAgentIds<triton::nvidia_gpu::DotWaitOp>(
|
||||
forOp.getLoc(), result, 0);
|
||||
result.replaceAllUsesExcept(dotWait.getResult(), dotWait);
|
||||
result.replaceAllUsesExcept(dotWait.getResult(0), dotWait);
|
||||
|
||||
// 3. insert ConsumerReleaseOp for outstanding DotAsyncOps
|
||||
zero = builder.createWithAgentIds<arith::ConstantIntOp>(loc, 0, 32);
|
||||
|
||||
Reference in New Issue
Block a user