fix(compiler/midlfhe): Manage tensor in midlfhe parametrization

This commit is contained in:
Quentin Bourgerie
2021-08-24 17:18:54 +02:00
parent af0789f128
commit 19f1a22b6a

View File

@@ -2,6 +2,7 @@
#include "mlir/Transforms/DialectConversion.h"
#include "zamalang/Conversion/Passes.h"
#include "zamalang/Conversion/Utils/TensorOpTypeConversion.h"
#include "zamalang/Dialect/MidLFHE/IR/MidLFHEDialect.h"
#include "zamalang/Dialect/MidLFHE/IR/MidLFHEOps.h"
#include "zamalang/Dialect/MidLFHE/IR/MidLFHETypes.h"
@@ -27,16 +28,30 @@ class MidLFHEGlobalParametrizationTypeConverter : public mlir::TypeConverter {
public:
MidLFHEGlobalParametrizationTypeConverter(
mlir::zamalang::V0FHEContext &fheContext) {
auto convertGLWECiphertextType =
[](GLWECipherTextType type, mlir::zamalang::V0FHEContext &fheContext) {
auto glweSize = fheContext.parameter.getNBigGlweSize();
auto p = fheContext.constraint.p;
if (type.getDimension() == glweSize && type.getP() == p) {
return type;
}
return GLWECipherTextType::get(
type.getContext(), glweSize,
1 /*for the v0, is always lwe ciphertext*/,
64 /*for the v0 we handle only q=64*/, p);
};
addConversion([](mlir::Type type) { return type; });
addConversion([&](GLWECipherTextType type) {
auto glweSize = fheContext.parameter.getNBigGlweSize();
auto p = fheContext.constraint.p;
if (type.getDimension() == glweSize && type.getP() == p) {
return type;
return convertGLWECiphertextType(type, fheContext);
});
addConversion([&](mlir::RankedTensorType type) {
auto glwe = type.getElementType().dyn_cast_or_null<GLWECipherTextType>();
if (glwe == nullptr) {
return (mlir::Type)(type);
}
return GLWECipherTextType::get(type.getContext(), glweSize,
1 /*for the v0, is always lwe ciphertext*/,
64 /*for the v0 we handle only q=64*/, p);
mlir::Type r = mlir::RankedTensorType::get(
type.getShape(), convertGLWECiphertextType(glwe, fheContext));
return r;
});
}
};
@@ -177,6 +192,8 @@ void MidLFHEGlobalParametrizationPass::runOnOperation() {
mlir::OwningRewritePatternList patterns(&getContext());
populateWithMidLFHEOpTypeConversionPatterns(patterns, target, converter,
fheContext.parameter);
mlir::zamalang::populateWithTensorTypeConverterPatterns(patterns, target,
converter);
mlir::populateFuncOpTypeConversionPattern(patterns, converter);
// Apply conversion