[OPTIMIZER] Fix Shared layout in OptimizeDotOperands pass to generate correct swizzling code (#2180)

fix bug #1937

Co-authored-by: Philippe Tillet <phil@openai.com>
This commit is contained in:
Bin Fan
2023-09-13 12:52:09 -07:00
committed by GitHub
parent c61d772eee
commit 38a2ecdccf
3 changed files with 40 additions and 6 deletions

View File

@@ -113,6 +113,16 @@ compared to 1*64 when the hasLeadingOffset is false.
"ArrayRef<unsigned>":$order,
"CTALayoutAttr":$CTALayout,
"unsigned":$typeWidthInBit), [{
bool needTrans = false; // default value
return get(context, dotOpEnc, shape, order, CTALayout, typeWidthInBit, needTrans);
}]>,
AttrBuilder<(ins "DotOperandEncodingAttr":$dotOpEnc,
"ArrayRef<int64_t>":$shape,
"ArrayRef<unsigned>":$order,
"CTALayoutAttr":$CTALayout,
"unsigned":$typeWidthInBit,
"bool":$needTrans), [{
auto mmaEnc = dotOpEnc.getParent().dyn_cast<MmaEncodingAttr>();
if(!mmaEnc)
@@ -152,16 +162,23 @@ compared to 1*64 when the hasLeadingOffset is false.
// --- handle A operand ---
if (opIdx == 0) { // compute swizzling for A operand
int vec = (order[0] == 1) ? matShape[2] : matShape[0]; // k : m
int mmaStride = (order[0] == 1) ? matShape[0] : matShape[2];
int m = (needTrans) ? matShape[2] : matShape[0];
int k = (needTrans) ? matShape[0] : matShape[2];
int vec = (order[0] == 1) ? k : m;
int mmaStride = (order[0] == 1) ? m : k;
int maxPhase = mmaStride / perPhase;
return get(context, vec, perPhase, maxPhase, order, CTALayout);
}
// --- handle B operand ---
if (opIdx == 1) {
int vec = (order[0] == 1) ? matShape[1] : matShape[2]; // n : k
int mmaStride = (order[0] == 1) ? matShape[2] : matShape[1];
// we compute vec and maxPhase m, n and k size of the mma
// instruction. when matmul operands is transposed, we should
// consider that to get m, n and k.
int n = needTrans ? matShape[2] : matShape[1];
int k = needTrans ? matShape[1] : matShape[2];
int vec = (order[0] == 1) ? n : k;
int mmaStride = (order[0] == 1) ? k : n;
int maxPhase = mmaStride / perPhase;
return get(context, vec, perPhase, maxPhase, order, CTALayout);
}
@@ -189,6 +206,16 @@ compared to 1*64 when the hasLeadingOffset is false.
return get(context, dotOpEnc, shape, order, CTALayout, bitwidth);
}]>,
AttrBuilder<(ins "DotOperandEncodingAttr":$dotOpEnc,
"ArrayRef<int64_t>":$shape,
"ArrayRef<unsigned>":$order,
"CTALayoutAttr":$CTALayout,
"Type":$eltTy,
"bool":$needTrans), [{
unsigned bitwidth = eltTy.getIntOrFloatBitWidth();
return get(context, dotOpEnc, shape, order, CTALayout, bitwidth, needTrans);
}]>,
AttrBuilder<(ins "ArrayRef<int64_t>":$shape,
"ArrayRef<unsigned>":$order,
"CTALayoutAttr":$CTALayout,