mirror of
https://github.com/ROCm/ROCm.git
synced 2026-04-05 03:01:17 -04:00
[OPTIMIZER] Added kWidth attribute to DotOperandEncoding (#1584)
This is a pre-requisist for efficient mixed-precision matmul
This commit is contained in:
@@ -501,10 +501,22 @@ section 9.7.13.4.1 for more details.
|
||||
let parameters = (
|
||||
ins
|
||||
"unsigned":$opIdx,
|
||||
"Attribute":$parent
|
||||
"Attribute":$parent,
|
||||
"unsigned":$MMAv2kWidth
|
||||
);
|
||||
|
||||
let builders = [
|
||||
// Specially for MMAV1(Volta)
|
||||
AttrBuilder<(ins "unsigned":$opIdx,
|
||||
"Attribute":$parent,
|
||||
"Type":$eltTy), [{
|
||||
MmaEncodingAttr parentAttr = parent.dyn_cast<MmaEncodingAttr>();
|
||||
if (!parentAttr || !parentAttr.isAmpere())
|
||||
return $_get(context, opIdx, parent, 0);
|
||||
unsigned bitwidth = eltTy.getIntOrFloatBitWidth();
|
||||
unsigned MMAv2kWidth = 32 / bitwidth;
|
||||
return $_get(context, opIdx, parent, MMAv2kWidth);
|
||||
}]>
|
||||
];
|
||||
|
||||
let hasCustomAssemblyFormat = 1;
|
||||
|
||||
@@ -325,7 +325,8 @@ private:
|
||||
|
||||
if (needTrans) {
|
||||
// do transpose
|
||||
auto aEncoding = DotOperandEncodingAttr::get(mma.getContext(), 0, mma);
|
||||
auto aEncoding =
|
||||
DotOperandEncodingAttr::get(mma.getContext(), 0, mma, 0);
|
||||
int numM = aEncoding.getMMAv1NumOuter(shape);
|
||||
int numN = accumSizePerThread / numM;
|
||||
|
||||
|
||||
@@ -358,11 +358,11 @@ SmallVector<CoordTy> getMNCoords(Value thread,
|
||||
Value _fpw1 = i32_val(fpw[1]);
|
||||
|
||||
// A info
|
||||
auto aEncoding = DotOperandEncodingAttr::get(ctx, 0, mmaLayout);
|
||||
auto aEncoding = DotOperandEncodingAttr::get(ctx, 0, mmaLayout, 0);
|
||||
auto aRep = aEncoding.getMMAv1Rep();
|
||||
auto aSpw = aEncoding.getMMAv1ShapePerWarp();
|
||||
// B info
|
||||
auto bEncoding = DotOperandEncodingAttr::get(ctx, 1, mmaLayout);
|
||||
auto bEncoding = DotOperandEncodingAttr::get(ctx, 1, mmaLayout, 0);
|
||||
auto bSpw = bEncoding.getMMAv1ShapePerWarp();
|
||||
auto bRep = bEncoding.getMMAv1Rep();
|
||||
|
||||
|
||||
@@ -714,11 +714,11 @@ private:
|
||||
Value _fpw1 = i32_val(fpw[1]);
|
||||
|
||||
// A info
|
||||
auto aEncoding = DotOperandEncodingAttr::get(ctx, 0, mmaLayout);
|
||||
auto aEncoding = DotOperandEncodingAttr::get(ctx, 0, mmaLayout, 0);
|
||||
auto aRep = aEncoding.getMMAv1Rep();
|
||||
auto aSpw = aEncoding.getMMAv1ShapePerWarp();
|
||||
// B info
|
||||
auto bEncoding = DotOperandEncodingAttr::get(ctx, 1, mmaLayout);
|
||||
auto bEncoding = DotOperandEncodingAttr::get(ctx, 1, mmaLayout, 0);
|
||||
auto bSpw = bEncoding.getMMAv1ShapePerWarp();
|
||||
auto bRep = bEncoding.getMMAv1Rep();
|
||||
|
||||
@@ -775,12 +775,12 @@ private:
|
||||
// TODO: seems like the apttern below to get `rep`/`spw` appears quite often
|
||||
// A info
|
||||
auto aEncoding =
|
||||
DotOperandEncodingAttr::get(type.getContext(), 0, mmaLayout);
|
||||
DotOperandEncodingAttr::get(type.getContext(), 0, mmaLayout, 0);
|
||||
auto aRep = aEncoding.getMMAv1Rep();
|
||||
auto aSpw = aEncoding.getMMAv1ShapePerWarp();
|
||||
// B info
|
||||
auto bEncoding =
|
||||
DotOperandEncodingAttr::get(type.getContext(), 1, mmaLayout);
|
||||
DotOperandEncodingAttr::get(type.getContext(), 1, mmaLayout, 0);
|
||||
auto bSpw = bEncoding.getMMAv1ShapePerWarp();
|
||||
auto bRep = bEncoding.getMMAv1Rep();
|
||||
|
||||
|
||||
@@ -268,6 +268,8 @@ struct TritonDotPattern : public OpConversionPattern<triton::DotOp> {
|
||||
// a & b must be of smem layout
|
||||
auto aType = adaptor.getA().getType().cast<RankedTensorType>();
|
||||
auto bType = adaptor.getB().getType().cast<RankedTensorType>();
|
||||
Type aEltType = aType.getElementType();
|
||||
Type bEltType = bType.getElementType();
|
||||
Attribute aEncoding = aType.getEncoding();
|
||||
Attribute bEncoding = bType.getEncoding();
|
||||
if (!aEncoding || !bEncoding)
|
||||
@@ -276,17 +278,17 @@ struct TritonDotPattern : public OpConversionPattern<triton::DotOp> {
|
||||
Value b = adaptor.getB();
|
||||
Value c = adaptor.getC();
|
||||
if (!aEncoding.isa<triton::gpu::DotOperandEncodingAttr>()) {
|
||||
Attribute encoding =
|
||||
triton::gpu::DotOperandEncodingAttr::get(getContext(), 0, dEncoding);
|
||||
auto dstType = RankedTensorType::get(aType.getShape(),
|
||||
aType.getElementType(), encoding);
|
||||
Attribute encoding = triton::gpu::DotOperandEncodingAttr::get(
|
||||
getContext(), 0, dEncoding, aEltType);
|
||||
auto dstType =
|
||||
RankedTensorType::get(aType.getShape(), aEltType, encoding);
|
||||
a = rewriter.create<triton::gpu::ConvertLayoutOp>(a.getLoc(), dstType, a);
|
||||
}
|
||||
if (!bEncoding.isa<triton::gpu::DotOperandEncodingAttr>()) {
|
||||
Attribute encoding =
|
||||
triton::gpu::DotOperandEncodingAttr::get(getContext(), 1, dEncoding);
|
||||
auto dstType = RankedTensorType::get(bType.getShape(),
|
||||
bType.getElementType(), encoding);
|
||||
Attribute encoding = triton::gpu::DotOperandEncodingAttr::get(
|
||||
getContext(), 1, dEncoding, bEltType);
|
||||
auto dstType =
|
||||
RankedTensorType::get(bType.getShape(), bEltType, encoding);
|
||||
b = rewriter.create<triton::gpu::ConvertLayoutOp>(b.getLoc(), dstType, b);
|
||||
}
|
||||
c = rewriter.create<triton::gpu::ConvertLayoutOp>(c.getLoc(), retType, c);
|
||||
|
||||
@@ -774,14 +774,27 @@ Attribute DotOperandEncodingAttr::parse(AsmParser &parser, Type type) {
|
||||
return {};
|
||||
unsigned opIdx = attrs.get("opIdx").cast<IntegerAttr>().getInt();
|
||||
Attribute parent = attrs.get("parent");
|
||||
auto mmaParent = parent.dyn_cast<MmaEncodingAttr>();
|
||||
unsigned kWidth = 0;
|
||||
Attribute _kWidth = attrs.get("kWidth");
|
||||
if (_kWidth) {
|
||||
if (!mmaParent || mmaParent.isVolta()) {
|
||||
auto loc = parser.getNameLoc();
|
||||
parser.emitError(loc, "kWidth only supported for MMAv2+ parent");
|
||||
return Attribute();
|
||||
}
|
||||
kWidth = _kWidth.cast<IntegerAttr>().getInt();
|
||||
}
|
||||
return parser.getChecked<DotOperandEncodingAttr>(parser.getContext(), opIdx,
|
||||
parent);
|
||||
parent, kWidth);
|
||||
}
|
||||
|
||||
void DotOperandEncodingAttr::print(mlir::AsmPrinter &printer) const {
|
||||
auto mmaParent = getParent().dyn_cast<MmaEncodingAttr>();
|
||||
printer << "<{"
|
||||
<< "opIdx = " << getOpIdx() << ", "
|
||||
<< "parent = " << getParent();
|
||||
<< "opIdx = " << getOpIdx() << ", parent = " << getParent();
|
||||
if (mmaParent && mmaParent.isAmpere())
|
||||
printer << ", kWidth = " << getMMAv2kWidth();
|
||||
printer << "}>";
|
||||
}
|
||||
|
||||
|
||||
@@ -170,14 +170,17 @@ public:
|
||||
.cast<triton::gpu::BlockedEncodingAttr>()
|
||||
.getOrder();
|
||||
|
||||
auto newAEncoding = triton::gpu::DotOperandEncodingAttr::get(
|
||||
oldAType.getContext(), 0, newRetType.getEncoding(),
|
||||
oldAType.getElementType());
|
||||
auto newBEncoding = triton::gpu::DotOperandEncodingAttr::get(
|
||||
oldBType.getContext(), 1, newRetType.getEncoding(),
|
||||
oldBType.getElementType());
|
||||
|
||||
auto newAType = RankedTensorType::get(
|
||||
oldAType.getShape(), oldAType.getElementType(),
|
||||
triton::gpu::DotOperandEncodingAttr::get(oldAType.getContext(), 0,
|
||||
newRetType.getEncoding()));
|
||||
oldAType.getShape(), oldAType.getElementType(), newAEncoding);
|
||||
auto newBType = RankedTensorType::get(
|
||||
oldBType.getShape(), oldBType.getElementType(),
|
||||
triton::gpu::DotOperandEncodingAttr::get(oldBType.getContext(), 1,
|
||||
newRetType.getEncoding()));
|
||||
oldBType.getShape(), oldBType.getElementType(), newBEncoding);
|
||||
|
||||
a = rewriter.create<triton::gpu::ConvertLayoutOp>(a.getLoc(), newAType, a);
|
||||
b = rewriter.create<triton::gpu::ConvertLayoutOp>(b.getLoc(), newBType, b);
|
||||
|
||||
@@ -1,3 +1,4 @@
|
||||
#include "Utility.h"
|
||||
#include "mlir/Dialect/Tensor/IR/Tensor.h"
|
||||
#include "mlir/IR/IRMapping.h"
|
||||
#include "mlir/IR/TypeUtilities.h"
|
||||
@@ -42,6 +43,9 @@ class LoopPipeliner {
|
||||
|
||||
/// Loads to be pipelined
|
||||
SetVector<Value> loads;
|
||||
/// Smallest data-type for each load (used to optimize swizzle and
|
||||
/// (create DotOpEncoding layout)
|
||||
DenseMap<Value, Type> loadsSmallestType;
|
||||
/// The value that each load will be mapped to (after layout conversion)
|
||||
DenseMap<Value, Value> loadsMapping;
|
||||
/// load => buffer
|
||||
@@ -256,33 +260,62 @@ LogicalResult LoopPipeliner::initialize() {
|
||||
use = *use->getResult(0).getUsers().begin();
|
||||
}
|
||||
|
||||
if (auto convertLayout = llvm::dyn_cast<ttg::ConvertLayoutOp>(use)) {
|
||||
if (auto tensorType = convertLayout.getResult()
|
||||
.getType()
|
||||
.dyn_cast<RankedTensorType>()) {
|
||||
if (auto dotOpEnc = tensorType.getEncoding()
|
||||
.dyn_cast<ttg::DotOperandEncodingAttr>()) {
|
||||
isCandidate = true;
|
||||
loadsMapping[loadOp] = convertLayout;
|
||||
auto ty = loadOp.getType().cast<RankedTensorType>();
|
||||
SmallVector<int64_t> bufferShape(ty.getShape().begin(),
|
||||
ty.getShape().end());
|
||||
bufferShape.insert(bufferShape.begin(), numStages);
|
||||
auto sharedEnc = ttg::SharedEncodingAttr::get(
|
||||
ty.getContext(), dotOpEnc, ty.getShape(),
|
||||
triton::gpu::getOrder(ty.getEncoding()), ty.getElementType());
|
||||
loadsBufferType[loadOp] = RankedTensorType::get(
|
||||
bufferShape, ty.getElementType(), sharedEnc);
|
||||
}
|
||||
}
|
||||
}
|
||||
} else
|
||||
auto convertLayout = llvm::dyn_cast<ttg::ConvertLayoutOp>(use);
|
||||
if (!convertLayout)
|
||||
continue;
|
||||
auto tensorType =
|
||||
convertLayout.getResult().getType().dyn_cast<RankedTensorType>();
|
||||
if (!tensorType)
|
||||
continue;
|
||||
auto dotOpEnc =
|
||||
tensorType.getEncoding().dyn_cast<ttg::DotOperandEncodingAttr>();
|
||||
if (!dotOpEnc)
|
||||
continue;
|
||||
isCandidate = true;
|
||||
loadsMapping[loadOp] = convertLayout;
|
||||
}
|
||||
|
||||
else
|
||||
isCandidate = false;
|
||||
|
||||
if (isCandidate)
|
||||
loads.insert(loadOp);
|
||||
}
|
||||
|
||||
// we need to find the smallest ocmmon dtype
|
||||
// since this determines the layout of `mma.sync` operands
|
||||
// in mixed-precision mode
|
||||
Type smallestType;
|
||||
for (auto loadCvt : loadsMapping) {
|
||||
auto loadOp = loadCvt.first;
|
||||
auto ty = loadOp.getType().cast<RankedTensorType>();
|
||||
Type eltTy = ty.getElementType();
|
||||
if (!smallestType ||
|
||||
(eltTy.getIntOrFloatBitWidth() < smallestType.getIntOrFloatBitWidth()))
|
||||
smallestType = eltTy;
|
||||
}
|
||||
|
||||
for (auto loadCvt : loadsMapping)
|
||||
loadsSmallestType[loadCvt.first] = smallestType;
|
||||
|
||||
for (auto loadCvt : loadsMapping) {
|
||||
auto loadOp = loadCvt.first;
|
||||
Value cvt = loadCvt.second;
|
||||
auto dotOpEnc = cvt.getType()
|
||||
.cast<RankedTensorType>()
|
||||
.getEncoding()
|
||||
.cast<ttg::DotOperandEncodingAttr>();
|
||||
auto ty = loadOp.getType().cast<RankedTensorType>();
|
||||
SmallVector<int64_t> bufferShape(ty.getShape().begin(),
|
||||
ty.getShape().end());
|
||||
bufferShape.insert(bufferShape.begin(), numStages);
|
||||
auto sharedEnc = ttg::SharedEncodingAttr::get(
|
||||
ty.getContext(), dotOpEnc, ty.getShape(),
|
||||
triton::gpu::getOrder(ty.getEncoding()), loadsSmallestType[loadOp]);
|
||||
loadsBufferType[loadOp] =
|
||||
RankedTensorType::get(bufferShape, ty.getElementType(), sharedEnc);
|
||||
}
|
||||
|
||||
// We have some loads to pipeline
|
||||
if (!loads.empty()) {
|
||||
// Update depArgs & depOps
|
||||
@@ -551,8 +584,15 @@ scf::ForOp LoopPipeliner::createNewForOp() {
|
||||
}
|
||||
// we replace the use new load use with a convert layout
|
||||
size_t i = std::distance(loads.begin(), it);
|
||||
auto cvtDstTy = op.getResult(0).getType().cast<RankedTensorType>();
|
||||
auto cvtDstEnc = cvtDstTy.getEncoding().cast<ttg::DotOperandEncodingAttr>();
|
||||
auto newDstTy = RankedTensorType::get(
|
||||
cvtDstTy.getShape(), cvtDstTy.getElementType(),
|
||||
ttg::DotOperandEncodingAttr::get(
|
||||
cvtDstEnc.getContext(), cvtDstEnc.getOpIdx(), cvtDstEnc.getParent(),
|
||||
loadsSmallestType[op.getOperand(0)]));
|
||||
auto cvt = builder.create<ttg::ConvertLayoutOp>(
|
||||
op.getLoc(), op.getResult(0).getType(),
|
||||
op.getResult(0).getLoc(), newDstTy,
|
||||
newForOp.getRegionIterArgs()[loadIdx + i]);
|
||||
mapping.map(op.getResult(0), cvt.getResult());
|
||||
}
|
||||
|
||||
@@ -110,7 +110,7 @@ Value Prefetcher::generatePrefetch(Value v, unsigned opIdx, bool isPrologue,
|
||||
SmallVector<OpFoldResult>{intAttr(1), intAttr(1)});
|
||||
|
||||
auto dotOperandEnc = triton::gpu::DotOperandEncodingAttr::get(
|
||||
builder.getContext(), opIdx, dotEncoding);
|
||||
builder.getContext(), opIdx, dotEncoding, prefetchWidth / 8);
|
||||
Value prefetchSlice = builder.create<triton::gpu::ConvertLayoutOp>(
|
||||
v.getLoc(), RankedTensorType::get(shape, elementType, dotOperandEnc),
|
||||
newSmem);
|
||||
@@ -156,12 +156,22 @@ LogicalResult Prefetcher::initialize() {
|
||||
};
|
||||
|
||||
for (triton::DotOp dot : dotsInFor) {
|
||||
auto kSize = dot.getA().getType().cast<RankedTensorType>().getShape()[1];
|
||||
auto aType = dot.getA().getType().cast<RankedTensorType>();
|
||||
auto bType = dot.getB().getType().cast<RankedTensorType>();
|
||||
auto aEnc = aType.getEncoding().cast<triton::gpu::DotOperandEncodingAttr>();
|
||||
auto bEnc = bType.getEncoding().cast<triton::gpu::DotOperandEncodingAttr>();
|
||||
int aKWidth = aEnc.getMMAv2kWidth();
|
||||
int bKWidth = bEnc.getMMAv2kWidth();
|
||||
assert(aKWidth == bKWidth);
|
||||
|
||||
auto kSize = aType.getShape()[1];
|
||||
|
||||
// works better with nvidia tensor cores
|
||||
unsigned elementWidth =
|
||||
dot.getA().getType().cast<RankedTensorType>().getElementTypeBitWidth();
|
||||
prefetchWidth = 256 / elementWidth;
|
||||
unsigned elementWidth = aType.getElementTypeBitWidth();
|
||||
if (aKWidth == 0)
|
||||
prefetchWidth = 256 / elementWidth;
|
||||
else
|
||||
prefetchWidth = 8 * aKWidth;
|
||||
|
||||
// Skip prefetching if kSize is less than prefetchWidth
|
||||
if (kSize < prefetchWidth)
|
||||
|
||||
@@ -341,7 +341,10 @@ public:
|
||||
cvt.getOperand().getType().cast<RankedTensorType>().getEncoding();
|
||||
auto dstEncoding =
|
||||
cvt.getResult().getType().cast<RankedTensorType>().getEncoding();
|
||||
// XXX: why is this needed?
|
||||
if (srcEncoding.isa<triton::gpu::SharedEncodingAttr>() ||
|
||||
dstEncoding.isa<triton::gpu::SharedEncodingAttr>())
|
||||
return failure();
|
||||
// heuristics for flash attention
|
||||
if (srcEncoding.isa<triton::gpu::SliceEncodingAttr>())
|
||||
return failure();
|
||||
SetVector<Operation *> cvtSlices;
|
||||
|
||||
@@ -206,11 +206,15 @@ int simulateBackwardRematerialization(
|
||||
|
||||
//
|
||||
|
||||
Operation *cloneWithInferType(mlir::PatternRewriter &rewriter, Operation *op,
|
||||
Operation *cloneWithInferType(mlir::OpBuilder &rewriter, Operation *op,
|
||||
IRMapping &mapping) {
|
||||
Operation *newOp = rewriter.clone(*op, mapping);
|
||||
auto origType = op->getResult(0).getType().cast<RankedTensorType>();
|
||||
auto argType = newOp->getOperand(0).getType().cast<RankedTensorType>();
|
||||
if (newOp->getNumResults() == 0)
|
||||
return newOp;
|
||||
auto origType = op->getResult(0).getType().dyn_cast<RankedTensorType>();
|
||||
auto argType = newOp->getOperand(0).getType().dyn_cast<RankedTensorType>();
|
||||
if (!origType || !argType)
|
||||
return newOp;
|
||||
auto newType = RankedTensorType::get(
|
||||
origType.getShape(), origType.getElementType(), argType.getEncoding());
|
||||
newOp->getResult(0).setType(newType);
|
||||
|
||||
@@ -21,7 +21,7 @@ int simulateBackwardRematerialization(
|
||||
SetVector<Attribute> &layout, llvm::MapVector<Value, Attribute> &toConvert,
|
||||
Attribute targetEncoding);
|
||||
|
||||
Operation *cloneWithInferType(mlir::PatternRewriter &rewriter, Operation *op,
|
||||
Operation *cloneWithInferType(mlir::OpBuilder &rewriter, Operation *op,
|
||||
IRMapping &mapping);
|
||||
|
||||
void rematerializeConversionChain(
|
||||
|
||||
BIN
python/triton/third_party/cuda/bin/ptxas
vendored
BIN
python/triton/third_party/cuda/bin/ptxas
vendored
Binary file not shown.
@@ -6,8 +6,8 @@
|
||||
#A_SHARED_T = #triton_gpu.shared<{vec = 2, perPhase = 2, maxPhase = 4, order = [0, 1]}>
|
||||
#B_SHARED = #triton_gpu.shared<{vec = 2, perPhase = 2, maxPhase = 4, order = [1, 0]}>
|
||||
#C = #triton_gpu.mma<{versionMajor = 2, warpsPerCTA = [4, 1]}>
|
||||
#A_DOT = #triton_gpu.dot_op<{opIdx = 0, parent = #C}>
|
||||
#B_DOT = #triton_gpu.dot_op<{opIdx = 1, parent = #C}>
|
||||
#A_DOT = #triton_gpu.dot_op<{opIdx = 0, parent = #C, kWidth=2}>
|
||||
#B_DOT = #triton_gpu.dot_op<{opIdx = 1, parent = #C, kWidth=2}>
|
||||
|
||||
// CHECK-LABEL: matmul_loop
|
||||
// There shouldn't be any aliasing with the dot op encoding.
|
||||
|
||||
@@ -7,8 +7,8 @@
|
||||
#A_SHARED_T = #triton_gpu.shared<{vec = 2, perPhase = 2, maxPhase = 4, order = [0, 1]}>
|
||||
#B_SHARED = #triton_gpu.shared<{vec = 2, perPhase = 2, maxPhase = 4, order = [1, 0]}>
|
||||
#C = #triton_gpu.mma<{versionMajor = 2, warpsPerCTA = [4, 1]}>
|
||||
#A_DOT = #triton_gpu.dot_op<{opIdx = 0, parent = #C}>
|
||||
#B_DOT = #triton_gpu.dot_op<{opIdx = 1, parent = #C}>
|
||||
#A_DOT = #triton_gpu.dot_op<{opIdx = 0, parent = #C, kWidth=2}>
|
||||
#B_DOT = #triton_gpu.dot_op<{opIdx = 1, parent = #C, kWidth=2}>
|
||||
|
||||
module attributes {"triton_gpu.num-warps" = 4 : i32} {
|
||||
|
||||
|
||||
@@ -7,8 +7,8 @@
|
||||
#A_SHARED_T = #triton_gpu.shared<{vec = 2, perPhase = 2, maxPhase = 4, order = [0, 1]}>
|
||||
#B_SHARED = #triton_gpu.shared<{vec = 2, perPhase = 2, maxPhase = 4, order = [1, 0]}>
|
||||
#C = #triton_gpu.mma<{versionMajor = 2, warpsPerCTA = [4, 1]}>
|
||||
#A_DOT = #triton_gpu.dot_op<{opIdx = 0, parent = #C}>
|
||||
#B_DOT = #triton_gpu.dot_op<{opIdx = 1, parent = #C}>
|
||||
#A_DOT = #triton_gpu.dot_op<{opIdx = 0, parent = #C, kWidth=2}>
|
||||
#B_DOT = #triton_gpu.dot_op<{opIdx = 1, parent = #C, kWidth=2}>
|
||||
|
||||
module attributes {"triton_gpu.num-warps" = 4 : i32} {
|
||||
|
||||
|
||||
@@ -755,8 +755,8 @@ module attributes {"triton_gpu.num-warps" = 1 : i32} {
|
||||
#blocked0 = #triton_gpu.blocked<{sizePerThread = [1, 4], threadsPerWarp = [8, 4], warpsPerCTA = [1, 1], order = [1, 0]}>
|
||||
#shared0 = #triton_gpu.shared<{vec = 1, perPhase=2, maxPhase=8 ,order = [1, 0]}>
|
||||
#mma0 = #triton_gpu.mma<{versionMajor=2, warpsPerCTA=[1,1]}>
|
||||
#dot_operand_a = #triton_gpu.dot_op<{opIdx=0, parent=#mma0}>
|
||||
#dot_operand_b = #triton_gpu.dot_op<{opIdx=1, parent=#mma0}>
|
||||
#dot_operand_a = #triton_gpu.dot_op<{opIdx=0, parent=#mma0, kWidth=2}>
|
||||
#dot_operand_b = #triton_gpu.dot_op<{opIdx=1, parent=#mma0, kWidth=2}>
|
||||
module attributes {"triton_gpu.num-warps" = 1 : i32} {
|
||||
// CHECK-LABEL: convert_dot
|
||||
tt.func @convert_dot(%A: tensor<16x16xf16, #blocked0>, %B: tensor<16x16xf16, #blocked0>) {
|
||||
@@ -897,8 +897,8 @@ module attributes {"triton_gpu.num-warps" = 1 : i32} {
|
||||
#blocked = #triton_gpu.blocked<{sizePerThread = [1, 4], threadsPerWarp = [2, 16], warpsPerCTA = [1, 4], order = [1, 0]}>
|
||||
#shared = #triton_gpu.shared<{vec = 1, perPhase = 1, maxPhase = 1, order = [1, 0]}>
|
||||
#mma = #triton_gpu.mma<{versionMajor = 2, warpsPerCTA = [2, 2]}>
|
||||
#dot_operand_a = #triton_gpu.dot_op<{opIdx=0, parent=#mma}>
|
||||
#dot_operand_b = #triton_gpu.dot_op<{opIdx=1, parent=#mma}>
|
||||
#dot_operand_a = #triton_gpu.dot_op<{opIdx=0, parent=#mma, kWidth=2}>
|
||||
#dot_operand_b = #triton_gpu.dot_op<{opIdx=1, parent=#mma, kWidth=2}>
|
||||
module attributes {"triton_gpu.num-warps" = 4 : i32} {
|
||||
tt.func @matmul_kernel_dot_operand_layout(%ptr:!tt.ptr<f32> {tt.divisibility = 16 : i32},
|
||||
%a:tensor<128x32xf16, #shared>, %b:tensor<32x256xf16, #shared>) {
|
||||
@@ -969,8 +969,8 @@ module attributes {"triton_gpu.num-warps" = 4 : i32} {
|
||||
#mma = #triton_gpu.mma<{versionMajor=2, warpsPerCTA=[2, 2]}>
|
||||
#shared = #triton_gpu.shared<{vec = 1, perPhase = 1, maxPhase = 1, order = [1, 0]}>
|
||||
#blocked = #triton_gpu.blocked<{sizePerThread = [1, 4], threadsPerWarp = [2, 16], warpsPerCTA = [1, 4], order = [1, 0]}>
|
||||
#dot_operand_a = #triton_gpu.dot_op<{opIdx=0, parent=#mma}>
|
||||
#dot_operand_b = #triton_gpu.dot_op<{opIdx=1, parent=#mma}>
|
||||
#dot_operand_a = #triton_gpu.dot_op<{opIdx=0, parent=#mma, kWidth=1}>
|
||||
#dot_operand_b = #triton_gpu.dot_op<{opIdx=1, parent=#mma, kWidth=1}>
|
||||
module attributes {"triton_gpu.num-warps" = 4 : i32} {
|
||||
// CHECK-LABEL: matmul_tf32dot
|
||||
tt.func @matmul_tf32dot(%ptr:!tt.ptr<f32> {tt.divisibility = 16 : i32},
|
||||
|
||||
@@ -8,8 +8,8 @@
|
||||
#BLs0 = #triton_gpu.slice<{parent=#BL, dim=0}>
|
||||
#BLs1 = #triton_gpu.slice<{parent=#BL, dim=1}>
|
||||
#C = #triton_gpu.mma<{versionMajor = 2, warpsPerCTA = [4, 1]}>
|
||||
#A = #triton_gpu.dot_op<{opIdx = 0, parent = #C}>
|
||||
#B = #triton_gpu.dot_op<{opIdx = 1, parent = #C}>
|
||||
#A = #triton_gpu.dot_op<{opIdx = 0, parent = #C, kWidth=2}>
|
||||
#B = #triton_gpu.dot_op<{opIdx = 1, parent = #C, kWidth=2}>
|
||||
|
||||
// CHECK: tt.func @matmul_loop
|
||||
// CHECK-DAG: %[[CONSTANT_0:.*]] = arith.constant 0 : i32
|
||||
|
||||
@@ -7,8 +7,8 @@
|
||||
#A = #triton_gpu.shared<{vec = 2, perPhase = 2, maxPhase = 4, order = [1, 0]}>
|
||||
#B = #triton_gpu.shared<{vec = 2, perPhase = 2, maxPhase = 4, order = [1, 0]}>
|
||||
#C = #triton_gpu.mma<{version = 2, warpsPerCTA = [4, 1]}>
|
||||
#A_OP = #triton_gpu.dot_op<{opIdx = 0, parent = #C}>
|
||||
#B_OP = #triton_gpu.dot_op<{opIdx = 1, parent = #C}>
|
||||
#A_OP = #triton_gpu.dot_op<{opIdx = 0, parent = #C, kWidth = 2}>
|
||||
#B_OP = #triton_gpu.dot_op<{opIdx = 1, parent = #C, kWidth = 2}>
|
||||
|
||||
|
||||
// CHECK: tt.func @matmul_loop
|
||||
|
||||
@@ -30,7 +30,7 @@ TEST_P(SwizzleDotOperandTestFixture, DotOperands) {
|
||||
// create encoding
|
||||
auto parent = triton::gpu::MmaEncodingAttr::get(&ctx, 2, 0, {1, 1});
|
||||
auto encoding =
|
||||
triton::gpu::DotOperandEncodingAttr::get(&ctx, params.opIdx, parent);
|
||||
triton::gpu::DotOperandEncodingAttr::get(&ctx, params.opIdx, parent, 0);
|
||||
|
||||
// create element type
|
||||
Type eltType = IntegerType::get(&ctx, params.typeWidth);
|
||||
|
||||
Reference in New Issue
Block a user