[OPTIMIZER] Added kWidth attribute to DotOperandEncoding (#1584)

This is a pre-requisist for efficient mixed-precision matmul
This commit is contained in:
Philippe Tillet
2023-04-26 23:03:18 -07:00
committed by GitHub
parent 167206924c
commit 8f47bdcc92
20 changed files with 162 additions and 74 deletions

View File

@@ -501,10 +501,22 @@ section 9.7.13.4.1 for more details.
let parameters = (
ins
"unsigned":$opIdx,
"Attribute":$parent
"Attribute":$parent,
"unsigned":$MMAv2kWidth
);
let builders = [
// Specially for MMAV1(Volta)
AttrBuilder<(ins "unsigned":$opIdx,
"Attribute":$parent,
"Type":$eltTy), [{
MmaEncodingAttr parentAttr = parent.dyn_cast<MmaEncodingAttr>();
if (!parentAttr || !parentAttr.isAmpere())
return $_get(context, opIdx, parent, 0);
unsigned bitwidth = eltTy.getIntOrFloatBitWidth();
unsigned MMAv2kWidth = 32 / bitwidth;
return $_get(context, opIdx, parent, MMAv2kWidth);
}]>
];
let hasCustomAssemblyFormat = 1;

View File

@@ -325,7 +325,8 @@ private:
if (needTrans) {
// do transpose
auto aEncoding = DotOperandEncodingAttr::get(mma.getContext(), 0, mma);
auto aEncoding =
DotOperandEncodingAttr::get(mma.getContext(), 0, mma, 0);
int numM = aEncoding.getMMAv1NumOuter(shape);
int numN = accumSizePerThread / numM;

View File

@@ -358,11 +358,11 @@ SmallVector<CoordTy> getMNCoords(Value thread,
Value _fpw1 = i32_val(fpw[1]);
// A info
auto aEncoding = DotOperandEncodingAttr::get(ctx, 0, mmaLayout);
auto aEncoding = DotOperandEncodingAttr::get(ctx, 0, mmaLayout, 0);
auto aRep = aEncoding.getMMAv1Rep();
auto aSpw = aEncoding.getMMAv1ShapePerWarp();
// B info
auto bEncoding = DotOperandEncodingAttr::get(ctx, 1, mmaLayout);
auto bEncoding = DotOperandEncodingAttr::get(ctx, 1, mmaLayout, 0);
auto bSpw = bEncoding.getMMAv1ShapePerWarp();
auto bRep = bEncoding.getMMAv1Rep();

View File

@@ -714,11 +714,11 @@ private:
Value _fpw1 = i32_val(fpw[1]);
// A info
auto aEncoding = DotOperandEncodingAttr::get(ctx, 0, mmaLayout);
auto aEncoding = DotOperandEncodingAttr::get(ctx, 0, mmaLayout, 0);
auto aRep = aEncoding.getMMAv1Rep();
auto aSpw = aEncoding.getMMAv1ShapePerWarp();
// B info
auto bEncoding = DotOperandEncodingAttr::get(ctx, 1, mmaLayout);
auto bEncoding = DotOperandEncodingAttr::get(ctx, 1, mmaLayout, 0);
auto bSpw = bEncoding.getMMAv1ShapePerWarp();
auto bRep = bEncoding.getMMAv1Rep();
@@ -775,12 +775,12 @@ private:
// TODO: seems like the apttern below to get `rep`/`spw` appears quite often
// A info
auto aEncoding =
DotOperandEncodingAttr::get(type.getContext(), 0, mmaLayout);
DotOperandEncodingAttr::get(type.getContext(), 0, mmaLayout, 0);
auto aRep = aEncoding.getMMAv1Rep();
auto aSpw = aEncoding.getMMAv1ShapePerWarp();
// B info
auto bEncoding =
DotOperandEncodingAttr::get(type.getContext(), 1, mmaLayout);
DotOperandEncodingAttr::get(type.getContext(), 1, mmaLayout, 0);
auto bSpw = bEncoding.getMMAv1ShapePerWarp();
auto bRep = bEncoding.getMMAv1Rep();

View File

@@ -268,6 +268,8 @@ struct TritonDotPattern : public OpConversionPattern<triton::DotOp> {
// a & b must be of smem layout
auto aType = adaptor.getA().getType().cast<RankedTensorType>();
auto bType = adaptor.getB().getType().cast<RankedTensorType>();
Type aEltType = aType.getElementType();
Type bEltType = bType.getElementType();
Attribute aEncoding = aType.getEncoding();
Attribute bEncoding = bType.getEncoding();
if (!aEncoding || !bEncoding)
@@ -276,17 +278,17 @@ struct TritonDotPattern : public OpConversionPattern<triton::DotOp> {
Value b = adaptor.getB();
Value c = adaptor.getC();
if (!aEncoding.isa<triton::gpu::DotOperandEncodingAttr>()) {
Attribute encoding =
triton::gpu::DotOperandEncodingAttr::get(getContext(), 0, dEncoding);
auto dstType = RankedTensorType::get(aType.getShape(),
aType.getElementType(), encoding);
Attribute encoding = triton::gpu::DotOperandEncodingAttr::get(
getContext(), 0, dEncoding, aEltType);
auto dstType =
RankedTensorType::get(aType.getShape(), aEltType, encoding);
a = rewriter.create<triton::gpu::ConvertLayoutOp>(a.getLoc(), dstType, a);
}
if (!bEncoding.isa<triton::gpu::DotOperandEncodingAttr>()) {
Attribute encoding =
triton::gpu::DotOperandEncodingAttr::get(getContext(), 1, dEncoding);
auto dstType = RankedTensorType::get(bType.getShape(),
bType.getElementType(), encoding);
Attribute encoding = triton::gpu::DotOperandEncodingAttr::get(
getContext(), 1, dEncoding, bEltType);
auto dstType =
RankedTensorType::get(bType.getShape(), bEltType, encoding);
b = rewriter.create<triton::gpu::ConvertLayoutOp>(b.getLoc(), dstType, b);
}
c = rewriter.create<triton::gpu::ConvertLayoutOp>(c.getLoc(), retType, c);

View File

@@ -774,14 +774,27 @@ Attribute DotOperandEncodingAttr::parse(AsmParser &parser, Type type) {
return {};
unsigned opIdx = attrs.get("opIdx").cast<IntegerAttr>().getInt();
Attribute parent = attrs.get("parent");
auto mmaParent = parent.dyn_cast<MmaEncodingAttr>();
unsigned kWidth = 0;
Attribute _kWidth = attrs.get("kWidth");
if (_kWidth) {
if (!mmaParent || mmaParent.isVolta()) {
auto loc = parser.getNameLoc();
parser.emitError(loc, "kWidth only supported for MMAv2+ parent");
return Attribute();
}
kWidth = _kWidth.cast<IntegerAttr>().getInt();
}
return parser.getChecked<DotOperandEncodingAttr>(parser.getContext(), opIdx,
parent);
parent, kWidth);
}
void DotOperandEncodingAttr::print(mlir::AsmPrinter &printer) const {
auto mmaParent = getParent().dyn_cast<MmaEncodingAttr>();
printer << "<{"
<< "opIdx = " << getOpIdx() << ", "
<< "parent = " << getParent();
<< "opIdx = " << getOpIdx() << ", parent = " << getParent();
if (mmaParent && mmaParent.isAmpere())
printer << ", kWidth = " << getMMAv2kWidth();
printer << "}>";
}

View File

@@ -170,14 +170,17 @@ public:
.cast<triton::gpu::BlockedEncodingAttr>()
.getOrder();
auto newAEncoding = triton::gpu::DotOperandEncodingAttr::get(
oldAType.getContext(), 0, newRetType.getEncoding(),
oldAType.getElementType());
auto newBEncoding = triton::gpu::DotOperandEncodingAttr::get(
oldBType.getContext(), 1, newRetType.getEncoding(),
oldBType.getElementType());
auto newAType = RankedTensorType::get(
oldAType.getShape(), oldAType.getElementType(),
triton::gpu::DotOperandEncodingAttr::get(oldAType.getContext(), 0,
newRetType.getEncoding()));
oldAType.getShape(), oldAType.getElementType(), newAEncoding);
auto newBType = RankedTensorType::get(
oldBType.getShape(), oldBType.getElementType(),
triton::gpu::DotOperandEncodingAttr::get(oldBType.getContext(), 1,
newRetType.getEncoding()));
oldBType.getShape(), oldBType.getElementType(), newBEncoding);
a = rewriter.create<triton::gpu::ConvertLayoutOp>(a.getLoc(), newAType, a);
b = rewriter.create<triton::gpu::ConvertLayoutOp>(b.getLoc(), newBType, b);

View File

@@ -1,3 +1,4 @@
#include "Utility.h"
#include "mlir/Dialect/Tensor/IR/Tensor.h"
#include "mlir/IR/IRMapping.h"
#include "mlir/IR/TypeUtilities.h"
@@ -42,6 +43,9 @@ class LoopPipeliner {
/// Loads to be pipelined
SetVector<Value> loads;
/// Smallest data-type for each load (used to optimize swizzle and
/// (create DotOpEncoding layout)
DenseMap<Value, Type> loadsSmallestType;
/// The value that each load will be mapped to (after layout conversion)
DenseMap<Value, Value> loadsMapping;
/// load => buffer
@@ -256,33 +260,62 @@ LogicalResult LoopPipeliner::initialize() {
use = *use->getResult(0).getUsers().begin();
}
if (auto convertLayout = llvm::dyn_cast<ttg::ConvertLayoutOp>(use)) {
if (auto tensorType = convertLayout.getResult()
.getType()
.dyn_cast<RankedTensorType>()) {
if (auto dotOpEnc = tensorType.getEncoding()
.dyn_cast<ttg::DotOperandEncodingAttr>()) {
isCandidate = true;
loadsMapping[loadOp] = convertLayout;
auto ty = loadOp.getType().cast<RankedTensorType>();
SmallVector<int64_t> bufferShape(ty.getShape().begin(),
ty.getShape().end());
bufferShape.insert(bufferShape.begin(), numStages);
auto sharedEnc = ttg::SharedEncodingAttr::get(
ty.getContext(), dotOpEnc, ty.getShape(),
triton::gpu::getOrder(ty.getEncoding()), ty.getElementType());
loadsBufferType[loadOp] = RankedTensorType::get(
bufferShape, ty.getElementType(), sharedEnc);
}
}
}
} else
auto convertLayout = llvm::dyn_cast<ttg::ConvertLayoutOp>(use);
if (!convertLayout)
continue;
auto tensorType =
convertLayout.getResult().getType().dyn_cast<RankedTensorType>();
if (!tensorType)
continue;
auto dotOpEnc =
tensorType.getEncoding().dyn_cast<ttg::DotOperandEncodingAttr>();
if (!dotOpEnc)
continue;
isCandidate = true;
loadsMapping[loadOp] = convertLayout;
}
else
isCandidate = false;
if (isCandidate)
loads.insert(loadOp);
}
// we need to find the smallest ocmmon dtype
// since this determines the layout of `mma.sync` operands
// in mixed-precision mode
Type smallestType;
for (auto loadCvt : loadsMapping) {
auto loadOp = loadCvt.first;
auto ty = loadOp.getType().cast<RankedTensorType>();
Type eltTy = ty.getElementType();
if (!smallestType ||
(eltTy.getIntOrFloatBitWidth() < smallestType.getIntOrFloatBitWidth()))
smallestType = eltTy;
}
for (auto loadCvt : loadsMapping)
loadsSmallestType[loadCvt.first] = smallestType;
for (auto loadCvt : loadsMapping) {
auto loadOp = loadCvt.first;
Value cvt = loadCvt.second;
auto dotOpEnc = cvt.getType()
.cast<RankedTensorType>()
.getEncoding()
.cast<ttg::DotOperandEncodingAttr>();
auto ty = loadOp.getType().cast<RankedTensorType>();
SmallVector<int64_t> bufferShape(ty.getShape().begin(),
ty.getShape().end());
bufferShape.insert(bufferShape.begin(), numStages);
auto sharedEnc = ttg::SharedEncodingAttr::get(
ty.getContext(), dotOpEnc, ty.getShape(),
triton::gpu::getOrder(ty.getEncoding()), loadsSmallestType[loadOp]);
loadsBufferType[loadOp] =
RankedTensorType::get(bufferShape, ty.getElementType(), sharedEnc);
}
// We have some loads to pipeline
if (!loads.empty()) {
// Update depArgs & depOps
@@ -551,8 +584,15 @@ scf::ForOp LoopPipeliner::createNewForOp() {
}
// we replace the use new load use with a convert layout
size_t i = std::distance(loads.begin(), it);
auto cvtDstTy = op.getResult(0).getType().cast<RankedTensorType>();
auto cvtDstEnc = cvtDstTy.getEncoding().cast<ttg::DotOperandEncodingAttr>();
auto newDstTy = RankedTensorType::get(
cvtDstTy.getShape(), cvtDstTy.getElementType(),
ttg::DotOperandEncodingAttr::get(
cvtDstEnc.getContext(), cvtDstEnc.getOpIdx(), cvtDstEnc.getParent(),
loadsSmallestType[op.getOperand(0)]));
auto cvt = builder.create<ttg::ConvertLayoutOp>(
op.getLoc(), op.getResult(0).getType(),
op.getResult(0).getLoc(), newDstTy,
newForOp.getRegionIterArgs()[loadIdx + i]);
mapping.map(op.getResult(0), cvt.getResult());
}

View File

@@ -110,7 +110,7 @@ Value Prefetcher::generatePrefetch(Value v, unsigned opIdx, bool isPrologue,
SmallVector<OpFoldResult>{intAttr(1), intAttr(1)});
auto dotOperandEnc = triton::gpu::DotOperandEncodingAttr::get(
builder.getContext(), opIdx, dotEncoding);
builder.getContext(), opIdx, dotEncoding, prefetchWidth / 8);
Value prefetchSlice = builder.create<triton::gpu::ConvertLayoutOp>(
v.getLoc(), RankedTensorType::get(shape, elementType, dotOperandEnc),
newSmem);
@@ -156,12 +156,22 @@ LogicalResult Prefetcher::initialize() {
};
for (triton::DotOp dot : dotsInFor) {
auto kSize = dot.getA().getType().cast<RankedTensorType>().getShape()[1];
auto aType = dot.getA().getType().cast<RankedTensorType>();
auto bType = dot.getB().getType().cast<RankedTensorType>();
auto aEnc = aType.getEncoding().cast<triton::gpu::DotOperandEncodingAttr>();
auto bEnc = bType.getEncoding().cast<triton::gpu::DotOperandEncodingAttr>();
int aKWidth = aEnc.getMMAv2kWidth();
int bKWidth = bEnc.getMMAv2kWidth();
assert(aKWidth == bKWidth);
auto kSize = aType.getShape()[1];
// works better with nvidia tensor cores
unsigned elementWidth =
dot.getA().getType().cast<RankedTensorType>().getElementTypeBitWidth();
prefetchWidth = 256 / elementWidth;
unsigned elementWidth = aType.getElementTypeBitWidth();
if (aKWidth == 0)
prefetchWidth = 256 / elementWidth;
else
prefetchWidth = 8 * aKWidth;
// Skip prefetching if kSize is less than prefetchWidth
if (kSize < prefetchWidth)

View File

@@ -341,7 +341,10 @@ public:
cvt.getOperand().getType().cast<RankedTensorType>().getEncoding();
auto dstEncoding =
cvt.getResult().getType().cast<RankedTensorType>().getEncoding();
// XXX: why is this needed?
if (srcEncoding.isa<triton::gpu::SharedEncodingAttr>() ||
dstEncoding.isa<triton::gpu::SharedEncodingAttr>())
return failure();
// heuristics for flash attention
if (srcEncoding.isa<triton::gpu::SliceEncodingAttr>())
return failure();
SetVector<Operation *> cvtSlices;

View File

@@ -206,11 +206,15 @@ int simulateBackwardRematerialization(
//
Operation *cloneWithInferType(mlir::PatternRewriter &rewriter, Operation *op,
Operation *cloneWithInferType(mlir::OpBuilder &rewriter, Operation *op,
IRMapping &mapping) {
Operation *newOp = rewriter.clone(*op, mapping);
auto origType = op->getResult(0).getType().cast<RankedTensorType>();
auto argType = newOp->getOperand(0).getType().cast<RankedTensorType>();
if (newOp->getNumResults() == 0)
return newOp;
auto origType = op->getResult(0).getType().dyn_cast<RankedTensorType>();
auto argType = newOp->getOperand(0).getType().dyn_cast<RankedTensorType>();
if (!origType || !argType)
return newOp;
auto newType = RankedTensorType::get(
origType.getShape(), origType.getElementType(), argType.getEncoding());
newOp->getResult(0).setType(newType);

View File

@@ -21,7 +21,7 @@ int simulateBackwardRematerialization(
SetVector<Attribute> &layout, llvm::MapVector<Value, Attribute> &toConvert,
Attribute targetEncoding);
Operation *cloneWithInferType(mlir::PatternRewriter &rewriter, Operation *op,
Operation *cloneWithInferType(mlir::OpBuilder &rewriter, Operation *op,
IRMapping &mapping);
void rematerializeConversionChain(

Binary file not shown.

View File

@@ -6,8 +6,8 @@
#A_SHARED_T = #triton_gpu.shared<{vec = 2, perPhase = 2, maxPhase = 4, order = [0, 1]}>
#B_SHARED = #triton_gpu.shared<{vec = 2, perPhase = 2, maxPhase = 4, order = [1, 0]}>
#C = #triton_gpu.mma<{versionMajor = 2, warpsPerCTA = [4, 1]}>
#A_DOT = #triton_gpu.dot_op<{opIdx = 0, parent = #C}>
#B_DOT = #triton_gpu.dot_op<{opIdx = 1, parent = #C}>
#A_DOT = #triton_gpu.dot_op<{opIdx = 0, parent = #C, kWidth=2}>
#B_DOT = #triton_gpu.dot_op<{opIdx = 1, parent = #C, kWidth=2}>
// CHECK-LABEL: matmul_loop
// There shouldn't be any aliasing with the dot op encoding.

View File

@@ -7,8 +7,8 @@
#A_SHARED_T = #triton_gpu.shared<{vec = 2, perPhase = 2, maxPhase = 4, order = [0, 1]}>
#B_SHARED = #triton_gpu.shared<{vec = 2, perPhase = 2, maxPhase = 4, order = [1, 0]}>
#C = #triton_gpu.mma<{versionMajor = 2, warpsPerCTA = [4, 1]}>
#A_DOT = #triton_gpu.dot_op<{opIdx = 0, parent = #C}>
#B_DOT = #triton_gpu.dot_op<{opIdx = 1, parent = #C}>
#A_DOT = #triton_gpu.dot_op<{opIdx = 0, parent = #C, kWidth=2}>
#B_DOT = #triton_gpu.dot_op<{opIdx = 1, parent = #C, kWidth=2}>
module attributes {"triton_gpu.num-warps" = 4 : i32} {

View File

@@ -7,8 +7,8 @@
#A_SHARED_T = #triton_gpu.shared<{vec = 2, perPhase = 2, maxPhase = 4, order = [0, 1]}>
#B_SHARED = #triton_gpu.shared<{vec = 2, perPhase = 2, maxPhase = 4, order = [1, 0]}>
#C = #triton_gpu.mma<{versionMajor = 2, warpsPerCTA = [4, 1]}>
#A_DOT = #triton_gpu.dot_op<{opIdx = 0, parent = #C}>
#B_DOT = #triton_gpu.dot_op<{opIdx = 1, parent = #C}>
#A_DOT = #triton_gpu.dot_op<{opIdx = 0, parent = #C, kWidth=2}>
#B_DOT = #triton_gpu.dot_op<{opIdx = 1, parent = #C, kWidth=2}>
module attributes {"triton_gpu.num-warps" = 4 : i32} {

View File

@@ -755,8 +755,8 @@ module attributes {"triton_gpu.num-warps" = 1 : i32} {
#blocked0 = #triton_gpu.blocked<{sizePerThread = [1, 4], threadsPerWarp = [8, 4], warpsPerCTA = [1, 1], order = [1, 0]}>
#shared0 = #triton_gpu.shared<{vec = 1, perPhase=2, maxPhase=8 ,order = [1, 0]}>
#mma0 = #triton_gpu.mma<{versionMajor=2, warpsPerCTA=[1,1]}>
#dot_operand_a = #triton_gpu.dot_op<{opIdx=0, parent=#mma0}>
#dot_operand_b = #triton_gpu.dot_op<{opIdx=1, parent=#mma0}>
#dot_operand_a = #triton_gpu.dot_op<{opIdx=0, parent=#mma0, kWidth=2}>
#dot_operand_b = #triton_gpu.dot_op<{opIdx=1, parent=#mma0, kWidth=2}>
module attributes {"triton_gpu.num-warps" = 1 : i32} {
// CHECK-LABEL: convert_dot
tt.func @convert_dot(%A: tensor<16x16xf16, #blocked0>, %B: tensor<16x16xf16, #blocked0>) {
@@ -897,8 +897,8 @@ module attributes {"triton_gpu.num-warps" = 1 : i32} {
#blocked = #triton_gpu.blocked<{sizePerThread = [1, 4], threadsPerWarp = [2, 16], warpsPerCTA = [1, 4], order = [1, 0]}>
#shared = #triton_gpu.shared<{vec = 1, perPhase = 1, maxPhase = 1, order = [1, 0]}>
#mma = #triton_gpu.mma<{versionMajor = 2, warpsPerCTA = [2, 2]}>
#dot_operand_a = #triton_gpu.dot_op<{opIdx=0, parent=#mma}>
#dot_operand_b = #triton_gpu.dot_op<{opIdx=1, parent=#mma}>
#dot_operand_a = #triton_gpu.dot_op<{opIdx=0, parent=#mma, kWidth=2}>
#dot_operand_b = #triton_gpu.dot_op<{opIdx=1, parent=#mma, kWidth=2}>
module attributes {"triton_gpu.num-warps" = 4 : i32} {
tt.func @matmul_kernel_dot_operand_layout(%ptr:!tt.ptr<f32> {tt.divisibility = 16 : i32},
%a:tensor<128x32xf16, #shared>, %b:tensor<32x256xf16, #shared>) {
@@ -969,8 +969,8 @@ module attributes {"triton_gpu.num-warps" = 4 : i32} {
#mma = #triton_gpu.mma<{versionMajor=2, warpsPerCTA=[2, 2]}>
#shared = #triton_gpu.shared<{vec = 1, perPhase = 1, maxPhase = 1, order = [1, 0]}>
#blocked = #triton_gpu.blocked<{sizePerThread = [1, 4], threadsPerWarp = [2, 16], warpsPerCTA = [1, 4], order = [1, 0]}>
#dot_operand_a = #triton_gpu.dot_op<{opIdx=0, parent=#mma}>
#dot_operand_b = #triton_gpu.dot_op<{opIdx=1, parent=#mma}>
#dot_operand_a = #triton_gpu.dot_op<{opIdx=0, parent=#mma, kWidth=1}>
#dot_operand_b = #triton_gpu.dot_op<{opIdx=1, parent=#mma, kWidth=1}>
module attributes {"triton_gpu.num-warps" = 4 : i32} {
// CHECK-LABEL: matmul_tf32dot
tt.func @matmul_tf32dot(%ptr:!tt.ptr<f32> {tt.divisibility = 16 : i32},

View File

@@ -8,8 +8,8 @@
#BLs0 = #triton_gpu.slice<{parent=#BL, dim=0}>
#BLs1 = #triton_gpu.slice<{parent=#BL, dim=1}>
#C = #triton_gpu.mma<{versionMajor = 2, warpsPerCTA = [4, 1]}>
#A = #triton_gpu.dot_op<{opIdx = 0, parent = #C}>
#B = #triton_gpu.dot_op<{opIdx = 1, parent = #C}>
#A = #triton_gpu.dot_op<{opIdx = 0, parent = #C, kWidth=2}>
#B = #triton_gpu.dot_op<{opIdx = 1, parent = #C, kWidth=2}>
// CHECK: tt.func @matmul_loop
// CHECK-DAG: %[[CONSTANT_0:.*]] = arith.constant 0 : i32

View File

@@ -7,8 +7,8 @@
#A = #triton_gpu.shared<{vec = 2, perPhase = 2, maxPhase = 4, order = [1, 0]}>
#B = #triton_gpu.shared<{vec = 2, perPhase = 2, maxPhase = 4, order = [1, 0]}>
#C = #triton_gpu.mma<{version = 2, warpsPerCTA = [4, 1]}>
#A_OP = #triton_gpu.dot_op<{opIdx = 0, parent = #C}>
#B_OP = #triton_gpu.dot_op<{opIdx = 1, parent = #C}>
#A_OP = #triton_gpu.dot_op<{opIdx = 0, parent = #C, kWidth = 2}>
#B_OP = #triton_gpu.dot_op<{opIdx = 1, parent = #C, kWidth = 2}>
// CHECK: tt.func @matmul_loop

View File

@@ -30,7 +30,7 @@ TEST_P(SwizzleDotOperandTestFixture, DotOperands) {
// create encoding
auto parent = triton::gpu::MmaEncodingAttr::get(&ctx, 2, 0, {1, 1});
auto encoding =
triton::gpu::DotOperandEncodingAttr::get(&ctx, params.opIdx, parent);
triton::gpu::DotOperandEncodingAttr::get(&ctx, params.opIdx, parent, 0);
// create element type
Type eltType = IntegerType::get(&ctx, params.typeWidth);