mirror of
https://github.com/ROCm/ROCm.git
synced 2026-04-05 03:01:17 -04:00
[OPTIMIZER] Tweak warpsPerCTA based on the shape of MMA output (#2581)
In current implementation, warpsPerCTA is always set to [numWarps, 1] for 2 tt.dot fusion scenario. But, it is not optimal for cases such that tt.dot doesn't have enough parallelism on row dimension but on column dimension.
This commit is contained in:
@@ -50,8 +50,13 @@ warpsPerTileV2(tt::DotOp dotOp, const ArrayRef<int64_t> shape, int numWarps) {
|
||||
};
|
||||
auto slices = multiRootGetSlice(dotOp, {filter});
|
||||
for (Operation *op : slices)
|
||||
if (isa<tt::DotOp>(op) && (op != dotOp))
|
||||
return {(unsigned)numWarps, 1};
|
||||
if (isa<tt::DotOp>(op) && (op != dotOp)) {
|
||||
if (shape[0] >= shape[1]) {
|
||||
return {(unsigned)numWarps, 1};
|
||||
} else {
|
||||
return {1, (unsigned)numWarps};
|
||||
}
|
||||
}
|
||||
|
||||
SmallVector<unsigned, 2> ret = {1, 1};
|
||||
SmallVector<int64_t, 2> shapePerWarp = {16, 8};
|
||||
|
||||
Reference in New Issue
Block a user