mirror of
https://github.com/zama-ai/concrete.git
synced 2026-02-08 19:44:57 -05:00
feat: give crt decomposition feedback
This commit is contained in:
@@ -37,6 +37,9 @@ struct CompilationFeedback {
|
||||
/// @brief the total number of bytes of outputs
|
||||
uint64_t totalOutputsSize;
|
||||
|
||||
/// @brief crt decomposition of outputs, if crt is not used, empty vectors
|
||||
std::vector<std::vector<int64_t>> crtDecompositionsOfOutputs;
|
||||
|
||||
/// Fill the sizes from the client parameters.
|
||||
void
|
||||
fillFromClientParameters(::concretelang::clientlib::ClientParameters params);
|
||||
|
||||
@@ -88,7 +88,10 @@ void mlir::concretelang::python::populateCompilerAPISubmodule(
|
||||
.def_readonly("total_inputs_size",
|
||||
&mlir::concretelang::CompilationFeedback::totalInputsSize)
|
||||
.def_readonly("total_output_size",
|
||||
&mlir::concretelang::CompilationFeedback::totalOutputsSize);
|
||||
&mlir::concretelang::CompilationFeedback::totalOutputsSize)
|
||||
.def_readonly(
|
||||
"crt_decompositions_of_outputs",
|
||||
&mlir::concretelang::CompilationFeedback::crtDecompositionsOfOutputs);
|
||||
|
||||
pybind11::class_<mlir::concretelang::JitCompilationResult>(
|
||||
m, "JITCompilationResult");
|
||||
|
||||
@@ -36,5 +36,8 @@ class CompilationFeedback(WrapperCpp):
|
||||
self.total_keyswitch_keys_size = compilation_feedback.total_keyswitch_keys_size
|
||||
self.total_inputs_size = compilation_feedback.total_inputs_size
|
||||
self.total_output_size = compilation_feedback.total_output_size
|
||||
self.crt_decompositions_of_outputs = (
|
||||
compilation_feedback.crt_decompositions_of_outputs
|
||||
)
|
||||
|
||||
super().__init__(compilation_feedback)
|
||||
|
||||
@@ -53,6 +53,15 @@ void CompilationFeedback::fillFromClientParameters(
|
||||
for (auto gate : params.outputs) {
|
||||
totalOutputsSize += gate.byteSize(params.secretKeys);
|
||||
}
|
||||
// Extract CRT decomposition
|
||||
crtDecompositionsOfOutputs = {};
|
||||
for (auto gate : params.outputs) {
|
||||
std::vector<int64_t> decomposition;
|
||||
if (gate.encryption.hasValue()) {
|
||||
decomposition = gate.encryption->encoding.crt;
|
||||
}
|
||||
crtDecompositionsOfOutputs.push_back(decomposition);
|
||||
}
|
||||
}
|
||||
|
||||
outcome::checked<CompilationFeedback, StringError>
|
||||
@@ -80,6 +89,7 @@ llvm::json::Value toJSON(const mlir::concretelang::CompilationFeedback &v) {
|
||||
{"totalKeyswitchKeysSize", v.totalKeyswitchKeysSize},
|
||||
{"totalInputsSize", v.totalInputsSize},
|
||||
{"totalOutputsSize", v.totalOutputsSize},
|
||||
{"crtDecompositionsOfOutputs", v.crtDecompositionsOfOutputs},
|
||||
};
|
||||
return object;
|
||||
}
|
||||
@@ -92,7 +102,8 @@ bool fromJSON(const llvm::json::Value j,
|
||||
O.map("totalBootstrapKeysSize", v.totalBootstrapKeysSize) &&
|
||||
O.map("totalKeyswitchKeysSize", v.totalKeyswitchKeysSize) &&
|
||||
O.map("totalInputsSize", v.totalInputsSize) &&
|
||||
O.map("totalOutputsSize", v.totalOutputsSize);
|
||||
O.map("totalOutputsSize", v.totalOutputsSize) &&
|
||||
O.map("crtDecompositionsOfOutputs", v.crtDecompositionsOfOutputs);
|
||||
}
|
||||
|
||||
} // namespace concretelang
|
||||
|
||||
@@ -374,3 +374,22 @@ def test_compile_invalid(mlir_input):
|
||||
RuntimeError, match=r"Could not find existing crypto parameters for"
|
||||
):
|
||||
engine.compile(mlir_input)
|
||||
|
||||
|
||||
def test_crt_decomposition_feedback():
|
||||
mlir = """
|
||||
|
||||
func.func @main(%arg0: !FHE.eint<16>) -> !FHE.eint<16> {
|
||||
%tlu = arith.constant dense<60000> : tensor<65536xi64>
|
||||
%1 = "FHE.apply_lookup_table"(%arg0, %tlu): (!FHE.eint<16>, tensor<65536xi64>) -> (!FHE.eint<16>)
|
||||
return %1: !FHE.eint<16>
|
||||
}
|
||||
|
||||
"""
|
||||
|
||||
engine = JITSupport.new()
|
||||
compilation_result = engine.compile(mlir, options=CompilationOptions.new("main"))
|
||||
compilation_feedback = engine.load_compilation_feedback(compilation_result)
|
||||
|
||||
assert isinstance(compilation_feedback, CompilationFeedback)
|
||||
assert compilation_feedback.crt_decompositions_of_outputs == [[7, 8, 9, 11, 13]]
|
||||
|
||||
Reference in New Issue
Block a user