[OPTIMIZER] Tweaked layout removal conversion heuristics (#1501)

Loads are now consider cheap to rematerialize when there are more
threads than elements in the tensor
This commit is contained in:
Philippe Tillet
2023-04-10 15:19:08 -07:00
committed by GitHub
parent 2c06f875e4
commit 640f3c3921
2 changed files with 36 additions and 8 deletions

View File

@@ -89,10 +89,21 @@ LogicalResult invertEncoding(Attribute targetEncoding, Operation *op,
}
bool expensiveLoadOrStore(Operation *op, Attribute &targetEncoding) {
// Case 1: A size 1 tensor is not expensive since all threads will load the
// Case 1a: A size 1 tensor is not expensive since all threads will load the
// same
if (isSingleValue(op->getOperand(0)))
return false;
// Case 1b: Tensor of pointers has more threads than elements
// we can presume a high hit-rate that makes it cheap to load
auto ptrType = op->getOperand(0).getType().cast<RankedTensorType>();
IntegerAttr numWarps =
op->getParentOfType<ModuleOp>()->getAttrOfType<IntegerAttr>(
"triton_gpu.num-warps");
if (numWarps) {
int sizePerThread = triton::gpu::getElemsPerThread(ptrType);
if (ptrType.getNumElements() < numWarps.getInt() * 32)
return false;
}
// auto ptr = op->getOperand(0);
//// Case 2: We assume that `evict_last` loads/stores have high hit rate
// if (auto load = dyn_cast<triton::LoadOp>(op))
@@ -103,15 +114,17 @@ bool expensiveLoadOrStore(Operation *op, Attribute &targetEncoding) {
// return false;
// if (auto tensorTy = ptr.getType().dyn_cast<RankedTensorType>()) {
// auto encoding = tensorTy.getEncoding();
// // Case 3: Different type conversion is expensive (e.g., mma <-> block)
// if (encoding.getTypeID() != targetEncoding.getTypeID())
// // Case 3: Different type conversion is expensive (e.g., mma <->
// block) if (encoding.getTypeID() != targetEncoding.getTypeID())
// return true;
// auto sizePerThread = triton::gpu::getSizePerThread(encoding);
// auto targetSizePerThread = triton::gpu::getSizePerThread(targetEncoding);
// auto order = triton::gpu::getOrder(encoding);
// auto targetOrder = triton::gpu::getOrder(targetEncoding);
// // Case 4: The targeEncoding may expose more vectorization opportunities
// return sizePerThread[order[0]] >= targetSizePerThread[targetOrder[0]];
// auto targetSizePerThread =
// triton::gpu::getSizePerThread(targetEncoding); auto order =
// triton::gpu::getOrder(encoding); auto targetOrder =
// triton::gpu::getOrder(targetEncoding);
// // Case 4: The targeEncoding may expose more vectorization
// opportunities return sizePerThread[order[0]] >=
// targetSizePerThread[targetOrder[0]];
// }
return true;
}

View File

@@ -67,6 +67,21 @@ func.func @remat_single_value(%arg: !tt.ptr<i32> {tt.divisibility = 16 : i32}) {
return
}
module attributes {"triton_gpu.num-warps" = 4 : i32} {
func.func @remat_fast_load(%arg: !tt.ptr<i32> {tt.divisibility = 16 : i32}) {
%0 = tt.splat %arg : (!tt.ptr<i32>) -> tensor<16x!tt.ptr<i32>, #layout1>
%1 = tt.make_range {end = 16 : i32, start = 0 : i32} : tensor<16xi32, #layout1>
%2 = tt.addptr %0, %1 : tensor<16x!tt.ptr<i32>, #layout1>, tensor<16xi32, #layout1>
%3 = tt.load %2 {cache = 1 : i32, evict = 1 : i32, isVolatile = false} : tensor<16xi32, #layout1>
// CHECK-NOT: triton_gpu.convert_layout
%4 = triton_gpu.convert_layout %3 : (tensor<16xi32, #layout1>) -> tensor<16xi32, #layout0>
%5 = triton_gpu.convert_layout %2 : (tensor<16x!tt.ptr<i32>, #layout1>) -> tensor<16x!tt.ptr<i32>, #layout0>
tt.store %5, %4 : tensor<16xi32, #layout0>
return
}
}
// CHECK-LABEL: if
func.func @if(%arg0: i32, %arg1: !tt.ptr<i32> {tt.divisibility = 16 : i32}) {
// CHECK-NOT: triton_gpu.convert_layout