diff --git a/compiler/lib/Conversion/MidLFHEGlobalParametrization/MidLFHEGlobalParametrization.cpp b/compiler/lib/Conversion/MidLFHEGlobalParametrization/MidLFHEGlobalParametrization.cpp index e54dbcaab..28ae9d137 100644 --- a/compiler/lib/Conversion/MidLFHEGlobalParametrization/MidLFHEGlobalParametrization.cpp +++ b/compiler/lib/Conversion/MidLFHEGlobalParametrization/MidLFHEGlobalParametrization.cpp @@ -2,6 +2,7 @@ #include "mlir/Transforms/DialectConversion.h" #include "zamalang/Conversion/Passes.h" +#include "zamalang/Conversion/Utils/LinalgGenericTypeConverterPattern.h" #include "zamalang/Conversion/Utils/TensorOpTypeConversion.h" #include "zamalang/Dialect/MidLFHE/IR/MidLFHEDialect.h" #include "zamalang/Dialect/MidLFHE/IR/MidLFHEOps.h" @@ -192,6 +193,8 @@ void MidLFHEGlobalParametrizationPass::runOnOperation() { mlir::OwningRewritePatternList patterns(&getContext()); populateWithMidLFHEOpTypeConversionPatterns(patterns, target, converter, fheContext.parameter); + patterns.add>(&getContext(), converter); mlir::zamalang::populateWithTensorTypeConverterPatterns(patterns, target, converter); mlir::populateFuncOpTypeConversionPattern(patterns, converter); diff --git a/compiler/lib/Conversion/MidLFHEToLowLFHE/MidLFHEToLowLFHE.cpp b/compiler/lib/Conversion/MidLFHEToLowLFHE/MidLFHEToLowLFHE.cpp index 9e63d906f..a4148b1c5 100644 --- a/compiler/lib/Conversion/MidLFHEToLowLFHE/MidLFHEToLowLFHE.cpp +++ b/compiler/lib/Conversion/MidLFHEToLowLFHE/MidLFHEToLowLFHE.cpp @@ -30,7 +30,7 @@ public: MidLFHEToLowLFHETypeConverter() { addConversion([](mlir::Type type) { return type; }); addConversion([&](GLWECipherTextType type) { - return mlir::zamalang::convertTypeGLWEToLWE(type.getContext(), type); + return mlir::zamalang::convertTypeToLWE(type.getContext(), type); }); addConversion([&](mlir::RankedTensorType type) { auto glwe = type.getElementType().dyn_cast_or_null(); @@ -39,7 +39,7 @@ public: } mlir::Type r = mlir::RankedTensorType::get( type.getShape(), - mlir::zamalang::convertTypeGLWEToLWE(glwe.getContext(), glwe)); + mlir::zamalang::convertTypeToLWE(glwe.getContext(), glwe)); return r; }); }