diff --git a/compiler/lib/Conversion/MidLFHEGlobalParametrization/MidLFHEGlobalParametrization.cpp b/compiler/lib/Conversion/MidLFHEGlobalParametrization/MidLFHEGlobalParametrization.cpp index a3394ec07..e54dbcaab 100644 --- a/compiler/lib/Conversion/MidLFHEGlobalParametrization/MidLFHEGlobalParametrization.cpp +++ b/compiler/lib/Conversion/MidLFHEGlobalParametrization/MidLFHEGlobalParametrization.cpp @@ -2,6 +2,7 @@ #include "mlir/Transforms/DialectConversion.h" #include "zamalang/Conversion/Passes.h" +#include "zamalang/Conversion/Utils/TensorOpTypeConversion.h" #include "zamalang/Dialect/MidLFHE/IR/MidLFHEDialect.h" #include "zamalang/Dialect/MidLFHE/IR/MidLFHEOps.h" #include "zamalang/Dialect/MidLFHE/IR/MidLFHETypes.h" @@ -27,16 +28,30 @@ class MidLFHEGlobalParametrizationTypeConverter : public mlir::TypeConverter { public: MidLFHEGlobalParametrizationTypeConverter( mlir::zamalang::V0FHEContext &fheContext) { + auto convertGLWECiphertextType = + [](GLWECipherTextType type, mlir::zamalang::V0FHEContext &fheContext) { + auto glweSize = fheContext.parameter.getNBigGlweSize(); + auto p = fheContext.constraint.p; + if (type.getDimension() == glweSize && type.getP() == p) { + return type; + } + return GLWECipherTextType::get( + type.getContext(), glweSize, + 1 /*for the v0, is always lwe ciphertext*/, + 64 /*for the v0 we handle only q=64*/, p); + }; addConversion([](mlir::Type type) { return type; }); addConversion([&](GLWECipherTextType type) { - auto glweSize = fheContext.parameter.getNBigGlweSize(); - auto p = fheContext.constraint.p; - if (type.getDimension() == glweSize && type.getP() == p) { - return type; + return convertGLWECiphertextType(type, fheContext); + }); + addConversion([&](mlir::RankedTensorType type) { + auto glwe = type.getElementType().dyn_cast_or_null(); + if (glwe == nullptr) { + return (mlir::Type)(type); } - return GLWECipherTextType::get(type.getContext(), glweSize, - 1 /*for the v0, is always lwe ciphertext*/, - 64 /*for the v0 we handle only q=64*/, p); + mlir::Type r = mlir::RankedTensorType::get( + type.getShape(), convertGLWECiphertextType(glwe, fheContext)); + return r; }); } }; @@ -177,6 +192,8 @@ void MidLFHEGlobalParametrizationPass::runOnOperation() { mlir::OwningRewritePatternList patterns(&getContext()); populateWithMidLFHEOpTypeConversionPatterns(patterns, target, converter, fheContext.parameter); + mlir::zamalang::populateWithTensorTypeConverterPatterns(patterns, target, + converter); mlir::populateFuncOpTypeConversionPattern(patterns, converter); // Apply conversion