From 9b5a2e46dad07a996762e95bb057eb4483d2cf7a Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Alexandre=20P=C3=A9r=C3=A9?= Date: Fri, 9 Feb 2024 11:39:06 +0100 Subject: [PATCH] feat(compiler): support multi-circuit compilation --- .../include/concretelang/Common/Compat.h | 31 +- .../include/concretelang/Common/Protocol.h | 1 + .../Dialect/Concrete/Analysis/MemoryUsage.h | 2 +- .../Dialect/TFHE/Analysis/ExtractStatistics.h | 2 +- .../Support/CompilationFeedback.h | 59 +- .../concretelang/Support/CompilerEngine.h | 30 +- .../include/concretelang/Support/Encodings.h | 8 +- .../include/concretelang/Support/Pipeline.h | 4 +- .../Support/ProgramInfoGeneration.h | 4 +- .../concretelang/Support/V0Parameters.h | 4 +- .../TestLib/{TestCircuit.h => TestProgram.h} | 54 +- .../lib/Bindings/Python/CompilerAPIModule.cpp | 161 +++--- .../Python/concrete/compiler/__init__.py | 2 +- .../concrete/compiler/client_support.py | 19 +- .../concrete/compiler/compilation_feedback.py | 104 +++- .../concrete/compiler/compilation_options.py | 14 +- .../compiler/library_compilation_result.py | 12 +- .../concrete/compiler/library_support.py | 39 +- .../compiler/simulated_value_decrypter.py | 4 +- .../compiler/simulated_value_exporter.py | 6 +- .../concrete/compiler/value_decrypter.py | 6 +- .../concrete/compiler/value_exporter.py | 6 +- .../Dialect/Concrete/Analysis/MemoryUsage.cpp | 61 ++- .../TFHE/Analysis/ExtractStatistics.cpp | 130 +++-- .../lib/Support/CompilationFeedback.cpp | 513 ++++++++++-------- .../compiler/lib/Support/CompilerEngine.cpp | 66 +-- .../compiler/lib/Support/Encodings.cpp | 58 +- .../compiler/lib/Support/Pipeline.cpp | 5 +- .../lib/Support/ProgramInfoGeneration.cpp | 78 ++- .../compiler/lib/Support/V0Parameters.cpp | 10 +- .../concrete-compiler/compiler/src/main.cpp | 32 +- .../check_tests/BugReport/bug_report_785.mlir | 18 +- .../check_tests/BugReport/bug_report_858.mlir | 2 +- .../check_tests/BugReport/bug_report_890.mlir | 12 - .../Conversion/ConcreteToLLVM/gpu_ops.mlir | 2 +- .../TFHEGlobalParametrization/pbs_ks_bs.mlir | 2 +- .../Conversion/TFHEToConcrete/add_glwe.mlir | 2 +- .../TFHEToConcrete/add_glwe_int.mlir | 2 +- .../Conversion/TFHEToConcrete/bootstrap.mlir | 2 +- .../encode_expand_lut_for_bootstrap.mlir | 2 +- .../TFHEToConcrete/encode_lut_for_woppbs.mlir | 2 +- .../encode_plaintext_with_crt.mlir | 2 +- .../Conversion/TFHEToConcrete/keyswitch.mlir | 2 +- .../TFHEToConcrete/mul_glwe_int.mlir | 2 +- .../Conversion/TFHEToConcrete/neg_glwe.mlir | 2 +- .../TFHEToConcrete/sub_int_glwe.mlir | 2 +- .../tensor_exapand_collapse_shape.mlir | 2 +- .../TFHEToConcrete/tensor_from_elements.mlir | 2 +- .../TFHEToConcrete/tensor_identity.mlir | 2 +- .../check_tests/Dialect/FHE/folding.mlir | 2 +- .../Dialect/FHELinalg/folding.mlir | 2 +- .../Dialect/TFHE/no_optimization.mlir | 2 +- .../Dialect/TFHE/optimization.mlir | 2 +- .../check_tests/Transforms/batching.mlir | 2 +- .../check_tests/TypeInference/inference.mlir | 2 +- .../end_to_end_benchmark.cpp | 11 +- .../end_to_end_mlbench.cpp | 2 +- .../end_to_end_jit_aes_short.cc | 2 +- .../end_to_end_jit_auto_parallelization.cc | 2 +- .../end_to_end_jit_chunked_int.cc | 2 +- .../end_to_end_jit_distributed.cc | 2 +- .../end_to_end_tests/end_to_end_jit_lambda.cc | 2 +- .../end_to_end_tests/end_to_end_jit_test.cc | 54 +- .../end_to_end_tests/end_to_end_jit_test.h | 17 +- .../tests/end_to_end_tests/end_to_end_test.cc | 8 +- .../tests/end_to_end_tests/end_to_end_test.h | 3 +- .../compiler/tests/python/test_compilation.py | 120 ++-- .../compiler/tests/python/test_simulation.py | 2 +- .../compiler/tests/python/test_statistics.py | 5 +- .../Encodings/Encodings_unit_tests.cpp | 22 +- .../concretelang/SDFG/SDFG_unit_tests.cpp | 28 +- .../TestLib/testlib_unit_test.cpp | 34 +- .../src/concrete-optimizer.rs | 6 + .../src/cpp/concrete-optimizer.cpp | 7 + .../src/cpp/concrete-optimizer.hpp | 1 + .../src/dag/unparametrized.rs | 51 ++ .../concrete/fhe/compilation/server.py | 78 +-- .../src/concrete-protocol.capnp | 7 + 78 files changed, 1200 insertions(+), 865 deletions(-) rename compilers/concrete-compiler/compiler/include/concretelang/TestLib/{TestCircuit.h => TestProgram.h} (83%) delete mode 100644 compilers/concrete-compiler/compiler/tests/check_tests/BugReport/bug_report_890.mlir diff --git a/compilers/concrete-compiler/compiler/include/concretelang/Common/Compat.h b/compilers/concrete-compiler/compiler/include/concretelang/Common/Compat.h index b09b160d9..ae2664489 100644 --- a/compilers/concrete-compiler/compiler/include/concretelang/Common/Compat.h +++ b/compilers/concrete-compiler/compiler/include/concretelang/Common/Compat.h @@ -286,7 +286,6 @@ struct LibraryCompilationResult { /// The output directory path where the compilation artifacts have been /// generated. std::string outputDirPath; - std::string funcName; }; class LibrarySupport { @@ -318,14 +317,8 @@ public: return std::move(err); } - if (!options.mainFuncName.has_value()) { - return StreamStringError("Need to have a funcname to compile library"); - } - this->funcName = options.mainFuncName.value(); - auto result = std::make_unique(); result->outputDirPath = outputPath; - result->funcName = *options.mainFuncName; return std::move(result); } @@ -356,20 +349,15 @@ public: return std::move(err); } - if (!options.mainFuncName.has_value()) { - return StreamStringError("Need to have a funcname to compile library"); - } - this->funcName = options.mainFuncName.value(); - auto result = std::make_unique(); result->outputDirPath = outputPath; - result->funcName = *options.mainFuncName; return std::move(result); } /// Load the server lambda from the compilation result. llvm::Expected<::concretelang::serverlib::ServerLambda> - loadServerLambda(LibraryCompilationResult &result, bool useSimulation) { + loadServerLambda(LibraryCompilationResult &result, std::string circuitName, + bool useSimulation) { EXPECTED_TRY(auto programInfo, getProgramInfo()); EXPECTED_TRY(ServerProgram serverProgram, outcomeToExpected(ServerProgram::load(programInfo.asReader(), @@ -377,7 +365,7 @@ public: useSimulation))); EXPECTED_TRY( ServerCircuit serverCircuit, - outcomeToExpected(serverProgram.getServerCircuit(result.funcName))); + outcomeToExpected(serverProgram.getServerCircuit(circuitName))); return ::concretelang::serverlib::ServerLambda{serverCircuit, useSimulation}; } @@ -386,13 +374,6 @@ public: llvm::Expected<::concretelang::clientlib::ClientParameters> loadClientParameters(LibraryCompilationResult &result) { EXPECTED_TRY(auto programInfo, getProgramInfo()); - if (programInfo.asReader().getCircuits().size() > 1) { - return StreamStringError("ClientLambda: Provided program info contains " - "more than one circuit."); - } - if (programInfo.asReader().getCircuits()[0].getName() != result.funcName) { - return StreamStringError("Unexpected circuit name in program info"); - } auto secretKeys = std::vector<::concretelang::clientlib::LweSecretKeyParam>(); for (auto key : programInfo.asReader().getKeyset().getLweSecretKeys()) { @@ -441,15 +422,14 @@ public: loadCompilationResult() { auto result = std::make_unique(); result->outputDirPath = outputPath; - result->funcName = funcName; return std::move(result); } - llvm::Expected + llvm::Expected loadCompilationFeedback(LibraryCompilationResult &result) { auto path = CompilerEngine::Library::getCompilationFeedbackPath( result.outputDirPath); - auto feedback = CompilationFeedback::load(path); + auto feedback = ProgramCompilationFeedback::load(path); if (feedback.has_error()) { return StreamStringError(feedback.error().mesg); } @@ -499,7 +479,6 @@ public: private: std::string outputPath; - std::string funcName; std::string runtimeLibraryPath; /// Flags to select generated artifacts bool generateSharedLib; diff --git a/compilers/concrete-compiler/compiler/include/concretelang/Common/Protocol.h b/compilers/concrete-compiler/compiler/include/concretelang/Common/Protocol.h index 3b4c8273f..092d43735 100644 --- a/compilers/concrete-compiler/compiler/include/concretelang/Common/Protocol.h +++ b/compilers/concrete-compiler/compiler/include/concretelang/Common/Protocol.h @@ -240,6 +240,7 @@ private: template struct Message; template struct Message; +template struct Message; template struct Message; template struct Message; diff --git a/compilers/concrete-compiler/compiler/include/concretelang/Dialect/Concrete/Analysis/MemoryUsage.h b/compilers/concrete-compiler/compiler/include/concretelang/Dialect/Concrete/Analysis/MemoryUsage.h index 38d1f5285..df4d6ae65 100644 --- a/compilers/concrete-compiler/compiler/include/concretelang/Dialect/Concrete/Analysis/MemoryUsage.h +++ b/compilers/concrete-compiler/compiler/include/concretelang/Dialect/Concrete/Analysis/MemoryUsage.h @@ -15,7 +15,7 @@ namespace mlir { namespace concretelang { std::unique_ptr> -createMemoryUsagePass(CompilationFeedback &feedback); +createMemoryUsagePass(ProgramCompilationFeedback &feedback); } // namespace concretelang } // namespace mlir diff --git a/compilers/concrete-compiler/compiler/include/concretelang/Dialect/TFHE/Analysis/ExtractStatistics.h b/compilers/concrete-compiler/compiler/include/concretelang/Dialect/TFHE/Analysis/ExtractStatistics.h index 866ff6d04..bc945d416 100644 --- a/compilers/concrete-compiler/compiler/include/concretelang/Dialect/TFHE/Analysis/ExtractStatistics.h +++ b/compilers/concrete-compiler/compiler/include/concretelang/Dialect/TFHE/Analysis/ExtractStatistics.h @@ -15,7 +15,7 @@ namespace mlir { namespace concretelang { std::unique_ptr> -createStatisticExtractionPass(CompilationFeedback &feedback); +createStatisticExtractionPass(ProgramCompilationFeedback &feedback); } // namespace concretelang } // namespace mlir diff --git a/compilers/concrete-compiler/compiler/include/concretelang/Support/CompilationFeedback.h b/compilers/concrete-compiler/compiler/include/concretelang/Support/CompilationFeedback.h index 23f2bc121..91deb226a 100644 --- a/compilers/concrete-compiler/compiler/include/concretelang/Support/CompilationFeedback.h +++ b/compilers/concrete-compiler/compiler/include/concretelang/Support/CompilationFeedback.h @@ -44,27 +44,13 @@ enum class KeyType { struct Statistic { std::string location; PrimitiveOperation operation; - std::vector> keys; - size_t count; + std::vector> keys; + int64_t count; }; -struct CompilationFeedback { - double complexity; - - /// @brief Probability of error for every PBS. - double pError; - - /// @brief Probability of error for the whole programs. - double globalPError; - - /// @brief the total number of bytes of secret keys - uint64_t totalSecretKeysSize; - - /// @brief the total number of bytes of bootstrap keys - uint64_t totalBootstrapKeysSize; - - /// @brief the total number of bytes of keyswitch keys - uint64_t totalKeyswitchKeysSize; +struct CircuitCompilationFeedback { + /// @brief the name of circuit. + std::string name; /// @brief the total number of bytes of inputs uint64_t totalInputsSize; @@ -81,25 +67,52 @@ struct CompilationFeedback { /// @brief memory usage per location std::map memoryUsagePerLoc; + /// Fill the sizes from the program info. + void fillFromCircuitInfo(concreteprotocol::CircuitInfo::Reader params); +}; + +struct ProgramCompilationFeedback { + double complexity; + + /// @brief Probability of error for every PBS. + double pError; + + /// @brief Probability of error for the whole programs. + double globalPError; + + /// @brief the total number of bytes of secret keys + uint64_t totalSecretKeysSize; + + /// @brief the total number of bytes of bootstrap keys + uint64_t totalBootstrapKeysSize; + + /// @brief the total number of bytes of keyswitch keys + uint64_t totalKeyswitchKeysSize; + + /// @brief the feedback for each circuit + std::vector circuitFeedbacks; + /// Fill the sizes from the program info. void fillFromProgramInfo(const Message ¶ms); /// Load the compilation feedback from a path - static outcome::checked + static outcome::checked load(std::string path); }; -llvm::json::Value toJSON(const mlir::concretelang::CompilationFeedback &); +llvm::json::Value +toJSON(const mlir::concretelang::ProgramCompilationFeedback &); bool fromJSON(const llvm::json::Value, - mlir::concretelang::CompilationFeedback &, llvm::json::Path); + mlir::concretelang::ProgramCompilationFeedback &, + llvm::json::Path); } // namespace concretelang } // namespace mlir static inline llvm::raw_ostream & operator<<(llvm::raw_string_ostream &OS, - mlir::concretelang::CompilationFeedback cp) { + mlir::concretelang::ProgramCompilationFeedback cp) { return OS << llvm::formatv("{0:2}", toJSON(cp)); } #endif diff --git a/compilers/concrete-compiler/compiler/include/concretelang/Support/CompilerEngine.h b/compilers/concrete-compiler/compiler/include/concretelang/Support/CompilerEngine.h index c505d1687..7422ea094 100644 --- a/compilers/concrete-compiler/compiler/include/concretelang/Support/CompilerEngine.h +++ b/compilers/concrete-compiler/compiler/include/concretelang/Support/CompilerEngine.h @@ -79,8 +79,6 @@ struct CompilationOptions { std::optional> fhelinalgTileSizes; - std::optional mainFuncName; - optimizer::Config optimizerConfig; /// When decomposing big integers into chunks, chunkSize is the total number @@ -92,7 +90,9 @@ struct CompilationOptions { /// When compiling from a dialect lower than FHE, one needs to provide /// encodings info manually to allow the client lib to be generated. - std::optional> encodings; + std::optional> encodings; + + bool skipProgramInfo; bool compressEvaluationKeys; @@ -102,20 +102,14 @@ struct CompilationOptions { maxBatchSize(std::numeric_limits::max()), emitSDFGOps(false), unrollLoopsWithSDFGConvertibleOps(false), dataflowParallelize(false), optimizeTFHE(true), simulate(false), emitGPUOps(false), - mainFuncName(std::nullopt), optimizerConfig(optimizer::DEFAULT_CONFIG), - chunkIntegers(false), chunkSize(4), chunkWidth(2), - encodings(std::nullopt), compressEvaluationKeys(false){}; - - CompilationOptions(std::string funcname) : CompilationOptions() { - mainFuncName = funcname; - } + optimizerConfig(optimizer::DEFAULT_CONFIG), chunkIntegers(false), + chunkSize(4), chunkWidth(2), encodings(std::nullopt), + skipProgramInfo(false), compressEvaluationKeys(false){}; /// @brief Constructor for CompilationOptions with default parameters for a /// specific backend. - /// @param funcname The name of the function to compile. /// @param backend The backend to target. - CompilationOptions(std::string funcname, enum Backend backend) - : CompilationOptions(funcname) { + CompilationOptions(enum Backend backend) : CompilationOptions() { switch (backend) { case Backend::CPU: loopParallelize = true; @@ -143,7 +137,7 @@ public: std::optional> mlirModuleRef; std::optional> programInfo; - std::optional feedback; + std::optional feedback; std::unique_ptr llvmModule; std::optional fheContext; @@ -157,7 +151,7 @@ public: /// Path to the runtime library. Will be linked to the output library if set std::string runtimeLibraryPath; bool cleanUp; - mlir::concretelang::CompilationFeedback compilationFeedback; + mlir::concretelang::ProgramCompilationFeedback compilationFeedback; Message programInfo; public: @@ -280,7 +274,7 @@ public: CompilerEngine(std::shared_ptr compilationContext) : overrideMaxEintPrecision(), overrideMaxMANP(), compilerOptions(), - generateProgramInfo(compilerOptions.mainFuncName.has_value()), + generateProgramInfo(true), enablePass([](mlir::Pass *pass) { return true; }), compilationContext(compilationContext) {} @@ -325,10 +319,6 @@ public: if (compilerOptions.v0FHEConstraints.has_value()) { setFHEConstraints(*compilerOptions.v0FHEConstraints); } - - if (compilerOptions.mainFuncName.has_value()) { - setGenerateProgramInfo(true); - } } CompilationOptions &getCompilationOptions() { return compilerOptions; } diff --git a/compilers/concrete-compiler/compiler/include/concretelang/Support/Encodings.h b/compilers/concrete-compiler/compiler/include/concretelang/Support/Encodings.h index 3a66be4f6..75348ed6f 100644 --- a/compilers/concrete-compiler/compiler/include/concretelang/Support/Encodings.h +++ b/compilers/concrete-compiler/compiler/include/concretelang/Support/Encodings.h @@ -35,11 +35,11 @@ namespace mlir { namespace concretelang { namespace encodings { -llvm::Expected> -getCircuitEncodings(llvm::StringRef functionName, mlir::ModuleOp module); +llvm::Expected> +getProgramEncoding(mlir::ModuleOp module); -void setCircuitEncodingModes( - Message &info, +void setProgramEncodingModes( + Message &info, std::optional< Message> maybeChunk, diff --git a/compilers/concrete-compiler/compiler/include/concretelang/Support/Pipeline.h b/compilers/concrete-compiler/compiler/include/concretelang/Support/Pipeline.h index d79e7e3a6..b5e72d8b9 100644 --- a/compilers/concrete-compiler/compiler/include/concretelang/Support/Pipeline.h +++ b/compilers/concrete-compiler/compiler/include/concretelang/Support/Pipeline.h @@ -77,7 +77,7 @@ normalizeTFHEKeys(mlir::MLIRContext &context, mlir::ModuleOp &module, mlir::LogicalResult extractTFHEStatistics(mlir::MLIRContext &context, mlir::ModuleOp &module, std::function enablePass, - CompilationFeedback &feedback); + ProgramCompilationFeedback &feedback); mlir::LogicalResult lowerTFHEToConcrete(mlir::MLIRContext &context, mlir::ModuleOp &module, @@ -86,7 +86,7 @@ lowerTFHEToConcrete(mlir::MLIRContext &context, mlir::ModuleOp &module, mlir::LogicalResult computeMemoryUsage(mlir::MLIRContext &context, mlir::ModuleOp &module, std::function enablePass, - CompilationFeedback &feedback); + ProgramCompilationFeedback &feedback); mlir::LogicalResult lowerConcreteLinalgToLoops(mlir::MLIRContext &context, mlir::ModuleOp &module, diff --git a/compilers/concrete-compiler/compiler/include/concretelang/Support/ProgramInfoGeneration.h b/compilers/concrete-compiler/compiler/include/concretelang/Support/ProgramInfoGeneration.h index 08aad6bc1..108f155bf 100644 --- a/compilers/concrete-compiler/compiler/include/concretelang/Support/ProgramInfoGeneration.h +++ b/compilers/concrete-compiler/compiler/include/concretelang/Support/ProgramInfoGeneration.h @@ -21,8 +21,8 @@ namespace concretelang { llvm::Expected> createProgramInfoFromTfheDialect( - mlir::ModuleOp module, llvm::StringRef functionName, int bitsOfSecurity, - Message &encodings, + mlir::ModuleOp module, int bitsOfSecurity, + const Message &encodings, bool compressEvaluationKeys); } // namespace concretelang diff --git a/compilers/concrete-compiler/compiler/include/concretelang/Support/V0Parameters.h b/compilers/concrete-compiler/compiler/include/concretelang/Support/V0Parameters.h index d2ceb7ca5..8d2a687a6 100644 --- a/compilers/concrete-compiler/compiler/include/concretelang/Support/V0Parameters.h +++ b/compilers/concrete-compiler/compiler/include/concretelang/Support/V0Parameters.h @@ -151,10 +151,10 @@ typedef std::variant Solution; } // namespace optimizer -struct CompilationFeedback; +struct ProgramCompilationFeedback; llvm::Expected -getSolution(optimizer::Description &descr, CompilationFeedback &feedback, +getSolution(optimizer::Description &descr, ProgramCompilationFeedback &feedback, optimizer::Config optimizerConfig); // As for now the solution which contains a crt encoding is mono parameter only diff --git a/compilers/concrete-compiler/compiler/include/concretelang/TestLib/TestCircuit.h b/compilers/concrete-compiler/compiler/include/concretelang/TestLib/TestProgram.h similarity index 83% rename from compilers/concrete-compiler/compiler/include/concretelang/TestLib/TestCircuit.h rename to compilers/concrete-compiler/compiler/include/concretelang/TestLib/TestProgram.h index 077192ac9..88ebc3d41 100644 --- a/compilers/concrete-compiler/compiler/include/concretelang/TestLib/TestCircuit.h +++ b/compilers/concrete-compiler/compiler/include/concretelang/TestLib/TestProgram.h @@ -36,26 +36,26 @@ using concretelang::values::Value; namespace concretelang { namespace testlib { -class TestCircuit { +class TestProgram { public: - TestCircuit(mlir::concretelang::CompilationOptions options) + TestProgram(mlir::concretelang::CompilationOptions options) : artifactDirectory(createTempFolderIn(getSystemTempFolderPath())), compiler(mlir::concretelang::CompilationContext::createShared()), encryptionCsprng(std::make_shared(0)) { compiler.setCompilationOptions(options); } - TestCircuit(TestCircuit &&tc) + TestProgram(TestProgram &&tc) : artifactDirectory(tc.artifactDirectory), compiler(tc.compiler), library(tc.library), keyset(tc.keyset), encryptionCsprng(tc.encryptionCsprng) { tc.artifactDirectory = ""; }; - TestCircuit(TestCircuit &tc) = delete; + TestProgram(TestProgram &tc) = delete; - ~TestCircuit() { + ~TestProgram() { auto d = getArtifactDirectory(); if (d.empty()) return; @@ -93,16 +93,17 @@ public: return outcome::success(); } - Result> call(std::vector inputs) { + Result> call(std::vector inputs, + std::string name = "main") { // preprocess arguments auto preparedArgs = std::vector(); - OUTCOME_TRY(auto clientCircuit, getClientCircuit()); + OUTCOME_TRY(auto clientCircuit, getClientCircuit(name)); for (size_t i = 0; i < inputs.size(); i++) { OUTCOME_TRY(auto preparedInput, clientCircuit.prepareInput(inputs[i], i)); preparedArgs.push_back(preparedInput); } // Call server - OUTCOME_TRY(auto returns, callServer(preparedArgs)); + OUTCOME_TRY(auto returns, callServer(preparedArgs, name)); // postprocess arguments std::vector processedOutputs(returns.size()); for (size_t i = 0; i < processedOutputs.size(); i++) { @@ -113,17 +114,18 @@ public: } Result> compose_n_times(std::vector inputs, - size_t n) { + size_t n, + std::string name = "main") { // preprocess arguments auto preparedArgs = std::vector(); - OUTCOME_TRY(auto clientCircuit, getClientCircuit()); + OUTCOME_TRY(auto clientCircuit, getClientCircuit(name)); for (size_t i = 0; i < inputs.size(); i++) { OUTCOME_TRY(auto preparedInput, clientCircuit.prepareInput(inputs[i], i)); preparedArgs.push_back(preparedInput); } // Call server multiple times in a row for (size_t i = 0; i < n; i++) { - OUTCOME_TRY(preparedArgs, callServer(preparedArgs)); + OUTCOME_TRY(preparedArgs, callServer(preparedArgs, name)); } // postprocess arguments std::vector processedOutputs(preparedArgs.size()); @@ -135,9 +137,9 @@ public: } Result> - callServer(std::vector inputs) { + callServer(std::vector inputs, std::string name = "main") { std::vector returns; - OUTCOME_TRY(auto serverCircuit, getServerCircuit()); + OUTCOME_TRY(auto serverCircuit, getServerCircuit(name)); if (compiler.getCompilationOptions().simulate) { OUTCOME_TRY(returns, serverCircuit.simulate(inputs)); } else { @@ -146,29 +148,25 @@ public: return returns; } - Result getClientCircuit() { + Result getClientCircuit(std::string name = "main") { OUTCOME_TRY(auto lib, getLibrary()); OUTCOME_TRY(auto ks, getKeyset()); auto programInfo = lib.getProgramInfo(); OUTCOME_TRY(auto clientProgram, ClientProgram::create(programInfo, ks.client, encryptionCsprng, isSimulation())); - OUTCOME_TRY(auto clientCircuit, - clientProgram.getClientCircuit( - programInfo.asReader().getCircuits()[0].getName())); + OUTCOME_TRY(auto clientCircuit, clientProgram.getClientCircuit(name)); return clientCircuit; } - Result getServerCircuit() { + Result getServerCircuit(std::string name = "main") { OUTCOME_TRY(auto lib, getLibrary()); auto programInfo = lib.getProgramInfo(); OUTCOME_TRY(auto serverProgram, ServerProgram::load(programInfo, lib.getSharedLibraryPath(artifactDirectory), isSimulation())); - OUTCOME_TRY(auto serverCircuit, - serverProgram.getServerCircuit( - programInfo.asReader().getCircuits()[0].getName())); + OUTCOME_TRY(auto serverCircuit, serverProgram.getServerCircuit(name)); return serverCircuit; } @@ -177,14 +175,14 @@ private: Result getLibrary() { if (!library.has_value()) { - return StringError("TestCircuit: compilation has not been done\n"); + return StringError("TestProgram: compilation has not been done\n"); } return *library; } Result getKeyset() { if (!keyset.has_value()) { - return StringError("TestCircuit: keyset has not been generated\n"); + return StringError("TestProgram: keyset has not been generated\n"); } return *keyset; } @@ -206,10 +204,10 @@ private: void deleteFolder(const std::string &folder) { auto ec = std::error_code(); - llvm::errs() << "TestCircuit: delete artifact directory(" << folder + llvm::errs() << "TestProgram: delete artifact directory(" << folder << ")\n"; if (!std::filesystem::remove_all(folder, ec)) { - llvm::errs() << "TestCircuit: fail to delete directory(" << folder + llvm::errs() << "TestProgram: fail to delete directory(" << folder << "), error(" << ec.message() << ")\n"; } } @@ -232,10 +230,10 @@ private: for (size_t i = 0; i < 5; i++) { auto pathString = new_path(); auto ec = std::error_code(); - llvm::errs() << "TestCircuit: create temporary directory(" << pathString + llvm::errs() << "TestProgram: create temporary directory(" << pathString << ")\n"; if (!std::filesystem::create_directory(pathString, ec)) { - llvm::errs() << "TestCircuit: fail to create temporary directory(" + llvm::errs() << "TestProgram: fail to create temporary directory(" << pathString << "), "; if (ec) { llvm::errs() << "already exists"; @@ -243,7 +241,7 @@ private: llvm::errs() << "error(" << ec.message() << ")"; } } else { - llvm::errs() << "TestCircuit: directory(" << pathString + llvm::errs() << "TestProgram: directory(" << pathString << ") successfuly created\n"; return pathString; } diff --git a/compilers/concrete-compiler/compiler/lib/Bindings/Python/CompilerAPIModule.cpp b/compilers/concrete-compiler/compiler/lib/Bindings/Python/CompilerAPIModule.cpp index 14b8f8242..d850dadfa 100644 --- a/compilers/concrete-compiler/compiler/lib/Bindings/Python/CompilerAPIModule.cpp +++ b/compilers/concrete-compiler/compiler/lib/Bindings/Python/CompilerAPIModule.cpp @@ -102,7 +102,8 @@ concretelang::clientlib::ClientParameters library_load_client_parameters( return *clientParameters; } -mlir::concretelang::CompilationFeedback library_load_compilation_feedback( +mlir::concretelang::ProgramCompilationFeedback +library_load_compilation_feedback( LibrarySupport_Py support, mlir::concretelang::LibraryCompilationResult &result) { GET_OR_THROW_LLVM_EXPECTED(compilationFeedback, @@ -113,9 +114,10 @@ mlir::concretelang::CompilationFeedback library_load_compilation_feedback( concretelang::serverlib::ServerLambda library_load_server_lambda(LibrarySupport_Py support, mlir::concretelang::LibraryCompilationResult &result, - bool useSimulation) { + std::string circuitName, bool useSimulation) { GET_OR_THROW_LLVM_EXPECTED( - serverLambda, support.support.loadServerLambda(result, useSimulation)); + serverLambda, + support.support.loadServerLambda(result, circuitName, useSimulation)); return *serverLambda; } @@ -176,7 +178,8 @@ key_set(concretelang::clientlib::ClientParameters clientParameters, std::unique_ptr encrypt_arguments(concretelang::clientlib::ClientParameters clientParameters, concretelang::clientlib::KeySet &keySet, - llvm::ArrayRef args) { + llvm::ArrayRef args, + const std::string &circuitName) { auto maybeProgram = ::concretelang::clientlib::ClientProgram::create( clientParameters.programInfo.asReader(), keySet.keyset.client, std::make_shared<::concretelang::csprng::EncryptionCSPRNG>( @@ -185,15 +188,10 @@ encrypt_arguments(concretelang::clientlib::ClientParameters clientParameters, if (maybeProgram.has_failure()) { throw std::runtime_error(maybeProgram.as_failure().error().mesg); } - auto circuit = maybeProgram.value() - .getClientCircuit(clientParameters.programInfo.asReader() - .getCircuits()[0] - .getName()) - .value(); + auto circuit = maybeProgram.value().getClientCircuit(circuitName).value(); std::vector output; for (size_t i = 0; i < args.size(); i++) { - auto info = - clientParameters.programInfo.asReader().getCircuits()[0].getInputs()[i]; + auto info = circuit.getCircuitInfo().asReader().getInputs()[i]; auto typeTransformer = getPythonTypeTransformer(info); auto input = typeTransformer(args[i]->value); auto maybePrepared = circuit.prepareInput(input, i); @@ -211,7 +209,8 @@ encrypt_arguments(concretelang::clientlib::ClientParameters clientParameters, std::vector decrypt_result(concretelang::clientlib::ClientParameters clientParameters, concretelang::clientlib::KeySet &keySet, - concretelang::clientlib::PublicResult &publicResult) { + concretelang::clientlib::PublicResult &publicResult, + const std::string &circuitName) { auto maybeProgram = ::concretelang::clientlib::ClientProgram::create( clientParameters.programInfo.asReader(), keySet.keyset.client, std::make_shared<::concretelang::csprng::EncryptionCSPRNG>( @@ -220,11 +219,7 @@ decrypt_result(concretelang::clientlib::ClientParameters clientParameters, if (maybeProgram.has_failure()) { throw std::runtime_error(maybeProgram.as_failure().error().mesg); } - auto circuit = maybeProgram.value() - .getClientCircuit(clientParameters.programInfo.asReader() - .getCircuits()[0] - .getName()) - .value(); + auto circuit = maybeProgram.value().getClientCircuit(circuitName).value(); std::vector results; for (auto e : llvm::enumerate(publicResult.values)) { auto maybeProcessed = circuit.processOutput(e.value(), e.index()); @@ -370,9 +365,10 @@ valueSerialize(const concretelang::clientlib::SharedScalarOrTensorData &value) { return maybeString.value(); } -concretelang::clientlib::ValueExporter createValueExporter( - concretelang::clientlib::KeySet &keySet, - concretelang::clientlib::ClientParameters &clientParameters) { +concretelang::clientlib::ValueExporter +createValueExporter(concretelang::clientlib::KeySet &keySet, + concretelang::clientlib::ClientParameters &clientParameters, + const std::string &circuitName) { auto maybeProgram = ::concretelang::clientlib::ClientProgram::create( clientParameters.programInfo.asReader(), keySet.keyset.client, std::make_shared<::concretelang::csprng::EncryptionCSPRNG>( @@ -381,13 +377,13 @@ concretelang::clientlib::ValueExporter createValueExporter( if (maybeProgram.has_failure()) { throw std::runtime_error(maybeProgram.as_failure().error().mesg); } - auto maybeCircuit = maybeProgram.value().getClientCircuit( - clientParameters.programInfo.asReader().getCircuits()[0].getName()); + auto maybeCircuit = maybeProgram.value().getClientCircuit(circuitName); return ::concretelang::clientlib::ValueExporter{maybeCircuit.value()}; } concretelang::clientlib::SimulatedValueExporter createSimulatedValueExporter( - concretelang::clientlib::ClientParameters &clientParameters) { + concretelang::clientlib::ClientParameters &clientParameters, + const std::string &circuitName) { auto maybeProgram = ::concretelang::clientlib::ClientProgram::create( clientParameters.programInfo, ::concretelang::keysets::ClientKeyset(), @@ -397,15 +393,15 @@ concretelang::clientlib::SimulatedValueExporter createSimulatedValueExporter( if (maybeProgram.has_failure()) { throw std::runtime_error(maybeProgram.as_failure().error().mesg); } - auto maybeCircuit = maybeProgram.value().getClientCircuit( - clientParameters.programInfo.asReader().getCircuits()[0].getName()); + auto maybeCircuit = maybeProgram.value().getClientCircuit(circuitName); return ::concretelang::clientlib::SimulatedValueExporter{ maybeCircuit.value()}; } concretelang::clientlib::ValueDecrypter createValueDecrypter( concretelang::clientlib::KeySet &keySet, - concretelang::clientlib::ClientParameters &clientParameters) { + concretelang::clientlib::ClientParameters &clientParameters, + const std::string &circuitName) { auto maybeProgram = ::concretelang::clientlib::ClientProgram::create( clientParameters.programInfo.asReader(), keySet.keyset.client, @@ -415,13 +411,13 @@ concretelang::clientlib::ValueDecrypter createValueDecrypter( if (maybeProgram.has_failure()) { throw std::runtime_error(maybeProgram.as_failure().error().mesg); } - auto maybeCircuit = maybeProgram.value().getClientCircuit( - clientParameters.programInfo.asReader().getCircuits()[0].getName()); + auto maybeCircuit = maybeProgram.value().getClientCircuit(circuitName); return ::concretelang::clientlib::ValueDecrypter{maybeCircuit.value()}; } concretelang::clientlib::SimulatedValueDecrypter createSimulatedValueDecrypter( - concretelang::clientlib::ClientParameters &clientParameters) { + concretelang::clientlib::ClientParameters &clientParameters, + const std::string &circuitName) { auto maybeProgram = ::concretelang::clientlib::ClientProgram::create( clientParameters.programInfo.asReader(), @@ -432,8 +428,7 @@ concretelang::clientlib::SimulatedValueDecrypter createSimulatedValueDecrypter( if (maybeProgram.has_failure()) { throw std::runtime_error(maybeProgram.as_failure().error().mesg); } - auto maybeCircuit = maybeProgram.value().getClientCircuit( - clientParameters.programInfo.asReader().getCircuits()[0].getName()); + auto maybeCircuit = maybeProgram.value().getClientCircuit(circuitName); return ::concretelang::clientlib::SimulatedValueDecrypter{ maybeCircuit.value()}; } @@ -699,14 +694,9 @@ void mlir::concretelang::python::populateCompilerAPISubmodule( .export_values(); pybind11::class_(m, "CompilationOptions") - .def(pybind11::init( - [](std::string funcname, mlir::concretelang::Backend backend) { - return CompilationOptions(funcname, backend); - })) - .def("set_funcname", - [](CompilationOptions &options, std::string funcname) { - options.mainFuncName = funcname; - }) + .def(pybind11::init([](mlir::concretelang::Backend backend) { + return CompilationOptions(backend); + })) .def("set_verify_diagnostics", [](CompilationOptions &options, bool b) { options.verifyDiagnostics = b; @@ -828,34 +818,46 @@ void mlir::concretelang::python::populateCompilerAPISubmodule( .def_readonly("keys", &mlir::concretelang::Statistic::keys) .def_readonly("count", &mlir::concretelang::Statistic::count); - pybind11::class_( - m, "CompilationFeedback") + pybind11::class_( + m, "ProgramCompilationFeedback") .def_readonly("complexity", - &mlir::concretelang::CompilationFeedback::complexity) - .def_readonly("p_error", &mlir::concretelang::CompilationFeedback::pError) - .def_readonly("global_p_error", - &mlir::concretelang::CompilationFeedback::globalPError) + &mlir::concretelang::ProgramCompilationFeedback::complexity) + .def_readonly("p_error", + &mlir::concretelang::ProgramCompilationFeedback::pError) + .def_readonly( + "global_p_error", + &mlir::concretelang::ProgramCompilationFeedback::globalPError) .def_readonly( "total_secret_keys_size", - &mlir::concretelang::CompilationFeedback::totalSecretKeysSize) + &mlir::concretelang::ProgramCompilationFeedback::totalSecretKeysSize) + .def_readonly("total_bootstrap_keys_size", + &mlir::concretelang::ProgramCompilationFeedback:: + totalBootstrapKeysSize) + .def_readonly("total_keyswitch_keys_size", + &mlir::concretelang::ProgramCompilationFeedback:: + totalKeyswitchKeysSize) .def_readonly( - "total_bootstrap_keys_size", - &mlir::concretelang::CompilationFeedback::totalBootstrapKeysSize) + "circuit_feedbacks", + &mlir::concretelang::ProgramCompilationFeedback::circuitFeedbacks); + + pybind11::class_( + m, "CircuitCompilationFeedback") + .def_readonly("name", + &mlir::concretelang::CircuitCompilationFeedback::name) .def_readonly( - "total_keyswitch_keys_size", - &mlir::concretelang::CompilationFeedback::totalKeyswitchKeysSize) - .def_readonly("total_inputs_size", - &mlir::concretelang::CompilationFeedback::totalInputsSize) - .def_readonly("total_output_size", - &mlir::concretelang::CompilationFeedback::totalOutputsSize) + "total_inputs_size", + &mlir::concretelang::CircuitCompilationFeedback::totalInputsSize) .def_readonly( - "crt_decompositions_of_outputs", - &mlir::concretelang::CompilationFeedback::crtDecompositionsOfOutputs) + "total_output_size", + &mlir::concretelang::CircuitCompilationFeedback::totalOutputsSize) + .def_readonly("crt_decompositions_of_outputs", + &mlir::concretelang::CircuitCompilationFeedback:: + crtDecompositionsOfOutputs) .def_readonly("statistics", - &mlir::concretelang::CompilationFeedback::statistics) + &mlir::concretelang::CircuitCompilationFeedback::statistics) .def_readonly( "memory_usage_per_location", - &mlir::concretelang::CompilationFeedback::memoryUsagePerLoc); + &mlir::concretelang::CircuitCompilationFeedback::memoryUsagePerLoc); pybind11::class_>( @@ -872,11 +874,8 @@ void mlir::concretelang::python::populateCompilerAPISubmodule( pybind11::class_( m, "LibraryCompilationResult") - .def(pybind11::init([](std::string outputDirPath, std::string funcname) { - return mlir::concretelang::LibraryCompilationResult{ - outputDirPath, - funcname, - }; + .def(pybind11::init([](std::string outputDirPath) { + return mlir::concretelang::LibraryCompilationResult{outputDirPath}; })); pybind11::class_<::concretelang::serverlib::ServerLambda>(m, "LibraryLambda"); pybind11::class_(m, "LibrarySupport") @@ -920,8 +919,9 @@ void mlir::concretelang::python::populateCompilerAPISubmodule( "load_server_lambda", [](LibrarySupport_Py &support, mlir::concretelang::LibraryCompilationResult &result, - bool useSimulation) { - return library_load_server_lambda(support, result, useSimulation); + std::string circuitName, bool useSimulation) { + return library_load_server_lambda(support, result, circuitName, + useSimulation); }, pybind11::return_value_policy::reference) .def("server_call", @@ -974,19 +974,22 @@ void mlir::concretelang::python::populateCompilerAPISubmodule( "encrypt_arguments", [](::concretelang::clientlib::ClientParameters clientParameters, ::concretelang::clientlib::KeySet &keySet, - std::vector args) { + std::vector args, const std::string &circuitName) { std::vector argsRef; for (auto i = 0u; i < args.size(); i++) { argsRef.push_back(args[i].ptr.get()); } - return encrypt_arguments(clientParameters, keySet, argsRef); + return encrypt_arguments(clientParameters, keySet, argsRef, + circuitName); }) .def_static( "decrypt_result", [](::concretelang::clientlib::ClientParameters clientParameters, ::concretelang::clientlib::KeySet &keySet, - ::concretelang::clientlib::PublicResult &publicResult) { - return decrypt_result(clientParameters, keySet, publicResult); + ::concretelang::clientlib::PublicResult &publicResult, + const std::string &circuitName) { + return decrypt_result(clientParameters, keySet, publicResult, + circuitName); }); pybind11::class_<::concretelang::clientlib::KeySetCache>(m, "KeySetCache") .def(pybind11::init()); @@ -1187,8 +1190,9 @@ void mlir::concretelang::python::populateCompilerAPISubmodule( .def_static( "create", [](::concretelang::clientlib::KeySet &keySet, - ::concretelang::clientlib::ClientParameters &clientParameters) { - return createValueExporter(keySet, clientParameters); + ::concretelang::clientlib::ClientParameters &clientParameters, + const std::string &circuitName) { + return createValueExporter(keySet, clientParameters, circuitName); }) .def("export_scalar", [](::concretelang::clientlib::ValueExporter &exporter, @@ -1233,8 +1237,9 @@ void mlir::concretelang::python::populateCompilerAPISubmodule( m, "SimulatedValueExporter") .def_static( "create", - [](::concretelang::clientlib::ClientParameters &clientParameters) { - return createSimulatedValueExporter(clientParameters); + [](::concretelang::clientlib::ClientParameters &clientParameters, + const std::string &circuitName) { + return createSimulatedValueExporter(clientParameters, circuitName); }) .def("export_scalar", [](::concretelang::clientlib::SimulatedValueExporter &exporter, @@ -1279,8 +1284,9 @@ void mlir::concretelang::python::populateCompilerAPISubmodule( .def_static( "create", [](::concretelang::clientlib::KeySet &keySet, - ::concretelang::clientlib::ClientParameters &clientParameters) { - return createValueDecrypter(keySet, clientParameters); + ::concretelang::clientlib::ClientParameters &clientParameters, + const std::string &circuitName) { + return createValueDecrypter(keySet, clientParameters, circuitName); }) .def("decrypt", [](::concretelang::clientlib::ValueDecrypter &decrypter, @@ -1303,8 +1309,9 @@ void mlir::concretelang::python::populateCompilerAPISubmodule( m, "SimulatedValueDecrypter") .def_static( "create", - [](::concretelang::clientlib::ClientParameters &clientParameters) { - return createSimulatedValueDecrypter(clientParameters); + [](::concretelang::clientlib::ClientParameters &clientParameters, + const std::string &circuitName) { + return createSimulatedValueDecrypter(clientParameters, circuitName); }) .def("decrypt", [](::concretelang::clientlib::SimulatedValueDecrypter &decrypter, diff --git a/compilers/concrete-compiler/compiler/lib/Bindings/Python/concrete/compiler/__init__.py b/compilers/concrete-compiler/compiler/lib/Bindings/Python/concrete/compiler/__init__.py index 14cdc504b..7a712cbd3 100644 --- a/compilers/concrete-compiler/compiler/lib/Bindings/Python/concrete/compiler/__init__.py +++ b/compilers/concrete-compiler/compiler/lib/Bindings/Python/concrete/compiler/__init__.py @@ -21,7 +21,7 @@ from .compilation_options import CompilationOptions, Encoding from .compilation_context import CompilationContext from .key_set_cache import KeySetCache from .client_parameters import ClientParameters -from .compilation_feedback import CompilationFeedback +from .compilation_feedback import ProgramCompilationFeedback, CircuitCompilationFeedback from .key_set import KeySet from .public_result import PublicResult from .public_arguments import PublicArguments diff --git a/compilers/concrete-compiler/compiler/lib/Bindings/Python/concrete/compiler/client_support.py b/compilers/concrete-compiler/compiler/lib/Bindings/Python/concrete/compiler/client_support.py index bc9d91fed..01c051e3d 100644 --- a/compilers/concrete-compiler/compiler/lib/Bindings/Python/concrete/compiler/client_support.py +++ b/compilers/concrete-compiler/compiler/lib/Bindings/Python/concrete/compiler/client_support.py @@ -116,6 +116,7 @@ class ClientSupport(WrapperCpp): client_parameters: ClientParameters, keyset: KeySet, args: List[Union[int, np.ndarray]], + circuit_name: str = "main", ) -> PublicArguments: """Prepare arguments for encrypted computation. @@ -126,10 +127,12 @@ class ClientSupport(WrapperCpp): client_parameters (ClientParameters): client parameters specification keyset (KeySet): keyset used to encrypt arguments that require encryption args (List[Union[int, np.ndarray]]): list of scalar or tensor arguments + circuit_name(str): the name of the circuit for which to encrypt Raises: TypeError: if client_parameters is not of type ClientParameters TypeError: if keyset is not of type KeySet + TypeError: if circuit_name is not of type str Returns: PublicArguments: public arguments for execution @@ -140,6 +143,10 @@ class ClientSupport(WrapperCpp): ) if not isinstance(keyset, KeySet): raise TypeError(f"keyset must be of type KeySet, not {type(keyset)}") + if not isinstance(circuit_name, str): + raise TypeError( + f"circuit_name must be of type str, not {type(circuit_name)}" + ) signs = client_parameters.input_signs() if len(signs) != len(args): @@ -156,6 +163,7 @@ class ClientSupport(WrapperCpp): client_parameters.cpp(), keyset.cpp(), [arg.cpp() for arg in lambda_arguments], + circuit_name, ) ) @@ -164,17 +172,20 @@ class ClientSupport(WrapperCpp): client_parameters: ClientParameters, keyset: KeySet, public_result: PublicResult, + circuit_name: str = "main", ) -> Union[int, np.ndarray]: """Decrypt a public result using the keyset. Args: client_parameters (ClientParameters): client parameters for decryption keyset (KeySet): keyset used for decryption - public_result: public result to decrypt + public_result (PublicResult): public result to decrypt + circuit_name (str): name of the circuit for which to decrypt Raises: TypeError: if keyset is not of type KeySet TypeError: if public_result is not of type PublicResult + TypeError: if circuit_name is not of type str RuntimeError: if the result is of an unknown type Returns: @@ -186,8 +197,12 @@ class ClientSupport(WrapperCpp): raise TypeError( f"public_result must be of type PublicResult, not {type(public_result)}" ) + if not isinstance(circuit_name, str): + raise TypeError( + f"circuit_name must be of type str, not {type(circuit_name)}" + ) results = _ClientSupport.decrypt_result( - client_parameters.cpp(), keyset.cpp(), public_result.cpp() + client_parameters.cpp(), keyset.cpp(), public_result.cpp(), circuit_name ) def process_result(result): diff --git a/compilers/concrete-compiler/compiler/lib/Bindings/Python/concrete/compiler/compilation_feedback.py b/compilers/concrete-compiler/compiler/lib/Bindings/Python/concrete/compiler/compilation_feedback.py index 499c9a79f..9fc2448c1 100644 --- a/compilers/concrete-compiler/compiler/lib/Bindings/Python/concrete/compiler/compilation_feedback.py +++ b/compilers/concrete-compiler/compiler/lib/Bindings/Python/concrete/compiler/compilation_feedback.py @@ -8,7 +8,8 @@ from typing import Dict, Set # pylint: disable=no-name-in-module,import-error,too-many-instance-attributes from mlir._mlir_libs._concretelang._compiler import ( - CompilationFeedback as _CompilationFeedback, + ProgramCompilationFeedback as _ProgramCompilationFeedback, + CircuitCompilationFeedback as _CircuitCompilationFeedback, KeyType, PrimitiveOperation, ) @@ -39,38 +40,36 @@ def tag_from_location(location): return tag -class CompilationFeedback(WrapperCpp): - """CompilationFeedback is a set of hint computed by the compiler engine.""" +class CircuitCompilationFeedback(WrapperCpp): + """CircuitCompilationFeedback is a set of hint computed by the compiler engine for a circuit.""" - def __init__(self, compilation_feedback: _CompilationFeedback): + def __init__(self, circuit_compilation_feedback: _CircuitCompilationFeedback): """Wrap the native Cpp object. Args: - compilation_feeback (_CompilationFeedback): object to wrap + circuit_compilation_feeback (_CircuitCompilationFeedback): object to wrap Raises: - TypeError: if compilation_feedback is not of type _CompilationFeedback + TypeError: if circuit_compilation_feedback is not of type _CircuitCompilationFeedback """ - if not isinstance(compilation_feedback, _CompilationFeedback): + if not isinstance(circuit_compilation_feedback, _CircuitCompilationFeedback): raise TypeError( - f"compilation_feedback must be of type _CompilationFeedback, not {type(compilation_feedback)}" + "circuit_compilation_feedback must be of type " + f"_CircuitCompilationFeedback, not {type(circuit_compilation_feedback)}" ) - self.complexity = compilation_feedback.complexity - self.p_error = compilation_feedback.p_error - self.global_p_error = compilation_feedback.global_p_error - self.total_secret_keys_size = compilation_feedback.total_secret_keys_size - self.total_bootstrap_keys_size = compilation_feedback.total_bootstrap_keys_size - self.total_keyswitch_keys_size = compilation_feedback.total_keyswitch_keys_size - self.total_inputs_size = compilation_feedback.total_inputs_size - self.total_output_size = compilation_feedback.total_output_size + self.name = circuit_compilation_feedback.name + self.total_inputs_size = circuit_compilation_feedback.total_inputs_size + self.total_output_size = circuit_compilation_feedback.total_output_size self.crt_decompositions_of_outputs = ( - compilation_feedback.crt_decompositions_of_outputs + circuit_compilation_feedback.crt_decompositions_of_outputs + ) + self.statistics = circuit_compilation_feedback.statistics + self.memory_usage_per_location = ( + circuit_compilation_feedback.memory_usage_per_location ) - self.statistics = compilation_feedback.statistics - self.memory_usage_per_location = compilation_feedback.memory_usage_per_location - super().__init__(compilation_feedback) + super().__init__(circuit_compilation_feedback) def count(self, *, operations: Set[PrimitiveOperation]) -> int: """ @@ -216,3 +215,68 @@ class CompilationFeedback(WrapperCpp): result[current_tag][parameter] += statistic.count return result + + +class ProgramCompilationFeedback(WrapperCpp): + """CompilationFeedback is a set of hint computed by the compiler engine.""" + + def __init__(self, program_compilation_feedback: _ProgramCompilationFeedback): + """Wrap the native Cpp object. + + Args: + compilation_feeback (_CompilationFeedback): object to wrap + + Raises: + TypeError: if program_compilation_feedback is not of type _CompilationFeedback + """ + if not isinstance(program_compilation_feedback, _ProgramCompilationFeedback): + raise TypeError( + "program_compilation_feedback must be of type " + f"_CompilationFeedback, not {type(program_compilation_feedback)}" + ) + + self.complexity = program_compilation_feedback.complexity + self.p_error = program_compilation_feedback.p_error + self.global_p_error = program_compilation_feedback.global_p_error + self.total_secret_keys_size = ( + program_compilation_feedback.total_secret_keys_size + ) + self.total_bootstrap_keys_size = ( + program_compilation_feedback.total_bootstrap_keys_size + ) + self.total_keyswitch_keys_size = ( + program_compilation_feedback.total_keyswitch_keys_size + ) + self.circuit_feedbacks = [ + CircuitCompilationFeedback(c) + for c in program_compilation_feedback.circuit_feedbacks + ] + + super().__init__(program_compilation_feedback) + + def circuit(self, circuit_name: str) -> CircuitCompilationFeedback: + """ + Returns the feedback for the circuit circuit_name. + + Args: + circuit_name (str): + the name of the circuit. + + Returns: + CircuitCompilationFeedback: + the feedback for the circuit. + + Raises: + TypeError: if the circuit_name is not a string + ValueError: if there is no circuit with name circuit_name + """ + if not isinstance(circuit_name, str): + raise TypeError( + f"circuit_name must be of type str, not {type(circuit_name)}" + ) + + for circuit in self.circuit_feedbacks: + if circuit.name == circuit_name: + return circuit + + raise ValueError(f"no circuit with name {circuit_name} found") diff --git a/compilers/concrete-compiler/compiler/lib/Bindings/Python/concrete/compiler/compilation_options.py b/compilers/concrete-compiler/compiler/lib/Bindings/Python/concrete/compiler/compilation_options.py index 10a50cfdb..b3c0294bf 100644 --- a/compilers/concrete-compiler/compiler/lib/Bindings/Python/concrete/compiler/compilation_options.py +++ b/compilers/concrete-compiler/compiler/lib/Bindings/Python/concrete/compiler/compilation_options.py @@ -43,11 +43,11 @@ class CompilationOptions(WrapperCpp): @staticmethod # pylint: disable=arguments-differ - def new(function_name="main", backend=_Backend.CPU) -> "CompilationOptions": + def new(backend=_Backend.CPU) -> "CompilationOptions": """Build a CompilationOptions. Args: - function_name (str, optional): name of the entrypoint function. Defaults to "main". + backend (_Backend): backend to use. Raises: TypeError: if function_name is not an str @@ -55,15 +55,9 @@ class CompilationOptions(WrapperCpp): Returns: CompilationOptions """ - if not isinstance(function_name, str): - raise TypeError( - f"function_name must be of type str not {type(function_name)}" - ) if not isinstance(backend, _Backend): - raise TypeError( - f"backend must be of type Backend not {type(function_name)}" - ) - return CompilationOptions.wrap(_CompilationOptions(function_name, backend)) + raise TypeError(f"backend must be of type Backend not {type(backend)}") + return CompilationOptions.wrap(_CompilationOptions(backend)) # pylint: enable=arguments-differ diff --git a/compilers/concrete-compiler/compiler/lib/Bindings/Python/concrete/compiler/library_compilation_result.py b/compilers/concrete-compiler/compiler/lib/Bindings/Python/concrete/compiler/library_compilation_result.py index cb2df5c97..53165a5bd 100644 --- a/compilers/concrete-compiler/compiler/lib/Bindings/Python/concrete/compiler/library_compilation_result.py +++ b/compilers/concrete-compiler/compiler/lib/Bindings/Python/concrete/compiler/library_compilation_result.py @@ -33,16 +33,14 @@ class LibraryCompilationResult(WrapperCpp): @staticmethod # pylint: disable=arguments-differ - def new(output_dir_path: str, func_name: str) -> "LibraryCompilationResult": - """Build a LibraryCompilationResult at output_dir_path, with func_name as entrypoint. + def new(output_dir_path: str) -> "LibraryCompilationResult": + """Build a LibraryCompilationResult at output_dir_path. Args: output_dir_path (str): path to the compilation artifacts - func_name (str): entrypoint function name Raises: TypeError: if output_dir_path is not of type str - TypeError: if func_name is not of type str Returns: LibraryCompilationResult @@ -51,10 +49,6 @@ class LibraryCompilationResult(WrapperCpp): raise TypeError( f"output_dir_path must be of type str, not {type(output_dir_path)}" ) - if not isinstance(func_name, str): - raise TypeError(f"func_name must be of type str, not {type(func_name)}") - return LibraryCompilationResult.wrap( - _LibraryCompilationResult(output_dir_path, func_name) - ) + return LibraryCompilationResult.wrap(_LibraryCompilationResult(output_dir_path)) # pylint: enable=arguments-differ diff --git a/compilers/concrete-compiler/compiler/lib/Bindings/Python/concrete/compiler/library_support.py b/compilers/concrete-compiler/compiler/lib/Bindings/Python/concrete/compiler/library_support.py index 6ef57a610..3fd723331 100644 --- a/compilers/concrete-compiler/compiler/lib/Bindings/Python/concrete/compiler/library_support.py +++ b/compilers/concrete-compiler/compiler/lib/Bindings/Python/concrete/compiler/library_support.py @@ -23,7 +23,7 @@ from .public_arguments import PublicArguments from .library_lambda import LibraryLambda from .public_result import PublicResult from .client_parameters import ClientParameters -from .compilation_feedback import CompilationFeedback +from .compilation_feedback import ProgramCompilationFeedback from .wrapper import WrapperCpp from .utils import lookup_runtime_lib from .evaluation_keys import EvaluationKeys @@ -132,7 +132,7 @@ class LibrarySupport(WrapperCpp): def compile( self, mlir_program: Union[str, MlirModule], - options: CompilationOptions = CompilationOptions.new("main"), + options: CompilationOptions = CompilationOptions.new(), compilation_context: Optional[CompilationContext] = None, ) -> LibraryCompilationResult: """Compile an MLIR program using Concrete dialects into a library. @@ -178,18 +178,13 @@ class LibrarySupport(WrapperCpp): self.cpp().compile(mlir_program, options.cpp()) ) - def reload(self, func_name: str = "main") -> LibraryCompilationResult: + def reload(self) -> LibraryCompilationResult: """Reload the library compilation result from the output_dir_path. - Args: - func_name: entrypoint function name - Returns: LibraryCompilationResult: loaded library """ - if not isinstance(func_name, str): - raise TypeError(f"func_name must be of type str, not {type(func_name)}") - return LibraryCompilationResult.new(self.output_dir_path, func_name) + return LibraryCompilationResult.new(self.output_dir_path) def load_client_parameters( self, library_compilation_result: LibraryCompilationResult @@ -217,7 +212,7 @@ class LibrarySupport(WrapperCpp): def load_compilation_feedback( self, compilation_result: LibraryCompilationResult - ) -> CompilationFeedback: + ) -> ProgramCompilationFeedback: """Load the compilation feedback from the compilation result. Args: @@ -227,13 +222,13 @@ class LibrarySupport(WrapperCpp): TypeError: if compilation_result is not of type LibraryCompilationResult Returns: - CompilationFeedback: the compilation feedback for the compiled program + ProgramCompilationFeedback: the compilation feedback for the compiled program """ if not isinstance(compilation_result, LibraryCompilationResult): raise TypeError( f"compilation_result must be of type LibraryCompilationResult, not {type(compilation_result)}" ) - return CompilationFeedback.wrap( + return ProgramCompilationFeedback.wrap( self.cpp().load_compilation_feedback(compilation_result.cpp()) ) @@ -241,14 +236,18 @@ class LibrarySupport(WrapperCpp): self, library_compilation_result: LibraryCompilationResult, simulation: bool, + circuit_name: str = "main", ) -> LibraryLambda: - """Load the server lambda from the library compilation result. + """Load the server lambda for a given circuit from the library compilation result. Args: library_compilation_result (LibraryCompilationResult): compilation result of the library + simulation (bool): use simulation for execution + circuit_name (str): name of the circuit to be loaded Raises: - TypeError: if library_compilation_result is not of type LibraryCompilationResult + TypeError: if library_compilation_result is not of type LibraryCompilationResult, if + circuit_name is not of type str or Returns: LibraryLambda: executable reference to the library @@ -258,8 +257,18 @@ class LibrarySupport(WrapperCpp): f"library_compilation_result must be of type LibraryCompilationResult, not " f"{type(library_compilation_result)}" ) + if not isinstance(circuit_name, str): + raise TypeError( + f"circuit_name must be of type str, not " f"{type(circuit_name)}" + ) + if not isinstance(simulation, bool): + raise TypeError( + f"simulation must be of type bool, not " f"{type(simulation)}" + ) return LibraryLambda.wrap( - self.cpp().load_server_lambda(library_compilation_result.cpp(), simulation) + self.cpp().load_server_lambda( + library_compilation_result.cpp(), circuit_name, simulation + ) ) def server_call( diff --git a/compilers/concrete-compiler/compiler/lib/Bindings/Python/concrete/compiler/simulated_value_decrypter.py b/compilers/concrete-compiler/compiler/lib/Bindings/Python/concrete/compiler/simulated_value_decrypter.py index e153ce554..f00055d30 100644 --- a/compilers/concrete-compiler/compiler/lib/Bindings/Python/concrete/compiler/simulated_value_decrypter.py +++ b/compilers/concrete-compiler/compiler/lib/Bindings/Python/concrete/compiler/simulated_value_decrypter.py @@ -41,12 +41,12 @@ class SimulatedValueDecrypter(WrapperCpp): @staticmethod # pylint: disable=arguments-differ - def new(client_parameters: ClientParameters): + def new(client_parameters: ClientParameters, circuit_name: str = "main"): """ Create a value decrypter. """ return SimulatedValueDecrypter( - _SimulatedValueDecrypter.create(client_parameters.cpp()) + _SimulatedValueDecrypter.create(client_parameters.cpp(), circuit_name) ) def decrypt(self, position: int, value: Value) -> Union[int, np.ndarray]: diff --git a/compilers/concrete-compiler/compiler/lib/Bindings/Python/concrete/compiler/simulated_value_exporter.py b/compilers/concrete-compiler/compiler/lib/Bindings/Python/concrete/compiler/simulated_value_exporter.py index 79cba8804..10deb348e 100644 --- a/compilers/concrete-compiler/compiler/lib/Bindings/Python/concrete/compiler/simulated_value_exporter.py +++ b/compilers/concrete-compiler/compiler/lib/Bindings/Python/concrete/compiler/simulated_value_exporter.py @@ -40,12 +40,14 @@ class SimulatedValueExporter(WrapperCpp): @staticmethod # pylint: disable=arguments-differ - def new(client_parameters: ClientParameters) -> "SimulatedValueExporter": + def new( + client_parameters: ClientParameters, circuitName: str = "main" + ) -> "SimulatedValueExporter": """ Create a value exporter. """ return SimulatedValueExporter( - _SimulatedValueExporter.create(client_parameters.cpp()) + _SimulatedValueExporter.create(client_parameters.cpp(), circuitName) ) def export_scalar(self, position: int, value: int) -> Value: diff --git a/compilers/concrete-compiler/compiler/lib/Bindings/Python/concrete/compiler/value_decrypter.py b/compilers/concrete-compiler/compiler/lib/Bindings/Python/concrete/compiler/value_decrypter.py index ab369773b..d424136e3 100644 --- a/compilers/concrete-compiler/compiler/lib/Bindings/Python/concrete/compiler/value_decrypter.py +++ b/compilers/concrete-compiler/compiler/lib/Bindings/Python/concrete/compiler/value_decrypter.py @@ -42,12 +42,14 @@ class ValueDecrypter(WrapperCpp): @staticmethod # pylint: disable=arguments-differ - def new(keyset: KeySet, client_parameters: ClientParameters): + def new( + keyset: KeySet, client_parameters: ClientParameters, circuit_name: str = "main" + ): """ Create a value decrypter. """ return ValueDecrypter( - _ValueDecrypter.create(keyset.cpp(), client_parameters.cpp()) + _ValueDecrypter.create(keyset.cpp(), client_parameters.cpp(), circuit_name) ) def decrypt(self, position: int, value: Value) -> Union[int, np.ndarray]: diff --git a/compilers/concrete-compiler/compiler/lib/Bindings/Python/concrete/compiler/value_exporter.py b/compilers/concrete-compiler/compiler/lib/Bindings/Python/concrete/compiler/value_exporter.py index 1ddd4a45b..84a585eb6 100644 --- a/compilers/concrete-compiler/compiler/lib/Bindings/Python/concrete/compiler/value_exporter.py +++ b/compilers/concrete-compiler/compiler/lib/Bindings/Python/concrete/compiler/value_exporter.py @@ -41,12 +41,14 @@ class ValueExporter(WrapperCpp): @staticmethod # pylint: disable=arguments-differ - def new(keyset: KeySet, client_parameters: ClientParameters) -> "ValueExporter": + def new( + keyset: KeySet, client_parameters: ClientParameters, circuit_name: str = "main" + ) -> "ValueExporter": """ Create a value exporter. """ return ValueExporter( - _ValueExporter.create(keyset.cpp(), client_parameters.cpp()) + _ValueExporter.create(keyset.cpp(), client_parameters.cpp(), circuit_name) ) def export_scalar(self, position: int, value: int) -> Value: diff --git a/compilers/concrete-compiler/compiler/lib/Dialect/Concrete/Analysis/MemoryUsage.cpp b/compilers/concrete-compiler/compiler/lib/Dialect/Concrete/Analysis/MemoryUsage.cpp index d90c62ef7..b4958735e 100644 --- a/compilers/concrete-compiler/compiler/lib/Dialect/Concrete/Analysis/MemoryUsage.cpp +++ b/compilers/concrete-compiler/compiler/lib/Dialect/Concrete/Analysis/MemoryUsage.cpp @@ -3,6 +3,7 @@ #include #include #include +#include #include #include #include @@ -66,34 +67,48 @@ namespace Concrete { struct MemoryUsagePass : public PassWrapper> { - CompilationFeedback &feedback; + ProgramCompilationFeedback &feedback; + CircuitCompilationFeedback *circuitFeedback; - MemoryUsagePass(CompilationFeedback &feedback) : feedback{feedback} {}; + MemoryUsagePass(ProgramCompilationFeedback &feedback) + : feedback{feedback}, circuitFeedback{nullptr} {}; void runOnOperation() override { - WalkResult walk = - getOperation()->walk([&](Operation *op, const WalkStage &stage) { - if (stage.isBeforeAllRegions()) { - std::optional error = this->enter(op); - if (error.has_value()) { - op->emitError() << error->mesg; - return WalkResult::interrupt(); + auto module = getOperation(); + auto funcs = module.getOps(); + for (CircuitCompilationFeedback &circuitFeedback : + feedback.circuitFeedbacks) { + auto funcOp = llvm::find_if(funcs, [&](mlir::func::FuncOp op) { + return op.getName() == circuitFeedback.name; + }); + assert(funcOp != funcs.end()); + this->circuitFeedback = &circuitFeedback; + + WalkResult walk = + getOperation()->walk([&](Operation *op, const WalkStage &stage) { + if (stage.isBeforeAllRegions()) { + std::optional error = this->enter(op); + if (error.has_value()) { + op->emitError() << error->mesg; + return WalkResult::interrupt(); + } } - } - if (stage.isAfterAllRegions()) { - std::optional error = this->exit(op); - if (error.has_value()) { - op->emitError() << error->mesg; - return WalkResult::interrupt(); + if (stage.isAfterAllRegions()) { + std::optional error = this->exit(op); + if (error.has_value()) { + op->emitError() << error->mesg; + return WalkResult::interrupt(); + } } - } - return WalkResult::advance(); - }); + return WalkResult::advance(); + }); - if (walk.wasInterrupted()) { - signalPassFailure(); + if (walk.wasInterrupted()) { + signalPassFailure(); + return; + } } } @@ -172,7 +187,7 @@ struct MemoryUsagePass // element_size auto memoryUsage = numberOfAlloc * maybeBufferSize.value(); - pass.feedback.memoryUsagePerLoc[location] += memoryUsage; + pass.circuitFeedback->memoryUsagePerLoc[location] += memoryUsage; return std::nullopt; } @@ -218,7 +233,7 @@ struct MemoryUsagePass } auto bufferSize = maybeBufferSize.value(); - pass.feedback.memoryUsagePerLoc[location] += bufferSize; + pass.circuitFeedback->memoryUsagePerLoc[location] += bufferSize; } } @@ -233,7 +248,7 @@ struct MemoryUsagePass } // namespace Concrete std::unique_ptr> -createMemoryUsagePass(CompilationFeedback &feedback) { +createMemoryUsagePass(ProgramCompilationFeedback &feedback) { return std::make_unique(feedback); } diff --git a/compilers/concrete-compiler/compiler/lib/Dialect/TFHE/Analysis/ExtractStatistics.cpp b/compilers/concrete-compiler/compiler/lib/Dialect/TFHE/Analysis/ExtractStatistics.cpp index 781e46fa7..4b7aeaad3 100644 --- a/compilers/concrete-compiler/compiler/lib/Dialect/TFHE/Analysis/ExtractStatistics.cpp +++ b/compilers/concrete-compiler/compiler/lib/Dialect/TFHE/Analysis/ExtractStatistics.cpp @@ -1,3 +1,4 @@ +#include "concretelang/Support/CompilationFeedback.h" #include #include @@ -34,35 +35,48 @@ namespace TFHE { struct ExtractTFHEStatisticsPass : public PassWrapper> { - CompilationFeedback &feedback; + ProgramCompilationFeedback &feedback; + CircuitCompilationFeedback *circuitFeedback; - ExtractTFHEStatisticsPass(CompilationFeedback &feedback) - : feedback{feedback} {}; + ExtractTFHEStatisticsPass(ProgramCompilationFeedback &feedback) + : feedback{feedback}, circuitFeedback{nullptr} {}; void runOnOperation() override { - WalkResult walk = - getOperation()->walk([&](Operation *op, const WalkStage &stage) { - if (stage.isBeforeAllRegions()) { - std::optional error = this->enter(op); - if (error.has_value()) { - op->emitError() << error->mesg; - return WalkResult::interrupt(); + auto module = getOperation(); + auto funcs = module.getOps(); + for (CircuitCompilationFeedback &circuitFeedback : + feedback.circuitFeedbacks) { + auto funcOp = llvm::find_if(funcs, [&](mlir::func::FuncOp op) { + return op.getName() == circuitFeedback.name; + }); + assert(funcOp != funcs.end()); + this->circuitFeedback = &circuitFeedback; + + WalkResult walk = + (*funcOp)->walk([&](Operation *op, const WalkStage &stage) { + if (stage.isBeforeAllRegions()) { + std::optional error = this->enter(op); + if (error.has_value()) { + op->emitError() << error->mesg; + return WalkResult::interrupt(); + } } - } - if (stage.isAfterAllRegions()) { - std::optional error = this->exit(op); - if (error.has_value()) { - op->emitError() << error->mesg; - return WalkResult::interrupt(); + if (stage.isAfterAllRegions()) { + std::optional error = this->exit(op); + if (error.has_value()) { + op->emitError() << error->mesg; + return WalkResult::interrupt(); + } } - } - return WalkResult::advance(); - }); + return WalkResult::advance(); + }); - if (walk.wasInterrupted()) { - signalPassFailure(); + if (walk.wasInterrupted()) { + signalPassFailure(); + return; + } } } @@ -118,14 +132,14 @@ struct ExtractTFHEStatisticsPass auto location = locationString(op.getLoc()); auto operation = PrimitiveOperation::ENCRYPTED_ADDITION; - auto keys = std::vector>(); + auto keys = std::vector>(); auto count = pass.iterations; - std::pair key = - std::make_pair(KeyType::SECRET, (size_t)resultingKey->index); + std::pair key = + std::make_pair(KeyType::SECRET, (int64_t)resultingKey->index); keys.push_back(key); - pass.feedback.statistics.push_back(concretelang::Statistic{ + pass.circuitFeedback->statistics.push_back(concretelang::Statistic{ location, operation, keys, @@ -145,14 +159,14 @@ struct ExtractTFHEStatisticsPass auto location = locationString(op.getLoc()); auto operation = PrimitiveOperation::CLEAR_ADDITION; - auto keys = std::vector>(); + auto keys = std::vector>(); auto count = pass.iterations; - std::pair key = - std::make_pair(KeyType::SECRET, (size_t)resultingKey->index); + std::pair key = + std::make_pair(KeyType::SECRET, (int64_t)resultingKey->index); keys.push_back(key); - pass.feedback.statistics.push_back(concretelang::Statistic{ + pass.circuitFeedback->statistics.push_back(concretelang::Statistic{ location, operation, keys, @@ -172,14 +186,14 @@ struct ExtractTFHEStatisticsPass auto location = locationString(op.getLoc()); auto operation = PrimitiveOperation::PBS; - auto keys = std::vector>(); + auto keys = std::vector>(); auto count = pass.iterations; - std::pair key = - std::make_pair(KeyType::BOOTSTRAP, (size_t)bsk.getIndex()); + std::pair key = + std::make_pair(KeyType::BOOTSTRAP, (int64_t)bsk.getIndex()); keys.push_back(key); - pass.feedback.statistics.push_back(concretelang::Statistic{ + pass.circuitFeedback->statistics.push_back(concretelang::Statistic{ location, operation, keys, @@ -199,14 +213,14 @@ struct ExtractTFHEStatisticsPass auto location = locationString(op.getLoc()); auto operation = PrimitiveOperation::KEY_SWITCH; - auto keys = std::vector>(); + auto keys = std::vector>(); auto count = pass.iterations; - std::pair key = - std::make_pair(KeyType::KEY_SWITCH, (size_t)ksk.getIndex()); + std::pair key = + std::make_pair(KeyType::KEY_SWITCH, (int64_t)ksk.getIndex()); keys.push_back(key); - pass.feedback.statistics.push_back(concretelang::Statistic{ + pass.circuitFeedback->statistics.push_back(concretelang::Statistic{ location, operation, keys, @@ -226,14 +240,14 @@ struct ExtractTFHEStatisticsPass auto location = locationString(op.getLoc()); auto operation = PrimitiveOperation::CLEAR_MULTIPLICATION; - auto keys = std::vector>(); + auto keys = std::vector>(); auto count = pass.iterations; - std::pair key = - std::make_pair(KeyType::SECRET, (size_t)resultingKey->index); + std::pair key = + std::make_pair(KeyType::SECRET, (int64_t)resultingKey->index); keys.push_back(key); - pass.feedback.statistics.push_back(concretelang::Statistic{ + pass.circuitFeedback->statistics.push_back(concretelang::Statistic{ location, operation, keys, @@ -253,14 +267,14 @@ struct ExtractTFHEStatisticsPass auto location = locationString(op.getLoc()); auto operation = PrimitiveOperation::ENCRYPTED_NEGATION; - auto keys = std::vector>(); + auto keys = std::vector>(); auto count = pass.iterations; - std::pair key = - std::make_pair(KeyType::SECRET, (size_t)resultingKey->index); + std::pair key = + std::make_pair(KeyType::SECRET, (int64_t)resultingKey->index); keys.push_back(key); - pass.feedback.statistics.push_back(concretelang::Statistic{ + pass.circuitFeedback->statistics.push_back(concretelang::Statistic{ location, operation, keys, @@ -279,17 +293,17 @@ struct ExtractTFHEStatisticsPass auto resultingKey = op.getType().getKey().getNormalized(); auto location = locationString(op.getLoc()); - auto keys = std::vector>(); + auto keys = std::vector>(); auto count = pass.iterations; - std::pair key = - std::make_pair(KeyType::SECRET, (size_t)resultingKey->index); + std::pair key = + std::make_pair(KeyType::SECRET, (int64_t)resultingKey->index); keys.push_back(key); // clear - encrypted = clear + neg(encrypted) auto operation = PrimitiveOperation::ENCRYPTED_NEGATION; - pass.feedback.statistics.push_back(concretelang::Statistic{ + pass.circuitFeedback->statistics.push_back(concretelang::Statistic{ location, operation, keys, @@ -297,7 +311,7 @@ struct ExtractTFHEStatisticsPass }); operation = PrimitiveOperation::CLEAR_ADDITION; - pass.feedback.statistics.push_back(concretelang::Statistic{ + pass.circuitFeedback->statistics.push_back(concretelang::Statistic{ location, operation, keys, @@ -319,20 +333,20 @@ struct ExtractTFHEStatisticsPass auto location = locationString(op.getLoc()); auto operation = PrimitiveOperation::WOP_PBS; - auto keys = std::vector>(); + auto keys = std::vector>(); auto count = pass.iterations; - std::pair key = - std::make_pair(KeyType::BOOTSTRAP, (size_t)bsk.getIndex()); + std::pair key = + std::make_pair(KeyType::BOOTSTRAP, (int64_t)bsk.getIndex()); keys.push_back(key); - key = std::make_pair(KeyType::KEY_SWITCH, (size_t)ksk.getIndex()); + key = std::make_pair(KeyType::KEY_SWITCH, (int64_t)ksk.getIndex()); keys.push_back(key); - key = std::make_pair(KeyType::PACKING_KEY_SWITCH, (size_t)pksk.getIndex()); + key = std::make_pair(KeyType::PACKING_KEY_SWITCH, (int64_t)pksk.getIndex()); keys.push_back(key); - pass.feedback.statistics.push_back(concretelang::Statistic{ + pass.circuitFeedback->statistics.push_back(concretelang::Statistic{ location, operation, keys, @@ -342,13 +356,13 @@ struct ExtractTFHEStatisticsPass return std::nullopt; } - size_t iterations = 1; + int64_t iterations = 1; }; } // namespace TFHE std::unique_ptr> -createStatisticExtractionPass(CompilationFeedback &feedback) { +createStatisticExtractionPass(ProgramCompilationFeedback &feedback) { return std::make_unique(feedback); } diff --git a/compilers/concrete-compiler/compiler/lib/Support/CompilationFeedback.cpp b/compilers/concrete-compiler/compiler/lib/Support/CompilationFeedback.cpp index fc32b065e..48d4b9a8e 100644 --- a/compilers/concrete-compiler/compiler/lib/Support/CompilationFeedback.cpp +++ b/compilers/concrete-compiler/compiler/lib/Support/CompilationFeedback.cpp @@ -5,7 +5,9 @@ #include #include +#include +#include "concrete-protocol.capnp.h" #include "concretelang/Support/CompilationFeedback.h" using concretelang::protocol::Message; @@ -13,60 +15,11 @@ using concretelang::protocol::Message; namespace mlir { namespace concretelang { -void CompilationFeedback::fillFromProgramInfo( - const Message &programInfo) { - auto params = programInfo.asReader(); - - // Compute the size of secret keys - totalSecretKeysSize = 0; - for (auto skInfo : params.getKeyset().getLweSecretKeys()) { - assert(skInfo.getParams().getIntegerPrecision() % 8 == 0); - auto byteSize = skInfo.getParams().getIntegerPrecision() / 8; - totalSecretKeysSize += skInfo.getParams().getLweDimension() * byteSize; - } - // Compute the boostrap keys size - totalBootstrapKeysSize = 0; - for (auto bskInfo : params.getKeyset().getLweBootstrapKeys()) { - assert(bskInfo.getInputId() < - (uint32_t)params.getKeyset().getLweSecretKeys().size()); - auto inputKeyInfo = - params.getKeyset().getLweSecretKeys()[bskInfo.getInputId()]; - assert(bskInfo.getOutputId() < - (uint32_t)params.getKeyset().getLweSecretKeys().size()); - auto outputKeyInfo = - params.getKeyset().getLweSecretKeys()[bskInfo.getOutputId()]; - assert(bskInfo.getParams().getIntegerPrecision() % 8 == 0); - auto byteSize = bskInfo.getParams().getIntegerPrecision() % 8; - auto inputLweSize = inputKeyInfo.getParams().getLweDimension() + 1; - auto outputLweSize = outputKeyInfo.getParams().getLweDimension() + 1; - auto level = bskInfo.getParams().getLevelCount(); - auto glweDimension = bskInfo.getParams().getGlweDimension(); - totalBootstrapKeysSize += inputLweSize * level * (glweDimension + 1) * - (glweDimension + 1) * outputLweSize * byteSize; - } - // Compute the keyswitch keys size - totalKeyswitchKeysSize = 0; - for (auto kskInfo : params.getKeyset().getLweKeyswitchKeys()) { - assert(kskInfo.getInputId() < - (uint32_t)params.getKeyset().getLweSecretKeys().size()); - auto inputKeyInfo = - params.getKeyset().getLweSecretKeys()[kskInfo.getInputId()]; - assert(kskInfo.getOutputId() < - (uint32_t)params.getKeyset().getLweSecretKeys().size()); - auto outputKeyInfo = - params.getKeyset().getLweSecretKeys()[kskInfo.getOutputId()]; - assert(kskInfo.getParams().getIntegerPrecision() % 8 == 0); - auto byteSize = kskInfo.getParams().getIntegerPrecision() % 8; - auto inputLweSize = inputKeyInfo.getParams().getLweDimension() + 1; - auto outputLweSize = outputKeyInfo.getParams().getLweDimension() + 1; - auto level = kskInfo.getParams().getLevelCount(); - totalKeyswitchKeysSize += level * inputLweSize * outputLweSize * byteSize; - } - auto circuitInfo = params.getCircuits()[0]; +void CircuitCompilationFeedback::fillFromCircuitInfo( + concreteprotocol::CircuitInfo::Reader circuitInfo) { auto computeGateSize = [&](const Message &gateInfo) { unsigned int nElements = 1; - // TODO: CHANGE THAT ITS WRONG for (auto dimension : gateInfo.asReader().getRawInfo().getShape().getDimensions()) { nElements *= dimension; @@ -104,17 +57,77 @@ void CompilationFeedback::fillFromProgramInfo( } } } + // Sets name + name = circuitInfo.getName().cStr(); } -outcome::checked -CompilationFeedback::load(std::string jsonPath) { +void ProgramCompilationFeedback::fillFromProgramInfo( + const Message &programInfo) { + auto params = programInfo.asReader(); + + // Compute the size of secret keys + totalSecretKeysSize = 0; + for (auto skInfo : params.getKeyset().getLweSecretKeys()) { + assert(skInfo.getParams().getIntegerPrecision() % 8 == 0); + auto byteSize = skInfo.getParams().getIntegerPrecision() / 8; + totalSecretKeysSize += skInfo.getParams().getLweDimension() * byteSize; + } + // Compute the boostrap keys size + totalBootstrapKeysSize = 0; + for (auto bskInfo : params.getKeyset().getLweBootstrapKeys()) { + assert(bskInfo.getInputId() < + (uint32_t)params.getKeyset().getLweSecretKeys().size()); + auto inputKeyInfo = + params.getKeyset().getLweSecretKeys()[bskInfo.getInputId()]; + assert(bskInfo.getOutputId() < + (uint32_t)params.getKeyset().getLweSecretKeys().size()); + auto outputKeyInfo = + params.getKeyset().getLweSecretKeys()[bskInfo.getOutputId()]; + assert(bskInfo.getParams().getIntegerPrecision() % 8 == 0); + auto byteSize = bskInfo.getParams().getIntegerPrecision() / 8; + auto inputLweSize = inputKeyInfo.getParams().getLweDimension() + 1; + auto outputLweSize = outputKeyInfo.getParams().getLweDimension() + 1; + auto level = bskInfo.getParams().getLevelCount(); + auto glweDimension = bskInfo.getParams().getGlweDimension(); + totalBootstrapKeysSize += inputLweSize * level * (glweDimension + 1) * + (glweDimension + 1) * outputLweSize * byteSize; + } + // Compute the keyswitch keys size + totalKeyswitchKeysSize = 0; + for (auto kskInfo : params.getKeyset().getLweKeyswitchKeys()) { + assert(kskInfo.getInputId() < + (uint32_t)params.getKeyset().getLweSecretKeys().size()); + auto inputKeyInfo = + params.getKeyset().getLweSecretKeys()[kskInfo.getInputId()]; + assert(kskInfo.getOutputId() < + (uint32_t)params.getKeyset().getLweSecretKeys().size()); + auto outputKeyInfo = + params.getKeyset().getLweSecretKeys()[kskInfo.getOutputId()]; + assert(kskInfo.getParams().getIntegerPrecision() % 8 == 0); + auto byteSize = kskInfo.getParams().getIntegerPrecision() / 8; + auto inputLweSize = inputKeyInfo.getParams().getLweDimension() + 1; + auto outputLweSize = outputKeyInfo.getParams().getLweDimension() + 1; + auto level = kskInfo.getParams().getLevelCount(); + totalKeyswitchKeysSize += level * inputLweSize * outputLweSize * byteSize; + } + // Compute the circuit feedbacks + for (auto circuitInfo : params.getCircuits()) { + CircuitCompilationFeedback feedback; + feedback.fillFromCircuitInfo(circuitInfo); + circuitFeedbacks.push_back(feedback); + } +} + +outcome::checked +ProgramCompilationFeedback::load(std::string jsonPath) { std::ifstream file(jsonPath); std::string content((std::istreambuf_iterator(file)), (std::istreambuf_iterator())); if (file.fail()) { return StringError("Cannot read file: ") << jsonPath; } - auto expectedCompFeedback = llvm::json::parse(content); + auto expectedCompFeedback = + llvm::json::parse(content); if (auto err = expectedCompFeedback.takeError()) { return StringError("Cannot open compilation feedback: ") << llvm::toString(std::move(err)) << "\n" @@ -123,196 +136,228 @@ CompilationFeedback::load(std::string jsonPath) { return expectedCompFeedback.get(); } -llvm::json::Value toJSON(const mlir::concretelang::CompilationFeedback &v) { - llvm::json::Object object{ - {"complexity", v.complexity}, - {"pError", v.pError}, - {"globalPError", v.globalPError}, - {"totalSecretKeysSize", v.totalSecretKeysSize}, - {"totalBootstrapKeysSize", v.totalBootstrapKeysSize}, - {"totalKeyswitchKeysSize", v.totalKeyswitchKeysSize}, - {"totalInputsSize", v.totalInputsSize}, - {"totalOutputsSize", v.totalOutputsSize}, - {"crtDecompositionsOfOutputs", v.crtDecompositionsOfOutputs}, - }; - - auto memoryUsageObject = llvm::json::Object(); - for (auto key : v.memoryUsagePerLoc) { - memoryUsageObject.insert({key.first, key.second}); +llvm::json::Object +memoryUsageToJson(const std::map &memoryUsagePerLoc) { + auto object = llvm::json::Object(); + for (auto key : memoryUsagePerLoc) { + object.insert({key.first, key.second}); } - object.insert({"memoryUsagePerLoc", std::move(memoryUsageObject)}); - - auto statisticsJson = llvm::json::Array(); - for (auto statistic : v.statistics) { - auto statisticJson = llvm::json::Object(); - statisticJson.insert({"location", statistic.location}); - switch (statistic.operation) { - case PrimitiveOperation::PBS: - statisticJson.insert({"operation", "PBS"}); - break; - case PrimitiveOperation::WOP_PBS: - statisticJson.insert({"operation", "WOP_PBS"}); - break; - case PrimitiveOperation::KEY_SWITCH: - statisticJson.insert({"operation", "KEY_SWITCH"}); - break; - case PrimitiveOperation::CLEAR_ADDITION: - statisticJson.insert({"operation", "CLEAR_ADDITION"}); - break; - case PrimitiveOperation::ENCRYPTED_ADDITION: - statisticJson.insert({"operation", "ENCRYPTED_ADDITION"}); - break; - case PrimitiveOperation::CLEAR_MULTIPLICATION: - statisticJson.insert({"operation", "CLEAR_MULTIPLICATION"}); - break; - case PrimitiveOperation::ENCRYPTED_NEGATION: - statisticJson.insert({"operation", "ENCRYPTED_NEGATION"}); - break; - } - auto keysJson = llvm::json::Array(); - for (auto &key : statistic.keys) { - KeyType type = key.first; - size_t index = key.second; - - auto keyJson = llvm::json::Array(); - switch (type) { - case KeyType::SECRET: - keyJson.push_back("SECRET"); - break; - case KeyType::BOOTSTRAP: - keyJson.push_back("BOOTSTRAP"); - break; - case KeyType::KEY_SWITCH: - keyJson.push_back("KEY_SWITCH"); - break; - case KeyType::PACKING_KEY_SWITCH: - keyJson.push_back("PACKING_KEY_SWITCH"); - break; - } - keyJson.push_back((int64_t)index); - - keysJson.push_back(std::move(keyJson)); - } - statisticJson.insert({"keys", std::move(keysJson)}); - statisticJson.insert({"count", (int64_t)statistic.count}); - - statisticsJson.push_back(std::move(statisticJson)); - } - object.insert({"statistics", std::move(statisticsJson)}); - return object; } +llvm::json::Object statisticToJson(const Statistic &statistic) { + auto object = llvm::json::Object(); + object.insert({"location", statistic.location}); + object.insert({"count", statistic.count}); + switch (statistic.operation) { + case PrimitiveOperation::PBS: + object.insert({"operation", "PBS"}); + break; + case PrimitiveOperation::WOP_PBS: + object.insert({"operation", "WOP_PBS"}); + break; + case PrimitiveOperation::KEY_SWITCH: + object.insert({"operation", "KEY_SWITCH"}); + break; + case PrimitiveOperation::CLEAR_ADDITION: + object.insert({"operation", "CLEAR_ADDITION"}); + break; + case PrimitiveOperation::ENCRYPTED_ADDITION: + object.insert({"operation", "ENCRYPTED_ADDITION"}); + break; + case PrimitiveOperation::CLEAR_MULTIPLICATION: + object.insert({"operation", "CLEAR_MULTIPLICATION"}); + break; + case PrimitiveOperation::ENCRYPTED_NEGATION: + object.insert({"operation", "ENCRYPTED_NEGATION"}); + break; + } + auto keysJson = llvm::json::Array(); + for (auto &key : statistic.keys) { + KeyType type = key.first; + size_t index = key.second; + + auto keyJson = llvm::json::Array(); + switch (type) { + case KeyType::SECRET: + keyJson.push_back("SECRET"); + break; + case KeyType::BOOTSTRAP: + keyJson.push_back("BOOTSTRAP"); + break; + case KeyType::KEY_SWITCH: + keyJson.push_back("KEY_SWITCH"); + break; + case KeyType::PACKING_KEY_SWITCH: + keyJson.push_back("PACKING_KEY_SWITCH"); + break; + } + keyJson.push_back((int64_t)index); + + keysJson.push_back(std::move(keyJson)); + } + object.insert({"keys", std::move(keysJson)}); + return object; +} + +llvm::json::Array statisticsToJson(const std::vector &statistics) { + auto object = llvm::json::Array(); + for (auto statistic : statistics) { + object.push_back(statisticToJson(statistic)); + } + return object; +} + +llvm::json::Array crtDecompositionToJson( + const std::vector> &crtDecompositionsOfOutputs) { + auto object = llvm::json::Array(); + for (auto crtDec : crtDecompositionsOfOutputs) { + auto inner = llvm::json::Array(); + for (auto val : crtDec) { + inner.push_back(val); + } + object.push_back(std::move(inner)); + } + return object; +} + +llvm::json::Array circuitFeedbacksToJson( + const std::vector &circuitFeedbacks) { + auto object = llvm::json::Array(); + for (auto circuit : circuitFeedbacks) { + llvm::json::Object circuitObject{ + {"name", circuit.name}, + {"totalInputsSize", circuit.totalInputsSize}, + {"totalOutputsSize", circuit.totalOutputsSize}, + {"crtDecompositionsOfOutputs", + crtDecompositionToJson(circuit.crtDecompositionsOfOutputs)}, + {"statistics", statisticsToJson(circuit.statistics)}, + {"memoryUsagePerLoc", memoryUsageToJson(circuit.memoryUsagePerLoc)}, + }; + object.push_back(std::move(circuitObject)); + } + return object; +} + +llvm::json::Value +toJSON(const mlir::concretelang::ProgramCompilationFeedback &program) { + llvm::json::Object programObject{ + {"complexity", program.complexity}, + {"pError", program.pError}, + {"globalPError", program.globalPError}, + {"totalSecretKeysSize", program.totalSecretKeysSize}, + {"totalBootstrapKeysSize", program.totalBootstrapKeysSize}, + {"totalKeyswitchKeysSize", program.totalKeyswitchKeysSize}, + {"circuitFeedbacks", circuitFeedbacksToJson(program.circuitFeedbacks)}}; + return programObject; +} + +template +bool fromJSON(const llvm::json::Value &j, std::pair &v, + llvm::json::Path p) { + if (auto *array = j.getAsArray()) { + if (!fromJSON((*array)[0], v.first, p.index(0))) + return false; + if (!fromJSON((*array)[1], v.second, p.index(1))) + return false; + return true; + } + p.report("expected array"); + return false; +} + bool fromJSON(const llvm::json::Value j, - mlir::concretelang::CompilationFeedback &v, llvm::json::Path p) { + mlir::concretelang::PrimitiveOperation &v, llvm::json::Path p) { + if (auto operationString = j.getAsString()) { + if (operationString == "PBS") { + v = PrimitiveOperation::PBS; + return true; + } else if (operationString == "KEY_SWITCH") { + v = PrimitiveOperation::KEY_SWITCH; + return true; + } else if (operationString == "WOP_PBS") { + v = PrimitiveOperation::WOP_PBS; + return true; + } else if (operationString == "CLEAR_ADDITION") { + v = PrimitiveOperation::CLEAR_ADDITION; + return true; + } else if (operationString == "ENCRYPTED_ADDITION") { + v = PrimitiveOperation::ENCRYPTED_ADDITION; + return true; + } else if (operationString == "CLEAR_MULTIPLICATION") { + v = PrimitiveOperation::CLEAR_MULTIPLICATION; + return true; + } else if (operationString == "ENCRYPTED_NEGATION") { + v = PrimitiveOperation::ENCRYPTED_NEGATION; + return true; + } else { + p.report("expected one of " + "(PBS|KEY_SWITCH|WOP_PBS|CLEAR_ADDITION|ENCRYPTED_ADDITION|" + "CLEAR_MULTIPLICATION|ENCRYPTED_NEGATION)"); + return false; + } + } + p.report("expected string"); + return false; +} + +bool fromJSON(const llvm::json::Value j, mlir::concretelang::KeyType &v, + llvm::json::Path p) { + if (auto keyTypeString = j.getAsString()) { + if (keyTypeString == "SECRET") { + v = KeyType::SECRET; + return true; + } else if (keyTypeString == "BOOTSTRAP") { + v = KeyType::BOOTSTRAP; + return true; + } else if (keyTypeString == "KEY_SWITCH") { + v = KeyType::KEY_SWITCH; + return true; + } else if (keyTypeString == "PACKING_KEY_SWITCH") { + v = KeyType::PACKING_KEY_SWITCH; + return true; + } else { + p.report( + "expected one of (SECRET|BOOTSTRAP|KEY_SWITCH|PACKING_KEY_SWITCH)"); + return false; + } + } + p.report("expected string"); + return false; +} + +bool fromJSON(const llvm::json::Value j, mlir::concretelang::Statistic &v, + llvm::json::Path p) { llvm::json::ObjectMapper O(j, p); - bool is_success = - O && O.map("complexity", v.complexity) && O.map("pError", v.pError) && - O.map("globalPError", v.globalPError) && - O.map("totalSecretKeysSize", v.totalSecretKeysSize) && - O.map("totalBootstrapKeysSize", v.totalBootstrapKeysSize) && - O.map("totalKeyswitchKeysSize", v.totalKeyswitchKeysSize) && - O.map("totalInputsSize", v.totalInputsSize) && - O.map("totalOutputsSize", v.totalOutputsSize) && - O.map("crtDecompositionsOfOutputs", v.crtDecompositionsOfOutputs); + return O && O.map("location", v.location) && + O.map("operation", v.operation) && O.map("operation", v.operation) && + O.map("keys", v.keys) && O.map("count", v.count); +} - if (!is_success) { - return false; - } +bool fromJSON(const llvm::json::Value j, + mlir::concretelang::CircuitCompilationFeedback &v, + llvm::json::Path p) { + llvm::json::ObjectMapper O(j, p); + return O && O.map("name", v.name) && + O.map("totalInputsSize", v.totalInputsSize) && + O.map("totalOutputsSize", v.totalOutputsSize) && + O.map("crtDecompositionsOfOutputs", v.crtDecompositionsOfOutputs) && + O.map("statistics", v.statistics) && + O.map("memoryUsagePerLoc", v.memoryUsagePerLoc); +} - auto object = j.getAsObject(); - if (!object) { - return false; - } +bool fromJSON(const llvm::json::Value j, + mlir::concretelang::ProgramCompilationFeedback &v, + llvm::json::Path p) { + llvm::json::ObjectMapper O(j, p); - auto memoryUsageObject = object->getObject("memoryUsagePerLoc"); - if (!memoryUsageObject) { - return false; - } - for (auto entry : *memoryUsageObject) { - auto loc = entry.getFirst().str(); - auto maybeUsage = entry.getSecond().getAsInteger(); - if (!maybeUsage.has_value()) { - return false; - } - v.memoryUsagePerLoc[loc] = *maybeUsage; - } - - auto statistics = object->getArray("statistics"); - if (!statistics) { - return false; - } - - for (auto statisticValue : *statistics) { - auto statistic = statisticValue.getAsObject(); - if (!statistic) { - return false; - } - - auto location = statistic->getString("location"); - auto operationStr = statistic->getString("operation"); - auto keysArray = statistic->getArray("keys"); - auto count = statistic->getInteger("count"); - - if (!operationStr || !location || !keysArray || !count) { - return false; - } - - PrimitiveOperation operation; - if (operationStr.value() == "PBS") { - operation = PrimitiveOperation::PBS; - } else if (operationStr.value() == "KEY_SWITCH") { - operation = PrimitiveOperation::KEY_SWITCH; - } else if (operationStr.value() == "WOP_PBS") { - operation = PrimitiveOperation::WOP_PBS; - } else if (operationStr.value() == "CLEAR_ADDITION") { - operation = PrimitiveOperation::CLEAR_ADDITION; - } else if (operationStr.value() == "ENCRYPTED_ADDITION") { - operation = PrimitiveOperation::ENCRYPTED_ADDITION; - } else if (operationStr.value() == "CLEAR_MULTIPLICATION") { - operation = PrimitiveOperation::CLEAR_MULTIPLICATION; - } else if (operationStr.value() == "ENCRYPTED_NEGATION") { - operation = PrimitiveOperation::ENCRYPTED_NEGATION; - } else { - return false; - } - - auto keys = std::vector>(); - for (auto keyValue : *keysArray) { - llvm::json::Array *keyArray = keyValue.getAsArray(); - if (!keyArray || keyArray->size() != 2) { - return false; - } - - auto typeStr = keyArray->front().getAsString(); - auto index = keyArray->back().getAsInteger(); - - if (!typeStr || !index) { - return false; - } - - KeyType type; - if (typeStr.value() == "SECRET") { - type = KeyType::SECRET; - } else if (typeStr.value() == "BOOTSTRAP") { - type = KeyType::BOOTSTRAP; - } else if (typeStr.value() == "KEY_SWITCH") { - type = KeyType::KEY_SWITCH; - } else if (typeStr.value() == "PACKING_KEY_SWITCH") { - type = KeyType::PACKING_KEY_SWITCH; - } else { - return false; - } - - keys.push_back(std::make_pair(type, (size_t)*index)); - } - - v.statistics.push_back( - Statistic{location->str(), operation, keys, (uint64_t)*count}); - } - - return true; + return O && O.map("complexity", v.complexity) && O.map("pError", v.pError) && + O.map("globalPError", v.globalPError) && + O.map("totalSecretKeysSize", v.totalSecretKeysSize) && + O.map("totalBootstrapKeysSize", v.totalBootstrapKeysSize) && + O.map("totalKeyswitchKeysSize", v.totalKeyswitchKeysSize) && + O.map("circuitFeedbacks", v.circuitFeedbacks); } } // namespace concretelang diff --git a/compilers/concrete-compiler/compiler/lib/Support/CompilerEngine.cpp b/compilers/concrete-compiler/compiler/lib/Support/CompilerEngine.cpp index 6b62c75d3..cddc262c4 100644 --- a/compilers/concrete-compiler/compiler/lib/Support/CompilerEngine.cpp +++ b/compilers/concrete-compiler/compiler/lib/Support/CompilerEngine.cpp @@ -3,6 +3,7 @@ // https://github.com/zama-ai/concrete-compiler-internal/blob/main/LICENSE.txt // for license information. +#include "concretelang/Support/V0Parameters.h" #include "mlir/Dialect/Arith/Transforms/BufferizableOpInterfaceImpl.h" #include "mlir/Dialect/Bufferization/IR/Bufferization.h" #include "mlir/Dialect/Linalg/Transforms/BufferizableOpInterfaceImpl.h" @@ -166,19 +167,33 @@ CompilerEngine::getConcreteOptimizerDescription(CompilationResult &res) { if (descriptions->empty()) { // The pass has not been run return std::nullopt; } - if (this->compilerOptions.mainFuncName.has_value()) { - auto name = this->compilerOptions.mainFuncName.value(); - auto description = descriptions->find(name); - if (description == descriptions->end()) { - std::string names; - return StreamStringError("Function not found, name='") - << name << "', cannot get optimizer description"; - } - return std::move(description->second); + if (descriptions->size() > 1 && + config.strategy != + mlir::concretelang::optimizer::V0) { // Multi circuits without V0 + return StreamStringError( + "Multi-circuits is only supported for V0 optimization."); } - if (descriptions->size() != 1) { - llvm::errs() << "Several crypto parameters exists: the function need to be " - "specified, taking the first one"; + if (descriptions->size() > 1) { + auto iter = descriptions->begin(); + auto desc = std::move(iter->second); + if (!desc.has_value()) { + return StreamStringError("Expected description."); + } + if (!desc.value().dag.has_value()) { + return StreamStringError("Expected dag in description."); + } + iter++; + while (iter != descriptions->end()) { + if (!iter->second.has_value()) { + return StreamStringError("Expected description."); + } + if (!iter->second.value().dag.has_value()) { + return StreamStringError("Expected dag in description."); + } + desc->dag.value()->concat(*iter->second.value().dag.value()); + iter++; + } + return std::move(desc); } return std::move(descriptions->begin()->second); } @@ -199,7 +214,7 @@ llvm::Error CompilerEngine::determineFHEParameters(CompilationResult &res) { res.fheContext.emplace( mlir::concretelang::V0FHEContext{constraint, v0Params}); - CompilationFeedback feedback; + ProgramCompilationFeedback feedback; res.feedback.emplace(feedback); return llvm::Error::success(); @@ -213,7 +228,7 @@ llvm::Error CompilerEngine::determineFHEParameters(CompilationResult &res) { if (!descr.get().has_value()) { return llvm::Error::success(); } - CompilationFeedback feedback; + ProgramCompilationFeedback feedback; // Make sure to use the gpu constraint of the optimizer if we use gpu // backend. compilerOptions.optimizerConfig.use_gpu_constraints = @@ -322,9 +337,8 @@ CompilerEngine::compile(mlir::ModuleOp moduleOp, Target target, // on the `FHE` dialect. if ((this->generateProgramInfo || target == Target::LIBRARY) && !options.encodings) { - auto funcName = options.mainFuncName.value_or("main"); auto encodingInfosOrErr = - mlir::concretelang::encodings::getCircuitEncodings(funcName, module); + mlir::concretelang::encodings::getProgramEncoding(module); if (!encodingInfosOrErr) { return encodingInfosOrErr.takeError(); } @@ -363,7 +377,7 @@ CompilerEngine::compile(mlir::ModuleOp moduleOp, Target target, chunkedMode.asBuilder().setWidth(options.chunkWidth); maybeChunkInfo = chunkedMode; } - mlir::concretelang::encodings::setCircuitEncodingModes( + mlir::concretelang::encodings::setProgramEncodingModes( *options.encodings, maybeChunkInfo, res.fheContext); } @@ -457,41 +471,29 @@ CompilerEngine::compile(mlir::ModuleOp moduleOp, Target target, // Generate client parameters if requested if (this->generateProgramInfo) { - if (!options.mainFuncName.has_value()) { - return StreamStringError( - "Generation of client parameters requested, but no function name " - "specified"); - } if (!res.fheContext.has_value()) { return StreamStringError( - "Cannot generate client parameters, the fhe context is empty for " + - options.mainFuncName.value()); + "Cannot generate client parameters, the fhe context is empty"); } } // Generate program info if requested if (this->generateProgramInfo || target == Target::LIBRARY) { - auto funcName = options.mainFuncName.value_or("main"); if (!res.fheContext.has_value()) { // Some tests involve call a to non encrypted functions auto programInfo = Message(); programInfo.asBuilder().initCircuits(1); - programInfo.asBuilder().getCircuits()[0].setName(std::string(funcName)); + programInfo.asBuilder().getCircuits()[0].setName(std::string("main")); res.programInfo = programInfo; } else { auto programInfoOrErr = mlir::concretelang::createProgramInfoFromTfheDialect( - module, funcName, options.optimizerConfig.security, + module, options.optimizerConfig.security, options.encodings.value(), options.compressEvaluationKeys); if (!programInfoOrErr) return programInfoOrErr.takeError(); res.programInfo = std::move(*programInfoOrErr); - // If more than one circuit, feedback can not be generated for now .. - if (res.programInfo->asReader().getCircuits().size() != 1) { - return StreamStringError( - "Cannot generate feedback for program with more than one circuit."); - } res.feedback->fillFromProgramInfo(*res.programInfo); } } diff --git a/compilers/concrete-compiler/compiler/lib/Support/Encodings.cpp b/compilers/concrete-compiler/compiler/lib/Support/Encodings.cpp index d80e733ba..5de7cc7ce 100644 --- a/compilers/concrete-compiler/compiler/lib/Support/Encodings.cpp +++ b/compilers/concrete-compiler/compiler/lib/Support/Encodings.cpp @@ -16,6 +16,7 @@ #include #include #include +#include namespace FHE = mlir::concretelang::FHE; using concretelang::protocol::Message; @@ -67,20 +68,13 @@ encodingFromType(mlir::Type ty) { } llvm::Expected> -getCircuitEncodings(llvm::StringRef functionName, mlir::ModuleOp module) { - // Find the input function - auto rangeOps = module.getOps(); - auto funcOp = llvm::find_if(rangeOps, [&](mlir::func::FuncOp op) { - return op.getName() == functionName; - }); - if (funcOp == rangeOps.end()) { - return StreamStringError("Function not found, name='") - << functionName << "', cannot get circuit encodings"; - } - auto funcType = (*funcOp).getFunctionType(); +getCircuitEncodings(mlir::func::FuncOp funcOp) { + + auto funcType = funcOp.getFunctionType(); // Retrieve input/output encodings auto circuitEncodings = Message(); + circuitEncodings.asBuilder().setName(funcOp.getSymName().str()); auto inputsBuilder = circuitEncodings.asBuilder().initInputs(funcType.getNumInputs()); for (size_t i = 0; i < funcType.getNumInputs(); i++) { @@ -105,8 +99,32 @@ getCircuitEncodings(llvm::StringRef functionName, mlir::ModuleOp module) { return std::move(circuitEncodings); } +llvm::Expected> +getProgramEncoding(mlir::ModuleOp module) { + + auto funcs = module.getOps(); + auto circuitEncodings = + std::vector>(); + for (auto func : funcs) { + auto encodingInfosOrErr = getCircuitEncodings(func); + if (!encodingInfosOrErr) { + return encodingInfosOrErr.takeError(); + } + circuitEncodings.push_back(*encodingInfosOrErr); + } + + auto programEncoding = Message(); + auto circuitBuilder = + programEncoding.asBuilder().initCircuits(circuitEncodings.size()); + for (size_t i = 0; i < circuitEncodings.size(); i++) { + circuitBuilder.setWithCaveats(i, circuitEncodings[i].asReader()); + } + + return std::move(programEncoding); +} + void setCircuitEncodingModes( - Message &info, + concreteprotocol::CircuitEncodingInfo::Builder info, std::optional< Message> maybeChunk, @@ -164,13 +182,25 @@ void setCircuitEncodingModes( // Got nothing particular. Setting encoding mode to native. integerEncodingBuilder.getMode().initNative(); }; - for (auto encInfoBuilder : info.asBuilder().getInputs()) { + for (auto encInfoBuilder : info.getInputs()) { setMode(encInfoBuilder); } - for (auto encInfoBuilder : info.asBuilder().getOutputs()) { + for (auto encInfoBuilder : info.getOutputs()) { setMode(encInfoBuilder); } } + +void setProgramEncodingModes( + Message &info, + std::optional< + Message> + maybeChunk, + std::optional maybeFheContext) { + for (auto circuitInfo : info.asBuilder().getCircuits()) { + setCircuitEncodingModes(circuitInfo, maybeChunk, maybeFheContext); + } +} + } // namespace encodings } // namespace concretelang } // namespace mlir diff --git a/compilers/concrete-compiler/compiler/lib/Support/Pipeline.cpp b/compilers/concrete-compiler/compiler/lib/Support/Pipeline.cpp index adb141f97..4dcd52dd1 100644 --- a/compilers/concrete-compiler/compiler/lib/Support/Pipeline.cpp +++ b/compilers/concrete-compiler/compiler/lib/Support/Pipeline.cpp @@ -5,6 +5,7 @@ #include "llvm/Support/TargetSelect.h" +#include "concretelang/Support/CompilationFeedback.h" #include "mlir/Conversion/BufferizationToMemRef/BufferizationToMemRef.h" #include "mlir/Conversion/Passes.h" #include "mlir/Dialect/Bufferization/Transforms/Passes.h" @@ -345,7 +346,7 @@ normalizeTFHEKeys(mlir::MLIRContext &context, mlir::ModuleOp &module, mlir::LogicalResult extractTFHEStatistics(mlir::MLIRContext &context, mlir::ModuleOp &module, std::function enablePass, - CompilationFeedback &feedback) { + ProgramCompilationFeedback &feedback) { mlir::PassManager pm(&context); pipelinePrinting("TFHEStatistics", pm, context); @@ -371,7 +372,7 @@ lowerTFHEToConcrete(mlir::MLIRContext &context, mlir::ModuleOp &module, mlir::LogicalResult computeMemoryUsage(mlir::MLIRContext &context, mlir::ModuleOp &module, std::function enablePass, - CompilationFeedback &feedback) { + ProgramCompilationFeedback &feedback) { mlir::PassManager pm(&context); pipelinePrinting("Computing Memory Usage", pm, context); diff --git a/compilers/concrete-compiler/compiler/lib/Support/ProgramInfoGeneration.cpp b/compilers/concrete-compiler/compiler/lib/Support/ProgramInfoGeneration.cpp index a9441dbe5..a42742af7 100644 --- a/compilers/concrete-compiler/compiler/lib/Support/ProgramInfoGeneration.cpp +++ b/compilers/concrete-compiler/compiler/lib/Support/ProgramInfoGeneration.cpp @@ -296,32 +296,22 @@ extractKeysetInfo(TFHE::TFHECircuitKeys circuitKeys, } llvm::Expected> -extractCircuitInfo(mlir::ModuleOp module, llvm::StringRef functionName, - Message &encodings, +extractCircuitInfo(mlir::func::FuncOp funcOp, + concreteprotocol::CircuitEncodingInfo::Reader encodings, concrete::SecurityCurve curve) { auto output = Message(); - // Check that the specified function can be found - auto rangeOps = module.getOps(); - auto funcOp = llvm::find_if(rangeOps, [&](mlir::func::FuncOp op) { - return op.getName() == functionName; - }); - if (funcOp == rangeOps.end()) { - return StreamStringError( - "cannot find the function for generate client parameters: ") - << functionName; - } // Create input and output circuit gate parameters - auto funcType = (*funcOp).getFunctionType(); + auto funcType = funcOp.getFunctionType(); - output.asBuilder().setName(functionName.str()); + output.asBuilder().setName(encodings.getName().cStr()); output.asBuilder().initInputs(funcType.getNumInputs()); output.asBuilder().initOutputs(funcType.getNumResults()); for (unsigned int i = 0; i < funcType.getNumInputs(); i++) { auto ty = funcType.getInput(i); - auto encoding = encodings.asReader().getInputs()[i]; + auto encoding = encodings.getInputs()[i]; auto maybeGate = generateGate(ty, encoding, curve); if (!maybeGate) { return maybeGate.takeError(); @@ -330,7 +320,7 @@ extractCircuitInfo(mlir::ModuleOp module, llvm::StringRef functionName, } for (unsigned int i = 0; i < funcType.getNumResults(); i++) { auto ty = funcType.getResult(i); - auto encoding = encodings.asReader().getOutputs()[i]; + auto encoding = encodings.getOutputs()[i]; auto maybeGate = generateGate(ty, encoding, curve); if (!maybeGate) { return maybeGate.takeError(); @@ -341,10 +331,43 @@ extractCircuitInfo(mlir::ModuleOp module, llvm::StringRef functionName, return output; } +llvm::Expected> extractProgramInfo( + mlir::ModuleOp module, + const Message &encodings, + concrete::SecurityCurve curve) { + + auto output = Message(); + auto circuitsCount = encodings.asReader().getCircuits().size(); + auto circuitsBuilder = output.asBuilder().initCircuits(circuitsCount); + auto rangeOps = module.getOps(); + + for (size_t i = 0; i < circuitsCount; i++) { + auto circuitEncoding = encodings.asReader().getCircuits()[i]; + auto functionName = circuitEncoding.getName(); + auto funcOp = llvm::find_if(rangeOps, [&](mlir::func::FuncOp op) { + return op.getName() == functionName.cStr(); + }); + if (funcOp == rangeOps.end()) { + return StreamStringError("cannot find the following function to generate " + "program info: ") + << functionName.cStr(); + } + + auto maybeCircuitInfo = extractCircuitInfo(*funcOp, circuitEncoding, curve); + if (!maybeCircuitInfo) { + return maybeCircuitInfo.takeError(); + } + + circuitsBuilder.setWithCaveats(i, (*maybeCircuitInfo).asReader()); + } + + return output; +} + llvm::Expected> createProgramInfoFromTfheDialect( - mlir::ModuleOp module, llvm::StringRef functionName, int bitsOfSecurity, - Message &encodings, + mlir::ModuleOp module, int bitsOfSecurity, + const Message &encodings, bool compressEvaluationKeys) { // Check that security curves exist @@ -354,23 +377,20 @@ createProgramInfoFromTfheDialect( << bitsOfSecurity << "bits"; } - // Create the output Program Info. - auto output = Message(); + // We generate the circuit infos from the module. + auto maybeProgramInfo = extractProgramInfo(module, encodings, *curve); + if (!maybeProgramInfo) { + return maybeProgramInfo.takeError(); + } + + // Extract the output Program Info. + Message output = *maybeProgramInfo; // We extract the keys of the circuit auto keysetInfo = extractKeysetInfo(TFHE::extractCircuitKeys(module), *curve, compressEvaluationKeys); output.asBuilder().setKeyset(keysetInfo.asReader()); - // We generate the gates for the inputs aud outputs - auto maybeCircuitInfo = - extractCircuitInfo(module, functionName, encodings, *curve); - if (!maybeCircuitInfo) { - return maybeCircuitInfo.takeError(); - } - output.asBuilder().initCircuits(1).setWithCaveats( - 0, maybeCircuitInfo->asReader()); - return output; } diff --git a/compilers/concrete-compiler/compiler/lib/Support/V0Parameters.cpp b/compilers/concrete-compiler/compiler/lib/Support/V0Parameters.cpp index 4630233da..c56b7bd3a 100644 --- a/compilers/concrete-compiler/compiler/lib/Support/V0Parameters.cpp +++ b/compilers/concrete-compiler/compiler/lib/Support/V0Parameters.cpp @@ -268,7 +268,7 @@ optimizer::Solution convertSolution(optimizer::CircuitSolution sol) { /// Fill the compilation `feedback` from a `solution` returned by the optmizer. template -void fillFeedback(Solution solution, CompilationFeedback &feedback) { +void fillFeedback(Solution solution, ProgramCompilationFeedback &feedback) { feedback.complexity = solution.complexity; feedback.pError = solution.p_error; feedback.globalPError = @@ -314,7 +314,7 @@ llvm::Error checkPErrorSolution(Solution solution, optimizer::Config config) { /// optimizer::Solution, and fill the `feedback`. template llvm::Expected -toCompilerSolution(Solution solution, CompilationFeedback &feedback, +toCompilerSolution(Solution solution, ProgramCompilationFeedback &feedback, optimizer::Config config) { // display(descr, config, sol, naive_user, duration); if (auto err = checkPErrorSolution(solution, config); err) { @@ -334,9 +334,9 @@ optimizer::Solution emptySolution() { return solution; } -llvm::Expected getSolution(optimizer::Description &descr, - CompilationFeedback &feedback, - optimizer::Config config) { +llvm::Expected +getSolution(optimizer::Description &descr, ProgramCompilationFeedback &feedback, + optimizer::Config config) { namespace chrono = std::chrono; // auto start = chrono::high_resolution_clock::now(); auto naive_user = diff --git a/compilers/concrete-compiler/compiler/src/main.cpp b/compilers/concrete-compiler/compiler/src/main.cpp index 2360d66e5..8bc5b538e 100644 --- a/compilers/concrete-compiler/compiler/src/main.cpp +++ b/compilers/concrete-compiler/compiler/src/main.cpp @@ -223,11 +223,6 @@ llvm::cl::opt dataflowParallelize( llvm::cl::desc("Generate the program as a dataflow graph"), llvm::cl::init(false)); -llvm::cl::opt - funcName("funcname", - llvm::cl::desc("Name of the function to compile, default 'main'"), - llvm::cl::init("")); - llvm::cl::opt chunkIntegers("chunk-integers", llvm::cl::desc("Whether to decompose integer into chunks or " @@ -379,12 +374,17 @@ llvm::cl::list largeIntegerCircuitBootstrap( "(experimental) [level, baseLog]"), llvm::cl::ZeroOrMore, llvm::cl::MiscFlags::CommaSeparated); -llvm::cl::opt circuitEncodings( - "circuit-encodings", - llvm::cl::desc("Specify the input and output encodings of the circuit, " +llvm::cl::opt programEncoding( + "program-encoding", + llvm::cl::desc("Specify the encodings to use for the program, " "using the JSON representation."), llvm::cl::init(std::string{})); +llvm::cl::opt skipProgramInfo( + "skip-program-info", + llvm::cl::desc("Skip generating the program info artefacts."), + llvm::cl::init(false)); + } // namespace cmdline namespace llvm { @@ -424,6 +424,7 @@ cmdlineCompilationOptions() { options.chunkIntegers = cmdline::chunkIntegers; options.chunkSize = cmdline::chunkSize; options.chunkWidth = cmdline::chunkWidth; + options.skipProgramInfo = cmdline::skipProgramInfo; if (!cmdline::v0Constraint.empty()) { if (cmdline::v0Constraint.size() != 2) { @@ -435,10 +436,6 @@ cmdlineCompilationOptions() { cmdline::v0Constraint[1], cmdline::v0Constraint[0]}; } - if (!cmdline::funcName.empty()) { - options.mainFuncName = cmdline::funcName; - } - // Convert tile sizes to `Optional` if (!cmdline::fhelinalgTileSizes.empty()) options.fhelinalgTileSizes.emplace(cmdline::fhelinalgTileSizes); @@ -512,12 +509,12 @@ cmdlineCompilationOptions() { llvm::inconvertibleErrorCode()); } - if (!cmdline::circuitEncodings.empty()) { - auto jsonString = cmdline::circuitEncodings.getValue(); - auto encodings = Message(); + if (!cmdline::programEncoding.empty()) { + auto jsonString = cmdline::programEncoding.getValue(); + auto encodings = Message(); if (encodings.readJsonFromString(jsonString).has_failure()) { return llvm::make_error( - "Failed to parse the --circuit-encodings option", + "Failed to parse the --program-encoding option", llvm::inconvertibleErrorCode()); } options.encodings = encodings; @@ -556,10 +553,9 @@ mlir::LogicalResult processInputBuffer( std::shared_ptr ccx = mlir::concretelang::CompilationContext::createShared(); - std::string funcName = options.mainFuncName.value_or(""); - mlir::concretelang::CompilerEngine ce{ccx}; ce.setCompilationOptions(std::move(options)); + ce.setGenerateProgramInfo(!options.skipProgramInfo); if (cmdline::passes.size() != 0) { ce.setEnablePass([](mlir::Pass *pass) { diff --git a/compilers/concrete-compiler/compiler/tests/check_tests/BugReport/bug_report_785.mlir b/compilers/concrete-compiler/compiler/tests/check_tests/BugReport/bug_report_785.mlir index 5386eb9b0..2e3764aff 100644 --- a/compilers/concrete-compiler/compiler/tests/check_tests/BugReport/bug_report_785.mlir +++ b/compilers/concrete-compiler/compiler/tests/check_tests/BugReport/bug_report_785.mlir @@ -1,15 +1,15 @@ -// RUN: concretecompiler --action=dump-llvm-ir %s +// RUN: concretecompiler --action=dump-llvm-ir --optimizer-strategy=V0 --skip-program-info %s // Just ensure that compile // https://github.com/zama-ai/concrete-compiler-internal/issues/785 -func.func @main(%arg0: !FHE.eint<15>, %cst: tensor<32768xi64>) -> tensor<1x!FHE.eint<15>> { - %1 = "FHE.apply_lookup_table"(%arg0, %cst) : (!FHE.eint<15>, tensor<32768xi64>) -> !FHE.eint<15> - %6 = tensor.from_elements %1 : tensor<1x!FHE.eint<15>> // ERROR HERE line 4 - return %6 : tensor<1x!FHE.eint<15>> +func.func @main(%arg0: !FHE.eint<5>, %cst: tensor<32xi64>) -> tensor<1x!FHE.eint<5>> { + %1 = "FHE.apply_lookup_table"(%arg0, %cst) : (!FHE.eint<5>, tensor<32xi64>) -> !FHE.eint<5> + %6 = tensor.from_elements %1 : tensor<1x!FHE.eint<5>> // ERROR HERE line 4 + return %6 : tensor<1x!FHE.eint<5>> } // Ensures that tensors of multiple elements can be constructed as well. -func.func @main2(%arg0: !FHE.eint<15>, %cst: tensor<32768xi64>) -> tensor<2x!FHE.eint<15>> { - %1 = "FHE.apply_lookup_table"(%arg0, %cst) : (!FHE.eint<15>, tensor<32768xi64>) -> !FHE.eint<15> - %6 = tensor.from_elements %1, %arg0 : tensor<2x!FHE.eint<15>> // ERROR HERE line 4 - return %6 : tensor<2x!FHE.eint<15>> +func.func @main2(%arg0: !FHE.eint<5>, %cst: tensor<32xi64>) -> tensor<2x!FHE.eint<5>> { + %1 = "FHE.apply_lookup_table"(%arg0, %cst) : (!FHE.eint<5>, tensor<32xi64>) -> !FHE.eint<5> + %6 = tensor.from_elements %1, %arg0 : tensor<2x!FHE.eint<5>> // ERROR HERE line 4 + return %6 : tensor<2x!FHE.eint<5>> } diff --git a/compilers/concrete-compiler/compiler/tests/check_tests/BugReport/bug_report_858.mlir b/compilers/concrete-compiler/compiler/tests/check_tests/BugReport/bug_report_858.mlir index b75f3da50..bcb2378bf 100644 --- a/compilers/concrete-compiler/compiler/tests/check_tests/BugReport/bug_report_858.mlir +++ b/compilers/concrete-compiler/compiler/tests/check_tests/BugReport/bug_report_858.mlir @@ -1,4 +1,4 @@ -// RUN: concretecompiler --action=dump-tfhe --force-encoding crt %s +// RUN: concretecompiler --action=dump-tfhe --optimizer-strategy=V0 --force-encoding crt --skip-program-info %s func.func @main(%arg0: tensor<32x!FHE.eint<8>>) -> tensor<16x!FHE.eint<8>>{ %0 = tensor.extract_slice %arg0[16] [16] [1] : tensor<32x!FHE.eint<8>> to tensor<16x!FHE.eint<8>> return %0 : tensor<16x!FHE.eint<8>> diff --git a/compilers/concrete-compiler/compiler/tests/check_tests/BugReport/bug_report_890.mlir b/compilers/concrete-compiler/compiler/tests/check_tests/BugReport/bug_report_890.mlir deleted file mode 100644 index 8690450bd..000000000 --- a/compilers/concrete-compiler/compiler/tests/check_tests/BugReport/bug_report_890.mlir +++ /dev/null @@ -1,12 +0,0 @@ -// RUN: concretecompiler --action=dump-tfhe --force-encoding crt %s -func.func @main(%2: tensor<1x1x!FHE.eint<16>>) -> tensor<1x1x1x!FHE.eint<16>> { - %3 = tensor.expand_shape %2 [[0], [1, 2]] : tensor<1x1x!FHE.eint<16>> into tensor<1x1x1x!FHE.eint<16>> - return %3 : tensor<1x1x1x!FHE.eint<16>> -} - -func.func @main2(%2: tensor<1x1x1x!FHE.eint<16>>) -> tensor<1x1x!FHE.eint<16>> { - %3 = tensor.collapse_shape %2 [[0], [1, 2]] : tensor<1x1x1x!FHE.eint<16>> into tensor<1x1x!FHE.eint<16>> - return %3 : tensor<1x1x!FHE.eint<16>> -} - - diff --git a/compilers/concrete-compiler/compiler/tests/check_tests/Conversion/ConcreteToLLVM/gpu_ops.mlir b/compilers/concrete-compiler/compiler/tests/check_tests/Conversion/ConcreteToLLVM/gpu_ops.mlir index a842601f9..c8e77b9fa 100644 --- a/compilers/concrete-compiler/compiler/tests/check_tests/Conversion/ConcreteToLLVM/gpu_ops.mlir +++ b/compilers/concrete-compiler/compiler/tests/check_tests/Conversion/ConcreteToLLVM/gpu_ops.mlir @@ -1,4 +1,4 @@ -// RUN: concretecompiler --action=dump-llvm-dialect --emit-gpu-ops %s 2>&1| FileCheck %s +// RUN: concretecompiler --action=dump-llvm-dialect --emit-gpu-ops --skip-program-info %s 2>&1| FileCheck %s //CHECK: llvm.call @memref_keyswitch_lwe_cuda_u64 //CHECK: llvm.call @memref_bootstrap_lwe_cuda_u64 diff --git a/compilers/concrete-compiler/compiler/tests/check_tests/Conversion/TFHEGlobalParametrization/pbs_ks_bs.mlir b/compilers/concrete-compiler/compiler/tests/check_tests/Conversion/TFHEGlobalParametrization/pbs_ks_bs.mlir index ccfafb834..7e3caab98 100644 --- a/compilers/concrete-compiler/compiler/tests/check_tests/Conversion/TFHEGlobalParametrization/pbs_ks_bs.mlir +++ b/compilers/concrete-compiler/compiler/tests/check_tests/Conversion/TFHEGlobalParametrization/pbs_ks_bs.mlir @@ -1,4 +1,4 @@ -// RUN: concretecompiler --action=dump-parametrized-tfhe --optimizer-strategy=V0 --v0-parameter=2,10,750,1,23,3,4 --v0-constraint=4,0 %s 2>&1| FileCheck %s +// RUN: concretecompiler --action=dump-parametrized-tfhe --optimizer-strategy=V0 --v0-parameter=2,10,750,1,23,3,4 --v0-constraint=4,0 --skip-program-info %s 2>&1| FileCheck %s //CHECK: func.func @main(%[[A0:.*]]: !TFHE.glwe>) -> !TFHE.glwe> { //CHECK-NEXT: %cst = arith.constant dense<0> : tensor<1024xi64> diff --git a/compilers/concrete-compiler/compiler/tests/check_tests/Conversion/TFHEToConcrete/add_glwe.mlir b/compilers/concrete-compiler/compiler/tests/check_tests/Conversion/TFHEToConcrete/add_glwe.mlir index ed4ead296..9b83303fa 100644 --- a/compilers/concrete-compiler/compiler/tests/check_tests/Conversion/TFHEToConcrete/add_glwe.mlir +++ b/compilers/concrete-compiler/compiler/tests/check_tests/Conversion/TFHEToConcrete/add_glwe.mlir @@ -1,4 +1,4 @@ -// RUN: concretecompiler --passes tfhe-to-concrete --action=dump-concrete %s 2>&1| FileCheck %s +// RUN: concretecompiler --passes tfhe-to-concrete --action=dump-concrete --skip-program-info %s 2>&1| FileCheck %s // CHECK-LABEL: func.func @add_glwe(%arg0: tensor<2049xi64>, %arg1: tensor<2049xi64>) -> tensor<2049xi64> func.func @add_glwe(%arg0: !TFHE.glwe>, %arg1: !TFHE.glwe>) -> !TFHE.glwe> { diff --git a/compilers/concrete-compiler/compiler/tests/check_tests/Conversion/TFHEToConcrete/add_glwe_int.mlir b/compilers/concrete-compiler/compiler/tests/check_tests/Conversion/TFHEToConcrete/add_glwe_int.mlir index 467edcc9b..da24a1298 100644 --- a/compilers/concrete-compiler/compiler/tests/check_tests/Conversion/TFHEToConcrete/add_glwe_int.mlir +++ b/compilers/concrete-compiler/compiler/tests/check_tests/Conversion/TFHEToConcrete/add_glwe_int.mlir @@ -1,4 +1,4 @@ -// RUN: concretecompiler --passes tfhe-to-concrete --action=dump-concrete %s 2>&1| FileCheck %s +// RUN: concretecompiler --passes tfhe-to-concrete --action=dump-concrete --skip-program-info %s 2>&1| FileCheck %s //CHECK: func.func @add_glwe_const_int(%[[A0:.*]]: tensor<1025xi64>) -> tensor<1025xi64> { //CHECK: %c1_i64 = arith.constant 1 : i64 diff --git a/compilers/concrete-compiler/compiler/tests/check_tests/Conversion/TFHEToConcrete/bootstrap.mlir b/compilers/concrete-compiler/compiler/tests/check_tests/Conversion/TFHEToConcrete/bootstrap.mlir index fb48740a6..bd6399b3f 100644 --- a/compilers/concrete-compiler/compiler/tests/check_tests/Conversion/TFHEToConcrete/bootstrap.mlir +++ b/compilers/concrete-compiler/compiler/tests/check_tests/Conversion/TFHEToConcrete/bootstrap.mlir @@ -1,4 +1,4 @@ -// RUN: concretecompiler --passes tfhe-to-concrete --action=dump-concrete %s 2>&1| FileCheck %s +// RUN: concretecompiler --passes tfhe-to-concrete --action=dump-concrete --skip-program-info %s 2>&1| FileCheck %s //CHECK: func.func @bootstrap_lwe(%[[A0:.*]]: tensor<601xi64>) -> tensor<1025xi64> { //CHECK-NEXT: %cst = arith.constant dense<"0x00100000000000000020000000000000003000000000000000400000000000000050000000000000006000000000000000700000000000000080000000000000009000000000000000A000000000000000B000000000000000C000000000000000D000000000000000E000000000000000F0000000000000010000000000000001100000000000000120000000000000013000000000000001400000000000000150000000000000016000000000000001700000000000000180000000000000019000000000000001A000000000000001B000000000000001C000000000000001D000000000000001E000000000000001F0000000000000020000000000000002100000000000000220000000000000023000000000000002400000000000000250000000000000026000000000000002700000000000000280000000000000029000000000000002A000000000000002B000000000000002C000000000000002D000000000000002E000000000000002F0000000000000030000000000000003100000000000000320000000000000033000000000000003400000000000000350000000000000036000000000000003700000000000000380000000000000039000000000000003A000000000000003B000000000000003C000000000000003D000000000000003E000000000000003F0000000000000040000000000000004100000000000000420000000000000043000000000000004400000000000000450000000000000046000000000000004700000000000000480000000000000049000000000000004A000000000000004B000000000000004C000000000000004D000000000000004E000000000000004F0000000000000050000000000000005100000000000000520000000000000053000000000000005400000000000000550000000000000056000000000000005700000000000000580000000000000059000000000000005A000000000000005B000000000000005C000000000000005D000000000000005E000000000000005F0000000000000060000000000000006100000000000000620000000000000063000000000000006400000000000000650000000000000066000000000000006700000000000000680000000000000069000000000000006A000000000000006B000000000000006C000000000000006D000000000000006E000000000000006F0000000000000070000000000000007100000000000000720000000000000073000000000000007400000000000000750000000000000076000000000000007700000000000000780000000000000079000000000000007A000000000000007B000000000000007C000000000000007D000000000000007E000000000000007F0000000000000080000000000000008100000000000000820000000000000083000000000000008400000000000000850000000000000086000000000000008700000000000000880000000000000089000000000000008A000000000000008B000000000000008C000000000000008D000000000000008E000000000000008F0000000000000090000000000000009100000000000000920000000000000093000000000000009400000000000000950000000000000096000000000000009700000000000000980000000000000099000000000000009A000000000000009B000000000000009C000000000000009D000000000000009E000000000000009F00000000000000A000000000000000A100000000000000A200000000000000A300000000000000A400000000000000A500000000000000A600000000000000A700000000000000A800000000000000A900000000000000AA00000000000000AB00000000000000AC00000000000000AD00000000000000AE00000000000000AF00000000000000B000000000000000B100000000000000B200000000000000B300000000000000B400000000000000B500000000000000B600000000000000B700000000000000B800000000000000B900000000000000BA00000000000000BB00000000000000BC00000000000000BD00000000000000BE00000000000000BF00000000000000C000000000000000C100000000000000C200000000000000C300000000000000C400000000000000C500000000000000C600000000000000C700000000000000C800000000000000C900000000000000CA00000000000000CB00000000000000CC00000000000000CD00000000000000CE00000000000000CF00000000000000D000000000000000D100000000000000D200000000000000D300000000000000D400000000000000D500000000000000D600000000000000D700000000000000D800000000000000D900000000000000DA00000000000000DB00000000000000DC00000000000000DD00000000000000DE00000000000000DF00000000000000E000000000000000E100000000000000E200000000000000E300000000000000E400000000000000E500000000000000E600000000000000E700000000000000E800000000000000E900000000000000EA00000000000000EB00000000000000EC00000000000000ED00000000000000EE00000000000000EF00000000000000F000000000000000F100000000000000F200000000000000F300000000000000F400000000000000F500000000000000F600000000000000F700000000000000F800000000000000F900000000000000FA00000000000000FB00000000000000FC00000000000000FD00000000000000FE00000000000000FF0000000000000100000000000000010100000000000001020000000000000103000000000000010400000000000001050000000000000106000000000000010700000000000001080000000000000109000000000000010A000000000000010B000000000000010C000000000000010D000000000000010E000000000000010F0000000000000110000000000000011100000000000001120000000000000113000000000000011400000000000001150000000000000116000000000000011700000000000001180000000000000119000000000000011A000000000000011B000000000000011C000000000000011D000000000000011E000000000000011F0000000000000120000000000000012100000000000001220000000000000123000000000000012400000000000001250000000000000126000000000000012700000000000001280000000000000129000000000000012A000000000000012B000000000000012C000000000000012D000000000000012E000000000000012F0000000000000130000000000000013100000000000001320000000000000133000000000000013400000000000001350000000000000136000000000000013700000000000001380000000000000139000000000000013A000000000000013B000000000000013C000000000000013D000000000000013E000000000000013F0000000000000140000000000000014100000000000001420000000000000143000000000000014400000000000001450000000000000146000000000000014700000000000001480000000000000149000000000000014A000000000000014B000000000000014C000000000000014D000000000000014E000000000000014F0000000000000150000000000000015100000000000001520000000000000153000000000000015400000000000001550000000000000156000000000000015700000000000001580000000000000159000000000000015A000000000000015B000000000000015C000000000000015D000000000000015E000000000000015F0000000000000160000000000000016100000000000001620000000000000163000000000000016400000000000001650000000000000166000000000000016700000000000001680000000000000169000000000000016A000000000000016B000000000000016C000000000000016D000000000000016E000000000000016F0000000000000170000000000000017100000000000001720000000000000173000000000000017400000000000001750000000000000176000000000000017700000000000001780000000000000179000000000000017A000000000000017B000000000000017C000000000000017D000000000000017E000000000000017F0000000000000180000000000000018100000000000001820000000000000183000000000000018400000000000001850000000000000186000000000000018700000000000001880000000000000189000000000000018A000000000000018B000000000000018C000000000000018D000000000000018E000000000000018F0000000000000190000000000000019100000000000001920000000000000193000000000000019400000000000001950000000000000196000000000000019700000000000001980000000000000199000000000000019A000000000000019B000000000000019C000000000000019D000000000000019E000000000000019F00000000000001A000000000000001A100000000000001A200000000000001A300000000000001A400000000000001A500000000000001A600000000000001A700000000000001A800000000000001A900000000000001AA00000000000001AB00000000000001AC00000000000001AD00000000000001AE00000000000001AF00000000000001B000000000000001B100000000000001B200000000000001B300000000000001B400000000000001B500000000000001B600000000000001B700000000000001B800000000000001B900000000000001BA00000000000001BB00000000000001BC00000000000001BD00000000000001BE00000000000001BF00000000000001C000000000000001C100000000000001C200000000000001C300000000000001C400000000000001C500000000000001C600000000000001C700000000000001C800000000000001C900000000000001CA00000000000001CB00000000000001CC00000000000001CD00000000000001CE00000000000001CF00000000000001D000000000000001D100000000000001D200000000000001D300000000000001D400000000000001D500000000000001D600000000000001D700000000000001D800000000000001D900000000000001DA00000000000001DB00000000000001DC00000000000001DD00000000000001DE00000000000001DF00000000000001E000000000000001E100000000000001E200000000000001E300000000000001E400000000000001E500000000000001E600000000000001E700000000000001E800000000000001E900000000000001EA00000000000001EB00000000000001EC00000000000001ED00000000000001EE00000000000001EF00000000000001F000000000000001F100000000000001F200000000000001F300000000000001F400000000000001F500000000000001F600000000000001F700000000000001F800000000000001F900000000000001FA00000000000001FB00000000000001FC00000000000001FD00000000000001FE00000000000001FF0000000000000200000000000000020100000000000002020000000000000203000000000000020400000000000002050000000000000206000000000000020700000000000002080000000000000209000000000000020A000000000000020B000000000000020C000000000000020D000000000000020E000000000000020F0000000000000210000000000000021100000000000002120000000000000213000000000000021400000000000002150000000000000216000000000000021700000000000002180000000000000219000000000000021A000000000000021B000000000000021C000000000000021D000000000000021E000000000000021F0000000000000220000000000000022100000000000002220000000000000223000000000000022400000000000002250000000000000226000000000000022700000000000002280000000000000229000000000000022A000000000000022B000000000000022C000000000000022D000000000000022E000000000000022F0000000000000230000000000000023100000000000002320000000000000233000000000000023400000000000002350000000000000236000000000000023700000000000002380000000000000239000000000000023A000000000000023B000000000000023C000000000000023D000000000000023E000000000000023F0000000000000240000000000000024100000000000002420000000000000243000000000000024400000000000002450000000000000246000000000000024700000000000002480000000000000249000000000000024A000000000000024B000000000000024C000000000000024D000000000000024E000000000000024F0000000000000250000000000000025100000000000002520000000000000253000000000000025400000000000002550000000000000256000000000000025700000000000002580000000000000259000000000000025A000000000000025B000000000000025C000000000000025D000000000000025E000000000000025F0000000000000260000000000000026100000000000002620000000000000263000000000000026400000000000002650000000000000266000000000000026700000000000002680000000000000269000000000000026A000000000000026B000000000000026C000000000000026D000000000000026E000000000000026F0000000000000270000000000000027100000000000002720000000000000273000000000000027400000000000002750000000000000276000000000000027700000000000002780000000000000279000000000000027A000000000000027B000000000000027C000000000000027D000000000000027E000000000000027F0000000000000280000000000000028100000000000002820000000000000283000000000000028400000000000002850000000000000286000000000000028700000000000002880000000000000289000000000000028A000000000000028B000000000000028C000000000000028D000000000000028E000000000000028F0000000000000290000000000000029100000000000002920000000000000293000000000000029400000000000002950000000000000296000000000000029700000000000002980000000000000299000000000000029A000000000000029B000000000000029C000000000000029D000000000000029E000000000000029F00000000000002A000000000000002A100000000000002A200000000000002A300000000000002A400000000000002A500000000000002A600000000000002A700000000000002A800000000000002A900000000000002AA00000000000002AB00000000000002AC00000000000002AD00000000000002AE00000000000002AF00000000000002B000000000000002B100000000000002B200000000000002B300000000000002B400000000000002B500000000000002B600000000000002B700000000000002B800000000000002B900000000000002BA00000000000002BB00000000000002BC00000000000002BD00000000000002BE00000000000002BF00000000000002C000000000000002C100000000000002C200000000000002C300000000000002C400000000000002C500000000000002C600000000000002C700000000000002C800000000000002C900000000000002CA00000000000002CB00000000000002CC00000000000002CD00000000000002CE00000000000002CF00000000000002D000000000000002D100000000000002D200000000000002D300000000000002D400000000000002D500000000000002D600000000000002D700000000000002D800000000000002D900000000000002DA00000000000002DB00000000000002DC00000000000002DD00000000000002DE00000000000002DF00000000000002E000000000000002E100000000000002E200000000000002E300000000000002E400000000000002E500000000000002E600000000000002E700000000000002E800000000000002E900000000000002EA00000000000002EB00000000000002EC00000000000002ED00000000000002EE00000000000002EF00000000000002F000000000000002F100000000000002F200000000000002F300000000000002F400000000000002F500000000000002F600000000000002F700000000000002F800000000000002F900000000000002FA00000000000002FB00000000000002FC00000000000002FD00000000000002FE00000000000002FF0000000000000300000000000000030100000000000003020000000000000303000000000000030400000000000003050000000000000306000000000000030700000000000003080000000000000309000000000000030A000000000000030B000000000000030C000000000000030D000000000000030E000000000000030F0000000000000310000000000000031100000000000003120000000000000313000000000000031400000000000003150000000000000316000000000000031700000000000003180000000000000319000000000000031A000000000000031B000000000000031C000000000000031D000000000000031E000000000000031F0000000000000320000000000000032100000000000003220000000000000323000000000000032400000000000003250000000000000326000000000000032700000000000003280000000000000329000000000000032A000000000000032B000000000000032C000000000000032D000000000000032E000000000000032F0000000000000330000000000000033100000000000003320000000000000333000000000000033400000000000003350000000000000336000000000000033700000000000003380000000000000339000000000000033A000000000000033B000000000000033C000000000000033D000000000000033E000000000000033F0000000000000340000000000000034100000000000003420000000000000343000000000000034400000000000003450000000000000346000000000000034700000000000003480000000000000349000000000000034A000000000000034B000000000000034C000000000000034D000000000000034E000000000000034F0000000000000350000000000000035100000000000003520000000000000353000000000000035400000000000003550000000000000356000000000000035700000000000003580000000000000359000000000000035A000000000000035B000000000000035C000000000000035D000000000000035E000000000000035F0000000000000360000000000000036100000000000003620000000000000363000000000000036400000000000003650000000000000366000000000000036700000000000003680000000000000369000000000000036A000000000000036B000000000000036C000000000000036D000000000000036E000000000000036F0000000000000370000000000000037100000000000003720000000000000373000000000000037400000000000003750000000000000376000000000000037700000000000003780000000000000379000000000000037A000000000000037B000000000000037C000000000000037D000000000000037E000000000000037F0000000000000380000000000000038100000000000003820000000000000383000000000000038400000000000003850000000000000386000000000000038700000000000003880000000000000389000000000000038A000000000000038B000000000000038C000000000000038D000000000000038E000000000000038F0000000000000390000000000000039100000000000003920000000000000393000000000000039400000000000003950000000000000396000000000000039700000000000003980000000000000399000000000000039A000000000000039B000000000000039C000000000000039D000000000000039E000000000000039F00000000000003A000000000000003A100000000000003A200000000000003A300000000000003A400000000000003A500000000000003A600000000000003A700000000000003A800000000000003A900000000000003AA00000000000003AB00000000000003AC00000000000003AD00000000000003AE00000000000003AF00000000000003B000000000000003B100000000000003B200000000000003B300000000000003B400000000000003B500000000000003B600000000000003B700000000000003B800000000000003B900000000000003BA00000000000003BB00000000000003BC00000000000003BD00000000000003BE00000000000003BF00000000000003C000000000000003C100000000000003C200000000000003C300000000000003C400000000000003C500000000000003C600000000000003C700000000000003C800000000000003C900000000000003CA00000000000003CB00000000000003CC00000000000003CD00000000000003CE00000000000003CF00000000000003D000000000000003D100000000000003D200000000000003D300000000000003D400000000000003D500000000000003D600000000000003D700000000000003D800000000000003D900000000000003DA00000000000003DB00000000000003DC00000000000003DD00000000000003DE00000000000003DF00000000000003E000000000000003E100000000000003E200000000000003E300000000000003E400000000000003E500000000000003E600000000000003E700000000000003E800000000000003E900000000000003EA00000000000003EB00000000000003EC00000000000003ED00000000000003EE00000000000003EF00000000000003F000000000000003F100000000000003F200000000000003F300000000000003F400000000000003F500000000000003F600000000000003F700000000000003F800000000000003F900000000000003FA00000000000003FB00000000000003FC00000000000003FD00000000000003FE00000000000003FF00000000000004000000000000000"> : tensor<1024xi64> diff --git a/compilers/concrete-compiler/compiler/tests/check_tests/Conversion/TFHEToConcrete/encode_expand_lut_for_bootstrap.mlir b/compilers/concrete-compiler/compiler/tests/check_tests/Conversion/TFHEToConcrete/encode_expand_lut_for_bootstrap.mlir index fd369d907..5cca3ccbe 100644 --- a/compilers/concrete-compiler/compiler/tests/check_tests/Conversion/TFHEToConcrete/encode_expand_lut_for_bootstrap.mlir +++ b/compilers/concrete-compiler/compiler/tests/check_tests/Conversion/TFHEToConcrete/encode_expand_lut_for_bootstrap.mlir @@ -1,4 +1,4 @@ -// RUN: concretecompiler --passes tfhe-to-concrete --action=dump-concrete %s 2>&1| FileCheck %s +// RUN: concretecompiler --passes tfhe-to-concrete --action=dump-concrete --skip-program-info %s 2>&1| FileCheck %s // CHECK: func.func @apply_lookup_table(%arg0: tensor<4xi64>) -> tensor<1024xi64> { // CHECK-NEXT: %0 = "Concrete.encode_expand_lut_for_bootstrap_tensor"(%arg0) {isSigned = true, outputBits = 3 : i32, polySize = 1024 : i32} : (tensor<4xi64>) -> tensor<1024xi64> diff --git a/compilers/concrete-compiler/compiler/tests/check_tests/Conversion/TFHEToConcrete/encode_lut_for_woppbs.mlir b/compilers/concrete-compiler/compiler/tests/check_tests/Conversion/TFHEToConcrete/encode_lut_for_woppbs.mlir index d647061e0..cfbd458a5 100644 --- a/compilers/concrete-compiler/compiler/tests/check_tests/Conversion/TFHEToConcrete/encode_lut_for_woppbs.mlir +++ b/compilers/concrete-compiler/compiler/tests/check_tests/Conversion/TFHEToConcrete/encode_lut_for_woppbs.mlir @@ -1,4 +1,4 @@ -// RUN: concretecompiler --passes tfhe-to-concrete --action=dump-concrete %s 2>&1| FileCheck %s +// RUN: concretecompiler --passes tfhe-to-concrete --action=dump-concrete --skip-program-info %s 2>&1| FileCheck %s // CHECK: func.func @main(%arg0: tensor<4xi64>) -> tensor<5x8192xi64> { // CHECK-NEXT: %0 = "Concrete.encode_lut_for_crt_woppbs_tensor"(%arg0) {crtBits = [1, 2, 3, 3, 4], crtDecomposition = [2, 3, 5, 7, 11], isSigned = false, modulusProduct = 2310 : i32} : (tensor<4xi64>) -> tensor<5x8192xi64> diff --git a/compilers/concrete-compiler/compiler/tests/check_tests/Conversion/TFHEToConcrete/encode_plaintext_with_crt.mlir b/compilers/concrete-compiler/compiler/tests/check_tests/Conversion/TFHEToConcrete/encode_plaintext_with_crt.mlir index 694634319..1ee04a327 100644 --- a/compilers/concrete-compiler/compiler/tests/check_tests/Conversion/TFHEToConcrete/encode_plaintext_with_crt.mlir +++ b/compilers/concrete-compiler/compiler/tests/check_tests/Conversion/TFHEToConcrete/encode_plaintext_with_crt.mlir @@ -1,4 +1,4 @@ -// RUN: concretecompiler --passes tfhe-to-concrete --action=dump-concrete %s 2>&1| FileCheck %s +// RUN: concretecompiler --passes tfhe-to-concrete --action=dump-concrete --skip-program-info %s 2>&1| FileCheck %s // CHECK: func.func @main(%arg0: i64) -> tensor<5xi64> { // CHECK-NEXT: %0 = "Concrete.encode_plaintext_with_crt_tensor"(%arg0) {mods = [2, 3, 5, 7, 11], modsProd = 2310 : i64} : (i64) -> tensor<5xi64> diff --git a/compilers/concrete-compiler/compiler/tests/check_tests/Conversion/TFHEToConcrete/keyswitch.mlir b/compilers/concrete-compiler/compiler/tests/check_tests/Conversion/TFHEToConcrete/keyswitch.mlir index 92a462b7f..47529f92a 100644 --- a/compilers/concrete-compiler/compiler/tests/check_tests/Conversion/TFHEToConcrete/keyswitch.mlir +++ b/compilers/concrete-compiler/compiler/tests/check_tests/Conversion/TFHEToConcrete/keyswitch.mlir @@ -1,4 +1,4 @@ -// RUN: concretecompiler --passes tfhe-to-concrete --action=dump-concrete %s 2>&1| FileCheck %s +// RUN: concretecompiler --passes tfhe-to-concrete --action=dump-concrete --skip-program-info %s 2>&1| FileCheck %s // CHECK: func.func @keyswitch_glwe(%[[A0:.*]]: tensor<1025xi64>) -> tensor<568xi64> { // CHECK-NEXT: %[[V0:.*]] = "Concrete.keyswitch_lwe_tensor"(%[[A0]]) {baseLog = 3 : i32, kskIndex = -1 : i32, level = 2 : i32, lwe_dim_in = 1024 : i32, lwe_dim_out = 567 : i32} : (tensor<1025xi64>) -> tensor<568xi64> diff --git a/compilers/concrete-compiler/compiler/tests/check_tests/Conversion/TFHEToConcrete/mul_glwe_int.mlir b/compilers/concrete-compiler/compiler/tests/check_tests/Conversion/TFHEToConcrete/mul_glwe_int.mlir index 83a782474..ac9e5607e 100644 --- a/compilers/concrete-compiler/compiler/tests/check_tests/Conversion/TFHEToConcrete/mul_glwe_int.mlir +++ b/compilers/concrete-compiler/compiler/tests/check_tests/Conversion/TFHEToConcrete/mul_glwe_int.mlir @@ -1,4 +1,4 @@ -// RUN: concretecompiler --passes tfhe-to-concrete --action=dump-concrete %s 2>&1| FileCheck %s +// RUN: concretecompiler --passes tfhe-to-concrete --action=dump-concrete --skip-program-info %s 2>&1| FileCheck %s //CHECK: func.func @mul_glwe_const_int(%[[A0:.*]]: tensor<1025xi64>) -> tensor<1025xi64> { //CHECK: %c1_i64 = arith.constant 1 : i64 diff --git a/compilers/concrete-compiler/compiler/tests/check_tests/Conversion/TFHEToConcrete/neg_glwe.mlir b/compilers/concrete-compiler/compiler/tests/check_tests/Conversion/TFHEToConcrete/neg_glwe.mlir index 40b6d5884..52e6e013b 100644 --- a/compilers/concrete-compiler/compiler/tests/check_tests/Conversion/TFHEToConcrete/neg_glwe.mlir +++ b/compilers/concrete-compiler/compiler/tests/check_tests/Conversion/TFHEToConcrete/neg_glwe.mlir @@ -1,4 +1,4 @@ -// RUN: concretecompiler --passes tfhe-to-concrete --action=dump-concrete %s 2>&1| FileCheck %s +// RUN: concretecompiler --passes tfhe-to-concrete --action=dump-concrete --skip-program-info %s 2>&1| FileCheck %s // CHECK-LABEL: func.func @neg_glwe(%arg0: tensor<1025xi64>) -> tensor<1025xi64> func.func @neg_glwe(%arg0: !TFHE.glwe>) -> !TFHE.glwe> { diff --git a/compilers/concrete-compiler/compiler/tests/check_tests/Conversion/TFHEToConcrete/sub_int_glwe.mlir b/compilers/concrete-compiler/compiler/tests/check_tests/Conversion/TFHEToConcrete/sub_int_glwe.mlir index 36f3a66d5..64ae8cef4 100644 --- a/compilers/concrete-compiler/compiler/tests/check_tests/Conversion/TFHEToConcrete/sub_int_glwe.mlir +++ b/compilers/concrete-compiler/compiler/tests/check_tests/Conversion/TFHEToConcrete/sub_int_glwe.mlir @@ -1,4 +1,4 @@ -// RUN: concretecompiler --passes tfhe-to-concrete --action=dump-concrete %s 2>&1| FileCheck %s +// RUN: concretecompiler --passes tfhe-to-concrete --action=dump-concrete --skip-program-info %s 2>&1| FileCheck %s //CHECK: func.func @sub_const_int_glwe(%[[A0:.*]]: tensor<1025xi64>) -> tensor<1025xi64> { //CHECK: %c1_i64 = arith.constant 1 : i64 diff --git a/compilers/concrete-compiler/compiler/tests/check_tests/Conversion/TFHEToConcrete/tensor_exapand_collapse_shape.mlir b/compilers/concrete-compiler/compiler/tests/check_tests/Conversion/TFHEToConcrete/tensor_exapand_collapse_shape.mlir index e608ae78e..8a77262cc 100644 --- a/compilers/concrete-compiler/compiler/tests/check_tests/Conversion/TFHEToConcrete/tensor_exapand_collapse_shape.mlir +++ b/compilers/concrete-compiler/compiler/tests/check_tests/Conversion/TFHEToConcrete/tensor_exapand_collapse_shape.mlir @@ -1,4 +1,4 @@ -// RUN: concretecompiler --split-input-file --passes tfhe-to-concrete --action=dump-concrete %s 2>&1| FileCheck %s +// RUN: concretecompiler --split-input-file --passes tfhe-to-concrete --action=dump-concrete --skip-program-info %s 2>&1| FileCheck %s //CHECK: func.func @tensor_collapse_shape(%[[A0:.*]]: tensor<2x3x4x5x6x1025xi64>) -> tensor<720x1025xi64> { //CHECK: %[[V0:.*]] = tensor.collapse_shape %[[A0]] [[_:\[\[0, 1, 2, 3, 4\], \[5\]\]]] : tensor<2x3x4x5x6x1025xi64> into tensor<720x1025xi64> diff --git a/compilers/concrete-compiler/compiler/tests/check_tests/Conversion/TFHEToConcrete/tensor_from_elements.mlir b/compilers/concrete-compiler/compiler/tests/check_tests/Conversion/TFHEToConcrete/tensor_from_elements.mlir index 7a8608af7..60f9620a5 100644 --- a/compilers/concrete-compiler/compiler/tests/check_tests/Conversion/TFHEToConcrete/tensor_from_elements.mlir +++ b/compilers/concrete-compiler/compiler/tests/check_tests/Conversion/TFHEToConcrete/tensor_from_elements.mlir @@ -1,4 +1,4 @@ -// RUN: concretecompiler --passes tfhe-to-concrete --action=dump-concrete --split-input-file %s 2>&1| FileCheck %s +// RUN: concretecompiler --passes tfhe-to-concrete --action=dump-concrete --split-input-file --skip-program-info %s 2>&1| FileCheck %s // CHECK: func.func @main(%[[A0:.*]]: tensor<2049xi64>, %[[A1:.*]]: tensor<2049xi64>, %[[A2:.*]]: tensor<2049xi64>, %[[A3:.*]]: tensor<2049xi64>, %[[A4:.*]]: tensor<2049xi64>, %[[A5:.*]]: tensor<2049xi64>) -> tensor<6x2049xi64> { // CHECK: %[[V0:.*]] = bufferization.alloc_tensor() : tensor<6x2049xi64> diff --git a/compilers/concrete-compiler/compiler/tests/check_tests/Conversion/TFHEToConcrete/tensor_identity.mlir b/compilers/concrete-compiler/compiler/tests/check_tests/Conversion/TFHEToConcrete/tensor_identity.mlir index 43d5a2f75..eebde108e 100644 --- a/compilers/concrete-compiler/compiler/tests/check_tests/Conversion/TFHEToConcrete/tensor_identity.mlir +++ b/compilers/concrete-compiler/compiler/tests/check_tests/Conversion/TFHEToConcrete/tensor_identity.mlir @@ -1,4 +1,4 @@ -// RUN: concretecompiler --passes tfhe-to-concrete --action=dump-concrete %s 2>&1| FileCheck %s +// RUN: concretecompiler --passes tfhe-to-concrete --action=dump-concrete --skip-program-info %s 2>&1| FileCheck %s // CHECK: func.func @tensor_identity(%arg0: tensor<2x3x4x1025xi64>) -> tensor<2x3x4x1025xi64> { // CHECK-NEXT: return %arg0 : tensor<2x3x4x1025xi64> diff --git a/compilers/concrete-compiler/compiler/tests/check_tests/Dialect/FHE/folding.mlir b/compilers/concrete-compiler/compiler/tests/check_tests/Dialect/FHE/folding.mlir index 16aa1e897..a323cc332 100644 --- a/compilers/concrete-compiler/compiler/tests/check_tests/Dialect/FHE/folding.mlir +++ b/compilers/concrete-compiler/compiler/tests/check_tests/Dialect/FHE/folding.mlir @@ -1,4 +1,4 @@ -// RUN: concretecompiler --action=dump-fhe %s 2>&1| FileCheck %s +// RUN: concretecompiler --action=dump-fhe --optimizer-strategy=V0 --skip-program-info %s 2>&1| FileCheck %s // CHECK-LABEL: func.func @add_eint_int(%arg0: !FHE.eint<2>) -> !FHE.eint<2> func.func @add_eint_int(%arg0: !FHE.eint<2>) -> !FHE.eint<2> { diff --git a/compilers/concrete-compiler/compiler/tests/check_tests/Dialect/FHELinalg/folding.mlir b/compilers/concrete-compiler/compiler/tests/check_tests/Dialect/FHELinalg/folding.mlir index db55f3fa7..23d8147d0 100644 --- a/compilers/concrete-compiler/compiler/tests/check_tests/Dialect/FHELinalg/folding.mlir +++ b/compilers/concrete-compiler/compiler/tests/check_tests/Dialect/FHELinalg/folding.mlir @@ -1,4 +1,4 @@ -// RUN: concretecompiler --action=dump-fhe %s 2>&1| FileCheck %s +// RUN: concretecompiler --action=dump-fhe --optimizer-strategy=V0 --skip-program-info %s 2>&1| FileCheck %s // CHECK: func.func @add_eint_int_1D(%[[a0:.*]]: tensor<4x!FHE.eint<2>>) -> tensor<4x!FHE.eint<2>> { // CHECK-NEXT: return %[[a0]] : tensor<4x!FHE.eint<2>> diff --git a/compilers/concrete-compiler/compiler/tests/check_tests/Dialect/TFHE/no_optimization.mlir b/compilers/concrete-compiler/compiler/tests/check_tests/Dialect/TFHE/no_optimization.mlir index a4972fb7a..d80c280d4 100644 --- a/compilers/concrete-compiler/compiler/tests/check_tests/Dialect/TFHE/no_optimization.mlir +++ b/compilers/concrete-compiler/compiler/tests/check_tests/Dialect/TFHE/no_optimization.mlir @@ -1,4 +1,4 @@ -// RUN: concretecompiler --passes tfhe-optimization --optimize-tfhe=false --action=dump-tfhe %s 2>&1| FileCheck %s +// RUN: concretecompiler --passes tfhe-optimization --optimize-tfhe=false --action=dump-tfhe --skip-program-info %s 2>&1| FileCheck %s //CHECK: func.func @mul_cleartext_glwe_ciphertext_0(%[[A0:.*]]: !TFHE.glwe>) -> !TFHE.glwe> { //CHECK: %c0_i64 = arith.constant 0 : i64 diff --git a/compilers/concrete-compiler/compiler/tests/check_tests/Dialect/TFHE/optimization.mlir b/compilers/concrete-compiler/compiler/tests/check_tests/Dialect/TFHE/optimization.mlir index a7ed3f83d..ef4edbfbb 100644 --- a/compilers/concrete-compiler/compiler/tests/check_tests/Dialect/TFHE/optimization.mlir +++ b/compilers/concrete-compiler/compiler/tests/check_tests/Dialect/TFHE/optimization.mlir @@ -1,4 +1,4 @@ -// RUN: concretecompiler --passes tfhe-optimization --action=dump-tfhe %s 2>&1| FileCheck %s +// RUN: concretecompiler --passes tfhe-optimization --action=dump-tfhe --skip-program-info %s 2>&1| FileCheck %s // CHECK-LABEL: func.func @mul_cleartext_lwe_ciphertext(%arg0: !TFHE.glwe>, %arg1: i64) -> !TFHE.glwe> diff --git a/compilers/concrete-compiler/compiler/tests/check_tests/Transforms/batching.mlir b/compilers/concrete-compiler/compiler/tests/check_tests/Transforms/batching.mlir index a4c7f6af9..814dc7bde 100644 --- a/compilers/concrete-compiler/compiler/tests/check_tests/Transforms/batching.mlir +++ b/compilers/concrete-compiler/compiler/tests/check_tests/Transforms/batching.mlir @@ -1,4 +1,4 @@ -// RUN: concretecompiler --split-input-file --action=dump-batched-tfhe --batch-tfhe-ops %s 2>&1| FileCheck %s +// RUN: concretecompiler --split-input-file --action=dump-batched-tfhe --batch-tfhe-ops --skip-program-info %s 2>&1| FileCheck %s // CHECK-LABEL: func.func @batch_continuous_slice_keyswitch // CHECK: (%arg0: tensor<2x3x4x!TFHE.glwe>>) -> tensor<2x3x4x!TFHE.glwe>> { diff --git a/compilers/concrete-compiler/compiler/tests/check_tests/TypeInference/inference.mlir b/compilers/concrete-compiler/compiler/tests/check_tests/TypeInference/inference.mlir index 732c4c304..007bd6e52 100644 --- a/compilers/concrete-compiler/compiler/tests/check_tests/TypeInference/inference.mlir +++ b/compilers/concrete-compiler/compiler/tests/check_tests/TypeInference/inference.mlir @@ -1,4 +1,4 @@ -// RUN: concretecompiler --split-input-file --action=dump-parametrized-tfhe --optimizer-strategy=dag-multi %s 2>&1| FileCheck %s +// RUN: concretecompiler --split-input-file --action=dump-parametrized-tfhe --optimizer-strategy=dag-multi --skip-program-info %s 2>&1| FileCheck %s // CHECK: func.func @funconly_fwd(%arg0: !TFHE.glwe>) -> !TFHE.glwe> { // CHECK-NEXT: return %arg0 : !TFHE.glwe> diff --git a/compilers/concrete-compiler/compiler/tests/end_to_end_benchmarks/end_to_end_benchmark.cpp b/compilers/concrete-compiler/compiler/tests/end_to_end_benchmarks/end_to_end_benchmark.cpp index 4371c0765..897f3e8eb 100644 --- a/compilers/concrete-compiler/compiler/tests/end_to_end_benchmarks/end_to_end_benchmark.cpp +++ b/compilers/concrete-compiler/compiler/tests/end_to_end_benchmarks/end_to_end_benchmark.cpp @@ -1,6 +1,6 @@ #include "../end_to_end_tests/end_to_end_test.h" #include "concretelang/Common/Compat.h" -#include "concretelang/TestLib/TestCircuit.h" +#include "concretelang/TestLib/TestProgram.h" #include #include @@ -23,7 +23,7 @@ using namespace concretelang::testlib; /// Benchmark time of the compilation static void BM_Compile(benchmark::State &state, EndToEndDesc description, mlir::concretelang::CompilationOptions options) { - TestCircuit tc(options); + TestProgram tc(options); for (auto _ : state) { assert(tc.compile(description.program)); } @@ -32,7 +32,7 @@ static void BM_Compile(benchmark::State &state, EndToEndDesc description, /// Benchmark time of the key generation static void BM_KeyGen(benchmark::State &state, EndToEndDesc description, mlir::concretelang::CompilationOptions options) { - TestCircuit tc(options); + TestProgram tc(options); assert(tc.compile(description.program)); for (auto _ : state) { @@ -44,7 +44,7 @@ static void BM_KeyGen(benchmark::State &state, EndToEndDesc description, static void BM_ExportArguments(benchmark::State &state, EndToEndDesc description, mlir::concretelang::CompilationOptions options) { - TestCircuit tc(options); + TestProgram tc(options); assert(tc.compile(description.program)); assert(tc.generateKeyset()); @@ -68,7 +68,7 @@ static void BM_ExportArguments(benchmark::State &state, /// Benchmark time of the program evaluation static void BM_Evaluate(benchmark::State &state, EndToEndDesc description, mlir::concretelang::CompilationOptions options) { - TestCircuit tc(options); + TestProgram tc(options); assert(tc.compile(description.program)); assert(tc.generateKeyset()); auto clientCircuit = tc.getClientCircuit().value(); @@ -111,7 +111,6 @@ void registerEndToEndBenchmark(std::string suiteName, int num_iterations = 0) { auto optionsName = getOptionsName(options); for (auto description : descriptions) { - options.mainFuncName = "main"; if (description.p_error) { assert(std::isnan(options.optimizerConfig.global_p_error)); options.optimizerConfig.p_error = description.p_error.value(); diff --git a/compilers/concrete-compiler/compiler/tests/end_to_end_benchmarks/end_to_end_mlbench.cpp b/compilers/concrete-compiler/compiler/tests/end_to_end_benchmarks/end_to_end_mlbench.cpp index f473337f3..5c69a6616 100644 --- a/compilers/concrete-compiler/compiler/tests/end_to_end_benchmarks/end_to_end_mlbench.cpp +++ b/compilers/concrete-compiler/compiler/tests/end_to_end_benchmarks/end_to_end_mlbench.cpp @@ -1,5 +1,5 @@ #include "concretelang/Common/Compat.h" -#include "concretelang/TestLib/TestCircuit.h" +#include "concretelang/TestLib/TestProgram.h" #include "end_to_end_fixture/EndToEndFixture.h" #include #define BENCHMARK_HAS_CXX11 diff --git a/compilers/concrete-compiler/compiler/tests/end_to_end_tests/end_to_end_jit_aes_short.cc b/compilers/concrete-compiler/compiler/tests/end_to_end_tests/end_to_end_jit_aes_short.cc index 03edf6259..d14a74aa3 100644 --- a/compilers/concrete-compiler/compiler/tests/end_to_end_tests/end_to_end_jit_aes_short.cc +++ b/compilers/concrete-compiler/compiler/tests/end_to_end_tests/end_to_end_jit_aes_short.cc @@ -4,7 +4,7 @@ #include #include -#include "concretelang/TestLib/TestCircuit.h" +#include "concretelang/TestLib/TestProgram.h" #include "end_to_end_jit_test.h" #include "tests_tools/GtestEnvironment.h" std::vector distributed_results; diff --git a/compilers/concrete-compiler/compiler/tests/end_to_end_tests/end_to_end_jit_auto_parallelization.cc b/compilers/concrete-compiler/compiler/tests/end_to_end_tests/end_to_end_jit_auto_parallelization.cc index 84c9416e3..d1391cc2f 100644 --- a/compilers/concrete-compiler/compiler/tests/end_to_end_tests/end_to_end_jit_auto_parallelization.cc +++ b/compilers/concrete-compiler/compiler/tests/end_to_end_tests/end_to_end_jit_auto_parallelization.cc @@ -5,7 +5,7 @@ #include #include -#include "concretelang/TestLib/TestCircuit.h" +#include "concretelang/TestLib/TestProgram.h" #include "end_to_end_jit_test.h" #include "tests_tools/GtestEnvironment.h" /////////////////////////////////////////////////////////////////////////////// diff --git a/compilers/concrete-compiler/compiler/tests/end_to_end_tests/end_to_end_jit_chunked_int.cc b/compilers/concrete-compiler/compiler/tests/end_to_end_tests/end_to_end_jit_chunked_int.cc index c485d8177..3e542ae05 100644 --- a/compilers/concrete-compiler/compiler/tests/end_to_end_tests/end_to_end_jit_chunked_int.cc +++ b/compilers/concrete-compiler/compiler/tests/end_to_end_tests/end_to_end_jit_chunked_int.cc @@ -1,6 +1,6 @@ #include -#include "concretelang/TestLib/TestCircuit.h" +#include "concretelang/TestLib/TestProgram.h" #include "end_to_end_jit_test.h" #include "tests_tools/GtestEnvironment.h" diff --git a/compilers/concrete-compiler/compiler/tests/end_to_end_tests/end_to_end_jit_distributed.cc b/compilers/concrete-compiler/compiler/tests/end_to_end_tests/end_to_end_jit_distributed.cc index 840c48225..8b9ff2d45 100644 --- a/compilers/concrete-compiler/compiler/tests/end_to_end_tests/end_to_end_jit_distributed.cc +++ b/compilers/concrete-compiler/compiler/tests/end_to_end_tests/end_to_end_jit_distributed.cc @@ -4,7 +4,7 @@ #include #include -#include "concretelang/TestLib/TestCircuit.h" +#include "concretelang/TestLib/TestProgram.h" #include "end_to_end_jit_test.h" #include "tests_tools/GtestEnvironment.h" diff --git a/compilers/concrete-compiler/compiler/tests/end_to_end_tests/end_to_end_jit_lambda.cc b/compilers/concrete-compiler/compiler/tests/end_to_end_tests/end_to_end_jit_lambda.cc index 7f4aed329..0f11405b3 100644 --- a/compilers/concrete-compiler/compiler/tests/end_to_end_tests/end_to_end_jit_lambda.cc +++ b/compilers/concrete-compiler/compiler/tests/end_to_end_tests/end_to_end_jit_lambda.cc @@ -1,6 +1,6 @@ #include -#include "concretelang/TestLib/TestCircuit.h" +#include "concretelang/TestLib/TestProgram.h" #include "end_to_end_jit_test.h" #include "tests_tools/GtestEnvironment.h" #include "tests_tools/assert.h" diff --git a/compilers/concrete-compiler/compiler/tests/end_to_end_tests/end_to_end_jit_test.cc b/compilers/concrete-compiler/compiler/tests/end_to_end_tests/end_to_end_jit_test.cc index 0dcb52b7a..ec13275e4 100644 --- a/compilers/concrete-compiler/compiler/tests/end_to_end_tests/end_to_end_jit_test.cc +++ b/compilers/concrete-compiler/compiler/tests/end_to_end_tests/end_to_end_jit_test.cc @@ -4,7 +4,7 @@ #include #include -#include "concretelang/TestLib/TestCircuit.h" +#include "concretelang/TestLib/TestProgram.h" #include "end_to_end_jit_test.h" #include "tests_tools/GtestEnvironment.h" @@ -395,10 +395,10 @@ func.func @main(%arg0: !FHE.eint<3>) -> !FHE.eint<3> { } TEST(CompileNotComposable, not_composable_1) { - mlir::concretelang::CompilationOptions options("main"); + mlir::concretelang::CompilationOptions options; options.optimizerConfig.composable = true; options.optimizerConfig.strategy = mlir::concretelang::optimizer::DAG_MULTI; - TestCircuit circuit(options); + TestProgram circuit(options); auto err = circuit.compile(R"XXX( func.func @main(%arg0: !FHE.eint<3>) -> !FHE.eint<3> { %cst_1 = arith.constant 1 : i4 @@ -411,11 +411,11 @@ func.func @main(%arg0: !FHE.eint<3>) -> !FHE.eint<3> { } TEST(CompileNotComposable, not_composable_2) { - mlir::concretelang::CompilationOptions options("main"); + mlir::concretelang::CompilationOptions options; options.optimizerConfig.composable = true; options.optimizerConfig.display = true; options.optimizerConfig.strategy = mlir::concretelang::optimizer::DAG_MULTI; - TestCircuit circuit(options); + TestProgram circuit(options); auto err = circuit.compile(R"XXX( func.func @main(%arg0: !FHE.eint<3>) -> (!FHE.eint<3>, !FHE.eint<3>) { %cst_1 = arith.constant 1 : i4 @@ -430,11 +430,11 @@ func.func @main(%arg0: !FHE.eint<3>) -> (!FHE.eint<3>, !FHE.eint<3>) { } TEST(CompileComposable, composable_supported_dag_mono) { - mlir::concretelang::CompilationOptions options("main"); + mlir::concretelang::CompilationOptions options; options.optimizerConfig.composable = true; options.optimizerConfig.display = true; options.optimizerConfig.strategy = mlir::concretelang::optimizer::DAG_MONO; - TestCircuit circuit(options); + TestProgram circuit(options); auto err = circuit.compile(R"XXX( func.func @main(%arg0: !FHE.eint<3>) -> !FHE.eint<3> { %cst_1 = arith.constant 1 : i4 @@ -446,11 +446,11 @@ func.func @main(%arg0: !FHE.eint<3>) -> !FHE.eint<3> { } TEST(CompileComposable, composable_supported_v0) { - mlir::concretelang::CompilationOptions options("main"); + mlir::concretelang::CompilationOptions options; options.optimizerConfig.composable = true; options.optimizerConfig.display = true; options.optimizerConfig.strategy = mlir::concretelang::optimizer::V0; - TestCircuit circuit(options); + TestProgram circuit(options); auto err = circuit.compile(R"XXX( func.func @main(%arg0: !FHE.eint<3>) -> !FHE.eint<3> { %cst_1 = arith.constant 1 : i4 @@ -460,3 +460,39 @@ func.func @main(%arg0: !FHE.eint<3>) -> !FHE.eint<3> { )XXX"); assert(err.has_value()); } + +TEST(CompileMultiFunctions, multi_functions_v0) { + mlir::concretelang::CompilationOptions options; + options.optimizerConfig.strategy = mlir::concretelang::optimizer::V0; + TestProgram circuit(options); + auto err = circuit.compile(R"XXX( +func.func @inc(%arg0: !FHE.eint<3>) -> !FHE.eint<3> { + %cst_1 = arith.constant 1 : i4 + %1 = "FHE.add_eint_int"(%arg0, %cst_1) : (!FHE.eint<3>, i4) -> !FHE.eint<3> + return %1: !FHE.eint<3> +} +func.func @dec(%arg0: !FHE.eint<3>) -> !FHE.eint<3> { + %cst_1 = arith.constant 1 : i4 + %1 = "FHE.sub_eint_int"(%arg0, %cst_1) : (!FHE.eint<3>, i4) -> !FHE.eint<3> + return %1: !FHE.eint<3> +} +)XXX"); + assert(err.has_value()); + assert(circuit.generateKeyset().has_value()); + auto lambda_inc = [&](std::vector args) { + return circuit.call(args, "inc") + .value()[0] + .template getTensor() + .value()[0]; + }; + auto lambda_dec = [&](std::vector args) { + return circuit.call(args, "dec") + .value()[0] + .template getTensor() + .value()[0]; + }; + ASSERT_EQ(lambda_inc({Tensor(1)}), (uint64_t)2); + ASSERT_EQ(lambda_inc({Tensor(4)}), (uint64_t)5); + ASSERT_EQ(lambda_dec({Tensor(1)}), (uint64_t)0); + ASSERT_EQ(lambda_dec({Tensor(4)}), (uint64_t)3); +} diff --git a/compilers/concrete-compiler/compiler/tests/end_to_end_tests/end_to_end_jit_test.h b/compilers/concrete-compiler/compiler/tests/end_to_end_tests/end_to_end_jit_test.h index cfbc0bccd..89a2411d4 100644 --- a/compilers/concrete-compiler/compiler/tests/end_to_end_tests/end_to_end_jit_test.h +++ b/compilers/concrete-compiler/compiler/tests/end_to_end_tests/end_to_end_jit_test.h @@ -5,7 +5,7 @@ #include "concretelang/Common/Error.h" #include "concretelang/Support/CompilerEngine.h" #include "concretelang/Support/V0Parameters.h" -#include "concretelang/TestLib/TestCircuit.h" +#include "concretelang/TestLib/TestProgram.h" #include "cstdlib" #include "end_to_end_test.h" #include "globals.h" @@ -14,7 +14,7 @@ using concretelang::error::Result; using concretelang::error::StringError; -using concretelang::testlib::TestCircuit; +using concretelang::testlib::TestProgram; llvm::StringRef DEFAULT_func = "main"; bool DEFAULT_useDefaultFHEConstraints = false; @@ -30,7 +30,7 @@ bool DEFAULT_composable = false; // Jit-compiles the function specified by `func` from `src` and // returns the corresponding lambda. Any compilation errors are caught // and reult in abnormal termination. -inline Result internalCheckedJit( +inline Result internalCheckedJit( llvm::StringRef src, llvm::StringRef func = DEFAULT_func, bool useDefaultFHEConstraints = DEFAULT_useDefaultFHEConstraints, bool dataflowParallelize = DEFAULT_dataflowParallelize, @@ -42,8 +42,7 @@ inline Result internalCheckedJit( unsigned int chunkWidth = DEFAULT_chunkWidth, bool composable = DEFAULT_composable) { - auto options = - mlir::concretelang::CompilationOptions(std::string(func.data())); + auto options = mlir::concretelang::CompilationOptions(); options.optimizerConfig.global_p_error = global_p_error; options.chunkIntegers = chunkedIntegers; options.chunkSize = chunkSize; @@ -70,10 +69,10 @@ inline Result internalCheckedJit( } std::vector sources = {src.str()}; - TestCircuit testCircuit(options); - OUTCOME_TRYV(testCircuit.compile({src.str()})); - OUTCOME_TRYV(testCircuit.generateKeyset()); - return std::move(testCircuit); + TestProgram testProgram(options); + OUTCOME_TRYV(testProgram.compile({src.str()})); + OUTCOME_TRYV(testProgram.generateKeyset()); + return std::move(testProgram); } // Wrapper around `internalCheckedJit` that causes diff --git a/compilers/concrete-compiler/compiler/tests/end_to_end_tests/end_to_end_test.cc b/compilers/concrete-compiler/compiler/tests/end_to_end_tests/end_to_end_test.cc index 4f3559221..617cfb9a4 100644 --- a/compilers/concrete-compiler/compiler/tests/end_to_end_tests/end_to_end_test.cc +++ b/compilers/concrete-compiler/compiler/tests/end_to_end_tests/end_to_end_test.cc @@ -7,14 +7,14 @@ #include "concretelang/Common/Values.h" #include "concretelang/Support/CompilationFeedback.h" -#include "concretelang/TestLib/TestCircuit.h" +#include "concretelang/TestLib/TestProgram.h" #include "end_to_end_fixture/EndToEndFixture.h" #include "end_to_end_jit_test.h" #include "tests_tools/GtestEnvironment.h" #include "tests_tools/assert.h" #include "tests_tools/keySetCache.h" -using concretelang::testlib::TestCircuit; +using concretelang::testlib::TestProgram; using concretelang::values::Value; /// @brief EndToEndTest is a template that allows testing for one program for a @@ -35,7 +35,7 @@ public: }; void SetUp() override { - TestCircuit tc(options.compilationOptions); + TestProgram tc(options.compilationOptions); ASSERT_OUTCOME_HAS_VALUE(tc.compile({program})); ASSERT_OUTCOME_HAS_VALUE(tc.generateKeyset()); testCircuit.emplace(std::move(tc)); @@ -122,7 +122,7 @@ private: TestDescription desc; std::optional errorRate; std::optional library; - std::optional testCircuit; + std::optional testCircuit; EndToEndTestOptions options; std::vector args; }; diff --git a/compilers/concrete-compiler/compiler/tests/end_to_end_tests/end_to_end_test.h b/compilers/concrete-compiler/compiler/tests/end_to_end_tests/end_to_end_test.h index d15ceea3c..a1b5859f9 100644 --- a/compilers/concrete-compiler/compiler/tests/end_to_end_tests/end_to_end_test.h +++ b/compilers/concrete-compiler/compiler/tests/end_to_end_tests/end_to_end_test.h @@ -135,8 +135,7 @@ parseEndToEndCommandLine(int argc, char **argv) { llvm::cl::ParseCommandLineOptions(argc, argv); // Build compilation options - mlir::concretelang::CompilationOptions compilationOptions("main", - backend.getValue()); + mlir::concretelang::CompilationOptions compilationOptions(backend.getValue()); if (loopParallelize.getValue() != -1) compilationOptions.loopParallelize = loopParallelize.getValue(); diff --git a/compilers/concrete-compiler/compiler/tests/python/test_compilation.py b/compilers/concrete-compiler/compiler/tests/python/test_compilation.py index 81230e6e8..7d3eaf757 100644 --- a/compilers/concrete-compiler/compiler/tests/python/test_compilation.py +++ b/compilers/concrete-compiler/compiler/tests/python/test_compilation.py @@ -7,7 +7,8 @@ from concrete.compiler import ( LibrarySupport, ClientSupport, CompilationOptions, - CompilationFeedback, + ProgramCompilationFeedback, + CircuitCompilationFeedback, ) @@ -23,31 +24,40 @@ def assert_result(result, expected_result): assert np.all(result == expected_result) -def run(engine, args, compilation_result, keyset_cache): +def run(engine, args, compilation_result, keyset_cache, circuit_name="main"): """Execute engine on the given arguments. Perform required loading, encryption, execution, and decryption.""" # Dev compilation_feedback = engine.load_compilation_feedback(compilation_result) - assert isinstance(compilation_feedback, CompilationFeedback) + assert isinstance(compilation_feedback, ProgramCompilationFeedback) assert isinstance(compilation_feedback.complexity, float) assert isinstance(compilation_feedback.p_error, float) assert isinstance(compilation_feedback.global_p_error, float) assert isinstance(compilation_feedback.total_secret_keys_size, int) assert isinstance(compilation_feedback.total_bootstrap_keys_size, int) - assert isinstance(compilation_feedback.total_inputs_size, int) - assert isinstance(compilation_feedback.total_output_size, int) + assert isinstance(compilation_feedback.circuit_feedbacks, list) + circuit_feedback = next( + filter(lambda x: x.name == circuit_name, compilation_feedback.circuit_feedbacks) + ) + assert isinstance(circuit_feedback, CircuitCompilationFeedback) + assert isinstance(circuit_feedback.total_inputs_size, int) + assert isinstance(circuit_feedback.total_output_size, int) # Client client_parameters = engine.load_client_parameters(compilation_result) key_set = ClientSupport.key_set(client_parameters, keyset_cache) - public_arguments = ClientSupport.encrypt_arguments(client_parameters, key_set, args) + public_arguments = ClientSupport.encrypt_arguments( + client_parameters, key_set, args, circuit_name + ) # Server - server_lambda = engine.load_server_lambda(compilation_result, False) + server_lambda = engine.load_server_lambda(compilation_result, False, circuit_name) evaluation_keys = key_set.get_evaluation_keys() public_result = engine.server_call(server_lambda, public_arguments, evaluation_keys) # Client - result = ClientSupport.decrypt_result(client_parameters, key_set, public_result) + result = ClientSupport.decrypt_result( + client_parameters, key_set, public_result, circuit_name + ) return result @@ -57,11 +67,12 @@ def compile_run_assert( args, expected_result, keyset_cache, - options=CompilationOptions.new("main"), + options=CompilationOptions.new(), + circuit_name="main", ): """Compile run and assert result.""" compilation_result = engine.compile(mlir_input, options) - result = run(engine, args, compilation_result, keyset_cache) + result = run(engine, args, compilation_result, keyset_cache, circuit_name) assert_result(result, expected_result) @@ -231,6 +242,33 @@ def test_lib_compilation_artifacts(): assert not os.path.exists(engine.get_shared_lib_path()) +def test_multi_circuits(keyset_cache): + from mlir._mlir_libs._concretelang._compiler import OptimizerStrategy + + mlir_str = """ + func.func @add(%arg0: !FHE.eint<7>, %arg1: !FHE.eint<7>) -> !FHE.eint<7> { + %1 = "FHE.add_eint"(%arg0, %arg1): (!FHE.eint<7>, !FHE.eint<7>) -> (!FHE.eint<7>) + return %1: !FHE.eint<7> + } + func.func @sub(%arg0: !FHE.eint<7>, %arg1: !FHE.eint<7>) -> !FHE.eint<7> { + %1 = "FHE.sub_eint"(%arg0, %arg1): (!FHE.eint<7>, !FHE.eint<7>) -> (!FHE.eint<7>) + return %1: !FHE.eint<7> + } + """ + args = (10, 3) + expected_add_result = 13 + expected_sub_result = 7 + engine = LibrarySupport.new("./py_test_multi_circuits") + options = CompilationOptions.new() + options.set_optimizer_strategy(OptimizerStrategy.V0) + compile_run_assert( + engine, mlir_str, args, expected_add_result, keyset_cache, options, "add" + ) + compile_run_assert( + engine, mlir_str, args, expected_sub_result, keyset_cache, options, "sub" + ) + + def _test_lib_compile_and_run_with_options(keyset_cache, options): mlir_input = """ func.func @main(%arg0: !FHE.eint<7>) -> !FHE.eint<7> { @@ -247,21 +285,21 @@ def _test_lib_compile_and_run_with_options(keyset_cache, options): def test_lib_compile_and_run_p_error(keyset_cache): - options = CompilationOptions.new("main") + options = CompilationOptions.new() options.set_p_error(0.00001) options.set_display_optimizer_choice(True) _test_lib_compile_and_run_with_options(keyset_cache, options) def test_lib_compile_and_run_global_p_error(keyset_cache): - options = CompilationOptions.new("main") + options = CompilationOptions.new() options.set_global_p_error(0.00001) options.set_display_optimizer_choice(True) _test_lib_compile_and_run_with_options(keyset_cache, options) def test_lib_compile_and_run_security_level(keyset_cache): - options = CompilationOptions.new("main") + options = CompilationOptions.new() options.set_security_level(80) options.set_display_optimizer_choice(True) _test_lib_compile_and_run_with_options(keyset_cache, options) @@ -276,7 +314,7 @@ def test_compile_and_run_auto_parallelize( ): artifact_dir = "./py_test_compile_and_run_auto_parallelize" engine = LibrarySupport.new(artifact_dir) - options = CompilationOptions.new("main") + options = CompilationOptions.new() options.set_auto_parallelize(True) compile_run_assert( engine, mlir_input, args, expected_result, keyset_cache, options=options @@ -301,7 +339,7 @@ def test_compile_and_run_auto_parallelize( # if no_parallel: # artifact_dir = "./py_test_compile_dataflow_and_fail_run" # engine = LibrarySupport.new(artifact_dir) -# options = CompilationOptions.new("main") +# options = CompilationOptions.new() # options.set_auto_parallelize(True) # with pytest.raises( # RuntimeError, @@ -334,7 +372,7 @@ def test_compile_and_run_loop_parallelize( ): artifact_dir = "./py_test_compile_and_run_loop_parallelize" engine = LibrarySupport.new(artifact_dir) - options = CompilationOptions.new("main") + options = CompilationOptions.new() options.set_loop_parallelize(True) compile_run_assert( engine, mlir_input, args, expected_result, keyset_cache, options=options @@ -365,29 +403,6 @@ def test_compile_and_run_invalid_arg_number(mlir_input, args, keyset_cache): compile_run_assert(engine, mlir_input, args, None, keyset_cache) -@pytest.mark.parametrize( - "mlir_input", - [ - pytest.param( - """ - func.func @test(%arg0: tensor<4x!FHE.eint<7>>, %arg1: tensor<4xi8>) -> !FHE.eint<7> - { - %ret = "FHELinalg.dot_eint_int"(%arg0, %arg1) : - (tensor<4x!FHE.eint<7>>, tensor<4xi8>) -> !FHE.eint<7> - return %ret : !FHE.eint<7> - } - """, - id="not @main", - ), - ], -) -def test_compile_invalid(mlir_input): - artifact_dir = "./py_test_compile_invalid" - engine = LibrarySupport.new(artifact_dir) - with pytest.raises(RuntimeError, match=r"Function not found, name='main'"): - engine.compile(mlir_input) - - def test_crt_decomposition_feedback(): mlir = """ @@ -401,11 +416,27 @@ func.func @main(%arg0: !FHE.eint<16>) -> !FHE.eint<16> { artifact_dir = "./py_test_crt_decomposition_feedback" engine = LibrarySupport.new(artifact_dir) - compilation_result = engine.compile(mlir, options=CompilationOptions.new("main")) + compilation_result = engine.compile(mlir, options=CompilationOptions.new()) compilation_feedback = engine.load_compilation_feedback(compilation_result) - assert isinstance(compilation_feedback, CompilationFeedback) - assert compilation_feedback.crt_decompositions_of_outputs == [[7, 8, 9, 11, 13]] + assert isinstance(compilation_feedback, ProgramCompilationFeedback) + assert isinstance(compilation_feedback.complexity, float) + assert isinstance(compilation_feedback.p_error, float) + assert isinstance(compilation_feedback.global_p_error, float) + assert isinstance(compilation_feedback.total_secret_keys_size, int) + assert isinstance(compilation_feedback.total_bootstrap_keys_size, int) + assert isinstance(compilation_feedback.circuit_feedbacks, list) + assert isinstance( + compilation_feedback.circuit_feedbacks[0], CircuitCompilationFeedback + ) + assert isinstance(compilation_feedback.circuit_feedbacks[0].total_inputs_size, int) + assert isinstance(compilation_feedback.circuit_feedbacks[0].total_output_size, int) + assert isinstance( + compilation_feedback.circuit_feedbacks[0].crt_decompositions_of_outputs, list + ) + assert compilation_feedback.circuit_feedbacks[0].crt_decompositions_of_outputs == [ + [7, 8, 9, 11, 13] + ] @pytest.mark.parametrize( @@ -450,10 +481,11 @@ def test_memory_usage(mlir: str, expected_memory_usage_per_loc: dict): engine = LibrarySupport.new(artifact_dir) compilation_result = engine.compile(mlir) compilation_feedback = engine.load_compilation_feedback(compilation_result) - assert isinstance(compilation_feedback, CompilationFeedback) + assert isinstance(compilation_feedback, ProgramCompilationFeedback) assert ( - expected_memory_usage_per_loc == compilation_feedback.memory_usage_per_location + expected_memory_usage_per_loc + == compilation_feedback.circuit_feedbacks[0].memory_usage_per_location ) shutil.rmtree(artifact_dir) diff --git a/compilers/concrete-compiler/compiler/tests/python/test_simulation.py b/compilers/concrete-compiler/compiler/tests/python/test_simulation.py index cc16ae85f..6f6cf73ab 100644 --- a/compilers/concrete-compiler/compiler/tests/python/test_simulation.py +++ b/compilers/concrete-compiler/compiler/tests/python/test_simulation.py @@ -49,7 +49,7 @@ def compile_run_assert( mlir_input, args_and_shape, expected_result, - options=CompilationOptions.new("main"), + options=CompilationOptions.new(), ): # compile with simulation options.simulation(True) diff --git a/compilers/concrete-compiler/compiler/tests/python/test_statistics.py b/compilers/concrete-compiler/compiler/tests/python/test_statistics.py index faf69d46f..1c5be1e3e 100644 --- a/compilers/concrete-compiler/compiler/tests/python/test_statistics.py +++ b/compilers/concrete-compiler/compiler/tests/python/test_statistics.py @@ -32,7 +32,10 @@ module { compilation_result = support.compile(mlir) client_parameters = support.load_client_parameters(compilation_result) - compilation_feedback = support.load_compilation_feedback(compilation_result) + program_compilation_feedback = support.load_compilation_feedback( + compilation_result + ) + compilation_feedback = program_compilation_feedback.circuit("main") pbs_count = compilation_feedback.count( operations={ diff --git a/compilers/concrete-compiler/compiler/tests/unit_tests/concretelang/Encodings/Encodings_unit_tests.cpp b/compilers/concrete-compiler/compiler/tests/unit_tests/concretelang/Encodings/Encodings_unit_tests.cpp index 90694b912..c575005b3 100644 --- a/compilers/concrete-compiler/compiler/tests/unit_tests/concretelang/Encodings/Encodings_unit_tests.cpp +++ b/compilers/concrete-compiler/compiler/tests/unit_tests/concretelang/Encodings/Encodings_unit_tests.cpp @@ -12,7 +12,7 @@ #include "concretelang/Common/Error.h" #include "concretelang/Support/CompilerEngine.h" #include "concretelang/Support/Encodings.h" -#include "concretelang/TestLib/TestCircuit.h" +#include "concretelang/TestLib/TestProgram.h" #include "tests_tools/GtestEnvironment.h" #include "tests_tools/assert.h" @@ -23,17 +23,18 @@ testing::Environment *const dfr_env = using namespace concretelang::testlib; namespace encodings = mlir::concretelang::encodings; -Result setupTestCircuit(std::string source, +Result setupTestProgram(std::string source, std::string funcname = FUNCNAME) { std::vector sources = {source}; std::shared_ptr ccx = mlir::concretelang::CompilationContext::createShared(); mlir::concretelang::CompilerEngine ce{ccx}; - mlir::concretelang::CompilationOptions options(funcname); + mlir::concretelang::CompilationOptions options; - options.encodings = Message(); - auto inputs = options.encodings->asBuilder().initInputs(2); - auto outputs = options.encodings->asBuilder().initOutputs(1); + auto circuitEncoding = Message(); + auto inputs = circuitEncoding.asBuilder().initInputs(2); + auto outputs = circuitEncoding.asBuilder().initOutputs(1); + circuitEncoding.asBuilder().setName(funcname); auto encodingInfo = Message().asBuilder(); encodingInfo.initShape(); @@ -46,9 +47,12 @@ Result setupTestCircuit(std::string source, inputs.setWithCaveats(1, encodingInfo); outputs.setWithCaveats(0, encodingInfo); - options.encodings->asBuilder().setName("main"); + options.encodings = Message(); + options.encodings->asBuilder().initCircuits(1).setWithCaveats( + 0, circuitEncoding.asReader()); + options.v0Parameter = {2, 10, 693, 4, 9, 7, 2, std::nullopt}; - TestCircuit testCircuit(options); + TestProgram testCircuit(options); OUTCOME_TRYV(testCircuit.compile({source})); OUTCOME_TRYV(testCircuit.generateKeyset()); return std::move(testCircuit); @@ -67,7 +71,7 @@ func.func @main( } )"; - ASSERT_ASSIGN_OUTCOME_VALUE(circuit, setupTestCircuit(source)); + ASSERT_ASSIGN_OUTCOME_VALUE(circuit, setupTestProgram(source)); uint64_t a = 5; uint64_t b = 5; auto res = circuit.call({Tensor(a), Tensor(b)}); diff --git a/compilers/concrete-compiler/compiler/tests/unit_tests/concretelang/SDFG/SDFG_unit_tests.cpp b/compilers/concrete-compiler/compiler/tests/unit_tests/concretelang/SDFG/SDFG_unit_tests.cpp index 1e8001cc8..ec3487418 100644 --- a/compilers/concrete-compiler/compiler/tests/unit_tests/concretelang/SDFG/SDFG_unit_tests.cpp +++ b/compilers/concrete-compiler/compiler/tests/unit_tests/concretelang/SDFG/SDFG_unit_tests.cpp @@ -10,7 +10,7 @@ #include "concretelang/Common/Error.h" #include "concretelang/Support/CompilerEngine.h" -#include "concretelang/TestLib/TestCircuit.h" +#include "concretelang/TestLib/TestProgram.h" #include "tests_tools/GtestEnvironment.h" #include "tests_tools/assert.h" @@ -20,15 +20,15 @@ using namespace concretelang::testlib; testing::Environment *const dfr_env = testing::AddGlobalTestEnvironment(new DFREnvironment); -Result setupTestCircuit(std::string source, +Result setupTestProgram(std::string source, std::string funcname = FUNCNAME) { - mlir::concretelang::CompilationOptions options(funcname); + mlir::concretelang::CompilationOptions options; #ifdef CONCRETELANG_CUDA_SUPPORT options.emitGPUOps = true; options.emitSDFGOps = true; #endif options.batchTFHEOps = true; - TestCircuit testCircuit(options); + TestProgram testCircuit(options); OUTCOME_TRYV(testCircuit.compile({source})); OUTCOME_TRYV(testCircuit.generateKeyset()); return std::move(testCircuit); @@ -41,7 +41,7 @@ func.func @main(%arg0: !FHE.eint<7>, %arg1: !FHE.eint<7>) -> !FHE.eint<7> { return %1: !FHE.eint<7> } )"; - ASSERT_ASSIGN_OUTCOME_VALUE(circuit, setupTestCircuit(source)); + ASSERT_ASSIGN_OUTCOME_VALUE(circuit, setupTestProgram(source)); for (auto a : values_7bits()) for (auto b : values_7bits()) { if (a > b) { @@ -61,7 +61,7 @@ func.func @main(%arg0: !FHE.eint<7>, %arg1: i8) -> !FHE.eint<7> { return %1: !FHE.eint<7> } )"; - ASSERT_ASSIGN_OUTCOME_VALUE(circuit, setupTestCircuit(source)); + ASSERT_ASSIGN_OUTCOME_VALUE(circuit, setupTestProgram(source)); for (auto a : values_7bits()) for (auto b : values_7bits()) { if (a > b) { @@ -81,7 +81,7 @@ func.func @main(%arg0: !FHE.eint<7>, %arg1: i8) -> !FHE.eint<7> { return %1: !FHE.eint<7> } )"; - ASSERT_ASSIGN_OUTCOME_VALUE(circuit, setupTestCircuit(source)); + ASSERT_ASSIGN_OUTCOME_VALUE(circuit, setupTestProgram(source)); for (auto a : values_3bits()) for (auto b : values_3bits()) { if (a > b) { @@ -101,7 +101,7 @@ func.func @main(%arg0: !FHE.eint<7>) -> !FHE.eint<7> { return %1: !FHE.eint<7> } )"; - ASSERT_ASSIGN_OUTCOME_VALUE(circuit, setupTestCircuit(source)); + ASSERT_ASSIGN_OUTCOME_VALUE(circuit, setupTestProgram(source)); for (auto a : values_7bits()) { auto res = circuit.call({Tensor(a)}); ASSERT_TRUE(res.has_value()); @@ -119,7 +119,7 @@ func.func @main(%arg0: !FHE.eint<7>, %arg1: !FHE.eint<7>, %arg2: !FHE.eint<7>, % return %3: !FHE.eint<7> } )"; - ASSERT_ASSIGN_OUTCOME_VALUE(circuit, setupTestCircuit(source)); + ASSERT_ASSIGN_OUTCOME_VALUE(circuit, setupTestProgram(source)); for (auto a : values_3bits()) { for (auto b : values_3bits()) { auto res = circuit.call({ @@ -143,7 +143,7 @@ func.func @main(%arg0: !FHE.eint<3>) -> !FHE.eint<3> { return %1: !FHE.eint<3> } )"; - ASSERT_ASSIGN_OUTCOME_VALUE(circuit, setupTestCircuit(source)); + ASSERT_ASSIGN_OUTCOME_VALUE(circuit, setupTestProgram(source)); for (auto a : values_3bits()) { auto res = circuit.call({Tensor(a)}); ASSERT_TRUE(res.has_value()); @@ -165,7 +165,7 @@ func.func @main(%arg0: !FHE.eint<4>) -> !FHE.eint<4> { return %6: !FHE.eint<4> } )"; - ASSERT_ASSIGN_OUTCOME_VALUE(circuit, setupTestCircuit(source)); + ASSERT_ASSIGN_OUTCOME_VALUE(circuit, setupTestProgram(source)); for (auto a : values_3bits()) { auto res = circuit.call({Tensor(a)}); ASSERT_TRUE(res.has_value()); @@ -182,7 +182,7 @@ TEST(SDFG_unit_tests, tlu_batched) { return %res : tensor<3x3x!FHE.eint<3>> } )"; - ASSERT_ASSIGN_OUTCOME_VALUE(circuit, setupTestCircuit(source)); + ASSERT_ASSIGN_OUTCOME_VALUE(circuit, setupTestProgram(source)); auto t = Tensor({0, 1, 2, 3, 0, 1, 2, 3, 0}, {3, 3}); auto expected = Tensor({1, 3, 5, 7, 1, 3, 5, 7, 1}, {3, 3}); auto res = circuit.call({t}); @@ -201,7 +201,7 @@ TEST(SDFG_unit_tests, batched_tree) { return %res : tensor<3x3x!FHE.eint<4>> } )"; - ASSERT_ASSIGN_OUTCOME_VALUE(circuit, setupTestCircuit(source)); + ASSERT_ASSIGN_OUTCOME_VALUE(circuit, setupTestProgram(source)); auto t = Tensor({0, 1, 2, 3, 0, 1, 2, 3, 0}, {3, 3}); auto a1 = Tensor({0, 1, 0, 0, 1, 0, 0, 1, 0}, {3, 3}); auto a2 = Tensor({1, 0, 1, 1, 0, 1, 1, 0, 1}, {3, 3}); @@ -226,7 +226,7 @@ TEST(SDFG_unit_tests, batched_tree_mapped_tlu) { return %res : tensor<3x3x!FHE.eint<4>> } )"; - ASSERT_ASSIGN_OUTCOME_VALUE(circuit, setupTestCircuit(source)); + ASSERT_ASSIGN_OUTCOME_VALUE(circuit, setupTestProgram(source)); auto t = Tensor({0, 1, 2, 3, 0, 1, 2, 3, 0}, {3, 3}); auto a1 = Tensor({0, 1, 0, 0, 1, 0, 0, 1, 0}, {3, 3}); auto a2 = Tensor({1, 0, 1, 1, 0, 1, 1, 0, 1}, {3, 3}); diff --git a/compilers/concrete-compiler/compiler/tests/unit_tests/concretelang/TestLib/testlib_unit_test.cpp b/compilers/concrete-compiler/compiler/tests/unit_tests/concretelang/TestLib/testlib_unit_test.cpp index e23272777..76ce56fcb 100644 --- a/compilers/concrete-compiler/compiler/tests/unit_tests/concretelang/TestLib/testlib_unit_test.cpp +++ b/compilers/concrete-compiler/compiler/tests/unit_tests/concretelang/TestLib/testlib_unit_test.cpp @@ -8,7 +8,7 @@ #include "concretelang/Common/Error.h" #include "concretelang/Support/CompilerEngine.h" -#include "concretelang/TestLib/TestCircuit.h" +#include "concretelang/TestLib/TestProgram.h" #include "tests_tools/GtestEnvironment.h" #include "tests_tools/assert.h" @@ -18,17 +18,17 @@ using namespace concretelang::testlib; testing::Environment *const dfr_env = testing::AddGlobalTestEnvironment(new DFREnvironment); -Result setupTestCircuit(std::string source, +Result setupTestProgram(std::string source, std::string funcname = FUNCNAME) { std::vector sources = {source}; std::shared_ptr ccx = mlir::concretelang::CompilationContext::createShared(); mlir::concretelang::CompilerEngine ce{ccx}; - mlir::concretelang::CompilationOptions options(funcname); + mlir::concretelang::CompilationOptions options; #ifdef CONCRETELANG_DATAFLOW_TESTING_ENABLED options.dataflowParallelize = true; #endif - TestCircuit testCircuit(options); + TestProgram testCircuit(options); OUTCOME_TRYV(testCircuit.compile({source})); OUTCOME_TRYV(testCircuit.generateKeyset()); return std::move(testCircuit); @@ -67,7 +67,7 @@ func.func @main(%arg0: !FHE.eint<7>) -> !FHE.eint<7> { return %arg0: !FHE.eint<7> } )"; - ASSERT_ASSIGN_OUTCOME_VALUE(circuit, setupTestCircuit(source)); + ASSERT_ASSIGN_OUTCOME_VALUE(circuit, setupTestProgram(source)); for (auto a : values_7bits()) { auto res = circuit.call({Tensor(a)}); ASSERT_TRUE(res.has_value()); @@ -82,7 +82,7 @@ func.func @main(%arg0: !FHE.eint<7>, %arg1: !FHE.eint<7>) -> !FHE.eint<7> { return %arg0: !FHE.eint<7> } )"; - ASSERT_ASSIGN_OUTCOME_VALUE(circuit, setupTestCircuit(source)); + ASSERT_ASSIGN_OUTCOME_VALUE(circuit, setupTestProgram(source)); for (auto a : values_7bits()) for (auto b : values_7bits()) { if (a > b) { @@ -102,7 +102,7 @@ func.func @main(%arg0: !FHE.eint<7>, %arg1: !FHE.eint<7>) -> !FHE.eint<7> { return %1: !FHE.eint<7> } )"; - ASSERT_ASSIGN_OUTCOME_VALUE(circuit, setupTestCircuit(source)); + ASSERT_ASSIGN_OUTCOME_VALUE(circuit, setupTestProgram(source)); for (auto a : values_7bits()) for (auto b : values_7bits()) { if (a > b) { @@ -122,7 +122,7 @@ func.func @main(%arg0: !FHE.eint<7>, %arg1: !FHE.eint<7>) -> !FHE.eint<7> { return %1: !FHE.eint<7> } )"; - ASSERT_ASSIGN_OUTCOME_VALUE(circuit, setupTestCircuit(source)); + ASSERT_ASSIGN_OUTCOME_VALUE(circuit, setupTestProgram(source)); auto res = circuit.call({Tensor(1)}); ASSERT_FALSE(res.has_value()); } @@ -134,7 +134,7 @@ func.func @main(%arg0: !FHE.eint<7>) -> tensor<1x!FHE.eint<7>> { return %1: tensor<1x!FHE.eint<7>> } )"; - ASSERT_ASSIGN_OUTCOME_VALUE(circuit, setupTestCircuit(source)); + ASSERT_ASSIGN_OUTCOME_VALUE(circuit, setupTestProgram(source)); for (auto a : values_7bits()) { auto res = circuit.call({Tensor(a)}); EXPECT_TRUE(res); @@ -150,7 +150,7 @@ func.func @main(%arg0: !FHE.eint<7>, %arg1: !FHE.eint<7>) -> tensor<2x!FHE.eint< return %1: tensor<2x!FHE.eint<7>> } )"; - ASSERT_ASSIGN_OUTCOME_VALUE(circuit, setupTestCircuit(source)); + ASSERT_ASSIGN_OUTCOME_VALUE(circuit, setupTestProgram(source)); for (auto a : values_7bits()) { auto res = circuit.call({Tensor(a), Tensor(a + 1)}); EXPECT_TRUE(res); @@ -168,7 +168,7 @@ func.func @main(%arg0: tensor<1x!FHE.eint<7>>) -> !FHE.eint<7> { return %1: !FHE.eint<7> } )"; - ASSERT_ASSIGN_OUTCOME_VALUE(circuit, setupTestCircuit(source)); + ASSERT_ASSIGN_OUTCOME_VALUE(circuit, setupTestProgram(source)); for (uint8_t a : values_7bits()) { auto ta = Tensor({a}, {1}); auto res = circuit.call({ta}); @@ -184,7 +184,7 @@ func.func @main(%arg0: tensor<3x!FHE.eint<7>>) -> tensor<3x!FHE.eint<7>> { return %arg0: tensor<3x!FHE.eint<7>> } )"; - ASSERT_ASSIGN_OUTCOME_VALUE(circuit, setupTestCircuit(source)); + ASSERT_ASSIGN_OUTCOME_VALUE(circuit, setupTestProgram(source)); auto ta = Tensor({1, 2, 3}, {3}); auto res = circuit.call({ta}); ASSERT_TRUE(res); @@ -202,7 +202,7 @@ func.func @main(%arg0: tensor<3x!FHE.eint<7>>, %arg1: tensor<3x!FHE.eint<7>>) -> return %3: !FHE.eint<7> } )"; - ASSERT_ASSIGN_OUTCOME_VALUE(circuit, setupTestCircuit(source)); + ASSERT_ASSIGN_OUTCOME_VALUE(circuit, setupTestProgram(source)); auto ta = Tensor({1, 2, 3}, {3}); auto tb = Tensor({5, 7, 9}, {3}); auto res = circuit.call({ta, tb}); @@ -219,7 +219,7 @@ func.func @main(%arg0: tensor<2x3x!FHE.eint<7>>) -> tensor<2x3x!FHE.eint<7>> { return %arg0: tensor<2x3x!FHE.eint<7>> } )"; - ASSERT_ASSIGN_OUTCOME_VALUE(circuit, setupTestCircuit(source)); + ASSERT_ASSIGN_OUTCOME_VALUE(circuit, setupTestProgram(source)); auto ta = Tensor({1, 2, 3, 4, 5, 6}, {2, 3}); auto res = circuit.call({ta}); ASSERT_TRUE(res); @@ -233,7 +233,7 @@ func.func @main(%arg0: tensor<2x3x1x!FHE.eint<7>>) -> tensor<2x3x1x!FHE.eint<7>> return %arg0: tensor<2x3x1x!FHE.eint<7>> } )"; - ASSERT_ASSIGN_OUTCOME_VALUE(circuit, setupTestCircuit(source)); + ASSERT_ASSIGN_OUTCOME_VALUE(circuit, setupTestProgram(source)); auto ta = Tensor({1, 2, 3, 4, 5, 6}, {2, 3, 1}); auto res = circuit.call({ta}); ASSERT_TRUE(res); @@ -248,7 +248,7 @@ func.func @main(%arg0: tensor<2x3x1x!FHE.eint<7>>, %arg1: tensor<2x3x1x!FHE.eint return %1: tensor<2x3x1x!FHE.eint<7>> } )"; - ASSERT_ASSIGN_OUTCOME_VALUE(circuit, setupTestCircuit(source)); + ASSERT_ASSIGN_OUTCOME_VALUE(circuit, setupTestProgram(source)); auto ta = Tensor({1, 2, 3, 4, 5, 6}, {2, 3, 1}); auto res = circuit.call({ta, ta}); ASSERT_TRUE(res); @@ -312,7 +312,7 @@ func.func @main(%arg0: !FHE.eint<6>, %arg1: !FHE.eint<3>) -> !FHE.eint<6> { return %a_plus_b: !FHE.eint<6> } )"; - ASSERT_ASSIGN_OUTCOME_VALUE(circuit, setupTestCircuit(source)); + ASSERT_ASSIGN_OUTCOME_VALUE(circuit, setupTestProgram(source)); for (auto a : values_6bits()) for (auto b : values_3bits()) { auto res = circuit.call({Tensor(a), Tensor(b)}); diff --git a/compilers/concrete-optimizer/concrete-optimizer-cpp/src/concrete-optimizer.rs b/compilers/concrete-optimizer/concrete-optimizer-cpp/src/concrete-optimizer.rs index b02ca15e9..1e61cc33e 100644 --- a/compilers/concrete-optimizer/concrete-optimizer-cpp/src/concrete-optimizer.rs +++ b/compilers/concrete-optimizer/concrete-optimizer-cpp/src/concrete-optimizer.rs @@ -599,6 +599,10 @@ impl OperationDag { self.0.dump() } + fn concat(&mut self, other: &Self) { + self.0.concat(&other.0); + } + fn tag_operator_as_output(&mut self, op: ffi::OperatorIndex) { self.0.tag_operator_as_output(op.into()); } @@ -741,6 +745,8 @@ mod ffi { fn dump(self: &OperationDag) -> String; + fn concat(self: &mut OperationDag, other: &OperationDag); + #[namespace = "concrete_optimizer::dag"] fn dump(self: &CircuitSolution) -> String; diff --git a/compilers/concrete-optimizer/concrete-optimizer-cpp/src/cpp/concrete-optimizer.cpp b/compilers/concrete-optimizer/concrete-optimizer-cpp/src/cpp/concrete-optimizer.cpp index 5733ae123..77a08cc30 100644 --- a/compilers/concrete-optimizer/concrete-optimizer-cpp/src/cpp/concrete-optimizer.cpp +++ b/compilers/concrete-optimizer/concrete-optimizer-cpp/src/cpp/concrete-optimizer.cpp @@ -976,6 +976,7 @@ struct OperationDag final : public ::rust::Opaque { ::concrete_optimizer::dag::OperatorIndex add_unsafe_cast_op(::concrete_optimizer::dag::OperatorIndex input, ::std::uint8_t rounded_precision) noexcept; ::concrete_optimizer::dag::DagSolution optimize(::concrete_optimizer::Options options) const noexcept; ::rust::String dump() const noexcept; + void concat(::concrete_optimizer::OperationDag const &other) noexcept; void tag_operator_as_output(::concrete_optimizer::dag::OperatorIndex op) noexcept; ::concrete_optimizer::dag::CircuitSolution optimize_multi(::concrete_optimizer::Options options) const noexcept; ~OperationDag() = delete; @@ -1289,6 +1290,8 @@ extern "C" { void concrete_optimizer$cxxbridge1$OperationDag$optimize(::concrete_optimizer::OperationDag const &self, ::concrete_optimizer::Options options, ::concrete_optimizer::dag::DagSolution *return$) noexcept; void concrete_optimizer$cxxbridge1$OperationDag$dump(::concrete_optimizer::OperationDag const &self, ::rust::String *return$) noexcept; + +void concrete_optimizer$cxxbridge1$OperationDag$concat(::concrete_optimizer::OperationDag &self, ::concrete_optimizer::OperationDag const &other) noexcept; } // extern "C" namespace dag { @@ -1390,6 +1393,10 @@ namespace dag { return ::std::move(return$.value); } +void OperationDag::concat(::concrete_optimizer::OperationDag const &other) noexcept { + concrete_optimizer$cxxbridge1$OperationDag$concat(*this, other); +} + namespace dag { ::rust::String CircuitSolution::dump() const noexcept { ::rust::MaybeUninit<::rust::String> return$; diff --git a/compilers/concrete-optimizer/concrete-optimizer-cpp/src/cpp/concrete-optimizer.hpp b/compilers/concrete-optimizer/concrete-optimizer-cpp/src/cpp/concrete-optimizer.hpp index 486d9adce..4636598f5 100644 --- a/compilers/concrete-optimizer/concrete-optimizer-cpp/src/cpp/concrete-optimizer.hpp +++ b/compilers/concrete-optimizer/concrete-optimizer-cpp/src/cpp/concrete-optimizer.hpp @@ -957,6 +957,7 @@ struct OperationDag final : public ::rust::Opaque { ::concrete_optimizer::dag::OperatorIndex add_unsafe_cast_op(::concrete_optimizer::dag::OperatorIndex input, ::std::uint8_t rounded_precision) noexcept; ::concrete_optimizer::dag::DagSolution optimize(::concrete_optimizer::Options options) const noexcept; ::rust::String dump() const noexcept; + void concat(::concrete_optimizer::OperationDag const &other) noexcept; void tag_operator_as_output(::concrete_optimizer::dag::OperatorIndex op) noexcept; ::concrete_optimizer::dag::CircuitSolution optimize_multi(::concrete_optimizer::Options options) const noexcept; ~OperationDag() = delete; diff --git a/compilers/concrete-optimizer/concrete-optimizer/src/dag/unparametrized.rs b/compilers/concrete-optimizer/concrete-optimizer/src/dag/unparametrized.rs index a7a289401..ce144266f 100644 --- a/compilers/concrete-optimizer/concrete-optimizer/src/dag/unparametrized.rs +++ b/compilers/concrete-optimizer/concrete-optimizer/src/dag/unparametrized.rs @@ -252,6 +252,29 @@ impl OperationDag { self.add_lut(rounded, table, out_precision) } + /// Concatenates two dags into a single one (with two disconnected clusters). + pub fn concat(&mut self, other: &Self) { + let length = self.len(); + self.operators.extend(other.operators.iter().cloned()); + self.out_precisions.extend(other.out_precisions.iter()); + self.out_shapes.extend(other.out_shapes.iter().cloned()); + self.output_tags.extend(other.output_tags.iter()); + self.operators[length..] + .iter_mut() + .for_each(|node| match node { + Operator::Lut { ref mut input, .. } + | Operator::UnsafeCast { ref mut input, .. } + | Operator::Round { ref mut input, .. } => { + input.i += length; + } + Operator::Dot { ref mut inputs, .. } + | Operator::LevelledOp { ref mut inputs, .. } => { + inputs.iter_mut().for_each(|inp| inp.i += length); + } + _ => (), + }); + } + /// Returns an iterator over input nodes indices. pub(crate) fn get_input_index_iter(&self) -> impl Iterator + '_ { self.operators @@ -315,6 +338,7 @@ impl OperationDag { DotKind::Broadcast { shape } => shape, DotKind::Unsupported { .. } => { let weights_shape = &weights.shape; + println!(); println!(); println!("Error diagnostic on dot operation:"); @@ -347,6 +371,33 @@ mod tests { use super::*; use crate::dag::operator::Shape; + #[test] + fn graph_concat() { + let mut graph1 = OperationDag::new(); + let a = graph1.add_input(1, Shape::number()); + let b = graph1.add_input(1, Shape::number()); + let c = graph1.add_dot([a, b], [1, 1]); + let _d = graph1.add_lut(c, FunctionTable::UNKWOWN, 1); + let mut graph2 = OperationDag::new(); + let a = graph2.add_input(2, Shape::number()); + let b = graph2.add_input(2, Shape::number()); + let c = graph2.add_dot([a, b], [2, 2]); + let _d = graph2.add_lut(c, FunctionTable::UNKWOWN, 2); + graph1.concat(&graph2); + + let mut graph3 = OperationDag::new(); + let a = graph3.add_input(1, Shape::number()); + let b = graph3.add_input(1, Shape::number()); + let c = graph3.add_dot([a, b], [1, 1]); + let _d = graph3.add_lut(c, FunctionTable::UNKWOWN, 1); + let a = graph3.add_input(2, Shape::number()); + let b = graph3.add_input(2, Shape::number()); + let c = graph3.add_dot([a, b], [2, 2]); + let _d = graph3.add_lut(c, FunctionTable::UNKWOWN, 2); + + assert_eq!(graph1, graph3); + } + #[test] fn graph_creation() { let mut graph = OperationDag::new(); diff --git a/frontends/concrete-python/concrete/fhe/compilation/server.py b/frontends/concrete-python/concrete/fhe/compilation/server.py index 34755781e..f6da3b1e7 100644 --- a/frontends/concrete-python/concrete/fhe/compilation/server.py +++ b/frontends/concrete-python/concrete/fhe/compilation/server.py @@ -14,13 +14,13 @@ from typing import Dict, List, Optional, Tuple, Union import concrete.compiler from concrete.compiler import ( CompilationContext, - CompilationFeedback, CompilationOptions, EvaluationKeys, LibraryCompilationResult, LibraryLambda, LibrarySupport, Parameter, + ProgramCompilationFeedback, PublicArguments, set_compiler_logging, set_llvm_debug_flag, @@ -59,7 +59,7 @@ class Server: _output_dir: Optional[tempfile.TemporaryDirectory] _support: LibrarySupport _compilation_result: LibraryCompilationResult - _compilation_feedback: CompilationFeedback + _compilation_feedback: ProgramCompilationFeedback _server_lambda: LibraryLambda _mlir: Optional[str] @@ -114,7 +114,7 @@ class Server: """ backend = Backend.GPU if configuration.use_gpu else Backend.CPU - options = CompilationOptions.new("main", backend) + options = CompilationOptions.new(backend) options.simulation(is_simulated) @@ -312,7 +312,7 @@ class Server: generateCppHeader=False, generateStaticLib=False, ) - compilation_result = support.reload("main") + compilation_result = support.reload() server_lambda = support.load_server_lambda(compilation_result, is_simulated) return Server( @@ -408,14 +408,14 @@ class Server: """ Get size of the inputs of the compiled program. """ - return self._compilation_feedback.total_inputs_size + return self._compilation_feedback.circuit("main").total_inputs_size @property def size_of_outputs(self) -> int: """ Get size of the outputs of the compiled program. """ - return self._compilation_feedback.total_output_size + return self._compilation_feedback.circuit("main").total_output_size @property def p_error(self) -> int: @@ -445,7 +445,7 @@ class Server: """ Get the number of programmable bootstraps in the compiled program. """ - return self._compilation_feedback.count( + return self._compilation_feedback.circuit("main").count( operations={PrimitiveOperation.PBS, PrimitiveOperation.WOP_PBS}, ) @@ -454,7 +454,7 @@ class Server: """ Get the number of programmable bootstraps per parameter in the compiled program. """ - return self._compilation_feedback.count_per_parameter( + return self._compilation_feedback.circuit("main").count_per_parameter( operations={PrimitiveOperation.PBS, PrimitiveOperation.WOP_PBS}, key_types={KeyType.BOOTSTRAP}, client_parameters=self.client_specs.client_parameters, @@ -465,7 +465,7 @@ class Server: """ Get the number of programmable bootstraps per tag in the compiled program. """ - return self._compilation_feedback.count_per_tag( + return self._compilation_feedback.circuit("main").count_per_tag( operations={PrimitiveOperation.PBS, PrimitiveOperation.WOP_PBS}, ) @@ -474,7 +474,7 @@ class Server: """ Get the number of programmable bootstraps per tag per parameter in the compiled program. """ - return self._compilation_feedback.count_per_tag_per_parameter( + return self._compilation_feedback.circuit("main").count_per_tag_per_parameter( operations={PrimitiveOperation.PBS, PrimitiveOperation.WOP_PBS}, key_types={KeyType.BOOTSTRAP}, client_parameters=self.client_specs.client_parameters, @@ -487,7 +487,7 @@ class Server: """ Get the number of key switches in the compiled program. """ - return self._compilation_feedback.count( + return self._compilation_feedback.circuit("main").count( operations={PrimitiveOperation.KEY_SWITCH, PrimitiveOperation.WOP_PBS}, ) @@ -496,7 +496,7 @@ class Server: """ Get the number of key switches per parameter in the compiled program. """ - return self._compilation_feedback.count_per_parameter( + return self._compilation_feedback.circuit("main").count_per_parameter( operations={PrimitiveOperation.KEY_SWITCH, PrimitiveOperation.WOP_PBS}, key_types={KeyType.KEY_SWITCH}, client_parameters=self.client_specs.client_parameters, @@ -507,7 +507,7 @@ class Server: """ Get the number of key switches per tag in the compiled program. """ - return self._compilation_feedback.count_per_tag( + return self._compilation_feedback.circuit("main").count_per_tag( operations={PrimitiveOperation.KEY_SWITCH, PrimitiveOperation.WOP_PBS}, ) @@ -516,7 +516,7 @@ class Server: """ Get the number of key switches per tag per parameter in the compiled program. """ - return self._compilation_feedback.count_per_tag_per_parameter( + return self._compilation_feedback.circuit("main").count_per_tag_per_parameter( operations={PrimitiveOperation.KEY_SWITCH, PrimitiveOperation.WOP_PBS}, key_types={KeyType.KEY_SWITCH}, client_parameters=self.client_specs.client_parameters, @@ -529,14 +529,16 @@ class Server: """ Get the number of packing key switches in the compiled program. """ - return self._compilation_feedback.count(operations={PrimitiveOperation.WOP_PBS}) + return self._compilation_feedback.circuit("main").count( + operations={PrimitiveOperation.WOP_PBS} + ) @property def packing_key_switch_count_per_parameter(self) -> Dict[Parameter, int]: """ Get the number of packing key switches per parameter in the compiled program. """ - return self._compilation_feedback.count_per_parameter( + return self._compilation_feedback.circuit("main").count_per_parameter( operations={PrimitiveOperation.WOP_PBS}, key_types={KeyType.PACKING_KEY_SWITCH}, client_parameters=self.client_specs.client_parameters, @@ -547,14 +549,16 @@ class Server: """ Get the number of packing key switches per tag in the compiled program. """ - return self._compilation_feedback.count_per_tag(operations={PrimitiveOperation.WOP_PBS}) + return self._compilation_feedback.circuit("main").count_per_tag( + operations={PrimitiveOperation.WOP_PBS} + ) @property def packing_key_switch_count_per_tag_per_parameter(self) -> Dict[str, Dict[Parameter, int]]: """ Get the number of packing key switches per tag per parameter in the compiled program. """ - return self._compilation_feedback.count_per_tag_per_parameter( + return self._compilation_feedback.circuit("main").count_per_tag_per_parameter( operations={PrimitiveOperation.WOP_PBS}, key_types={KeyType.PACKING_KEY_SWITCH}, client_parameters=self.client_specs.client_parameters, @@ -567,14 +571,16 @@ class Server: """ Get the number of clear additions in the compiled program. """ - return self._compilation_feedback.count(operations={PrimitiveOperation.CLEAR_ADDITION}) + return self._compilation_feedback.circuit("main").count( + operations={PrimitiveOperation.CLEAR_ADDITION} + ) @property def clear_addition_count_per_parameter(self) -> Dict[Parameter, int]: """ Get the number of clear additions per parameter in the compiled program. """ - return self._compilation_feedback.count_per_parameter( + return self._compilation_feedback.circuit("main").count_per_parameter( operations={PrimitiveOperation.CLEAR_ADDITION}, key_types={KeyType.SECRET}, client_parameters=self.client_specs.client_parameters, @@ -585,7 +591,7 @@ class Server: """ Get the number of clear additions per tag in the compiled program. """ - return self._compilation_feedback.count_per_tag( + return self._compilation_feedback.circuit("main").count_per_tag( operations={PrimitiveOperation.CLEAR_ADDITION}, ) @@ -594,7 +600,7 @@ class Server: """ Get the number of clear additions per tag per parameter in the compiled program. """ - return self._compilation_feedback.count_per_tag_per_parameter( + return self._compilation_feedback.circuit("main").count_per_tag_per_parameter( operations={PrimitiveOperation.CLEAR_ADDITION}, key_types={KeyType.SECRET}, client_parameters=self.client_specs.client_parameters, @@ -607,14 +613,16 @@ class Server: """ Get the number of encrypted additions in the compiled program. """ - return self._compilation_feedback.count(operations={PrimitiveOperation.ENCRYPTED_ADDITION}) + return self._compilation_feedback.circuit("main").count( + operations={PrimitiveOperation.ENCRYPTED_ADDITION} + ) @property def encrypted_addition_count_per_parameter(self) -> Dict[Parameter, int]: """ Get the number of encrypted additions per parameter in the compiled program. """ - return self._compilation_feedback.count_per_parameter( + return self._compilation_feedback.circuit("main").count_per_parameter( operations={PrimitiveOperation.ENCRYPTED_ADDITION}, key_types={KeyType.SECRET}, client_parameters=self.client_specs.client_parameters, @@ -625,7 +633,7 @@ class Server: """ Get the number of encrypted additions per tag in the compiled program. """ - return self._compilation_feedback.count_per_tag( + return self._compilation_feedback.circuit("main").count_per_tag( operations={PrimitiveOperation.ENCRYPTED_ADDITION}, ) @@ -634,7 +642,7 @@ class Server: """ Get the number of encrypted additions per tag per parameter in the compiled program. """ - return self._compilation_feedback.count_per_tag_per_parameter( + return self._compilation_feedback.circuit("main").count_per_tag_per_parameter( operations={PrimitiveOperation.ENCRYPTED_ADDITION}, key_types={KeyType.SECRET}, client_parameters=self.client_specs.client_parameters, @@ -647,7 +655,7 @@ class Server: """ Get the number of clear multiplications in the compiled program. """ - return self._compilation_feedback.count( + return self._compilation_feedback.circuit("main").count( operations={PrimitiveOperation.CLEAR_MULTIPLICATION}, ) @@ -656,7 +664,7 @@ class Server: """ Get the number of clear multiplications per parameter in the compiled program. """ - return self._compilation_feedback.count_per_parameter( + return self._compilation_feedback.circuit("main").count_per_parameter( operations={PrimitiveOperation.CLEAR_MULTIPLICATION}, key_types={KeyType.SECRET}, client_parameters=self.client_specs.client_parameters, @@ -667,7 +675,7 @@ class Server: """ Get the number of clear multiplications per tag in the compiled program. """ - return self._compilation_feedback.count_per_tag( + return self._compilation_feedback.circuit("main").count_per_tag( operations={PrimitiveOperation.CLEAR_MULTIPLICATION}, ) @@ -676,7 +684,7 @@ class Server: """ Get the number of clear multiplications per tag per parameter in the compiled program. """ - return self._compilation_feedback.count_per_tag_per_parameter( + return self._compilation_feedback.circuit("main").count_per_tag_per_parameter( operations={PrimitiveOperation.CLEAR_MULTIPLICATION}, key_types={KeyType.SECRET}, client_parameters=self.client_specs.client_parameters, @@ -689,14 +697,16 @@ class Server: """ Get the number of encrypted negations in the compiled program. """ - return self._compilation_feedback.count(operations={PrimitiveOperation.ENCRYPTED_NEGATION}) + return self._compilation_feedback.circuit("main").count( + operations={PrimitiveOperation.ENCRYPTED_NEGATION} + ) @property def encrypted_negation_count_per_parameter(self) -> Dict[Parameter, int]: """ Get the number of encrypted negations per parameter in the compiled program. """ - return self._compilation_feedback.count_per_parameter( + return self._compilation_feedback.circuit("main").count_per_parameter( operations={PrimitiveOperation.ENCRYPTED_NEGATION}, key_types={KeyType.SECRET}, client_parameters=self.client_specs.client_parameters, @@ -707,7 +717,7 @@ class Server: """ Get the number of encrypted negations per tag in the compiled program. """ - return self._compilation_feedback.count_per_tag( + return self._compilation_feedback.circuit("main").count_per_tag( operations={PrimitiveOperation.ENCRYPTED_NEGATION}, ) @@ -716,7 +726,7 @@ class Server: """ Get the number of encrypted negations per tag per parameter in the compiled program. """ - return self._compilation_feedback.count_per_tag_per_parameter( + return self._compilation_feedback.circuit("main").count_per_tag_per_parameter( operations={PrimitiveOperation.ENCRYPTED_NEGATION}, key_types={KeyType.SECRET}, client_parameters=self.client_specs.client_parameters, diff --git a/tools/concrete-protocol/src/concrete-protocol.capnp b/tools/concrete-protocol/src/concrete-protocol.capnp index 41a485411..f13ae3bc6 100644 --- a/tools/concrete-protocol/src/concrete-protocol.capnp +++ b/tools/concrete-protocol/src/concrete-protocol.capnp @@ -399,6 +399,13 @@ struct CircuitEncodingInfo { name @2 :Text; # The name of the circuit. } +struct ProgramEncodingInfo { + # A program encodings is described by the set of circuit encodings. This structure represents + # this ensemble of encoding signatures. + + circuits @0 :List(CircuitEncodingInfo); # The list of the circuit encoding infos. +} + ###################################################################################### Encryption ## struct LweCiphertextEncryptionInfo {