[BACKEND] DotOp enable ld.v4 in MMAv1 (#1020)

The existing convert distributed to distributed layouts logic is based
on processing each MMA-block, this requires each MMA-block to share
exactly the same fixed pattern(such as the one described in the [NV PTX
doc](https://docs.nvidia.com/cuda/parallel-thread-execution/index.html#warp-level-matrix-fragment-mma-16816-float)).

While for MMAv1, things are different, the MMA-block has variant
patterns for different shapes and data layouts as below

<img width="200" alt="image"
src="https://user-images.githubusercontent.com/328693/213354941-731d7856-ad24-4f48-be0e-3cf41532cfa4.png">

This requires all the cell coordinates in DotOp output to be computed.
This commit is contained in:
Yan Chunwei
2023-01-20 01:42:33 +08:00
committed by GitHub
parent 408d1d7e87
commit 88498d104a
17 changed files with 563 additions and 281 deletions

View File

@@ -35,7 +35,9 @@ SmallVector<unsigned> getContigPerThread(Attribute layout);
SmallVector<unsigned> getThreadsPerCTA(const Attribute &layout);
SmallVector<unsigned> getShapePerCTA(const Attribute &layout);
SmallVector<unsigned>
getShapePerCTA(const Attribute &layout,
ArrayRef<int64_t> tensorShape = ArrayRef<int64_t>());
SmallVector<unsigned> getOrder(const Attribute &layout);

View File

@@ -95,9 +95,6 @@ A_{3, 2} A_{3, 3} A_{3, 0} A_{3, 1} ... [phase 1] /
bool is_row = order[0] != 0;
bool is_vec4 = opIdx == 0 ? !is_row && (shape[order[0]] <= 16) :
is_row && (shape[order[0]] <= 16);
// TODO[Superjomn]: Support the case when is_vec4=false later
// Currently, we only support ld.v2, for the mma layout varies with different ld vector width.
is_vec4 = true;
int pack_size = opIdx == 0 ? ((is_row || is_vec4) ? 1 : 2) :
((is_row && !is_vec4) ? 2 : 1);
int rep = 2 * pack_size;
@@ -135,8 +132,6 @@ A_{3, 2} A_{3, 3} A_{3, 0} A_{3, 1} ... [phase 1] /
// ---- not implemented ----
llvm_unreachable("unsupported swizzling for provided MMA version");
}]>
];
@@ -403,6 +398,7 @@ For example, the matrix L corresponding to blockTileSize=[32,16] is:
bool isAVec4 = !isARow && (shapeA[isARow] <= 16);
bool isBVec4 = isBRow && (shapeB[isBRow] <= 16);
// 4-bits to encode 4 booleans: [isARow, isBRow, isAVec4, isBVec4]
// 3-bits to encode the MMA ID to make each unique
int versionMinor = (isARow * (1<<0)) |\
(isBRow * (1<<1)) |\
(isAVec4 * (1<<2)) |\
@@ -424,11 +420,6 @@ For example, the matrix L corresponding to blockTileSize=[32,16] is:
// Number of bits in versionMinor to hold the ID of the MMA encoding instance.
// Here 5 bits can hold 32 IDs in a single module.
static constexpr int numBitsToHoldMmaV1ID{5};
// Here is a temporary flag that indicates whether we need to update the warpsPerCTA for MMAv1, since the current backend cannot support the updated wpt.
// The mmav1's wpt-related logic is separated into multiple files, so a global flag is added here for universal coordination.
// TODO[Superjomn]: Remove this flag once the MMAv1 backend is ready.
static constexpr bool _mmaV1UpdateWpt{false};
}];
}

View File

@@ -75,8 +75,10 @@ getScratchConfigForCvtLayout(triton::gpu::ConvertLayoutOp op, unsigned &inVec,
inVec = outOrd[0] == 0 ? 1 : inOrd[0] == 0 ? 1 : srcContigPerThread;
outVec = outOrd[0] == 0 ? 1 : dstContigPerThread;
auto srcShapePerCTA = getShapePerCTA(srcLayout);
auto dstShapePerCTA = getShapePerCTA(dstLayout);
auto srcShape = srcTy.getShape();
auto dstShape = dstTy.getShape();
auto srcShapePerCTA = getShapePerCTA(srcLayout, srcShape);
auto dstShapePerCTA = getShapePerCTA(dstLayout, dstShape);
unsigned rank = dstTy.getRank();
SmallVector<unsigned> paddedRepShape(rank);

View File

@@ -120,27 +120,7 @@ private:
mmaColIdx[0] = add(mmaThreadIdInGrpM2, colWarpOffset);
mmaColIdx[1] = add(mmaThreadIdInGrpM2P1, colWarpOffset);
} else if (mmaLayout.isVolta()) {
multiDimWarpId[0] = urem(multiDimWarpId[0], idx_val(shape[0] / 16));
multiDimWarpId[1] = urem(multiDimWarpId[1], idx_val(shape[1] / 16));
Value laneIdDiv16 = udiv(laneId, _16);
Value laneIdRem16 = urem(laneId, _16);
Value laneIdRem2 = urem(laneId, _2);
Value laneIdRem16Div8 = udiv(laneIdRem16, _8);
Value laneIdRem16Div4 = udiv(laneIdRem16, _4);
Value laneIdRem16Div4Rem2 = urem(laneIdRem16Div4, _2);
Value laneIdRem4Div2 = udiv(urem(laneId, _4), _2);
Value rowWarpOffset = mul(multiDimWarpId[0], _16);
Value colWarpOffset = mul(multiDimWarpId[1], _16);
mmaRowIdx[0] =
add(add(mul(laneIdDiv16, _8), mul(laneIdRem16Div4Rem2, _4)),
laneIdRem2);
mmaRowIdx[0] = add(mmaRowIdx[0], rowWarpOffset);
mmaRowIdx[1] = add(mmaRowIdx[0], _2);
mmaColIdx[0] = add(mul(laneIdRem16Div8, _4), mul(laneIdRem4Div2, _2));
mmaColIdx[0] = add(mmaColIdx[0], colWarpOffset);
mmaColIdx[1] = add(mmaColIdx[0], _1);
mmaColIdx[2] = add(mmaColIdx[0], _8);
mmaColIdx[3] = add(mmaColIdx[0], idx_val(9));
// Volta doesn't follow the pattern here."
} else {
llvm_unreachable("Unexpected MMALayout version");
}
@@ -155,26 +135,12 @@ private:
multiDimOffset[1] = add(
multiDimOffset[1], idx_val(multiDimCTAInRepId[1] * shapePerCTA[1]));
} else if (mmaLayout.isVolta()) {
// the order of elements in a thread:
// c0, c1, ... c4, c5
// c2, c3, ... c6, c7
if (elemId < 2) {
multiDimOffset[0] = mmaRowIdx[0];
multiDimOffset[1] = mmaColIdx[elemId % 2];
} else if (elemId >= 2 && elemId < 4) {
multiDimOffset[0] = mmaRowIdx[1];
multiDimOffset[1] = mmaColIdx[elemId % 2];
} else if (elemId >= 4 && elemId < 6) {
multiDimOffset[0] = mmaRowIdx[0];
multiDimOffset[1] = mmaColIdx[elemId % 2 + 2];
} else if (elemId >= 6) {
multiDimOffset[0] = mmaRowIdx[1];
multiDimOffset[1] = mmaColIdx[elemId % 2 + 2];
}
multiDimOffset[0] = add(
multiDimOffset[0], idx_val(multiDimCTAInRepId[0] * shapePerCTA[0]));
multiDimOffset[1] = add(
multiDimOffset[1], idx_val(multiDimCTAInRepId[1] * shapePerCTA[1]));
auto [isARow, isBRow, isAVec4, isBVec4, mmaId] =
mmaLayout.decodeVoltaLayoutStates();
auto coords = DotOpMmaV1ConversionHelper::getMNCoords(
threadId, rewriter, mmaLayout.getWarpsPerCTA(), shape, isARow,
isBRow, isAVec4, isBVec4);
return DotOpMmaV1ConversionHelper::getCoord(elemId, coords);
} else {
llvm_unreachable("Unexpected MMALayout version");
}
@@ -200,7 +166,7 @@ private:
auto sizePerThread = getSizePerThread(layout);
auto accumSizePerThread = product<unsigned>(sizePerThread);
SmallVector<unsigned> numCTAs(rank);
auto shapePerCTA = getShapePerCTA(layout);
auto shapePerCTA = getShapePerCTA(layout, type.getShape());
auto order = getOrder(layout);
for (unsigned d = 0; d < rank; ++d) {
numCTAs[d] = ceil<unsigned>(type.getShape()[d], shapePerCTA[d]);
@@ -269,6 +235,109 @@ private:
}
}
// The MMAV1's result is quite different from the exising "Replica" structure,
// add a new simple but clear implementation for it to avoid modificating the
// logic of the exising one.
void processReplicaForMMAV1(Location loc, ConversionPatternRewriter &rewriter,
bool stNotRd, RankedTensorType type,
ArrayRef<unsigned> multiDimRepId, unsigned vec,
ArrayRef<unsigned> paddedRepShape,
ArrayRef<unsigned> outOrd,
SmallVector<Value> &vals, Value smemBase,
ArrayRef<int64_t> shape,
bool isDestMma = false) const {
unsigned accumNumCTAsEachRep = 1;
auto layout = type.getEncoding();
MmaEncodingAttr mma = layout.dyn_cast<MmaEncodingAttr>();
auto sliceLayout = layout.dyn_cast<SliceEncodingAttr>();
if (sliceLayout)
mma = sliceLayout.getParent().cast<MmaEncodingAttr>();
auto order = getOrder(layout);
auto rank = type.getRank();
int accumSizePerThread = vals.size();
SmallVector<unsigned> numCTAs(rank, 1);
SmallVector<unsigned> numCTAsEachRep(rank, 1);
SmallVector<unsigned> shapePerCTA = getShapePerCTA(layout, shape);
auto elemTy = type.getElementType();
int ctaId = 0;
auto multiDimCTAInRepId =
getMultiDimIndex<unsigned>(ctaId, numCTAsEachRep, order);
SmallVector<unsigned> multiDimCTAId(rank);
for (const auto &it : llvm::enumerate(multiDimCTAInRepId)) {
auto d = it.index();
multiDimCTAId[d] = multiDimRepId[d] * numCTAsEachRep[d] + it.value();
}
std::vector<std::pair<SmallVector<Value>, Value>> coord2valT(
accumSizePerThread);
bool needTrans = outOrd[0] != 0;
if (sliceLayout || isDestMma)
needTrans = false;
vec = needTrans ? 2 : 1;
{
// We need to transpose the coordinates and values here to enable vec=2
// when store to smem.
std::vector<std::pair<SmallVector<Value>, Value>> coord2val(
accumSizePerThread);
for (unsigned elemId = 0; elemId < accumSizePerThread; ++elemId) {
// TODO[Superjomn]: Move the coordinate computation out of loop, it is
// duplicate in Volta.
SmallVector<Value> multiDimOffset =
getMultiDimOffset(layout, loc, rewriter, elemId, type.getShape(),
multiDimCTAInRepId, shapePerCTA);
coord2val[elemId] = std::make_pair(multiDimOffset, vals[elemId]);
}
if (needTrans) {
auto [isARow, isBRow, isAVec4, isBVec4, mmaId] =
mma.decodeVoltaLayoutStates();
DotOpMmaV1ConversionHelper helper(mma);
// do transpose
int numM = helper.getElemsM(mma.getWarpsPerCTA()[0], shape[0], isARow,
isAVec4);
int numN = accumSizePerThread / numM;
for (int r = 0; r < numM; r++) {
for (int c = 0; c < numN; c++) {
coord2valT[r * numN + c] = std::move(coord2val[c * numM + r]);
}
}
} else {
coord2valT = std::move(coord2val);
}
}
// Now the coord2valT has the transposed and contiguous elements(with
// vec=2), the original vals is not needed.
for (unsigned elemId = 0; elemId < accumSizePerThread; elemId += vec) {
auto coord = coord2valT[elemId].first;
Value offset = linearize(rewriter, loc, coord, paddedRepShape, outOrd);
auto elemPtrTy = ptr_ty(elemTy, 3);
Value ptr = gep(elemPtrTy, smemBase, offset);
auto vecTy = vec_ty(elemTy, vec);
ptr = bitcast(ptr, ptr_ty(vecTy, 3));
if (stNotRd) {
Value valVec = undef(vecTy);
for (unsigned v = 0; v < vec; ++v) {
auto currVal = coord2valT[elemId + v].second;
valVec = insert_element(vecTy, valVec, currVal, idx_val(v));
}
store(valVec, ptr);
} else {
Value valVec = load(ptr);
for (unsigned v = 0; v < vec; ++v) {
Value currVal = extract_element(elemTy, valVec, idx_val(v));
vals[elemId + v] = currVal;
}
}
}
}
// blocked/mma -> blocked/mma.
// Data padding in shared memory to avoid bank conflict.
LogicalResult
@@ -293,8 +362,26 @@ private:
SmallVector<unsigned> outNumCTAsEachRep(rank);
SmallVector<unsigned> inNumCTAs(rank);
SmallVector<unsigned> outNumCTAs(rank);
auto srcShapePerCTA = getShapePerCTA(srcLayout);
auto dstShapePerCTA = getShapePerCTA(dstLayout);
auto srcShapePerCTA = getShapePerCTA(srcLayout, srcTy.getShape());
auto dstShapePerCTA = getShapePerCTA(dstLayout, shape);
// For Volta, all the coords for a CTA are calculated.
bool isSrcMmaV1{}, isDstMmaV1{};
if (auto mmaLayout = srcLayout.dyn_cast<MmaEncodingAttr>()) {
isSrcMmaV1 = mmaLayout.isVolta();
}
if (auto sliceLayout = srcLayout.dyn_cast<SliceEncodingAttr>()) {
isSrcMmaV1 = sliceLayout.getParent().isa<MmaEncodingAttr>() &&
sliceLayout.getParent().cast<MmaEncodingAttr>().isVolta();
}
if (auto mmaLayout = dstLayout.dyn_cast<MmaEncodingAttr>()) {
isDstMmaV1 = mmaLayout.isVolta();
}
if (auto sliceLayout = dstLayout.dyn_cast<SliceEncodingAttr>()) {
isDstMmaV1 = sliceLayout.getParent().isa<MmaEncodingAttr>() &&
sliceLayout.getParent().cast<MmaEncodingAttr>().isVolta();
}
for (unsigned d = 0; d < rank; ++d) {
unsigned inPerCTA = std::min<unsigned>(shape[d], srcShapePerCTA[d]);
unsigned outPerCTA = std::min<unsigned>(shape[d], dstShapePerCTA[d]);
@@ -326,20 +413,31 @@ private:
if (srcLayout.isa<BlockedEncodingAttr>() ||
srcLayout.isa<SliceEncodingAttr>() ||
srcLayout.isa<MmaEncodingAttr>()) {
processReplica(loc, rewriter, /*stNotRd*/ true, srcTy, inNumCTAsEachRep,
multiDimRepId, inVec, paddedRepShape, outOrd, vals,
smemBase);
if (isSrcMmaV1)
processReplicaForMMAV1(loc, rewriter, /*stNotRd*/ true, srcTy,
multiDimRepId, inVec, paddedRepShape, outOrd,
vals, smemBase, shape);
else
processReplica(loc, rewriter, /*stNotRd*/ true, srcTy,
inNumCTAsEachRep, multiDimRepId, inVec, paddedRepShape,
outOrd, vals, smemBase);
} else {
assert(0 && "ConvertLayout with input layout not implemented");
return failure();
}
barrier();
if (dstLayout.isa<BlockedEncodingAttr>() ||
dstLayout.isa<SliceEncodingAttr>() ||
dstLayout.isa<MmaEncodingAttr>()) {
processReplica(loc, rewriter, /*stNotRd*/ false, dstTy,
outNumCTAsEachRep, multiDimRepId, outVec, paddedRepShape,
outOrd, outVals, smemBase);
if (isDstMmaV1)
processReplicaForMMAV1(loc, rewriter, /*stNotRd*/ false, dstTy,
multiDimRepId, outVec, paddedRepShape, outOrd,
outVals, smemBase, shape, /*isDestMma=*/true);
else
processReplica(loc, rewriter, /*stNotRd*/ false, dstTy,
outNumCTAsEachRep, multiDimRepId, outVec,
paddedRepShape, outOrd, outVals, smemBase);
} else {
assert(0 && "ConvertLayout with output layout not implemented");
return failure();
@@ -540,13 +638,13 @@ private:
if (dotOperandLayout.getOpIdx() == 0) { // operand $a
// TODO[Superjomn]: transA is not available here.
bool transA = false;
res = helper.loadA(src, transA, smemObj, getThreadId(rewriter, loc),
loc, rewriter);
res = helper.loadA(src, smemObj, getThreadId(rewriter, loc), loc,
rewriter);
} else if (dotOperandLayout.getOpIdx() == 1) { // operand $b
// TODO[Superjomn]: transB is not available here.
bool transB = false;
res = helper.loadB(src, transB, smemObj, getThreadId(rewriter, loc),
loc, rewriter);
res = helper.loadB(src, smemObj, getThreadId(rewriter, loc), loc,
rewriter);
}
} else {
assert(false && "Unsupported mma layout found");

View File

@@ -47,40 +47,47 @@ struct DotOpMmaV1ConversionHelper {
: mmaLayout(mmaLayout), wpt(mmaLayout.getWarpsPerCTA()) {}
// Help to share some variables across multiple functions for A.
// TODO[Superjomn]: refactor and restrict this to only use in DotOp
// conversion.
struct AParam {
SmallVector<int> rep;
SmallVector<int> spw;
bool isAVec4{};
int vec{}; // This could only used in DotOp, not in
// loadA/loadB/TypeConverter
// TODO[Superjomn]: Support the case when isAVec4=false later
// Currently, we only support ld.v2, for the mma layout varies with
// different ld vector width.
// bool isAVec4 = !isARow && shapeTransed[orderTransed[0]] <= 16;
const bool isAVec4{true};
AParam(bool isARow, bool isAVec4) : isAVec4(isAVec4) { build(isARow); }
explicit AParam(bool isARow) {
private:
void build(bool isARow) {
int packSize0 = (isARow || isAVec4) ? 1 : 2;
int repM = 2 * packSize0;
int repK = 1;
int spwM = fpw[0] * 4 * repM;
rep.assign({repM, 0, repK});
spw.assign({spwM, 0, 1});
vec = 2 * rep[0];
}
};
// Help to share some variables across multiple functions for A.
// TODO[Superjomn]: refactor and restrict this to only use in DotOp
// conversion.
struct BParam {
SmallVector<int> rep;
SmallVector<int> spw;
// TODO[Superjomn]: Support the case when isBVec4=false later
// Currently, we only support ld.v2, for the mma layout varies with
// different ld vector width.
// bool isBVec4 = isBRow && shapeTransed[orderTransed[0]] <= 16;
const bool isBVec4{true};
bool isBVec4{};
int vec{}; // This could only used in DotOp, not in
// loadA/loadB/TypeConverter
explicit BParam(bool isBRow) {
BParam(bool isBRow, bool isBVec4) : isBVec4(isBVec4) { build(isBRow); }
private:
void build(bool isBRow) {
int packSize1 = (isBRow && !isBVec4) ? 2 : 1;
rep.assign({0, 2 * packSize1, 1});
spw.assign({0, fpw[1] * 4 * rep[1], 1});
vec = 2 * rep[1];
}
};
@@ -93,13 +100,6 @@ struct DotOpMmaV1ConversionHelper {
static ArrayRef<unsigned> getMmaInstrShape() { return instrShape; }
static Type getMatType(TensorType operand) {
auto *ctx = operand.getContext();
Type fp16Ty = type::f16Ty(ctx);
Type vecTy = vec_ty(fp16Ty, 2);
return struct_ty(SmallVector<Type>{vecTy});
}
static Type getMmaRetType(TensorType operand) {
auto *ctx = operand.getContext();
Type fp32Ty = type::f32Ty(ctx);
@@ -107,78 +107,101 @@ struct DotOpMmaV1ConversionHelper {
return struct_ty(SmallVector<Type>{8, fp32Ty});
}
// Get the number of fp16x2 elements for $a.
// \param shapeTransed: A's shape or reordered shape if transpose needed.
// \param orderTransed: the order or reordered order if transpose needed.
unsigned getNumM(ArrayRef<int64_t> shapeTransed, bool isARow) const {
AParam param(isARow);
static Type getMatType(TensorType operand) {
auto *ctx = operand.getContext();
Type fp16Ty = type::f16Ty(ctx);
Type vecTy = vec_ty(fp16Ty, 2);
return struct_ty(SmallVector<Type>{vecTy});
}
unsigned numM = param.rep[0] * shapeTransed[0] / (param.spw[0] * wpt[0]);
// Get the number of fp16x2 elements for $a.
unsigned getNumM(int M, bool isARow, bool isAVec4) const {
AParam param(isARow, isAVec4);
unsigned numM = param.rep[0] * M / (param.spw[0] * wpt[0]);
return numM;
}
// Get the number of fp16x2 elements for $b.
// \param shapeTransed: B' shape or reordered shape if transpose needed.
// \param orderTransed: the order or reordered order if transpose needed.
unsigned getNumN(ArrayRef<int64_t> shapeTransed, bool isBRow) const {
BParam param(isBRow);
unsigned getNumN(int N, bool isBRow, bool isBVec4) const {
BParam param(isBRow, isBVec4);
unsigned numN = param.rep[1] * shapeTransed[1] / (param.spw[1] * wpt[1]);
unsigned numN = param.rep[1] * N / (param.spw[1] * wpt[1]);
return numN;
}
int numElemsPerThreadA(ArrayRef<int64_t> shapeTransed,
ArrayRef<unsigned> orderTransed) const {
int numM = getNumM(shapeTransed, orderTransed[0] == 1);
int NK = shapeTransed[1];
int numElemsPerThreadA(ArrayRef<int64_t> shape, bool isARow, bool isAVec4,
int vec) const {
int numM = getNumM(shape[0], isARow, isAVec4);
int NK = shape[1];
// Here we mimic the logic in loadA, the result cannot be calculated
// directly.
llvm::DenseSet<std::pair<int, int>> visited;
auto ld = [&](int m, int k) {
visited.insert({m, k});
if (vec > 4) {
if (isARow)
visited.insert({m, k + 4});
else
visited.insert({m + 1, k});
}
};
// NOTE: We couldn't get the vec from the shared layout.
// int vecA = sharedLayout.getVec();
// TODO[Superjomn]: Consider the case when vecA > 4
bool vecGt4 = false;
int elemsPerLd = vecGt4 ? 4 : 2;
return (numM / 2) * (NK / 4) * elemsPerLd;
for (unsigned k = 0; k < NK; k += 4)
for (unsigned m = 0; m < numM / 2; ++m)
if (!visited.count({m, k}))
ld(m, k);
return visited.size() * 2;
}
int numElemsPerThreadB(ArrayRef<int64_t> shapeTransed,
ArrayRef<unsigned> orderTransed) const {
unsigned numN = getNumN(shapeTransed, orderTransed[0] == 1);
int NK = shapeTransed[0];
// NOTE: We couldn't get the vec from the shared layout.
// int vecB = sharedLayout.getVec();
// TODO[Superjomn]: Consider the case when vecA > 4
bool vecGt4 = false;
int elemsPerLd = vecGt4 ? 4 : 2;
return (numN / 2) * (NK / 4) * elemsPerLd;
int numElemsPerThreadB(ArrayRef<int64_t> shape, bool isBRow, bool isBVec4,
int vec) const {
unsigned numN = getNumN(shape[1], isBRow, isBVec4);
int NK = shape[0];
// Here we mimic the logic in loadA, the result cannot be calculated
// directly.
llvm::DenseSet<std::pair<int, int>> visited;
int elemsPerLd = vec > 4 ? 4 : 2;
auto ld = [&](int n, int k) {
visited.insert({n, k});
if (vec > 4) {
if (isBRow)
visited.insert({n + 1, k});
else
visited.insert({n, k + 4});
}
};
for (unsigned k = 0; k < NK; k += 4)
for (unsigned n = 0; n < numN / 2; ++n) {
if (!visited.count({n, k}))
ld(n, k);
}
return visited.size() * 2;
}
// Loading $a from smem to registers, returns a LLVM::Struct.
Value loadA(Value tensor, bool transA, const SharedMemoryObject &smemObj,
Value thread, Location loc,
ConversionPatternRewriter &rewriter) const {
Value loadA(Value tensor, const SharedMemoryObject &smemObj, Value thread,
Location loc, ConversionPatternRewriter &rewriter) const {
auto *ctx = rewriter.getContext();
auto tensorTy = tensor.getType().cast<RankedTensorType>();
auto sharedLayout = tensorTy.getEncoding().cast<SharedEncodingAttr>();
SmallVector<int64_t> shape(tensorTy.getShape().begin(),
tensorTy.getShape().end());
SmallVector<unsigned> order(sharedLayout.getOrder().begin(),
sharedLayout.getOrder().end());
auto shape = tensorTy.getShape();
auto order = sharedLayout.getOrder();
Value cSwizzleOffset = smemObj.getCSwizzleOffset(order[0]);
Value smemBase = smemObj.getBaseBeforeSwizzle(order[0], loc, rewriter);
bool isARow = order[0] != 0;
AParam param(isARow);
auto [isARow_, _0, isAVec4, _1, _2] = mmaLayout.decodeVoltaLayoutStates();
auto [offsetAM, offsetAK, _0, _1] = computeOffsets(
AParam param(isARow_, isAVec4);
auto [offsetAM, offsetAK, _3, _4] = computeOffsets(
thread, isARow, false, fpw, param.spw, param.rep, rewriter, loc);
if (transA) {
std::swap(shape[0], shape[1]);
std::swap(offsetAM, offsetAK);
std::swap(order[0], order[1]);
}
int vecA = sharedLayout.getVec();
auto strides = smemObj.strides;
@@ -254,10 +277,11 @@ struct DotOpMmaV1ConversionHelper {
}
};
unsigned numM = getNumM(shape, order[0] == 1);
unsigned numM = getNumM(shape[0], isARow, isAVec4);
for (unsigned k = 0; k < NK; k += 4)
for (unsigned m = 0; m < numM / 2; ++m)
loadA(m, k);
if (!has.count({m, k}))
loadA(m, k);
SmallVector<Value> elems;
elems.reserve(has.size() * 2);
@@ -272,9 +296,8 @@ struct DotOpMmaV1ConversionHelper {
}
// Loading $b from smem to registers, returns a LLVM::Struct.
Value loadB(Value tensor, bool transB, const SharedMemoryObject &smemObj,
Value thread, Location loc,
ConversionPatternRewriter &rewriter) const {
Value loadB(Value tensor, const SharedMemoryObject &smemObj, Value thread,
Location loc, ConversionPatternRewriter &rewriter) const {
// smem
auto strides = smemObj.strides;
@@ -282,14 +305,16 @@ struct DotOpMmaV1ConversionHelper {
auto tensorTy = tensor.getType().cast<RankedTensorType>();
auto sharedLayout = tensorTy.getEncoding().cast<SharedEncodingAttr>();
SmallVector<int64_t> shape(tensorTy.getShape().begin(),
tensorTy.getShape().end());
SmallVector<unsigned> order(sharedLayout.getOrder().begin(),
sharedLayout.getOrder().end());
auto shape = tensorTy.getShape();
auto order = sharedLayout.getOrder();
Value smem = smemObj.getBaseBeforeSwizzle(order[0], loc, rewriter);
bool isBRow = order[0] != 0;
BParam param(isBRow);
bool isBRow = order[0] != 0; // is row-major in shared memory layout
// isBRow_ indicates whether B is row-major in DotOperand layout
auto [_0, isBRow_, _1, isBVec4, _2] = mmaLayout.decodeVoltaLayoutStates();
assert(isBRow == isBRow_ && "B need smem isRow");
BParam param(isBRow_, isBVec4);
int vecB = sharedLayout.getVec();
Value strideBN = isBRow ? i32_val(1) : strides[1];
@@ -299,13 +324,8 @@ struct DotOpMmaV1ConversionHelper {
int strideRepN = wpt[1] * fpw[1] * 8;
int strideRepK = 1;
auto [_0, _1, offsetBN, offsetBK] = computeOffsets(
auto [_3, _4, offsetBN, offsetBK] = computeOffsets(
thread, false, isBRow, fpw, param.spw, param.rep, rewriter, loc);
if (transB) {
std::swap(order[0], order[1]);
std::swap(shape[0], shape[1]);
std::swap(offsetBK, offsetBN);
}
// swizzling
int perPhaseB = sharedLayout.getPerPhase();
@@ -371,7 +391,7 @@ struct DotOpMmaV1ConversionHelper {
}
};
unsigned numN = getNumN(shape, order[0] == 1);
unsigned numN = getNumN(shape[1], isBRow, isBVec4);
for (unsigned k = 0; k < NK; k += 4)
for (unsigned n = 0; n < numN / 2; ++n) {
if (!hbs.count({n, k}))
@@ -383,6 +403,7 @@ struct DotOpMmaV1ConversionHelper {
elems.push_back(item.second.first);
elems.push_back(item.second.second);
}
Type resTy = struct_ty(SmallVector<Type>(elems.size(), elemX2Ty));
Value res = getStructFromElements(loc, elems, rewriter, resTy);
return res;
@@ -475,6 +496,117 @@ struct DotOpMmaV1ConversionHelper {
return rcds;
}
// Get the number of elements of this thread in M axis. The N axis could be
// further deduced with the accSize / elemsM. \param wpt: the wpt in M axis
// \param M: the shape in M axis
int getElemsM(int wpt, int M, bool isARow, bool isAVec4) {
DotOpMmaV1ConversionHelper::AParam param(isARow, isAVec4);
int shapePerCTAM = param.spw[0] * wpt;
return M / shapePerCTAM * param.rep[0];
}
using CoordTy = SmallVector<Value, 2>;
// Get the coordinates(m,n) of the elements emit by a thread in accumulator.
static SmallVector<CoordTy>
getMNCoords(Value thread, ConversionPatternRewriter &rewriter,
ArrayRef<unsigned> wpt, ArrayRef<int64_t> shape, bool isARow,
bool isBRow, bool isAVec4, bool isBVec4) {
auto *ctx = thread.getContext();
auto loc = UnknownLoc::get(ctx);
Value _1 = i32_val(1);
Value _2 = i32_val(2);
Value _4 = i32_val(4);
Value _16 = i32_val(16);
Value _32 = i32_val(32);
Value _fpw0 = i32_val(fpw[0]);
Value _fpw1 = i32_val(fpw[1]);
DotOpMmaV1ConversionHelper::AParam aParam(isARow, isAVec4);
DotOpMmaV1ConversionHelper::BParam bParam(isBRow, isBVec4);
SmallVector<int, 2> rep({aParam.rep[0], bParam.rep[1]});
SmallVector<int, 2> spw({aParam.spw[0], bParam.spw[1]});
SmallVector<unsigned, 2> shapePerCTA({spw[0] * wpt[0], spw[1] * wpt[1]});
Value lane = urem(thread, _32);
Value warp = udiv(thread, _32);
Value warp0 = urem(warp, i32_val(wpt[0]));
Value warp12 = udiv(warp, i32_val(wpt[0]));
Value warp1 = urem(warp12, i32_val(wpt[1]));
// warp offset
Value offWarpM = mul(warp0, i32_val(spw[0]));
Value offWarpN = mul(warp1, i32_val(spw[1]));
// quad offset
Value offQuadM = mul(udiv(and_(lane, _16), _4), _fpw0);
Value offQuadN = mul(udiv(and_(lane, _16), _4), _fpw1);
// pair offset
Value offPairM = udiv(urem(lane, _16), _4);
offPairM = urem(offPairM, _fpw0);
offPairM = mul(offPairM, _4);
Value offPairN = udiv(urem(lane, _16), _4);
offPairN = udiv(offPairN, _fpw0);
offPairN = urem(offPairN, _fpw1);
offPairN = mul(offPairN, _4);
// sclare
offPairM = mul(offPairM, i32_val(rep[0] / 2));
offQuadM = mul(offQuadM, i32_val(rep[0] / 2));
offPairN = mul(offPairN, i32_val(rep[1] / 2));
offQuadN = mul(offQuadN, i32_val(rep[1] / 2));
// quad pair offset
Value offLaneM = add(offPairM, offQuadM);
Value offLaneN = add(offPairN, offQuadN);
// a, b offset
Value offsetAM = add(offWarpM, offLaneM);
Value offsetBN = add(offWarpN, offLaneN);
// m indices
Value offsetCM = add(and_(lane, _1), offsetAM);
SmallVector<Value> idxM;
for (unsigned m = 0; m < shape[0]; m += shapePerCTA[0])
for (unsigned mm = 0; mm < rep[0]; ++mm)
idxM.push_back(add(offsetCM, i32_val(m + mm * 2)));
// n indices
Value offsetCN = add((and_(lane, _2)), (add(offWarpN, offPairN)));
SmallVector<Value> idxN;
for (int n = 0; n < shape[1]; n += shapePerCTA[1]) {
for (int nn = 0; nn < rep[1]; ++nn) {
idxN.push_back(add(offsetCN, i32_val(n + nn / 2 * 4 +
(nn % 2) * 2 * fpw[1] * rep[1])));
idxN.push_back(
add(offsetCN,
i32_val(n + nn / 2 * 4 + (nn % 2) * 2 * fpw[1] * rep[1] + 1)));
}
}
SmallVector<SmallVector<Value>> axes({idxM, idxN});
// product the axis M and axis N to get coords, ported from
// generator::init_idx method from triton2.0
// TODO[Superjomn]: check the order.
SmallVector<CoordTy> coords;
for (Value x1 : axes[1]) { // N
for (Value x0 : axes[0]) { // M
SmallVector<Value, 2> idx(2);
idx[0] = x0; // M
idx[1] = x1; // N
coords.push_back(std::move(idx));
}
}
return coords; // {M,N} in row-major
}
// \param elemId the offset of the element in a thread
static CoordTy getCoord(int elemId, ArrayRef<CoordTy> coords) {
return coords[elemId];
}
private:
static constexpr unsigned instrShape[] = {16, 16, 4};
static constexpr unsigned mmaOrder[] = {0, 1};

View File

@@ -120,11 +120,15 @@ private:
bool isARow = ALayout.getIsMMAv1Row().cast<BoolAttr>().getValue();
bool isBRow = BLayout.getIsMMAv1Row().cast<BoolAttr>().getValue();
auto [isARow_, isBRow_, isAVec4_, isBVec4_, mmaId] =
mmaLayout.decodeVoltaLayoutStates();
assert(isARow == isARow_);
assert(isBRow == isBRow_);
DotOpMmaV1ConversionHelper helper(mmaLayout);
unsigned numM = helper.getNumM(AShape, isARow);
unsigned numN = helper.getNumN(BShape, isBRow);
unsigned numM = helper.getNumM(AShape[0], isARow, isAVec4_);
unsigned numN = helper.getNumN(BShape[1], isBRow, isBVec4_);
unsigned NK = AShape[1];
auto has = helper.extractLoadedOperand(adaptor.a(), NK, rewriter);
@@ -156,20 +160,6 @@ private:
return idx;
};
{ // convert the acc's value from accumuator-external order to
// accumulator-internal order.
SmallVector<Value> accInit(acc.size());
for (unsigned m = 0; m < numM / 2; ++m)
for (unsigned n = 0; n < numN / 2; ++n) {
auto idx = getIdx(m, n);
for (unsigned i = 0; i < 8; ++i)
accInit[idx[i]] = acc[(m * numN / 2 + n) * 8 + i];
}
acc = accInit;
}
auto callMMA = [&](unsigned m, unsigned n, unsigned k) {
auto ha = has.at({m, k});
auto hb = hbs.at({n, k});
@@ -206,7 +196,6 @@ private:
for (auto i = 0; i < 8; i++) {
Value elem = extract_val(f32_ty, res, i32_arr_attr(i));
acc[idx[i]] = elem;
resVals[(m * numN / 2 + n) * 8 + i] = elem;
}
};
@@ -216,6 +205,11 @@ private:
callMMA(m, n, k);
}
// res holds the same layout of acc
for (size_t i = 0; i < acc.size(); ++i) {
resVals[i] = acc[i];
}
Type structTy = LLVM::LLVMStructType::getLiteral(
ctx, SmallVector<Type>(resSize, type::f32Ty(ctx)));
Value res = getStructFromElements(loc, resVals, rewriter, structTy);

View File

@@ -69,10 +69,42 @@ struct BroadcastOpConversion
auto srcOffsets = emitOffsetForLayout(srcLayout, srcShape);
auto resultOffsets = emitOffsetForLayout(resultLayout, resultShape);
SmallVector<Value> srcVals = getElementsFromStruct(loc, src, rewriter);
if (auto srcMma = srcLayout.dyn_cast<MmaEncodingAttr>()) {
// NOTE: This is just an naive fix, but for MMA layout, and 2-d fix should
// be all right.
// TODO[Superjomn]: Replace this with a generic implementation.
if (srcMma.isVolta()) {
assert(srcTy.getElementType().isF16() &&
"Unexpected data type on Volta");
int numElemsPerThread = srcMma.getElemsPerThread(resultTy.getShape());
int srcUniqElems = srcVals.size() / 2;
int dup = numElemsPerThread / srcUniqElems;
SmallVector<Value> retVals;
if (srcShape[0] == 1) { // add-cols
for (int i = 0; i < srcUniqElems; ++i)
for (int k = 0; k < dup; ++k)
retVals.push_back(srcVals[i * 2]);
} else { // add-rows
for (int k = 0; k < dup; ++k)
for (int i = 0; i < srcUniqElems; ++i)
retVals.push_back(srcVals[i]);
}
auto llvmStructTy = getTypeConverter()->convertType(resultTy);
Value ret = getStructFromElements(loc, retVals, rewriter, llvmStructTy);
rewriter.replaceOp(op, {ret});
return success();
}
}
DenseMap<SmallVector<unsigned>, Value, SmallVectorKeyInfo> srcValues;
for (size_t i = 0; i < srcOffsets.size(); i++) {
srcValues[srcOffsets[i]] = srcVals[i];
}
SmallVector<Value> resultVals;
for (size_t i = 0; i < resultOffsets.size(); i++) {
auto offset = resultOffsets[i];
@@ -81,6 +113,7 @@ struct BroadcastOpConversion
offset[j] = 0;
resultVals.push_back(srcValues.lookup(offset));
}
auto llvmStructTy = getTypeConverter()->convertType(resultTy);
Value resultStruct =
getStructFromElements(loc, resultVals, rewriter, llvmStructTy);
@@ -523,6 +556,29 @@ struct AsyncCommitGroupOpConversion
}
};
namespace mlir {
namespace LLVM {
void vprintf(StringRef msg, ValueRange args,
ConversionPatternRewriter &rewriter) {
PrintfOpConversion::llPrintf(msg, args, rewriter);
}
void vprintf_array(Value thread, ArrayRef<Value> arr, std::string info,
std::string elem_repr, ConversionPatternRewriter &builder) {
std::string fmt = info + " t-%d ";
std::vector<Value> new_arr({thread});
for (int i = 0; i < arr.size(); ++i) {
fmt += elem_repr + ((i == arr.size() - 1) ? "" : ", ");
new_arr.push_back(arr[i]);
}
vprintf(fmt, new_arr, builder);
}
} // namespace LLVM
} // namespace mlir
void populateTritonGPUToLLVMPatterns(
mlir::LLVMTypeConverter &typeConverter, RewritePatternSet &patterns,
int numWarps, AxisInfoAnalysis &axisInfoAnalysis,
@@ -543,4 +599,4 @@ void populateTritonGPUToLLVMPatterns(
patterns.add<MakeRangeOpConversion>(typeConverter, indexCacheInfo, benefit);
patterns.add<ReturnOpConversion>(typeConverter, benefit);
patterns.add<PrintfOpConversion>(typeConverter, benefit);
}
}

View File

@@ -18,6 +18,20 @@ using ::mlir::LLVM::SharedMemoryObject;
using ::mlir::triton::gpu::BlockedEncodingAttr;
using ::mlir::triton::gpu::MmaEncodingAttr;
using ::mlir::triton::gpu::SliceEncodingAttr;
namespace mlir {
namespace LLVM {
// Helper function for using printf in LLVM conversion.
void vprintf(StringRef msg, ValueRange args,
ConversionPatternRewriter &rewriter);
void vprintf_array(Value thread, ArrayRef<Value> arr, std::string info,
std::string elem_repr, ConversionPatternRewriter &builder);
} // namespace LLVM
} // namespace mlir
// FuncOpConversion/FuncOpConversionBase is borrowed from
// https://github.com/llvm/llvm-project/blob/fae656b2dd80246c3c6f01e9c77c49560368752c/mlir/lib/Conversion/FuncToLLVM/FuncToLLVM.cpp#L276
// since it is not exposed on header files in mlir v14
@@ -199,6 +213,7 @@ public:
ValueRange{rewriter.create<::mlir::gpu::ThreadIdOp>(
loc, rewriter.getIndexType(), ::mlir::gpu::Dimension::x)});
Value threadId = cast.getResult(0);
return threadId;
}
@@ -688,8 +703,10 @@ private:
ArrayRef<int64_t> shape) const {
SmallVector<SmallVector<unsigned>> ret;
for (unsigned i = 0; i < shape[0]; i += getShapePerCTA(mmaLayout)[0]) {
for (unsigned j = 0; j < shape[1]; j += getShapePerCTA(mmaLayout)[1]) {
for (unsigned i = 0; i < shape[0];
i += getShapePerCTA(mmaLayout, shape)[0]) {
for (unsigned j = 0; j < shape[1];
j += getShapePerCTA(mmaLayout, shape)[1]) {
ret.push_back({i, j});
ret.push_back({i, j + 1});
ret.push_back({i + 2, j});
@@ -700,6 +717,7 @@ private:
ret.push_back({i + 2, j + 9});
}
}
return ret;
}
@@ -751,6 +769,9 @@ private:
SmallVector<SmallVector<Value>> emitIndicesForDistributedLayout(
Location loc, ConversionPatternRewriter &rewriter,
const Attribute &layout, ArrayRef<int64_t> shape) const {
if (auto mmaLayout = layout.template dyn_cast<MmaEncodingAttr>()) {
assert(!mmaLayout.isVolta());
}
// step 1, delinearize threadId to get the base index
auto multiDimBase = emitBaseIndexForLayout(loc, rewriter, layout, shape);

View File

@@ -113,24 +113,20 @@ public:
}
if (mmaLayout.isVolta()) {
auto [isARow, isBRow, isAVec4, isBVec4, mmaId] =
mmaLayout.decodeVoltaLayoutStates();
DotOpMmaV1ConversionHelper helper(mmaLayout);
// TODO[Superjomn]: Both transA and transB are not available here.
bool trans = false;
// TODO[Superjomn]: The order of A and B are not available here.
SmallVector<unsigned> order({1, 0});
if (trans) {
std::swap(shape[0], shape[1]);
std::swap(order[0], order[1]);
}
if (dotOpLayout.getOpIdx() == 0) { // $a
int elems = helper.numElemsPerThreadA(shape, order);
DotOpMmaV1ConversionHelper::AParam param(isARow, isAVec4);
int elems =
helper.numElemsPerThreadA(shape, isARow, isAVec4, param.vec);
Type x2Ty = vec_ty(elemTy, 2);
return struct_ty(SmallVector<Type>(elems, x2Ty));
}
if (dotOpLayout.getOpIdx() == 1) { // $b
int elems = helper.numElemsPerThreadB(shape, order);
DotOpMmaV1ConversionHelper::BParam param(isBRow, isBVec4);
int elems =
helper.numElemsPerThreadB(shape, isBRow, isBVec4, param.vec);
Type x2Ty = vec_ty(elemTy, 2);
return struct_ty(SmallVector<Type>(elems, x2Ty));
}

View File

@@ -61,6 +61,8 @@ struct SplatOpConversion
ConversionPatternRewriter &rewriter, Location loc) {
auto tensorTy = resType.cast<RankedTensorType>();
auto shape = tensorTy.getShape();
auto dotOperand =
tensorTy.getEncoding().cast<triton::gpu::DotOperandEncodingAttr>();
auto parent = layout.getParent();
Value retVal = constVal;
Type retTy = elemType;
@@ -78,11 +80,21 @@ struct SplatOpConversion
matTy = helper.getMatType();
} else if (mmaLayout.isVolta()) {
DotOpMmaV1ConversionHelper helper(mmaLayout);
numElems = layout.getOpIdx() == 0
? helper.numElemsPerThreadA(shape, {0, 1})
: helper.numElemsPerThreadB(shape, {0, 1});
bool isRow = layout.getIsMMAv1Row().cast<BoolAttr>().getValue();
auto [isARow, isBRow, isAVec4, isBVec4, _0] =
mmaLayout.decodeVoltaLayoutStates();
if (layout.getOpIdx() == 0) {
DotOpMmaV1ConversionHelper::AParam aParam(isARow, isAVec4);
numElems =
helper.numElemsPerThreadA(shape, isARow, isAVec4, aParam.vec);
} else {
DotOpMmaV1ConversionHelper::BParam bParam(isBRow, isBVec4);
numElems =
helper.numElemsPerThreadB(shape, isBRow, isBVec4, bParam.vec);
}
matTy = helper.getMatType(tensorTy);
}
auto numPackedElems = matTy.cast<LLVM::LLVMStructType>()
.getBody()[0]
.cast<VectorType>()
@@ -92,6 +104,7 @@ struct SplatOpConversion
for (auto i = 0; i < numPackedElems; ++i) {
retVal = insert_element(retTy, retVal, constVal, i32_val(i));
}
} else if (auto blockedLayout = parent.dyn_cast<BlockedEncodingAttr>()) {
numElems = DotOpFMAConversionHelper::getNumElemsPerThread(shape, layout);
} else {

View File

@@ -112,9 +112,7 @@ SmallVector<unsigned> getSizePerThread(const Attribute &layout) {
if (mmaLayout.isAmpere()) {
return {2, 2};
} else if (mmaLayout.isVolta()) {
// Note: here the definition of sizePerThread is obscure, which doesn't
// mean vecSize=4 can be supported in the last dimension.
return {2, 4};
return {1, 2};
} else {
llvm_unreachable("Unexpected mma version");
}
@@ -173,7 +171,8 @@ SmallVector<unsigned> getThreadsPerCTA(const Attribute &layout) {
return threads;
}
SmallVector<unsigned> getShapePerCTA(const Attribute &layout) {
SmallVector<unsigned> getShapePerCTA(const Attribute &layout,
ArrayRef<int64_t> tensorShape) {
SmallVector<unsigned> shape;
if (auto blockedLayout = layout.dyn_cast<BlockedEncodingAttr>()) {
for (unsigned d = 0, n = blockedLayout.getOrder().size(); d < n; ++d)
@@ -186,15 +185,20 @@ SmallVector<unsigned> getShapePerCTA(const Attribute &layout) {
for (unsigned d = 0, n = getOrder(parent).size(); d < n; ++d) {
if (d == dim)
continue;
shape.push_back(getShapePerCTA(parent)[d]);
shape.push_back(getShapePerCTA(parent, tensorShape)[d]);
}
} else if (auto mmaLayout = layout.dyn_cast<MmaEncodingAttr>()) {
if (mmaLayout.isAmpere())
return {16 * mmaLayout.getWarpsPerCTA()[0],
8 * mmaLayout.getWarpsPerCTA()[1]};
if (mmaLayout.isVolta())
return {16 * mmaLayout.getWarpsPerCTA()[0],
16 * mmaLayout.getWarpsPerCTA()[1]};
if (mmaLayout.isVolta()) {
assert(!tensorShape.empty() && "Volta needs the tensorShape");
if (tensorShape.size() == 1) // must be SliceEncoding
return {static_cast<unsigned>(tensorShape[0]),
static_cast<unsigned>(tensorShape[0])};
return {static_cast<unsigned>(tensorShape[0]),
static_cast<unsigned>(tensorShape[1])};
}
assert(0 && "Unexpected MMA layout version found");
} else if (auto dotLayout = layout.dyn_cast<DotOperandEncodingAttr>()) {
auto parentLayout = dotLayout.getParent();
@@ -202,7 +206,7 @@ SmallVector<unsigned> getShapePerCTA(const Attribute &layout) {
if (auto parentMmaLayout = parentLayout.dyn_cast<MmaEncodingAttr>()) {
assert(parentMmaLayout.isAmpere() &&
"mmaLayout version = 1 is not implemented yet");
auto parentShapePerCTA = getShapePerCTA(parentLayout);
auto parentShapePerCTA = getShapePerCTA(parentLayout, tensorShape);
auto opIdx = dotLayout.getOpIdx();
if (opIdx == 0) {
return {parentShapePerCTA[0], 16};

View File

@@ -887,31 +887,8 @@ SmallVector<int64_t, 2> mmaVersionToShapePerWarp(int version) {
SmallVector<unsigned, 2> warpsPerTileV1(const ArrayRef<int64_t> shape,
int numWarps) {
if (!MmaEncodingAttr::_mmaV1UpdateWpt) {
SmallVector<unsigned, 2> ret = {1, 1};
SmallVector<int64_t, 2> shapePerWarp =
mmaVersionToShapePerWarp(1 /*version*/);
bool changed = false;
do {
changed = false;
int pre = ret[0];
if (ret[0] * ret[1] < numWarps) {
ret[0] =
std::clamp<unsigned>(ret[0] * 2, 1, shape[0] / shapePerWarp[0]);
changed = pre != ret[0];
}
if (ret[0] * ret[1] < numWarps) {
pre = ret[1];
ret[1] =
std::clamp<unsigned>(ret[1] * 2, 1, shape[1] / shapePerWarp[1]);
changed = pre != ret[1];
}
} while (changed);
return ret;
} else {
// Set a default value and ensure product of wpt equals numWarps
return {static_cast<unsigned>(numWarps), 1};
}
// Set a default value and ensure product of wpt equals numWarps
return {static_cast<unsigned>(numWarps), 1};
}
SmallVector<unsigned, 2> warpsPerTileV2(triton::DotOp dotOp,
@@ -1107,13 +1084,8 @@ public:
getWarpsPerTile(dotOp, retShape, versionMajor, numWarps);
triton::gpu::MmaEncodingAttr mmaEnc;
if (versionMajor == 1) {
if (MmaEncodingAttr::_mmaV1UpdateWpt)
mmaEnc = triton::gpu::MmaEncodingAttr::get(
oldRetType.getContext(), versionMajor, numWarps, mmaV1Counter++);
else
mmaEnc = triton::gpu::MmaEncodingAttr::get(
dotOp->getContext(), versionMajor, 0 /*versionMinor*/,
warpsPerTileV1(retShape, numWarps));
mmaEnc = triton::gpu::MmaEncodingAttr::get(
oldRetType.getContext(), versionMajor, numWarps, mmaV1Counter++);
} else if (versionMajor == 2) {
mmaEnc = triton::gpu::MmaEncodingAttr::get(
oldRetType.getContext(), versionMajor, 0 /*versionMinor*/,

View File

@@ -13,6 +13,7 @@ using triton::DotOp;
using triton::gpu::ConvertLayoutOp;
using triton::gpu::DotOperandEncodingAttr;
using triton::gpu::MmaEncodingAttr;
using triton::gpu::SharedEncodingAttr;
using triton::gpu::SliceEncodingAttr;
// This pattern collects the wrong Mma those need to update and create the right
@@ -33,12 +34,13 @@ public:
mlir::LogicalResult
matchAndRewrite(mlir::Operation *op,
mlir::PatternRewriter &rewriter) const override {
auto dotOp = cast<triton::DotOp>(op);
auto *ctx = dotOp->getContext();
auto AT = dotOp.a().getType().cast<RankedTensorType>();
auto BT = dotOp.b().getType().cast<RankedTensorType>();
auto DT = dotOp.d().getType().cast<RankedTensorType>();
auto shapeA = AT.getShape();
auto shapeB = BT.getShape();
if (!DT.getEncoding())
return failure();
auto mmaLayout = DT.getEncoding().dyn_cast<MmaEncodingAttr>();
@@ -53,42 +55,34 @@ public:
auto dotOperandB = BT.getEncoding().cast<DotOperandEncodingAttr>();
bool isARow = dotOperandA.getIsMMAv1Row().cast<BoolAttr>().getValue();
bool isBRow = dotOperandB.getIsMMAv1Row().cast<BoolAttr>().getValue();
auto [isARow_, isBRow_, isAVec4, isBVec4, mmaId] =
auto [isARow_, isBRow_, isAVec4_, isBVec4_, mmaId] =
mmaLayout.decodeVoltaLayoutStates();
bool isAVec4 = !isARow && (shapeA[isARow] <= 16);
bool isBVec4 = isBRow && (shapeB[isBRow] <= 16);
// The wpt of MMAv1 is also determined by isARow, isBRow and shape, and it
// could only be set here for those states might be updated by previous
// patterns in the Combine Pass.
if (isARow_ == isARow && isBRow_ == isBRow) {
auto tgtWpt =
getWarpsPerCTA(DT.getShape(), isARow, isBRow, isAVec4, isBVec4,
product(mmaLayout.getWarpsPerCTA()));
// Check if the wpt should be updated.
if (tgtWpt == mmaLayout.getWarpsPerCTA() ||
!MmaEncodingAttr::_mmaV1UpdateWpt)
auto tgtWpt = getWarpsPerCTA(DT.getShape(), isARow, isBRow, isAVec4,
isBVec4, product(mmaLayout.getWarpsPerCTA()));
if (isARow == isARow_ && isBRow == isBRow_ && isAVec4 == isAVec4_ &&
isBVec4 == isBVec4_) {
if (tgtWpt == mmaLayout.getWarpsPerCTA())
return failure();
}
MmaEncodingAttr newMmaLayout;
{
// Create a temporary MMA layout to obtain the isAVec4 and isBVec4
auto tmpMmaLayout = MmaEncodingAttr::get(
ctx, mmaLayout.getVersionMajor(), mmaLayout.getWarpsPerCTA(),
AT.getShape(), BT.getShape(), isARow, isBRow, mmaId);
auto [isARow_, isBRow_, isAVec4_, isBVec4_, _] =
tmpMmaLayout.decodeVoltaLayoutStates();
// Recalculate the wpt, for here we could get the latest information, the
// wpt should be updated.
auto updatedWpt =
getWarpsPerCTA(DT.getShape(), isARow_, isBRow_, isAVec4_, isBVec4_,
getWarpsPerCTA(DT.getShape(), isARow, isBRow, isAVec4, isBVec4,
product(mmaLayout.getWarpsPerCTA()));
auto newWpt = MmaEncodingAttr::_mmaV1UpdateWpt
? updatedWpt
: mmaLayout.getWarpsPerCTA();
newMmaLayout = MmaEncodingAttr::get(ctx, mmaLayout.getVersionMajor(),
newWpt, AT.getShape(), BT.getShape(),
isARow, isBRow, mmaId);
updatedWpt, AT.getShape(),
BT.getShape(), isARow, isBRow, mmaId);
}
// Collect the wrong MMA Layouts, and mark need to update.
@@ -100,14 +94,14 @@ public:
// Get the wpt for MMAv1 using more information.
// Reference the original logic here
// https://github.com/openai/triton/blob/0e4691e6dd91e001a8d33b71badf8b3314325459/lib/codegen/analysis/layout.cc#L223
SmallVector<unsigned, 2> getWarpsPerCTA(ArrayRef<int64_t> shape, bool isARow,
bool isBRow, bool isAVec4,
bool isBVec4, int numWarps) const {
SmallVector<unsigned> getWarpsPerCTA(ArrayRef<int64_t> shape, bool isARow,
bool isBRow, bool isAVec4, bool isBVec4,
int numWarps) const {
// TODO[Superjomn]: Share code with
// DotOpMmaV1ConversionHelper::AParam/BParam, since same code to compute the
// rep,spw and fpw.
SmallVector<unsigned, 2> wpt({1, 1});
SmallVector<unsigned, 2> wpt_nm1;
SmallVector<unsigned> wpt({1, 1});
SmallVector<unsigned> wpt_nm1;
SmallVector<int, 2> rep(2), spw(2);
std::array<int, 3> fpw{{2, 2, 1}};
@@ -242,7 +236,10 @@ public:
auto srcTy = op->getOperand(0).getType();
auto resTy = op->getResult(0).getType();
if (!needUpdate(srcTy) && needUpdate(resTy)) {
if (needUpdate(resTy)) {
// The op-inputs' types are not necessary to update, for some
// replaceOpWithNewOp will help update them.
op->getResult(0).setType(
getUpdatedType(resTy.dyn_cast<RankedTensorType>()));
return success();

View File

@@ -1108,6 +1108,10 @@ def test_dot(M, N, K, num_warps, col_a, col_b, epilogue, allow_tf32, dtype, devi
pytest.skip("Only test int8 on devices with sm >= 80")
elif dtype == 'float32' and allow_tf32:
pytest.skip("Only test tf32 on devices with sm >= 80")
if capability[0] == 7:
if (M, N, K, num_warps) == (128, 256, 32, 8):
pytest.skip("shared memory out of resource")
torch.backends.cuda.matmul.allow_tf32 = allow_tf32
# triton kernel

View File

@@ -900,11 +900,12 @@ def ttir_to_ttgir(mod, num_warps, num_stages, compute_capability):
pm.add_tritongpu_combine_pass(compute_capability)
pm.add_licm_pass()
pm.add_tritongpu_combine_pass(compute_capability)
if compute_capability // 10 == 7:
# The update_mma_for_volta pass helps to compute some information for MMA encoding specifically for MMAv1
pm.add_tritongpu_update_mma_for_volta_pass()
pm.add_cse_pass()
pm.add_tritongpu_decompose_conversions_pass()
if compute_capability // 10 == 7:
# The update_mma_for_volta pass helps to compute some information for MMA encoding specifically for MMAv1
# NOTE this pass should be placed after all the passes those modifies mma layout
pm.add_tritongpu_update_mma_for_volta_pass()
pm.add_cse_pass()
pm.add_symbol_dce_pass()
pm.add_tritongpu_reorder_instructions_pass()

View File

@@ -758,12 +758,12 @@ module attributes {"triton_gpu.num-warps" = 1 : i32} {
// -----
#blocked0 = #triton_gpu.blocked<{sizePerThread = [1, 4], threadsPerWarp = [32, 1], warpsPerCTA = [1, 4], order = [1, 0]}>
#mma = #triton_gpu.mma<{versionMajor = 1, warpsPerCTA = [2, 1]}>
#blocked = #triton_gpu.blocked<{sizePerThread = [1, 4], threadsPerWarp = [2, 16], warpsPerCTA = [1, 4], order = [1, 0]}>
#mma = #triton_gpu.mma<{versionMajor = 1, versionMinor = 3, warpsPerCTA = [2, 2]}>
module attributes {"triton_gpu.num-warps" = 1 : i32} {
// CHECK: llvm.mlir.global external @global_smem() {addr_space = 3 : i32} : !llvm.array<0 x i8>
// CHECK-LABEL: convert_layout_mmav1_block
func @convert_layout_mmav1_blocked(%arg0: tensor<32x16xf32, #mma>) {
func @convert_layout_mmav1_blocked(%arg0: tensor<32x64xf32, #mma>) {
// CHECK: llvm.store
// CHECK-SAME: !llvm.ptr<vector<2xf32>, 3>
// CHECK: llvm.store
@@ -775,13 +775,12 @@ module attributes {"triton_gpu.num-warps" = 1 : i32} {
// CHECK: nvvm.barrier0
// CHECK: llvm.load
// CHECK-SAME: !llvm.ptr<vector<4xf32>, 3>
%0 = triton_gpu.convert_layout %arg0 : (tensor<32x16xf32, #mma>) -> tensor<32x16xf32, #blocked0>
%0 = triton_gpu.convert_layout %arg0 : (tensor<32x64xf32, #mma>) -> tensor<32x64xf32, #blocked>
return
}
}
// -----
#blocked0 = #triton_gpu.blocked<{sizePerThread = [1, 8], threadsPerWarp = [8, 4], warpsPerCTA = [8, 1], order = [1, 0]}>
#shared0 = #triton_gpu.shared<{vec = 8, perPhase = 2, maxPhase = 4, order = [1, 0]}>
module attributes {"triton_gpu.num-warps" = 1 : i32} {
@@ -868,24 +867,24 @@ module attributes {"triton_gpu.num-warps" = 4 : i32} {
// -----
#blocked = #triton_gpu.blocked<{sizePerThread = [1, 4], threadsPerWarp = [2, 16], warpsPerCTA = [1, 4], order = [1, 0]}>
#shared = #triton_gpu.shared<{vec = 1, perPhase = 1, maxPhase = 1, order = [1, 0]}>
#mma = #triton_gpu.mma<{versionMajor = 1, warpsPerCTA = [2, 2]}>
#shared0 = #triton_gpu.shared<{vec = 4, perPhase = 1, maxPhase = 8, order = [1, 0]}>
#shared1 = #triton_gpu.shared<{vec = 8, perPhase = 1, maxPhase = 4, order = [1, 0]}>
#mma = #triton_gpu.mma<{versionMajor = 1, versionMinor = 3, warpsPerCTA = [2, 2]}>
#dot_operand_a = #triton_gpu.dot_op<{opIdx=0, parent=#mma, isMMAv1Row=true}>
#dot_operand_b = #triton_gpu.dot_op<{opIdx=1, parent=#mma, isMMAv1Row=true}>
module attributes {"triton_gpu.num-warps" = 4 : i32} {
func @matmul884_kernel_dot_operand_layout(%ptr:!tt.ptr<f32> {tt.divisibility = 16 : i32},
%a:tensor<128x32xf16, #shared>, %b:tensor<32x256xf16, #shared>) {
%cst = arith.constant dense<0.000000e+00> : tensor<128x256xf32, #mma>
%a:tensor<32x64xf16, #shared0>, %b:tensor<64x64xf16, #shared1>) {
%cst = arith.constant dense<0.000000e+00> : tensor<32x64xf32, #mma>
// CHECK: ldmatrix.sync.aligned.m8n8.x4.shared.b16
%a_mat = triton_gpu.convert_layout %a : (tensor<128x32xf16, #shared>) -> tensor<128x32xf16, #dot_operand_a>
%b_mat = triton_gpu.convert_layout %b : (tensor<32x256xf16, #shared>) -> tensor<32x256xf16, #dot_operand_b>
%a_mat = triton_gpu.convert_layout %a : (tensor<32x64xf16, #shared0>) -> tensor<32x64xf16, #dot_operand_a>
%b_mat = triton_gpu.convert_layout %b : (tensor<64x64xf16, #shared1>) -> tensor<64x64xf16, #dot_operand_b>
%28 = tt.dot %a_mat, %b_mat, %cst {allowTF32 = true, transA = false, transB = false} : tensor<128x32xf16, #dot_operand_a> * tensor<32x256xf16, #dot_operand_b> -> tensor<128x256xf32, #mma>
// TODO[goostavz]: uncomment the following lines after convert_layout[mma<v1> -> blocked] is ready.
// %38 = triton_gpu.convert_layout %28 : (tensor<128x256xf32, #mma>) -> tensor<128x256xf32, #blocked>
// %30 = tt.splat %ptr : (!tt.ptr<f32>) -> tensor<128x1x!tt.ptr<f32>, #blocked>
// %36 = tt.broadcast %30 : (tensor<128x1x!tt.ptr<f32>, #blocked>) -> tensor<128x256x!tt.ptr<f32>, #blocked>
// tt.store %36, %38 : tensor<128x256xf32, #blocked>
%28 = tt.dot %a_mat, %b_mat, %cst {allowTF32 = true, transA = false, transB = false} : tensor<32x64xf16, #dot_operand_a> * tensor<64x64xf16, #dot_operand_b> -> tensor<32x64xf32, #mma>
%38 = triton_gpu.convert_layout %28 : (tensor<32x64xf32, #mma>) -> tensor<32x64xf32, #blocked>
%30 = tt.splat %ptr : (!tt.ptr<f32>) -> tensor<32x1x!tt.ptr<f32>, #blocked>
%36 = tt.broadcast %30 : (tensor<32x1x!tt.ptr<f32>, #blocked>) -> tensor<32x64x!tt.ptr<f32>, #blocked>
tt.store %36, %38 : tensor<32x64xf32, #blocked>
return
}
}
@@ -999,7 +998,7 @@ module attributes {"triton_gpu.num-warps" = 4 : i32} {
%v1 = arith.addi %v0, %blockdimz : i32
%0 = tt.splat %v1 : (i32) -> tensor<32xi32, #blocked0>
tt.store %a, %0 : tensor<32xi32, #blocked0>
return
}
}
@@ -1007,7 +1006,7 @@ module attributes {"triton_gpu.num-warps" = 4 : i32} {
// -----
#blocked0 = #triton_gpu.blocked<{sizePerThread = [2], threadsPerWarp = [32], warpsPerCTA = [1], order = [0]}>
module attributes {"triton_gpu.num-warps" = 4 : i32} {
// CHECK-LABEL: test_index_cache
// CHECK-LABEL: test_index_cache
func @test_index_cache() {
// CHECK: nvvm.read.ptx.sreg.tid.x
%0 = tt.make_range {end = 256 : i32, start = 0 : i32} : tensor<256xi32, #blocked0>
@@ -1021,7 +1020,7 @@ module attributes {"triton_gpu.num-warps" = 4 : i32} {
#blocked0 = #triton_gpu.blocked<{sizePerThread = [1, 8], threadsPerWarp = [8, 4], warpsPerCTA = [8, 1], order = [1, 0]}>
#shared0 = #triton_gpu.shared<{vec = 8, perPhase = 2, maxPhase = 4, order = [1, 0]}>
module attributes {"triton_gpu.num-warps" = 1 : i32} {
// CHECK-LABEL: test_base_index_cache
// CHECK-LABEL: test_base_index_cache
func @test_base_index_cache(%arg0: tensor<128x32xf32, #blocked0>) {
// CHECK: nvvm.read.ptx.sreg.tid.x
%0 = triton_gpu.convert_layout %arg0 : (tensor<128x32xf32, #blocked0>) -> tensor<128x32xf32, #shared0>
@@ -1045,4 +1044,4 @@ module attributes {"triton_gpu.num-warps" = 1 : i32} {
}
return
}
}
}

View File

@@ -12,7 +12,7 @@
#dot_operand_b = #triton_gpu.dot_op<{opIdx=1, parent=#mma0, isMMAv1Row=false}>
// It creates a new MMA layout to fit with $a and $b's dot_operand, and get the right warpsPerCTA
// The ID of this MMA instance should be 0.
// CHECK: [[new_mma:#mma.*]] = #triton_gpu.mma<{versionMajor = 1, versionMinor = 3, warpsPerCTA = [4, 4]}>
// CHECK: [[new_mma:#mma.*]] = #triton_gpu.mma<{versionMajor = 1, versionMinor = 3, warpsPerCTA = [4, 2]}>
module attributes {"triton_gpu.num-warps" = 16 : i32} {
// CHECK-LABEL: dot_mmav1
func @dot_mmav1(%A: tensor<64x64xf16, #blocked0>, %B: tensor<64x64xf16, #blocked0>) -> tensor<64x64xf32, #blocked0> {
@@ -40,8 +40,8 @@ module attributes {"triton_gpu.num-warps" = 16 : i32} {
#mma1 = #triton_gpu.mma<{versionMajor=1, versionMinor=16, warpsPerCTA=[4,4]}>
// Will still get two MMA layouts
// CHECK: [[new_mma:#mma.*]] = #triton_gpu.mma<{versionMajor = 1, versionMinor = 3, warpsPerCTA = [4, 4]}>
// CHECK: [[new_mma1:#mma.*]] = #triton_gpu.mma<{versionMajor = 1, versionMinor = 19, warpsPerCTA = [4, 4]}>
// CHECK: [[new_mma:#mma.*]] = #triton_gpu.mma<{versionMajor = 1, versionMinor = 3, warpsPerCTA = [4, 2]}>
// CHECK: [[new_mma1:#mma.*]] = #triton_gpu.mma<{versionMajor = 1, versionMinor = 19, warpsPerCTA = [4, 2]}>
#dot_operand_a = #triton_gpu.dot_op<{opIdx=0, parent=#mma0, isMMAv1Row=true}>
#dot_operand_b = #triton_gpu.dot_op<{opIdx=1, parent=#mma0, isMMAv1Row=false}>