mirror of
https://github.com/zama-ai/concrete.git
synced 2026-02-09 12:15:09 -05:00
refactor(compiler): generalize LinalgGenericPattern
also fixes an issue regarding populateWithGenerated, which can be duplicated across different pattern files. So I redefined a different function that is more unique to the pass that should be ran, and hide the populateWithGenerated from the global namespace
This commit is contained in:
@@ -62,6 +62,12 @@ createApplyLookupTableGLWEOpFromHLFHE(mlir::PatternRewriter rewriter,
|
||||
} // namespace zamalang
|
||||
} // namespace mlir
|
||||
|
||||
namespace {
|
||||
#include "zamalang/Conversion/HLFHEToMidLFHE/Patterns.h.inc"
|
||||
}
|
||||
|
||||
void populateWithGeneratedHLFHEToMidLFHE(mlir::RewritePatternSet &patterns) {
|
||||
populateWithGenerated(patterns);
|
||||
}
|
||||
|
||||
#endif
|
||||
|
||||
@@ -0,0 +1,55 @@
|
||||
#include "mlir/Dialect/Linalg/IR/LinalgOps.h"
|
||||
#include "mlir/IR/PatternMatch.h"
|
||||
|
||||
/// LinalgGenericTypeConverterPattern is a rewrite pattern that convert types
|
||||
/// `linalg.generic` operation, using a specific `typeConverter`
|
||||
template <typename TypeConverter>
|
||||
struct LinalgGenericTypeConverterPattern
|
||||
: public mlir::OpRewritePattern<mlir::linalg::GenericOp> {
|
||||
LinalgGenericTypeConverterPattern(mlir::MLIRContext *context,
|
||||
TypeConverter &converter,
|
||||
mlir::PatternBenefit benefit = 100)
|
||||
: mlir::OpRewritePattern<mlir::linalg::GenericOp>(context, benefit),
|
||||
converter(converter) {}
|
||||
|
||||
mlir::LogicalResult
|
||||
matchAndRewrite(mlir::linalg::GenericOp op,
|
||||
mlir::PatternRewriter &rewriter) const override {
|
||||
|
||||
rewriter.startRootUpdate(op);
|
||||
// Rewrite arguments
|
||||
{
|
||||
for (auto i = 0; i < op->getNumOperands(); i++) {
|
||||
auto operand = op->getOperand(i);
|
||||
mlir::Type type = converter.convertType(operand.getType());
|
||||
if (type != mlir::Type()) {
|
||||
operand.setType(type);
|
||||
}
|
||||
}
|
||||
}
|
||||
// Rewrite results
|
||||
{
|
||||
for (auto i = 0; i < op->getNumResults(); i++) {
|
||||
auto result = op->getResult(i);
|
||||
mlir::Type type = converter.convertType(result.getType());
|
||||
if (type != mlir::Type()) {
|
||||
result.setType(type);
|
||||
}
|
||||
}
|
||||
}
|
||||
// Rewrite block arguments
|
||||
mlir::Region ®ion = op->getRegion(0);
|
||||
mlir::Block *entry = ®ion.front();
|
||||
for (auto arg : entry->getArguments()) {
|
||||
mlir::Type type = converter.convertType(arg.getType());
|
||||
if (type != mlir::Type()) {
|
||||
arg.setType(type);
|
||||
}
|
||||
}
|
||||
rewriter.finalizeRootUpdate(op);
|
||||
return mlir::success();
|
||||
}
|
||||
|
||||
private:
|
||||
TypeConverter &converter;
|
||||
};
|
||||
@@ -5,6 +5,7 @@
|
||||
|
||||
#include "zamalang/Conversion/HLFHEToMidLFHE/Patterns.h"
|
||||
#include "zamalang/Conversion/Passes.h"
|
||||
#include "zamalang/Conversion/Utils/LinalgGenericTypeConverterPattern.h"
|
||||
#include "zamalang/Dialect/HLFHE/IR/HLFHEDialect.h"
|
||||
#include "zamalang/Dialect/HLFHE/IR/HLFHETypes.h"
|
||||
#include "zamalang/Dialect/MidLFHE/IR/MidLFHEDialect.h"
|
||||
@@ -64,55 +65,6 @@ public:
|
||||
}
|
||||
};
|
||||
|
||||
/// LinalgGenericPattern is a rewrite pattern that convert types from HLFHE to
|
||||
/// MidLFHE of `linalg.generic` operation
|
||||
struct LinalgGenericPattern
|
||||
: public mlir::OpRewritePattern<mlir::linalg::GenericOp> {
|
||||
LinalgGenericPattern(mlir::MLIRContext *context,
|
||||
mlir::PatternBenefit benefit = 100)
|
||||
: mlir::OpRewritePattern<mlir::linalg::GenericOp>(context, benefit) {}
|
||||
|
||||
mlir::LogicalResult
|
||||
matchAndRewrite(mlir::linalg::GenericOp op,
|
||||
mlir::PatternRewriter &rewriter) const override {
|
||||
|
||||
HLFHEToMidLFHETypeConverter converter;
|
||||
|
||||
rewriter.startRootUpdate(op);
|
||||
// Rewrite arguments
|
||||
{
|
||||
for (auto i = 0; i < op->getNumOperands(); i++) {
|
||||
auto operand = op->getOperand(i);
|
||||
mlir::Type type = converter.convertType(operand.getType());
|
||||
if (type != mlir::Type()) {
|
||||
operand.setType(type);
|
||||
}
|
||||
}
|
||||
}
|
||||
// Rewrite results
|
||||
{
|
||||
for (auto i = 0; i < op->getNumResults(); i++) {
|
||||
auto result = op->getResult(i);
|
||||
mlir::Type type = converter.convertType(result.getType());
|
||||
if (type != mlir::Type()) {
|
||||
result.setType(type);
|
||||
}
|
||||
}
|
||||
}
|
||||
// Rewrite block arguments
|
||||
mlir::Region ®ion = op->getRegion(0);
|
||||
mlir::Block *entry = ®ion.front();
|
||||
for (auto arg : entry->getArguments()) {
|
||||
mlir::Type type = converter.convertType(arg.getType());
|
||||
if (type != mlir::Type()) {
|
||||
arg.setType(type);
|
||||
}
|
||||
}
|
||||
rewriter.finalizeRootUpdate(op);
|
||||
return mlir::success();
|
||||
}
|
||||
};
|
||||
|
||||
void HLFHEToMidLFHEPass::runOnOperation() {
|
||||
auto op = this->getOperation();
|
||||
|
||||
@@ -149,8 +101,9 @@ void HLFHEToMidLFHEPass::runOnOperation() {
|
||||
// `MidLFHE`
|
||||
mlir::OwningRewritePatternList patterns(&getContext());
|
||||
|
||||
populateWithGenerated(patterns);
|
||||
patterns.add<LinalgGenericPattern>(&getContext());
|
||||
populateWithGeneratedHLFHEToMidLFHE(patterns);
|
||||
patterns.add<LinalgGenericTypeConverterPattern<HLFHEToMidLFHETypeConverter>>(
|
||||
&getContext(), converter);
|
||||
mlir::populateFuncOpTypeConversionPattern(patterns, converter);
|
||||
|
||||
// Apply conversion
|
||||
|
||||
Reference in New Issue
Block a user