From a654fb2d0eb08f5725d9fb20722ff343b314311e Mon Sep 17 00:00:00 2001 From: Quentin Bourgerie Date: Wed, 25 Aug 2021 14:08:42 +0200 Subject: [PATCH] enhance(compiler/midlfhe): Handle linalg.generic in parametrization pass --- .../MidLFHEGlobalParametrization.cpp | 3 +++ compiler/lib/Conversion/MidLFHEToLowLFHE/MidLFHEToLowLFHE.cpp | 4 ++-- 2 files changed, 5 insertions(+), 2 deletions(-) diff --git a/compiler/lib/Conversion/MidLFHEGlobalParametrization/MidLFHEGlobalParametrization.cpp b/compiler/lib/Conversion/MidLFHEGlobalParametrization/MidLFHEGlobalParametrization.cpp index e54dbcaab..28ae9d137 100644 --- a/compiler/lib/Conversion/MidLFHEGlobalParametrization/MidLFHEGlobalParametrization.cpp +++ b/compiler/lib/Conversion/MidLFHEGlobalParametrization/MidLFHEGlobalParametrization.cpp @@ -2,6 +2,7 @@ #include "mlir/Transforms/DialectConversion.h" #include "zamalang/Conversion/Passes.h" +#include "zamalang/Conversion/Utils/LinalgGenericTypeConverterPattern.h" #include "zamalang/Conversion/Utils/TensorOpTypeConversion.h" #include "zamalang/Dialect/MidLFHE/IR/MidLFHEDialect.h" #include "zamalang/Dialect/MidLFHE/IR/MidLFHEOps.h" @@ -192,6 +193,8 @@ void MidLFHEGlobalParametrizationPass::runOnOperation() { mlir::OwningRewritePatternList patterns(&getContext()); populateWithMidLFHEOpTypeConversionPatterns(patterns, target, converter, fheContext.parameter); + patterns.add>(&getContext(), converter); mlir::zamalang::populateWithTensorTypeConverterPatterns(patterns, target, converter); mlir::populateFuncOpTypeConversionPattern(patterns, converter); diff --git a/compiler/lib/Conversion/MidLFHEToLowLFHE/MidLFHEToLowLFHE.cpp b/compiler/lib/Conversion/MidLFHEToLowLFHE/MidLFHEToLowLFHE.cpp index 9e63d906f..a4148b1c5 100644 --- a/compiler/lib/Conversion/MidLFHEToLowLFHE/MidLFHEToLowLFHE.cpp +++ b/compiler/lib/Conversion/MidLFHEToLowLFHE/MidLFHEToLowLFHE.cpp @@ -30,7 +30,7 @@ public: MidLFHEToLowLFHETypeConverter() { addConversion([](mlir::Type type) { return type; }); addConversion([&](GLWECipherTextType type) { - return mlir::zamalang::convertTypeGLWEToLWE(type.getContext(), type); + return mlir::zamalang::convertTypeToLWE(type.getContext(), type); }); addConversion([&](mlir::RankedTensorType type) { auto glwe = type.getElementType().dyn_cast_or_null(); @@ -39,7 +39,7 @@ public: } mlir::Type r = mlir::RankedTensorType::get( type.getShape(), - mlir::zamalang::convertTypeGLWEToLWE(glwe.getContext(), glwe)); + mlir::zamalang::convertTypeToLWE(glwe.getContext(), glwe)); return r; }); }