mirror of
https://github.com/ROCm/ROCm.git
synced 2026-04-05 03:01:17 -04:00
- Rewrite the AxisInfo analysis to handle each op case by case. - Add bit shift, min max, div/rem, and select ops to AxisInfo. - Rematerialize across load/store ops in the following two cases: - A size 1 tensor is considered not expensive since all threads will load the same - the targeEncoding may expose more vectorization opportunities (more elements per thread on the first dim) **_res2next_** benchmark GPU Kernel time comparison on A100. - Average kernel sum. Triton 16838630ns vs Triton-MLIR 17105166ns. **1.016x slowdown**. - Total kernel sum. Triton 6511735460ns vs Triton-MLIR 6512370620ns.
78 lines
2.1 KiB
C++
78 lines
2.1 KiB
C++
#include "mlir/Pass/Pass.h"
|
|
#include "triton/Analysis/AxisInfo.h"
|
|
|
|
using namespace mlir;
|
|
|
|
namespace {
|
|
|
|
struct TestAxisInfoPass
|
|
: public PassWrapper<TestAxisInfoPass, OperationPass<FuncOp>> {
|
|
|
|
// LLVM15+
|
|
// MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(TestAlignmentPass);
|
|
|
|
void print(const std::string &name, raw_ostream &os, ArrayRef<int64_t> vals) {
|
|
os << name << ": [";
|
|
for (size_t d = 0; d < vals.size(); d++) {
|
|
if (d != 0)
|
|
os << ", ";
|
|
os << vals[d];
|
|
}
|
|
os << "]";
|
|
}
|
|
|
|
StringRef getArgument() const final { return "test-print-alignment"; }
|
|
StringRef getDescription() const final {
|
|
return "print the result of the alignment analysis pass";
|
|
}
|
|
|
|
void runOnOperation() override {
|
|
Operation *operation = getOperation();
|
|
auto &os = llvm::errs();
|
|
auto opName = SymbolTable::getSymbolName(operation).getValue().str();
|
|
os << opName << "\n";
|
|
AxisInfoAnalysis analysis(&getContext());
|
|
analysis.run(operation);
|
|
operation->walk([&](Operation *op) {
|
|
if (op->getNumResults() < 1)
|
|
return;
|
|
for (Value result : op->getResults()) {
|
|
// std::ostringstream oss;
|
|
// result.print(oss);
|
|
// os << " => ";
|
|
LatticeElement<AxisInfo> *latticeElement =
|
|
analysis.lookupLatticeElement(result);
|
|
if (!latticeElement) {
|
|
os << "None\n";
|
|
return;
|
|
}
|
|
AxisInfo &info = latticeElement->getValue();
|
|
print("Contiguity", os, info.getContiguity());
|
|
os << " ; ";
|
|
print("Divisibility", os, info.getDivisibility());
|
|
os << " ; ";
|
|
print("Constancy", os, info.getConstancy());
|
|
os << " ; ";
|
|
auto constantValue = info.getConstantValue();
|
|
os << "ConstantValue: [";
|
|
if (constantValue.has_value())
|
|
os << constantValue.value();
|
|
else
|
|
os << "None";
|
|
os << "] ( ";
|
|
result.print(os);
|
|
os << " ) ";
|
|
os << "\n";
|
|
}
|
|
});
|
|
}
|
|
};
|
|
|
|
} // namespace
|
|
|
|
namespace mlir {
|
|
namespace test {
|
|
void registerTestAlignmentPass() { PassRegistration<TestAxisInfoPass>(); }
|
|
} // namespace test
|
|
} // namespace mlir
|