mirror of
https://github.com/ROCm/ROCm.git
synced 2026-04-05 03:01:17 -04:00
[BACKEND] Rewrite Membar to fit the CF dialect (#1213)
This commit is contained in:
@@ -1,6 +1,8 @@
|
||||
#include "mlir/Conversion/SCFToControlFlow/SCFToControlFlow.h"
|
||||
#include "mlir/Dialect/GPU/IR/GPUDialect.h"
|
||||
#include "mlir/IR/Dialect.h"
|
||||
#include "mlir/Pass/Pass.h"
|
||||
#include "mlir/Transforms/DialectConversion.h"
|
||||
#include "triton/Analysis/Allocation.h"
|
||||
#include "triton/Analysis/Membar.h"
|
||||
|
||||
@@ -24,21 +26,25 @@ struct TestMembarPass
|
||||
// Convert to std::string can remove quotes from op_name
|
||||
auto opName = SymbolTable::getSymbolName(operation).getValue().str();
|
||||
os << opName << "\n";
|
||||
|
||||
// Lower the module to the cf dialect
|
||||
auto *context = operation->getContext();
|
||||
RewritePatternSet scfPatterns(context);
|
||||
mlir::populateSCFToControlFlowConversionPatterns(scfPatterns);
|
||||
mlir::ConversionTarget scfTarget(*context);
|
||||
scfTarget.addIllegalOp<scf::ForOp, scf::IfOp, scf::ParallelOp, scf::WhileOp,
|
||||
scf::ExecuteRegionOp>();
|
||||
scfTarget.markUnknownOpDynamicallyLegal([](Operation *) { return true; });
|
||||
if (failed(applyPartialConversion(operation, scfTarget,
|
||||
std::move(scfPatterns))))
|
||||
return signalPassFailure();
|
||||
|
||||
// Print all ops after membar pass
|
||||
Allocation allocation(operation);
|
||||
MembarAnalysis membarPass(&allocation);
|
||||
membarPass.run();
|
||||
|
||||
size_t operationId = 0;
|
||||
operation->walk<WalkOrder::PreOrder>([&](Operation *op) {
|
||||
if (isa<gpu::BarrierOp>(op)) {
|
||||
os << "Membar " << operationId << "\n";
|
||||
}
|
||||
if (op->getNumRegions() == 0) {
|
||||
// Don't count parent Operation to simplify the test.
|
||||
operationId++;
|
||||
}
|
||||
return;
|
||||
});
|
||||
os << *operation << "\n";
|
||||
}
|
||||
};
|
||||
|
||||
|
||||
Reference in New Issue
Block a user