[OPTIMIZER] AxisInfoVisitor for LoadOp constancy calculation (#1968)

If you call `result = load(x, mask)` where `x` and `mask` have some
constancy properties, then you can infer some constancy properties for
`result`.
This commit is contained in:
David Berard
2023-07-19 17:40:46 -07:00
committed by GitHub
parent 68124676c9
commit 9c422e260b
2 changed files with 60 additions and 0 deletions

View File

@@ -469,6 +469,36 @@ public:
}
};
class LoadOpAxisInfoVisitor final : public AxisInfoVisitorImpl<triton::LoadOp> {
public:
using AxisInfoVisitorImpl<triton::LoadOp>::AxisInfoVisitorImpl;
AxisInfo
getAxisInfo(triton::LoadOp op,
ArrayRef<const dataflow::Lattice<AxisInfo> *> operands) override {
// If pointers and mask both have constancy properties, those properties
// will also extend to output.
AxisInfo ptrInfo = operands[0]->getValue();
std::optional<AxisInfo> maskInfo;
if (operands.size() > 1) {
maskInfo = operands[1]->getValue();
}
AxisInfo::DimVectorT contiguity;
AxisInfo::DimVectorT divisibility;
AxisInfo::DimVectorT constancy;
for (int d = 0; d < ptrInfo.getRank(); ++d) {
contiguity.push_back(1);
divisibility.push_back(1);
constancy.push_back(
gcd(ptrInfo.getConstancy(d),
maskInfo.has_value() ? maskInfo->getConstancy(d) : 0));
}
return AxisInfo(contiguity, divisibility, constancy);
}
};
class ExpandDimsOpAxisInfoVisitor final
: public AxisInfoVisitorImpl<triton::ExpandDimsOp> {
public:
@@ -871,6 +901,7 @@ AxisInfoAnalysis::AxisInfoAnalysis(DataFlowSolver &solver)
MaxMinOpAxisInfoVisitor<arith::MaxUIOp>,
MaxMinOpAxisInfoVisitor<arith::MinSIOp>,
MaxMinOpAxisInfoVisitor<arith::MinUIOp>>();
visitors.append<LoadOpAxisInfoVisitor>();
}
void AxisInfoAnalysis::visitOperation(

View File

@@ -402,6 +402,35 @@ tt.func @permute_2d(%arg0: !tt.ptr<f32> {tt.divisibility = 16 : i32}, %arg1: i32
// -----
// CHECK-LABEL: @load_constancy
tt.func @load_constancy(%arg0: !tt.ptr<f32> {tt.divisibility = 16 : i32}, %arg1: i32 {tt.divisibility = 1 : i32}) {
// CHECK: divisibility = [16]
%sixteen = arith.constant dense<16> : tensor<1024xi32>
// CHECK-NEXT: divisibility = [8]
%eight = arith.constant dense<8> : tensor<1024xi32>
// CHECK-NEXT: contiguity = [1024], divisibility = [1073741824], constancy = [1]
%1 = tt.make_range {end = 1024 : i32, start = 0 : i32} : tensor<1024xi32>
// CHECK-NEXT: constancy = [16]
%2 = arith.divsi %1, %sixteen : tensor<1024xi32>
// CHECK-NEXT: constancy = [1024]
%3 = tt.splat %arg0 : (!tt.ptr<f32>) -> tensor<1024x!tt.ptr<f32>>
// CHECK-NEXT: constancy = [1024]
%4 = tt.splat %arg1 : (i32) -> tensor<1024xi32>
// CHECK-NEXT: constancy = [8]
%5 = arith.divsi %1, %eight : tensor<1024xi32>
// CHECK-NEXT: constancy = [8]
%6 = arith.cmpi slt, %5, %4 : tensor<1024xi32>
// CHECK-NEXT: constancy = [16]
%7 = tt.addptr %3, %2 : tensor<1024x!tt.ptr<f32>>, tensor<1024xi32>
// CHECK-NEXT: constancy = [16]
%8 = tt.load %7 {cache = 1 : i32, evict = 1 : i32, isVolatile = false} : tensor<1024xf32>
// CHECK-NEXT: constancy = [8]
%9 = tt.load %7, %6 {cache = 1 : i32, evict = 1 : i32, isVolatile = false} : tensor<1024xf32>
tt.return
}
// -----
// This is a tiny test for verifying StoreOp-related alignment, It simply store a constant to a buffer.
// CHECK-LABEL: @store_constant_align
tt.func @store_constant_align(%addr: !tt.ptr<f32> {tt.divisibility = 16 : i32}, %n: i32 {tt.divisibility = 16 : i32}) {