mirror of
https://github.com/ROCm/ROCm.git
synced 2026-02-21 03:00:39 -05:00
Reverts openai/triton#2485
This commit is contained in:
@@ -44,9 +44,17 @@ static int getMMAVersionSafe(int computeCapability, tt::DotOp op) {
|
||||
}
|
||||
|
||||
SmallVector<unsigned, 2>
|
||||
getWarpsPerTileV2(const ArrayRef<int64_t> shape,
|
||||
const SmallVector<int64_t, 2> &shapePerWarp, int numWarps) {
|
||||
warpsPerTileV2(tt::DotOp dotOp, const ArrayRef<int64_t> shape, int numWarps) {
|
||||
auto filter = [&dotOp](Operation *op) {
|
||||
return op->getParentRegion() == dotOp->getParentRegion();
|
||||
};
|
||||
auto slices = mlir::getSlice(dotOp, {filter});
|
||||
for (Operation *op : slices)
|
||||
if (isa<tt::DotOp>(op) && (op != dotOp))
|
||||
return {(unsigned)numWarps, 1};
|
||||
|
||||
SmallVector<unsigned, 2> ret = {1, 1};
|
||||
SmallVector<int64_t, 2> shapePerWarp = {16, 8};
|
||||
// TODO (@daadaada): double-check.
|
||||
// original logic in
|
||||
// https://github.com/openai/triton/blob/master/lib/codegen/analysis/layout.cc#L252
|
||||
@@ -67,27 +75,6 @@ getWarpsPerTileV2(const ArrayRef<int64_t> shape,
|
||||
return ret;
|
||||
}
|
||||
|
||||
SmallVector<unsigned, 2>
|
||||
warpsPerTileV2(tt::DotOp dotOp, const ArrayRef<int64_t> shape, int numWarps) {
|
||||
SmallVector<int64_t, 2> shapePerWarp = {16, 8};
|
||||
|
||||
auto filter = [&dotOp](Operation *op) {
|
||||
return op->getParentRegion() == dotOp->getParentRegion();
|
||||
};
|
||||
auto slices = mlir::getSlice(dotOp, {filter});
|
||||
for (Operation *op : slices)
|
||||
if (isa<tt::DotOp>(op) && (op != dotOp)) {
|
||||
if (shape[0] < shapePerWarp[0] * numWarps &&
|
||||
shape[1] > shapePerWarp[1] * numWarps) {
|
||||
return getWarpsPerTileV2(shape, shapePerWarp, numWarps);
|
||||
} else {
|
||||
return {(unsigned)numWarps, 1};
|
||||
}
|
||||
}
|
||||
|
||||
return getWarpsPerTileV2(shape, shapePerWarp, numWarps);
|
||||
}
|
||||
|
||||
SmallVector<unsigned, 2>
|
||||
warpsPerTileV3(tt::DotOp dotOp, const ArrayRef<int64_t> shape, int numWarps,
|
||||
const SmallVector<unsigned, 3> &instrShape) {
|
||||
|
||||
@@ -147,14 +147,14 @@ flash_attention_data = {
|
||||
(4, 48, 4096, 64, True, True, 'forward', 'bfloat16'): 0.471,
|
||||
(4, 48, 1024, 16, True, True, 'forward', 'float32'): 0.155,
|
||||
(4, 48, 4096, 64, True, True, 'backward', 'float16'): 0.203,
|
||||
(4, 48, 4096, 64, True, True, 'backward', 'bfloat16'): 0.230,
|
||||
(4, 48, 1024, 16, True, True, 'backward', 'float32'): 0.146,
|
||||
(4, 48, 4096, 64, True, True, 'backward', 'bfloat16'): 0.202,
|
||||
(4, 48, 1024, 16, True, True, 'backward', 'float32'): 0.108,
|
||||
(4, 48, 4096, 64, True, False, 'forward', 'float16'): 0.306,
|
||||
(4, 48, 4096, 64, True, False, 'forward', 'bfloat16'): 0.266,
|
||||
(4, 48, 1024, 16, True, False, 'forward', 'float32'): 0.098,
|
||||
(4, 48, 4096, 64, True, False, 'backward', 'float16'): 0.134,
|
||||
(4, 48, 4096, 64, True, False, 'backward', 'bfloat16'): 0.135,
|
||||
(4, 48, 1024, 16, True, False, 'backward', 'float32'): 0.098,
|
||||
(4, 48, 1024, 16, True, False, 'backward', 'float32'): 0.066,
|
||||
(4, 48, 4096, 64, False, True, 'forward', 'float16'): 0.541,
|
||||
(4, 48, 4096, 64, False, True, 'forward', 'bfloat16'): 0.471,
|
||||
(4, 48, 1024, 16, False, True, 'forward', 'float32'): 0.150,
|
||||
|
||||
Reference in New Issue
Block a user