From 15fd194075ac5967c039113a882bd17e195ab136 Mon Sep 17 00:00:00 2001 From: Quentin Bourgerie Date: Wed, 25 Aug 2021 14:07:42 +0200 Subject: [PATCH] fix(compiler): Since the patterns could be call in any order take care of already lowered types in mid to low conversion --- .../Conversion/MidLFHEToLowLFHE/Patterns.h | 81 ++++++++++++++----- 1 file changed, 62 insertions(+), 19 deletions(-) diff --git a/compiler/include/zamalang/Conversion/MidLFHEToLowLFHE/Patterns.h b/compiler/include/zamalang/Conversion/MidLFHEToLowLFHE/Patterns.h index f847a0e2e..48d5b65db 100644 --- a/compiler/include/zamalang/Conversion/MidLFHEToLowLFHE/Patterns.h +++ b/compiler/include/zamalang/Conversion/MidLFHEToLowLFHE/Patterns.h @@ -15,22 +15,66 @@ using LowLFHE::LweCiphertextType; using LowLFHE::PlaintextType; using MidLFHE::GLWECipherTextType; -LweCiphertextType convertTypeGLWEToLWE(mlir::MLIRContext *context, - GLWECipherTextType &glwe) { - return LweCiphertextType::get( - context, glwe.getDimension() * glwe.getPolynomialSize(), glwe.getP()); +LweCiphertextType convertTypeToLWE(mlir::MLIRContext *context, + mlir::Type type) { + auto glwe = type.dyn_cast_or_null(); + if (glwe != nullptr) { + return LweCiphertextType::get( + context, glwe.getDimension() * glwe.getPolynomialSize(), glwe.getP()); + } + auto lwe = type.dyn_cast_or_null(); + if (lwe != nullptr) { + return lwe; + } + assert(false && "expect glwe or lwe"); } -PlaintextType convertPlaintextTypeFromGlwe(mlir::MLIRContext *context, - GLWECipherTextType &type) { +template +PlaintextType convertPlaintextTypeFromPType(mlir::MLIRContext *context, + PType &type) { return PlaintextType::get(context, type.getP() + 1); } -CleartextType convertCleartextTypeFromGlwe(mlir::MLIRContext *context, - GLWECipherTextType &type) { +// convertPlaintextTypeFromType create a plaintext type according the +// precision of the given type argument. The type should be a GLWECipherText +// (if operand is not yet lowered) or a LWECipherTextType (if operand is +// already lowered). +PlaintextType convertPlaintextTypeFromType(mlir::MLIRContext *context, + mlir::Type &type) { + auto glwe = type.dyn_cast_or_null(); + if (glwe != nullptr) { + return convertPlaintextTypeFromPType(context, glwe); + } + auto lwe = type.dyn_cast_or_null(); + if (lwe != nullptr) { + return convertPlaintextTypeFromPType(context, lwe); + } + assert(false && "expect glwe or lwe"); +} + +template +CleartextType convertCleartextTypeFromPType(mlir::MLIRContext *context, + PType &type) { return CleartextType::get(context, type.getP() + 1); } +// convertCleartextTypeFromType create a cleartext type according the +// precision of the given type argument. The type should be a GLWECipherText +// (if operand is not yet lowered) or a LWECipherTextType (if operand is +// already lowered). +CleartextType convertCleartextTypeFromType(mlir::MLIRContext *context, + mlir::Type &type) { + auto glwe = type.dyn_cast_or_null(); + if (glwe != nullptr) { + return convertCleartextTypeFromPType(context, glwe); + } + auto lwe = type.dyn_cast_or_null(); + if (lwe != nullptr) { + return convertCleartextTypeFromPType(context, lwe); + } + assert(false && "expect glwe or lwe"); +} + template mlir::Value createLowLFHEOpFromMidLFHE(mlir::PatternRewriter rewriter, mlir::Location loc, mlir::Value arg0, @@ -40,16 +84,16 @@ mlir::Value createLowLFHEOpFromMidLFHE(mlir::PatternRewriter rewriter, mlir::SmallVector attrs; auto glwe = result.getType().cast(); mlir::SmallVector resTypes{ - convertTypeGLWEToLWE(rewriter.getContext(), glwe)}; + convertTypeToLWE(rewriter.getContext(), glwe)}; Operator op = rewriter.create(loc, resTypes, args, attrs); return op.getODSResults(0).front(); } mlir::Value createAddPlainLweCiphertextWithGlwe( mlir::PatternRewriter rewriter, mlir::Location loc, mlir::Value arg0, - mlir::Value arg1, mlir::OpResult result, GLWECipherTextType glwe) { + mlir::Value arg1, mlir::OpResult result, mlir::Type encryptedType) { PlaintextType encoded_type = - convertPlaintextTypeFromGlwe(rewriter.getContext(), glwe); + convertPlaintextTypeFromType(rewriter.getContext(), encryptedType); // encode int into plaintext mlir::Value encoded = rewriter @@ -58,7 +102,7 @@ mlir::Value createAddPlainLweCiphertextWithGlwe( // convert result type GLWECipherTextType glwe_type = result.getType().cast(); LweCiphertextType lwe_type = - convertTypeGLWEToLWE(rewriter.getContext(), glwe_type); + convertTypeToLWE(rewriter.getContext(), result.getType()); // replace op using the encoded plaintext instead of int auto op = rewriter.create( @@ -78,11 +122,11 @@ mlir::Value createAddPlainLweCiphertext(mlir::PatternRewriter rewriter, mlir::Value createSubIntLweCiphertext(mlir::PatternRewriter rewriter, mlir::Location loc, mlir::Value arg0, mlir::Value arg1, mlir::OpResult result) { - auto arg1_type = arg1.getType().cast(); + auto arg1_type = arg1.getType(); auto negated_arg1 = rewriter .create( - loc, convertTypeGLWEToLWE(rewriter.getContext(), arg1_type), arg1) + loc, convertTypeToLWE(rewriter.getContext(), arg1_type), arg1) .result(); return createAddPlainLweCiphertextWithGlwe(rewriter, loc, negated_arg1, arg0, result, arg1_type); @@ -92,18 +136,17 @@ mlir::Value createMulClearLweCiphertext(mlir::PatternRewriter rewriter, mlir::Location loc, mlir::Value arg0, mlir::Value arg1, mlir::OpResult result) { - auto glwe = arg0.getType().cast(); + auto inType = arg0.getType(); CleartextType encoded_type = - convertCleartextTypeFromGlwe(rewriter.getContext(), glwe); + convertCleartextTypeFromType(rewriter.getContext(), inType); // encode int into plaintext mlir::Value encoded = rewriter .create( loc, encoded_type, arg1) .cleartext(); // convert result type - GLWECipherTextType glwe_type = result.getType().cast(); - LweCiphertextType lwe_type = - convertTypeGLWEToLWE(rewriter.getContext(), glwe_type); + auto resType = result.getType(); + LweCiphertextType lwe_type = convertTypeToLWE(rewriter.getContext(), resType); // replace op using the encoded plaintext instead of int auto op = rewriter.create(