mirror of
https://github.com/ROCm/ROCm.git
synced 2026-04-05 03:01:17 -04:00
[OPTIMIZER] Fix Shared layout in OptimizeDotOperands pass to generate correct swizzling code (#2180)
fix bug #1937 Co-authored-by: Philippe Tillet <phil@openai.com>
This commit is contained in:
@@ -113,6 +113,16 @@ compared to 1*64 when the hasLeadingOffset is false.
|
||||
"ArrayRef<unsigned>":$order,
|
||||
"CTALayoutAttr":$CTALayout,
|
||||
"unsigned":$typeWidthInBit), [{
|
||||
bool needTrans = false; // default value
|
||||
return get(context, dotOpEnc, shape, order, CTALayout, typeWidthInBit, needTrans);
|
||||
}]>,
|
||||
|
||||
AttrBuilder<(ins "DotOperandEncodingAttr":$dotOpEnc,
|
||||
"ArrayRef<int64_t>":$shape,
|
||||
"ArrayRef<unsigned>":$order,
|
||||
"CTALayoutAttr":$CTALayout,
|
||||
"unsigned":$typeWidthInBit,
|
||||
"bool":$needTrans), [{
|
||||
auto mmaEnc = dotOpEnc.getParent().dyn_cast<MmaEncodingAttr>();
|
||||
|
||||
if(!mmaEnc)
|
||||
@@ -152,16 +162,23 @@ compared to 1*64 when the hasLeadingOffset is false.
|
||||
|
||||
// --- handle A operand ---
|
||||
if (opIdx == 0) { // compute swizzling for A operand
|
||||
int vec = (order[0] == 1) ? matShape[2] : matShape[0]; // k : m
|
||||
int mmaStride = (order[0] == 1) ? matShape[0] : matShape[2];
|
||||
int m = (needTrans) ? matShape[2] : matShape[0];
|
||||
int k = (needTrans) ? matShape[0] : matShape[2];
|
||||
int vec = (order[0] == 1) ? k : m;
|
||||
int mmaStride = (order[0] == 1) ? m : k;
|
||||
int maxPhase = mmaStride / perPhase;
|
||||
return get(context, vec, perPhase, maxPhase, order, CTALayout);
|
||||
}
|
||||
|
||||
// --- handle B operand ---
|
||||
if (opIdx == 1) {
|
||||
int vec = (order[0] == 1) ? matShape[1] : matShape[2]; // n : k
|
||||
int mmaStride = (order[0] == 1) ? matShape[2] : matShape[1];
|
||||
// we compute vec and maxPhase m, n and k size of the mma
|
||||
// instruction. when matmul operands is transposed, we should
|
||||
// consider that to get m, n and k.
|
||||
int n = needTrans ? matShape[2] : matShape[1];
|
||||
int k = needTrans ? matShape[1] : matShape[2];
|
||||
int vec = (order[0] == 1) ? n : k;
|
||||
int mmaStride = (order[0] == 1) ? k : n;
|
||||
int maxPhase = mmaStride / perPhase;
|
||||
return get(context, vec, perPhase, maxPhase, order, CTALayout);
|
||||
}
|
||||
@@ -189,6 +206,16 @@ compared to 1*64 when the hasLeadingOffset is false.
|
||||
return get(context, dotOpEnc, shape, order, CTALayout, bitwidth);
|
||||
}]>,
|
||||
|
||||
AttrBuilder<(ins "DotOperandEncodingAttr":$dotOpEnc,
|
||||
"ArrayRef<int64_t>":$shape,
|
||||
"ArrayRef<unsigned>":$order,
|
||||
"CTALayoutAttr":$CTALayout,
|
||||
"Type":$eltTy,
|
||||
"bool":$needTrans), [{
|
||||
unsigned bitwidth = eltTy.getIntOrFloatBitWidth();
|
||||
return get(context, dotOpEnc, shape, order, CTALayout, bitwidth, needTrans);
|
||||
}]>,
|
||||
|
||||
AttrBuilder<(ins "ArrayRef<int64_t>":$shape,
|
||||
"ArrayRef<unsigned>":$order,
|
||||
"CTALayoutAttr":$CTALayout,
|
||||
|
||||
@@ -60,9 +60,14 @@ public:
|
||||
// used here. For tests where numCTAs = 1, this is not a problem since all
|
||||
// CTALayouts are the same.
|
||||
auto newXOrder = triton::gpu::getOrder(argEncoding);
|
||||
// set needTrans to true here. newXEncoding is computed based on argEncoding
|
||||
// which is before the transpose. without needTrans we will compute vec and
|
||||
// maxPhase based on incorrect m, n and k size of mma. the type inference of
|
||||
// TransOp simply swap the order but doesn't fix the vec and maxPhase for
|
||||
// the YType, hence it would causing incorrect swizzling code.
|
||||
auto newXEncoding = triton::gpu::SharedEncodingAttr::get(
|
||||
getContext(), ZEncoding, XType.getShape(), newXOrder,
|
||||
XEncoding.getCTALayout(), XType.getElementType());
|
||||
XEncoding.getCTALayout(), XType.getElementType(), true);
|
||||
auto newXType = RankedTensorType::get(XType.getShape(),
|
||||
XType.getElementType(), newXEncoding);
|
||||
if (XEncoding == newXEncoding)
|
||||
|
||||
@@ -652,10 +652,12 @@ void LoopPipeliner::createBufferTypes() {
|
||||
.getEncoding()
|
||||
.dyn_cast<ttg::DotOperandEncodingAttr>()) {
|
||||
// MMAv1 and MMAv2
|
||||
bool needTrans = dyn_cast_or_null<tt::TransOp>(
|
||||
cvt.getDefiningOp()->getOperand(0).getDefiningOp());
|
||||
unsigned bitWidth = ty.getElementType().getIntOrFloatBitWidth();
|
||||
sharedEnc = ttg::SharedEncodingAttr::get(
|
||||
ty.getContext(), dotOpEnc, ty.getShape(),
|
||||
ttg::getOrder(ty.getEncoding()), CTALayout, bitWidth);
|
||||
ttg::getOrder(ty.getEncoding()), CTALayout, bitWidth, needTrans);
|
||||
} else {
|
||||
// MMAv3
|
||||
sharedEnc = ttg::SharedEncodingAttr::get(ty.getContext(), ty.getShape(),
|
||||
|
||||
Reference in New Issue
Block a user