mirror of
https://github.com/zama-ai/concrete.git
synced 2026-02-09 20:25:34 -05:00
fix(compiler/midlfhe): Manage tensor in midlfhe parametrization
This commit is contained in:
@@ -2,6 +2,7 @@
|
||||
#include "mlir/Transforms/DialectConversion.h"
|
||||
|
||||
#include "zamalang/Conversion/Passes.h"
|
||||
#include "zamalang/Conversion/Utils/TensorOpTypeConversion.h"
|
||||
#include "zamalang/Dialect/MidLFHE/IR/MidLFHEDialect.h"
|
||||
#include "zamalang/Dialect/MidLFHE/IR/MidLFHEOps.h"
|
||||
#include "zamalang/Dialect/MidLFHE/IR/MidLFHETypes.h"
|
||||
@@ -27,16 +28,30 @@ class MidLFHEGlobalParametrizationTypeConverter : public mlir::TypeConverter {
|
||||
public:
|
||||
MidLFHEGlobalParametrizationTypeConverter(
|
||||
mlir::zamalang::V0FHEContext &fheContext) {
|
||||
auto convertGLWECiphertextType =
|
||||
[](GLWECipherTextType type, mlir::zamalang::V0FHEContext &fheContext) {
|
||||
auto glweSize = fheContext.parameter.getNBigGlweSize();
|
||||
auto p = fheContext.constraint.p;
|
||||
if (type.getDimension() == glweSize && type.getP() == p) {
|
||||
return type;
|
||||
}
|
||||
return GLWECipherTextType::get(
|
||||
type.getContext(), glweSize,
|
||||
1 /*for the v0, is always lwe ciphertext*/,
|
||||
64 /*for the v0 we handle only q=64*/, p);
|
||||
};
|
||||
addConversion([](mlir::Type type) { return type; });
|
||||
addConversion([&](GLWECipherTextType type) {
|
||||
auto glweSize = fheContext.parameter.getNBigGlweSize();
|
||||
auto p = fheContext.constraint.p;
|
||||
if (type.getDimension() == glweSize && type.getP() == p) {
|
||||
return type;
|
||||
return convertGLWECiphertextType(type, fheContext);
|
||||
});
|
||||
addConversion([&](mlir::RankedTensorType type) {
|
||||
auto glwe = type.getElementType().dyn_cast_or_null<GLWECipherTextType>();
|
||||
if (glwe == nullptr) {
|
||||
return (mlir::Type)(type);
|
||||
}
|
||||
return GLWECipherTextType::get(type.getContext(), glweSize,
|
||||
1 /*for the v0, is always lwe ciphertext*/,
|
||||
64 /*for the v0 we handle only q=64*/, p);
|
||||
mlir::Type r = mlir::RankedTensorType::get(
|
||||
type.getShape(), convertGLWECiphertextType(glwe, fheContext));
|
||||
return r;
|
||||
});
|
||||
}
|
||||
};
|
||||
@@ -177,6 +192,8 @@ void MidLFHEGlobalParametrizationPass::runOnOperation() {
|
||||
mlir::OwningRewritePatternList patterns(&getContext());
|
||||
populateWithMidLFHEOpTypeConversionPatterns(patterns, target, converter,
|
||||
fheContext.parameter);
|
||||
mlir::zamalang::populateWithTensorTypeConverterPatterns(patterns, target,
|
||||
converter);
|
||||
mlir::populateFuncOpTypeConversionPattern(patterns, converter);
|
||||
|
||||
// Apply conversion
|
||||
|
||||
Reference in New Issue
Block a user