[OPTIMIZER][BACKEND] Cleaned up Volta codegen (#1185)

This commit is contained in:
Philippe Tillet
2023-02-14 22:39:35 -08:00
committed by GitHub
parent 8bca84ce3d
commit e3941f9d09
8 changed files with 297 additions and 496 deletions

View File

@@ -159,7 +159,7 @@ struct DotOpMmaV1ConversionHelper {
return M / shapePerCTAM * param.rep[0];
}
using CoordTy = SmallVector<Value, 2>;
using CoordTy = SmallVector<Value>;
// Get the coordinates(m,n) of the elements emit by a thread in accumulator.
static SmallVector<CoordTy>
getMNCoords(Value thread, ConversionPatternRewriter &rewriter,

View File

@@ -69,36 +69,6 @@ struct BroadcastOpConversion
auto srcOffsets = emitOffsetForLayout(srcLayout, srcShape);
auto resultOffsets = emitOffsetForLayout(resultLayout, resultShape);
SmallVector<Value> srcVals = getElementsFromStruct(loc, src, rewriter);
if (auto srcMma = srcLayout.dyn_cast<MmaEncodingAttr>()) {
// NOTE: This is just an naive fix, but for MMA layout, and 2-d fix should
// be all right.
// TODO[Superjomn]: Replace this with a generic implementation.
if (srcMma.isVolta()) {
assert(srcTy.getElementType().isF16() &&
"Unexpected data type on Volta");
int numElemsPerThread = srcMma.getElemsPerThread(resultTy.getShape());
int srcUniqElems = srcVals.size() / 2;
int dup = numElemsPerThread / srcUniqElems;
SmallVector<Value> retVals;
if (srcShape[0] == 1) { // add-cols
for (int i = 0; i < srcUniqElems; ++i)
for (int k = 0; k < dup; ++k)
retVals.push_back(srcVals[i * 2]);
} else { // add-rows
for (int k = 0; k < dup; ++k)
for (int i = 0; i < srcUniqElems; ++i)
retVals.push_back(srcVals[i]);
}
auto llvmStructTy = getTypeConverter()->convertType(resultTy);
Value ret = getStructFromElements(loc, retVals, rewriter, llvmStructTy);
rewriter.replaceOp(op, {ret});
return success();
}
}
DenseMap<SmallVector<unsigned>, Value, SmallVectorKeyInfo> srcValues;
for (size_t i = 0; i < srcOffsets.size(); i++) {
@@ -115,6 +85,7 @@ struct BroadcastOpConversion
}
auto llvmStructTy = getTypeConverter()->convertType(resultTy);
Value resultStruct =
getStructFromElements(loc, resultVals, rewriter, llvmStructTy);
rewriter.replaceOp(op, {resultStruct});

View File

@@ -7,6 +7,7 @@
#include "triton/Analysis/Allocation.h"
//
#include "DotOpHelpers.h"
#include "Utility.h"
#include "mlir/IR/TypeUtilities.h"
#include "triton/Analysis/AxisInfo.h"
@@ -693,29 +694,103 @@ private:
emitBaseIndexForMmaLayoutV1(Location loc, ConversionPatternRewriter &rewriter,
const MmaEncodingAttr &mmaLayout,
ArrayRef<int64_t> shape) const {
llvm_unreachable("emitIndicesForMmaLayoutV1 not implemented");
auto wpt = mmaLayout.getWarpsPerCTA();
auto fpw = LLVM::DotOpMmaV1ConversionHelper::fpw;
auto [isARow, isBRow, isAVec4, isBVec4, id] =
mmaLayout.decodeVoltaLayoutStates();
Value thread = getThreadId(rewriter, loc);
auto *ctx = thread.getContext();
Value _1 = i32_val(1);
Value _2 = i32_val(2);
Value _4 = i32_val(4);
Value _16 = i32_val(16);
Value _32 = i32_val(32);
Value _fpw0 = i32_val(fpw[0]);
Value _fpw1 = i32_val(fpw[1]);
LLVM::DotOpMmaV1ConversionHelper::AParam aParam(isARow, isAVec4);
LLVM::DotOpMmaV1ConversionHelper::BParam bParam(isBRow, isBVec4);
SmallVector<int, 2> rep({aParam.rep[0], bParam.rep[1]});
SmallVector<int, 2> spw({aParam.spw[0], bParam.spw[1]});
SmallVector<unsigned, 2> shapePerCTA({spw[0] * wpt[0], spw[1] * wpt[1]});
Value lane = urem(thread, _32);
Value warp = udiv(thread, _32);
Value warp0 = urem(warp, i32_val(wpt[0]));
Value warp12 = udiv(warp, i32_val(wpt[0]));
Value warp1 = urem(warp12, i32_val(wpt[1]));
// warp offset
Value offWarpM = mul(warp0, i32_val(spw[0]));
Value offWarpN = mul(warp1, i32_val(spw[1]));
// quad offset
Value offQuadM = mul(udiv(and_(lane, _16), _4), _fpw0);
Value offQuadN = mul(udiv(and_(lane, _16), _4), _fpw1);
// pair offset
Value offPairM = udiv(urem(lane, _16), _4);
offPairM = urem(offPairM, _fpw0);
offPairM = mul(offPairM, _4);
Value offPairN = udiv(urem(lane, _16), _4);
offPairN = udiv(offPairN, _fpw0);
offPairN = urem(offPairN, _fpw1);
offPairN = mul(offPairN, _4);
offPairM = mul(offPairM, i32_val(rep[0] / 2));
offQuadM = mul(offQuadM, i32_val(rep[0] / 2));
offPairN = mul(offPairN, i32_val(rep[1] / 2));
offQuadN = mul(offQuadN, i32_val(rep[1] / 2));
// quad pair offset
Value offLaneM = add(offPairM, offQuadM);
Value offLaneN = add(offPairN, offQuadN);
// a, b offset
Value offsetAM = add(offWarpM, offLaneM);
Value offsetBN = add(offWarpN, offLaneN);
// m indices
Value offsetCM = add(and_(lane, _1), offsetAM);
// n indices
Value offsetCN = add((and_(lane, _2)), (add(offWarpN, offPairN)));
return {offsetCM, offsetCN};
}
SmallVector<SmallVector<unsigned>>
emitOffsetForMmaLayoutV1(const MmaEncodingAttr &mmaLayout,
ArrayRef<int64_t> shape) const {
SmallVector<SmallVector<unsigned>> ret;
for (unsigned i = 0; i < shape[0];
i += getShapePerCTA(mmaLayout, shape)[0]) {
for (unsigned j = 0; j < shape[1];
j += getShapePerCTA(mmaLayout, shape)[1]) {
ret.push_back({i, j});
ret.push_back({i, j + 1});
ret.push_back({i + 2, j});
ret.push_back({i + 2, j + 1});
ret.push_back({i, j + 8});
ret.push_back({i, j + 9});
ret.push_back({i + 2, j + 8});
ret.push_back({i + 2, j + 9});
auto [isARow, isBRow, isAVec4, isBVec4, id] =
mmaLayout.decodeVoltaLayoutStates();
LLVM::DotOpMmaV1ConversionHelper::AParam aParam(isARow, isAVec4);
LLVM::DotOpMmaV1ConversionHelper::BParam bParam(isBRow, isBVec4);
auto wpt = mmaLayout.getWarpsPerCTA();
auto fpw = LLVM::DotOpMmaV1ConversionHelper::fpw;
SmallVector<int, 2> rep({aParam.rep[0], bParam.rep[1]});
SmallVector<int, 2> spw({aParam.spw[0], bParam.spw[1]});
SmallVector<unsigned, 2> shapePerCTA({spw[0] * wpt[0], spw[1] * wpt[1]});
SmallVector<unsigned> idxM;
for (unsigned m = 0; m < shape[0]; m += shapePerCTA[0])
for (unsigned mm = 0; mm < rep[0]; ++mm)
idxM.push_back(m + mm * 2);
SmallVector<unsigned> idxN;
for (int n = 0; n < shape[1]; n += shapePerCTA[1]) {
for (int nn = 0; nn < rep[1]; ++nn) {
idxN.push_back(n + nn / 2 * 4 + (nn % 2) * 2 * fpw[1] * rep[1]);
idxN.push_back(n + nn / 2 * 4 + (nn % 2) * 2 * fpw[1] * rep[1] + 1);
}
}
SmallVector<SmallVector<unsigned>> ret;
for (unsigned x1 : idxN) { // N
for (unsigned x0 : idxM) { // M
SmallVector<unsigned> idx(2);
idx[0] = x0; // M
idx[1] = x1; // N
ret.push_back(std::move(idx));
}
}
return ret;
}
@@ -761,17 +836,9 @@ private:
// Emit indices calculation within each ConversionPattern, and returns a
// [elemsPerThread X rank] index matrix.
// TODO: [phil] redundant indices computation do not appear to hurt
// performance much, but they could still significantly slow down
// computations.
SmallVector<SmallVector<Value>> emitIndicesForDistributedLayout(
Location loc, ConversionPatternRewriter &rewriter,
const Attribute &layout, ArrayRef<int64_t> shape) const {
if (auto mmaLayout = layout.template dyn_cast<MmaEncodingAttr>()) {
assert(!mmaLayout.isVolta());
}
// step 1, delinearize threadId to get the base index
auto multiDimBase = emitBaseIndexForLayout(loc, rewriter, layout, shape);
// step 2, get offset of each element
@@ -785,7 +852,6 @@ private:
for (unsigned n = 0; n < elemsPerThread; ++n)
for (unsigned k = 0; k < rank; ++k)
multiDimIdx[n][k] = add(multiDimBase[k], idx_val(offset[n][k]));
return multiDimIdx;
}

View File

@@ -29,124 +29,15 @@ struct SplatOpConversion
ConversionPatternRewriter &rewriter,
Location loc) {
auto tensorTy = resType.cast<RankedTensorType>();
if (tensorTy.getEncoding().isa<BlockedEncodingAttr>() ||
tensorTy.getEncoding().isa<SliceEncodingAttr>()) {
auto srcType = typeConverter->convertType(elemType);
auto llSrc = bitcast(constVal, srcType);
size_t elemsPerThread = getElemsPerThread(tensorTy);
llvm::SmallVector<Value> elems(elemsPerThread, llSrc);
llvm::SmallVector<Type> elemTypes(elems.size(), srcType);
auto structTy =
LLVM::LLVMStructType::getLiteral(rewriter.getContext(), elemTypes);
auto srcType = typeConverter->convertType(elemType);
auto llSrc = bitcast(constVal, srcType);
size_t elemsPerThread = getElemsPerThread(tensorTy);
llvm::SmallVector<Value> elems(elemsPerThread, llSrc);
llvm::SmallVector<Type> elemTypes(elems.size(), srcType);
auto structTy =
LLVM::LLVMStructType::getLiteral(rewriter.getContext(), elemTypes);
return getStructFromElements(loc, elems, rewriter, structTy);
} else if (auto dotLayout =
tensorTy.getEncoding()
.dyn_cast<triton::gpu::DotOperandEncodingAttr>()) {
return convertSplatLikeOpWithDotOperandLayout(
dotLayout, resType, elemType, constVal, typeConverter, rewriter, loc);
} else if (auto mmaLayout =
tensorTy.getEncoding().dyn_cast<MmaEncodingAttr>()) {
return convertSplatLikeOpWithMmaLayout(
mmaLayout, resType, elemType, constVal, typeConverter, rewriter, loc);
} else
assert(false && "Unsupported layout found in ConvertSplatLikeOp");
return {};
}
static Value convertSplatLikeOpWithDotOperandLayout(
const triton::gpu::DotOperandEncodingAttr &layout, Type resType,
Type elemType, Value constVal, TypeConverter *typeConverter,
ConversionPatternRewriter &rewriter, Location loc) {
auto tensorTy = resType.cast<RankedTensorType>();
auto shape = tensorTy.getShape();
auto dotOperand =
tensorTy.getEncoding().cast<triton::gpu::DotOperandEncodingAttr>();
auto parent = layout.getParent();
Value retVal = constVal;
Type retTy = elemType;
int numElems{};
if (auto mmaLayout = parent.dyn_cast<MmaEncodingAttr>()) {
Type matTy;
if (mmaLayout.isAmpere()) {
numElems = layout.getOpIdx() == 0
? MMA16816ConversionHelper::getANumElemsPerThread(
tensorTy, mmaLayout.getWarpsPerCTA()[0])
: MMA16816ConversionHelper::getBNumElemsPerThread(
tensorTy, mmaLayout.getWarpsPerCTA()[1]);
DotOpMmaV2ConversionHelper helper(mmaLayout);
helper.deduceMmaType(tensorTy);
matTy = helper.getMatType();
} else if (mmaLayout.isVolta()) {
DotOpMmaV1ConversionHelper helper(mmaLayout);
bool isRow = layout.getIsMMAv1Row().cast<BoolAttr>().getValue();
auto [isARow, isBRow, isAVec4, isBVec4, _0] =
mmaLayout.decodeVoltaLayoutStates();
if (layout.getOpIdx() == 0) {
DotOpMmaV1ConversionHelper::AParam aParam(isARow, isAVec4);
numElems =
helper.numElemsPerThreadA(shape, isARow, isAVec4, aParam.vec);
} else {
DotOpMmaV1ConversionHelper::BParam bParam(isBRow, isBVec4);
numElems =
helper.numElemsPerThreadB(shape, isBRow, isBVec4, bParam.vec);
}
matTy = helper.getMatType(tensorTy);
}
auto numPackedElems = matTy.cast<LLVM::LLVMStructType>()
.getBody()[0]
.cast<VectorType>()
.getNumElements();
retTy = vec_ty(elemType, numPackedElems);
retVal = undef(retTy);
for (auto i = 0; i < numPackedElems; ++i) {
retVal = insert_element(retTy, retVal, constVal, i32_val(i));
}
} else if (auto blockedLayout = parent.dyn_cast<BlockedEncodingAttr>()) {
numElems = DotOpFMAConversionHelper::getNumElemsPerThread(shape, layout);
} else {
assert(false && "Unsupported layout found");
}
auto structTy = LLVM::LLVMStructType::getLiteral(
rewriter.getContext(), SmallVector<Type>(numElems, retTy));
return getStructFromElements(loc, SmallVector<Value>(numElems, retVal),
rewriter, structTy);
}
static Value convertSplatLikeOpWithMmaLayout(
const MmaEncodingAttr &layout, Type resType, Type elemType,
Value constVal, TypeConverter *typeConverter,
ConversionPatternRewriter &rewriter, Location loc) {
auto tensorTy = resType.cast<RankedTensorType>();
auto shape = tensorTy.getShape();
if (layout.isAmpere()) {
auto [repM, repN] = DotOpMmaV2ConversionHelper::getRepMN(tensorTy);
size_t fcSize = 4 * repM * repN;
auto structTy = LLVM::LLVMStructType::getLiteral(
rewriter.getContext(), SmallVector<Type>(fcSize, elemType));
return getStructFromElements(loc, SmallVector<Value>(fcSize, constVal),
rewriter, structTy);
}
if (layout.isVolta()) {
DotOpMmaV1ConversionHelper helper(layout);
int repM = helper.getRepM(shape[0]);
int repN = helper.getRepN(shape[1]);
// According to mma layout of v1, each thread process 8 elements.
int elems = 8 * repM * repN;
auto structTy = LLVM::LLVMStructType::getLiteral(
rewriter.getContext(), SmallVector<Type>(elems, elemType));
return getStructFromElements(loc, SmallVector<Value>(elems, constVal),
rewriter, structTy);
}
assert(false && "Unsupported mma layout found");
return {};
return getStructFromElements(loc, elems, rewriter, structTy);
}
LogicalResult matchAndRewrite(triton::SplatOp op, OpAdaptor adaptor,
@@ -254,6 +145,7 @@ struct ViewLikeOpConversion : public ConvertTritonGPUOpToLLVMPattern<SourceOp> {
SmallVector<Type> types(elems, elemTy);
Type structTy = LLVM::LLVMStructType::getLiteral(this->getContext(), types);
auto vals = getElementsFromStruct(loc, adaptor.src(), rewriter);
Value view = getStructFromElements(loc, vals, rewriter, structTy);
rewriter.replaceOp(op, view);
return success();

View File

@@ -378,11 +378,19 @@ unsigned MmaEncodingAttr::getElemsPerThread(ArrayRef<int64_t> shape) const {
int res = 0;
if (isVolta()) {
unsigned mmasRow = ceil<unsigned>(shape[0], 16 * getWarpsPerCTA()[0]);
unsigned mmasCol = ceil<unsigned>(shape[1], 16 * getWarpsPerCTA()[1]);
// Each warp-level mma884 will perform a m16xn16xk4 mma, thus get a m16xn16
// matrix as result.
res = mmasRow * mmasCol * (16 * 16 / 32);
auto [isARow, isBRow, isAVec4, isBVec4, id] = decodeVoltaLayoutStates();
static constexpr std::array<unsigned, 2> fpw{{2, 2}};
unsigned packSize0 = (isARow || isAVec4) ? 1 : 2;
unsigned packSize1 = (isBRow && !isBVec4) ? 2 : 1;
unsigned repM = 2 * packSize0;
unsigned repN = 2 * packSize1;
unsigned spwM = fpw[0] * 4 * repM;
unsigned spwN = fpw[1] * 4 * repN;
unsigned wptM = getWarpsPerCTA()[0];
unsigned wptN = getWarpsPerCTA()[1];
unsigned resM = repM * std::max<int>(1, shape[0] / (spwM * wptM));
unsigned resN = 2 * repN * std::max<int>(1, shape[1] / (spwN * wptN));
res = resM * resN;
} else if (isAmpere()) {
unsigned elemsCol = ceil<unsigned>(shape[0], 16 * getWarpsPerCTA()[0]) * 2;
unsigned elemsRow = ceil<unsigned>(shape[1], 8 * getWarpsPerCTA()[1]) * 2;

View File

@@ -601,7 +601,8 @@ public:
return failure();
}
// don't rematerialize non-element-wise
if (!op->hasTrait<mlir::OpTrait::SameOperandsAndResultEncoding>() &&
if (!isa<triton::ViewOp, triton::CatOp>(op) &&
!op->hasTrait<mlir::OpTrait::SameOperandsAndResultEncoding>() &&
!op->hasTrait<mlir::OpTrait::Elementwise>() &&
!isa<triton::StoreOp>(op)) {
return failure();
@@ -865,8 +866,10 @@ public:
return failure();
for (Operation *op : cvtSlices) {
if (!op->hasTrait<mlir::OpTrait::SameOperandsAndResultEncoding>() &&
!op->hasTrait<mlir::OpTrait::SameOperandsAndResultType>())
if (!isa<triton::ViewOp, triton::CatOp>(op) &&
!op->hasTrait<mlir::OpTrait::SameOperandsAndResultEncoding>() &&
!op->hasTrait<mlir::OpTrait::Elementwise>() &&
!isa<triton::StoreOp>(op))
return failure();
for (Value arg : op->getOperands()) {
Operation *argOp = arg.getDefiningOp();

View File

@@ -6,6 +6,8 @@
#include "triton/Analysis/Utility.h"
#include "triton/Dialect/TritonGPU/IR/Dialect.h"
#include "triton/Dialect/TritonGPU/Transforms/Passes.h"
#define GEN_PASS_CLASSES
#include "triton/Dialect/TritonGPU/Transforms/Passes.h.inc"
namespace mlir {
namespace {
@@ -16,297 +18,155 @@ using triton::gpu::MmaEncodingAttr;
using triton::gpu::SharedEncodingAttr;
using triton::gpu::SliceEncodingAttr;
// This pattern collects the wrong Mma those need to update and create the right
// ones for each.
// TODO[Superjomn]: RewirtePattern is not needed here, Rewrite this to a method
class CollectMmaToUpdateForVolta : public mlir::RewritePattern {
// Holds the mapping from old(wrong) mmaEncodingAttr to the new(correct)
// mmaEncodingAttr.
DenseMap<MmaEncodingAttr, MmaEncodingAttr> &mmaToUpdate;
// Get the wpt for MMAv1 using more information.
// Reference the original logic here
// https://github.com/openai/triton/blob/0e4691e6dd91e001a8d33b71badf8b3314325459/lib/codegen/analysis/layout.cc#L223
SmallVector<unsigned> getWarpsPerCTA(ArrayRef<int64_t> shape, bool isARow,
bool isBRow, bool isAVec4, bool isBVec4,
int numWarps) {
// TODO[Superjomn]: Share code with
// DotOpMmaV1ConversionHelper::AParam/BParam, since same code to compute the
// rep,spw and fpw.
SmallVector<unsigned> wpt({1, 1});
SmallVector<unsigned> wpt_nm1;
public:
CollectMmaToUpdateForVolta(
mlir::MLIRContext *ctx,
DenseMap<MmaEncodingAttr, MmaEncodingAttr> &mmaToUpdate)
: mlir::RewritePattern(triton::DotOp::getOperationName(), 1, ctx),
mmaToUpdate(mmaToUpdate) {}
SmallVector<int, 2> rep(2), spw(2);
std::array<int, 3> fpw{{2, 2, 1}};
int packSize0 = (isARow || isAVec4) ? 1 : 2;
rep[0] = 2 * packSize0;
spw[0] = fpw[0] * 4 * rep[0];
mlir::LogicalResult
matchAndRewrite(mlir::Operation *op,
mlir::PatternRewriter &rewriter) const override {
auto dotOp = cast<triton::DotOp>(op);
auto *ctx = dotOp->getContext();
auto AT = dotOp.a().getType().cast<RankedTensorType>();
auto BT = dotOp.b().getType().cast<RankedTensorType>();
auto DT = dotOp.d().getType().cast<RankedTensorType>();
auto shapeA = AT.getShape();
auto shapeB = BT.getShape();
if (!DT.getEncoding())
return failure();
auto mmaLayout = DT.getEncoding().dyn_cast<MmaEncodingAttr>();
if (!(mmaLayout && mmaLayout.isVolta()))
return failure();
int packSize1 = (isBRow && !isBVec4) ? 2 : 1;
rep[1] = 2 * packSize1;
spw[1] = fpw[1] * 4 * rep[1];
// Has processed.
if (mmaToUpdate.count(mmaLayout))
return failure();
do {
wpt_nm1 = wpt;
if (wpt[0] * wpt[1] < numWarps)
wpt[0] = std::clamp<int>(wpt[0] * 2, 1, shape[0] / spw[0]);
if (wpt[0] * wpt[1] < numWarps)
wpt[1] = std::clamp<int>(wpt[1] * 2, 1, shape[1] / spw[1]);
} while (wpt_nm1 != wpt);
auto dotOperandA = AT.getEncoding().cast<DotOperandEncodingAttr>();
auto dotOperandB = BT.getEncoding().cast<DotOperandEncodingAttr>();
bool isARow = dotOperandA.getIsMMAv1Row().cast<BoolAttr>().getValue();
bool isBRow = dotOperandB.getIsMMAv1Row().cast<BoolAttr>().getValue();
auto [isARow_, isBRow_, isAVec4_, isBVec4_, mmaId] =
mmaLayout.decodeVoltaLayoutStates();
return wpt;
}
bool isAVec4 = !isARow && (shapeA[isARow] <= 16);
bool isBVec4 = isBRow && (shapeB[isBRow] <= 16);
// Given a (potentially malformed) DotOp, determines the optimal
// MMAEncoding to use on V100
LogicalResult getOptimizedV100MMaLayout(triton::DotOp dotOp,
MmaEncodingAttr &old,
MmaEncodingAttr &ret) {
auto *ctx = dotOp->getContext();
auto AT = dotOp.a().getType().cast<RankedTensorType>();
auto BT = dotOp.b().getType().cast<RankedTensorType>();
auto DT = dotOp.d().getType().cast<RankedTensorType>();
auto shapeA = AT.getShape();
auto shapeB = BT.getShape();
if (!DT.getEncoding())
return mlir::failure();
auto mmaLayout = DT.getEncoding().dyn_cast<MmaEncodingAttr>();
if (!(mmaLayout && mmaLayout.isVolta()))
return mlir::failure();
// We have an MmaEncodingAttr here. Find the correct layout for it.
auto dotOperandA = AT.getEncoding().cast<DotOperandEncodingAttr>();
auto dotOperandB = BT.getEncoding().cast<DotOperandEncodingAttr>();
bool isARow = dotOperandA.getIsMMAv1Row().cast<BoolAttr>().getValue();
bool isBRow = dotOperandB.getIsMMAv1Row().cast<BoolAttr>().getValue();
auto [isARow_, isBRow_, isAVec4_, isBVec4_, mmaId] =
mmaLayout.decodeVoltaLayoutStates();
bool isAVec4 = !isARow && (shapeA[isARow] <= 16);
bool isBVec4 = isBRow && (shapeB[isBRow] <= 16);
// The wpt of MMAv1 is also determined by isARow, isBRow and shape, and it
// could only be set here for those states might be updated by previous
// patterns in the Combine Pass.
auto tgtWpt = getWarpsPerCTA(DT.getShape(), isARow, isBRow, isAVec4, isBVec4,
product(mmaLayout.getWarpsPerCTA()));
if (isARow == isARow_ && isBRow == isBRow_ && isAVec4 == isAVec4_ &&
isBVec4 == isBVec4_) {
if (tgtWpt == mmaLayout.getWarpsPerCTA())
return mlir::failure();
}
// Recalculate the wpt, for here we could get the latest information, the
// wpt should be updated.
auto updatedWpt =
getWarpsPerCTA(DT.getShape(), isARow, isBRow, isAVec4, isBVec4,
product(mmaLayout.getWarpsPerCTA()));
// return results
old = mmaLayout;
ret =
MmaEncodingAttr::get(ctx, mmaLayout.getVersionMajor(), updatedWpt,
AT.getShape(), BT.getShape(), isARow, isBRow, mmaId);
return mlir::success();
}
// The wpt of MMAv1 is also determined by isARow, isBRow and shape, and it
// could only be set here for those states might be updated by previous
// patterns in the Combine Pass.
auto tgtWpt = getWarpsPerCTA(DT.getShape(), isARow, isBRow, isAVec4,
isBVec4, product(mmaLayout.getWarpsPerCTA()));
if (isARow == isARow_ && isBRow == isBRow_ && isAVec4 == isAVec4_ &&
isBVec4 == isBVec4_) {
if (tgtWpt == mmaLayout.getWarpsPerCTA())
return failure();
// Replace result op type
void setOpResultType(Operation *op, ArrayRef<Type> newTypes) {
if (op->getNumResults() != newTypes.size())
llvm_unreachable("number of types different from number of results");
// nothing to do
if (op->getNumResults() == 0)
return;
// replace types
for (unsigned i = 0; i < op->getNumResults(); i++) {
Type newType = newTypes[i];
op->getResult(i).setType(newType);
}
// special case: arith.constant: we need to change the value attr
if (isa<arith::ConstantOp>(op)) {
Type newType = newTypes[0];
auto attr = op->getAttrDictionary()
.get("value")
.dyn_cast<mlir::DenseElementsAttr>();
if (attr) {
auto newAttr = mlir::DenseElementsAttr::getFromRawBuffer(
newType, attr.getRawData(), true);
op->setAttr("value", newAttr);
}
MmaEncodingAttr newMmaLayout;
{
// Recalculate the wpt, for here we could get the latest information, the
// wpt should be updated.
auto updatedWpt =
getWarpsPerCTA(DT.getShape(), isARow, isBRow, isAVec4, isBVec4,
product(mmaLayout.getWarpsPerCTA()));
newMmaLayout = MmaEncodingAttr::get(ctx, mmaLayout.getVersionMajor(),
updatedWpt, AT.getShape(),
BT.getShape(), isARow, isBRow, mmaId);
}
// Collect the wrong MMA Layouts, and mark need to update.
mmaToUpdate.try_emplace(mmaLayout, newMmaLayout);
return failure();
}
}
// Get the wpt for MMAv1 using more information.
// Reference the original logic here
// https://github.com/openai/triton/blob/0e4691e6dd91e001a8d33b71badf8b3314325459/lib/codegen/analysis/layout.cc#L223
SmallVector<unsigned> getWarpsPerCTA(ArrayRef<int64_t> shape, bool isARow,
bool isBRow, bool isAVec4, bool isBVec4,
int numWarps) const {
// TODO[Superjomn]: Share code with
// DotOpMmaV1ConversionHelper::AParam/BParam, since same code to compute the
// rep,spw and fpw.
SmallVector<unsigned> wpt({1, 1});
SmallVector<unsigned> wpt_nm1;
SmallVector<int, 2> rep(2), spw(2);
std::array<int, 3> fpw{{2, 2, 1}};
int packSize0 = (isARow || isAVec4) ? 1 : 2;
rep[0] = 2 * packSize0;
spw[0] = fpw[0] * 4 * rep[0];
int packSize1 = (isBRow && !isBVec4) ? 2 : 1;
rep[1] = 2 * packSize1;
spw[1] = fpw[1] * 4 * rep[1];
do {
wpt_nm1 = wpt;
if (wpt[0] * wpt[1] < numWarps)
wpt[0] = std::clamp<int>(wpt[0] * 2, 1, shape[0] / spw[0]);
if (wpt[0] * wpt[1] < numWarps)
wpt[1] = std::clamp<int>(wpt[1] * 2, 1, shape[1] / spw[1]);
} while (wpt_nm1 != wpt);
return wpt;
// update style type given the provided layoutMap
Type updateStaleType(
const DenseMap<MmaEncodingAttr, MmaEncodingAttr> &layoutMap,
RankedTensorType type) {
auto encoding = type.getEncoding();
// mma encoding
if (auto mma = encoding.dyn_cast<MmaEncodingAttr>()) {
auto newMma = layoutMap.lookup(mma);
if (!newMma)
return Type();
return RankedTensorType::get(type.getShape(), type.getElementType(),
newMma);
}
};
class UpdateMMAForMMAv1 : public mlir::RewritePattern {
const DenseMap<MmaEncodingAttr, MmaEncodingAttr> &mmaToUpdate;
public:
UpdateMMAForMMAv1(
MLIRContext *context,
const DenseMap<MmaEncodingAttr, MmaEncodingAttr> &mmaToUpdate)
: RewritePattern(MatchAnyOpTypeTag{}, 1, context),
mmaToUpdate(mmaToUpdate) {}
LogicalResult
matchAndRewrite(mlir::Operation *op,
mlir::PatternRewriter &rewriter) const override {
// Nothing to update
if (mmaToUpdate.empty())
return failure();
if (auto dotOp = llvm::dyn_cast<DotOp>(op))
return rewriteDotOp(op, rewriter);
else if (auto cvtOp = llvm::dyn_cast<ConvertLayoutOp>(op))
return rewriteCvtOp(op, rewriter);
else if (auto expandDimsOp = llvm::dyn_cast<triton::ExpandDimsOp>(op))
return rewriteExpandDimsOp(op, rewriter);
else if (auto constOp = llvm::dyn_cast<arith::ConstantOp>(op))
return rewriteConstantOp(op, rewriter);
else
return rewriteElementwiseOp(op, rewriter);
return failure();
}
LogicalResult rewriteDotOp(Operation *op,
mlir::PatternRewriter &rewriter) const {
auto dotOp = llvm::cast<DotOp>(op);
auto tensorTy = dotOp->getResult(0).getType().dyn_cast<RankedTensorType>();
if (!tensorTy)
return failure();
auto mma = dotOp.d()
.getType()
.cast<RankedTensorType>()
.getEncoding()
.dyn_cast<MmaEncodingAttr>();
if (!mma || !mmaToUpdate.count(mma))
return failure();
auto newTensorTy = getUpdatedType(tensorTy);
rewriter.replaceOpWithNewOp<DotOp>(op, newTensorTy, dotOp.a(), dotOp.b(),
dotOp.c(), dotOp.allowTF32());
return success();
}
LogicalResult rewriteCvtOp(Operation *op,
mlir::PatternRewriter &rewriter) const {
auto cvt = llvm::cast<ConvertLayoutOp>(op);
if (!needUpdate(cvt.getResult().getType()))
return failure();
auto tensorTy = cvt.result().getType().dyn_cast<RankedTensorType>();
auto newTensorTy = getUpdatedType(tensorTy);
auto newOp = rewriter.replaceOpWithNewOp<ConvertLayoutOp>(op, newTensorTy,
cvt.getOperand());
return success();
}
LogicalResult rewriteExpandDimsOp(Operation *op,
mlir::PatternRewriter &rewriter) const {
auto expandDims = llvm::cast<triton::ExpandDimsOp>(op);
auto srcTy = expandDims.src().getType();
auto resTy = expandDims.getResult().getType();
// the result type need to update
if (!needUpdate(srcTy) && needUpdate(resTy)) {
rewriter.replaceOpWithNewOp<triton::ExpandDimsOp>(op, expandDims.src(),
expandDims.axis());
return success();
}
return failure();
}
LogicalResult rewriteConstantOp(Operation *op,
mlir::PatternRewriter &rewriter) const {
auto constant = llvm::cast<arith::ConstantOp>(op);
auto resTy = constant.getResult().getType();
if (!needUpdate(resTy))
return failure();
auto tensorTy = constant.getResult().getType().cast<RankedTensorType>();
auto mma = tensorTy.getEncoding().dyn_cast<MmaEncodingAttr>();
auto dot = tensorTy.getEncoding().dyn_cast<DotOperandEncodingAttr>();
if (!mma && !dot)
return failure();
auto newTensorTy = getUpdatedType(tensorTy);
if (auto attr = constant.getValue().dyn_cast<SplatElementsAttr>()) {
auto newRet =
SplatElementsAttr::get(newTensorTy, attr.getSplatValue<Attribute>());
rewriter.replaceOpWithNewOp<arith::ConstantOp>(op, newRet);
return success();
}
return failure();
}
LogicalResult rewriteElementwiseOp(Operation *op,
mlir::PatternRewriter &rewriter) const {
if (op->getNumOperands() != 1 || op->getNumResults() != 1)
return failure();
auto srcTy = op->getOperand(0).getType();
auto resTy = op->getResult(0).getType();
if (needUpdate(resTy)) {
// The op-inputs' types are not necessary to update, for some
// replaceOpWithNewOp will help update them.
op->getResult(0).setType(
getUpdatedType(resTy.dyn_cast<RankedTensorType>()));
return success();
}
return failure();
}
RankedTensorType getUpdatedType(RankedTensorType type) const {
if (!needUpdate(type))
return type;
auto encoding = type.getEncoding();
if (auto mma = encoding.dyn_cast<MmaEncodingAttr>()) {
auto newMma = mmaToUpdate.lookup(mma);
// slice encoding
else if (auto slice = encoding.dyn_cast<SliceEncodingAttr>()) {
if (auto mma = slice.getParent().dyn_cast<MmaEncodingAttr>()) {
auto newMma = layoutMap.lookup(mma);
if (!newMma)
return Type();
auto newSlice =
SliceEncodingAttr::get(slice.getContext(), slice.getDim(), newMma);
return RankedTensorType::get(type.getShape(), type.getElementType(),
newMma);
} else if (auto slice = encoding.dyn_cast<SliceEncodingAttr>()) {
if (auto mma = slice.getParent().dyn_cast<MmaEncodingAttr>()) {
auto newMma = mmaToUpdate.lookup(mma);
auto newSlice =
SliceEncodingAttr::get(slice.getContext(), slice.getDim(), newMma);
return RankedTensorType::get(type.getShape(), type.getElementType(),
newSlice);
}
} else if (auto dotOp = encoding.dyn_cast<DotOperandEncodingAttr>()) {
if (auto mma = dotOp.getParent().dyn_cast<MmaEncodingAttr>()) {
auto newMma = mmaToUpdate.lookup(mma);
auto newDotOp =
DotOperandEncodingAttr::get(dotOp.getContext(), dotOp.getOpIdx(),
newMma, dotOp.getIsMMAv1Row());
return RankedTensorType::get(type.getShape(), type.getElementType(),
newDotOp);
}
newSlice);
}
return type;
}
// Tell if this type contains a wrong MMA encoding and need to update.
bool needUpdate(Type type) const {
auto tensorTy = type.dyn_cast<RankedTensorType>();
if (!tensorTy)
return false;
return needUpdate(tensorTy);
}
// Tell if this type contains a wrong MMA encoding and need to update.
bool needUpdate(RankedTensorType type) const {
auto encoding = type.getEncoding();
if (!encoding)
return false;
MmaEncodingAttr mma;
if ((mma = encoding.dyn_cast<MmaEncodingAttr>())) {
} else if (auto slice = encoding.dyn_cast<SliceEncodingAttr>()) {
mma = slice.getParent().dyn_cast<MmaEncodingAttr>();
} else if (auto dotOp = encoding.dyn_cast<DotOperandEncodingAttr>()) {
mma = dotOp.getParent().dyn_cast<MmaEncodingAttr>();
// dot operand encoding
else if (auto dotOp = encoding.dyn_cast<DotOperandEncodingAttr>()) {
if (auto mma = dotOp.getParent().dyn_cast<MmaEncodingAttr>()) {
auto newMma = layoutMap.lookup(mma);
if (!newMma)
return Type();
auto newDotOp = DotOperandEncodingAttr::get(
dotOp.getContext(), dotOp.getOpIdx(), newMma, dotOp.getIsMMAv1Row());
return RankedTensorType::get(type.getShape(), type.getElementType(),
newDotOp);
}
return mma && mmaToUpdate.count(mma);
}
};
return Type();
}
} // namespace
#define GEN_PASS_CLASSES
#include "triton/Dialect/TritonGPU/Transforms/Passes.h.inc"
class UpdateMmaForVoltaPass
: public UpdateMmaForVoltaBase<UpdateMmaForVoltaPass> {
public:
@@ -314,34 +174,34 @@ public:
void runOnOperation() override {
MLIRContext *context = &getContext();
ModuleOp m = getOperation();
llvm::DenseMap<MmaEncodingAttr, MmaEncodingAttr> mmaToUpdate;
{
mlir::RewritePatternSet patterns(context);
patterns.add<CollectMmaToUpdateForVolta>(context, mmaToUpdate);
GreedyRewriteConfig config;
config.enableRegionSimplification =
false; // The pattern doesn't modify the IR
if (applyPatternsAndFoldGreedily(m, std::move(patterns), config).failed())
signalPassFailure();
}
if (!mmaToUpdate.empty()) {
mlir::RewritePatternSet patterns(context);
patterns.add<UpdateMMAForMMAv1>(context, mmaToUpdate);
mlir::GreedyRewriteConfig config;
// Make sure the slice and dot_operand layouts' parent mma are updated
// before updating DotOp or it will get a mismatch parent-encoding.
config.useTopDownTraversal = true;
if (applyPatternsAndFoldGreedily(m, std::move(patterns), config).failed())
signalPassFailure();
if (fixupLoops(m).failed())
signalPassFailure();
}
// Step 1:
// Build a map from old MMA encoding to new MMA encoding.
DenseMap<MmaEncodingAttr, MmaEncodingAttr> layoutMap;
m.walk([&layoutMap](triton::DotOp dotOp) {
MmaEncodingAttr newLayout;
MmaEncodingAttr oldLayout;
if (failed(getOptimizedV100MMaLayout(dotOp, oldLayout, newLayout)))
return;
layoutMap[oldLayout] = newLayout;
});
// Step 2:
// Replace all wrong layouts with the right one
m.walk([&layoutMap](Operation *op) {
if (op->getNumResults() != 1)
return;
auto type = op->getResult(0).getType().dyn_cast<RankedTensorType>();
if (!type)
return;
Type newType = updateStaleType(layoutMap, type);
if (!newType)
return;
setOpResultType(op, {newType});
});
// Step 3:
// We may have messed up some loops in the process.
// Fix them up
if (fixupLoops(m).failed())
signalPassFailure();
}
};

View File

@@ -1240,23 +1240,24 @@ def test_dot(M, N, K, num_warps, col_a, col_b, epilogue, allow_tf32, dtype, devi
assert 'mma.sync.aligned.m16n8k32.row.col.satfinite.s32.s8.s8.s32' in ptx
@pytest.mark.parametrize("dtype_str", ['float32', 'float16'])
def test_dot_without_load(dtype_str):
@triton.jit
def _kernel(out):
a = GENERATE_TEST_HERE
b = GENERATE_TEST_HERE
c = tl.dot(a, b)
out_ptr = out + tl.arange(0, 32)[:, None] * 32 + tl.arange(0, 32)[None, :]
tl.store(out_ptr, c)
# TODO: uncomment once DotOperandEncoding::getElemsPerThread is implemented
# @pytest.mark.parametrize("dtype_str", ['float32', 'float16'])
# def test_dot_without_load(dtype_str):
# @triton.jit
# def _kernel(out):
# a = GENERATE_TEST_HERE
# b = GENERATE_TEST_HERE
# c = tl.dot(a, b)
# out_ptr = out + tl.arange(0, 32)[:, None] * 32 + tl.arange(0, 32)[None, :]
# tl.store(out_ptr, c)
kernel = patch_kernel(_kernel, {'GENERATE_TEST_HERE': f"tl.full((32, 32), 1.0, tl.{dtype_str})"})
a = torch.ones((32, 32), dtype=getattr(torch, dtype_str), device="cuda")
b = torch.ones((32, 32), dtype=getattr(torch, dtype_str), device="cuda")
out_ref = torch.matmul(a, b)
out = torch.zeros((32, 32), dtype=getattr(torch, dtype_str), device="cuda")
kernel[(1,)](out)
assert torch.all(out == out_ref)
# kernel = patch_kernel(_kernel, {'GENERATE_TEST_HERE': f"tl.full((32, 32), 1.0, tl.{dtype_str})"})
# a = torch.ones((32, 32), dtype=getattr(torch, dtype_str), device="cuda")
# b = torch.ones((32, 32), dtype=getattr(torch, dtype_str), device="cuda")
# out_ref = torch.matmul(a, b)
# out = torch.zeros((32, 32), dtype=getattr(torch, dtype_str), device="cuda")
# kernel[(1,)](out)
# assert torch.all(out == out_ref)
# ---------------
# test arange