[BACKEND] Rewrite Membar to fit the CF dialect (#1213)

This commit is contained in:
Keren Zhou
2023-02-19 17:54:33 -05:00
committed by GitHub
parent 6b44d31ae4
commit 123c687ed9
9 changed files with 470 additions and 223 deletions

View File

@@ -1,6 +1,8 @@
#include "mlir/Conversion/SCFToControlFlow/SCFToControlFlow.h"
#include "mlir/Dialect/GPU/IR/GPUDialect.h"
#include "mlir/IR/Dialect.h"
#include "mlir/Pass/Pass.h"
#include "mlir/Transforms/DialectConversion.h"
#include "triton/Analysis/Allocation.h"
#include "triton/Analysis/Membar.h"
@@ -24,21 +26,25 @@ struct TestMembarPass
// Convert to std::string can remove quotes from op_name
auto opName = SymbolTable::getSymbolName(operation).getValue().str();
os << opName << "\n";
// Lower the module to the cf dialect
auto *context = operation->getContext();
RewritePatternSet scfPatterns(context);
mlir::populateSCFToControlFlowConversionPatterns(scfPatterns);
mlir::ConversionTarget scfTarget(*context);
scfTarget.addIllegalOp<scf::ForOp, scf::IfOp, scf::ParallelOp, scf::WhileOp,
scf::ExecuteRegionOp>();
scfTarget.markUnknownOpDynamicallyLegal([](Operation *) { return true; });
if (failed(applyPartialConversion(operation, scfTarget,
std::move(scfPatterns))))
return signalPassFailure();
// Print all ops after membar pass
Allocation allocation(operation);
MembarAnalysis membarPass(&allocation);
membarPass.run();
size_t operationId = 0;
operation->walk<WalkOrder::PreOrder>([&](Operation *op) {
if (isa<gpu::BarrierOp>(op)) {
os << "Membar " << operationId << "\n";
}
if (op->getNumRegions() == 0) {
// Don't count parent Operation to simplify the test.
operationId++;
}
return;
});
os << *operation << "\n";
}
};