mirror of
https://github.com/ROCm/ROCm.git
synced 2026-04-05 03:01:17 -04:00
Reverts openai/triton#2485
This commit is contained in:
@@ -44,9 +44,17 @@ static int getMMAVersionSafe(int computeCapability, tt::DotOp op) {
|
||||
}
|
||||
|
||||
SmallVector<unsigned, 2>
|
||||
getWarpsPerTileV2(const ArrayRef<int64_t> shape,
|
||||
const SmallVector<int64_t, 2> &shapePerWarp, int numWarps) {
|
||||
warpsPerTileV2(tt::DotOp dotOp, const ArrayRef<int64_t> shape, int numWarps) {
|
||||
auto filter = [&dotOp](Operation *op) {
|
||||
return op->getParentRegion() == dotOp->getParentRegion();
|
||||
};
|
||||
auto slices = mlir::getSlice(dotOp, {filter});
|
||||
for (Operation *op : slices)
|
||||
if (isa<tt::DotOp>(op) && (op != dotOp))
|
||||
return {(unsigned)numWarps, 1};
|
||||
|
||||
SmallVector<unsigned, 2> ret = {1, 1};
|
||||
SmallVector<int64_t, 2> shapePerWarp = {16, 8};
|
||||
// TODO (@daadaada): double-check.
|
||||
// original logic in
|
||||
// https://github.com/openai/triton/blob/master/lib/codegen/analysis/layout.cc#L252
|
||||
@@ -67,27 +75,6 @@ getWarpsPerTileV2(const ArrayRef<int64_t> shape,
|
||||
return ret;
|
||||
}
|
||||
|
||||
SmallVector<unsigned, 2>
|
||||
warpsPerTileV2(tt::DotOp dotOp, const ArrayRef<int64_t> shape, int numWarps) {
|
||||
SmallVector<int64_t, 2> shapePerWarp = {16, 8};
|
||||
|
||||
auto filter = [&dotOp](Operation *op) {
|
||||
return op->getParentRegion() == dotOp->getParentRegion();
|
||||
};
|
||||
auto slices = mlir::getSlice(dotOp, {filter});
|
||||
for (Operation *op : slices)
|
||||
if (isa<tt::DotOp>(op) && (op != dotOp)) {
|
||||
if (shape[0] < shapePerWarp[0] * numWarps &&
|
||||
shape[1] > shapePerWarp[1] * numWarps) {
|
||||
return getWarpsPerTileV2(shape, shapePerWarp, numWarps);
|
||||
} else {
|
||||
return {(unsigned)numWarps, 1};
|
||||
}
|
||||
}
|
||||
|
||||
return getWarpsPerTileV2(shape, shapePerWarp, numWarps);
|
||||
}
|
||||
|
||||
SmallVector<unsigned, 2>
|
||||
warpsPerTileV3(tt::DotOp dotOp, const ArrayRef<int64_t> shape, int numWarps,
|
||||
const SmallVector<unsigned, 3> &instrShape) {
|
||||
|
||||
Reference in New Issue
Block a user