[OPTIMIZER] Tweak warpsPerCTA based on the shape of MMA output (#2581)

In current implementation, warpsPerCTA is always set to [numWarps, 1]
for 2 tt.dot fusion scenario. But, it is not optimal for cases such that
tt.dot doesn't have enough parallelism on row dimension but on column
dimension.
This commit is contained in:
Weixing Zhang
2023-11-03 13:40:03 -07:00
committed by GitHub
parent 6ac9d51ff0
commit 34b89a1173
2 changed files with 11 additions and 6 deletions

View File

@@ -50,8 +50,13 @@ warpsPerTileV2(tt::DotOp dotOp, const ArrayRef<int64_t> shape, int numWarps) {
};
auto slices = multiRootGetSlice(dotOp, {filter});
for (Operation *op : slices)
if (isa<tt::DotOp>(op) && (op != dotOp))
return {(unsigned)numWarps, 1};
if (isa<tt::DotOp>(op) && (op != dotOp)) {
if (shape[0] >= shape[1]) {
return {(unsigned)numWarps, 1};
} else {
return {1, (unsigned)numWarps};
}
}
SmallVector<unsigned, 2> ret = {1, 1};
SmallVector<int64_t, 2> shapePerWarp = {16, 8};