mirror of
https://github.com/ROCm/ROCm.git
synced 2026-02-21 03:00:39 -05:00
[OPTIMIZER][BACKEND] significantly cleaner handling of mixed-precision kernels (#1949)
we currently have a very janky approach to optimizing mixed-precision matmul workloads, where some layout combinations (e.g., NT matmul) were explicitly pattern-matched to take a more optimized codepath. Attempt at unifying all the codepaths to codegen cp.async failed, due to bugs in SharedToDotOperandMMAv2.cpp. This PR fixes said bugs, add some assertions for SharedToDotOperandMMAv2 modes that aren't well supported, and greatly simplify our handling of element-wise operations between load and conversions to DotOperand.
This commit is contained in:
@@ -83,18 +83,17 @@ A_{3, 2} A_{3, 3} A_{3, 0} A_{3, 1} ... [phase 1] /
|
||||
if(!mmaEnc)
|
||||
return $_get(context, 1, 1, 1, order);
|
||||
|
||||
|
||||
int opIdx = dotOpEnc.getOpIdx();
|
||||
|
||||
// number of rows per phase
|
||||
int perPhase = 128 / (shape[order[0]] * (typeWidthInBit / 8));
|
||||
perPhase = std::max<int>(perPhase, 1);
|
||||
|
||||
// index of the inner dimension in `order`
|
||||
unsigned inner = (opIdx == 0) ? 0 : 1;
|
||||
|
||||
// ---- begin Volta ----
|
||||
if (mmaEnc.isVolta()) {
|
||||
int perPhase = 128 / (shape[order[0]] * (typeWidthInBit / 8));
|
||||
perPhase = std::max<int>(perPhase, 1);
|
||||
bool is_row = order[0] != 0;
|
||||
bool is_vec4 = opIdx == 0 ? !is_row && (shape[order[0]] <= 16) :
|
||||
is_row && (shape[order[0]] <= 16);
|
||||
@@ -108,10 +107,11 @@ A_{3, 2} A_{3, 3} A_{3, 0} A_{3, 1} ... [phase 1] /
|
||||
|
||||
// ---- begin Ampere ----
|
||||
if (mmaEnc.isAmpere()) {
|
||||
std::vector<size_t> matShape = {8, 8,
|
||||
2 * 64 / typeWidthInBit};
|
||||
int perPhase = 128 / (shape[order[0]] * 4 / dotOpEnc.getMMAv2kWidth());
|
||||
perPhase = std::max<int>(perPhase, 1);
|
||||
std::vector<size_t> matShape = {8, 8, 4 * dotOpEnc.getMMAv2kWidth()};
|
||||
// for now, disable swizzle when using transposed int8 tensor cores
|
||||
if (typeWidthInBit == 8 && order[0] == inner)
|
||||
if ((32 / typeWidthInBit != dotOpEnc.getMMAv2kWidth()) && order[0] == inner)
|
||||
return $_get(context, 1, 1, 1, order);
|
||||
|
||||
// --- handle A operand ---
|
||||
|
||||
@@ -60,6 +60,7 @@ private:
|
||||
SmallVector<uint32_t> warpsPerCTA;
|
||||
int kOrder;
|
||||
int kWidth;
|
||||
int vecWidth;
|
||||
SmallVector<int64_t> tileShape;
|
||||
SmallVector<int> instrShape;
|
||||
SmallVector<int> matShape;
|
||||
@@ -178,13 +179,13 @@ MMA16816SmemLoader::computeLdmatrixMatOffs(Value warpId, Value lane,
|
||||
// <----------------------------------------->
|
||||
// vecWidth
|
||||
// <------->
|
||||
// t0 ... t0 t1 ... t1 t2 ... t2 t3 ... t3 || t0 ... t0 t1 ... t1 t2 ... t2 t3 ... t3 /|\
|
||||
// *#t0 ... *#t0 t1 ... t1 t2 ... t2 t3 ... t3 || *t0 ... *t0 t1 ... t1 t2 ... t2 t3 ... t3 /|\
|
||||
// t4 ... t4 t5 ... t5 t6 ... t6 t7 ... t7 || t4 ... t4 t5 ... t5 t6 ... t6 t7 ... t7 |
|
||||
// t8 ... t8 t9 ... t9 t10 .. t10 t11 .. t11 || t8 ... t8 t9 ... t9 t10 .. t10 t11 .. t11 | quad height
|
||||
// ... |
|
||||
// t28 .. t28 t29 .. t29 t30 .. t30 t31 .. t31 || t28 .. t28 t29 .. t29 t30 .. t30 t31 .. t31 \|/
|
||||
// --------------------------------------------- || --------------------------------------------
|
||||
// t0 ... t0 t1 ... t1 t2 ... t2 t3 ... t3 || t0 ... t0 t1 ... t1 t2 ... t2 t3 ... t3
|
||||
// *#t0 ... *#t0 t1 ... t1 t2 ... t2 t3 ... t3 || t0 ... t0 t1 ... t1 t2 ... t2 t3 ... t3
|
||||
// t4 ... t4 t5 ... t5 t6 ... t6 t7 ... t7 || t4 ... t4 t5 ... t5 t6 ... t6 t7 ... t7
|
||||
// t8 ... t8 t9 ... t9 t10 .. t10 t11 .. t11 || t8 ... t8 t9 ... t9 t10 .. t10 t11 .. t11
|
||||
// ...
|
||||
@@ -206,23 +207,21 @@ SmallVector<Value> MMA16816SmemLoader::computeLdsMatOffs(Value warpOff,
|
||||
|
||||
SmallVector<Value> offs(numPtrs);
|
||||
|
||||
int vecWidth = kWidth;
|
||||
int threadsPerQuad[2] = {8, 4};
|
||||
int laneWidth = 4;
|
||||
int laneHeight = 8;
|
||||
int quadWidth = laneWidth * vecWidth;
|
||||
int quadWidth = laneWidth * kWidth;
|
||||
int quadHeight = laneHeight;
|
||||
int numQuadI = 2;
|
||||
|
||||
// outer index base
|
||||
Value iBase = udiv(lane, i32_val(laneWidth));
|
||||
|
||||
for (int rep = 0; rep < numPtrs / (2 * vecWidth); ++rep)
|
||||
for (int rep = 0; rep < numPtrs / (2 * kWidth); ++rep)
|
||||
for (int quadId = 0; quadId < 2; ++quadId)
|
||||
for (int elemId = 0; elemId < vecWidth; ++elemId) {
|
||||
int idx = rep * 2 * vecWidth + quadId * vecWidth + elemId;
|
||||
for (int elemId = 0; elemId < kWidth; ++elemId) {
|
||||
// inner index base
|
||||
Value jBase = mul(urem(lane, i32_val(laneWidth)), i32_val(vecWidth));
|
||||
Value jBase = mul(urem(lane, i32_val(laneWidth)), i32_val(kWidth));
|
||||
jBase = add(jBase, i32_val(elemId));
|
||||
// inner index offset
|
||||
Value jOff = i32_val(0);
|
||||
@@ -250,9 +249,17 @@ SmallVector<Value> MMA16816SmemLoader::computeLdsMatOffs(Value warpOff,
|
||||
// To prevent out-of-bound access when tile is too small.
|
||||
Value i = add(iBase, mul(iOff, i32_val(quadHeight)));
|
||||
Value j = add(jBase, mul(jOff, i32_val(quadWidth)));
|
||||
// wrap around the bounds
|
||||
// i = urem(i, i32_val(cTileShape));
|
||||
// j = urem(j, i32_val(sTileShape));
|
||||
// Compute id of this ptr
|
||||
int idx = rep * 2 * kWidth;
|
||||
if (needTrans) {
|
||||
idx += quadId * vecWidth;
|
||||
idx += elemId % vecWidth;
|
||||
idx += elemId / vecWidth * kWidth;
|
||||
} else {
|
||||
idx += quadId * kWidth;
|
||||
idx += elemId;
|
||||
}
|
||||
|
||||
if (needTrans) {
|
||||
offs[idx] = add(i, mul(j, stridedSmemOffset));
|
||||
} else {
|
||||
@@ -274,7 +281,7 @@ MMA16816SmemLoader::loadX4(int mat0, int mat1, ArrayRef<Value> ptrs, Type matTy,
|
||||
if (canUseLdmatrix)
|
||||
ptrIdx = matIdx[order[0]] / (instrShape[order[0]] / matShape[order[0]]);
|
||||
else
|
||||
ptrIdx = matIdx[order[0]] * 4 / elemBytes;
|
||||
ptrIdx = matIdx[order[0]] * (needTrans ? kWidth : vecWidth);
|
||||
|
||||
// The main difference with the original triton code is we removed the
|
||||
// prefetch-related logic here for the upstream optimizer phase should
|
||||
@@ -323,11 +330,8 @@ MMA16816SmemLoader::loadX4(int mat0, int mat1, ArrayRef<Value> ptrs, Type matTy,
|
||||
return {extract_val(elemTy, resV4, 0), extract_val(elemTy, resV4, 1),
|
||||
extract_val(elemTy, resV4, 2), extract_val(elemTy, resV4, 3)};
|
||||
} else {
|
||||
if (needTrans && (4 / elemBytes) != kWidth)
|
||||
llvm_unreachable("unimplemented Shared -> DotOperandMmav2 code path");
|
||||
// base pointers
|
||||
std::array<std::array<Value, 4>, 2> ptrs;
|
||||
int vecWidth = 4 / elemBytes;
|
||||
for (int i = 0; i < vecWidth; i++)
|
||||
ptrs[0][i] = getPtr(ptrIdx + i);
|
||||
for (int i = 0; i < vecWidth; i++)
|
||||
@@ -336,7 +340,8 @@ MMA16816SmemLoader::loadX4(int mat0, int mat1, ArrayRef<Value> ptrs, Type matTy,
|
||||
int _i0 = matIdx[order[1]] * (stridedLoadMatOffset * stridedMatShape);
|
||||
int _i1 = _i0;
|
||||
if (needTrans)
|
||||
_i1 += stridedLoadMatOffset * stridedMatShape;
|
||||
_i1 += (kWidth != vecWidth) ? vecWidth
|
||||
: stridedLoadMatOffset * stridedMatShape;
|
||||
else
|
||||
_i1 += (kOrder == 1 ? 1 : stridedLoadMatOffset) * stridedMatShape;
|
||||
Value i0 = mul(i32_val(_i0), stridedSmemOffset);
|
||||
@@ -345,9 +350,11 @@ MMA16816SmemLoader::loadX4(int mat0, int mat1, ArrayRef<Value> ptrs, Type matTy,
|
||||
// load 4 32-bit values from shared memory
|
||||
// (equivalent to ldmatrix.x4)
|
||||
SmallVector<SmallVector<Value>> vptrs(4, SmallVector<Value>(vecWidth));
|
||||
|
||||
for (int i = 0; i < 4; ++i)
|
||||
for (int j = 0; j < vecWidth; ++j)
|
||||
for (int j = 0; j < vecWidth; ++j) {
|
||||
vptrs[i][j] = gep(shemPtrTy, ptrs[i / 2][j], ii[i % 2]);
|
||||
}
|
||||
// row + trans and col + no-trans are equivalent
|
||||
bool isActualTrans =
|
||||
(needTrans && kOrder == 1) || (!needTrans && kOrder == 0);
|
||||
@@ -398,13 +405,14 @@ MMA16816SmemLoader::MMA16816SmemLoader(
|
||||
ctx(rewriter.getContext()) {
|
||||
contiguousMatShape = matShape[order[0]];
|
||||
stridedMatShape = matShape[order[1]];
|
||||
|
||||
stridedSmemOffset = smemStrides[order[1]];
|
||||
vecWidth = 4 / elemBytes;
|
||||
|
||||
// rule: k must be the fast-changing axis.
|
||||
needTrans = kOrder != order[0];
|
||||
canUseLdmatrix = elemBytes == 2 || (!needTrans);
|
||||
canUseLdmatrix = canUseLdmatrix && (kWidth == 4 / elemBytes);
|
||||
canUseLdmatrix = canUseLdmatrix && (kWidth == vecWidth);
|
||||
// canUseLdmatrix = false;
|
||||
|
||||
if (canUseLdmatrix) {
|
||||
// Each CTA, the warps is arranged as [1xwarpsPerTile] if not transposed,
|
||||
@@ -414,7 +422,7 @@ MMA16816SmemLoader::MMA16816SmemLoader(
|
||||
} else {
|
||||
numPtrs = tileShape[order[0]] / (needTrans ? warpsPerTile : 1) /
|
||||
matShape[order[0]];
|
||||
numPtrs *= 4 / elemBytes;
|
||||
numPtrs *= kWidth;
|
||||
}
|
||||
numPtrs = std::max<int>(numPtrs, 2);
|
||||
|
||||
@@ -488,9 +496,21 @@ std::function<void(int, int)> getLoadMatrixFn(
|
||||
auto sharedLayout = tensorTy.getEncoding().cast<SharedEncodingAttr>();
|
||||
const int perPhase = sharedLayout.getPerPhase();
|
||||
const int maxPhase = sharedLayout.getMaxPhase();
|
||||
const int vecPhase = sharedLayout.getVec();
|
||||
const int elemBytes = tensorTy.getElementTypeBitWidth() / 8;
|
||||
auto order = sharedLayout.getOrder();
|
||||
|
||||
if (tensor.getType()
|
||||
.cast<RankedTensorType>()
|
||||
.getElementType()
|
||||
.isa<mlir::Float8E4M3B11FNUZType>()) {
|
||||
bool noTrans = (isA ^ order[0] == 0);
|
||||
assert(noTrans && "float8e4b15 must have row-col layout");
|
||||
}
|
||||
|
||||
if (kWidth != (4 / elemBytes))
|
||||
assert(vecPhase == 1 || vecPhase == 4 * kWidth);
|
||||
|
||||
// (a, b) is the coordinate.
|
||||
auto load = [=, &rewriter, &vals](int a, int b) {
|
||||
MMA16816SmemLoader loader(
|
||||
|
||||
@@ -342,7 +342,6 @@ static SmallVector<Value> reorderValues(const SmallVector<Value> &values,
|
||||
// ret.push_back(values[i + 14]);
|
||||
// ret.push_back(values[i + 15]);
|
||||
// }
|
||||
return values;
|
||||
}
|
||||
llvm_unreachable("unimplemented code path");
|
||||
}
|
||||
|
||||
@@ -782,7 +782,9 @@ struct InsertSliceAsyncOpConversion
|
||||
// start of the vector and the other pointer moving to the next vector.
|
||||
unsigned inVec = getContiguity(src);
|
||||
unsigned outVec = resSharedLayout.getVec();
|
||||
unsigned minVec = std::min(outVec, inVec);
|
||||
unsigned minVec = inVec;
|
||||
if (outVec > 1)
|
||||
minVec = std::min(outVec, inVec);
|
||||
unsigned numElems = getTotalElemsPerThread(srcTy);
|
||||
unsigned perPhase = resSharedLayout.getPerPhase();
|
||||
unsigned maxPhase = resSharedLayout.getMaxPhase();
|
||||
|
||||
@@ -566,7 +566,9 @@ private:
|
||||
inVec =
|
||||
std::min<unsigned>(axisInfoAnalysis.getMaskAlignment(mask), inVec);
|
||||
unsigned outVec = resSharedLayout.getVec();
|
||||
unsigned minVec = std::min(outVec, inVec);
|
||||
unsigned minVec = inVec;
|
||||
if (outVec > 1)
|
||||
minVec = std::min(outVec, inVec);
|
||||
auto maxBitWidth =
|
||||
std::max<unsigned>(128, resElemTy.getIntOrFloatBitWidth());
|
||||
auto vecBitWidth = resElemTy.getIntOrFloatBitWidth() * minVec;
|
||||
@@ -577,8 +579,9 @@ private:
|
||||
// capability does not support async copy, then we do decompose
|
||||
if (triton::gpu::InsertSliceAsyncOp::getEligibleLoadByteWidth(
|
||||
computeCapability)
|
||||
.contains(byteWidth))
|
||||
.contains(byteWidth)) {
|
||||
return;
|
||||
}
|
||||
|
||||
// load
|
||||
auto tmpTy =
|
||||
|
||||
@@ -1,3 +1,4 @@
|
||||
#include "mlir/IR/TypeUtilities.h"
|
||||
#include "mlir/Support/LogicalResult.h"
|
||||
#include "mlir/Transforms/GreedyPatternRewriteDriver.h"
|
||||
#include "triton/Analysis/Utility.h"
|
||||
@@ -76,6 +77,29 @@ class BlockedToMMA : public mlir::RewritePattern {
|
||||
int computeCapability;
|
||||
mutable int mmaV1Counter{}; // used to generate ID for MMAv1 encoding
|
||||
|
||||
static bool bwdFilter(Operation *op) {
|
||||
return op->getNumOperands() == 1 &&
|
||||
(isa<triton::FpToFpOp, triton::BitcastOp,
|
||||
triton::gpu::ConvertLayoutOp>(op) ||
|
||||
op->getDialect()->getTypeID() ==
|
||||
mlir::TypeID::get<arith::ArithDialect>());
|
||||
}
|
||||
|
||||
// finds the first different value bitwidth in the chain of
|
||||
// shape-preserving unary ops that x depends on
|
||||
static int computeOrigBitWidth(Value x) {
|
||||
int finalBitWidth = getElementTypeOrSelf(x).getIntOrFloatBitWidth();
|
||||
int origBitWidth = finalBitWidth;
|
||||
SetVector<Operation *> slice;
|
||||
mlir::getBackwardSlice(x, &slice, bwdFilter);
|
||||
Operation *firstOp = slice.empty() ? nullptr : *slice.begin();
|
||||
if (firstOp)
|
||||
if (Value arg = firstOp->getOperand(0))
|
||||
if (RankedTensorType argTy = arg.getType().dyn_cast<RankedTensorType>())
|
||||
origBitWidth = argTy.getElementType().getIntOrFloatBitWidth();
|
||||
return origBitWidth;
|
||||
}
|
||||
|
||||
public:
|
||||
BlockedToMMA(mlir::MLIRContext *context, int computeCapability)
|
||||
: mlir::RewritePattern(triton::DotOp::getOperationName(), 2, context),
|
||||
@@ -87,6 +111,7 @@ public:
|
||||
if (computeCapability < 70)
|
||||
return failure();
|
||||
auto dotOp = cast<triton::DotOp>(op);
|
||||
auto ctx = op->getContext();
|
||||
// TODO: Check data-types and SM compatibility
|
||||
auto oldRetType = dotOp.getResult().getType().cast<RankedTensorType>();
|
||||
if (!oldRetType.getEncoding() ||
|
||||
@@ -151,36 +176,28 @@ public:
|
||||
}
|
||||
auto newRetType =
|
||||
RankedTensorType::get(retShape, oldRetType.getElementType(), mmaEnc);
|
||||
|
||||
// convert accumulator
|
||||
auto oldAcc = dotOp.getOperand(2);
|
||||
auto newAcc = rewriter.create<triton::gpu::ConvertLayoutOp>(
|
||||
oldAcc.getLoc(), newRetType, oldAcc);
|
||||
auto oldAOrder = oldAType.getEncoding()
|
||||
.cast<triton::gpu::DotOperandEncodingAttr>()
|
||||
.getParent()
|
||||
.cast<triton::gpu::BlockedEncodingAttr>()
|
||||
.getOrder();
|
||||
auto oldBOrder = oldBType.getEncoding()
|
||||
.cast<triton::gpu::DotOperandEncodingAttr>()
|
||||
.getParent()
|
||||
.cast<triton::gpu::BlockedEncodingAttr>()
|
||||
.getOrder();
|
||||
|
||||
// convert operands
|
||||
int minBitwidth = std::min(computeOrigBitWidth(a), computeOrigBitWidth(b));
|
||||
Type minType = IntegerType::get(ctx, minBitwidth);
|
||||
// convert A operand
|
||||
auto newAEncoding = triton::gpu::DotOperandEncodingAttr::get(
|
||||
oldAType.getContext(), 0, newRetType.getEncoding(),
|
||||
oldAType.getElementType());
|
||||
auto newBEncoding = triton::gpu::DotOperandEncodingAttr::get(
|
||||
oldBType.getContext(), 1, newRetType.getEncoding(),
|
||||
oldBType.getElementType());
|
||||
|
||||
minBitwidth > 0 ? minType : oldAType.getElementType());
|
||||
auto newAType = RankedTensorType::get(
|
||||
oldAType.getShape(), oldAType.getElementType(), newAEncoding);
|
||||
a = rewriter.create<triton::gpu::ConvertLayoutOp>(a.getLoc(), newAType, a);
|
||||
// convert B operand
|
||||
auto newBEncoding = triton::gpu::DotOperandEncodingAttr::get(
|
||||
oldBType.getContext(), 1, newRetType.getEncoding(),
|
||||
minBitwidth > 0 ? minType : oldBType.getElementType());
|
||||
auto newBType = RankedTensorType::get(
|
||||
oldBType.getShape(), oldBType.getElementType(), newBEncoding);
|
||||
|
||||
a = rewriter.create<triton::gpu::ConvertLayoutOp>(a.getLoc(), newAType, a);
|
||||
b = rewriter.create<triton::gpu::ConvertLayoutOp>(b.getLoc(), newBType, b);
|
||||
// convert dot instruction
|
||||
auto newDot = rewriter.create<triton::DotOp>(
|
||||
dotOp.getLoc(), newRetType, a, b, newAcc, dotOp.getAllowTF32());
|
||||
|
||||
|
||||
@@ -1,3 +1,4 @@
|
||||
#include "mlir/IR/TypeUtilities.h"
|
||||
#include "mlir/Pass/PassManager.h"
|
||||
#include "mlir/Support/LogicalResult.h"
|
||||
#include "mlir/Transforms/GreedyPatternRewriteDriver.h"
|
||||
@@ -73,260 +74,77 @@ public:
|
||||
}
|
||||
};
|
||||
|
||||
//
|
||||
|
||||
// convert(layout_preserving_op(x), dot_operand)
|
||||
// -> layout_preserving_op(convert(x, dot_operand))
|
||||
class MoveOpAfterLayoutConversion : public mlir::RewritePattern {
|
||||
|
||||
public:
|
||||
MoveOpAfterLayoutConversion(mlir::MLIRContext *context)
|
||||
: mlir::RewritePattern(triton::DotOp::getOperationName(), 1, context) {}
|
||||
: mlir::RewritePattern(triton::gpu::ConvertLayoutOp::getOperationName(),
|
||||
1, context) {}
|
||||
|
||||
static mlir::LogicalResult
|
||||
isBlockedToDotOperand(mlir::Operation *op,
|
||||
triton::gpu::DotOperandEncodingAttr &retEncoding,
|
||||
triton::gpu::BlockedEncodingAttr &srcEncoding) {
|
||||
auto cvt = dyn_cast_or_null<triton::gpu::ConvertLayoutOp>(op);
|
||||
if (!cvt)
|
||||
return failure();
|
||||
auto srcTy = cvt.getOperand().getType().cast<RankedTensorType>();
|
||||
auto retTy = cvt.getResult().getType().dyn_cast<RankedTensorType>();
|
||||
retEncoding =
|
||||
retTy.getEncoding().dyn_cast<triton::gpu::DotOperandEncodingAttr>();
|
||||
srcEncoding =
|
||||
srcTy.getEncoding().dyn_cast<triton::gpu::BlockedEncodingAttr>();
|
||||
if (!retTy)
|
||||
return failure();
|
||||
if (!retEncoding)
|
||||
return failure();
|
||||
auto retEncodingParent =
|
||||
retEncoding.getParent().dyn_cast<triton::gpu::MmaEncodingAttr>();
|
||||
if (!retEncodingParent || retEncodingParent.isVolta())
|
||||
return failure();
|
||||
if (!srcEncoding)
|
||||
return failure();
|
||||
return success();
|
||||
}
|
||||
|
||||
static bool isTrans(const triton::gpu::DotOperandEncodingAttr &retEncoding,
|
||||
const triton::gpu::BlockedEncodingAttr &srcEncoding) {
|
||||
int kOrder = retEncoding.getOpIdx() ^ 1;
|
||||
return kOrder != srcEncoding.getOrder()[0];
|
||||
}
|
||||
|
||||
static bool isDotNT(triton::DotOp dotOp) {
|
||||
triton::gpu::DotOperandEncodingAttr aRetEncoding;
|
||||
triton::gpu::DotOperandEncodingAttr bRetEncoding;
|
||||
triton::gpu::BlockedEncodingAttr aSrcEncoding;
|
||||
triton::gpu::BlockedEncodingAttr bSrcEncoding;
|
||||
if (isBlockedToDotOperand(dotOp.getOperand(0).getDefiningOp(), aRetEncoding,
|
||||
aSrcEncoding)
|
||||
.failed())
|
||||
return false;
|
||||
if (isBlockedToDotOperand(dotOp.getOperand(1).getDefiningOp(), bRetEncoding,
|
||||
bSrcEncoding)
|
||||
.failed())
|
||||
return false;
|
||||
if (!aRetEncoding || !bRetEncoding || !aSrcEncoding || !bSrcEncoding)
|
||||
return false;
|
||||
return !isTrans(aRetEncoding, aSrcEncoding) &&
|
||||
!isTrans(bRetEncoding, bSrcEncoding);
|
||||
}
|
||||
|
||||
mlir::LogicalResult
|
||||
LogicalResult
|
||||
matchAndRewrite(mlir::Operation *op,
|
||||
mlir::PatternRewriter &rewriter) const override {
|
||||
auto dotOp = cast<triton::DotOp>(op);
|
||||
// only supports dot NT
|
||||
if (!isDotNT(dotOp))
|
||||
return failure();
|
||||
bool changed = false;
|
||||
for (Value operand : {dotOp.getOperand(0), dotOp.getOperand(1)}) {
|
||||
auto cvt = operand.getDefiningOp<triton::gpu::ConvertLayoutOp>();
|
||||
triton::gpu::DotOperandEncodingAttr retEncoding;
|
||||
triton::gpu::BlockedEncodingAttr srcEncoding;
|
||||
bool failed =
|
||||
isBlockedToDotOperand(cvt, retEncoding, srcEncoding).failed();
|
||||
assert(!failed);
|
||||
|
||||
// don't move things around when cvt operand is a block arg
|
||||
Operation *argOp = cvt.getOperand().getDefiningOp();
|
||||
if (!argOp)
|
||||
auto cvt = cast<triton::gpu::ConvertLayoutOp>(op);
|
||||
// conversion should be dependent on a load
|
||||
// and all operations between the load and the conversion
|
||||
// should be layout preserving
|
||||
SetVector<Operation *> slice;
|
||||
getBackwardSlice(op, &slice);
|
||||
int loadIdx = -1;
|
||||
bool checkOp = false;
|
||||
for (int i = 0; i < slice.size(); i++) {
|
||||
Operation *currOp = *(slice.begin() + i);
|
||||
if (currOp->getParentRegion() != op->getParentRegion())
|
||||
continue;
|
||||
SetVector<Operation *> processed;
|
||||
SetVector<Attribute> layout;
|
||||
llvm::MapVector<Value, Attribute> toConvert;
|
||||
int numCvts = simulateBackwardRematerialization(cvt, processed, layout,
|
||||
toConvert, retEncoding);
|
||||
if (numCvts > 1 || toConvert.size() == 1)
|
||||
continue;
|
||||
bool replaceOperand = true;
|
||||
for (Operation *op : processed) {
|
||||
if (op->getNumOperands() != 1)
|
||||
continue;
|
||||
auto srcTy = op->getOperand(0).getType().cast<RankedTensorType>();
|
||||
auto dstTy = op->getResult(0).getType().cast<RankedTensorType>();
|
||||
// we don't want to push conversions backward if there is a downcast
|
||||
// since it would result in more shared memory traffic
|
||||
if (srcTy.getElementType().getIntOrFloatBitWidth() >
|
||||
dstTy.getElementType().getIntOrFloatBitWidth()) {
|
||||
replaceOperand = false;
|
||||
break;
|
||||
}
|
||||
// we only push back when the first op in the chain has a load operand
|
||||
if ((op == processed.back()) &&
|
||||
!isa<triton::LoadOp>(op->getOperand(0).getDefiningOp())) {
|
||||
replaceOperand = false;
|
||||
break;
|
||||
}
|
||||
// we don't want to use ldmatrix for 8-bit data that requires trans
|
||||
// since Nvidia GPUs can't do it efficiently
|
||||
int kOrder = retEncoding.getOpIdx() ^ 1;
|
||||
bool isTrans = kOrder != srcEncoding.getOrder()[0];
|
||||
bool isInt8 = srcTy.getElementType().getIntOrFloatBitWidth() == 8;
|
||||
if (isTrans && isInt8) {
|
||||
replaceOperand = false;
|
||||
break;
|
||||
}
|
||||
if (isa<triton::LoadOp>(currOp))
|
||||
checkOp = true;
|
||||
else if (checkOp) {
|
||||
if (!isa<triton::FpToFpOp, triton::BitcastOp>(currOp) &&
|
||||
currOp->getDialect()->getTypeID() !=
|
||||
mlir::TypeID::get<arith::ArithDialect>())
|
||||
return mlir::failure();
|
||||
}
|
||||
if (!replaceOperand)
|
||||
continue;
|
||||
IRMapping mapping;
|
||||
rematerializeConversionChain(toConvert, rewriter, processed, mapping);
|
||||
rewriter.replaceOp(cvt, mapping.lookup(cvt->getOperand(0)));
|
||||
changed = true;
|
||||
}
|
||||
return mlir::success(changed);
|
||||
if (!checkOp)
|
||||
return mlir::failure();
|
||||
|
||||
auto cvtTy = cvt.getType().cast<RankedTensorType>();
|
||||
auto cvtArgOp = cvt.getSrc().getDefiningOp();
|
||||
if (!cvtArgOp || cvtArgOp->getNumOperands() == 0)
|
||||
return mlir::failure();
|
||||
// only consider custom conversions or arith ops
|
||||
if (!isa<triton::FpToFpOp, triton::BitcastOp>(cvtArgOp) &&
|
||||
cvtArgOp->getDialect()->getTypeID() !=
|
||||
mlir::TypeID::get<arith::ArithDialect>())
|
||||
return mlir::failure();
|
||||
// only considers conversions to dot operand
|
||||
if (!cvtTy.getEncoding().isa<triton::gpu::DotOperandEncodingAttr>())
|
||||
return mlir::failure();
|
||||
auto argTy = cvtArgOp->getOperand(0).getType().cast<RankedTensorType>();
|
||||
auto retTy = cvtArgOp->getResult(0).getType().cast<RankedTensorType>();
|
||||
if (!argTy || !retTy)
|
||||
return mlir::failure();
|
||||
Type newRetTy = RankedTensorType::get(
|
||||
retTy.getShape(), retTy.getElementType(), cvtTy.getEncoding());
|
||||
Type newCvtTy = RankedTensorType::get(
|
||||
retTy.getShape(), argTy.getElementType(), cvtTy.getEncoding());
|
||||
int numArgs = cvtArgOp->getNumOperands();
|
||||
SmallVector<triton::gpu::ConvertLayoutOp> newCvts(numArgs);
|
||||
for (int i = 0; i < numArgs; i++)
|
||||
newCvts[i] = rewriter.create<triton::gpu::ConvertLayoutOp>(
|
||||
cvt.getLoc(), newCvtTy, cvtArgOp->getOperand(i));
|
||||
auto newRet = rewriter.clone(*cvtArgOp);
|
||||
for (int i = 0; i < numArgs; i++)
|
||||
newRet->setOperand(i, newCvts[i]);
|
||||
newRet->getResult(0).setType(newRetTy);
|
||||
rewriter.replaceOp(op, newRet->getResults());
|
||||
return mlir::success();
|
||||
}
|
||||
};
|
||||
|
||||
} // namespace
|
||||
|
||||
static bool isConvertToDotEncoding(Operation *op) {
|
||||
auto convertLayout = llvm::dyn_cast<ConvertLayoutOp>(op);
|
||||
if (!convertLayout)
|
||||
return false;
|
||||
auto tensorType =
|
||||
convertLayout.getResult().getType().cast<RankedTensorType>();
|
||||
return tensorType.getEncoding().isa<DotOperandEncodingAttr>();
|
||||
}
|
||||
|
||||
static ConvertLayoutOp updateConvert(OpBuilder &builder, ConvertLayoutOp cvt,
|
||||
IRMapping &mapping, Type smallestType) {
|
||||
auto cvtDstTy = cvt.getResult().getType().cast<RankedTensorType>();
|
||||
auto cvtDstEnc = cvtDstTy.getEncoding().cast<DotOperandEncodingAttr>();
|
||||
Value operand = cvt.getOperand();
|
||||
if (mapping.contains(operand))
|
||||
operand = mapping.lookup(operand);
|
||||
auto newDstTy = RankedTensorType::get(
|
||||
cvtDstTy.getShape(), cvtDstTy.getElementType(),
|
||||
DotOperandEncodingAttr::get(cvtDstEnc.getContext(), cvtDstEnc.getOpIdx(),
|
||||
cvtDstEnc.getParent(), smallestType));
|
||||
auto newCvt =
|
||||
builder.create<ConvertLayoutOp>(cvt.getLoc(), newDstTy, operand);
|
||||
mapping.map(cvt.getResult(), newCvt.getResult());
|
||||
return newCvt;
|
||||
}
|
||||
|
||||
// Update kWidth based on the smallestType found in the given convert ops and
|
||||
// propagate the type change.
|
||||
static void
|
||||
updateDotEncodingLayout(SmallVector<ConvertLayoutOp> &convertsToDotEncoding,
|
||||
Type smallestType) {
|
||||
IRMapping mapping;
|
||||
OpBuilder builder(smallestType.getContext());
|
||||
SetVector<Operation *> slices(convertsToDotEncoding.begin(),
|
||||
convertsToDotEncoding.end());
|
||||
// Collect all the operations where the type needs to be propagated.
|
||||
for (auto cvt : convertsToDotEncoding) {
|
||||
auto forwardFilter = [&](Operation *op) {
|
||||
if (op == cvt.getOperation())
|
||||
return true;
|
||||
for (Value operand : op->getOperands()) {
|
||||
auto tensorType = operand.getType().dyn_cast<RankedTensorType>();
|
||||
if (tensorType &&
|
||||
tensorType.getEncoding().isa<DotOperandEncodingAttr>())
|
||||
return true;
|
||||
}
|
||||
return false;
|
||||
};
|
||||
auto backwardFilter = [&](Operation *op) {
|
||||
for (Value results : op->getResults()) {
|
||||
auto tensorType = results.getType().dyn_cast<RankedTensorType>();
|
||||
if (tensorType &&
|
||||
tensorType.getEncoding().isa<DotOperandEncodingAttr>())
|
||||
return true;
|
||||
}
|
||||
return false;
|
||||
};
|
||||
SetVector<Operation *> opSlice =
|
||||
getSlice(cvt.getOperation(), {backwardFilter}, {forwardFilter});
|
||||
slices.insert(opSlice.begin(), opSlice.end());
|
||||
}
|
||||
// Apply the type change by walking ops in topological order.
|
||||
slices = mlir::topologicalSort(slices);
|
||||
for (Operation *op : slices) {
|
||||
builder.setInsertionPoint(op);
|
||||
if (isConvertToDotEncoding(op)) {
|
||||
auto cvt = cast<ConvertLayoutOp>(op);
|
||||
ConvertLayoutOp newCvt =
|
||||
updateConvert(builder, cvt, mapping, smallestType);
|
||||
continue;
|
||||
}
|
||||
auto *newOp = cloneWithInferType(builder, op, mapping);
|
||||
for (auto [result, newResult] :
|
||||
llvm::zip(op->getResults(), newOp->getResults())) {
|
||||
result.replaceUsesWithIf(newResult, [&](OpOperand &operand) {
|
||||
return slices.count(operand.getOwner()) == 0;
|
||||
});
|
||||
}
|
||||
}
|
||||
for (Operation *op : llvm::reverse(slices))
|
||||
op->erase();
|
||||
}
|
||||
|
||||
// Change the layout of dotOperand layout to use the kWidth from the smallest
|
||||
// loaded type. This allows better code generation for mixed-mode matmul.
|
||||
static void optimizeKWidth(triton::FuncOp func) {
|
||||
SmallVector<ConvertLayoutOp> convertsToDotEncoding;
|
||||
Type smallestType;
|
||||
func->walk([&](triton::LoadOp loadOp) {
|
||||
if (!loadOp.getResult().hasOneUse())
|
||||
return;
|
||||
Operation *use = *loadOp.getResult().getUsers().begin();
|
||||
|
||||
// Advance to the first conversion as long as the use resides in shared
|
||||
// memory and it has a single use itself
|
||||
while (use) {
|
||||
if (use->getNumResults() != 1 || !use->getResult(0).hasOneUse())
|
||||
break;
|
||||
auto tensorType =
|
||||
use->getResult(0).getType().dyn_cast<RankedTensorType>();
|
||||
if (!tensorType || !tensorType.getEncoding().isa<SharedEncodingAttr>())
|
||||
break;
|
||||
use = *use->getResult(0).getUsers().begin();
|
||||
}
|
||||
|
||||
auto convertLayout = llvm::dyn_cast<ConvertLayoutOp>(use);
|
||||
if (!convertLayout)
|
||||
return;
|
||||
auto tensorType =
|
||||
convertLayout.getResult().getType().cast<RankedTensorType>();
|
||||
if (!tensorType.getEncoding().isa<DotOperandEncodingAttr>())
|
||||
return;
|
||||
convertsToDotEncoding.push_back(convertLayout);
|
||||
|
||||
// Update the smallest type.
|
||||
auto ty = loadOp.getType().cast<RankedTensorType>();
|
||||
Type eltTy = ty.getElementType();
|
||||
if (!smallestType ||
|
||||
(eltTy.getIntOrFloatBitWidth() < smallestType.getIntOrFloatBitWidth()))
|
||||
smallestType = eltTy;
|
||||
});
|
||||
if (!smallestType)
|
||||
return;
|
||||
updateDotEncodingLayout(convertsToDotEncoding, smallestType);
|
||||
}
|
||||
|
||||
#define GEN_PASS_CLASSES
|
||||
#include "triton/Dialect/TritonGPU/Transforms/Passes.h.inc"
|
||||
|
||||
@@ -351,10 +169,6 @@ public:
|
||||
signalPassFailure();
|
||||
if (fixupLoops(m).failed())
|
||||
signalPassFailure();
|
||||
|
||||
// Change the layout of dotOperand layout to use the kWidth from the
|
||||
// smallest loaded type.
|
||||
m->walk([](triton::FuncOp func) { optimizeKWidth(func); });
|
||||
}
|
||||
};
|
||||
|
||||
|
||||
@@ -493,9 +493,7 @@ void LoopPipeliner::createBufferTypes() {
|
||||
SmallVector<int64_t> bufferShape(ty.getShape().begin(),
|
||||
ty.getShape().end());
|
||||
bufferShape.insert(bufferShape.begin(), numStages);
|
||||
unsigned bitWidth = dotOpEnc.getMMAv2kWidth()
|
||||
? 32 / dotOpEnc.getMMAv2kWidth()
|
||||
: ty.getElementType().getIntOrFloatBitWidth();
|
||||
unsigned bitWidth = ty.getElementType().getIntOrFloatBitWidth();
|
||||
auto sharedEnc =
|
||||
ttg::SharedEncodingAttr::get(ty.getContext(), dotOpEnc, ty.getShape(),
|
||||
ttg::getOrder(ty.getEncoding()), bitWidth);
|
||||
|
||||
@@ -89,13 +89,21 @@ def f8_to_f16(x, dtype):
|
||||
(128, 256, 32, 1, 8, 2, None, None, None, AT, BT, ADTYPE, BDTYPE),
|
||||
(32, 64, 32, 1, 1, 2, 64, 128, 32, AT, BT, ADTYPE, BDTYPE),
|
||||
(128, 128, 32, 8, 4, 2, 256, 256, 128, AT, BT, ADTYPE, BDTYPE),
|
||||
] for ADTYPE, BDTYPE in [("float8e4b15", "float8e5"),
|
||||
] for ADTYPE, BDTYPE in [("float8e4", "float8e5"),
|
||||
("float8e4", "float16"),
|
||||
("float16", "float8e5"),
|
||||
("float16", "float32"),
|
||||
("float32", "float16"),
|
||||
("bfloat16", "float32"),
|
||||
("float32", "bfloat16")] for AT in [False, True] for BT in [False, True]
|
||||
],
|
||||
*[
|
||||
# float8e4b15 only supports row-col layout
|
||||
[
|
||||
(128, 128, 32, 1, 4, 2, None, None, None, False, True, ADTYPE, BDTYPE),
|
||||
] for ADTYPE, BDTYPE in [("float8e4b15", "float8e5"),
|
||||
("float8e4b15", "float16"),
|
||||
("float16", "float8e4b15")]
|
||||
]
|
||||
),
|
||||
)
|
||||
@@ -132,7 +140,7 @@ def test_op(BLOCK_M, BLOCK_N, BLOCK_K, SPLIT_K, NWARP, NSTAGE, M, N, K, AT, BT,
|
||||
if t:
|
||||
return init_input(m, n, False, dtype, is_float8).t()
|
||||
if is_float8:
|
||||
return torch.randint(20, 60, (n, m), device="cuda", dtype=torch.int8)
|
||||
return torch.randint(20, 50, (n, m), device="cuda", dtype=torch.int8)
|
||||
dtype = {"float16": torch.float16, "bfloat16": torch.bfloat16, "float32": torch.float32}[dtype]
|
||||
return .1 * torch.randn((n, m), device="cuda", dtype=dtype)
|
||||
|
||||
|
||||
@@ -1,8 +1,12 @@
|
||||
// RUN: triton-opt %s -split-input-file -tritongpu-optimize-dot-operands -tritongpu-remove-layout-conversions -canonicalize | FileCheck %s
|
||||
|
||||
#Cv2 = #triton_gpu.mma<{versionMajor = 2, warpsPerCTA = [4, 1]}>
|
||||
#Av2 = #triton_gpu.dot_op<{opIdx = 0, parent = #Cv2, kWidth=2}>
|
||||
#Bv2 = #triton_gpu.dot_op<{opIdx = 1, parent = #Cv2, kWidth=2}>
|
||||
#Av2k1 = #triton_gpu.dot_op<{opIdx = 0, parent = #Cv2, kWidth=1}>
|
||||
#Bv2k1 = #triton_gpu.dot_op<{opIdx = 1, parent = #Cv2, kWidth=1}>
|
||||
#Av2k2 = #triton_gpu.dot_op<{opIdx = 0, parent = #Cv2, kWidth=2}>
|
||||
#Bv2k2 = #triton_gpu.dot_op<{opIdx = 1, parent = #Cv2, kWidth=2}>
|
||||
#Av2k4 = #triton_gpu.dot_op<{opIdx = 0, parent = #Cv2, kWidth=4}>
|
||||
#Bv2k4 = #triton_gpu.dot_op<{opIdx = 1, parent = #Cv2, kWidth=4}>
|
||||
#Cv1 = #triton_gpu.mma<{versionMajor = 1, warpsPerCTA = [4, 1]}>
|
||||
#Av1 = #triton_gpu.dot_op<{opIdx = 0, parent = #Cv1}>
|
||||
#Bv1 = #triton_gpu.dot_op<{opIdx = 1, parent = #Cv1}>
|
||||
@@ -13,7 +17,7 @@
|
||||
|
||||
module attributes {"triton_gpu.num-warps" = 4 : i32} {
|
||||
|
||||
// CHECK: tt.func @push_elementwise1
|
||||
// CHECK: tt.func @push_elementwise
|
||||
// CHECK: %[[ALOAD:.*]] = tt.load %arg0
|
||||
// CHECK: %[[ACVT:.*]] = triton_gpu.convert_layout %[[ALOAD]] {{.*}} #triton_gpu.dot_op<{opIdx = 0, parent = #mma, kWidth = 4}>>
|
||||
// CHECK: %[[AF8E5:.*]] = tt.bitcast %[[ACVT]] {{.*}} #triton_gpu.dot_op<{opIdx = 0, parent = #mma, kWidth = 4}>>
|
||||
@@ -22,7 +26,7 @@ module attributes {"triton_gpu.num-warps" = 4 : i32} {
|
||||
// CHECK: %[[C:.*]] = tt.dot %[[AF16]], %[[BCVT]]
|
||||
// CHECK-SAME: tensor<16x16xf16, #triton_gpu.dot_op<{opIdx = 0, parent = #mma, kWidth = 4}>> * tensor<16x16xf16, #triton_gpu.dot_op<{opIdx = 1, parent = #mma, kWidth = 4}>> -> tensor<16x16xf32, #mma>
|
||||
// CHECK: tt.return %[[C]] : tensor<16x16xf32, #mma>
|
||||
tt.func @push_elementwise1(
|
||||
tt.func @push_elementwise(
|
||||
%pa: tensor<16x16x!tt.ptr<i8>, #ALR> {tt.divisibility=16: i32, tt.contiguity=2 : i32},
|
||||
%pb: tensor<16x16x!tt.ptr<f16>, #BLC> {tt.divisibility=16: i32, tt.contiguity=2 : i32},
|
||||
%c: tensor<16x16xf32, #Cv2>) -> tensor<16x16xf32, #Cv2>{
|
||||
@@ -30,103 +34,13 @@ tt.func @push_elementwise1(
|
||||
%b = tt.load %pb {cache = 1 : i32, evict = 1 : i32, isVolatile = false} : tensor<16x16xf16, #BLC>
|
||||
%af8 = tt.bitcast %ai8: tensor<16x16xi8, #ALR> -> tensor<16x16xf8E5M2, #ALR>
|
||||
%a = tt.fp_to_fp %af8: tensor<16x16xf8E5M2, #ALR> -> tensor<16x16xf16, #ALR>
|
||||
%dota = triton_gpu.convert_layout %a : (tensor<16x16xf16, #ALR>) -> tensor<16x16xf16, #Av2>
|
||||
%dotb = triton_gpu.convert_layout %b : (tensor<16x16xf16, #BLC>) -> tensor<16x16xf16, #Bv2>
|
||||
%newc = tt.dot %dota, %dotb, %c {allowTF32 = true, transA = false, transB = false} : tensor<16x16xf16, #Av2> * tensor<16x16xf16, #Bv2> -> tensor<16x16xf32, #Cv2>
|
||||
%dota = triton_gpu.convert_layout %a : (tensor<16x16xf16, #ALR>) -> tensor<16x16xf16, #Av2k4>
|
||||
%dotb = triton_gpu.convert_layout %b : (tensor<16x16xf16, #BLC>) -> tensor<16x16xf16, #Bv2k4>
|
||||
%newc = tt.dot %dota, %dotb, %c {allowTF32 = true, transA = false, transB = false} : tensor<16x16xf16, #Av2k4> * tensor<16x16xf16, #Bv2k4> -> tensor<16x16xf32, #Cv2>
|
||||
tt.return %newc : tensor<16x16xf32, #Cv2>
|
||||
}
|
||||
|
||||
|
||||
// Not modified for row-row
|
||||
// CHECK: tt.func @push_elementwise2
|
||||
// CHECK: %[[ALOAD:.*]] = tt.load %arg0
|
||||
// CHECK: %[[AF8E5:.*]] = tt.bitcast %[[ALOAD]]
|
||||
// CHECK: %[[AF16:.*]] = tt.fp_to_fp %[[AF8E5]]
|
||||
// CHECK: %[[ACVT:.*]] = triton_gpu.convert_layout %[[AF16]]
|
||||
// CHECK: %[[C:.*]] = tt.dot %[[ACVT]]
|
||||
// CHECK: tt.return %[[C]] : tensor<16x16xf32, #mma>
|
||||
tt.func @push_elementwise2(
|
||||
%pa: tensor<16x16x!tt.ptr<i8>, #ALR> {tt.divisibility=16: i32, tt.contiguity=2 : i32},
|
||||
%pb: tensor<16x16x!tt.ptr<f16>, #BLR> {tt.divisibility=16: i32, tt.contiguity=2 : i32},
|
||||
%c: tensor<16x16xf32, #Cv2>) -> tensor<16x16xf32, #Cv2>{
|
||||
%ai8 = tt.load %pa {cache = 1 : i32, evict = 1 : i32, isVolatile = false} : tensor<16x16xi8, #ALR>
|
||||
%b = tt.load %pb {cache = 1 : i32, evict = 1 : i32, isVolatile = false} : tensor<16x16xf16, #BLR>
|
||||
%af8 = tt.bitcast %ai8: tensor<16x16xi8, #ALR> -> tensor<16x16xf8E5M2, #ALR>
|
||||
%a = tt.fp_to_fp %af8: tensor<16x16xf8E5M2, #ALR> -> tensor<16x16xf16, #ALR>
|
||||
%dota = triton_gpu.convert_layout %a : (tensor<16x16xf16, #ALR>) -> tensor<16x16xf16, #Av2>
|
||||
%dotb = triton_gpu.convert_layout %b : (tensor<16x16xf16, #BLR>) -> tensor<16x16xf16, #Bv2>
|
||||
%newc = tt.dot %dota, %dotb, %c {allowTF32 = true, transA = false, transB = false} : tensor<16x16xf16, #Av2> * tensor<16x16xf16, #Bv2> -> tensor<16x16xf32, #Cv2>
|
||||
tt.return %newc : tensor<16x16xf32, #Cv2>
|
||||
}
|
||||
|
||||
|
||||
// Not modified for col-row
|
||||
// CHECK: tt.func @push_elementwise3
|
||||
// CHECK: %[[ALOAD:.*]] = tt.load %arg0
|
||||
// CHECK: %[[AF8E5:.*]] = tt.bitcast %[[ALOAD]]
|
||||
// CHECK: %[[AF16:.*]] = tt.fp_to_fp %[[AF8E5]]
|
||||
// CHECK: %[[ACVT:.*]] = triton_gpu.convert_layout %[[AF16]]
|
||||
// CHECK: %[[C:.*]] = tt.dot %[[ACVT]]
|
||||
// CHECK: tt.return %[[C]] : tensor<16x16xf32, #mma>
|
||||
tt.func @push_elementwise3(
|
||||
%pa: tensor<16x16x!tt.ptr<i8>, #ALC> {tt.divisibility=16: i32, tt.contiguity=2 : i32},
|
||||
%pb: tensor<16x16x!tt.ptr<f16>, #BLR> {tt.divisibility=16: i32, tt.contiguity=2 : i32},
|
||||
%c: tensor<16x16xf32, #Cv2>) -> tensor<16x16xf32, #Cv2>{
|
||||
%ai8 = tt.load %pa {cache = 1 : i32, evict = 1 : i32, isVolatile = false} : tensor<16x16xi8, #ALC>
|
||||
%b = tt.load %pb {cache = 1 : i32, evict = 1 : i32, isVolatile = false} : tensor<16x16xf16, #BLR>
|
||||
%af8 = tt.bitcast %ai8: tensor<16x16xi8, #ALC> -> tensor<16x16xf8E5M2, #ALC>
|
||||
%a = tt.fp_to_fp %af8: tensor<16x16xf8E5M2, #ALC> -> tensor<16x16xf16, #ALC>
|
||||
%dota = triton_gpu.convert_layout %a : (tensor<16x16xf16, #ALC>) -> tensor<16x16xf16, #Av2>
|
||||
%dotb = triton_gpu.convert_layout %b : (tensor<16x16xf16, #BLR>) -> tensor<16x16xf16, #Bv2>
|
||||
%newc = tt.dot %dota, %dotb, %c {allowTF32 = true, transA = false, transB = false} : tensor<16x16xf16, #Av2> * tensor<16x16xf16, #Bv2> -> tensor<16x16xf32, #Cv2>
|
||||
tt.return %newc : tensor<16x16xf32, #Cv2>
|
||||
}
|
||||
|
||||
// Not modified for col-col
|
||||
// CHECK: tt.func @push_elementwise4
|
||||
// CHECK: %[[ALOAD:.*]] = tt.load %arg0
|
||||
// CHECK: %[[AF8E5:.*]] = tt.bitcast %[[ALOAD]]
|
||||
// CHECK: %[[AF16:.*]] = tt.fp_to_fp %[[AF8E5]]
|
||||
// CHECK: %[[ACVT:.*]] = triton_gpu.convert_layout %[[AF16]]
|
||||
// CHECK: %[[C:.*]] = tt.dot %[[ACVT]]
|
||||
// CHECK: tt.return %[[C]] : tensor<16x16xf32, #mma>
|
||||
tt.func @push_elementwise4(
|
||||
%pa: tensor<16x16x!tt.ptr<i8>, #ALC> {tt.divisibility=16: i32, tt.contiguity=2 : i32},
|
||||
%pb: tensor<16x16x!tt.ptr<f16>, #BLC> {tt.divisibility=16: i32, tt.contiguity=2 : i32},
|
||||
%c: tensor<16x16xf32, #Cv2>) -> tensor<16x16xf32, #Cv2>{
|
||||
%ai8 = tt.load %pa {cache = 1 : i32, evict = 1 : i32, isVolatile = false} : tensor<16x16xi8, #ALC>
|
||||
%b = tt.load %pb {cache = 1 : i32, evict = 1 : i32, isVolatile = false} : tensor<16x16xf16, #BLC>
|
||||
%af8 = tt.bitcast %ai8: tensor<16x16xi8, #ALC> -> tensor<16x16xf8E5M2, #ALC>
|
||||
%a = tt.fp_to_fp %af8: tensor<16x16xf8E5M2, #ALC> -> tensor<16x16xf16, #ALC>
|
||||
%dota = triton_gpu.convert_layout %a : (tensor<16x16xf16, #ALC>) -> tensor<16x16xf16, #Av2>
|
||||
%dotb = triton_gpu.convert_layout %b : (tensor<16x16xf16, #BLC>) -> tensor<16x16xf16, #Bv2>
|
||||
%newc = tt.dot %dota, %dotb, %c {allowTF32 = true, transA = false, transB = false} : tensor<16x16xf16, #Av2> * tensor<16x16xf16, #Bv2> -> tensor<16x16xf32, #Cv2>
|
||||
tt.return %newc : tensor<16x16xf32, #Cv2>
|
||||
}
|
||||
|
||||
|
||||
// Not modified for Volta
|
||||
// CHECK: tt.func @push_elementwise5
|
||||
// CHECK: %[[ALOAD:.*]] = tt.load %arg0
|
||||
// CHECK: %[[AF8E5:.*]] = tt.bitcast %[[ALOAD]]
|
||||
// CHECK: %[[AF16:.*]] = tt.fp_to_fp %[[AF8E5]]
|
||||
// CHECK: %[[ACVT:.*]] = triton_gpu.convert_layout %[[AF16]]
|
||||
// CHECK: %[[C:.*]] = tt.dot %[[ACVT]]
|
||||
// CHECK: tt.return %[[C]] : tensor<16x16xf32, #mma1>
|
||||
tt.func @push_elementwise5(
|
||||
%pa: tensor<16x16x!tt.ptr<i8>, #ALR> {tt.divisibility=16: i32, tt.contiguity=2 : i32},
|
||||
%pb: tensor<16x16x!tt.ptr<f16>, #BLC> {tt.divisibility=16: i32, tt.contiguity=2 : i32},
|
||||
%c: tensor<16x16xf32, #Cv1>) -> tensor<16x16xf32, #Cv1>{
|
||||
%ai8 = tt.load %pa {cache = 1 : i32, evict = 1 : i32, isVolatile = false} : tensor<16x16xi8, #ALR>
|
||||
%b = tt.load %pb {cache = 1 : i32, evict = 1 : i32, isVolatile = false} : tensor<16x16xf16, #BLC>
|
||||
%af8 = tt.bitcast %ai8: tensor<16x16xi8, #ALR> -> tensor<16x16xf8E5M2, #ALR>
|
||||
%a = tt.fp_to_fp %af8: tensor<16x16xf8E5M2, #ALR> -> tensor<16x16xf16, #ALR>
|
||||
%dota = triton_gpu.convert_layout %a : (tensor<16x16xf16, #ALR>) -> tensor<16x16xf16, #Av1>
|
||||
%dotb = triton_gpu.convert_layout %b : (tensor<16x16xf16, #BLC>) -> tensor<16x16xf16, #Bv1>
|
||||
%newc = tt.dot %dota, %dotb, %c {allowTF32 = true, transA = false, transB = false} : tensor<16x16xf16, #Av1> * tensor<16x16xf16, #Bv1> -> tensor<16x16xf32, #Cv1>
|
||||
tt.return %newc : tensor<16x16xf32, #Cv1>
|
||||
}
|
||||
|
||||
// CHECK: tt.func @succeeds_if_arg_is_not_convert_layout
|
||||
// CHECK: %[[ALOAD:.*]] = tt.load %arg0
|
||||
// CHECK: %[[ACVT:.*]] = triton_gpu.convert_layout %[[ALOAD]]
|
||||
@@ -139,12 +53,12 @@ tt.func @succeeds_if_arg_is_not_convert_layout(
|
||||
%pb: tensor<16x16x!tt.ptr<f16>, #BLC> {tt.divisibility=16: i32, tt.contiguity=2 : i32},
|
||||
%c: tensor<16x16xf32, #Cv2>) -> tensor<16x16xf32, #Cv2>{
|
||||
%ai8 = tt.load %pa {cache = 1 : i32, evict = 1 : i32, isVolatile = false} : tensor<16x16xi8, #ALR>
|
||||
%dotai8 = triton_gpu.convert_layout %ai8 : (tensor<16x16xi8, #ALR>) -> tensor<16x16xi8, #Av2>
|
||||
%dotai8 = triton_gpu.convert_layout %ai8 : (tensor<16x16xi8, #ALR>) -> tensor<16x16xi8, #Av2k4>
|
||||
%b = tt.load %pb {cache = 1 : i32, evict = 1 : i32, isVolatile = false} : tensor<16x16xf16, #BLC>
|
||||
%dotaf8 = tt.bitcast %dotai8 : tensor<16x16xi8, #Av2> -> tensor<16x16xf8E5M2, #Av2>
|
||||
%dota = tt.fp_to_fp %dotaf8 : tensor<16x16xf8E5M2, #Av2> -> tensor<16x16xf16, #Av2>
|
||||
%dotb = triton_gpu.convert_layout %b : (tensor<16x16xf16, #BLC>) -> tensor<16x16xf16, #Bv2>
|
||||
%newc = tt.dot %dota, %dotb, %c {allowTF32 = true, transA = false, transB = false} : tensor<16x16xf16, #Av2> * tensor<16x16xf16, #Bv2> -> tensor<16x16xf32, #Cv2>
|
||||
%dotaf8 = tt.bitcast %dotai8 : tensor<16x16xi8, #Av2k4> -> tensor<16x16xf8E5M2, #Av2k4>
|
||||
%dota = tt.fp_to_fp %dotaf8 : tensor<16x16xf8E5M2, #Av2k4> -> tensor<16x16xf16, #Av2k4>
|
||||
%dotb = triton_gpu.convert_layout %b : (tensor<16x16xf16, #BLC>) -> tensor<16x16xf16, #Bv2k4>
|
||||
%newc = tt.dot %dota, %dotb, %c {allowTF32 = true, transA = false, transB = false} : tensor<16x16xf16, #Av2k4> * tensor<16x16xf16, #Bv2k4> -> tensor<16x16xf32, #Cv2>
|
||||
tt.return %newc : tensor<16x16xf32, #Cv2>
|
||||
}
|
||||
|
||||
@@ -163,10 +77,10 @@ module attributes {"triton_gpu.num-warps" = 4 : i32} {
|
||||
|
||||
// CHECK: tt.func @push_convert_both_operands
|
||||
// CHECK: %[[ALOAD:.*]] = tt.load %{{.*}} {cache = 1 : i32, evict = 1 : i32, isVolatile = false} : tensor<16x16xf16, #[[BA]]>
|
||||
// CHECK: %[[ACVT:.*]] = triton_gpu.convert_layout %[[ALOAD]] : (tensor<16x16xf16, #[[BA]]>) -> tensor<16x16xf16, #triton_gpu.dot_op<{opIdx = 0, parent = #[[MMA]], kWidth = 2}>>
|
||||
// CHECK: %[[BLOAD:.*]] = tt.load %{{.*}} {cache = 1 : i32, evict = 1 : i32, isVolatile = false} : tensor<16x16xf16, #[[BB]]>
|
||||
// CHECK: %[[BCVT:.*]] = triton_gpu.convert_layout %[[BLOAD]] : (tensor<16x16xf16, #[[BB]]>) -> tensor<16x16xf16, #triton_gpu.dot_op<{opIdx = 1, parent = #[[MMA]], kWidth = 2}>>
|
||||
// CHECK: %[[ACVT:.*]] = triton_gpu.convert_layout %[[ALOAD]] : (tensor<16x16xf16, #[[BA]]>) -> tensor<16x16xf16, #triton_gpu.dot_op<{opIdx = 0, parent = #[[MMA]], kWidth = 2}>>
|
||||
// CHECK: %[[AEXT:.*]] = arith.extf %[[ACVT]] : tensor<16x16xf16, #triton_gpu.dot_op<{opIdx = 0, parent = #[[MMA]], kWidth = 2}>> to tensor<16x16xf32, #triton_gpu.dot_op<{opIdx = 0, parent = #[[MMA]], kWidth = 2}>>
|
||||
// CHECK: %[[BCVT:.*]] = triton_gpu.convert_layout %[[BLOAD]] : (tensor<16x16xf16, #[[BB]]>) -> tensor<16x16xf16, #triton_gpu.dot_op<{opIdx = 1, parent = #[[MMA]], kWidth = 2}>>
|
||||
// CHECK: %[[BEXT:.*]] = arith.extf %[[BCVT]] : tensor<16x16xf16, #triton_gpu.dot_op<{opIdx = 1, parent = #[[MMA]], kWidth = 2}>> to tensor<16x16xf32, #triton_gpu.dot_op<{opIdx = 1, parent = #[[MMA]], kWidth = 2}>>
|
||||
// CHECK: tt.dot %[[AEXT]], %[[BEXT]], %{{.*}} {allowTF32 = true} : tensor<16x16xf32, #triton_gpu.dot_op<{opIdx = 0, parent = #[[MMA]], kWidth = 2}>> * tensor<16x16xf32, #triton_gpu.dot_op<{opIdx = 1, parent = #[[MMA]], kWidth = 2}>> -> tensor<16x16xf32, #mma>
|
||||
tt.func @push_convert_both_operands(
|
||||
@@ -177,9 +91,9 @@ tt.func @push_convert_both_operands(
|
||||
%b = tt.load %pb {cache = 1 : i32, evict = 1 : i32, isVolatile = false} : tensor<16x16xf16, #blockedB>
|
||||
%ae = arith.extf %a : tensor<16x16xf16, #blockedA> to tensor<16x16xf32, #blockedA>
|
||||
%be = arith.extf %b : tensor<16x16xf16, #blockedB> to tensor<16x16xf32, #blockedB>
|
||||
%al = triton_gpu.convert_layout %ae : (tensor<16x16xf32, #blockedA>) -> tensor<16x16xf32, #triton_gpu.dot_op<{opIdx = 0, parent = #mma, kWidth = 1}>>
|
||||
%bl = triton_gpu.convert_layout %be : (tensor<16x16xf32, #blockedB>) -> tensor<16x16xf32, #triton_gpu.dot_op<{opIdx = 1, parent = #mma, kWidth = 1}>>
|
||||
%r = tt.dot %al, %bl, %c {allowTF32 = true} : tensor<16x16xf32, #triton_gpu.dot_op<{opIdx = 0, parent = #mma, kWidth = 1}>> * tensor<16x16xf32, #triton_gpu.dot_op<{opIdx = 1, parent = #mma, kWidth = 1}>> -> tensor<16x16xf32, #mma>
|
||||
%al = triton_gpu.convert_layout %ae : (tensor<16x16xf32, #blockedA>) -> tensor<16x16xf32, #triton_gpu.dot_op<{opIdx = 0, parent = #mma, kWidth = 2}>>
|
||||
%bl = triton_gpu.convert_layout %be : (tensor<16x16xf32, #blockedB>) -> tensor<16x16xf32, #triton_gpu.dot_op<{opIdx = 1, parent = #mma, kWidth = 2}>>
|
||||
%r = tt.dot %al, %bl, %c {allowTF32 = true} : tensor<16x16xf32, #triton_gpu.dot_op<{opIdx = 0, parent = #mma, kWidth = 2}>> * tensor<16x16xf32, #triton_gpu.dot_op<{opIdx = 1, parent = #mma, kWidth = 2}>> -> tensor<16x16xf32, #mma>
|
||||
tt.return %r : tensor<16x16xf32, #mma>
|
||||
}
|
||||
|
||||
@@ -199,10 +113,10 @@ module attributes {"triton_gpu.num-warps" = 4 : i32} {
|
||||
// CHECK: tt.func @update_kwidth_slice
|
||||
// CHECK: %[[CST:.+]] = arith.constant dense<1.000000e+00> : tensor<16x16xf32, #triton_gpu.dot_op<{opIdx = 1, parent = #[[MMA]], kWidth = 2}>>
|
||||
// CHECK: %[[ALOAD:.*]] = tt.load %{{.*}} {cache = 1 : i32, evict = 1 : i32, isVolatile = false} : tensor<16x16xf16, #[[BA]]>
|
||||
// CHECK: %[[ACVT:.*]] = triton_gpu.convert_layout %[[ALOAD]] : (tensor<16x16xf16, #[[BA]]>) -> tensor<16x16xf16, #triton_gpu.dot_op<{opIdx = 0, parent = #[[MMA]], kWidth = 2}>>
|
||||
// CHECK: %[[BLOAD:.*]] = tt.load %{{.*}} {cache = 1 : i32, evict = 1 : i32, isVolatile = false} : tensor<16x16xf16, #[[BB]]>
|
||||
// CHECK: %[[BCVT:.*]] = triton_gpu.convert_layout %[[BLOAD]] : (tensor<16x16xf16, #[[BB]]>) -> tensor<16x16xf16, #triton_gpu.dot_op<{opIdx = 1, parent = #[[MMA]], kWidth = 2}>>
|
||||
// CHECK: %[[ACVT:.*]] = triton_gpu.convert_layout %[[ALOAD]] : (tensor<16x16xf16, #[[BA]]>) -> tensor<16x16xf16, #triton_gpu.dot_op<{opIdx = 0, parent = #[[MMA]], kWidth = 2}>>
|
||||
// CHECK: %[[AEXT:.*]] = arith.extf %[[ACVT]] : tensor<16x16xf16, #triton_gpu.dot_op<{opIdx = 0, parent = #[[MMA]], kWidth = 2}>> to tensor<16x16xf32, #triton_gpu.dot_op<{opIdx = 0, parent = #[[MMA]], kWidth = 2}>>
|
||||
// CHECK: %[[BCVT:.*]] = triton_gpu.convert_layout %[[BLOAD]] : (tensor<16x16xf16, #[[BB]]>) -> tensor<16x16xf16, #triton_gpu.dot_op<{opIdx = 1, parent = #[[MMA]], kWidth = 2}>>
|
||||
// CHECK: %[[BEXT:.*]] = arith.extf %[[BCVT]] : tensor<16x16xf16, #triton_gpu.dot_op<{opIdx = 1, parent = #[[MMA]], kWidth = 2}>> to tensor<16x16xf32, #triton_gpu.dot_op<{opIdx = 1, parent = #[[MMA]], kWidth = 2}>>
|
||||
// CHECK: %[[ADD:.+]] = arith.addf %[[BEXT]], %[[CST]] : tensor<16x16xf32, #triton_gpu.dot_op<{opIdx = 1, parent = #[[MMA]], kWidth = 2}>>
|
||||
// CHECK: tt.dot %[[AEXT]], %[[ADD]], %{{.*}} {allowTF32 = true} : tensor<16x16xf32, #triton_gpu.dot_op<{opIdx = 0, parent = #[[MMA]], kWidth = 2}>> * tensor<16x16xf32, #triton_gpu.dot_op<{opIdx = 1, parent = #[[MMA]], kWidth = 2}>> -> tensor<16x16xf32, #mma>
|
||||
@@ -216,9 +130,9 @@ tt.func @update_kwidth_slice(
|
||||
%ae = arith.extf %a : tensor<16x16xf16, #blockedA> to tensor<16x16xf32, #blockedA>
|
||||
%be = arith.extf %b : tensor<16x16xf16, #blockedB> to tensor<16x16xf32, #blockedB>
|
||||
%add = arith.addf %be, %cst : tensor<16x16xf32, #blockedB>
|
||||
%al = triton_gpu.convert_layout %ae : (tensor<16x16xf32, #blockedA>) -> tensor<16x16xf32, #triton_gpu.dot_op<{opIdx = 0, parent = #mma, kWidth = 1}>>
|
||||
%bl = triton_gpu.convert_layout %add : (tensor<16x16xf32, #blockedB>) -> tensor<16x16xf32, #triton_gpu.dot_op<{opIdx = 1, parent = #mma, kWidth = 1}>>
|
||||
%r = tt.dot %al, %bl, %c {allowTF32 = true} : tensor<16x16xf32, #triton_gpu.dot_op<{opIdx = 0, parent = #mma, kWidth = 1}>> * tensor<16x16xf32, #triton_gpu.dot_op<{opIdx = 1, parent = #mma, kWidth = 1}>> -> tensor<16x16xf32, #mma>
|
||||
%al = triton_gpu.convert_layout %ae : (tensor<16x16xf32, #blockedA>) -> tensor<16x16xf32, #triton_gpu.dot_op<{opIdx = 0, parent = #mma, kWidth = 2}>>
|
||||
%bl = triton_gpu.convert_layout %add : (tensor<16x16xf32, #blockedB>) -> tensor<16x16xf32, #triton_gpu.dot_op<{opIdx = 1, parent = #mma, kWidth = 2}>>
|
||||
%r = tt.dot %al, %bl, %c {allowTF32 = true} : tensor<16x16xf32, #triton_gpu.dot_op<{opIdx = 0, parent = #mma, kWidth = 2}>> * tensor<16x16xf32, #triton_gpu.dot_op<{opIdx = 1, parent = #mma, kWidth = 2}>> -> tensor<16x16xf32, #mma>
|
||||
tt.return %r : tensor<16x16xf32, #mma>
|
||||
}
|
||||
|
||||
|
||||
@@ -29,8 +29,8 @@ TEST_P(SwizzleDotOperandTestFixture, DotOperands) {
|
||||
ctx.loadDialect<triton::gpu::TritonGPUDialect>();
|
||||
// create encoding
|
||||
auto parent = triton::gpu::MmaEncodingAttr::get(&ctx, 2, 0, {1, 1});
|
||||
auto encoding =
|
||||
triton::gpu::DotOperandEncodingAttr::get(&ctx, params.opIdx, parent, 0);
|
||||
auto encoding = triton::gpu::DotOperandEncodingAttr::get(
|
||||
&ctx, params.opIdx, parent, 32 / params.typeWidth);
|
||||
|
||||
// create element type
|
||||
Type eltType = IntegerType::get(&ctx, params.typeWidth);
|
||||
|
||||
Reference in New Issue
Block a user