From d41d14dbb87ee17d57c7ea33249497b32426b7a7 Mon Sep 17 00:00:00 2001 From: youben11 Date: Wed, 1 Feb 2023 09:46:41 +0100 Subject: [PATCH] feat: lower FHE.add on eint64 to ops on smaller chunks this is a first commit to support operations on U64 by decomposing them into smaller chunks (32 chunks of 2 bits). This commit introduce the lowering pass that will be later populated to support other operations. --- .../Dialect/FHE/IR/FHEInterfaces.td | 5 + .../concretelang/Dialect/FHE/IR/FHETypes.td | 34 ++- .../Dialect/FHE/Transforms/BigInt/BigInt.h | 24 ++ .../Dialect/FHE/Transforms/BigInt/BigInt.td | 13 + .../FHE/Transforms/BigInt/CMakeLists.txt | 4 + .../FHE/Transforms/{ => Boolean}/Boolean.h | 2 +- .../FHE/Transforms/{ => Boolean}/Boolean.td | 0 .../FHE/Transforms/Boolean/CMakeLists.txt | 4 + .../Dialect/FHE/Transforms/CMakeLists.txt | 6 +- .../concretelang/Support/CompilerEngine.h | 9 +- .../include/concretelang/Support/Pipeline.h | 5 + compiler/lib/Dialect/FHE/IR/FHEDialect.cpp | 30 +++ compiler/lib/Dialect/FHE/IR/FHEOps.cpp | 21 +- .../lib/Dialect/FHE/Transforms/BigInt.cpp | 222 ++++++++++++++++++ .../lib/Dialect/FHE/Transforms/Boolean.cpp | 2 +- .../lib/Dialect/FHE/Transforms/CMakeLists.txt | 1 + compiler/lib/Support/CompilerEngine.cpp | 48 ++-- compiler/lib/Support/Pipeline.cpp | 20 +- compiler/src/main.cpp | 14 ++ .../FHE/Transform/big_int_transform.mlir | 24 ++ .../check_tests/Dialect/FHE/big_int.mlir | 34 +++ 21 files changed, 492 insertions(+), 30 deletions(-) create mode 100644 compiler/include/concretelang/Dialect/FHE/Transforms/BigInt/BigInt.h create mode 100644 compiler/include/concretelang/Dialect/FHE/Transforms/BigInt/BigInt.td create mode 100644 compiler/include/concretelang/Dialect/FHE/Transforms/BigInt/CMakeLists.txt rename compiler/include/concretelang/Dialect/FHE/Transforms/{ => Boolean}/Boolean.h (89%) rename compiler/include/concretelang/Dialect/FHE/Transforms/{ => Boolean}/Boolean.td (100%) create mode 100644 compiler/include/concretelang/Dialect/FHE/Transforms/Boolean/CMakeLists.txt create mode 100644 compiler/lib/Dialect/FHE/Transforms/BigInt.cpp create mode 100644 compiler/tests/check_tests/Dialect/FHE/Transform/big_int_transform.mlir create mode 100644 compiler/tests/check_tests/Dialect/FHE/big_int.mlir diff --git a/compiler/include/concretelang/Dialect/FHE/IR/FHEInterfaces.td b/compiler/include/concretelang/Dialect/FHE/IR/FHEInterfaces.td index cb0afe564..587a088f6 100644 --- a/compiler/include/concretelang/Dialect/FHE/IR/FHEInterfaces.td +++ b/compiler/include/concretelang/Dialect/FHE/IR/FHEInterfaces.td @@ -25,6 +25,11 @@ def FheIntegerInterface : TypeInterface<"FheIntegerInterface"> { /*description=*/"Get whether the integer is unsigned.", /*retTy=*/"bool", /*methodName=*/"isUnsigned" + >, + InterfaceMethod< + /*description=*/"Get whether the integer is chunked (composed of multiple smaller integers).", + /*retTy=*/"bool", + /*methodName=*/"isChunked" > ]; } diff --git a/compiler/include/concretelang/Dialect/FHE/IR/FHETypes.td b/compiler/include/concretelang/Dialect/FHE/IR/FHETypes.td index e4cbd0bae..fe98730d3 100644 --- a/compiler/include/concretelang/Dialect/FHE/IR/FHETypes.td +++ b/compiler/include/concretelang/Dialect/FHE/IR/FHETypes.td @@ -33,6 +33,7 @@ def FHE_EncryptedIntegerType : FHE_Type<"EncryptedInteger", let extraClassDeclaration = [{ bool isSigned() const { return false; } bool isUnsigned() const { return true; } + bool isChunked() const { return false; } }]; } @@ -61,12 +62,43 @@ def FHE_EncryptedSignedIntegerType : FHE_Type<"EncryptedSignedInteger", let extraClassDeclaration = [{ bool isSigned() const { return true; } bool isUnsigned() const { return false; } + bool isChunked() const { return false; } + }]; +} + +def FHE_ChunkedEncryptedIntegerType : FHE_Type<"ChunkedEncryptedInteger", + [MemRefElementTypeInterface, FheIntegerInterface]> { + let mnemonic = "chunked_eint"; + + let summary = "An encrypted integer composed of multiple chunks"; + + let description = [{ + An encrypted integer composed of multiple chunks. + + Examples: + ```mlir + !FHE.chunked_eint<64> + !FHE.chunked_eint<32> + ``` + }]; + + let parameters = (ins "unsigned":$width); + + let hasCustomAssemblyFormat = 1; + + let genVerifyDecl = true; + + let extraClassDeclaration = [{ + bool isSigned() const { return false; } + bool isUnsigned() const { return true; } + bool isChunked() const { return true; } }]; } def FHE_AnyEncryptedInteger : Type>; def FHE_EncryptedBooleanType : FHE_Type<"EncryptedBoolean", diff --git a/compiler/include/concretelang/Dialect/FHE/Transforms/BigInt/BigInt.h b/compiler/include/concretelang/Dialect/FHE/Transforms/BigInt/BigInt.h new file mode 100644 index 000000000..1de30e64c --- /dev/null +++ b/compiler/include/concretelang/Dialect/FHE/Transforms/BigInt/BigInt.h @@ -0,0 +1,24 @@ +// Part of the Concrete Compiler Project, under the BSD3 License with Zama +// Exceptions. See +// https://github.com/zama-ai/concrete-compiler-internal/blob/main/LICENSE.txt +// for license information. + +#ifndef CONCRETELANG_FHE_BIGINT_PASS_H +#define CONCRETELANG_FHE_BIGINT_PASS_H + +#include +#include + +#define GEN_PASS_CLASSES +#include + +namespace mlir { +namespace concretelang { + +std::unique_ptr> +createFHEBigIntTransformPass(unsigned int chunkSize, unsigned int chunkWidth); + +} // namespace concretelang +} // namespace mlir + +#endif diff --git a/compiler/include/concretelang/Dialect/FHE/Transforms/BigInt/BigInt.td b/compiler/include/concretelang/Dialect/FHE/Transforms/BigInt/BigInt.td new file mode 100644 index 000000000..98cbc312a --- /dev/null +++ b/compiler/include/concretelang/Dialect/FHE/Transforms/BigInt/BigInt.td @@ -0,0 +1,13 @@ +#ifndef CONCRETELANG_FHE_BIGINT_PASS +#define CONCRETELANG_FHE_BIGINT_PASS + +include "mlir/Pass/PassBase.td" + +def FHEBigIntTransform : Pass<"fhe-big-int-transform"> { + let summary = "Transform FHE operations on big integer into operations on chunks of small integer"; + let constructor = "mlir::concretelang::createFHEBigIntTransformPass()"; + let options = []; + let dependentDialects = [ "mlir::concretelang::FHE::FHEDialect" ]; +} + +#endif diff --git a/compiler/include/concretelang/Dialect/FHE/Transforms/BigInt/CMakeLists.txt b/compiler/include/concretelang/Dialect/FHE/Transforms/BigInt/CMakeLists.txt new file mode 100644 index 000000000..c97d7ecb6 --- /dev/null +++ b/compiler/include/concretelang/Dialect/FHE/Transforms/BigInt/CMakeLists.txt @@ -0,0 +1,4 @@ +set(LLVM_TARGET_DEFINITIONS BigInt.td) +mlir_tablegen(BigInt.h.inc -gen-pass-decls -name Transforms) +add_public_tablegen_target(ConcretelangFHEBigIntPassIncGen) +add_dependencies(mlir-headers ConcretelangFHEBigIntPassIncGen) diff --git a/compiler/include/concretelang/Dialect/FHE/Transforms/Boolean.h b/compiler/include/concretelang/Dialect/FHE/Transforms/Boolean/Boolean.h similarity index 89% rename from compiler/include/concretelang/Dialect/FHE/Transforms/Boolean.h rename to compiler/include/concretelang/Dialect/FHE/Transforms/Boolean/Boolean.h index 386c6a6a1..ce238d37c 100644 --- a/compiler/include/concretelang/Dialect/FHE/Transforms/Boolean.h +++ b/compiler/include/concretelang/Dialect/FHE/Transforms/Boolean/Boolean.h @@ -10,7 +10,7 @@ #include #define GEN_PASS_CLASSES -#include +#include namespace mlir { namespace concretelang { diff --git a/compiler/include/concretelang/Dialect/FHE/Transforms/Boolean.td b/compiler/include/concretelang/Dialect/FHE/Transforms/Boolean/Boolean.td similarity index 100% rename from compiler/include/concretelang/Dialect/FHE/Transforms/Boolean.td rename to compiler/include/concretelang/Dialect/FHE/Transforms/Boolean/Boolean.td diff --git a/compiler/include/concretelang/Dialect/FHE/Transforms/Boolean/CMakeLists.txt b/compiler/include/concretelang/Dialect/FHE/Transforms/Boolean/CMakeLists.txt new file mode 100644 index 000000000..1cac703ce --- /dev/null +++ b/compiler/include/concretelang/Dialect/FHE/Transforms/Boolean/CMakeLists.txt @@ -0,0 +1,4 @@ +set(LLVM_TARGET_DEFINITIONS Boolean.td) +mlir_tablegen(Boolean.h.inc -gen-pass-decls -name Transforms) +add_public_tablegen_target(ConcretelangFHEBooleanPassIncGen) +add_dependencies(mlir-headers ConcretelangFHEBooleanPassIncGen) diff --git a/compiler/include/concretelang/Dialect/FHE/Transforms/CMakeLists.txt b/compiler/include/concretelang/Dialect/FHE/Transforms/CMakeLists.txt index 4b41f4866..b237690f5 100644 --- a/compiler/include/concretelang/Dialect/FHE/Transforms/CMakeLists.txt +++ b/compiler/include/concretelang/Dialect/FHE/Transforms/CMakeLists.txt @@ -1,8 +1,6 @@ -set(LLVM_TARGET_DEFINITIONS Boolean.td) -mlir_tablegen(Boolean.h.inc -gen-pass-decls -name Transforms) -add_public_tablegen_target(ConcretelangFHEBooleanPassIncGen) -add_dependencies(mlir-headers ConcretelangFHEBooleanPassIncGen) set(LLVM_TARGET_DEFINITIONS EncryptedMulToDoubleTLU.td) mlir_tablegen(EncryptedMulToDoubleTLU.h.inc -gen-pass-decls -name Transforms) add_public_tablegen_target(EncryptedMulToDoubleTLUPassIncGen) add_dependencies(mlir-headers EncryptedMulToDoubleTLUPassIncGen) +add_subdirectory(BigInt) +add_subdirectory(Boolean) diff --git a/compiler/include/concretelang/Support/CompilerEngine.h b/compiler/include/concretelang/Support/CompilerEngine.h index a43e1c9f5..624c2cede 100644 --- a/compiler/include/concretelang/Support/CompilerEngine.h +++ b/compiler/include/concretelang/Support/CompilerEngine.h @@ -71,13 +71,20 @@ struct CompilationOptions { optimizer::Config optimizerConfig; + /// When decomposing big integers into chunks, chunkSize is the total number + /// of bits used for the message, including the carry, while chunkWidth is + /// only the number of bits used during encoding and decoding of a big integer + unsigned int chunkSize; + unsigned int chunkWidth; + CompilationOptions() : v0FHEConstraints(llvm::None), verifyDiagnostics(false), autoParallelize(false), loopParallelize(false), batchConcreteOps(false), emitSDFGOps(false), unrollLoopsWithSDFGConvertibleOps(false), dataflowParallelize(false), optimizeConcrete(true), emitGPUOps(false), clientParametersFuncName(llvm::None), - optimizerConfig(optimizer::DEFAULT_CONFIG){}; + optimizerConfig(optimizer::DEFAULT_CONFIG), chunkSize(4), + chunkWidth(2){}; CompilationOptions(std::string funcname) : CompilationOptions() { clientParametersFuncName = funcname; diff --git a/compiler/include/concretelang/Support/Pipeline.h b/compiler/include/concretelang/Support/Pipeline.h index 81c4cd7da..9152f75f0 100644 --- a/compiler/include/concretelang/Support/Pipeline.h +++ b/compiler/include/concretelang/Support/Pipeline.h @@ -48,6 +48,11 @@ mlir::LogicalResult transformFHEBoolean(mlir::MLIRContext &context, mlir::ModuleOp &module, std::function enablePass); +mlir::LogicalResult +transformFHEBigInt(mlir::MLIRContext &context, mlir::ModuleOp &module, + std::function enablePass, + unsigned int chunkSize, unsigned int chunkWidth); + mlir::LogicalResult lowerFHEToTFHE(mlir::MLIRContext &context, mlir::ModuleOp &module, llvm::Optional &fheContext, diff --git a/compiler/lib/Dialect/FHE/IR/FHEDialect.cpp b/compiler/lib/Dialect/FHE/IR/FHEDialect.cpp index 3bea0d3a9..b8c2c078e 100644 --- a/compiler/lib/Dialect/FHE/IR/FHEDialect.cpp +++ b/compiler/lib/Dialect/FHE/IR/FHEDialect.cpp @@ -89,3 +89,33 @@ mlir::Type EncryptedSignedIntegerType::parse(mlir::AsmParser &p) { return getChecked(loc, loc.getContext(), width); } + +mlir::LogicalResult ChunkedEncryptedIntegerType::verify( + llvm::function_ref<::mlir::InFlightDiagnostic()> emitError, unsigned p) { + if (p == 0) { + emitError() << "FHE.chunked_eint doesn't support precision of 0"; + return mlir::failure(); + } + return mlir::success(); +} + +void ChunkedEncryptedIntegerType::print(mlir::AsmPrinter &p) const { + p << "<" << getWidth() << ">"; +} + +mlir::Type ChunkedEncryptedIntegerType::parse(mlir::AsmParser &p) { + if (p.parseLess()) + return mlir::Type(); + + int width; + + if (p.parseInteger(width)) + return mlir::Type(); + + if (p.parseGreater()) + return mlir::Type(); + + mlir::Location loc = p.getEncodedSourceLoc(p.getNameLoc()); + + return getChecked(loc, loc.getContext(), width); +} diff --git a/compiler/lib/Dialect/FHE/IR/FHEOps.cpp b/compiler/lib/Dialect/FHE/IR/FHEOps.cpp index 18915270a..bc03735c0 100644 --- a/compiler/lib/Dialect/FHE/IR/FHEOps.cpp +++ b/compiler/lib/Dialect/FHE/IR/FHEOps.cpp @@ -29,14 +29,25 @@ bool verifyEncryptedIntegerInputAndResultConsistency( return false; } + if (input.isChunked() != result.isChunked()) { + op.emitOpError("should have the same composition (chunked or not) of " + "encrypted input and result"); + return false; + } + return true; } bool verifyEncryptedIntegerAndIntegerInputsConsistency(mlir::Operation &op, FheIntegerInterface &a, IntegerType &b) { - - if (a.getWidth() + 1 != b.getWidth()) { + if (a.isChunked()) { + if (a.getWidth() != b.getWidth()) { + op.emitOpError("should have the width of plain input equal to width of " + "encrypted input"); + return false; + } + } else if (a.getWidth() + 1 != b.getWidth()) { op.emitOpError("should have the width of plain input equal to width of " "encrypted input + 1"); return false; @@ -58,6 +69,12 @@ bool verifyEncryptedIntegerInputsConsistency(mlir::Operation &op, return false; } + if (a.isChunked() != b.isChunked()) { + op.emitOpError("should have the same composition (chunked or not) of " + "encrypted inputs"); + return false; + } + return true; } diff --git a/compiler/lib/Dialect/FHE/Transforms/BigInt.cpp b/compiler/lib/Dialect/FHE/Transforms/BigInt.cpp new file mode 100644 index 000000000..34c0ae632 --- /dev/null +++ b/compiler/lib/Dialect/FHE/Transforms/BigInt.cpp @@ -0,0 +1,222 @@ +// Part of the Concrete Compiler Project, under the BSD3 License with Zama +// Exceptions. See +// https://github.com/zama-ai/concrete-compiler-internal/blob/main/LICENSE.txt +// for license information. + +#include +#include +#include +#include +#include +#include + +#include +#include +#include +#include +#include +#include +#include + +namespace mlir { +namespace concretelang { + +/// Construct a table lookup to extract the carry bit +mlir::Value getTruthTableCarryExtract(mlir::PatternRewriter &rewriter, + mlir::Location loc, + unsigned int chunkSize, + unsigned int chunkWidth) { + auto tableSize = 1 << chunkSize; + std::vector values; + values.reserve(tableSize); + for (auto i = 0; i < tableSize; i++) { + if (i < 1 << chunkWidth) + values.push_back(llvm::APInt(1, 0, false)); + else + values.push_back(llvm::APInt(1, 1, false)); + } + auto truthTableAttr = mlir::DenseElementsAttr::get( + mlir::RankedTensorType::get({tableSize}, rewriter.getIntegerType(64)), + values); + auto truthTable = + rewriter.create(loc, truthTableAttr); + return truthTable.getResult(); +} + +namespace { + +namespace typing { + +/// Converts `FHE::ChunkedEncryptedInteger` into a tensor of +/// `FHE::EncryptedInteger`. +mlir::RankedTensorType +convertChunkedEint(mlir::MLIRContext *context, + FHE::ChunkedEncryptedIntegerType chunkedEint, + unsigned int chunkSize, unsigned int chunkWidth) { + auto eint = FHE::EncryptedIntegerType::get(context, chunkSize); + auto bigIntWidth = chunkedEint.getWidth(); + assert(bigIntWidth % chunkWidth == 0 && + "chunkWidth must divide width of the big integer"); + auto numberOfChunks = bigIntWidth / chunkWidth; + std::vector shape({numberOfChunks}); + return mlir::RankedTensorType::get(shape, eint); +} + +/// The type converter used to transform `FHE` ops on chunked integers +class TypeConverter : public mlir::TypeConverter { + +public: + TypeConverter(unsigned int chunkSize, unsigned int chunkWidth) { + addConversion([](mlir::Type type) { return type; }); + addConversion([chunkSize, + chunkWidth](FHE::ChunkedEncryptedIntegerType type) { + return convertChunkedEint(type.getContext(), type, chunkSize, chunkWidth); + }); + } +}; + +} // namespace typing + +class AddEintPattern + : public mlir::OpConversionPattern { +public: + AddEintPattern(mlir::TypeConverter &converter, mlir::MLIRContext *context, + unsigned int chunkSize, unsigned int chunkWidth) + : mlir::OpConversionPattern( + converter, context, ::mlir::concretelang::DEFAULT_PATTERN_BENEFIT), + chunkSize(chunkSize), chunkWidth(chunkWidth) {} + + mlir::LogicalResult + matchAndRewrite(FHE::AddEintOp op, FHE::AddEintOp::Adaptor adaptor, + mlir::ConversionPatternRewriter &rewriter) const override { + auto tensorType = adaptor.a().getType().dyn_cast(); + auto shape = tensorType.getShape(); + assert(shape.size() == 1 && + "chunked integer should be converted to flat tensors, but tensor " + "have more than one dimension"); + auto eintChunkWidth = tensorType.getElementType() + .dyn_cast() + .getWidth(); + assert(eintChunkWidth == chunkSize && "wrong tensor elements width"); + auto numberOfChunks = shape[0]; + + mlir::Value carry = + rewriter + .create(op.getLoc(), + FHE::EncryptedIntegerType::get( + rewriter.getContext(), chunkSize)) + .getResult(); + + mlir::Value resultTensor = + rewriter.create(op.getLoc(), adaptor.a().getType()) + .getResult(); + // used to shift the carry bit to the left + mlir::Value twoPowerChunkSizeCst = + rewriter + .create(op.getLoc(), 1 << chunkWidth, + chunkSize + 1) + .getResult(); + // Create the loop + int64_t lb = 0, step = 1; + auto forOp = rewriter.create( + op.getLoc(), lb, numberOfChunks, step, resultTensor, + [&](mlir::OpBuilder &builder, mlir::Location loc, mlir::Value iter, + mlir::ValueRange args) { + // add inputs with the previous carry (init to 0) + mlir::Value leftEint = + builder.create(loc, adaptor.a(), iter); + mlir::Value rightEint = + builder.create(loc, adaptor.b(), iter); + mlir::Value result = + builder.create(loc, leftEint, rightEint) + .getResult(); + mlir::Value resultWithCarry = + builder.create(loc, result, carry).getResult(); + // compute the new carry: either 1 or 0 + carry = + rewriter.create( + op.getLoc(), + FHE::EncryptedIntegerType::get(rewriter.getContext(), + chunkSize), + resultWithCarry, + getTruthTableCarryExtract(rewriter, op.getLoc(), chunkSize, + chunkWidth)); + // remove the carry bit from the result + mlir::Value shiftedCarry = + builder + .create(loc, carry, twoPowerChunkSizeCst) + .getResult(); + mlir::Value finalResult = + builder.create(loc, resultWithCarry, shiftedCarry) + .getResult(); + // insert the result in the result tensor + mlir::Value tensorResult = builder.create( + loc, finalResult, args[0], iter); + builder.create(loc, tensorResult); + }); + rewriter.replaceOp(op, forOp.getResult(0)); + return mlir::success(); + } + +private: + unsigned int chunkSize, chunkWidth; +}; + +/// Perfoms the transformation of big integer operations +class FHEBigIntTransformPass + : public FHEBigIntTransformBase { +public: + FHEBigIntTransformPass(unsigned int chunkSize, unsigned int chunkWidth) + : chunkSize(chunkSize), chunkWidth(chunkWidth){}; + + void runOnOperation() override { + mlir::Operation *op = getOperation(); + mlir::ConversionTarget target(getContext()); + mlir::RewritePatternSet patterns(&getContext()); + typing::TypeConverter converter(chunkSize, chunkWidth); + + // Legal ops created during pattern application + target.addLegalOp(); + concretelang::addDynamicallyLegalTypeOp(target, converter); + // Func ops are only legal with converted types + target.addDynamicallyLegalOp( + [&](mlir::func::FuncOp funcOp) { + return converter.isSignatureLegal(funcOp.getFunctionType()) && + converter.isLegal(&funcOp.getBody()); + }); + mlir::populateFunctionOpInterfaceTypeConversionPattern( + patterns, converter); + patterns.add>(patterns.getContext(), converter); + concretelang::addDynamicallyLegalTypeOp(target, + converter); + + patterns.add(converter, &getContext(), chunkSize, + chunkWidth); + + if (mlir::applyPartialConversion(op, target, std::move(patterns)) + .failed()) { + this->signalPassFailure(); + } + } + +private: + unsigned int chunkSize, chunkWidth; +}; + +} // end anonymous namespace + +std::unique_ptr> +createFHEBigIntTransformPass(unsigned int chunkSize, unsigned int chunkWidth) { + assert(chunkSize >= chunkWidth + 1 && + "chunkSize must be greater than chunkWidth"); + return std::make_unique(chunkSize, chunkWidth); +} + +} // namespace concretelang +} // namespace mlir diff --git a/compiler/lib/Dialect/FHE/Transforms/Boolean.cpp b/compiler/lib/Dialect/FHE/Transforms/Boolean.cpp index c559a39c0..a3a2e1300 100644 --- a/compiler/lib/Dialect/FHE/Transforms/Boolean.cpp +++ b/compiler/lib/Dialect/FHE/Transforms/Boolean.cpp @@ -9,7 +9,7 @@ #include #include -#include +#include #include namespace mlir { diff --git a/compiler/lib/Dialect/FHE/Transforms/CMakeLists.txt b/compiler/lib/Dialect/FHE/Transforms/CMakeLists.txt index 853c73e76..51b71f1f2 100644 --- a/compiler/lib/Dialect/FHE/Transforms/CMakeLists.txt +++ b/compiler/lib/Dialect/FHE/Transforms/CMakeLists.txt @@ -1,5 +1,6 @@ add_mlir_library( FHEDialectTransforms + BigInt.cpp Boolean.cpp EncryptedMulToDoubleTLU.cpp ADDITIONAL_HEADER_DIRS diff --git a/compiler/lib/Support/CompilerEngine.cpp b/compiler/lib/Support/CompilerEngine.cpp index e4467581a..88c56fa92 100644 --- a/compiler/lib/Support/CompilerEngine.cpp +++ b/compiler/lib/Support/CompilerEngine.cpp @@ -296,6 +296,13 @@ CompilerEngine::compile(llvm::SourceMgr &sm, Target target, OptionalLib lib) { return StreamStringError("Rewriting of encrypted mul failed"); } + if (mlir::concretelang::pipeline::transformFHEBigInt( + mlirContext, module, enablePass, options.chunkSize, + options.chunkWidth) + .failed()) { + return errorDiag("Transforming FHE big integer ops failed"); + } + // FHE High level pass to determine FHE parameters if (auto err = this->determineFHEParameters(res)) return std::move(err); @@ -414,8 +421,8 @@ CompilerEngine::compile(llvm::SourceMgr &sm, Target target, OptionalLib lib) { mlirContext, module, enablePass, options.unrollLoopsWithSDFGConvertibleOps) .failed()) { - return errorDiag( - "Extraction of SDFG operations from BConcrete representation failed"); + return errorDiag("Extraction of SDFG operations from BConcrete " + "representation failed"); } } @@ -427,8 +434,8 @@ CompilerEngine::compile(llvm::SourceMgr &sm, Target target, OptionalLib lib) { if (mlir::concretelang::pipeline::lowerBConcreteToStd(mlirContext, module, enablePass) .failed()) { - return errorDiag( - "Lowering from Bufferized Concrete to canonical MLIR dialects failed"); + return errorDiag("Lowering from Bufferized Concrete to canonical MLIR " + "dialects failed"); } // SDFG -> Canonical dialects @@ -595,16 +602,17 @@ CompilerEngine::Library::getCompilationFeedbackPath(std::string outputDirPath) { const std::string CompilerEngine::Library::OBJECT_EXT = ".o"; const std::string CompilerEngine::Library::LINKER = "ld"; #ifdef __APPLE__ -// We need to tell the linker that some symbols will be missing during linking, -// this symbols should be available during runtime however. This is the case -// when JIT compiling, the JIT should either link to the runtime library that -// has the missing symbols, or it would have been loaded even prior to that. -// Starting from Mac 11 (Big Sur), it appears we need to add -L -// /Library/Developer/CommandLineTools/SDKs/MacOSX.sdk/usr/lib -lSystem for the -// sharedlib to link properly. +// We need to tell the linker that some symbols will be missing during +// linking, this symbols should be available during runtime however. This is +// the case when JIT compiling, the JIT should either link to the runtime +// library that has the missing symbols, or it would have been loaded even +// prior to that. Starting from Mac 11 (Big Sur), it appears we need to add -L +// /Library/Developer/CommandLineTools/SDKs/MacOSX.sdk/usr/lib -lSystem for +// the sharedlib to link properly. const std::string CompilerEngine::Library::LINKER_SHARED_OPT = " -dylib -undefined dynamic_lookup -L " - "/Library/Developer/CommandLineTools/SDKs/MacOSX.sdk/usr/lib -lSystem -o "; + "/Library/Developer/CommandLineTools/SDKs/MacOSX.sdk/usr/lib -lSystem " + "-o "; const std::string CompilerEngine::Library::DOT_SHARED_LIB_EXT = ".dylib"; #else // Linux const std::string CompilerEngine::Library::LINKER_SHARED_OPT = " --shared -o "; @@ -798,7 +806,8 @@ llvm::Expected CompilerEngine::Library::emitShared() { std::vector extraArgs; std::string fullRuntimeLibraryName = ""; #ifdef __APPLE__ - // to issue the command for fixing the runtime dependency of the generated lib + // to issue the command for fixing the runtime dependency of the generated + // lib bool fixRuntimeDep = false; #endif if (!runtimeLibraryPath.empty()) { @@ -841,12 +850,13 @@ llvm::Expected CompilerEngine::Library::emitShared() { sharedLibraryPath = path.get(); #ifdef __APPLE__ // when dellocate is used to include dependencies in python wheels, the - // runtime library will have an id that is prefixed with /DLC, and that path - // doesn't exist. So when generated libraries won't be able to find it - // during load time. To solve this, we change the dep in the generated - // library to be relative to the rpath which should be set correctly during - // linking. This shouldn't have an impact when /DLC/concrete/.dylibs/* isn't - // a dependecy in the first place (when not using python). + // runtime library will have an id that is prefixed with /DLC, and that + // path doesn't exist. So when generated libraries won't be able to find + // it during load time. To solve this, we change the dep in the generated + // library to be relative to the rpath which should be set correctly + // during linking. This shouldn't have an impact when + // /DLC/concrete/.dylibs/* isn't a dependecy in the first place (when not + // using python). if (fixRuntimeDep) { std::string fixRuntimeDepCmd = "install_name_tool -change " "/DLC/concrete/.dylibs/" + diff --git a/compiler/lib/Support/Pipeline.cpp b/compiler/lib/Support/Pipeline.cpp index 1015ecd45..123d79faf 100644 --- a/compiler/lib/Support/Pipeline.cpp +++ b/compiler/lib/Support/Pipeline.cpp @@ -35,7 +35,8 @@ #include #include #include -#include +#include +#include #include #include #include @@ -218,6 +219,23 @@ transformFHEBoolean(mlir::MLIRContext &context, mlir::ModuleOp &module, return pm.run(module.getOperation()); } +mlir::LogicalResult +transformFHEBigInt(mlir::MLIRContext &context, mlir::ModuleOp &module, + std::function enablePass, + unsigned int chunkSize, unsigned int chunkWidth) { + mlir::PassManager pm(&context); + addPotentiallyNestedPass( + pm, + mlir::concretelang::createFHEBigIntTransformPass(chunkSize, chunkWidth), + enablePass); + // We want to fully unroll for loops introduced by the BigInt transform since + // MANP doesn't support loops. This is a workaround that make the IR much + // bigger than it should be + addPotentiallyNestedPass(pm, mlir::createLoopUnrollPass(-1, false, true), + enablePass); + return pm.run(module.getOperation()); +} + mlir::LogicalResult lowerFHEToTFHE(mlir::MLIRContext &context, mlir::ModuleOp &module, llvm::Optional &fheContext, diff --git a/compiler/src/main.cpp b/compiler/src/main.cpp index 373d02f88..cffb97950 100644 --- a/compiler/src/main.cpp +++ b/compiler/src/main.cpp @@ -206,6 +206,18 @@ llvm::cl::list llvm::cl::value_desc("argument(uint64)"), llvm::cl::ZeroOrMore, llvm::cl::MiscFlags::CommaSeparated); +llvm::cl::opt chunkSize( + "chunk-size", + llvm::cl::desc( + "Chunk size while decomposing big integers into chunks, default is 4"), + llvm::cl::init(4)); + +llvm::cl::opt chunkWidth( + "chunk-width", + llvm::cl::desc( + "Chunk width while decomposing big integers into chunks, default is 2"), + llvm::cl::init(2)); + llvm::cl::opt jitKeySetCachePath( "jit-keyset-cache-path", llvm::cl::desc("Path to cache KeySet content (unsecure)")); @@ -338,6 +350,8 @@ cmdlineCompilationOptions() { cmdline::unrollLoopsWithSDFGConvertibleOps; options.optimizeConcrete = cmdline::optimizeConcrete; options.emitGPUOps = cmdline::emitGPUOps; + options.chunkSize = cmdline::chunkSize; + options.chunkWidth = cmdline::chunkWidth; if (!cmdline::v0Constraint.empty()) { if (cmdline::v0Constraint.size() != 2) { diff --git a/compiler/tests/check_tests/Dialect/FHE/Transform/big_int_transform.mlir b/compiler/tests/check_tests/Dialect/FHE/Transform/big_int_transform.mlir new file mode 100644 index 000000000..0f0fbc177 --- /dev/null +++ b/compiler/tests/check_tests/Dialect/FHE/Transform/big_int_transform.mlir @@ -0,0 +1,24 @@ +// RUN: concretecompiler --chunk-size 4 --chunk-width 2 --passes fhe-big-int-transform --action=dump-fhe %s 2>&1| FileCheck %s + +// CHECK-LABEL: func.func @add_chunked_eint(%arg0: tensor<32x!FHE.eint<4>>, %arg1: tensor<32x!FHE.eint<4>>) -> tensor<32x!FHE.eint<4>> +func.func @add_chunked_eint(%arg0: !FHE.chunked_eint<64>, %arg1: !FHE.chunked_eint<64>) -> !FHE.chunked_eint<64> { + // CHECK-NEXT: %[[V0:.*]] = "FHE.zero"() : () -> !FHE.eint<4> + // CHECK-NEXT: %[[V1:.*]] = "FHE.zero_tensor"() : () -> tensor<32x!FHE.eint<4>> + // CHECK-NEXT: %[[c4_i5:.*]] = arith.constant 4 : i5 + // CHECK-NEXT: %[[V2:.*]] = affine.for %arg2 = 0 to 32 iter_args(%arg3 = %[[V1]]) -> (tensor<32x!FHE.eint<4>>) { + // CHECK-NEXT: %[[V3:.*]] = tensor.extract %arg0[%arg2] : tensor<32x!FHE.eint<4>> + // CHECK-NEXT: %[[V4:.*]] = tensor.extract %arg1[%arg2] : tensor<32x!FHE.eint<4>> + // CHECK-NEXT: %[[V5:.*]] = "FHE.add_eint"(%[[V3]], %[[V4]]) : (!FHE.eint<4>, !FHE.eint<4>) -> !FHE.eint<4> + // CHECK-NEXT: %[[V6:.*]] = "FHE.add_eint"(%[[V5]], %[[V0]]) : (!FHE.eint<4>, !FHE.eint<4>) -> !FHE.eint<4> + // CHECK-NEXT: %[[cst:.*]] = arith.constant dense<[0, 0, 0, 0, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1]> : tensor<16xi64> + // CHECK-NEXT: %[[V7:.*]] = "FHE.apply_lookup_table"(%[[V6]], %[[cst]]) : (!FHE.eint<4>, tensor<16xi64>) -> !FHE.eint<4> + // CHECK-NEXT: %[[V8:.*]] = "FHE.mul_eint_int"(%[[V7]], %[[c4_i5]]) : (!FHE.eint<4>, i5) -> !FHE.eint<4> + // CHECK-NEXT: %[[V9:.*]] = "FHE.sub_eint"(%[[V6]], %[[V8]]) : (!FHE.eint<4>, !FHE.eint<4>) -> !FHE.eint<4> + // CHECK-NEXT: %[[V10:.*]] = tensor.insert %[[V9]] into %arg3[%arg2] : tensor<32x!FHE.eint<4>> + // CHECK-NEXT: affine.yield %[[V10]] : tensor<32x!FHE.eint<4>> + // CHECK-NEXT: } + // CHECK-NEXT: return %2 : tensor<32x!FHE.eint<4>> + + %1 = "FHE.add_eint"(%arg0, %arg1): (!FHE.chunked_eint<64>, !FHE.chunked_eint<64>) -> (!FHE.chunked_eint<64>) + return %1: !FHE.chunked_eint<64> +} diff --git a/compiler/tests/check_tests/Dialect/FHE/big_int.mlir b/compiler/tests/check_tests/Dialect/FHE/big_int.mlir new file mode 100644 index 000000000..b571b7e64 --- /dev/null +++ b/compiler/tests/check_tests/Dialect/FHE/big_int.mlir @@ -0,0 +1,34 @@ +// RUN: concretecompiler --action=roundtrip %s 2>&1| FileCheck %s + +// CHECK-LABEL: func.func @mul_chunked_eint_int(%arg0: !FHE.chunked_eint<64>) -> !FHE.chunked_eint<64> +func.func @mul_chunked_eint_int(%arg0: !FHE.chunked_eint<64>) -> !FHE.chunked_eint<64> { + // CHECK-NEXT: %[[V1:.*]] = arith.constant 1 : i64 + // CHECK-NEXT: %[[V2:.*]] = "FHE.mul_eint_int"(%arg0, %[[V1]]) : (!FHE.chunked_eint<64>, i64) -> !FHE.chunked_eint<64> + // CHECK-NEXT: return %[[V2]] : !FHE.chunked_eint<64> + + %0 = arith.constant 1 : i64 + %1 = "FHE.mul_eint_int"(%arg0, %0): (!FHE.chunked_eint<64>, i64) -> (!FHE.chunked_eint<64>) + return %1: !FHE.chunked_eint<64> +} + +// CHECK-LABEL: func.func @add_chunked_eint_int(%arg0: !FHE.chunked_eint<64>) -> !FHE.chunked_eint<64> +func.func @add_chunked_eint_int(%arg0: !FHE.chunked_eint<64>) -> !FHE.chunked_eint<64> { + // CHECK-NEXT: %[[V1:.*]] = arith.constant 1 : i64 + // CHECK-NEXT: %[[V2:.*]] = "FHE.add_eint_int"(%arg0, %[[V1]]) : (!FHE.chunked_eint<64>, i64) -> !FHE.chunked_eint<64> + // CHECK-NEXT: return %[[V2]] : !FHE.chunked_eint<64> + + %0 = arith.constant 1 : i64 + %1 = "FHE.add_eint_int"(%arg0, %0): (!FHE.chunked_eint<64>, i64) -> (!FHE.chunked_eint<64>) + return %1: !FHE.chunked_eint<64> +} + +// CHECK-LABEL: func.func @add_chunked_eint(%arg0: !FHE.chunked_eint<64>, %arg1: !FHE.chunked_eint<64>) -> !FHE.chunked_eint<64> +func.func @add_chunked_eint(%arg0: !FHE.chunked_eint<64>, %arg1: !FHE.chunked_eint<64>) -> !FHE.chunked_eint<64> { + // CHECK-NEXT: %[[V1:.*]] = "FHE.add_eint"(%arg0, %arg1) : (!FHE.chunked_eint<64>, !FHE.chunked_eint<64>) -> !FHE.chunked_eint<64> + // CHECK-NEXT: return %[[V1]] : !FHE.chunked_eint<64> + + %1 = "FHE.add_eint"(%arg0, %arg1): (!FHE.chunked_eint<64>, !FHE.chunked_eint<64>) -> (!FHE.chunked_eint<64>) + return %1: !FHE.chunked_eint<64> +} + +// TODO: max/min