mirror of
https://github.com/ROCm/ROCm.git
synced 2026-04-05 03:01:17 -04:00
[OPTIMIZER] Tweaked layout removal conversion heuristics (#1501)
Loads are now consider cheap to rematerialize when there are more threads than elements in the tensor
This commit is contained in:
@@ -89,10 +89,21 @@ LogicalResult invertEncoding(Attribute targetEncoding, Operation *op,
|
||||
}
|
||||
|
||||
bool expensiveLoadOrStore(Operation *op, Attribute &targetEncoding) {
|
||||
// Case 1: A size 1 tensor is not expensive since all threads will load the
|
||||
// Case 1a: A size 1 tensor is not expensive since all threads will load the
|
||||
// same
|
||||
if (isSingleValue(op->getOperand(0)))
|
||||
return false;
|
||||
// Case 1b: Tensor of pointers has more threads than elements
|
||||
// we can presume a high hit-rate that makes it cheap to load
|
||||
auto ptrType = op->getOperand(0).getType().cast<RankedTensorType>();
|
||||
IntegerAttr numWarps =
|
||||
op->getParentOfType<ModuleOp>()->getAttrOfType<IntegerAttr>(
|
||||
"triton_gpu.num-warps");
|
||||
if (numWarps) {
|
||||
int sizePerThread = triton::gpu::getElemsPerThread(ptrType);
|
||||
if (ptrType.getNumElements() < numWarps.getInt() * 32)
|
||||
return false;
|
||||
}
|
||||
// auto ptr = op->getOperand(0);
|
||||
//// Case 2: We assume that `evict_last` loads/stores have high hit rate
|
||||
// if (auto load = dyn_cast<triton::LoadOp>(op))
|
||||
@@ -103,15 +114,17 @@ bool expensiveLoadOrStore(Operation *op, Attribute &targetEncoding) {
|
||||
// return false;
|
||||
// if (auto tensorTy = ptr.getType().dyn_cast<RankedTensorType>()) {
|
||||
// auto encoding = tensorTy.getEncoding();
|
||||
// // Case 3: Different type conversion is expensive (e.g., mma <-> block)
|
||||
// if (encoding.getTypeID() != targetEncoding.getTypeID())
|
||||
// // Case 3: Different type conversion is expensive (e.g., mma <->
|
||||
// block) if (encoding.getTypeID() != targetEncoding.getTypeID())
|
||||
// return true;
|
||||
// auto sizePerThread = triton::gpu::getSizePerThread(encoding);
|
||||
// auto targetSizePerThread = triton::gpu::getSizePerThread(targetEncoding);
|
||||
// auto order = triton::gpu::getOrder(encoding);
|
||||
// auto targetOrder = triton::gpu::getOrder(targetEncoding);
|
||||
// // Case 4: The targeEncoding may expose more vectorization opportunities
|
||||
// return sizePerThread[order[0]] >= targetSizePerThread[targetOrder[0]];
|
||||
// auto targetSizePerThread =
|
||||
// triton::gpu::getSizePerThread(targetEncoding); auto order =
|
||||
// triton::gpu::getOrder(encoding); auto targetOrder =
|
||||
// triton::gpu::getOrder(targetEncoding);
|
||||
// // Case 4: The targeEncoding may expose more vectorization
|
||||
// opportunities return sizePerThread[order[0]] >=
|
||||
// targetSizePerThread[targetOrder[0]];
|
||||
// }
|
||||
return true;
|
||||
}
|
||||
|
||||
@@ -67,6 +67,21 @@ func.func @remat_single_value(%arg: !tt.ptr<i32> {tt.divisibility = 16 : i32}) {
|
||||
return
|
||||
}
|
||||
|
||||
|
||||
module attributes {"triton_gpu.num-warps" = 4 : i32} {
|
||||
func.func @remat_fast_load(%arg: !tt.ptr<i32> {tt.divisibility = 16 : i32}) {
|
||||
%0 = tt.splat %arg : (!tt.ptr<i32>) -> tensor<16x!tt.ptr<i32>, #layout1>
|
||||
%1 = tt.make_range {end = 16 : i32, start = 0 : i32} : tensor<16xi32, #layout1>
|
||||
%2 = tt.addptr %0, %1 : tensor<16x!tt.ptr<i32>, #layout1>, tensor<16xi32, #layout1>
|
||||
%3 = tt.load %2 {cache = 1 : i32, evict = 1 : i32, isVolatile = false} : tensor<16xi32, #layout1>
|
||||
// CHECK-NOT: triton_gpu.convert_layout
|
||||
%4 = triton_gpu.convert_layout %3 : (tensor<16xi32, #layout1>) -> tensor<16xi32, #layout0>
|
||||
%5 = triton_gpu.convert_layout %2 : (tensor<16x!tt.ptr<i32>, #layout1>) -> tensor<16x!tt.ptr<i32>, #layout0>
|
||||
tt.store %5, %4 : tensor<16xi32, #layout0>
|
||||
return
|
||||
}
|
||||
}
|
||||
|
||||
// CHECK-LABEL: if
|
||||
func.func @if(%arg0: i32, %arg1: !tt.ptr<i32> {tt.divisibility = 16 : i32}) {
|
||||
// CHECK-NOT: triton_gpu.convert_layout
|
||||
|
||||
Reference in New Issue
Block a user