mirror of
https://github.com/ROCm/ROCm.git
synced 2026-04-05 03:01:17 -04:00
[OPTIMIZER] Infer the alignment info of loops' induction variables (#1350)
Before this PR, loops' induction variables' (IV) alignment info is lost.
For example:
```
for n in range(0, K, BLOCK):
x = base + n
^-- Triton doesn't know n is always a multiple of BLOCK
```
This PR fixes this.
---------
Co-authored-by: Philippe Tillet <phil@openai.com>
This commit is contained in:
@@ -164,16 +164,16 @@ public:
|
||||
}
|
||||
};
|
||||
|
||||
class ConstantOpAxisInfoVisitor final
|
||||
: public AxisInfoVisitorImpl<arith::ConstantOp> {
|
||||
template <typename OpTy>
|
||||
class ConstantOpAxisInfoVisitor final : public AxisInfoVisitorImpl<OpTy> {
|
||||
public:
|
||||
using AxisInfoVisitorImpl<arith::ConstantOp>::AxisInfoVisitorImpl;
|
||||
using AxisInfoVisitorImpl<OpTy>::AxisInfoVisitorImpl;
|
||||
|
||||
AxisInfo
|
||||
getAxisInfo(arith::ConstantOp op,
|
||||
getAxisInfo(OpTy op,
|
||||
ArrayRef<const dataflow::Lattice<AxisInfo> *> operands) override {
|
||||
auto intAttr = op.getValue().dyn_cast<IntegerAttr>();
|
||||
auto boolAttr = op.getValue().dyn_cast<BoolAttr>();
|
||||
auto intAttr = op.getValue().template dyn_cast<IntegerAttr>();
|
||||
auto boolAttr = op.getValue().template dyn_cast<BoolAttr>();
|
||||
if (intAttr || boolAttr) {
|
||||
int64_t value{};
|
||||
if (intAttr)
|
||||
@@ -186,10 +186,10 @@ public:
|
||||
/*knownConstantValue=*/{value});
|
||||
}
|
||||
// TODO: generalize to dense attr
|
||||
auto splatAttr = op.getValue().dyn_cast<SplatElementsAttr>();
|
||||
auto splatAttr = op.getValue().template dyn_cast<SplatElementsAttr>();
|
||||
if (splatAttr && splatAttr.getElementType().isIntOrIndex()) {
|
||||
int64_t value = splatAttr.getSplatValue<APInt>().getZExtValue();
|
||||
TensorType ty = splatAttr.getType().cast<TensorType>();
|
||||
int64_t value = splatAttr.template getSplatValue<APInt>().getZExtValue();
|
||||
TensorType ty = splatAttr.getType().template cast<TensorType>();
|
||||
return AxisInfo(
|
||||
/*contiguity=*/AxisInfo::DimVectorT(ty.getRank(), 1),
|
||||
/*divisibility=*/
|
||||
@@ -233,7 +233,8 @@ private:
|
||||
if (lhs.getConstantValue().has_value() &&
|
||||
rhs.getConstantValue().has_value()) {
|
||||
if constexpr (std::is_same_v<OpTy, arith::AddIOp> ||
|
||||
std::is_same_v<OpTy, triton::AddPtrOp>) {
|
||||
std::is_same_v<OpTy, triton::AddPtrOp> ||
|
||||
std::is_same_v<OpTy, LLVM::AddOp>) {
|
||||
return {lhs.getConstantValue().value() +
|
||||
rhs.getConstantValue().value()};
|
||||
} else if constexpr (std::is_same_v<OpTy, arith::SubIOp>) {
|
||||
@@ -812,11 +813,15 @@ AxisInfoAnalysis::AxisInfoAnalysis(DataFlowSolver &solver)
|
||||
CastOpAxisInfoVisitor<triton::gpu::ConvertLayoutOp>,
|
||||
CastOpAxisInfoVisitor<mlir::UnrealizedConversionCastOp>,
|
||||
CastOpAxisInfoVisitor<triton::BitcastOp>>();
|
||||
// TODO: Remove rules for LLVM::ConstantOp, LLVM::AddOp
|
||||
// when scf.for supports integers induction variable
|
||||
visitors.append<MakeRangeOpAxisInfoVisitor>();
|
||||
visitors.append<ConstantOpAxisInfoVisitor>();
|
||||
visitors.append<ConstantOpAxisInfoVisitor<arith::ConstantOp>,
|
||||
ConstantOpAxisInfoVisitor<LLVM::ConstantOp>>();
|
||||
visitors.append<AddSubOpAxisInfoVisitor<triton::AddPtrOp>,
|
||||
AddSubOpAxisInfoVisitor<arith::AddIOp>,
|
||||
AddSubOpAxisInfoVisitor<arith::SubIOp>>();
|
||||
AddSubOpAxisInfoVisitor<arith::SubIOp>,
|
||||
AddSubOpAxisInfoVisitor<LLVM::AddOp>>();
|
||||
visitors.append<MulIOpAxisInfoVisitor>();
|
||||
visitors.append<DivOpAxisInfoVisitor<arith::DivSIOp>,
|
||||
DivOpAxisInfoVisitor<arith::DivUIOp>>();
|
||||
|
||||
Reference in New Issue
Block a user