From 722e4d2ebad87ccc31d0c063c9b523f3eb53c480 Mon Sep 17 00:00:00 2001 From: Umut Date: Mon, 21 Nov 2022 10:15:12 +0100 Subject: [PATCH] feat: give crt decomposition feedback --- .../Support/CompilationFeedback.h | 3 +++ .../lib/Bindings/Python/CompilerAPIModule.cpp | 5 ++++- .../concrete/compiler/compilation_feedback.py | 3 +++ compiler/lib/Support/CompilationFeedback.cpp | 13 ++++++++++++- compiler/tests/python/test_compilation.py | 19 +++++++++++++++++++ 5 files changed, 41 insertions(+), 2 deletions(-) diff --git a/compiler/include/concretelang/Support/CompilationFeedback.h b/compiler/include/concretelang/Support/CompilationFeedback.h index 8a0d9539b..bf93d0ff1 100644 --- a/compiler/include/concretelang/Support/CompilationFeedback.h +++ b/compiler/include/concretelang/Support/CompilationFeedback.h @@ -37,6 +37,9 @@ struct CompilationFeedback { /// @brief the total number of bytes of outputs uint64_t totalOutputsSize; + /// @brief crt decomposition of outputs, if crt is not used, empty vectors + std::vector> crtDecompositionsOfOutputs; + /// Fill the sizes from the client parameters. void fillFromClientParameters(::concretelang::clientlib::ClientParameters params); diff --git a/compiler/lib/Bindings/Python/CompilerAPIModule.cpp b/compiler/lib/Bindings/Python/CompilerAPIModule.cpp index 0aa395f03..be9c1203c 100644 --- a/compiler/lib/Bindings/Python/CompilerAPIModule.cpp +++ b/compiler/lib/Bindings/Python/CompilerAPIModule.cpp @@ -88,7 +88,10 @@ void mlir::concretelang::python::populateCompilerAPISubmodule( .def_readonly("total_inputs_size", &mlir::concretelang::CompilationFeedback::totalInputsSize) .def_readonly("total_output_size", - &mlir::concretelang::CompilationFeedback::totalOutputsSize); + &mlir::concretelang::CompilationFeedback::totalOutputsSize) + .def_readonly( + "crt_decompositions_of_outputs", + &mlir::concretelang::CompilationFeedback::crtDecompositionsOfOutputs); pybind11::class_( m, "JITCompilationResult"); diff --git a/compiler/lib/Bindings/Python/concrete/compiler/compilation_feedback.py b/compiler/lib/Bindings/Python/concrete/compiler/compilation_feedback.py index 5096d6e3c..ec5e77d54 100644 --- a/compiler/lib/Bindings/Python/concrete/compiler/compilation_feedback.py +++ b/compiler/lib/Bindings/Python/concrete/compiler/compilation_feedback.py @@ -36,5 +36,8 @@ class CompilationFeedback(WrapperCpp): self.total_keyswitch_keys_size = compilation_feedback.total_keyswitch_keys_size self.total_inputs_size = compilation_feedback.total_inputs_size self.total_output_size = compilation_feedback.total_output_size + self.crt_decompositions_of_outputs = ( + compilation_feedback.crt_decompositions_of_outputs + ) super().__init__(compilation_feedback) diff --git a/compiler/lib/Support/CompilationFeedback.cpp b/compiler/lib/Support/CompilationFeedback.cpp index cec0a12e0..9ca87a8cf 100644 --- a/compiler/lib/Support/CompilationFeedback.cpp +++ b/compiler/lib/Support/CompilationFeedback.cpp @@ -53,6 +53,15 @@ void CompilationFeedback::fillFromClientParameters( for (auto gate : params.outputs) { totalOutputsSize += gate.byteSize(params.secretKeys); } + // Extract CRT decomposition + crtDecompositionsOfOutputs = {}; + for (auto gate : params.outputs) { + std::vector decomposition; + if (gate.encryption.hasValue()) { + decomposition = gate.encryption->encoding.crt; + } + crtDecompositionsOfOutputs.push_back(decomposition); + } } outcome::checked @@ -80,6 +89,7 @@ llvm::json::Value toJSON(const mlir::concretelang::CompilationFeedback &v) { {"totalKeyswitchKeysSize", v.totalKeyswitchKeysSize}, {"totalInputsSize", v.totalInputsSize}, {"totalOutputsSize", v.totalOutputsSize}, + {"crtDecompositionsOfOutputs", v.crtDecompositionsOfOutputs}, }; return object; } @@ -92,7 +102,8 @@ bool fromJSON(const llvm::json::Value j, O.map("totalBootstrapKeysSize", v.totalBootstrapKeysSize) && O.map("totalKeyswitchKeysSize", v.totalKeyswitchKeysSize) && O.map("totalInputsSize", v.totalInputsSize) && - O.map("totalOutputsSize", v.totalOutputsSize); + O.map("totalOutputsSize", v.totalOutputsSize) && + O.map("crtDecompositionsOfOutputs", v.crtDecompositionsOfOutputs); } } // namespace concretelang diff --git a/compiler/tests/python/test_compilation.py b/compiler/tests/python/test_compilation.py index 505de40e9..49905fa77 100644 --- a/compiler/tests/python/test_compilation.py +++ b/compiler/tests/python/test_compilation.py @@ -374,3 +374,22 @@ def test_compile_invalid(mlir_input): RuntimeError, match=r"Could not find existing crypto parameters for" ): engine.compile(mlir_input) + + +def test_crt_decomposition_feedback(): + mlir = """ + +func.func @main(%arg0: !FHE.eint<16>) -> !FHE.eint<16> { + %tlu = arith.constant dense<60000> : tensor<65536xi64> + %1 = "FHE.apply_lookup_table"(%arg0, %tlu): (!FHE.eint<16>, tensor<65536xi64>) -> (!FHE.eint<16>) + return %1: !FHE.eint<16> +} + + """ + + engine = JITSupport.new() + compilation_result = engine.compile(mlir, options=CompilationOptions.new("main")) + compilation_feedback = engine.load_compilation_feedback(compilation_result) + + assert isinstance(compilation_feedback, CompilationFeedback) + assert compilation_feedback.crt_decompositions_of_outputs == [[7, 8, 9, 11, 13]]