Files
ROCm/test/lib/Analysis/TestAxisInfo.cpp
Keren Zhou 82befe32ad [BACKEND] Improve torch inductor performance (#1108)
- Rewrite the AxisInfo analysis to handle each op case by case.
- Add bit shift, min max, div/rem, and select ops to AxisInfo.
- Rematerialize across load/store ops in the following two cases:
- A size 1 tensor is considered not expensive since all threads will
load the same
- the targeEncoding may expose more vectorization opportunities (more
elements per thread on the first dim)

**_res2next_** benchmark GPU Kernel time comparison on A100.
- Average kernel sum. Triton 16838630ns vs Triton-MLIR 17105166ns.
**1.016x slowdown**.
- Total kernel sum. Triton 6511735460ns vs Triton-MLIR 6512370620ns.
2023-02-01 18:21:15 -08:00

78 lines
2.1 KiB
C++

#include "mlir/Pass/Pass.h"
#include "triton/Analysis/AxisInfo.h"
using namespace mlir;
namespace {
struct TestAxisInfoPass
: public PassWrapper<TestAxisInfoPass, OperationPass<FuncOp>> {
// LLVM15+
// MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(TestAlignmentPass);
void print(const std::string &name, raw_ostream &os, ArrayRef<int64_t> vals) {
os << name << ": [";
for (size_t d = 0; d < vals.size(); d++) {
if (d != 0)
os << ", ";
os << vals[d];
}
os << "]";
}
StringRef getArgument() const final { return "test-print-alignment"; }
StringRef getDescription() const final {
return "print the result of the alignment analysis pass";
}
void runOnOperation() override {
Operation *operation = getOperation();
auto &os = llvm::errs();
auto opName = SymbolTable::getSymbolName(operation).getValue().str();
os << opName << "\n";
AxisInfoAnalysis analysis(&getContext());
analysis.run(operation);
operation->walk([&](Operation *op) {
if (op->getNumResults() < 1)
return;
for (Value result : op->getResults()) {
// std::ostringstream oss;
// result.print(oss);
// os << " => ";
LatticeElement<AxisInfo> *latticeElement =
analysis.lookupLatticeElement(result);
if (!latticeElement) {
os << "None\n";
return;
}
AxisInfo &info = latticeElement->getValue();
print("Contiguity", os, info.getContiguity());
os << " ; ";
print("Divisibility", os, info.getDivisibility());
os << " ; ";
print("Constancy", os, info.getConstancy());
os << " ; ";
auto constantValue = info.getConstantValue();
os << "ConstantValue: [";
if (constantValue.has_value())
os << constantValue.value();
else
os << "None";
os << "] ( ";
result.print(os);
os << " ) ";
os << "\n";
}
});
}
};
} // namespace
namespace mlir {
namespace test {
void registerTestAlignmentPass() { PassRegistration<TestAxisInfoPass>(); }
} // namespace test
} // namespace mlir