mirror of
https://github.com/ROCm/ROCm.git
synced 2026-04-05 03:01:17 -04:00
[OPTIMIZER] AxisInfoVisitor for LoadOp constancy calculation (#1968)
If you call `result = load(x, mask)` where `x` and `mask` have some constancy properties, then you can infer some constancy properties for `result`.
This commit is contained in:
@@ -469,6 +469,36 @@ public:
|
||||
}
|
||||
};
|
||||
|
||||
class LoadOpAxisInfoVisitor final : public AxisInfoVisitorImpl<triton::LoadOp> {
|
||||
public:
|
||||
using AxisInfoVisitorImpl<triton::LoadOp>::AxisInfoVisitorImpl;
|
||||
|
||||
AxisInfo
|
||||
getAxisInfo(triton::LoadOp op,
|
||||
ArrayRef<const dataflow::Lattice<AxisInfo> *> operands) override {
|
||||
// If pointers and mask both have constancy properties, those properties
|
||||
// will also extend to output.
|
||||
AxisInfo ptrInfo = operands[0]->getValue();
|
||||
std::optional<AxisInfo> maskInfo;
|
||||
if (operands.size() > 1) {
|
||||
maskInfo = operands[1]->getValue();
|
||||
}
|
||||
AxisInfo::DimVectorT contiguity;
|
||||
AxisInfo::DimVectorT divisibility;
|
||||
AxisInfo::DimVectorT constancy;
|
||||
|
||||
for (int d = 0; d < ptrInfo.getRank(); ++d) {
|
||||
contiguity.push_back(1);
|
||||
divisibility.push_back(1);
|
||||
constancy.push_back(
|
||||
gcd(ptrInfo.getConstancy(d),
|
||||
maskInfo.has_value() ? maskInfo->getConstancy(d) : 0));
|
||||
}
|
||||
|
||||
return AxisInfo(contiguity, divisibility, constancy);
|
||||
}
|
||||
};
|
||||
|
||||
class ExpandDimsOpAxisInfoVisitor final
|
||||
: public AxisInfoVisitorImpl<triton::ExpandDimsOp> {
|
||||
public:
|
||||
@@ -871,6 +901,7 @@ AxisInfoAnalysis::AxisInfoAnalysis(DataFlowSolver &solver)
|
||||
MaxMinOpAxisInfoVisitor<arith::MaxUIOp>,
|
||||
MaxMinOpAxisInfoVisitor<arith::MinSIOp>,
|
||||
MaxMinOpAxisInfoVisitor<arith::MinUIOp>>();
|
||||
visitors.append<LoadOpAxisInfoVisitor>();
|
||||
}
|
||||
|
||||
void AxisInfoAnalysis::visitOperation(
|
||||
|
||||
@@ -402,6 +402,35 @@ tt.func @permute_2d(%arg0: !tt.ptr<f32> {tt.divisibility = 16 : i32}, %arg1: i32
|
||||
|
||||
// -----
|
||||
|
||||
// CHECK-LABEL: @load_constancy
|
||||
tt.func @load_constancy(%arg0: !tt.ptr<f32> {tt.divisibility = 16 : i32}, %arg1: i32 {tt.divisibility = 1 : i32}) {
|
||||
// CHECK: divisibility = [16]
|
||||
%sixteen = arith.constant dense<16> : tensor<1024xi32>
|
||||
// CHECK-NEXT: divisibility = [8]
|
||||
%eight = arith.constant dense<8> : tensor<1024xi32>
|
||||
// CHECK-NEXT: contiguity = [1024], divisibility = [1073741824], constancy = [1]
|
||||
%1 = tt.make_range {end = 1024 : i32, start = 0 : i32} : tensor<1024xi32>
|
||||
// CHECK-NEXT: constancy = [16]
|
||||
%2 = arith.divsi %1, %sixteen : tensor<1024xi32>
|
||||
// CHECK-NEXT: constancy = [1024]
|
||||
%3 = tt.splat %arg0 : (!tt.ptr<f32>) -> tensor<1024x!tt.ptr<f32>>
|
||||
// CHECK-NEXT: constancy = [1024]
|
||||
%4 = tt.splat %arg1 : (i32) -> tensor<1024xi32>
|
||||
// CHECK-NEXT: constancy = [8]
|
||||
%5 = arith.divsi %1, %eight : tensor<1024xi32>
|
||||
// CHECK-NEXT: constancy = [8]
|
||||
%6 = arith.cmpi slt, %5, %4 : tensor<1024xi32>
|
||||
// CHECK-NEXT: constancy = [16]
|
||||
%7 = tt.addptr %3, %2 : tensor<1024x!tt.ptr<f32>>, tensor<1024xi32>
|
||||
// CHECK-NEXT: constancy = [16]
|
||||
%8 = tt.load %7 {cache = 1 : i32, evict = 1 : i32, isVolatile = false} : tensor<1024xf32>
|
||||
// CHECK-NEXT: constancy = [8]
|
||||
%9 = tt.load %7, %6 {cache = 1 : i32, evict = 1 : i32, isVolatile = false} : tensor<1024xf32>
|
||||
tt.return
|
||||
}
|
||||
|
||||
// -----
|
||||
|
||||
// This is a tiny test for verifying StoreOp-related alignment, It simply store a constant to a buffer.
|
||||
// CHECK-LABEL: @store_constant_align
|
||||
tt.func @store_constant_align(%addr: !tt.ptr<f32> {tt.divisibility = 16 : i32}, %n: i32 {tt.divisibility = 16 : i32}) {
|
||||
|
||||
Reference in New Issue
Block a user