From fb680340f9b794e478cb588a42d76cf272ef1595 Mon Sep 17 00:00:00 2001 From: aPere3 Date: Thu, 6 Oct 2022 10:55:10 +0200 Subject: [PATCH] feat(concrete-compiler): add new ciphertext multiplication operator --- .../concretelang/Dialect/FHE/IR/FHEOps.td | 38 ++++ .../Dialect/FHE/Transforms/CMakeLists.txt | 4 + .../FHE/Transforms/EncryptedMulToDoubleTLU.h | 24 +++ .../FHE/Transforms/EncryptedMulToDoubleTLU.td | 11 ++ .../include/concretelang/Support/Pipeline.h | 4 + .../Conversion/FHEToTFHECrt/FHEToTFHECrt.cpp | 40 ++++ .../FHEToTFHEScalar/FHEToTFHEScalar.cpp | 41 +++- compiler/lib/Dialect/FHE/Analysis/MANP.cpp | 34 ++++ compiler/lib/Dialect/FHE/IR/FHEOps.cpp | 16 ++ .../lib/Dialect/FHE/Transforms/CMakeLists.txt | 1 + .../Transforms/EncryptedMulToDoubleTLU.cpp | 178 +++++++++++++++++ compiler/lib/Support/CMakeLists.txt | 1 + compiler/lib/Support/CompilerEngine.cpp | 7 + compiler/lib/Support/Pipeline.cpp | 11 ++ .../Dialect/FHE/Transforms/mul_eint.mlir | 17 ++ .../Dialect/FHE/mul_eint.invalid.mlir | 15 ++ .../tests/check_tests/Dialect/FHE/ops.mlir | 9 + .../end_to_end_leveled_gen.py | 187 ++++++++++++++++++ 18 files changed, 637 insertions(+), 1 deletion(-) create mode 100644 compiler/include/concretelang/Dialect/FHE/Transforms/EncryptedMulToDoubleTLU.h create mode 100644 compiler/include/concretelang/Dialect/FHE/Transforms/EncryptedMulToDoubleTLU.td create mode 100644 compiler/lib/Dialect/FHE/Transforms/EncryptedMulToDoubleTLU.cpp create mode 100644 compiler/tests/check_tests/Dialect/FHE/Transforms/mul_eint.mlir create mode 100644 compiler/tests/check_tests/Dialect/FHE/mul_eint.invalid.mlir diff --git a/compiler/include/concretelang/Dialect/FHE/IR/FHEOps.td b/compiler/include/concretelang/Dialect/FHE/IR/FHEOps.td index cd129b51a..cf6d195e4 100644 --- a/compiler/include/concretelang/Dialect/FHE/IR/FHEOps.td +++ b/compiler/include/concretelang/Dialect/FHE/IR/FHEOps.td @@ -285,6 +285,44 @@ def FHE_MulEintIntOp : FHE_Op<"mul_eint_int", [NoSideEffect]> { let hasFolder = 1; } +def FHE_MulEintOp : FHE_Op<"mul_eint", [NoSideEffect]> { + let summary = "Multiplies two encrypted integers"; + + let description = [{ + Multiplies two encrypted integers. + + The encrypted integers and the result must have the same width and + signedness. Also, due to the current implementation, one supplementary + bit of width must be provided, in addition to the number of bits needed + to encode the largest output value. + + Example: + ```mlir + // ok + "FHE.mul_eint"(%a, %b): (!FHE.eint<2>, !FHE.eint<2>) -> (!FHE.eint<2>) + "FHE.mul_eint"(%a, %b): (!FHE.eint<3>, !FHE.eint<3>) -> (!FHE.eint<3>) + "FHE.mul_eint"(%a, %b): (!FHE.esint<3>, !FHE.esint<3>) -> (!FHE.esint<3>) + + // error + "FHE.mul_eint"(%a, %b): (!FHE.eint<2>, !FHE.eint<3>) -> (!FHE.eint<2>) + "FHE.mul_eint"(%a, %b): (!FHE.eint<2>, !FHE.eint<2>) -> (!FHE.eint<3>) + "FHE.mul_eint"(%a, %b): (!FHE.eint<2>, !FHE.eint<2>) -> (!FHE.esint<2>) + "FHE.mul_eint"(%a, %b): (!FHE.esint<2>, !FHE.eint<2>) -> (!FHE.eint<2>) + ``` + }]; + + let arguments = (ins FHE_AnyEncryptedInteger:$a, FHE_AnyEncryptedInteger:$b); + let results = (outs FHE_AnyEncryptedInteger); + + let builders = [ + OpBuilder<(ins "Value":$a, "Value":$b), [{ + build($_builder, $_state, a.getType(), a, b); + }]> + ]; + + let hasVerifier = 1; +} + def FHE_ToSignedOp : FHE_Op<"to_signed", [NoSideEffect]> { let summary = "Cast an unsigned integer to a signed one"; diff --git a/compiler/include/concretelang/Dialect/FHE/Transforms/CMakeLists.txt b/compiler/include/concretelang/Dialect/FHE/Transforms/CMakeLists.txt index 1cac703ce..4b41f4866 100644 --- a/compiler/include/concretelang/Dialect/FHE/Transforms/CMakeLists.txt +++ b/compiler/include/concretelang/Dialect/FHE/Transforms/CMakeLists.txt @@ -2,3 +2,7 @@ set(LLVM_TARGET_DEFINITIONS Boolean.td) mlir_tablegen(Boolean.h.inc -gen-pass-decls -name Transforms) add_public_tablegen_target(ConcretelangFHEBooleanPassIncGen) add_dependencies(mlir-headers ConcretelangFHEBooleanPassIncGen) +set(LLVM_TARGET_DEFINITIONS EncryptedMulToDoubleTLU.td) +mlir_tablegen(EncryptedMulToDoubleTLU.h.inc -gen-pass-decls -name Transforms) +add_public_tablegen_target(EncryptedMulToDoubleTLUPassIncGen) +add_dependencies(mlir-headers EncryptedMulToDoubleTLUPassIncGen) diff --git a/compiler/include/concretelang/Dialect/FHE/Transforms/EncryptedMulToDoubleTLU.h b/compiler/include/concretelang/Dialect/FHE/Transforms/EncryptedMulToDoubleTLU.h new file mode 100644 index 000000000..327898af7 --- /dev/null +++ b/compiler/include/concretelang/Dialect/FHE/Transforms/EncryptedMulToDoubleTLU.h @@ -0,0 +1,24 @@ +// Part of the Concrete Compiler Project, under the BSD3 License with Zama +// Exceptions. See +// https://github.com/zama-ai/concrete-compiler-internal/blob/main/LICENSE.txt +// for license information. + +#ifndef CONCRETELANG_FHE_ENCRYPTED_MUL_TO_DOUBLE_TLU_PASS_H +#define CONCRETELANG_FHE_ENCRYPTED_MUL_TO_DOUBLE_TLU_PASS_H + +#include +#include +#include + +#define GEN_PASS_CLASSES + +#include + +namespace mlir { +namespace concretelang { +std::unique_ptr> +createEncryptedMulToDoubleTLUPass(); +} // namespace concretelang +} // namespace mlir + +#endif diff --git a/compiler/include/concretelang/Dialect/FHE/Transforms/EncryptedMulToDoubleTLU.td b/compiler/include/concretelang/Dialect/FHE/Transforms/EncryptedMulToDoubleTLU.td new file mode 100644 index 000000000..e77f32815 --- /dev/null +++ b/compiler/include/concretelang/Dialect/FHE/Transforms/EncryptedMulToDoubleTLU.td @@ -0,0 +1,11 @@ +#ifndef CONCRETELANG_FHE_ENCRYPTED_MUL_TO_DOUBLE_TLU_PASS +#define CONCRETELANG_FHE_ENCRYPTED_MUL_TO_DOUBLE_TLU_PASS + +include "mlir/Pass/PassBase.td" + +def EncryptedMulToDoubleTLU : Pass<"EncryptedMulToDoubleTLU", "::mlir::func::FuncOp"> { + let summary = "Replaces encrypted multiplication with a double table lookup."; + let constructor = "mlir::concretelang::createEncryptedMulToDoubleTLUPass()"; +} + +#endif diff --git a/compiler/include/concretelang/Support/Pipeline.h b/compiler/include/concretelang/Support/Pipeline.h index 9067e3153..81c4cd7da 100644 --- a/compiler/include/concretelang/Support/Pipeline.h +++ b/compiler/include/concretelang/Support/Pipeline.h @@ -34,6 +34,10 @@ markFHELinalgForTiling(mlir::MLIRContext &context, mlir::ModuleOp &module, llvm::ArrayRef tileSizes, std::function enablePass); +mlir::LogicalResult +transformHighLevelFHEOps(mlir::MLIRContext &context, mlir::ModuleOp &module, + std::function enablePass); + mlir::LogicalResult lowerFHELinalgToFHE(mlir::MLIRContext &context, mlir::ModuleOp &module, llvm::Optional &fheContext, diff --git a/compiler/lib/Conversion/FHEToTFHECrt/FHEToTFHECrt.cpp b/compiler/lib/Conversion/FHEToTFHECrt/FHEToTFHECrt.cpp index 5a89a171a..1b1ac5f18 100644 --- a/compiler/lib/Conversion/FHEToTFHECrt/FHEToTFHECrt.cpp +++ b/compiler/lib/Conversion/FHEToTFHECrt/FHEToTFHECrt.cpp @@ -458,6 +458,42 @@ struct NegEintOpPattern : CrtOpPattern { } }; +/// Rewriter for the `FHE::to_signed` operation. +struct ToSignedOpPattern : public CrtOpPattern { + ToSignedOpPattern(mlir::MLIRContext *context, + concretelang::CrtLoweringParameters params, + mlir::PatternBenefit benefit = 1) + : CrtOpPattern(context, params, benefit) {} + + mlir::LogicalResult + matchAndRewrite(FHE::ToSignedOp op, FHE::ToSignedOp::Adaptor adaptor, + mlir::ConversionPatternRewriter &rewriter) const override { + + typing::TypeConverter converter{loweringParameters}; + rewriter.replaceOp(op, {adaptor.input()}); + + return mlir::success(); + } +}; + +/// Rewriter for the `FHE::to_unsigned` operation. +struct ToUnsignedOpPattern : public CrtOpPattern { + ToUnsignedOpPattern(mlir::MLIRContext *context, + concretelang::CrtLoweringParameters params, + mlir::PatternBenefit benefit = 1) + : CrtOpPattern(context, params, benefit) {} + + mlir::LogicalResult + matchAndRewrite(FHE::ToUnsignedOp op, FHE::ToUnsignedOp::Adaptor adaptor, + mlir::ConversionPatternRewriter &rewriter) const override { + + typing::TypeConverter converter{loweringParameters}; + rewriter.replaceOp(op, {adaptor.input()}); + + return mlir::success(); + } +}; + /// Rewriter for the `FHE::mul_eint_int` operation. struct MulEintIntOpPattern : CrtOpPattern { @@ -937,6 +973,10 @@ struct FHEToTFHECrtPass : public FHEToTFHECrtBase { lowering::NegEintOpPattern, // |_ `FHE::mul_eint_int` lowering::MulEintIntOpPattern, + // |_ `FHE::to_unsigned` + lowering::ToUnsignedOpPattern, + // |_ `FHE::to_signed` + lowering::ToSignedOpPattern, // |_ `FHE::apply_lookup_table` lowering::ApplyLookupTableEintOpPattern>(&getContext(), loweringParameters); diff --git a/compiler/lib/Conversion/FHEToTFHEScalar/FHEToTFHEScalar.cpp b/compiler/lib/Conversion/FHEToTFHEScalar/FHEToTFHEScalar.cpp index 2ca43f3a4..301a43fec 100644 --- a/compiler/lib/Conversion/FHEToTFHEScalar/FHEToTFHEScalar.cpp +++ b/compiler/lib/Conversion/FHEToTFHEScalar/FHEToTFHEScalar.cpp @@ -273,6 +273,41 @@ struct MulEintIntOpPattern : public ScalarOpPattern { } }; +/// Rewriter for the `FHE::to_signed` operation. +struct ToSignedOpPattern : public ScalarOpPattern { + ToSignedOpPattern(mlir::TypeConverter &converter, mlir::MLIRContext *context, + mlir::PatternBenefit benefit = 1) + : ScalarOpPattern(converter, context, benefit) {} + + mlir::LogicalResult + matchAndRewrite(FHE::ToSignedOp op, FHE::ToSignedOp::Adaptor adaptor, + mlir::ConversionPatternRewriter &rewriter) const override { + + typing::TypeConverter converter; + rewriter.replaceOp(op, {adaptor.input()}); + + return mlir::success(); + } +}; + +/// Rewriter for the `FHE::to_unsigned` operation. +struct ToUnsignedOpPattern : public ScalarOpPattern { + ToUnsignedOpPattern(mlir::TypeConverter &converter, + mlir::MLIRContext *context, + mlir::PatternBenefit benefit = 1) + : ScalarOpPattern(converter, context, benefit) {} + + mlir::LogicalResult + matchAndRewrite(FHE::ToUnsignedOp op, FHE::ToUnsignedOp::Adaptor adaptor, + mlir::ConversionPatternRewriter &rewriter) const override { + + typing::TypeConverter converter; + rewriter.replaceOp(op, {adaptor.input()}); + + return mlir::success(); + } +}; + /// Rewriter for the `FHE::apply_lookup_table` operation. struct ApplyLookupTableEintOpPattern : public ScalarOpPattern { @@ -474,7 +509,11 @@ struct FHEToTFHEScalarPass : public FHEToTFHEScalarBase { // |_ `FHE::sub_eint` lowering::SubEintOpPattern, // |_ `FHE::mul_eint_int` - lowering::MulEintIntOpPattern>(converter, &getContext()); + lowering::MulEintIntOpPattern, + // |_ `FHE::to_signed` + lowering::ToSignedOpPattern, + // |_ `FHE::to_unsigned` + lowering::ToUnsignedOpPattern>(converter, &getContext()); // |_ `FHE::apply_lookup_table` patterns.add( converter, &getContext(), loweringParameters); diff --git a/compiler/lib/Dialect/FHE/Analysis/MANP.cpp b/compiler/lib/Dialect/FHE/Analysis/MANP.cpp index 15ab78667..0c489367a 100644 --- a/compiler/lib/Dialect/FHE/Analysis/MANP.cpp +++ b/compiler/lib/Dialect/FHE/Analysis/MANP.cpp @@ -424,6 +424,34 @@ static llvm::APInt getSqMANP( return eNorm; } +static llvm::APInt getSqMANP( + mlir::concretelang::FHE::ToSignedOp op, + llvm::ArrayRef *> operandMANPs) { + + assert( + operandMANPs.size() == 1 && + operandMANPs[0]->getValue().getMANP().hasValue() && + "Missing squared Minimal Arithmetic Noise Padding for encrypted operand"); + + llvm::APInt eNorm = operandMANPs[0]->getValue().getMANP().getValue(); + + return eNorm; +} + +static llvm::APInt getSqMANP( + mlir::concretelang::FHE::ToUnsignedOp op, + llvm::ArrayRef *> operandMANPs) { + + assert( + operandMANPs.size() == 1 && + operandMANPs[0]->getValue().getMANP().hasValue() && + "Missing squared Minimal Arithmetic Noise Padding for encrypted operand"); + + llvm::APInt eNorm = operandMANPs[0]->getValue().getMANP().getValue(); + + return eNorm; +} + /// Calculates the squared Minimal Arithmetic Noise Padding of a dot operation /// that is equivalent to an `FHE.mul_eint_int` operation. static llvm::APInt getSqMANP( @@ -1139,6 +1167,12 @@ struct MANPAnalysis : public mlir::ForwardDataFlowAnalysis { } else if (auto boolNotOp = llvm::dyn_cast(op)) { norm2SqEquiv = getSqMANP(boolNotOp, operands); + } else if (auto toSignedOp = + llvm::dyn_cast(op)) { + norm2SqEquiv = getSqMANP(toSignedOp, operands); + } else if (auto toUnsignedOp = + llvm::dyn_cast(op)) { + norm2SqEquiv = getSqMANP(toUnsignedOp, operands); } else if (auto mulEintIntOp = llvm::dyn_cast(op)) { norm2SqEquiv = getSqMANP(mulEintIntOp, operands); diff --git a/compiler/lib/Dialect/FHE/IR/FHEOps.cpp b/compiler/lib/Dialect/FHE/IR/FHEOps.cpp index 5896fb09c..18915270a 100644 --- a/compiler/lib/Dialect/FHE/IR/FHEOps.cpp +++ b/compiler/lib/Dialect/FHE/IR/FHEOps.cpp @@ -177,6 +177,22 @@ mlir::LogicalResult MulEintIntOp::verify() { return mlir::success(); } +mlir::LogicalResult MulEintOp::verify() { + auto a = this->a().getType().dyn_cast(); + auto b = this->b().getType().dyn_cast(); + auto out = this->getResult().getType().dyn_cast(); + + if (!verifyEncryptedIntegerInputsConsistency(*this->getOperation(), a, b)) { + return ::mlir::failure(); + } + if (!verifyEncryptedIntegerInputAndResultConsistency(*this->getOperation(), a, + out)) { + return ::mlir::failure(); + } + + return ::mlir::success(); +} + mlir::LogicalResult ToSignedOp::verify() { auto input = this->input().getType().cast(); auto output = this->getResult().getType().cast(); diff --git a/compiler/lib/Dialect/FHE/Transforms/CMakeLists.txt b/compiler/lib/Dialect/FHE/Transforms/CMakeLists.txt index b8972f347..853c73e76 100644 --- a/compiler/lib/Dialect/FHE/Transforms/CMakeLists.txt +++ b/compiler/lib/Dialect/FHE/Transforms/CMakeLists.txt @@ -1,6 +1,7 @@ add_mlir_library( FHEDialectTransforms Boolean.cpp + EncryptedMulToDoubleTLU.cpp ADDITIONAL_HEADER_DIRS ${PROJECT_SOURCE_DIR}/include/concretelang/Dialect/FHE DEPENDS diff --git a/compiler/lib/Dialect/FHE/Transforms/EncryptedMulToDoubleTLU.cpp b/compiler/lib/Dialect/FHE/Transforms/EncryptedMulToDoubleTLU.cpp new file mode 100644 index 000000000..0d19dfb0a --- /dev/null +++ b/compiler/lib/Dialect/FHE/Transforms/EncryptedMulToDoubleTLU.cpp @@ -0,0 +1,178 @@ +// Part of the Concrete Compiler Project, under the BSD3 License with Zama +// Exceptions. See +// https://github.com/zama-ai/concrete-compiler-internal/blob/main/LICENSE.txt +// for license information. + +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include + +using namespace mlir::concretelang::FHE; + +namespace mlir { +namespace concretelang { +namespace { + +class EncryptedMulOpPattern : public mlir::OpConversionPattern { +public: + EncryptedMulOpPattern(mlir::MLIRContext *context) + : mlir::OpConversionPattern( + context, ::mlir::concretelang::DEFAULT_PATTERN_BENEFIT) {} + + mlir::LogicalResult + matchAndRewrite(FHE::MulEintOp op, FHE::MulEintOp::Adaptor adaptor, + mlir::ConversionPatternRewriter &rewriter) const override { + + auto inputType = adaptor.a().getType(); + auto bitWidth = inputType.cast().getWidth(); + auto isSigned = inputType.cast().isSigned(); + mlir::Type signedType = + FHE::EncryptedSignedIntegerType::get(op->getContext(), bitWidth); + + // Note: + // ----- + // + // The signedness of a value is only important: + // + when used as function input / output, because it changes the + // encoding/decoding used. + // + when used as tlu input, because it changes the encoding of the lut. + // + // Otherwise, for the leveled operations, the semantics are compatible. We + // just have to please the verifier that usually requires the same + // signedness for inputs and outputs. + + // s = a + b + mlir::Value sum = + rewriter.create(op->getLoc(), adaptor.a(), adaptor.b()); + + // se = (s)^2/4 + // Depending on whether a,b,s are signed or not, we need a different lut to + // compute (.)^2/4. + mlir::SmallVector rawSumLut; + if (isSigned) { + rawSumLut = generateSignedLut(bitWidth); + } else { + rawSumLut = generateUnsignedLut(bitWidth); + } + mlir::Value sumLut = rewriter.create( + op->getLoc(), mlir::DenseIntElementsAttr::get( + mlir::RankedTensorType::get( + rawSumLut.size(), rewriter.getIntegerType(64)), + rawSumLut)); + mlir::Value sumTluOutput = rewriter.create( + op->getLoc(), inputType, sum, sumLut); + + // d = a - b + mlir::Value diff = + rewriter.create(op->getLoc(), adaptor.a(), adaptor.b()); + + // de = (d)^2/4 + // Here, the tlu must be performed with signed encoded lut, to properly + // bootstrap negative values that may arise in the computation of d. If the + // inputs are not signed, we cast the output to a signed encrypted integer. + mlir::Value diffO; + if (isSigned) { + diffO = diff; + } else { + diff = rewriter.create(op->getLoc(), signedType, diff); + } + mlir::SmallVector rawDiffLut = generateSignedLut(bitWidth); + mlir::Value diffLut = rewriter.create( + op->getLoc(), mlir::DenseIntElementsAttr::get( + mlir::RankedTensorType::get( + rawDiffLut.size(), rewriter.getIntegerType(64)), + rawDiffLut)); + mlir::Value diffTluOutput = rewriter.create( + op->getLoc(), inputType, diff, diffLut); + + // o = se - de + mlir::Value output = rewriter.create( + op->getLoc(), inputType, sumTluOutput, diffTluOutput); + + rewriter.replaceOp(op, {output}); + + return mlir::success(); + } + +private: + static mlir::SmallVector generateUnsignedLut(unsigned bitWidth) { + mlir::SmallVector rawLut; + uint64_t lutLen = 1 << bitWidth; + for (uint64_t i = 0; i < lutLen; ++i) { + rawLut.push_back((i * i) / 4); + } + return rawLut; + } + + static mlir::SmallVector generateSignedLut(unsigned bitWidth) { + mlir::SmallVector rawLut; + uint64_t lutLen = 1 << bitWidth; + for (uint64_t i = 0; i < lutLen / 2; ++i) { + rawLut.push_back((i * i) / 4); + } + for (uint64_t i = lutLen / 2; i > 0; --i) { + rawLut.push_back((i * i) / 4); + } + return rawLut; + } +}; + +} // namespace + +/// This pass rewrites an `FHE::MulEintOp` into a set of ops of the `FHE` +/// dialects. +/// +/// It relies on the observation that `x*y` can be turned into `((x+y)^2)/4 - +/// ((x-y)^2)/4`, which uses operations already available in the `FHE` dialect: +/// + `x+y` can be computed with the leveled operation `add_eint` +/// + `x-y` can be computed with the leveled operation `sub_eint` +/// + `(a^2)/4` can be computed with a table lookup `apply_table_lookup` +/// +/// Gotchas: +/// -------- +/// +/// + Since we use the leveled addition and subtraction, we have to increment +/// the bitwidth of the inputs to properly +/// encode the carry of the computation. This change in bitwidth must then be +/// propagated to the whole graph, both upstream and downstream. +/// + This graph-wide update may reach existing `apply_lookup_table` operations, +/// which in turn will necessitate an +/// update of the size of the lookup table. +class EncryptedMulToDoubleTLU + : public EncryptedMulToDoubleTLUBase { + +public: + void runOnOperation() override { + mlir::func::FuncOp funcOp = getOperation(); + + mlir::ConversionTarget target(getContext()); + + target.addLegalDialect(); + target.addLegalDialect(); + target.addIllegalOp(); + + mlir::RewritePatternSet patterns(funcOp->getContext()); + patterns.add(funcOp->getContext()); + if (mlir::applyPartialConversion(funcOp, target, std::move(patterns)) + .failed()) { + funcOp->emitError("Failed to rewrite FHE mul_eint operation."); + this->signalPassFailure(); + } + } +}; + +std::unique_ptr<::mlir::OperationPass<::mlir::func::FuncOp>> +createEncryptedMulToDoubleTLUPass() { + return std::make_unique(); +} + +} // namespace concretelang +} // namespace mlir diff --git a/compiler/lib/Support/CMakeLists.txt b/compiler/lib/Support/CMakeLists.txt index 358e90659..b3ea2fc7b 100644 --- a/compiler/lib/Support/CMakeLists.txt +++ b/compiler/lib/Support/CMakeLists.txt @@ -24,6 +24,7 @@ add_mlir_library( ExtractSDFGOps MLIRLowerableDialectsToLLVM FHEDialectAnalysis + FHEDialectTransforms RTDialectAnalysis ConcretelangTransforms ConcretelangBConcreteTransforms diff --git a/compiler/lib/Support/CompilerEngine.cpp b/compiler/lib/Support/CompilerEngine.cpp index 854719ca2..e4467581a 100644 --- a/compiler/lib/Support/CompilerEngine.cpp +++ b/compiler/lib/Support/CompilerEngine.cpp @@ -289,6 +289,13 @@ CompilerEngine::compile(llvm::SourceMgr &sm, Target target, OptionalLib lib) { return errorDiag("Transforming FHE boolean ops failed"); } + // Encrypted mul rewriting + if (mlir::concretelang::pipeline::transformHighLevelFHEOps(mlirContext, + module, enablePass) + .failed()) { + return StreamStringError("Rewriting of encrypted mul failed"); + } + // FHE High level pass to determine FHE parameters if (auto err = this->determineFHEParameters(res)) return std::move(err); diff --git a/compiler/lib/Support/Pipeline.cpp b/compiler/lib/Support/Pipeline.cpp index 6450cd7eb..1015ecd45 100644 --- a/compiler/lib/Support/Pipeline.cpp +++ b/compiler/lib/Support/Pipeline.cpp @@ -36,6 +36,7 @@ #include #include #include +#include #include #include #include @@ -174,6 +175,16 @@ markFHELinalgForTiling(mlir::MLIRContext &context, mlir::ModuleOp &module, return pm.run(module.getOperation()); } +mlir::LogicalResult +transformHighLevelFHEOps(mlir::MLIRContext &context, mlir::ModuleOp &module, + std::function enablePass) { + mlir::PassManager pm(&context); + pipelinePrinting("transformHighLevelFHEOps", pm, context); + addPotentiallyNestedPass(pm, createEncryptedMulToDoubleTLUPass(), enablePass); + + return pm.run(module.getOperation()); +} + mlir::LogicalResult lowerFHELinalgToFHE(mlir::MLIRContext &context, mlir::ModuleOp &module, llvm::Optional &fheContext, diff --git a/compiler/tests/check_tests/Dialect/FHE/Transforms/mul_eint.mlir b/compiler/tests/check_tests/Dialect/FHE/Transforms/mul_eint.mlir new file mode 100644 index 000000000..0b11156ec --- /dev/null +++ b/compiler/tests/check_tests/Dialect/FHE/Transforms/mul_eint.mlir @@ -0,0 +1,17 @@ +// RUN: concretecompiler --action=dump-fhe --split-input-file %s 2>&1 | FileCheck %s + +// CHECK: func.func @simple_eint(%arg0: !FHE.eint<3>, %arg1: !FHE.eint<3>) -> !FHE.eint<3> { +// CHECK-NEXT: %cst = arith.constant dense<[0, 0, 1, 2, 4, 2, 1, 0]> : tensor<8xi64> +// CHECK-NEXT: %cst_0 = arith.constant dense<[0, 0, 1, 2, 4, 6, 9, 12]> : tensor<8xi64> +// CHECK-NEXT: %0 = "FHE.add_eint"(%arg0, %arg1) {MANP = 2 : ui3} : (!FHE.eint<3>, !FHE.eint<3>) -> !FHE.eint<3> +// CHECK-NEXT: %1 = "FHE.apply_lookup_table"(%0, %cst_0) {MANP = 1 : ui1} : (!FHE.eint<3>, tensor<8xi64>) -> !FHE.eint<3> +// CHECK-NEXT: %2 = "FHE.sub_eint"(%arg0, %arg1) {MANP = 2 : ui3} : (!FHE.eint<3>, !FHE.eint<3>) -> !FHE.eint<3> +// CHECK-NEXT: %3 = "FHE.to_signed"(%2) {MANP = 2 : ui3} : (!FHE.eint<3>) -> !FHE.esint<3> +// CHECK-NEXT: %4 = "FHE.apply_lookup_table"(%3, %cst) {MANP = 1 : ui1} : (!FHE.esint<3>, tensor<8xi64>) -> !FHE.eint<3> +// CHECK-NEXT: %5 = "FHE.sub_eint"(%1, %4) {MANP = 2 : ui3} : (!FHE.eint<3>, !FHE.eint<3>) -> !FHE.eint<3> +// CHECK-NEXT: return %5 : !FHE.eint<3> +// CHECK-NEXT: } +func.func @simple_eint(%arg0: !FHE.eint<3>, %arg1: !FHE.eint<3>) -> !FHE.eint<3> { + %0 = "FHE.mul_eint"(%arg0, %arg1): (!FHE.eint<3>, !FHE.eint<3>) -> (!FHE.eint<3>) + return %0: !FHE.eint<3> +} diff --git a/compiler/tests/check_tests/Dialect/FHE/mul_eint.invalid.mlir b/compiler/tests/check_tests/Dialect/FHE/mul_eint.invalid.mlir new file mode 100644 index 000000000..12ebc10a2 --- /dev/null +++ b/compiler/tests/check_tests/Dialect/FHE/mul_eint.invalid.mlir @@ -0,0 +1,15 @@ +// RUN: not concretecompiler --split-input-file --action=roundtrip %s 2>&1| FileCheck %s + +// CHECK-LABEL: error: 'FHE.mul_eint' op should have the width of encrypted inputs equal +func.func @bad_inputs_width(%arg0: !FHE.eint<2>, %arg1: !FHE.eint<3>) -> !FHE.eint<2> { + %1 = "FHE.mul_eint"(%arg0, %arg1): (!FHE.eint<2>, !FHE.eint<3>) -> (!FHE.eint<2>) + return %1: !FHE.eint<2> +} + +// ----- + +// CHECK-LABEL: error: 'FHE.mul_eint' op should have the width of encrypted inputs and result equal +func.func @bad_result_width(%arg0: !FHE.eint<2>, %arg1: !FHE.eint<2>) -> !FHE.eint<3> { + %1 = "FHE.mul_eint"(%arg0, %arg1): (!FHE.eint<2>, !FHE.eint<2>) -> (!FHE.eint<3>) + return %1: !FHE.eint<3> +} diff --git a/compiler/tests/check_tests/Dialect/FHE/ops.mlir b/compiler/tests/check_tests/Dialect/FHE/ops.mlir index 014dd522d..418f40e51 100644 --- a/compiler/tests/check_tests/Dialect/FHE/ops.mlir +++ b/compiler/tests/check_tests/Dialect/FHE/ops.mlir @@ -176,6 +176,15 @@ func.func @mul_eint_int(%arg0: !FHE.eint<2>) -> !FHE.eint<2> { return %1: !FHE.eint<2> } +// CHECK-LABEL: func.func @mul_eint(%arg0: !FHE.eint<2>, %arg1: !FHE.eint<2>) -> !FHE.eint<2> +func.func @mul_eint(%arg0: !FHE.eint<2>, %arg1: !FHE.eint<2>) -> !FHE.eint<2> { + // CHECK-NEXT: %[[V0:.*]] = "FHE.mul_eint"(%arg0, %arg1) : (!FHE.eint<2>, !FHE.eint<2>) -> !FHE.eint<2> + // CHECK-NEXT: return %[[V0]] : !FHE.eint<2> + + %0 = "FHE.mul_eint"(%arg0, %arg1): (!FHE.eint<2>, !FHE.eint<2>) -> (!FHE.eint<2>) + return %0: !FHE.eint<2> +} + // CHECK-LABEL: func.func @mul_eint_int_signed(%arg0: !FHE.esint<2>) -> !FHE.esint<2> func.func @mul_eint_int_signed(%arg0: !FHE.esint<2>) -> !FHE.esint<2> { // CHECK-NEXT: %[[V1:.*]] = arith.constant 1 : i3 diff --git a/compiler/tests/end_to_end_fixture/end_to_end_leveled_gen.py b/compiler/tests/end_to_end_fixture/end_to_end_leveled_gen.py index c0ae6cc8e..b3f4e8d74 100644 --- a/compiler/tests/end_to_end_fixture/end_to_end_leveled_gen.py +++ b/compiler/tests/end_to_end_fixture/end_to_end_leveled_gen.py @@ -1,3 +1,5 @@ +import random + MIN_PRECISON = 1 from end_to_end_linalg_leveled_gen import P_ERROR @@ -302,6 +304,56 @@ def main(): print(" - scalar: {0}".format(max_value)) may_check_error_rate() print("---") + # mul_eint + if p <= 15: + def gen_random_encodable(): + while True: + a = random.randint(1, max_value) + b = random.randint(1, max_value) + if a*b <= max_value: + return a, b + + print("description: mul_eint_{0}bits".format(p+1)) + print("program: |") + print( + " func.func @main(%arg0: !FHE.eint<{0}>, %arg1: !FHE.eint<{0}>) -> !FHE.eint<{0}> {{".format(p+1)) + print( + " %1 = \"FHE.mul_eint\"(%arg0, %arg1): (!FHE.eint<{0}>, !FHE.eint<{0}>) -> (!FHE.eint<{0}>)".format(p+1)) + print(" return %1: !FHE.eint<{0}>".format(p+1)) + print(" }") + print("tests:") + print(" - inputs:") + print(" - scalar: 0") + print(" - scalar: 0") + print(" outputs:") + print(" - scalar: 0") + print(" - inputs:") + print(" - scalar: {0}".format(max_value)) + print(" - scalar: 0") + print(" outputs:") + print(" - scalar: 0") + print(" - inputs:") + print(" - scalar: 0") + print(" - scalar: {0}".format(max_value)) + print(" outputs:") + print(" - scalar: 0") + print(" - inputs:") + print(" - scalar: 1") + print(" - scalar: {0}".format(max_value)) + print(" outputs:") + print(" - scalar: {0}".format(max_value)) + print(" - inputs:") + print(" - scalar: {0}".format(max_value)) + print(" - scalar: 1") + print(" outputs:") + print(" - scalar: {0}".format(max_value)) + inp = gen_random_encodable() + print(" - inputs:") + print(" - scalar: {0}".format(inp[0])) + print(" - scalar: {0}".format(inp[1])) + print(" outputs:") + print(" - scalar: {0}".format(inp[0]*inp[1])) + print("---") # signed for p in range(MIN_PRECISON, MAX_PRECISION+1): print("---") @@ -877,6 +929,141 @@ def main(): may_check_error_rate() print("---") + # mul_eint + if 2 <= p <= 15: + def gen_random_encodable(p): + while True: + a = random.randint(min_value, max_value) + b = random.randint(min_value, max_value) + if min_value <= a*b <= max_value: + if p == 3: + return a, b + if not (a in [-1, 1, 0] or b in [-1, 1, 0]): + return a, b + + print("description: signed_mul_eint_{0}bits".format(p+1)) + print("program: |") + print( + " func.func @main(%arg0: !FHE.esint<{0}>, %arg1: !FHE.esint<{0}>) -> !FHE.esint<{0}> {{".format(p+1)) + print( + " %1 = \"FHE.mul_eint\"(%arg0, %arg1): (!FHE.esint<{0}>, !FHE.esint<{0}>) -> (!FHE.esint<{0}>)".format(p+1)) + print(" return %1: !FHE.esint<{0}>".format(p+1)) + print(" }") + print("tests:") + print(" - inputs:") + print(" - scalar: 0") + print(" signed: true") + print(" - scalar: 0") + print(" signed: true") + print(" outputs:") + print(" - scalar: 0") + print(" signed: true") + print(" - inputs:") + print(" - scalar: {0}".format(max_value)) + print(" signed: true") + print(" - scalar: 0") + print(" signed: true") + print(" outputs:") + print(" - scalar: 0") + print(" signed: true") + print(" - inputs:") + print(" - scalar: {0}".format(min_value)) + print(" signed: true") + print(" - scalar: 0") + print(" signed: true") + print(" outputs:") + print(" - scalar: 0") + print(" signed: true") + print(" - inputs:") + print(" - scalar: 0") + print(" signed: true") + print(" - scalar: {0}".format(max_value)) + print(" signed: true") + print(" outputs:") + print(" - scalar: 0") + print(" signed: true") + print(" - inputs:") + print(" - scalar: 0") + print(" signed: true") + print(" - scalar: {0}".format(min_value)) + print(" signed: true") + print(" outputs:") + print(" - scalar: 0") + print(" signed: true") + print(" - inputs:") + print(" - scalar: 1") + print(" signed: true") + print(" - scalar: {0}".format(max_value)) + print(" signed: true") + print(" outputs:") + print(" - scalar: {0}".format(max_value)) + print(" signed: true") + print(" - inputs:") + print(" - scalar: 1") + print(" signed: true") + print(" - scalar: {0}".format(min_value)) + print(" signed: true") + print(" outputs:") + print(" - scalar: {0}".format(min_value)) + print(" signed: true") + print(" - inputs:") + print(" - scalar: {0}".format(max_value)) + print(" signed: true") + print(" - scalar: 1") + print(" signed: true") + print(" outputs:") + print(" - scalar: {0}".format(max_value)) + print(" signed: true") + print(" - inputs:") + print(" - scalar: {0}".format(min_value)) + print(" signed: true") + print(" - scalar: 1") + print(" signed: true") + print(" outputs:") + print(" - scalar: {0}".format(min_value)) + print(" signed: true") + print(" - inputs:") + print(" - scalar: -1") + print(" signed: true") + print(" - scalar: {0}".format(max_value)) + print(" signed: true") + print(" outputs:") + print(" - scalar: {0}".format(-max_value)) + print(" signed: true") + print(" - inputs:") + print(" - scalar: -1") + print(" signed: true") + print(" - scalar: {0}".format(min_value+1)) + print(" signed: true") + print(" outputs:") + print(" - scalar: {0}".format(-(min_value+1))) + print(" signed: true") + print(" - inputs:") + print(" - scalar: {0}".format(max_value)) + print(" signed: true") + print(" - scalar: -1") + print(" signed: true") + print(" outputs:") + print(" - scalar: {0}".format(-max_value)) + print(" signed: true") + print(" - inputs:") + print(" - scalar: {0}".format(min_value+1)) + print(" signed: true") + print(" - scalar: -1") + print(" signed: true") + print(" outputs:") + print(" - scalar: {0}".format(-(min_value+1))) + print(" signed: true") + inp = gen_random_encodable(p+1) + print(" - inputs:") + print(" - scalar: {0}".format(inp[0])) + print(" signed: true") + print(" - scalar: {0}".format(inp[1])) + print(" signed: true") + print(" outputs:") + print(" - scalar: {0}".format(inp[0]*inp[1])) + print(" signed: true") + print("---") if __name__ == "__main__": main()