mirror of
https://github.com/ROCm/ROCm.git
synced 2026-04-05 03:01:17 -04:00
[BACKEND] DotOp enable ld.v4 in MMAv1 (#1020)
The existing convert distributed to distributed layouts logic is based on processing each MMA-block, this requires each MMA-block to share exactly the same fixed pattern(such as the one described in the [NV PTX doc](https://docs.nvidia.com/cuda/parallel-thread-execution/index.html#warp-level-matrix-fragment-mma-16816-float)). While for MMAv1, things are different, the MMA-block has variant patterns for different shapes and data layouts as below <img width="200" alt="image" src="https://user-images.githubusercontent.com/328693/213354941-731d7856-ad24-4f48-be0e-3cf41532cfa4.png"> This requires all the cell coordinates in DotOp output to be computed.
This commit is contained in:
@@ -35,7 +35,9 @@ SmallVector<unsigned> getContigPerThread(Attribute layout);
|
||||
|
||||
SmallVector<unsigned> getThreadsPerCTA(const Attribute &layout);
|
||||
|
||||
SmallVector<unsigned> getShapePerCTA(const Attribute &layout);
|
||||
SmallVector<unsigned>
|
||||
getShapePerCTA(const Attribute &layout,
|
||||
ArrayRef<int64_t> tensorShape = ArrayRef<int64_t>());
|
||||
|
||||
SmallVector<unsigned> getOrder(const Attribute &layout);
|
||||
|
||||
|
||||
@@ -95,9 +95,6 @@ A_{3, 2} A_{3, 3} A_{3, 0} A_{3, 1} ... [phase 1] /
|
||||
bool is_row = order[0] != 0;
|
||||
bool is_vec4 = opIdx == 0 ? !is_row && (shape[order[0]] <= 16) :
|
||||
is_row && (shape[order[0]] <= 16);
|
||||
// TODO[Superjomn]: Support the case when is_vec4=false later
|
||||
// Currently, we only support ld.v2, for the mma layout varies with different ld vector width.
|
||||
is_vec4 = true;
|
||||
int pack_size = opIdx == 0 ? ((is_row || is_vec4) ? 1 : 2) :
|
||||
((is_row && !is_vec4) ? 2 : 1);
|
||||
int rep = 2 * pack_size;
|
||||
@@ -135,8 +132,6 @@ A_{3, 2} A_{3, 3} A_{3, 0} A_{3, 1} ... [phase 1] /
|
||||
|
||||
// ---- not implemented ----
|
||||
llvm_unreachable("unsupported swizzling for provided MMA version");
|
||||
|
||||
|
||||
}]>
|
||||
];
|
||||
|
||||
@@ -403,6 +398,7 @@ For example, the matrix L corresponding to blockTileSize=[32,16] is:
|
||||
bool isAVec4 = !isARow && (shapeA[isARow] <= 16);
|
||||
bool isBVec4 = isBRow && (shapeB[isBRow] <= 16);
|
||||
// 4-bits to encode 4 booleans: [isARow, isBRow, isAVec4, isBVec4]
|
||||
// 3-bits to encode the MMA ID to make each unique
|
||||
int versionMinor = (isARow * (1<<0)) |\
|
||||
(isBRow * (1<<1)) |\
|
||||
(isAVec4 * (1<<2)) |\
|
||||
@@ -424,11 +420,6 @@ For example, the matrix L corresponding to blockTileSize=[32,16] is:
|
||||
// Number of bits in versionMinor to hold the ID of the MMA encoding instance.
|
||||
// Here 5 bits can hold 32 IDs in a single module.
|
||||
static constexpr int numBitsToHoldMmaV1ID{5};
|
||||
|
||||
// Here is a temporary flag that indicates whether we need to update the warpsPerCTA for MMAv1, since the current backend cannot support the updated wpt.
|
||||
// The mmav1's wpt-related logic is separated into multiple files, so a global flag is added here for universal coordination.
|
||||
// TODO[Superjomn]: Remove this flag once the MMAv1 backend is ready.
|
||||
static constexpr bool _mmaV1UpdateWpt{false};
|
||||
}];
|
||||
|
||||
}
|
||||
|
||||
@@ -75,8 +75,10 @@ getScratchConfigForCvtLayout(triton::gpu::ConvertLayoutOp op, unsigned &inVec,
|
||||
inVec = outOrd[0] == 0 ? 1 : inOrd[0] == 0 ? 1 : srcContigPerThread;
|
||||
outVec = outOrd[0] == 0 ? 1 : dstContigPerThread;
|
||||
|
||||
auto srcShapePerCTA = getShapePerCTA(srcLayout);
|
||||
auto dstShapePerCTA = getShapePerCTA(dstLayout);
|
||||
auto srcShape = srcTy.getShape();
|
||||
auto dstShape = dstTy.getShape();
|
||||
auto srcShapePerCTA = getShapePerCTA(srcLayout, srcShape);
|
||||
auto dstShapePerCTA = getShapePerCTA(dstLayout, dstShape);
|
||||
|
||||
unsigned rank = dstTy.getRank();
|
||||
SmallVector<unsigned> paddedRepShape(rank);
|
||||
|
||||
@@ -120,27 +120,7 @@ private:
|
||||
mmaColIdx[0] = add(mmaThreadIdInGrpM2, colWarpOffset);
|
||||
mmaColIdx[1] = add(mmaThreadIdInGrpM2P1, colWarpOffset);
|
||||
} else if (mmaLayout.isVolta()) {
|
||||
multiDimWarpId[0] = urem(multiDimWarpId[0], idx_val(shape[0] / 16));
|
||||
multiDimWarpId[1] = urem(multiDimWarpId[1], idx_val(shape[1] / 16));
|
||||
Value laneIdDiv16 = udiv(laneId, _16);
|
||||
Value laneIdRem16 = urem(laneId, _16);
|
||||
Value laneIdRem2 = urem(laneId, _2);
|
||||
Value laneIdRem16Div8 = udiv(laneIdRem16, _8);
|
||||
Value laneIdRem16Div4 = udiv(laneIdRem16, _4);
|
||||
Value laneIdRem16Div4Rem2 = urem(laneIdRem16Div4, _2);
|
||||
Value laneIdRem4Div2 = udiv(urem(laneId, _4), _2);
|
||||
Value rowWarpOffset = mul(multiDimWarpId[0], _16);
|
||||
Value colWarpOffset = mul(multiDimWarpId[1], _16);
|
||||
mmaRowIdx[0] =
|
||||
add(add(mul(laneIdDiv16, _8), mul(laneIdRem16Div4Rem2, _4)),
|
||||
laneIdRem2);
|
||||
mmaRowIdx[0] = add(mmaRowIdx[0], rowWarpOffset);
|
||||
mmaRowIdx[1] = add(mmaRowIdx[0], _2);
|
||||
mmaColIdx[0] = add(mul(laneIdRem16Div8, _4), mul(laneIdRem4Div2, _2));
|
||||
mmaColIdx[0] = add(mmaColIdx[0], colWarpOffset);
|
||||
mmaColIdx[1] = add(mmaColIdx[0], _1);
|
||||
mmaColIdx[2] = add(mmaColIdx[0], _8);
|
||||
mmaColIdx[3] = add(mmaColIdx[0], idx_val(9));
|
||||
// Volta doesn't follow the pattern here."
|
||||
} else {
|
||||
llvm_unreachable("Unexpected MMALayout version");
|
||||
}
|
||||
@@ -155,26 +135,12 @@ private:
|
||||
multiDimOffset[1] = add(
|
||||
multiDimOffset[1], idx_val(multiDimCTAInRepId[1] * shapePerCTA[1]));
|
||||
} else if (mmaLayout.isVolta()) {
|
||||
// the order of elements in a thread:
|
||||
// c0, c1, ... c4, c5
|
||||
// c2, c3, ... c6, c7
|
||||
if (elemId < 2) {
|
||||
multiDimOffset[0] = mmaRowIdx[0];
|
||||
multiDimOffset[1] = mmaColIdx[elemId % 2];
|
||||
} else if (elemId >= 2 && elemId < 4) {
|
||||
multiDimOffset[0] = mmaRowIdx[1];
|
||||
multiDimOffset[1] = mmaColIdx[elemId % 2];
|
||||
} else if (elemId >= 4 && elemId < 6) {
|
||||
multiDimOffset[0] = mmaRowIdx[0];
|
||||
multiDimOffset[1] = mmaColIdx[elemId % 2 + 2];
|
||||
} else if (elemId >= 6) {
|
||||
multiDimOffset[0] = mmaRowIdx[1];
|
||||
multiDimOffset[1] = mmaColIdx[elemId % 2 + 2];
|
||||
}
|
||||
multiDimOffset[0] = add(
|
||||
multiDimOffset[0], idx_val(multiDimCTAInRepId[0] * shapePerCTA[0]));
|
||||
multiDimOffset[1] = add(
|
||||
multiDimOffset[1], idx_val(multiDimCTAInRepId[1] * shapePerCTA[1]));
|
||||
auto [isARow, isBRow, isAVec4, isBVec4, mmaId] =
|
||||
mmaLayout.decodeVoltaLayoutStates();
|
||||
auto coords = DotOpMmaV1ConversionHelper::getMNCoords(
|
||||
threadId, rewriter, mmaLayout.getWarpsPerCTA(), shape, isARow,
|
||||
isBRow, isAVec4, isBVec4);
|
||||
return DotOpMmaV1ConversionHelper::getCoord(elemId, coords);
|
||||
} else {
|
||||
llvm_unreachable("Unexpected MMALayout version");
|
||||
}
|
||||
@@ -200,7 +166,7 @@ private:
|
||||
auto sizePerThread = getSizePerThread(layout);
|
||||
auto accumSizePerThread = product<unsigned>(sizePerThread);
|
||||
SmallVector<unsigned> numCTAs(rank);
|
||||
auto shapePerCTA = getShapePerCTA(layout);
|
||||
auto shapePerCTA = getShapePerCTA(layout, type.getShape());
|
||||
auto order = getOrder(layout);
|
||||
for (unsigned d = 0; d < rank; ++d) {
|
||||
numCTAs[d] = ceil<unsigned>(type.getShape()[d], shapePerCTA[d]);
|
||||
@@ -269,6 +235,109 @@ private:
|
||||
}
|
||||
}
|
||||
|
||||
// The MMAV1's result is quite different from the exising "Replica" structure,
|
||||
// add a new simple but clear implementation for it to avoid modificating the
|
||||
// logic of the exising one.
|
||||
void processReplicaForMMAV1(Location loc, ConversionPatternRewriter &rewriter,
|
||||
bool stNotRd, RankedTensorType type,
|
||||
ArrayRef<unsigned> multiDimRepId, unsigned vec,
|
||||
ArrayRef<unsigned> paddedRepShape,
|
||||
ArrayRef<unsigned> outOrd,
|
||||
SmallVector<Value> &vals, Value smemBase,
|
||||
ArrayRef<int64_t> shape,
|
||||
bool isDestMma = false) const {
|
||||
unsigned accumNumCTAsEachRep = 1;
|
||||
auto layout = type.getEncoding();
|
||||
MmaEncodingAttr mma = layout.dyn_cast<MmaEncodingAttr>();
|
||||
auto sliceLayout = layout.dyn_cast<SliceEncodingAttr>();
|
||||
if (sliceLayout)
|
||||
mma = sliceLayout.getParent().cast<MmaEncodingAttr>();
|
||||
|
||||
auto order = getOrder(layout);
|
||||
auto rank = type.getRank();
|
||||
int accumSizePerThread = vals.size();
|
||||
|
||||
SmallVector<unsigned> numCTAs(rank, 1);
|
||||
SmallVector<unsigned> numCTAsEachRep(rank, 1);
|
||||
SmallVector<unsigned> shapePerCTA = getShapePerCTA(layout, shape);
|
||||
auto elemTy = type.getElementType();
|
||||
|
||||
int ctaId = 0;
|
||||
|
||||
auto multiDimCTAInRepId =
|
||||
getMultiDimIndex<unsigned>(ctaId, numCTAsEachRep, order);
|
||||
SmallVector<unsigned> multiDimCTAId(rank);
|
||||
for (const auto &it : llvm::enumerate(multiDimCTAInRepId)) {
|
||||
auto d = it.index();
|
||||
multiDimCTAId[d] = multiDimRepId[d] * numCTAsEachRep[d] + it.value();
|
||||
}
|
||||
|
||||
std::vector<std::pair<SmallVector<Value>, Value>> coord2valT(
|
||||
accumSizePerThread);
|
||||
bool needTrans = outOrd[0] != 0;
|
||||
if (sliceLayout || isDestMma)
|
||||
needTrans = false;
|
||||
|
||||
vec = needTrans ? 2 : 1;
|
||||
{
|
||||
// We need to transpose the coordinates and values here to enable vec=2
|
||||
// when store to smem.
|
||||
std::vector<std::pair<SmallVector<Value>, Value>> coord2val(
|
||||
accumSizePerThread);
|
||||
for (unsigned elemId = 0; elemId < accumSizePerThread; ++elemId) {
|
||||
// TODO[Superjomn]: Move the coordinate computation out of loop, it is
|
||||
// duplicate in Volta.
|
||||
SmallVector<Value> multiDimOffset =
|
||||
getMultiDimOffset(layout, loc, rewriter, elemId, type.getShape(),
|
||||
multiDimCTAInRepId, shapePerCTA);
|
||||
coord2val[elemId] = std::make_pair(multiDimOffset, vals[elemId]);
|
||||
}
|
||||
|
||||
if (needTrans) {
|
||||
auto [isARow, isBRow, isAVec4, isBVec4, mmaId] =
|
||||
mma.decodeVoltaLayoutStates();
|
||||
DotOpMmaV1ConversionHelper helper(mma);
|
||||
// do transpose
|
||||
int numM = helper.getElemsM(mma.getWarpsPerCTA()[0], shape[0], isARow,
|
||||
isAVec4);
|
||||
int numN = accumSizePerThread / numM;
|
||||
|
||||
for (int r = 0; r < numM; r++) {
|
||||
for (int c = 0; c < numN; c++) {
|
||||
coord2valT[r * numN + c] = std::move(coord2val[c * numM + r]);
|
||||
}
|
||||
}
|
||||
} else {
|
||||
coord2valT = std::move(coord2val);
|
||||
}
|
||||
}
|
||||
|
||||
// Now the coord2valT has the transposed and contiguous elements(with
|
||||
// vec=2), the original vals is not needed.
|
||||
for (unsigned elemId = 0; elemId < accumSizePerThread; elemId += vec) {
|
||||
auto coord = coord2valT[elemId].first;
|
||||
Value offset = linearize(rewriter, loc, coord, paddedRepShape, outOrd);
|
||||
auto elemPtrTy = ptr_ty(elemTy, 3);
|
||||
Value ptr = gep(elemPtrTy, smemBase, offset);
|
||||
auto vecTy = vec_ty(elemTy, vec);
|
||||
ptr = bitcast(ptr, ptr_ty(vecTy, 3));
|
||||
if (stNotRd) {
|
||||
Value valVec = undef(vecTy);
|
||||
for (unsigned v = 0; v < vec; ++v) {
|
||||
auto currVal = coord2valT[elemId + v].second;
|
||||
valVec = insert_element(vecTy, valVec, currVal, idx_val(v));
|
||||
}
|
||||
store(valVec, ptr);
|
||||
} else {
|
||||
Value valVec = load(ptr);
|
||||
for (unsigned v = 0; v < vec; ++v) {
|
||||
Value currVal = extract_element(elemTy, valVec, idx_val(v));
|
||||
vals[elemId + v] = currVal;
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// blocked/mma -> blocked/mma.
|
||||
// Data padding in shared memory to avoid bank conflict.
|
||||
LogicalResult
|
||||
@@ -293,8 +362,26 @@ private:
|
||||
SmallVector<unsigned> outNumCTAsEachRep(rank);
|
||||
SmallVector<unsigned> inNumCTAs(rank);
|
||||
SmallVector<unsigned> outNumCTAs(rank);
|
||||
auto srcShapePerCTA = getShapePerCTA(srcLayout);
|
||||
auto dstShapePerCTA = getShapePerCTA(dstLayout);
|
||||
auto srcShapePerCTA = getShapePerCTA(srcLayout, srcTy.getShape());
|
||||
auto dstShapePerCTA = getShapePerCTA(dstLayout, shape);
|
||||
|
||||
// For Volta, all the coords for a CTA are calculated.
|
||||
bool isSrcMmaV1{}, isDstMmaV1{};
|
||||
if (auto mmaLayout = srcLayout.dyn_cast<MmaEncodingAttr>()) {
|
||||
isSrcMmaV1 = mmaLayout.isVolta();
|
||||
}
|
||||
if (auto sliceLayout = srcLayout.dyn_cast<SliceEncodingAttr>()) {
|
||||
isSrcMmaV1 = sliceLayout.getParent().isa<MmaEncodingAttr>() &&
|
||||
sliceLayout.getParent().cast<MmaEncodingAttr>().isVolta();
|
||||
}
|
||||
if (auto mmaLayout = dstLayout.dyn_cast<MmaEncodingAttr>()) {
|
||||
isDstMmaV1 = mmaLayout.isVolta();
|
||||
}
|
||||
if (auto sliceLayout = dstLayout.dyn_cast<SliceEncodingAttr>()) {
|
||||
isDstMmaV1 = sliceLayout.getParent().isa<MmaEncodingAttr>() &&
|
||||
sliceLayout.getParent().cast<MmaEncodingAttr>().isVolta();
|
||||
}
|
||||
|
||||
for (unsigned d = 0; d < rank; ++d) {
|
||||
unsigned inPerCTA = std::min<unsigned>(shape[d], srcShapePerCTA[d]);
|
||||
unsigned outPerCTA = std::min<unsigned>(shape[d], dstShapePerCTA[d]);
|
||||
@@ -326,20 +413,31 @@ private:
|
||||
if (srcLayout.isa<BlockedEncodingAttr>() ||
|
||||
srcLayout.isa<SliceEncodingAttr>() ||
|
||||
srcLayout.isa<MmaEncodingAttr>()) {
|
||||
processReplica(loc, rewriter, /*stNotRd*/ true, srcTy, inNumCTAsEachRep,
|
||||
multiDimRepId, inVec, paddedRepShape, outOrd, vals,
|
||||
smemBase);
|
||||
if (isSrcMmaV1)
|
||||
processReplicaForMMAV1(loc, rewriter, /*stNotRd*/ true, srcTy,
|
||||
multiDimRepId, inVec, paddedRepShape, outOrd,
|
||||
vals, smemBase, shape);
|
||||
else
|
||||
processReplica(loc, rewriter, /*stNotRd*/ true, srcTy,
|
||||
inNumCTAsEachRep, multiDimRepId, inVec, paddedRepShape,
|
||||
outOrd, vals, smemBase);
|
||||
} else {
|
||||
assert(0 && "ConvertLayout with input layout not implemented");
|
||||
return failure();
|
||||
}
|
||||
|
||||
barrier();
|
||||
if (dstLayout.isa<BlockedEncodingAttr>() ||
|
||||
dstLayout.isa<SliceEncodingAttr>() ||
|
||||
dstLayout.isa<MmaEncodingAttr>()) {
|
||||
processReplica(loc, rewriter, /*stNotRd*/ false, dstTy,
|
||||
outNumCTAsEachRep, multiDimRepId, outVec, paddedRepShape,
|
||||
outOrd, outVals, smemBase);
|
||||
if (isDstMmaV1)
|
||||
processReplicaForMMAV1(loc, rewriter, /*stNotRd*/ false, dstTy,
|
||||
multiDimRepId, outVec, paddedRepShape, outOrd,
|
||||
outVals, smemBase, shape, /*isDestMma=*/true);
|
||||
else
|
||||
processReplica(loc, rewriter, /*stNotRd*/ false, dstTy,
|
||||
outNumCTAsEachRep, multiDimRepId, outVec,
|
||||
paddedRepShape, outOrd, outVals, smemBase);
|
||||
} else {
|
||||
assert(0 && "ConvertLayout with output layout not implemented");
|
||||
return failure();
|
||||
@@ -540,13 +638,13 @@ private:
|
||||
if (dotOperandLayout.getOpIdx() == 0) { // operand $a
|
||||
// TODO[Superjomn]: transA is not available here.
|
||||
bool transA = false;
|
||||
res = helper.loadA(src, transA, smemObj, getThreadId(rewriter, loc),
|
||||
loc, rewriter);
|
||||
res = helper.loadA(src, smemObj, getThreadId(rewriter, loc), loc,
|
||||
rewriter);
|
||||
} else if (dotOperandLayout.getOpIdx() == 1) { // operand $b
|
||||
// TODO[Superjomn]: transB is not available here.
|
||||
bool transB = false;
|
||||
res = helper.loadB(src, transB, smemObj, getThreadId(rewriter, loc),
|
||||
loc, rewriter);
|
||||
res = helper.loadB(src, smemObj, getThreadId(rewriter, loc), loc,
|
||||
rewriter);
|
||||
}
|
||||
} else {
|
||||
assert(false && "Unsupported mma layout found");
|
||||
|
||||
@@ -47,40 +47,47 @@ struct DotOpMmaV1ConversionHelper {
|
||||
: mmaLayout(mmaLayout), wpt(mmaLayout.getWarpsPerCTA()) {}
|
||||
|
||||
// Help to share some variables across multiple functions for A.
|
||||
// TODO[Superjomn]: refactor and restrict this to only use in DotOp
|
||||
// conversion.
|
||||
struct AParam {
|
||||
SmallVector<int> rep;
|
||||
SmallVector<int> spw;
|
||||
bool isAVec4{};
|
||||
int vec{}; // This could only used in DotOp, not in
|
||||
// loadA/loadB/TypeConverter
|
||||
|
||||
// TODO[Superjomn]: Support the case when isAVec4=false later
|
||||
// Currently, we only support ld.v2, for the mma layout varies with
|
||||
// different ld vector width.
|
||||
// bool isAVec4 = !isARow && shapeTransed[orderTransed[0]] <= 16;
|
||||
const bool isAVec4{true};
|
||||
AParam(bool isARow, bool isAVec4) : isAVec4(isAVec4) { build(isARow); }
|
||||
|
||||
explicit AParam(bool isARow) {
|
||||
private:
|
||||
void build(bool isARow) {
|
||||
int packSize0 = (isARow || isAVec4) ? 1 : 2;
|
||||
int repM = 2 * packSize0;
|
||||
int repK = 1;
|
||||
int spwM = fpw[0] * 4 * repM;
|
||||
rep.assign({repM, 0, repK});
|
||||
spw.assign({spwM, 0, 1});
|
||||
vec = 2 * rep[0];
|
||||
}
|
||||
};
|
||||
|
||||
// Help to share some variables across multiple functions for A.
|
||||
// TODO[Superjomn]: refactor and restrict this to only use in DotOp
|
||||
// conversion.
|
||||
struct BParam {
|
||||
SmallVector<int> rep;
|
||||
SmallVector<int> spw;
|
||||
// TODO[Superjomn]: Support the case when isBVec4=false later
|
||||
// Currently, we only support ld.v2, for the mma layout varies with
|
||||
// different ld vector width.
|
||||
// bool isBVec4 = isBRow && shapeTransed[orderTransed[0]] <= 16;
|
||||
const bool isBVec4{true};
|
||||
bool isBVec4{};
|
||||
int vec{}; // This could only used in DotOp, not in
|
||||
// loadA/loadB/TypeConverter
|
||||
|
||||
explicit BParam(bool isBRow) {
|
||||
BParam(bool isBRow, bool isBVec4) : isBVec4(isBVec4) { build(isBRow); }
|
||||
|
||||
private:
|
||||
void build(bool isBRow) {
|
||||
int packSize1 = (isBRow && !isBVec4) ? 2 : 1;
|
||||
rep.assign({0, 2 * packSize1, 1});
|
||||
spw.assign({0, fpw[1] * 4 * rep[1], 1});
|
||||
vec = 2 * rep[1];
|
||||
}
|
||||
};
|
||||
|
||||
@@ -93,13 +100,6 @@ struct DotOpMmaV1ConversionHelper {
|
||||
|
||||
static ArrayRef<unsigned> getMmaInstrShape() { return instrShape; }
|
||||
|
||||
static Type getMatType(TensorType operand) {
|
||||
auto *ctx = operand.getContext();
|
||||
Type fp16Ty = type::f16Ty(ctx);
|
||||
Type vecTy = vec_ty(fp16Ty, 2);
|
||||
return struct_ty(SmallVector<Type>{vecTy});
|
||||
}
|
||||
|
||||
static Type getMmaRetType(TensorType operand) {
|
||||
auto *ctx = operand.getContext();
|
||||
Type fp32Ty = type::f32Ty(ctx);
|
||||
@@ -107,78 +107,101 @@ struct DotOpMmaV1ConversionHelper {
|
||||
return struct_ty(SmallVector<Type>{8, fp32Ty});
|
||||
}
|
||||
|
||||
// Get the number of fp16x2 elements for $a.
|
||||
// \param shapeTransed: A's shape or reordered shape if transpose needed.
|
||||
// \param orderTransed: the order or reordered order if transpose needed.
|
||||
unsigned getNumM(ArrayRef<int64_t> shapeTransed, bool isARow) const {
|
||||
AParam param(isARow);
|
||||
static Type getMatType(TensorType operand) {
|
||||
auto *ctx = operand.getContext();
|
||||
Type fp16Ty = type::f16Ty(ctx);
|
||||
Type vecTy = vec_ty(fp16Ty, 2);
|
||||
return struct_ty(SmallVector<Type>{vecTy});
|
||||
}
|
||||
|
||||
unsigned numM = param.rep[0] * shapeTransed[0] / (param.spw[0] * wpt[0]);
|
||||
// Get the number of fp16x2 elements for $a.
|
||||
unsigned getNumM(int M, bool isARow, bool isAVec4) const {
|
||||
AParam param(isARow, isAVec4);
|
||||
|
||||
unsigned numM = param.rep[0] * M / (param.spw[0] * wpt[0]);
|
||||
return numM;
|
||||
}
|
||||
|
||||
// Get the number of fp16x2 elements for $b.
|
||||
// \param shapeTransed: B' shape or reordered shape if transpose needed.
|
||||
// \param orderTransed: the order or reordered order if transpose needed.
|
||||
unsigned getNumN(ArrayRef<int64_t> shapeTransed, bool isBRow) const {
|
||||
BParam param(isBRow);
|
||||
unsigned getNumN(int N, bool isBRow, bool isBVec4) const {
|
||||
BParam param(isBRow, isBVec4);
|
||||
|
||||
unsigned numN = param.rep[1] * shapeTransed[1] / (param.spw[1] * wpt[1]);
|
||||
unsigned numN = param.rep[1] * N / (param.spw[1] * wpt[1]);
|
||||
return numN;
|
||||
}
|
||||
|
||||
int numElemsPerThreadA(ArrayRef<int64_t> shapeTransed,
|
||||
ArrayRef<unsigned> orderTransed) const {
|
||||
int numM = getNumM(shapeTransed, orderTransed[0] == 1);
|
||||
int NK = shapeTransed[1];
|
||||
int numElemsPerThreadA(ArrayRef<int64_t> shape, bool isARow, bool isAVec4,
|
||||
int vec) const {
|
||||
int numM = getNumM(shape[0], isARow, isAVec4);
|
||||
int NK = shape[1];
|
||||
// Here we mimic the logic in loadA, the result cannot be calculated
|
||||
// directly.
|
||||
llvm::DenseSet<std::pair<int, int>> visited;
|
||||
auto ld = [&](int m, int k) {
|
||||
visited.insert({m, k});
|
||||
if (vec > 4) {
|
||||
if (isARow)
|
||||
visited.insert({m, k + 4});
|
||||
else
|
||||
visited.insert({m + 1, k});
|
||||
}
|
||||
};
|
||||
|
||||
// NOTE: We couldn't get the vec from the shared layout.
|
||||
// int vecA = sharedLayout.getVec();
|
||||
// TODO[Superjomn]: Consider the case when vecA > 4
|
||||
bool vecGt4 = false;
|
||||
int elemsPerLd = vecGt4 ? 4 : 2;
|
||||
return (numM / 2) * (NK / 4) * elemsPerLd;
|
||||
for (unsigned k = 0; k < NK; k += 4)
|
||||
for (unsigned m = 0; m < numM / 2; ++m)
|
||||
if (!visited.count({m, k}))
|
||||
ld(m, k);
|
||||
|
||||
return visited.size() * 2;
|
||||
}
|
||||
|
||||
int numElemsPerThreadB(ArrayRef<int64_t> shapeTransed,
|
||||
ArrayRef<unsigned> orderTransed) const {
|
||||
unsigned numN = getNumN(shapeTransed, orderTransed[0] == 1);
|
||||
int NK = shapeTransed[0];
|
||||
// NOTE: We couldn't get the vec from the shared layout.
|
||||
// int vecB = sharedLayout.getVec();
|
||||
// TODO[Superjomn]: Consider the case when vecA > 4
|
||||
bool vecGt4 = false;
|
||||
int elemsPerLd = vecGt4 ? 4 : 2;
|
||||
return (numN / 2) * (NK / 4) * elemsPerLd;
|
||||
int numElemsPerThreadB(ArrayRef<int64_t> shape, bool isBRow, bool isBVec4,
|
||||
int vec) const {
|
||||
unsigned numN = getNumN(shape[1], isBRow, isBVec4);
|
||||
int NK = shape[0];
|
||||
// Here we mimic the logic in loadA, the result cannot be calculated
|
||||
// directly.
|
||||
llvm::DenseSet<std::pair<int, int>> visited;
|
||||
int elemsPerLd = vec > 4 ? 4 : 2;
|
||||
auto ld = [&](int n, int k) {
|
||||
visited.insert({n, k});
|
||||
if (vec > 4) {
|
||||
if (isBRow)
|
||||
visited.insert({n + 1, k});
|
||||
else
|
||||
visited.insert({n, k + 4});
|
||||
}
|
||||
};
|
||||
|
||||
for (unsigned k = 0; k < NK; k += 4)
|
||||
for (unsigned n = 0; n < numN / 2; ++n) {
|
||||
if (!visited.count({n, k}))
|
||||
ld(n, k);
|
||||
}
|
||||
|
||||
return visited.size() * 2;
|
||||
}
|
||||
|
||||
// Loading $a from smem to registers, returns a LLVM::Struct.
|
||||
Value loadA(Value tensor, bool transA, const SharedMemoryObject &smemObj,
|
||||
Value thread, Location loc,
|
||||
ConversionPatternRewriter &rewriter) const {
|
||||
Value loadA(Value tensor, const SharedMemoryObject &smemObj, Value thread,
|
||||
Location loc, ConversionPatternRewriter &rewriter) const {
|
||||
auto *ctx = rewriter.getContext();
|
||||
auto tensorTy = tensor.getType().cast<RankedTensorType>();
|
||||
auto sharedLayout = tensorTy.getEncoding().cast<SharedEncodingAttr>();
|
||||
SmallVector<int64_t> shape(tensorTy.getShape().begin(),
|
||||
tensorTy.getShape().end());
|
||||
SmallVector<unsigned> order(sharedLayout.getOrder().begin(),
|
||||
sharedLayout.getOrder().end());
|
||||
auto shape = tensorTy.getShape();
|
||||
auto order = sharedLayout.getOrder();
|
||||
|
||||
Value cSwizzleOffset = smemObj.getCSwizzleOffset(order[0]);
|
||||
Value smemBase = smemObj.getBaseBeforeSwizzle(order[0], loc, rewriter);
|
||||
|
||||
bool isARow = order[0] != 0;
|
||||
AParam param(isARow);
|
||||
auto [isARow_, _0, isAVec4, _1, _2] = mmaLayout.decodeVoltaLayoutStates();
|
||||
|
||||
auto [offsetAM, offsetAK, _0, _1] = computeOffsets(
|
||||
AParam param(isARow_, isAVec4);
|
||||
|
||||
auto [offsetAM, offsetAK, _3, _4] = computeOffsets(
|
||||
thread, isARow, false, fpw, param.spw, param.rep, rewriter, loc);
|
||||
|
||||
if (transA) {
|
||||
std::swap(shape[0], shape[1]);
|
||||
std::swap(offsetAM, offsetAK);
|
||||
std::swap(order[0], order[1]);
|
||||
}
|
||||
|
||||
int vecA = sharedLayout.getVec();
|
||||
|
||||
auto strides = smemObj.strides;
|
||||
@@ -254,10 +277,11 @@ struct DotOpMmaV1ConversionHelper {
|
||||
}
|
||||
};
|
||||
|
||||
unsigned numM = getNumM(shape, order[0] == 1);
|
||||
unsigned numM = getNumM(shape[0], isARow, isAVec4);
|
||||
for (unsigned k = 0; k < NK; k += 4)
|
||||
for (unsigned m = 0; m < numM / 2; ++m)
|
||||
loadA(m, k);
|
||||
if (!has.count({m, k}))
|
||||
loadA(m, k);
|
||||
|
||||
SmallVector<Value> elems;
|
||||
elems.reserve(has.size() * 2);
|
||||
@@ -272,9 +296,8 @@ struct DotOpMmaV1ConversionHelper {
|
||||
}
|
||||
|
||||
// Loading $b from smem to registers, returns a LLVM::Struct.
|
||||
Value loadB(Value tensor, bool transB, const SharedMemoryObject &smemObj,
|
||||
Value thread, Location loc,
|
||||
ConversionPatternRewriter &rewriter) const {
|
||||
Value loadB(Value tensor, const SharedMemoryObject &smemObj, Value thread,
|
||||
Location loc, ConversionPatternRewriter &rewriter) const {
|
||||
// smem
|
||||
auto strides = smemObj.strides;
|
||||
|
||||
@@ -282,14 +305,16 @@ struct DotOpMmaV1ConversionHelper {
|
||||
auto tensorTy = tensor.getType().cast<RankedTensorType>();
|
||||
auto sharedLayout = tensorTy.getEncoding().cast<SharedEncodingAttr>();
|
||||
|
||||
SmallVector<int64_t> shape(tensorTy.getShape().begin(),
|
||||
tensorTy.getShape().end());
|
||||
SmallVector<unsigned> order(sharedLayout.getOrder().begin(),
|
||||
sharedLayout.getOrder().end());
|
||||
auto shape = tensorTy.getShape();
|
||||
auto order = sharedLayout.getOrder();
|
||||
|
||||
Value smem = smemObj.getBaseBeforeSwizzle(order[0], loc, rewriter);
|
||||
bool isBRow = order[0] != 0;
|
||||
BParam param(isBRow);
|
||||
bool isBRow = order[0] != 0; // is row-major in shared memory layout
|
||||
// isBRow_ indicates whether B is row-major in DotOperand layout
|
||||
auto [_0, isBRow_, _1, isBVec4, _2] = mmaLayout.decodeVoltaLayoutStates();
|
||||
assert(isBRow == isBRow_ && "B need smem isRow");
|
||||
|
||||
BParam param(isBRow_, isBVec4);
|
||||
|
||||
int vecB = sharedLayout.getVec();
|
||||
Value strideBN = isBRow ? i32_val(1) : strides[1];
|
||||
@@ -299,13 +324,8 @@ struct DotOpMmaV1ConversionHelper {
|
||||
int strideRepN = wpt[1] * fpw[1] * 8;
|
||||
int strideRepK = 1;
|
||||
|
||||
auto [_0, _1, offsetBN, offsetBK] = computeOffsets(
|
||||
auto [_3, _4, offsetBN, offsetBK] = computeOffsets(
|
||||
thread, false, isBRow, fpw, param.spw, param.rep, rewriter, loc);
|
||||
if (transB) {
|
||||
std::swap(order[0], order[1]);
|
||||
std::swap(shape[0], shape[1]);
|
||||
std::swap(offsetBK, offsetBN);
|
||||
}
|
||||
|
||||
// swizzling
|
||||
int perPhaseB = sharedLayout.getPerPhase();
|
||||
@@ -371,7 +391,7 @@ struct DotOpMmaV1ConversionHelper {
|
||||
}
|
||||
};
|
||||
|
||||
unsigned numN = getNumN(shape, order[0] == 1);
|
||||
unsigned numN = getNumN(shape[1], isBRow, isBVec4);
|
||||
for (unsigned k = 0; k < NK; k += 4)
|
||||
for (unsigned n = 0; n < numN / 2; ++n) {
|
||||
if (!hbs.count({n, k}))
|
||||
@@ -383,6 +403,7 @@ struct DotOpMmaV1ConversionHelper {
|
||||
elems.push_back(item.second.first);
|
||||
elems.push_back(item.second.second);
|
||||
}
|
||||
|
||||
Type resTy = struct_ty(SmallVector<Type>(elems.size(), elemX2Ty));
|
||||
Value res = getStructFromElements(loc, elems, rewriter, resTy);
|
||||
return res;
|
||||
@@ -475,6 +496,117 @@ struct DotOpMmaV1ConversionHelper {
|
||||
return rcds;
|
||||
}
|
||||
|
||||
// Get the number of elements of this thread in M axis. The N axis could be
|
||||
// further deduced with the accSize / elemsM. \param wpt: the wpt in M axis
|
||||
// \param M: the shape in M axis
|
||||
int getElemsM(int wpt, int M, bool isARow, bool isAVec4) {
|
||||
DotOpMmaV1ConversionHelper::AParam param(isARow, isAVec4);
|
||||
int shapePerCTAM = param.spw[0] * wpt;
|
||||
return M / shapePerCTAM * param.rep[0];
|
||||
}
|
||||
|
||||
using CoordTy = SmallVector<Value, 2>;
|
||||
// Get the coordinates(m,n) of the elements emit by a thread in accumulator.
|
||||
static SmallVector<CoordTy>
|
||||
getMNCoords(Value thread, ConversionPatternRewriter &rewriter,
|
||||
ArrayRef<unsigned> wpt, ArrayRef<int64_t> shape, bool isARow,
|
||||
bool isBRow, bool isAVec4, bool isBVec4) {
|
||||
|
||||
auto *ctx = thread.getContext();
|
||||
auto loc = UnknownLoc::get(ctx);
|
||||
Value _1 = i32_val(1);
|
||||
Value _2 = i32_val(2);
|
||||
Value _4 = i32_val(4);
|
||||
Value _16 = i32_val(16);
|
||||
Value _32 = i32_val(32);
|
||||
Value _fpw0 = i32_val(fpw[0]);
|
||||
Value _fpw1 = i32_val(fpw[1]);
|
||||
|
||||
DotOpMmaV1ConversionHelper::AParam aParam(isARow, isAVec4);
|
||||
DotOpMmaV1ConversionHelper::BParam bParam(isBRow, isBVec4);
|
||||
|
||||
SmallVector<int, 2> rep({aParam.rep[0], bParam.rep[1]});
|
||||
SmallVector<int, 2> spw({aParam.spw[0], bParam.spw[1]});
|
||||
SmallVector<unsigned, 2> shapePerCTA({spw[0] * wpt[0], spw[1] * wpt[1]});
|
||||
|
||||
Value lane = urem(thread, _32);
|
||||
Value warp = udiv(thread, _32);
|
||||
|
||||
Value warp0 = urem(warp, i32_val(wpt[0]));
|
||||
Value warp12 = udiv(warp, i32_val(wpt[0]));
|
||||
Value warp1 = urem(warp12, i32_val(wpt[1]));
|
||||
|
||||
// warp offset
|
||||
Value offWarpM = mul(warp0, i32_val(spw[0]));
|
||||
Value offWarpN = mul(warp1, i32_val(spw[1]));
|
||||
// quad offset
|
||||
Value offQuadM = mul(udiv(and_(lane, _16), _4), _fpw0);
|
||||
Value offQuadN = mul(udiv(and_(lane, _16), _4), _fpw1);
|
||||
// pair offset
|
||||
Value offPairM = udiv(urem(lane, _16), _4);
|
||||
offPairM = urem(offPairM, _fpw0);
|
||||
offPairM = mul(offPairM, _4);
|
||||
Value offPairN = udiv(urem(lane, _16), _4);
|
||||
offPairN = udiv(offPairN, _fpw0);
|
||||
offPairN = urem(offPairN, _fpw1);
|
||||
offPairN = mul(offPairN, _4);
|
||||
|
||||
// sclare
|
||||
offPairM = mul(offPairM, i32_val(rep[0] / 2));
|
||||
offQuadM = mul(offQuadM, i32_val(rep[0] / 2));
|
||||
offPairN = mul(offPairN, i32_val(rep[1] / 2));
|
||||
offQuadN = mul(offQuadN, i32_val(rep[1] / 2));
|
||||
|
||||
// quad pair offset
|
||||
Value offLaneM = add(offPairM, offQuadM);
|
||||
Value offLaneN = add(offPairN, offQuadN);
|
||||
// a, b offset
|
||||
Value offsetAM = add(offWarpM, offLaneM);
|
||||
Value offsetBN = add(offWarpN, offLaneN);
|
||||
// m indices
|
||||
Value offsetCM = add(and_(lane, _1), offsetAM);
|
||||
SmallVector<Value> idxM;
|
||||
for (unsigned m = 0; m < shape[0]; m += shapePerCTA[0])
|
||||
for (unsigned mm = 0; mm < rep[0]; ++mm)
|
||||
idxM.push_back(add(offsetCM, i32_val(m + mm * 2)));
|
||||
|
||||
// n indices
|
||||
Value offsetCN = add((and_(lane, _2)), (add(offWarpN, offPairN)));
|
||||
SmallVector<Value> idxN;
|
||||
for (int n = 0; n < shape[1]; n += shapePerCTA[1]) {
|
||||
for (int nn = 0; nn < rep[1]; ++nn) {
|
||||
idxN.push_back(add(offsetCN, i32_val(n + nn / 2 * 4 +
|
||||
(nn % 2) * 2 * fpw[1] * rep[1])));
|
||||
idxN.push_back(
|
||||
add(offsetCN,
|
||||
i32_val(n + nn / 2 * 4 + (nn % 2) * 2 * fpw[1] * rep[1] + 1)));
|
||||
}
|
||||
}
|
||||
|
||||
SmallVector<SmallVector<Value>> axes({idxM, idxN});
|
||||
|
||||
// product the axis M and axis N to get coords, ported from
|
||||
// generator::init_idx method from triton2.0
|
||||
|
||||
// TODO[Superjomn]: check the order.
|
||||
SmallVector<CoordTy> coords;
|
||||
for (Value x1 : axes[1]) { // N
|
||||
for (Value x0 : axes[0]) { // M
|
||||
SmallVector<Value, 2> idx(2);
|
||||
idx[0] = x0; // M
|
||||
idx[1] = x1; // N
|
||||
coords.push_back(std::move(idx));
|
||||
}
|
||||
}
|
||||
|
||||
return coords; // {M,N} in row-major
|
||||
}
|
||||
|
||||
// \param elemId the offset of the element in a thread
|
||||
static CoordTy getCoord(int elemId, ArrayRef<CoordTy> coords) {
|
||||
return coords[elemId];
|
||||
}
|
||||
|
||||
private:
|
||||
static constexpr unsigned instrShape[] = {16, 16, 4};
|
||||
static constexpr unsigned mmaOrder[] = {0, 1};
|
||||
|
||||
@@ -120,11 +120,15 @@ private:
|
||||
|
||||
bool isARow = ALayout.getIsMMAv1Row().cast<BoolAttr>().getValue();
|
||||
bool isBRow = BLayout.getIsMMAv1Row().cast<BoolAttr>().getValue();
|
||||
auto [isARow_, isBRow_, isAVec4_, isBVec4_, mmaId] =
|
||||
mmaLayout.decodeVoltaLayoutStates();
|
||||
assert(isARow == isARow_);
|
||||
assert(isBRow == isBRow_);
|
||||
|
||||
DotOpMmaV1ConversionHelper helper(mmaLayout);
|
||||
|
||||
unsigned numM = helper.getNumM(AShape, isARow);
|
||||
unsigned numN = helper.getNumN(BShape, isBRow);
|
||||
unsigned numM = helper.getNumM(AShape[0], isARow, isAVec4_);
|
||||
unsigned numN = helper.getNumN(BShape[1], isBRow, isBVec4_);
|
||||
unsigned NK = AShape[1];
|
||||
|
||||
auto has = helper.extractLoadedOperand(adaptor.a(), NK, rewriter);
|
||||
@@ -156,20 +160,6 @@ private:
|
||||
return idx;
|
||||
};
|
||||
|
||||
{ // convert the acc's value from accumuator-external order to
|
||||
// accumulator-internal order.
|
||||
SmallVector<Value> accInit(acc.size());
|
||||
|
||||
for (unsigned m = 0; m < numM / 2; ++m)
|
||||
for (unsigned n = 0; n < numN / 2; ++n) {
|
||||
auto idx = getIdx(m, n);
|
||||
for (unsigned i = 0; i < 8; ++i)
|
||||
accInit[idx[i]] = acc[(m * numN / 2 + n) * 8 + i];
|
||||
}
|
||||
|
||||
acc = accInit;
|
||||
}
|
||||
|
||||
auto callMMA = [&](unsigned m, unsigned n, unsigned k) {
|
||||
auto ha = has.at({m, k});
|
||||
auto hb = hbs.at({n, k});
|
||||
@@ -206,7 +196,6 @@ private:
|
||||
for (auto i = 0; i < 8; i++) {
|
||||
Value elem = extract_val(f32_ty, res, i32_arr_attr(i));
|
||||
acc[idx[i]] = elem;
|
||||
resVals[(m * numN / 2 + n) * 8 + i] = elem;
|
||||
}
|
||||
};
|
||||
|
||||
@@ -216,6 +205,11 @@ private:
|
||||
callMMA(m, n, k);
|
||||
}
|
||||
|
||||
// res holds the same layout of acc
|
||||
for (size_t i = 0; i < acc.size(); ++i) {
|
||||
resVals[i] = acc[i];
|
||||
}
|
||||
|
||||
Type structTy = LLVM::LLVMStructType::getLiteral(
|
||||
ctx, SmallVector<Type>(resSize, type::f32Ty(ctx)));
|
||||
Value res = getStructFromElements(loc, resVals, rewriter, structTy);
|
||||
|
||||
@@ -69,10 +69,42 @@ struct BroadcastOpConversion
|
||||
auto srcOffsets = emitOffsetForLayout(srcLayout, srcShape);
|
||||
auto resultOffsets = emitOffsetForLayout(resultLayout, resultShape);
|
||||
SmallVector<Value> srcVals = getElementsFromStruct(loc, src, rewriter);
|
||||
if (auto srcMma = srcLayout.dyn_cast<MmaEncodingAttr>()) {
|
||||
|
||||
// NOTE: This is just an naive fix, but for MMA layout, and 2-d fix should
|
||||
// be all right.
|
||||
// TODO[Superjomn]: Replace this with a generic implementation.
|
||||
if (srcMma.isVolta()) {
|
||||
assert(srcTy.getElementType().isF16() &&
|
||||
"Unexpected data type on Volta");
|
||||
int numElemsPerThread = srcMma.getElemsPerThread(resultTy.getShape());
|
||||
int srcUniqElems = srcVals.size() / 2;
|
||||
int dup = numElemsPerThread / srcUniqElems;
|
||||
SmallVector<Value> retVals;
|
||||
if (srcShape[0] == 1) { // add-cols
|
||||
for (int i = 0; i < srcUniqElems; ++i)
|
||||
for (int k = 0; k < dup; ++k)
|
||||
retVals.push_back(srcVals[i * 2]);
|
||||
|
||||
} else { // add-rows
|
||||
for (int k = 0; k < dup; ++k)
|
||||
for (int i = 0; i < srcUniqElems; ++i)
|
||||
retVals.push_back(srcVals[i]);
|
||||
}
|
||||
|
||||
auto llvmStructTy = getTypeConverter()->convertType(resultTy);
|
||||
Value ret = getStructFromElements(loc, retVals, rewriter, llvmStructTy);
|
||||
|
||||
rewriter.replaceOp(op, {ret});
|
||||
return success();
|
||||
}
|
||||
}
|
||||
|
||||
DenseMap<SmallVector<unsigned>, Value, SmallVectorKeyInfo> srcValues;
|
||||
for (size_t i = 0; i < srcOffsets.size(); i++) {
|
||||
srcValues[srcOffsets[i]] = srcVals[i];
|
||||
}
|
||||
|
||||
SmallVector<Value> resultVals;
|
||||
for (size_t i = 0; i < resultOffsets.size(); i++) {
|
||||
auto offset = resultOffsets[i];
|
||||
@@ -81,6 +113,7 @@ struct BroadcastOpConversion
|
||||
offset[j] = 0;
|
||||
resultVals.push_back(srcValues.lookup(offset));
|
||||
}
|
||||
|
||||
auto llvmStructTy = getTypeConverter()->convertType(resultTy);
|
||||
Value resultStruct =
|
||||
getStructFromElements(loc, resultVals, rewriter, llvmStructTy);
|
||||
@@ -523,6 +556,29 @@ struct AsyncCommitGroupOpConversion
|
||||
}
|
||||
};
|
||||
|
||||
namespace mlir {
|
||||
namespace LLVM {
|
||||
|
||||
void vprintf(StringRef msg, ValueRange args,
|
||||
ConversionPatternRewriter &rewriter) {
|
||||
PrintfOpConversion::llPrintf(msg, args, rewriter);
|
||||
}
|
||||
|
||||
void vprintf_array(Value thread, ArrayRef<Value> arr, std::string info,
|
||||
std::string elem_repr, ConversionPatternRewriter &builder) {
|
||||
std::string fmt = info + " t-%d ";
|
||||
std::vector<Value> new_arr({thread});
|
||||
for (int i = 0; i < arr.size(); ++i) {
|
||||
fmt += elem_repr + ((i == arr.size() - 1) ? "" : ", ");
|
||||
new_arr.push_back(arr[i]);
|
||||
}
|
||||
|
||||
vprintf(fmt, new_arr, builder);
|
||||
}
|
||||
|
||||
} // namespace LLVM
|
||||
} // namespace mlir
|
||||
|
||||
void populateTritonGPUToLLVMPatterns(
|
||||
mlir::LLVMTypeConverter &typeConverter, RewritePatternSet &patterns,
|
||||
int numWarps, AxisInfoAnalysis &axisInfoAnalysis,
|
||||
@@ -543,4 +599,4 @@ void populateTritonGPUToLLVMPatterns(
|
||||
patterns.add<MakeRangeOpConversion>(typeConverter, indexCacheInfo, benefit);
|
||||
patterns.add<ReturnOpConversion>(typeConverter, benefit);
|
||||
patterns.add<PrintfOpConversion>(typeConverter, benefit);
|
||||
}
|
||||
}
|
||||
|
||||
@@ -18,6 +18,20 @@ using ::mlir::LLVM::SharedMemoryObject;
|
||||
using ::mlir::triton::gpu::BlockedEncodingAttr;
|
||||
using ::mlir::triton::gpu::MmaEncodingAttr;
|
||||
using ::mlir::triton::gpu::SliceEncodingAttr;
|
||||
|
||||
namespace mlir {
|
||||
namespace LLVM {
|
||||
|
||||
// Helper function for using printf in LLVM conversion.
|
||||
void vprintf(StringRef msg, ValueRange args,
|
||||
ConversionPatternRewriter &rewriter);
|
||||
|
||||
void vprintf_array(Value thread, ArrayRef<Value> arr, std::string info,
|
||||
std::string elem_repr, ConversionPatternRewriter &builder);
|
||||
|
||||
} // namespace LLVM
|
||||
} // namespace mlir
|
||||
|
||||
// FuncOpConversion/FuncOpConversionBase is borrowed from
|
||||
// https://github.com/llvm/llvm-project/blob/fae656b2dd80246c3c6f01e9c77c49560368752c/mlir/lib/Conversion/FuncToLLVM/FuncToLLVM.cpp#L276
|
||||
// since it is not exposed on header files in mlir v14
|
||||
@@ -199,6 +213,7 @@ public:
|
||||
ValueRange{rewriter.create<::mlir::gpu::ThreadIdOp>(
|
||||
loc, rewriter.getIndexType(), ::mlir::gpu::Dimension::x)});
|
||||
Value threadId = cast.getResult(0);
|
||||
|
||||
return threadId;
|
||||
}
|
||||
|
||||
@@ -688,8 +703,10 @@ private:
|
||||
ArrayRef<int64_t> shape) const {
|
||||
SmallVector<SmallVector<unsigned>> ret;
|
||||
|
||||
for (unsigned i = 0; i < shape[0]; i += getShapePerCTA(mmaLayout)[0]) {
|
||||
for (unsigned j = 0; j < shape[1]; j += getShapePerCTA(mmaLayout)[1]) {
|
||||
for (unsigned i = 0; i < shape[0];
|
||||
i += getShapePerCTA(mmaLayout, shape)[0]) {
|
||||
for (unsigned j = 0; j < shape[1];
|
||||
j += getShapePerCTA(mmaLayout, shape)[1]) {
|
||||
ret.push_back({i, j});
|
||||
ret.push_back({i, j + 1});
|
||||
ret.push_back({i + 2, j});
|
||||
@@ -700,6 +717,7 @@ private:
|
||||
ret.push_back({i + 2, j + 9});
|
||||
}
|
||||
}
|
||||
|
||||
return ret;
|
||||
}
|
||||
|
||||
@@ -751,6 +769,9 @@ private:
|
||||
SmallVector<SmallVector<Value>> emitIndicesForDistributedLayout(
|
||||
Location loc, ConversionPatternRewriter &rewriter,
|
||||
const Attribute &layout, ArrayRef<int64_t> shape) const {
|
||||
if (auto mmaLayout = layout.template dyn_cast<MmaEncodingAttr>()) {
|
||||
assert(!mmaLayout.isVolta());
|
||||
}
|
||||
|
||||
// step 1, delinearize threadId to get the base index
|
||||
auto multiDimBase = emitBaseIndexForLayout(loc, rewriter, layout, shape);
|
||||
|
||||
@@ -113,24 +113,20 @@ public:
|
||||
}
|
||||
|
||||
if (mmaLayout.isVolta()) {
|
||||
auto [isARow, isBRow, isAVec4, isBVec4, mmaId] =
|
||||
mmaLayout.decodeVoltaLayoutStates();
|
||||
DotOpMmaV1ConversionHelper helper(mmaLayout);
|
||||
|
||||
// TODO[Superjomn]: Both transA and transB are not available here.
|
||||
bool trans = false;
|
||||
// TODO[Superjomn]: The order of A and B are not available here.
|
||||
SmallVector<unsigned> order({1, 0});
|
||||
if (trans) {
|
||||
std::swap(shape[0], shape[1]);
|
||||
std::swap(order[0], order[1]);
|
||||
}
|
||||
|
||||
if (dotOpLayout.getOpIdx() == 0) { // $a
|
||||
int elems = helper.numElemsPerThreadA(shape, order);
|
||||
DotOpMmaV1ConversionHelper::AParam param(isARow, isAVec4);
|
||||
int elems =
|
||||
helper.numElemsPerThreadA(shape, isARow, isAVec4, param.vec);
|
||||
Type x2Ty = vec_ty(elemTy, 2);
|
||||
return struct_ty(SmallVector<Type>(elems, x2Ty));
|
||||
}
|
||||
if (dotOpLayout.getOpIdx() == 1) { // $b
|
||||
int elems = helper.numElemsPerThreadB(shape, order);
|
||||
DotOpMmaV1ConversionHelper::BParam param(isBRow, isBVec4);
|
||||
int elems =
|
||||
helper.numElemsPerThreadB(shape, isBRow, isBVec4, param.vec);
|
||||
Type x2Ty = vec_ty(elemTy, 2);
|
||||
return struct_ty(SmallVector<Type>(elems, x2Ty));
|
||||
}
|
||||
|
||||
@@ -61,6 +61,8 @@ struct SplatOpConversion
|
||||
ConversionPatternRewriter &rewriter, Location loc) {
|
||||
auto tensorTy = resType.cast<RankedTensorType>();
|
||||
auto shape = tensorTy.getShape();
|
||||
auto dotOperand =
|
||||
tensorTy.getEncoding().cast<triton::gpu::DotOperandEncodingAttr>();
|
||||
auto parent = layout.getParent();
|
||||
Value retVal = constVal;
|
||||
Type retTy = elemType;
|
||||
@@ -78,11 +80,21 @@ struct SplatOpConversion
|
||||
matTy = helper.getMatType();
|
||||
} else if (mmaLayout.isVolta()) {
|
||||
DotOpMmaV1ConversionHelper helper(mmaLayout);
|
||||
numElems = layout.getOpIdx() == 0
|
||||
? helper.numElemsPerThreadA(shape, {0, 1})
|
||||
: helper.numElemsPerThreadB(shape, {0, 1});
|
||||
bool isRow = layout.getIsMMAv1Row().cast<BoolAttr>().getValue();
|
||||
auto [isARow, isBRow, isAVec4, isBVec4, _0] =
|
||||
mmaLayout.decodeVoltaLayoutStates();
|
||||
if (layout.getOpIdx() == 0) {
|
||||
DotOpMmaV1ConversionHelper::AParam aParam(isARow, isAVec4);
|
||||
numElems =
|
||||
helper.numElemsPerThreadA(shape, isARow, isAVec4, aParam.vec);
|
||||
} else {
|
||||
DotOpMmaV1ConversionHelper::BParam bParam(isBRow, isBVec4);
|
||||
numElems =
|
||||
helper.numElemsPerThreadB(shape, isBRow, isBVec4, bParam.vec);
|
||||
}
|
||||
matTy = helper.getMatType(tensorTy);
|
||||
}
|
||||
|
||||
auto numPackedElems = matTy.cast<LLVM::LLVMStructType>()
|
||||
.getBody()[0]
|
||||
.cast<VectorType>()
|
||||
@@ -92,6 +104,7 @@ struct SplatOpConversion
|
||||
for (auto i = 0; i < numPackedElems; ++i) {
|
||||
retVal = insert_element(retTy, retVal, constVal, i32_val(i));
|
||||
}
|
||||
|
||||
} else if (auto blockedLayout = parent.dyn_cast<BlockedEncodingAttr>()) {
|
||||
numElems = DotOpFMAConversionHelper::getNumElemsPerThread(shape, layout);
|
||||
} else {
|
||||
|
||||
@@ -112,9 +112,7 @@ SmallVector<unsigned> getSizePerThread(const Attribute &layout) {
|
||||
if (mmaLayout.isAmpere()) {
|
||||
return {2, 2};
|
||||
} else if (mmaLayout.isVolta()) {
|
||||
// Note: here the definition of sizePerThread is obscure, which doesn't
|
||||
// mean vecSize=4 can be supported in the last dimension.
|
||||
return {2, 4};
|
||||
return {1, 2};
|
||||
} else {
|
||||
llvm_unreachable("Unexpected mma version");
|
||||
}
|
||||
@@ -173,7 +171,8 @@ SmallVector<unsigned> getThreadsPerCTA(const Attribute &layout) {
|
||||
return threads;
|
||||
}
|
||||
|
||||
SmallVector<unsigned> getShapePerCTA(const Attribute &layout) {
|
||||
SmallVector<unsigned> getShapePerCTA(const Attribute &layout,
|
||||
ArrayRef<int64_t> tensorShape) {
|
||||
SmallVector<unsigned> shape;
|
||||
if (auto blockedLayout = layout.dyn_cast<BlockedEncodingAttr>()) {
|
||||
for (unsigned d = 0, n = blockedLayout.getOrder().size(); d < n; ++d)
|
||||
@@ -186,15 +185,20 @@ SmallVector<unsigned> getShapePerCTA(const Attribute &layout) {
|
||||
for (unsigned d = 0, n = getOrder(parent).size(); d < n; ++d) {
|
||||
if (d == dim)
|
||||
continue;
|
||||
shape.push_back(getShapePerCTA(parent)[d]);
|
||||
shape.push_back(getShapePerCTA(parent, tensorShape)[d]);
|
||||
}
|
||||
} else if (auto mmaLayout = layout.dyn_cast<MmaEncodingAttr>()) {
|
||||
if (mmaLayout.isAmpere())
|
||||
return {16 * mmaLayout.getWarpsPerCTA()[0],
|
||||
8 * mmaLayout.getWarpsPerCTA()[1]};
|
||||
if (mmaLayout.isVolta())
|
||||
return {16 * mmaLayout.getWarpsPerCTA()[0],
|
||||
16 * mmaLayout.getWarpsPerCTA()[1]};
|
||||
if (mmaLayout.isVolta()) {
|
||||
assert(!tensorShape.empty() && "Volta needs the tensorShape");
|
||||
if (tensorShape.size() == 1) // must be SliceEncoding
|
||||
return {static_cast<unsigned>(tensorShape[0]),
|
||||
static_cast<unsigned>(tensorShape[0])};
|
||||
return {static_cast<unsigned>(tensorShape[0]),
|
||||
static_cast<unsigned>(tensorShape[1])};
|
||||
}
|
||||
assert(0 && "Unexpected MMA layout version found");
|
||||
} else if (auto dotLayout = layout.dyn_cast<DotOperandEncodingAttr>()) {
|
||||
auto parentLayout = dotLayout.getParent();
|
||||
@@ -202,7 +206,7 @@ SmallVector<unsigned> getShapePerCTA(const Attribute &layout) {
|
||||
if (auto parentMmaLayout = parentLayout.dyn_cast<MmaEncodingAttr>()) {
|
||||
assert(parentMmaLayout.isAmpere() &&
|
||||
"mmaLayout version = 1 is not implemented yet");
|
||||
auto parentShapePerCTA = getShapePerCTA(parentLayout);
|
||||
auto parentShapePerCTA = getShapePerCTA(parentLayout, tensorShape);
|
||||
auto opIdx = dotLayout.getOpIdx();
|
||||
if (opIdx == 0) {
|
||||
return {parentShapePerCTA[0], 16};
|
||||
|
||||
@@ -887,31 +887,8 @@ SmallVector<int64_t, 2> mmaVersionToShapePerWarp(int version) {
|
||||
|
||||
SmallVector<unsigned, 2> warpsPerTileV1(const ArrayRef<int64_t> shape,
|
||||
int numWarps) {
|
||||
if (!MmaEncodingAttr::_mmaV1UpdateWpt) {
|
||||
SmallVector<unsigned, 2> ret = {1, 1};
|
||||
SmallVector<int64_t, 2> shapePerWarp =
|
||||
mmaVersionToShapePerWarp(1 /*version*/);
|
||||
bool changed = false;
|
||||
do {
|
||||
changed = false;
|
||||
int pre = ret[0];
|
||||
if (ret[0] * ret[1] < numWarps) {
|
||||
ret[0] =
|
||||
std::clamp<unsigned>(ret[0] * 2, 1, shape[0] / shapePerWarp[0]);
|
||||
changed = pre != ret[0];
|
||||
}
|
||||
if (ret[0] * ret[1] < numWarps) {
|
||||
pre = ret[1];
|
||||
ret[1] =
|
||||
std::clamp<unsigned>(ret[1] * 2, 1, shape[1] / shapePerWarp[1]);
|
||||
changed = pre != ret[1];
|
||||
}
|
||||
} while (changed);
|
||||
return ret;
|
||||
} else {
|
||||
// Set a default value and ensure product of wpt equals numWarps
|
||||
return {static_cast<unsigned>(numWarps), 1};
|
||||
}
|
||||
// Set a default value and ensure product of wpt equals numWarps
|
||||
return {static_cast<unsigned>(numWarps), 1};
|
||||
}
|
||||
|
||||
SmallVector<unsigned, 2> warpsPerTileV2(triton::DotOp dotOp,
|
||||
@@ -1107,13 +1084,8 @@ public:
|
||||
getWarpsPerTile(dotOp, retShape, versionMajor, numWarps);
|
||||
triton::gpu::MmaEncodingAttr mmaEnc;
|
||||
if (versionMajor == 1) {
|
||||
if (MmaEncodingAttr::_mmaV1UpdateWpt)
|
||||
mmaEnc = triton::gpu::MmaEncodingAttr::get(
|
||||
oldRetType.getContext(), versionMajor, numWarps, mmaV1Counter++);
|
||||
else
|
||||
mmaEnc = triton::gpu::MmaEncodingAttr::get(
|
||||
dotOp->getContext(), versionMajor, 0 /*versionMinor*/,
|
||||
warpsPerTileV1(retShape, numWarps));
|
||||
mmaEnc = triton::gpu::MmaEncodingAttr::get(
|
||||
oldRetType.getContext(), versionMajor, numWarps, mmaV1Counter++);
|
||||
} else if (versionMajor == 2) {
|
||||
mmaEnc = triton::gpu::MmaEncodingAttr::get(
|
||||
oldRetType.getContext(), versionMajor, 0 /*versionMinor*/,
|
||||
|
||||
@@ -13,6 +13,7 @@ using triton::DotOp;
|
||||
using triton::gpu::ConvertLayoutOp;
|
||||
using triton::gpu::DotOperandEncodingAttr;
|
||||
using triton::gpu::MmaEncodingAttr;
|
||||
using triton::gpu::SharedEncodingAttr;
|
||||
using triton::gpu::SliceEncodingAttr;
|
||||
|
||||
// This pattern collects the wrong Mma those need to update and create the right
|
||||
@@ -33,12 +34,13 @@ public:
|
||||
mlir::LogicalResult
|
||||
matchAndRewrite(mlir::Operation *op,
|
||||
mlir::PatternRewriter &rewriter) const override {
|
||||
|
||||
auto dotOp = cast<triton::DotOp>(op);
|
||||
auto *ctx = dotOp->getContext();
|
||||
auto AT = dotOp.a().getType().cast<RankedTensorType>();
|
||||
auto BT = dotOp.b().getType().cast<RankedTensorType>();
|
||||
auto DT = dotOp.d().getType().cast<RankedTensorType>();
|
||||
auto shapeA = AT.getShape();
|
||||
auto shapeB = BT.getShape();
|
||||
if (!DT.getEncoding())
|
||||
return failure();
|
||||
auto mmaLayout = DT.getEncoding().dyn_cast<MmaEncodingAttr>();
|
||||
@@ -53,42 +55,34 @@ public:
|
||||
auto dotOperandB = BT.getEncoding().cast<DotOperandEncodingAttr>();
|
||||
bool isARow = dotOperandA.getIsMMAv1Row().cast<BoolAttr>().getValue();
|
||||
bool isBRow = dotOperandB.getIsMMAv1Row().cast<BoolAttr>().getValue();
|
||||
auto [isARow_, isBRow_, isAVec4, isBVec4, mmaId] =
|
||||
auto [isARow_, isBRow_, isAVec4_, isBVec4_, mmaId] =
|
||||
mmaLayout.decodeVoltaLayoutStates();
|
||||
|
||||
bool isAVec4 = !isARow && (shapeA[isARow] <= 16);
|
||||
bool isBVec4 = isBRow && (shapeB[isBRow] <= 16);
|
||||
|
||||
// The wpt of MMAv1 is also determined by isARow, isBRow and shape, and it
|
||||
// could only be set here for those states might be updated by previous
|
||||
// patterns in the Combine Pass.
|
||||
if (isARow_ == isARow && isBRow_ == isBRow) {
|
||||
auto tgtWpt =
|
||||
getWarpsPerCTA(DT.getShape(), isARow, isBRow, isAVec4, isBVec4,
|
||||
product(mmaLayout.getWarpsPerCTA()));
|
||||
// Check if the wpt should be updated.
|
||||
if (tgtWpt == mmaLayout.getWarpsPerCTA() ||
|
||||
!MmaEncodingAttr::_mmaV1UpdateWpt)
|
||||
auto tgtWpt = getWarpsPerCTA(DT.getShape(), isARow, isBRow, isAVec4,
|
||||
isBVec4, product(mmaLayout.getWarpsPerCTA()));
|
||||
if (isARow == isARow_ && isBRow == isBRow_ && isAVec4 == isAVec4_ &&
|
||||
isBVec4 == isBVec4_) {
|
||||
if (tgtWpt == mmaLayout.getWarpsPerCTA())
|
||||
return failure();
|
||||
}
|
||||
|
||||
MmaEncodingAttr newMmaLayout;
|
||||
{
|
||||
// Create a temporary MMA layout to obtain the isAVec4 and isBVec4
|
||||
auto tmpMmaLayout = MmaEncodingAttr::get(
|
||||
ctx, mmaLayout.getVersionMajor(), mmaLayout.getWarpsPerCTA(),
|
||||
AT.getShape(), BT.getShape(), isARow, isBRow, mmaId);
|
||||
auto [isARow_, isBRow_, isAVec4_, isBVec4_, _] =
|
||||
tmpMmaLayout.decodeVoltaLayoutStates();
|
||||
|
||||
// Recalculate the wpt, for here we could get the latest information, the
|
||||
// wpt should be updated.
|
||||
auto updatedWpt =
|
||||
getWarpsPerCTA(DT.getShape(), isARow_, isBRow_, isAVec4_, isBVec4_,
|
||||
getWarpsPerCTA(DT.getShape(), isARow, isBRow, isAVec4, isBVec4,
|
||||
product(mmaLayout.getWarpsPerCTA()));
|
||||
auto newWpt = MmaEncodingAttr::_mmaV1UpdateWpt
|
||||
? updatedWpt
|
||||
: mmaLayout.getWarpsPerCTA();
|
||||
|
||||
newMmaLayout = MmaEncodingAttr::get(ctx, mmaLayout.getVersionMajor(),
|
||||
newWpt, AT.getShape(), BT.getShape(),
|
||||
isARow, isBRow, mmaId);
|
||||
updatedWpt, AT.getShape(),
|
||||
BT.getShape(), isARow, isBRow, mmaId);
|
||||
}
|
||||
|
||||
// Collect the wrong MMA Layouts, and mark need to update.
|
||||
@@ -100,14 +94,14 @@ public:
|
||||
// Get the wpt for MMAv1 using more information.
|
||||
// Reference the original logic here
|
||||
// https://github.com/openai/triton/blob/0e4691e6dd91e001a8d33b71badf8b3314325459/lib/codegen/analysis/layout.cc#L223
|
||||
SmallVector<unsigned, 2> getWarpsPerCTA(ArrayRef<int64_t> shape, bool isARow,
|
||||
bool isBRow, bool isAVec4,
|
||||
bool isBVec4, int numWarps) const {
|
||||
SmallVector<unsigned> getWarpsPerCTA(ArrayRef<int64_t> shape, bool isARow,
|
||||
bool isBRow, bool isAVec4, bool isBVec4,
|
||||
int numWarps) const {
|
||||
// TODO[Superjomn]: Share code with
|
||||
// DotOpMmaV1ConversionHelper::AParam/BParam, since same code to compute the
|
||||
// rep,spw and fpw.
|
||||
SmallVector<unsigned, 2> wpt({1, 1});
|
||||
SmallVector<unsigned, 2> wpt_nm1;
|
||||
SmallVector<unsigned> wpt({1, 1});
|
||||
SmallVector<unsigned> wpt_nm1;
|
||||
|
||||
SmallVector<int, 2> rep(2), spw(2);
|
||||
std::array<int, 3> fpw{{2, 2, 1}};
|
||||
@@ -242,7 +236,10 @@ public:
|
||||
|
||||
auto srcTy = op->getOperand(0).getType();
|
||||
auto resTy = op->getResult(0).getType();
|
||||
if (!needUpdate(srcTy) && needUpdate(resTy)) {
|
||||
|
||||
if (needUpdate(resTy)) {
|
||||
// The op-inputs' types are not necessary to update, for some
|
||||
// replaceOpWithNewOp will help update them.
|
||||
op->getResult(0).setType(
|
||||
getUpdatedType(resTy.dyn_cast<RankedTensorType>()));
|
||||
return success();
|
||||
|
||||
@@ -1108,6 +1108,10 @@ def test_dot(M, N, K, num_warps, col_a, col_b, epilogue, allow_tf32, dtype, devi
|
||||
pytest.skip("Only test int8 on devices with sm >= 80")
|
||||
elif dtype == 'float32' and allow_tf32:
|
||||
pytest.skip("Only test tf32 on devices with sm >= 80")
|
||||
if capability[0] == 7:
|
||||
if (M, N, K, num_warps) == (128, 256, 32, 8):
|
||||
pytest.skip("shared memory out of resource")
|
||||
|
||||
torch.backends.cuda.matmul.allow_tf32 = allow_tf32
|
||||
|
||||
# triton kernel
|
||||
|
||||
@@ -900,11 +900,12 @@ def ttir_to_ttgir(mod, num_warps, num_stages, compute_capability):
|
||||
pm.add_tritongpu_combine_pass(compute_capability)
|
||||
pm.add_licm_pass()
|
||||
pm.add_tritongpu_combine_pass(compute_capability)
|
||||
if compute_capability // 10 == 7:
|
||||
# The update_mma_for_volta pass helps to compute some information for MMA encoding specifically for MMAv1
|
||||
pm.add_tritongpu_update_mma_for_volta_pass()
|
||||
pm.add_cse_pass()
|
||||
pm.add_tritongpu_decompose_conversions_pass()
|
||||
if compute_capability // 10 == 7:
|
||||
# The update_mma_for_volta pass helps to compute some information for MMA encoding specifically for MMAv1
|
||||
# NOTE this pass should be placed after all the passes those modifies mma layout
|
||||
pm.add_tritongpu_update_mma_for_volta_pass()
|
||||
pm.add_cse_pass()
|
||||
pm.add_symbol_dce_pass()
|
||||
pm.add_tritongpu_reorder_instructions_pass()
|
||||
|
||||
@@ -758,12 +758,12 @@ module attributes {"triton_gpu.num-warps" = 1 : i32} {
|
||||
|
||||
// -----
|
||||
|
||||
#blocked0 = #triton_gpu.blocked<{sizePerThread = [1, 4], threadsPerWarp = [32, 1], warpsPerCTA = [1, 4], order = [1, 0]}>
|
||||
#mma = #triton_gpu.mma<{versionMajor = 1, warpsPerCTA = [2, 1]}>
|
||||
#blocked = #triton_gpu.blocked<{sizePerThread = [1, 4], threadsPerWarp = [2, 16], warpsPerCTA = [1, 4], order = [1, 0]}>
|
||||
#mma = #triton_gpu.mma<{versionMajor = 1, versionMinor = 3, warpsPerCTA = [2, 2]}>
|
||||
module attributes {"triton_gpu.num-warps" = 1 : i32} {
|
||||
// CHECK: llvm.mlir.global external @global_smem() {addr_space = 3 : i32} : !llvm.array<0 x i8>
|
||||
// CHECK-LABEL: convert_layout_mmav1_block
|
||||
func @convert_layout_mmav1_blocked(%arg0: tensor<32x16xf32, #mma>) {
|
||||
func @convert_layout_mmav1_blocked(%arg0: tensor<32x64xf32, #mma>) {
|
||||
// CHECK: llvm.store
|
||||
// CHECK-SAME: !llvm.ptr<vector<2xf32>, 3>
|
||||
// CHECK: llvm.store
|
||||
@@ -775,13 +775,12 @@ module attributes {"triton_gpu.num-warps" = 1 : i32} {
|
||||
// CHECK: nvvm.barrier0
|
||||
// CHECK: llvm.load
|
||||
// CHECK-SAME: !llvm.ptr<vector<4xf32>, 3>
|
||||
%0 = triton_gpu.convert_layout %arg0 : (tensor<32x16xf32, #mma>) -> tensor<32x16xf32, #blocked0>
|
||||
%0 = triton_gpu.convert_layout %arg0 : (tensor<32x64xf32, #mma>) -> tensor<32x64xf32, #blocked>
|
||||
return
|
||||
}
|
||||
}
|
||||
|
||||
// -----
|
||||
|
||||
#blocked0 = #triton_gpu.blocked<{sizePerThread = [1, 8], threadsPerWarp = [8, 4], warpsPerCTA = [8, 1], order = [1, 0]}>
|
||||
#shared0 = #triton_gpu.shared<{vec = 8, perPhase = 2, maxPhase = 4, order = [1, 0]}>
|
||||
module attributes {"triton_gpu.num-warps" = 1 : i32} {
|
||||
@@ -868,24 +867,24 @@ module attributes {"triton_gpu.num-warps" = 4 : i32} {
|
||||
// -----
|
||||
|
||||
#blocked = #triton_gpu.blocked<{sizePerThread = [1, 4], threadsPerWarp = [2, 16], warpsPerCTA = [1, 4], order = [1, 0]}>
|
||||
#shared = #triton_gpu.shared<{vec = 1, perPhase = 1, maxPhase = 1, order = [1, 0]}>
|
||||
#mma = #triton_gpu.mma<{versionMajor = 1, warpsPerCTA = [2, 2]}>
|
||||
#shared0 = #triton_gpu.shared<{vec = 4, perPhase = 1, maxPhase = 8, order = [1, 0]}>
|
||||
#shared1 = #triton_gpu.shared<{vec = 8, perPhase = 1, maxPhase = 4, order = [1, 0]}>
|
||||
#mma = #triton_gpu.mma<{versionMajor = 1, versionMinor = 3, warpsPerCTA = [2, 2]}>
|
||||
#dot_operand_a = #triton_gpu.dot_op<{opIdx=0, parent=#mma, isMMAv1Row=true}>
|
||||
#dot_operand_b = #triton_gpu.dot_op<{opIdx=1, parent=#mma, isMMAv1Row=true}>
|
||||
module attributes {"triton_gpu.num-warps" = 4 : i32} {
|
||||
func @matmul884_kernel_dot_operand_layout(%ptr:!tt.ptr<f32> {tt.divisibility = 16 : i32},
|
||||
%a:tensor<128x32xf16, #shared>, %b:tensor<32x256xf16, #shared>) {
|
||||
%cst = arith.constant dense<0.000000e+00> : tensor<128x256xf32, #mma>
|
||||
%a:tensor<32x64xf16, #shared0>, %b:tensor<64x64xf16, #shared1>) {
|
||||
%cst = arith.constant dense<0.000000e+00> : tensor<32x64xf32, #mma>
|
||||
// CHECK: ldmatrix.sync.aligned.m8n8.x4.shared.b16
|
||||
%a_mat = triton_gpu.convert_layout %a : (tensor<128x32xf16, #shared>) -> tensor<128x32xf16, #dot_operand_a>
|
||||
%b_mat = triton_gpu.convert_layout %b : (tensor<32x256xf16, #shared>) -> tensor<32x256xf16, #dot_operand_b>
|
||||
%a_mat = triton_gpu.convert_layout %a : (tensor<32x64xf16, #shared0>) -> tensor<32x64xf16, #dot_operand_a>
|
||||
%b_mat = triton_gpu.convert_layout %b : (tensor<64x64xf16, #shared1>) -> tensor<64x64xf16, #dot_operand_b>
|
||||
|
||||
%28 = tt.dot %a_mat, %b_mat, %cst {allowTF32 = true, transA = false, transB = false} : tensor<128x32xf16, #dot_operand_a> * tensor<32x256xf16, #dot_operand_b> -> tensor<128x256xf32, #mma>
|
||||
// TODO[goostavz]: uncomment the following lines after convert_layout[mma<v1> -> blocked] is ready.
|
||||
// %38 = triton_gpu.convert_layout %28 : (tensor<128x256xf32, #mma>) -> tensor<128x256xf32, #blocked>
|
||||
// %30 = tt.splat %ptr : (!tt.ptr<f32>) -> tensor<128x1x!tt.ptr<f32>, #blocked>
|
||||
// %36 = tt.broadcast %30 : (tensor<128x1x!tt.ptr<f32>, #blocked>) -> tensor<128x256x!tt.ptr<f32>, #blocked>
|
||||
// tt.store %36, %38 : tensor<128x256xf32, #blocked>
|
||||
%28 = tt.dot %a_mat, %b_mat, %cst {allowTF32 = true, transA = false, transB = false} : tensor<32x64xf16, #dot_operand_a> * tensor<64x64xf16, #dot_operand_b> -> tensor<32x64xf32, #mma>
|
||||
%38 = triton_gpu.convert_layout %28 : (tensor<32x64xf32, #mma>) -> tensor<32x64xf32, #blocked>
|
||||
%30 = tt.splat %ptr : (!tt.ptr<f32>) -> tensor<32x1x!tt.ptr<f32>, #blocked>
|
||||
%36 = tt.broadcast %30 : (tensor<32x1x!tt.ptr<f32>, #blocked>) -> tensor<32x64x!tt.ptr<f32>, #blocked>
|
||||
tt.store %36, %38 : tensor<32x64xf32, #blocked>
|
||||
return
|
||||
}
|
||||
}
|
||||
@@ -999,7 +998,7 @@ module attributes {"triton_gpu.num-warps" = 4 : i32} {
|
||||
%v1 = arith.addi %v0, %blockdimz : i32
|
||||
%0 = tt.splat %v1 : (i32) -> tensor<32xi32, #blocked0>
|
||||
tt.store %a, %0 : tensor<32xi32, #blocked0>
|
||||
|
||||
|
||||
return
|
||||
}
|
||||
}
|
||||
@@ -1007,7 +1006,7 @@ module attributes {"triton_gpu.num-warps" = 4 : i32} {
|
||||
// -----
|
||||
#blocked0 = #triton_gpu.blocked<{sizePerThread = [2], threadsPerWarp = [32], warpsPerCTA = [1], order = [0]}>
|
||||
module attributes {"triton_gpu.num-warps" = 4 : i32} {
|
||||
// CHECK-LABEL: test_index_cache
|
||||
// CHECK-LABEL: test_index_cache
|
||||
func @test_index_cache() {
|
||||
// CHECK: nvvm.read.ptx.sreg.tid.x
|
||||
%0 = tt.make_range {end = 256 : i32, start = 0 : i32} : tensor<256xi32, #blocked0>
|
||||
@@ -1021,7 +1020,7 @@ module attributes {"triton_gpu.num-warps" = 4 : i32} {
|
||||
#blocked0 = #triton_gpu.blocked<{sizePerThread = [1, 8], threadsPerWarp = [8, 4], warpsPerCTA = [8, 1], order = [1, 0]}>
|
||||
#shared0 = #triton_gpu.shared<{vec = 8, perPhase = 2, maxPhase = 4, order = [1, 0]}>
|
||||
module attributes {"triton_gpu.num-warps" = 1 : i32} {
|
||||
// CHECK-LABEL: test_base_index_cache
|
||||
// CHECK-LABEL: test_base_index_cache
|
||||
func @test_base_index_cache(%arg0: tensor<128x32xf32, #blocked0>) {
|
||||
// CHECK: nvvm.read.ptx.sreg.tid.x
|
||||
%0 = triton_gpu.convert_layout %arg0 : (tensor<128x32xf32, #blocked0>) -> tensor<128x32xf32, #shared0>
|
||||
@@ -1045,4 +1044,4 @@ module attributes {"triton_gpu.num-warps" = 1 : i32} {
|
||||
}
|
||||
return
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
@@ -12,7 +12,7 @@
|
||||
#dot_operand_b = #triton_gpu.dot_op<{opIdx=1, parent=#mma0, isMMAv1Row=false}>
|
||||
// It creates a new MMA layout to fit with $a and $b's dot_operand, and get the right warpsPerCTA
|
||||
// The ID of this MMA instance should be 0.
|
||||
// CHECK: [[new_mma:#mma.*]] = #triton_gpu.mma<{versionMajor = 1, versionMinor = 3, warpsPerCTA = [4, 4]}>
|
||||
// CHECK: [[new_mma:#mma.*]] = #triton_gpu.mma<{versionMajor = 1, versionMinor = 3, warpsPerCTA = [4, 2]}>
|
||||
module attributes {"triton_gpu.num-warps" = 16 : i32} {
|
||||
// CHECK-LABEL: dot_mmav1
|
||||
func @dot_mmav1(%A: tensor<64x64xf16, #blocked0>, %B: tensor<64x64xf16, #blocked0>) -> tensor<64x64xf32, #blocked0> {
|
||||
@@ -40,8 +40,8 @@ module attributes {"triton_gpu.num-warps" = 16 : i32} {
|
||||
#mma1 = #triton_gpu.mma<{versionMajor=1, versionMinor=16, warpsPerCTA=[4,4]}>
|
||||
|
||||
// Will still get two MMA layouts
|
||||
// CHECK: [[new_mma:#mma.*]] = #triton_gpu.mma<{versionMajor = 1, versionMinor = 3, warpsPerCTA = [4, 4]}>
|
||||
// CHECK: [[new_mma1:#mma.*]] = #triton_gpu.mma<{versionMajor = 1, versionMinor = 19, warpsPerCTA = [4, 4]}>
|
||||
// CHECK: [[new_mma:#mma.*]] = #triton_gpu.mma<{versionMajor = 1, versionMinor = 3, warpsPerCTA = [4, 2]}>
|
||||
// CHECK: [[new_mma1:#mma.*]] = #triton_gpu.mma<{versionMajor = 1, versionMinor = 19, warpsPerCTA = [4, 2]}>
|
||||
|
||||
#dot_operand_a = #triton_gpu.dot_op<{opIdx=0, parent=#mma0, isMMAv1Row=true}>
|
||||
#dot_operand_b = #triton_gpu.dot_op<{opIdx=1, parent=#mma0, isMMAv1Row=false}>
|
||||
|
||||
Reference in New Issue
Block a user