mirror of
https://github.com/ROCm/ROCm.git
synced 2026-04-05 03:01:17 -04:00
[OPTIMIZER] Fix Shared layout in OptimizeDotOperands pass to generate correct swizzling code (#2180)
fix bug #1937 Co-authored-by: Philippe Tillet <phil@openai.com>
This commit is contained in:
@@ -113,6 +113,16 @@ compared to 1*64 when the hasLeadingOffset is false.
|
||||
"ArrayRef<unsigned>":$order,
|
||||
"CTALayoutAttr":$CTALayout,
|
||||
"unsigned":$typeWidthInBit), [{
|
||||
bool needTrans = false; // default value
|
||||
return get(context, dotOpEnc, shape, order, CTALayout, typeWidthInBit, needTrans);
|
||||
}]>,
|
||||
|
||||
AttrBuilder<(ins "DotOperandEncodingAttr":$dotOpEnc,
|
||||
"ArrayRef<int64_t>":$shape,
|
||||
"ArrayRef<unsigned>":$order,
|
||||
"CTALayoutAttr":$CTALayout,
|
||||
"unsigned":$typeWidthInBit,
|
||||
"bool":$needTrans), [{
|
||||
auto mmaEnc = dotOpEnc.getParent().dyn_cast<MmaEncodingAttr>();
|
||||
|
||||
if(!mmaEnc)
|
||||
@@ -152,16 +162,23 @@ compared to 1*64 when the hasLeadingOffset is false.
|
||||
|
||||
// --- handle A operand ---
|
||||
if (opIdx == 0) { // compute swizzling for A operand
|
||||
int vec = (order[0] == 1) ? matShape[2] : matShape[0]; // k : m
|
||||
int mmaStride = (order[0] == 1) ? matShape[0] : matShape[2];
|
||||
int m = (needTrans) ? matShape[2] : matShape[0];
|
||||
int k = (needTrans) ? matShape[0] : matShape[2];
|
||||
int vec = (order[0] == 1) ? k : m;
|
||||
int mmaStride = (order[0] == 1) ? m : k;
|
||||
int maxPhase = mmaStride / perPhase;
|
||||
return get(context, vec, perPhase, maxPhase, order, CTALayout);
|
||||
}
|
||||
|
||||
// --- handle B operand ---
|
||||
if (opIdx == 1) {
|
||||
int vec = (order[0] == 1) ? matShape[1] : matShape[2]; // n : k
|
||||
int mmaStride = (order[0] == 1) ? matShape[2] : matShape[1];
|
||||
// we compute vec and maxPhase m, n and k size of the mma
|
||||
// instruction. when matmul operands is transposed, we should
|
||||
// consider that to get m, n and k.
|
||||
int n = needTrans ? matShape[2] : matShape[1];
|
||||
int k = needTrans ? matShape[1] : matShape[2];
|
||||
int vec = (order[0] == 1) ? n : k;
|
||||
int mmaStride = (order[0] == 1) ? k : n;
|
||||
int maxPhase = mmaStride / perPhase;
|
||||
return get(context, vec, perPhase, maxPhase, order, CTALayout);
|
||||
}
|
||||
@@ -189,6 +206,16 @@ compared to 1*64 when the hasLeadingOffset is false.
|
||||
return get(context, dotOpEnc, shape, order, CTALayout, bitwidth);
|
||||
}]>,
|
||||
|
||||
AttrBuilder<(ins "DotOperandEncodingAttr":$dotOpEnc,
|
||||
"ArrayRef<int64_t>":$shape,
|
||||
"ArrayRef<unsigned>":$order,
|
||||
"CTALayoutAttr":$CTALayout,
|
||||
"Type":$eltTy,
|
||||
"bool":$needTrans), [{
|
||||
unsigned bitwidth = eltTy.getIntOrFloatBitWidth();
|
||||
return get(context, dotOpEnc, shape, order, CTALayout, bitwidth, needTrans);
|
||||
}]>,
|
||||
|
||||
AttrBuilder<(ins "ArrayRef<int64_t>":$shape,
|
||||
"ArrayRef<unsigned>":$order,
|
||||
"CTALayoutAttr":$CTALayout,
|
||||
|
||||
Reference in New Issue
Block a user