Revert "[OPTIMIZER] Tweak warpsPerCTA based on the shape of MMA output (#2485)" (#2497)

Reverts openai/triton#2485
This commit is contained in:
Philippe Tillet
2023-10-13 23:32:59 -07:00
committed by GitHub
parent 76858bd917
commit 8db4fac3b0
2 changed files with 13 additions and 26 deletions

View File

@@ -44,9 +44,17 @@ static int getMMAVersionSafe(int computeCapability, tt::DotOp op) {
}
SmallVector<unsigned, 2>
getWarpsPerTileV2(const ArrayRef<int64_t> shape,
const SmallVector<int64_t, 2> &shapePerWarp, int numWarps) {
warpsPerTileV2(tt::DotOp dotOp, const ArrayRef<int64_t> shape, int numWarps) {
auto filter = [&dotOp](Operation *op) {
return op->getParentRegion() == dotOp->getParentRegion();
};
auto slices = mlir::getSlice(dotOp, {filter});
for (Operation *op : slices)
if (isa<tt::DotOp>(op) && (op != dotOp))
return {(unsigned)numWarps, 1};
SmallVector<unsigned, 2> ret = {1, 1};
SmallVector<int64_t, 2> shapePerWarp = {16, 8};
// TODO (@daadaada): double-check.
// original logic in
// https://github.com/openai/triton/blob/master/lib/codegen/analysis/layout.cc#L252
@@ -67,27 +75,6 @@ getWarpsPerTileV2(const ArrayRef<int64_t> shape,
return ret;
}
SmallVector<unsigned, 2>
warpsPerTileV2(tt::DotOp dotOp, const ArrayRef<int64_t> shape, int numWarps) {
SmallVector<int64_t, 2> shapePerWarp = {16, 8};
auto filter = [&dotOp](Operation *op) {
return op->getParentRegion() == dotOp->getParentRegion();
};
auto slices = mlir::getSlice(dotOp, {filter});
for (Operation *op : slices)
if (isa<tt::DotOp>(op) && (op != dotOp)) {
if (shape[0] < shapePerWarp[0] * numWarps &&
shape[1] > shapePerWarp[1] * numWarps) {
return getWarpsPerTileV2(shape, shapePerWarp, numWarps);
} else {
return {(unsigned)numWarps, 1};
}
}
return getWarpsPerTileV2(shape, shapePerWarp, numWarps);
}
SmallVector<unsigned, 2>
warpsPerTileV3(tt::DotOp dotOp, const ArrayRef<int64_t> shape, int numWarps,
const SmallVector<unsigned, 3> &instrShape) {

View File

@@ -147,14 +147,14 @@ flash_attention_data = {
(4, 48, 4096, 64, True, True, 'forward', 'bfloat16'): 0.471,
(4, 48, 1024, 16, True, True, 'forward', 'float32'): 0.155,
(4, 48, 4096, 64, True, True, 'backward', 'float16'): 0.203,
(4, 48, 4096, 64, True, True, 'backward', 'bfloat16'): 0.230,
(4, 48, 1024, 16, True, True, 'backward', 'float32'): 0.146,
(4, 48, 4096, 64, True, True, 'backward', 'bfloat16'): 0.202,
(4, 48, 1024, 16, True, True, 'backward', 'float32'): 0.108,
(4, 48, 4096, 64, True, False, 'forward', 'float16'): 0.306,
(4, 48, 4096, 64, True, False, 'forward', 'bfloat16'): 0.266,
(4, 48, 1024, 16, True, False, 'forward', 'float32'): 0.098,
(4, 48, 4096, 64, True, False, 'backward', 'float16'): 0.134,
(4, 48, 4096, 64, True, False, 'backward', 'bfloat16'): 0.135,
(4, 48, 1024, 16, True, False, 'backward', 'float32'): 0.098,
(4, 48, 1024, 16, True, False, 'backward', 'float32'): 0.066,
(4, 48, 4096, 64, False, True, 'forward', 'float16'): 0.541,
(4, 48, 4096, 64, False, True, 'forward', 'bfloat16'): 0.471,
(4, 48, 1024, 16, False, True, 'forward', 'float32'): 0.150,