enhance(compiler/midlfhe): Handle linalg.generic in parametrization pass

This commit is contained in:
Quentin Bourgerie
2021-08-25 14:08:42 +02:00
parent 15fd194075
commit a654fb2d0e
2 changed files with 5 additions and 2 deletions

View File

@@ -2,6 +2,7 @@
#include "mlir/Transforms/DialectConversion.h"
#include "zamalang/Conversion/Passes.h"
#include "zamalang/Conversion/Utils/LinalgGenericTypeConverterPattern.h"
#include "zamalang/Conversion/Utils/TensorOpTypeConversion.h"
#include "zamalang/Dialect/MidLFHE/IR/MidLFHEDialect.h"
#include "zamalang/Dialect/MidLFHE/IR/MidLFHEOps.h"
@@ -192,6 +193,8 @@ void MidLFHEGlobalParametrizationPass::runOnOperation() {
mlir::OwningRewritePatternList patterns(&getContext());
populateWithMidLFHEOpTypeConversionPatterns(patterns, target, converter,
fheContext.parameter);
patterns.add<LinalgGenericTypeConverterPattern<
MidLFHEGlobalParametrizationTypeConverter>>(&getContext(), converter);
mlir::zamalang::populateWithTensorTypeConverterPatterns(patterns, target,
converter);
mlir::populateFuncOpTypeConversionPattern(patterns, converter);

View File

@@ -30,7 +30,7 @@ public:
MidLFHEToLowLFHETypeConverter() {
addConversion([](mlir::Type type) { return type; });
addConversion([&](GLWECipherTextType type) {
return mlir::zamalang::convertTypeGLWEToLWE(type.getContext(), type);
return mlir::zamalang::convertTypeToLWE(type.getContext(), type);
});
addConversion([&](mlir::RankedTensorType type) {
auto glwe = type.getElementType().dyn_cast_or_null<GLWECipherTextType>();
@@ -39,7 +39,7 @@ public:
}
mlir::Type r = mlir::RankedTensorType::get(
type.getShape(),
mlir::zamalang::convertTypeGLWEToLWE(glwe.getContext(), glwe));
mlir::zamalang::convertTypeToLWE(glwe.getContext(), glwe));
return r;
});
}