[OPTIMIZER] Infer the alignment info of loops' induction variables (#1350)

Before this PR, loops' induction variables' (IV) alignment info is lost.
For example:
```
for n in range(0, K, BLOCK):
     x = base + n
                       ^--  Triton doesn't know n is always a multiple of BLOCK
```

This PR fixes this.

---------

Co-authored-by: Philippe Tillet <phil@openai.com>
This commit is contained in:
Da Yan
2023-03-16 03:39:08 -04:00
committed by GitHub
parent 4b774ee4d0
commit 9d5505d043

View File

@@ -164,16 +164,16 @@ public:
}
};
class ConstantOpAxisInfoVisitor final
: public AxisInfoVisitorImpl<arith::ConstantOp> {
template <typename OpTy>
class ConstantOpAxisInfoVisitor final : public AxisInfoVisitorImpl<OpTy> {
public:
using AxisInfoVisitorImpl<arith::ConstantOp>::AxisInfoVisitorImpl;
using AxisInfoVisitorImpl<OpTy>::AxisInfoVisitorImpl;
AxisInfo
getAxisInfo(arith::ConstantOp op,
getAxisInfo(OpTy op,
ArrayRef<const dataflow::Lattice<AxisInfo> *> operands) override {
auto intAttr = op.getValue().dyn_cast<IntegerAttr>();
auto boolAttr = op.getValue().dyn_cast<BoolAttr>();
auto intAttr = op.getValue().template dyn_cast<IntegerAttr>();
auto boolAttr = op.getValue().template dyn_cast<BoolAttr>();
if (intAttr || boolAttr) {
int64_t value{};
if (intAttr)
@@ -186,10 +186,10 @@ public:
/*knownConstantValue=*/{value});
}
// TODO: generalize to dense attr
auto splatAttr = op.getValue().dyn_cast<SplatElementsAttr>();
auto splatAttr = op.getValue().template dyn_cast<SplatElementsAttr>();
if (splatAttr && splatAttr.getElementType().isIntOrIndex()) {
int64_t value = splatAttr.getSplatValue<APInt>().getZExtValue();
TensorType ty = splatAttr.getType().cast<TensorType>();
int64_t value = splatAttr.template getSplatValue<APInt>().getZExtValue();
TensorType ty = splatAttr.getType().template cast<TensorType>();
return AxisInfo(
/*contiguity=*/AxisInfo::DimVectorT(ty.getRank(), 1),
/*divisibility=*/
@@ -233,7 +233,8 @@ private:
if (lhs.getConstantValue().has_value() &&
rhs.getConstantValue().has_value()) {
if constexpr (std::is_same_v<OpTy, arith::AddIOp> ||
std::is_same_v<OpTy, triton::AddPtrOp>) {
std::is_same_v<OpTy, triton::AddPtrOp> ||
std::is_same_v<OpTy, LLVM::AddOp>) {
return {lhs.getConstantValue().value() +
rhs.getConstantValue().value()};
} else if constexpr (std::is_same_v<OpTy, arith::SubIOp>) {
@@ -812,11 +813,15 @@ AxisInfoAnalysis::AxisInfoAnalysis(DataFlowSolver &solver)
CastOpAxisInfoVisitor<triton::gpu::ConvertLayoutOp>,
CastOpAxisInfoVisitor<mlir::UnrealizedConversionCastOp>,
CastOpAxisInfoVisitor<triton::BitcastOp>>();
// TODO: Remove rules for LLVM::ConstantOp, LLVM::AddOp
// when scf.for supports integers induction variable
visitors.append<MakeRangeOpAxisInfoVisitor>();
visitors.append<ConstantOpAxisInfoVisitor>();
visitors.append<ConstantOpAxisInfoVisitor<arith::ConstantOp>,
ConstantOpAxisInfoVisitor<LLVM::ConstantOp>>();
visitors.append<AddSubOpAxisInfoVisitor<triton::AddPtrOp>,
AddSubOpAxisInfoVisitor<arith::AddIOp>,
AddSubOpAxisInfoVisitor<arith::SubIOp>>();
AddSubOpAxisInfoVisitor<arith::SubIOp>,
AddSubOpAxisInfoVisitor<LLVM::AddOp>>();
visitors.append<MulIOpAxisInfoVisitor>();
visitors.append<DivOpAxisInfoVisitor<arith::DivSIOp>,
DivOpAxisInfoVisitor<arith::DivUIOp>>();