mirror of
https://github.com/zama-ai/concrete.git
synced 2026-02-09 12:15:09 -05:00
fix(compiler): Since the patterns could be call in any order take care of already lowered types in mid to low conversion
This commit is contained in:
@@ -15,22 +15,66 @@ using LowLFHE::LweCiphertextType;
|
||||
using LowLFHE::PlaintextType;
|
||||
using MidLFHE::GLWECipherTextType;
|
||||
|
||||
LweCiphertextType convertTypeGLWEToLWE(mlir::MLIRContext *context,
|
||||
GLWECipherTextType &glwe) {
|
||||
return LweCiphertextType::get(
|
||||
context, glwe.getDimension() * glwe.getPolynomialSize(), glwe.getP());
|
||||
LweCiphertextType convertTypeToLWE(mlir::MLIRContext *context,
|
||||
mlir::Type type) {
|
||||
auto glwe = type.dyn_cast_or_null<GLWECipherTextType>();
|
||||
if (glwe != nullptr) {
|
||||
return LweCiphertextType::get(
|
||||
context, glwe.getDimension() * glwe.getPolynomialSize(), glwe.getP());
|
||||
}
|
||||
auto lwe = type.dyn_cast_or_null<LweCiphertextType>();
|
||||
if (lwe != nullptr) {
|
||||
return lwe;
|
||||
}
|
||||
assert(false && "expect glwe or lwe");
|
||||
}
|
||||
|
||||
PlaintextType convertPlaintextTypeFromGlwe(mlir::MLIRContext *context,
|
||||
GLWECipherTextType &type) {
|
||||
template <typename PType>
|
||||
PlaintextType convertPlaintextTypeFromPType(mlir::MLIRContext *context,
|
||||
PType &type) {
|
||||
return PlaintextType::get(context, type.getP() + 1);
|
||||
}
|
||||
|
||||
CleartextType convertCleartextTypeFromGlwe(mlir::MLIRContext *context,
|
||||
GLWECipherTextType &type) {
|
||||
// convertPlaintextTypeFromType create a plaintext type according the
|
||||
// precision of the given type argument. The type should be a GLWECipherText
|
||||
// (if operand is not yet lowered) or a LWECipherTextType (if operand is
|
||||
// already lowered).
|
||||
PlaintextType convertPlaintextTypeFromType(mlir::MLIRContext *context,
|
||||
mlir::Type &type) {
|
||||
auto glwe = type.dyn_cast_or_null<GLWECipherTextType>();
|
||||
if (glwe != nullptr) {
|
||||
return convertPlaintextTypeFromPType<GLWECipherTextType>(context, glwe);
|
||||
}
|
||||
auto lwe = type.dyn_cast_or_null<LweCiphertextType>();
|
||||
if (lwe != nullptr) {
|
||||
return convertPlaintextTypeFromPType<LweCiphertextType>(context, lwe);
|
||||
}
|
||||
assert(false && "expect glwe or lwe");
|
||||
}
|
||||
|
||||
template <typename PType>
|
||||
CleartextType convertCleartextTypeFromPType(mlir::MLIRContext *context,
|
||||
PType &type) {
|
||||
return CleartextType::get(context, type.getP() + 1);
|
||||
}
|
||||
|
||||
// convertCleartextTypeFromType create a cleartext type according the
|
||||
// precision of the given type argument. The type should be a GLWECipherText
|
||||
// (if operand is not yet lowered) or a LWECipherTextType (if operand is
|
||||
// already lowered).
|
||||
CleartextType convertCleartextTypeFromType(mlir::MLIRContext *context,
|
||||
mlir::Type &type) {
|
||||
auto glwe = type.dyn_cast_or_null<GLWECipherTextType>();
|
||||
if (glwe != nullptr) {
|
||||
return convertCleartextTypeFromPType<GLWECipherTextType>(context, glwe);
|
||||
}
|
||||
auto lwe = type.dyn_cast_or_null<LweCiphertextType>();
|
||||
if (lwe != nullptr) {
|
||||
return convertCleartextTypeFromPType<LweCiphertextType>(context, lwe);
|
||||
}
|
||||
assert(false && "expect glwe or lwe");
|
||||
}
|
||||
|
||||
template <class Operator>
|
||||
mlir::Value createLowLFHEOpFromMidLFHE(mlir::PatternRewriter rewriter,
|
||||
mlir::Location loc, mlir::Value arg0,
|
||||
@@ -40,16 +84,16 @@ mlir::Value createLowLFHEOpFromMidLFHE(mlir::PatternRewriter rewriter,
|
||||
mlir::SmallVector<mlir::NamedAttribute, 0> attrs;
|
||||
auto glwe = result.getType().cast<GLWECipherTextType>();
|
||||
mlir::SmallVector<mlir::Type, 1> resTypes{
|
||||
convertTypeGLWEToLWE(rewriter.getContext(), glwe)};
|
||||
convertTypeToLWE(rewriter.getContext(), glwe)};
|
||||
Operator op = rewriter.create<Operator>(loc, resTypes, args, attrs);
|
||||
return op.getODSResults(0).front();
|
||||
}
|
||||
|
||||
mlir::Value createAddPlainLweCiphertextWithGlwe(
|
||||
mlir::PatternRewriter rewriter, mlir::Location loc, mlir::Value arg0,
|
||||
mlir::Value arg1, mlir::OpResult result, GLWECipherTextType glwe) {
|
||||
mlir::Value arg1, mlir::OpResult result, mlir::Type encryptedType) {
|
||||
PlaintextType encoded_type =
|
||||
convertPlaintextTypeFromGlwe(rewriter.getContext(), glwe);
|
||||
convertPlaintextTypeFromType(rewriter.getContext(), encryptedType);
|
||||
// encode int into plaintext
|
||||
mlir::Value encoded =
|
||||
rewriter
|
||||
@@ -58,7 +102,7 @@ mlir::Value createAddPlainLweCiphertextWithGlwe(
|
||||
// convert result type
|
||||
GLWECipherTextType glwe_type = result.getType().cast<GLWECipherTextType>();
|
||||
LweCiphertextType lwe_type =
|
||||
convertTypeGLWEToLWE(rewriter.getContext(), glwe_type);
|
||||
convertTypeToLWE(rewriter.getContext(), result.getType());
|
||||
// replace op using the encoded plaintext instead of int
|
||||
auto op =
|
||||
rewriter.create<mlir::zamalang::LowLFHE::AddPlaintextLweCiphertextOp>(
|
||||
@@ -78,11 +122,11 @@ mlir::Value createAddPlainLweCiphertext(mlir::PatternRewriter rewriter,
|
||||
mlir::Value createSubIntLweCiphertext(mlir::PatternRewriter rewriter,
|
||||
mlir::Location loc, mlir::Value arg0,
|
||||
mlir::Value arg1, mlir::OpResult result) {
|
||||
auto arg1_type = arg1.getType().cast<GLWECipherTextType>();
|
||||
auto arg1_type = arg1.getType();
|
||||
auto negated_arg1 =
|
||||
rewriter
|
||||
.create<mlir::zamalang::LowLFHE::NegateLweCiphertextOp>(
|
||||
loc, convertTypeGLWEToLWE(rewriter.getContext(), arg1_type), arg1)
|
||||
loc, convertTypeToLWE(rewriter.getContext(), arg1_type), arg1)
|
||||
.result();
|
||||
return createAddPlainLweCiphertextWithGlwe(rewriter, loc, negated_arg1, arg0,
|
||||
result, arg1_type);
|
||||
@@ -92,18 +136,17 @@ mlir::Value createMulClearLweCiphertext(mlir::PatternRewriter rewriter,
|
||||
mlir::Location loc, mlir::Value arg0,
|
||||
mlir::Value arg1,
|
||||
mlir::OpResult result) {
|
||||
auto glwe = arg0.getType().cast<GLWECipherTextType>();
|
||||
auto inType = arg0.getType();
|
||||
CleartextType encoded_type =
|
||||
convertCleartextTypeFromGlwe(rewriter.getContext(), glwe);
|
||||
convertCleartextTypeFromType(rewriter.getContext(), inType);
|
||||
// encode int into plaintext
|
||||
mlir::Value encoded = rewriter
|
||||
.create<mlir::zamalang::LowLFHE::IntToCleartextOp>(
|
||||
loc, encoded_type, arg1)
|
||||
.cleartext();
|
||||
// convert result type
|
||||
GLWECipherTextType glwe_type = result.getType().cast<GLWECipherTextType>();
|
||||
LweCiphertextType lwe_type =
|
||||
convertTypeGLWEToLWE(rewriter.getContext(), glwe_type);
|
||||
auto resType = result.getType();
|
||||
LweCiphertextType lwe_type = convertTypeToLWE(rewriter.getContext(), resType);
|
||||
// replace op using the encoded plaintext instead of int
|
||||
auto op =
|
||||
rewriter.create<mlir::zamalang::LowLFHE::MulCleartextLweCiphertextOp>(
|
||||
|
||||
Reference in New Issue
Block a user