[OPTIMIZER] Fix Shared layout in OptimizeDotOperands pass to generate correct swizzling code (#2180)

fix bug #1937

Co-authored-by: Philippe Tillet <phil@openai.com>
This commit is contained in:
Bin Fan
2023-09-13 12:52:09 -07:00
committed by GitHub
parent c61d772eee
commit 38a2ecdccf
3 changed files with 40 additions and 6 deletions

View File

@@ -113,6 +113,16 @@ compared to 1*64 when the hasLeadingOffset is false.
"ArrayRef<unsigned>":$order,
"CTALayoutAttr":$CTALayout,
"unsigned":$typeWidthInBit), [{
bool needTrans = false; // default value
return get(context, dotOpEnc, shape, order, CTALayout, typeWidthInBit, needTrans);
}]>,
AttrBuilder<(ins "DotOperandEncodingAttr":$dotOpEnc,
"ArrayRef<int64_t>":$shape,
"ArrayRef<unsigned>":$order,
"CTALayoutAttr":$CTALayout,
"unsigned":$typeWidthInBit,
"bool":$needTrans), [{
auto mmaEnc = dotOpEnc.getParent().dyn_cast<MmaEncodingAttr>();
if(!mmaEnc)
@@ -152,16 +162,23 @@ compared to 1*64 when the hasLeadingOffset is false.
// --- handle A operand ---
if (opIdx == 0) { // compute swizzling for A operand
int vec = (order[0] == 1) ? matShape[2] : matShape[0]; // k : m
int mmaStride = (order[0] == 1) ? matShape[0] : matShape[2];
int m = (needTrans) ? matShape[2] : matShape[0];
int k = (needTrans) ? matShape[0] : matShape[2];
int vec = (order[0] == 1) ? k : m;
int mmaStride = (order[0] == 1) ? m : k;
int maxPhase = mmaStride / perPhase;
return get(context, vec, perPhase, maxPhase, order, CTALayout);
}
// --- handle B operand ---
if (opIdx == 1) {
int vec = (order[0] == 1) ? matShape[1] : matShape[2]; // n : k
int mmaStride = (order[0] == 1) ? matShape[2] : matShape[1];
// we compute vec and maxPhase m, n and k size of the mma
// instruction. when matmul operands is transposed, we should
// consider that to get m, n and k.
int n = needTrans ? matShape[2] : matShape[1];
int k = needTrans ? matShape[1] : matShape[2];
int vec = (order[0] == 1) ? n : k;
int mmaStride = (order[0] == 1) ? k : n;
int maxPhase = mmaStride / perPhase;
return get(context, vec, perPhase, maxPhase, order, CTALayout);
}
@@ -189,6 +206,16 @@ compared to 1*64 when the hasLeadingOffset is false.
return get(context, dotOpEnc, shape, order, CTALayout, bitwidth);
}]>,
AttrBuilder<(ins "DotOperandEncodingAttr":$dotOpEnc,
"ArrayRef<int64_t>":$shape,
"ArrayRef<unsigned>":$order,
"CTALayoutAttr":$CTALayout,
"Type":$eltTy,
"bool":$needTrans), [{
unsigned bitwidth = eltTy.getIntOrFloatBitWidth();
return get(context, dotOpEnc, shape, order, CTALayout, bitwidth, needTrans);
}]>,
AttrBuilder<(ins "ArrayRef<int64_t>":$shape,
"ArrayRef<unsigned>":$order,
"CTALayoutAttr":$CTALayout,

View File

@@ -60,9 +60,14 @@ public:
// used here. For tests where numCTAs = 1, this is not a problem since all
// CTALayouts are the same.
auto newXOrder = triton::gpu::getOrder(argEncoding);
// set needTrans to true here. newXEncoding is computed based on argEncoding
// which is before the transpose. without needTrans we will compute vec and
// maxPhase based on incorrect m, n and k size of mma. the type inference of
// TransOp simply swap the order but doesn't fix the vec and maxPhase for
// the YType, hence it would causing incorrect swizzling code.
auto newXEncoding = triton::gpu::SharedEncodingAttr::get(
getContext(), ZEncoding, XType.getShape(), newXOrder,
XEncoding.getCTALayout(), XType.getElementType());
XEncoding.getCTALayout(), XType.getElementType(), true);
auto newXType = RankedTensorType::get(XType.getShape(),
XType.getElementType(), newXEncoding);
if (XEncoding == newXEncoding)

View File

@@ -652,10 +652,12 @@ void LoopPipeliner::createBufferTypes() {
.getEncoding()
.dyn_cast<ttg::DotOperandEncodingAttr>()) {
// MMAv1 and MMAv2
bool needTrans = dyn_cast_or_null<tt::TransOp>(
cvt.getDefiningOp()->getOperand(0).getDefiningOp());
unsigned bitWidth = ty.getElementType().getIntOrFloatBitWidth();
sharedEnc = ttg::SharedEncodingAttr::get(
ty.getContext(), dotOpEnc, ty.getShape(),
ttg::getOrder(ty.getEncoding()), CTALayout, bitWidth);
ttg::getOrder(ty.getEncoding()), CTALayout, bitWidth, needTrans);
} else {
// MMAv3
sharedEnc = ttg::SharedEncodingAttr::get(ty.getContext(), ty.getShape(),