feat: give crt decomposition feedback

This commit is contained in:
Umut
2022-11-21 10:15:12 +01:00
committed by Quentin Bourgerie
parent 1e8c0df381
commit 722e4d2eba
5 changed files with 41 additions and 2 deletions

View File

@@ -37,6 +37,9 @@ struct CompilationFeedback {
/// @brief the total number of bytes of outputs
uint64_t totalOutputsSize;
/// @brief crt decomposition of outputs, if crt is not used, empty vectors
std::vector<std::vector<int64_t>> crtDecompositionsOfOutputs;
/// Fill the sizes from the client parameters.
void
fillFromClientParameters(::concretelang::clientlib::ClientParameters params);

View File

@@ -88,7 +88,10 @@ void mlir::concretelang::python::populateCompilerAPISubmodule(
.def_readonly("total_inputs_size",
&mlir::concretelang::CompilationFeedback::totalInputsSize)
.def_readonly("total_output_size",
&mlir::concretelang::CompilationFeedback::totalOutputsSize);
&mlir::concretelang::CompilationFeedback::totalOutputsSize)
.def_readonly(
"crt_decompositions_of_outputs",
&mlir::concretelang::CompilationFeedback::crtDecompositionsOfOutputs);
pybind11::class_<mlir::concretelang::JitCompilationResult>(
m, "JITCompilationResult");

View File

@@ -36,5 +36,8 @@ class CompilationFeedback(WrapperCpp):
self.total_keyswitch_keys_size = compilation_feedback.total_keyswitch_keys_size
self.total_inputs_size = compilation_feedback.total_inputs_size
self.total_output_size = compilation_feedback.total_output_size
self.crt_decompositions_of_outputs = (
compilation_feedback.crt_decompositions_of_outputs
)
super().__init__(compilation_feedback)

View File

@@ -53,6 +53,15 @@ void CompilationFeedback::fillFromClientParameters(
for (auto gate : params.outputs) {
totalOutputsSize += gate.byteSize(params.secretKeys);
}
// Extract CRT decomposition
crtDecompositionsOfOutputs = {};
for (auto gate : params.outputs) {
std::vector<int64_t> decomposition;
if (gate.encryption.hasValue()) {
decomposition = gate.encryption->encoding.crt;
}
crtDecompositionsOfOutputs.push_back(decomposition);
}
}
outcome::checked<CompilationFeedback, StringError>
@@ -80,6 +89,7 @@ llvm::json::Value toJSON(const mlir::concretelang::CompilationFeedback &v) {
{"totalKeyswitchKeysSize", v.totalKeyswitchKeysSize},
{"totalInputsSize", v.totalInputsSize},
{"totalOutputsSize", v.totalOutputsSize},
{"crtDecompositionsOfOutputs", v.crtDecompositionsOfOutputs},
};
return object;
}
@@ -92,7 +102,8 @@ bool fromJSON(const llvm::json::Value j,
O.map("totalBootstrapKeysSize", v.totalBootstrapKeysSize) &&
O.map("totalKeyswitchKeysSize", v.totalKeyswitchKeysSize) &&
O.map("totalInputsSize", v.totalInputsSize) &&
O.map("totalOutputsSize", v.totalOutputsSize);
O.map("totalOutputsSize", v.totalOutputsSize) &&
O.map("crtDecompositionsOfOutputs", v.crtDecompositionsOfOutputs);
}
} // namespace concretelang

View File

@@ -374,3 +374,22 @@ def test_compile_invalid(mlir_input):
RuntimeError, match=r"Could not find existing crypto parameters for"
):
engine.compile(mlir_input)
def test_crt_decomposition_feedback():
mlir = """
func.func @main(%arg0: !FHE.eint<16>) -> !FHE.eint<16> {
%tlu = arith.constant dense<60000> : tensor<65536xi64>
%1 = "FHE.apply_lookup_table"(%arg0, %tlu): (!FHE.eint<16>, tensor<65536xi64>) -> (!FHE.eint<16>)
return %1: !FHE.eint<16>
}
"""
engine = JITSupport.new()
compilation_result = engine.compile(mlir, options=CompilationOptions.new("main"))
compilation_feedback = engine.load_compilation_feedback(compilation_result)
assert isinstance(compilation_feedback, CompilationFeedback)
assert compilation_feedback.crt_decompositions_of_outputs == [[7, 8, 9, 11, 13]]