diff --git a/compilers/concrete-optimizer/concrete-optimizer-cpp/src/concrete-optimizer.rs b/compilers/concrete-optimizer/concrete-optimizer-cpp/src/concrete-optimizer.rs index 480d86772..00a997055 100644 --- a/compilers/concrete-optimizer/concrete-optimizer-cpp/src/concrete-optimizer.rs +++ b/compilers/concrete-optimizer/concrete-optimizer-cpp/src/concrete-optimizer.rs @@ -6,7 +6,8 @@ use concrete_optimizer::dag::operator::{ }; use concrete_optimizer::dag::unparametrized; use concrete_optimizer::optimization::config::{Config, SearchSpace}; -use concrete_optimizer::optimization::dag::multi_parameters::keys_spec; +use concrete_optimizer::optimization::dag::multi_parameters::keys_spec::{self, CircuitSolution}; +use concrete_optimizer::optimization::dag::multi_parameters::optimize::optimize_to_circuit_solution; use concrete_optimizer::optimization::dag::solo_key::optimize_generic::{ Encoding, Solution as DagSolution, }; @@ -164,60 +165,81 @@ impl From for ffi::DagSolution { } } -impl ffi::CircuitSolution { - fn of(sol: ffi::DagSolution, dag: &OperationDag) -> Self { - let big_key = ffi::SecretLweKey { - identifier: 0, - polynomial_size: sol.glwe_polynomial_size, - glwe_dimension: sol.glwe_dimension, - description: "big representation".into(), - }; - let small_key = ffi::SecretLweKey { - identifier: 1, - polynomial_size: sol.internal_ks_output_lwe_dimension, - glwe_dimension: 1, - description: "small representation".into(), - }; - let keyswitch_key = ffi::KeySwitchKey { - identifier: 0, - input_key: big_key.clone(), - output_key: small_key.clone(), - ks_decomposition_parameter: ffi::KsDecompositionParameters { - level: sol.ks_decomposition_level_count, - log2_base: sol.ks_decomposition_base_log, - }, - description: "tlu keyswitch".into(), - }; - let bootstrap_key = ffi::BootstrapKey { - identifier: 0, - input_key: small_key.clone(), - output_key: big_key.clone(), - br_decomposition_parameter: ffi::BrDecompositionParameters { - level: sol.br_decomposition_level_count, - log2_base: sol.br_decomposition_base_log, - }, - description: "tlu bootsrap".into(), - }; - let instruction_keys = ffi::InstructionKeys { - input_key: big_key.identifier, - tlu_keyswitch_key: keyswitch_key.identifier, - tlu_bootstrap_key: bootstrap_key.identifier, - output_key: big_key.identifier, - extra_conversion_keys: vec![], - }; - let instructions_keys = vec![instruction_keys; dag.0.len()]; - let circuit_keys = ffi::CircuitKeys { - secret_keys: [big_key, small_key].into(), - keyswitch_keys: [keyswitch_key].into(), - bootstrap_keys: [bootstrap_key].into(), - conversion_keyswitch_keys: [].into(), - }; - ffi::CircuitSolution { - circuit_keys, - instructions_keys, - complexity: sol.complexity, - p_error: sol.p_error, - global_p_error: sol.global_p_error, +fn convert_to_circuit_solution(sol: &ffi::DagSolution, dag: &OperationDag) -> ffi::CircuitSolution { + let big_key = ffi::SecretLweKey { + identifier: 0, + polynomial_size: sol.glwe_polynomial_size, + glwe_dimension: sol.glwe_dimension, + description: "big representation".into(), + }; + let small_key = ffi::SecretLweKey { + identifier: 1, + polynomial_size: sol.internal_ks_output_lwe_dimension, + glwe_dimension: 1, + description: "small representation".into(), + }; + let keyswitch_key = ffi::KeySwitchKey { + identifier: 0, + input_key: big_key.clone(), + output_key: small_key.clone(), + ks_decomposition_parameter: ffi::KsDecompositionParameters { + level: sol.ks_decomposition_level_count, + log2_base: sol.ks_decomposition_base_log, + }, + description: "tlu keyswitch".into(), + }; + let bootstrap_key = ffi::BootstrapKey { + identifier: 0, + input_key: small_key.clone(), + output_key: big_key.clone(), + br_decomposition_parameter: ffi::BrDecompositionParameters { + level: sol.br_decomposition_level_count, + log2_base: sol.br_decomposition_base_log, + }, + description: "tlu bootstrap".into(), + }; + let instruction_keys = ffi::InstructionKeys { + input_key: big_key.identifier, + tlu_keyswitch_key: keyswitch_key.identifier, + tlu_bootstrap_key: bootstrap_key.identifier, + output_key: big_key.identifier, + extra_conversion_keys: vec![], + }; + let instructions_keys = vec![instruction_keys; dag.0.len()]; + let circuit_keys = ffi::CircuitKeys { + secret_keys: [big_key, small_key].into(), + keyswitch_keys: [keyswitch_key].into(), + bootstrap_keys: [bootstrap_key].into(), + conversion_keyswitch_keys: [].into(), + }; + let is_feasible = sol.p_error < 1.0; + let error_msg = if is_feasible { + "" + } else { + "No crypto-parameters for the given constraints" + } + .into(); + ffi::CircuitSolution { + circuit_keys, + instructions_keys, + complexity: sol.complexity, + p_error: sol.p_error, + global_p_error: sol.global_p_error, + is_feasible, + error_msg, + } +} + +impl From for ffi::CircuitSolution { + fn from(v: CircuitSolution) -> Self { + Self { + circuit_keys: v.circuit_keys.into(), + instructions_keys: vec_into(v.instructions_keys), + complexity: v.complexity, + p_error: v.p_error, + global_p_error: v.global_p_error, + is_feasible: v.is_feasible, + error_msg: v.error_msg, } } } @@ -230,7 +252,7 @@ impl ffi::CircuitSolution { impl From for ffi::KsDecompositionParameters { fn from(v: KsDecompositionParameters) -> Self { - ffi::KsDecompositionParameters { + Self { level: v.level, log2_base: v.log2_base, } @@ -239,7 +261,7 @@ impl From for ffi::KsDecompositionParameters { impl From for ffi::BrDecompositionParameters { fn from(v: BrDecompositionParameters) -> Self { - ffi::BrDecompositionParameters { + Self { level: v.level, log2_base: v.log2_base, } @@ -294,6 +316,18 @@ impl From for ffi::BootstrapKey { } } +impl From for ffi::InstructionKeys { + fn from(v: keys_spec::InstructionKeys) -> Self { + Self { + input_key: v.input_key, + tlu_keyswitch_key: v.tlu_keyswitch_key, + tlu_bootstrap_key: v.tlu_bootstrap_key, + output_key: v.output_key, + extra_conversion_keys: v.extra_conversion_keys, + } + } +} + fn vec_into>(vec: Vec) -> Vec { vec.into_iter().map(|x| x.into()).collect() } @@ -432,8 +466,22 @@ impl OperationDag { } fn optimize_multi(&self, options: ffi::Options) -> ffi::CircuitSolution { - let single_parameter = self.optimize(options); - ffi::CircuitSolution::of(single_parameter, self) + let processing_unit = processing_unit(options); + let config = Config { + security_level: options.security_level, + maximum_acceptable_error_probability: options.maximum_acceptable_error_probability, + ciphertext_modulus_log: 64, + complexity_model: &CpuComplexity::default(), + }; + let search_space = SearchSpace::default(processing_unit); + let circuit_sol = optimize_to_circuit_solution( + &self.0, + config, + &search_space, + &caches_from(options), + &None, + ); + circuit_sol.into() } } @@ -480,6 +528,12 @@ mod ffi { #[namespace = "concrete_optimizer::utils"] fn convert_to_dag_solution(solution: &Solution) -> DagSolution; + #[namespace = "concrete_optimizer::utils"] + fn convert_to_circuit_solution( + solution: &DagSolution, + dag: &OperationDag, + ) -> CircuitSolution; + type OperationDag; #[namespace = "concrete_optimizer::dag"] @@ -685,6 +739,8 @@ mod ffi { pub complexity: f64, pub p_error: f64, pub global_p_error: f64, + pub is_feasible: bool, + pub error_msg: String, } } diff --git a/compilers/concrete-optimizer/concrete-optimizer-cpp/src/cpp/concrete-optimizer.cpp b/compilers/concrete-optimizer/concrete-optimizer-cpp/src/cpp/concrete-optimizer.cpp index 58b24b90f..24c69c22f 100644 --- a/compilers/concrete-optimizer/concrete-optimizer-cpp/src/cpp/concrete-optimizer.cpp +++ b/compilers/concrete-optimizer/concrete-optimizer-cpp/src/cpp/concrete-optimizer.cpp @@ -1188,6 +1188,8 @@ struct CircuitSolution final { double complexity; double p_error; double global_p_error; + bool is_feasible; + ::rust::String error_msg; ::rust::String dump() const noexcept; using IsRelocatable = ::std::true_type; @@ -1204,6 +1206,8 @@ extern "C" { namespace utils { extern "C" { void concrete_optimizer$utils$cxxbridge1$convert_to_dag_solution(::concrete_optimizer::v0::Solution const &solution, ::concrete_optimizer::dag::DagSolution *return$) noexcept; + +void concrete_optimizer$utils$cxxbridge1$convert_to_circuit_solution(::concrete_optimizer::dag::DagSolution const &solution, ::concrete_optimizer::OperationDag const &dag, ::concrete_optimizer::dag::CircuitSolution *return$) noexcept; } // extern "C" } // namespace utils @@ -1269,6 +1273,12 @@ namespace utils { concrete_optimizer$utils$cxxbridge1$convert_to_dag_solution(solution, &return$.value); return ::std::move(return$.value); } + +::concrete_optimizer::dag::CircuitSolution convert_to_circuit_solution(::concrete_optimizer::dag::DagSolution const &solution, ::concrete_optimizer::OperationDag const &dag) noexcept { + ::rust::MaybeUninit<::concrete_optimizer::dag::CircuitSolution> return$; + concrete_optimizer$utils$cxxbridge1$convert_to_circuit_solution(solution, dag, &return$.value); + return ::std::move(return$.value); +} } // namespace utils ::std::size_t OperationDag::layout::size() noexcept { diff --git a/compilers/concrete-optimizer/concrete-optimizer-cpp/src/cpp/concrete-optimizer.hpp b/compilers/concrete-optimizer/concrete-optimizer-cpp/src/cpp/concrete-optimizer.hpp index 791d7750b..ddc053baa 100644 --- a/compilers/concrete-optimizer/concrete-optimizer-cpp/src/cpp/concrete-optimizer.hpp +++ b/compilers/concrete-optimizer/concrete-optimizer-cpp/src/cpp/concrete-optimizer.hpp @@ -1169,6 +1169,8 @@ struct CircuitSolution final { double complexity; double p_error; double global_p_error; + bool is_feasible; + ::rust::String error_msg; ::rust::String dump() const noexcept; using IsRelocatable = ::std::true_type; @@ -1182,6 +1184,8 @@ namespace v0 { namespace utils { ::concrete_optimizer::dag::DagSolution convert_to_dag_solution(::concrete_optimizer::v0::Solution const &solution) noexcept; + +::concrete_optimizer::dag::CircuitSolution convert_to_circuit_solution(::concrete_optimizer::dag::DagSolution const &solution, ::concrete_optimizer::OperationDag const &dag) noexcept; } // namespace utils namespace dag { diff --git a/compilers/concrete-optimizer/concrete-optimizer-cpp/tests/src/main.cpp b/compilers/concrete-optimizer/concrete-optimizer-cpp/tests/src/main.cpp index eab317262..bffc73fba 100644 --- a/compilers/concrete-optimizer/concrete-optimizer-cpp/tests/src/main.cpp +++ b/compilers/concrete-optimizer/concrete-optimizer-cpp/tests/src/main.cpp @@ -109,7 +109,7 @@ void test_dag_lut_force_wop() { assert(!solution.crt_decomposition.empty()); } -void test_multi_parameters() { +void test_multi_parameters_1_precision() { auto dag = concrete_optimizer::dag::empty(); std::vector shape = {3}; @@ -122,8 +122,53 @@ void test_multi_parameters() { auto options = default_options(); auto circuit_solution = dag->optimize_multi(options); - auto secret_keys = circuit_solution.circuit_keys.keyswitch_keys; - assert(!secret_keys.empty()); + assert(circuit_solution.is_feasible); + auto secret_keys = circuit_solution.circuit_keys.secret_keys; + assert(circuit_solution.circuit_keys.secret_keys.size() == 2); + assert(circuit_solution.circuit_keys.secret_keys[0].identifier == 0); + assert(circuit_solution.circuit_keys.secret_keys[1].identifier == 1); + assert(circuit_solution.circuit_keys.bootstrap_keys.size() == 1); + assert(circuit_solution.circuit_keys.keyswitch_keys.size() == 1); + assert(circuit_solution.circuit_keys.keyswitch_keys[0].identifier == 0); + assert(circuit_solution.circuit_keys.keyswitch_keys[0].identifier == 0); + assert(circuit_solution.circuit_keys.conversion_keyswitch_keys.size() == 0); +} + +void test_multi_parameters_2_precision() { + auto dag = concrete_optimizer::dag::empty(); + + std::vector shape = {3}; + + concrete_optimizer::dag::OperatorIndex input1 = + dag->add_input(PRECISION_8B, slice(shape)); + + concrete_optimizer::dag::OperatorIndex input2 = + dag->add_input(PRECISION_1B, slice(shape)); + + + std::vector table = {}; + auto lut1 = dag->add_lut(input1, slice(table), PRECISION_8B); + auto lut2 = dag->add_lut(input2, slice(table), PRECISION_8B); + + std::vector inputs = {lut1, lut2}; + + std::vector weight_vec = {1, 1}; + + rust::cxxbridge1::Box weights = + concrete_optimizer::weights::vector(slice(weight_vec)); + + dag->add_dot(slice(inputs), std::move(weights)); + + auto options = default_options(); + auto circuit_solution = dag->optimize_multi(options); + assert(circuit_solution.is_feasible); + auto secret_keys = circuit_solution.circuit_keys.secret_keys; + assert(circuit_solution.circuit_keys.secret_keys.size() == 4); + assert(circuit_solution.circuit_keys.bootstrap_keys.size() == 2); + assert(circuit_solution.circuit_keys.keyswitch_keys.size() == 2); // 1 layer so less ks + std::string actual = circuit_solution.circuit_keys.conversion_keyswitch_keys[0].description.c_str(); + std::string expected = "fks[1->0]"; + assert(actual == expected); } int main() { @@ -132,7 +177,8 @@ int main() { test_dag_lut(); test_dag_lut_wop(); test_dag_lut_force_wop(); - test_multi_parameters(); + test_multi_parameters_1_precision(); + test_multi_parameters_2_precision(); return 0; } diff --git a/compilers/concrete-optimizer/concrete-optimizer/src/dag/rewrite/regen.rs b/compilers/concrete-optimizer/concrete-optimizer/src/dag/rewrite/regen.rs index 771903d54..d4a295216 100644 --- a/compilers/concrete-optimizer/concrete-optimizer/src/dag/rewrite/regen.rs +++ b/compilers/concrete-optimizer/concrete-optimizer/src/dag/rewrite/regen.rs @@ -21,7 +21,7 @@ fn reindex_op_inputs(op: &Operator, old_index_to_new: &[usize]) -> Operator { pub(crate) fn regen( dag: &OperationDag, f: &mut dyn FnMut(usize, &Operator, &mut OperationDag) -> Option, -) -> OperationDag { +) -> (OperationDag, Vec>) { let mut regen_dag = OperationDag::new(); let mut old_index_to_new = vec![]; for (i, op) in dag.operators.iter().enumerate() { @@ -37,5 +37,25 @@ pub(crate) fn regen( regen_dag.out_shapes.push(dag.out_shapes[i].clone()); } } - regen_dag + (regen_dag, instructions_multi_map(&old_index_to_new)) +} + +fn instructions_multi_map(old_index_to_new: &[usize]) -> Vec> { + let mut last_new_instr = None; + let mut result = vec![]; + result.reserve_exact(old_index_to_new.len()); + for &new_instr in old_index_to_new { + let start_from = last_new_instr.map_or(new_instr, |v: usize| v + 1); + if start_from <= new_instr { + result.push( + (start_from..=new_instr) + .map(|i| OperatorIndex { i }) + .collect(), + ); + } else { + result.push(vec![]); + } + last_new_instr = Some(new_instr); + } + result } diff --git a/compilers/concrete-optimizer/concrete-optimizer/src/dag/rewrite/round.rs b/compilers/concrete-optimizer/concrete-optimizer/src/dag/rewrite/round.rs index ffbd73f38..04487ecef 100644 --- a/compilers/concrete-optimizer/concrete-optimizer/src/dag/rewrite/round.rs +++ b/compilers/concrete-optimizer/concrete-optimizer/src/dag/rewrite/round.rs @@ -14,5 +14,11 @@ fn regen_round(_: usize, op: &Operator, dag: &mut OperationDag) -> Option OperationDag { + regen(dag, &mut regen_round).0 +} + +pub(crate) fn expand_round_and_index_map( + dag: &OperationDag, +) -> (OperationDag, Vec>) { regen(dag, &mut regen_round) } diff --git a/compilers/concrete-optimizer/concrete-optimizer/src/optimization/dag/multi_parameters/analyze.rs b/compilers/concrete-optimizer/concrete-optimizer/src/optimization/dag/multi_parameters/analyze.rs index f45933064..5567d3f1e 100644 --- a/compilers/concrete-optimizer/concrete-optimizer/src/optimization/dag/multi_parameters/analyze.rs +++ b/compilers/concrete-optimizer/concrete-optimizer/src/optimization/dag/multi_parameters/analyze.rs @@ -3,7 +3,7 @@ use std::collections::HashSet; use crate::dag::operator::{ dot_kind, DotKind, LevelledComplexity, Operator, OperatorIndex, Precision, Shape, }; -use crate::dag::rewrite::round::expand_round; +use crate::dag::rewrite::round::expand_round_and_index_map; use crate::dag::unparametrized; use crate::optimization::config::NoiseBoundConfig; use crate::optimization::dag::multi_parameters::partitionning::partitionning_with_preferred; @@ -17,6 +17,7 @@ use crate::optimization::dag::solo_key::analyze::{ }; use super::complexity::OperationsCount; +use super::keys_spec; use super::operations_value::OperationsValue; use super::variance_constraint::VarianceConstraint; @@ -41,6 +42,7 @@ pub struct AnalyzedDag { pub undominated_variance_constraints: Vec, pub operations_count_per_instrs: Vec, pub operations_count: OperationsCount, + pub instruction_rewrite_index: Vec>, pub p_cut: PrecisionCut, } @@ -50,7 +52,7 @@ pub fn analyze( p_cut: &Option, default_partition: PartitionIndex, ) -> AnalyzedDag { - let dag = expand_round(dag); + let (dag, instruction_rewrite_index) = expand_round_and_index_map(dag); let levelled_complexity = LevelledComplexity::ZERO; // The precision cut is chosen to work well with rounded pbs // Note: this is temporary @@ -72,6 +74,7 @@ pub fn analyze( let operations_count = sum_operations_count(&operations_count_per_instrs); AnalyzedDag { operators: dag.operators, + instruction_rewrite_index, nb_partitions, instrs_partition, out_variances, @@ -84,6 +87,77 @@ pub fn analyze( } } +pub fn original_instrs_partition( + dag: &AnalyzedDag, + keys: &keys_spec::ExpandedCircuitKeys, +) -> Vec { + let big_keys = &keys.big_secret_keys; + let ks_keys = &keys.keyswitch_keys; + let pbs_keys = &keys.bootstrap_keys; + let fks_keys = &keys.conversion_keyswitch_keys; + let mut result = vec![]; + result.reserve_exact(dag.instruction_rewrite_index.len()); + let unknown = keys_spec::Id::MAX; + for new_instructions in &dag.instruction_rewrite_index { + let mut partition = None; + let mut input_partition = None; + let mut tlu_keyswitch_key = None; + let mut tlu_bootstrap_key = None; + let mut conversion_key = None; + // let mut extra_conversion_keys = None; + for (i, new_instruction) in new_instructions.iter().enumerate() { + // focus on TLU information + let new_instr_part = &dag.instrs_partition[new_instruction.i]; + if let Op::Lut { .. } = dag.operators[new_instruction.i] { + let ks_dst = new_instr_part.instruction_partition; + partition = Some(ks_dst); + #[allow(clippy::match_on_vec_items)] + let ks_src = match new_instr_part.inputs_transition[0] { + Some(Transition::Internal { src_partition }) => src_partition, + None => ks_dst, + _ => unreachable!(), + }; + input_partition = Some(ks_src); + let ks_key = ks_keys[ks_src][ks_dst].as_ref().unwrap().identifier; + let pbs_key = pbs_keys[ks_dst].identifier; + assert!(tlu_keyswitch_key.unwrap_or(ks_key) == ks_key); + assert!(tlu_bootstrap_key.unwrap_or(pbs_key) == pbs_key); + tlu_keyswitch_key = Some(ks_key); + tlu_bootstrap_key = Some(pbs_key); + } + if !new_instr_part.alternative_output_representation.is_empty() { + assert!(new_instr_part.alternative_output_representation.len() == 1); + let src = new_instr_part.instruction_partition; + let dst = *new_instr_part + .alternative_output_representation + .iter() + .next() + .unwrap(); + let key = fks_keys[src][dst].as_ref().unwrap().identifier; + assert!(conversion_key.unwrap_or(key) == key); + conversion_key = Some(key); + } + // Only last instruction can have alternative conversion + assert!( + new_instr_part.alternative_output_representation.is_empty() + || i == new_instructions.len() - 1 + ); + } + let partition = + partition.unwrap_or(dag.instrs_partition[new_instructions[0].i].instruction_partition); + let input_partition = input_partition.unwrap_or(partition); + let merged = keys_spec::InstructionKeys { + input_key: big_keys[input_partition].identifier, + tlu_keyswitch_key: tlu_keyswitch_key.unwrap_or(unknown), + tlu_bootstrap_key: tlu_bootstrap_key.unwrap_or(unknown), + output_key: big_keys[partition].identifier, + extra_conversion_keys: conversion_key.iter().copied().collect(), + }; + result.push(merged); + } + result +} + fn out_variance( op: &unparametrized::UnparameterizedOperator, out_shapes: &[Shape], diff --git a/compilers/concrete-optimizer/concrete-optimizer/src/optimization/dag/multi_parameters/keys_spec.rs b/compilers/concrete-optimizer/concrete-optimizer/src/optimization/dag/multi_parameters/keys_spec.rs index 430240655..ea22e068e 100644 --- a/compilers/concrete-optimizer/concrete-optimizer/src/optimization/dag/multi_parameters/keys_spec.rs +++ b/compilers/concrete-optimizer/concrete-optimizer/src/optimization/dag/multi_parameters/keys_spec.rs @@ -1,11 +1,13 @@ use crate::parameters::{BrDecompositionParameters, KsDecompositionParameters}; -type Id = u64; +use crate::optimization::dag::multi_parameters::optimize::MacroParameters; + +pub type Id = u64; /* An Id is unique per key type. Starting from 0 for the first key ... */ -type SecretLweKeyId = Id; -type BootstrapKeyId = Id; -type KeySwitchKeyId = Id; -type ConversionKeySwitchKeyId = Id; +pub type SecretLweKeyId = Id; +pub type BootstrapKeyId = Id; +pub type KeySwitchKeyId = Id; +pub type ConversionKeySwitchKeyId = Id; #[derive(Debug, Clone)] pub struct SecretLweKey { @@ -38,7 +40,7 @@ pub struct KeySwitchKey { #[derive(Debug, Clone)] pub struct ConversionKeySwitchKey { - /* Public conversion to make cyphertext with incompatible keys compatible. + /* Public conversion to make compatible ciphertext with incompatible keys. It's currently only between two big secret keys. */ pub identifier: ConversionKeySwitchKeyId, pub input_key: SecretLweKey, @@ -48,7 +50,7 @@ pub struct ConversionKeySwitchKey { pub description: String, } -#[derive(Debug, Clone)] +#[derive(Debug, Default, Clone)] pub struct CircuitKeys { /* All keys used in a circuit, sorted by Id for each key type */ pub secret_keys: Vec, @@ -86,4 +88,129 @@ pub struct CircuitSolution { pub p_error: f64, /* result error rate, assuming any error will propagate to the result */ pub global_p_error: f64, + pub is_feasible: bool, + pub error_msg: String, +} + +pub struct ExpandedCircuitKeys { + pub big_secret_keys: Vec, + pub small_secret_keys: Vec, + pub keyswitch_keys: Vec>>, + pub bootstrap_keys: Vec, + pub conversion_keyswitch_keys: Vec>>, +} + +impl ExpandedCircuitKeys { + pub fn of(params: &super::optimize::Parameters) -> Self { + let nb_partitions = params.macro_params.len(); + let big_secret_keys: Vec<_> = params + .macro_params + .iter() + .enumerate() + .map(|(i, v): (usize, &Option)| { + let glwe_params = v.unwrap().glwe_params; + SecretLweKey { + identifier: i as Id, + polynomial_size: glwe_params.polynomial_size(), + glwe_dimension: glwe_params.glwe_dimension, + description: format!("big-secret[{i}]"), + } + }) + .collect(); + let small_secret_keys: Vec<_> = params + .macro_params + .iter() + .enumerate() + .map(|(i, v): (usize, &Option)| { + let polynomial_size = v.unwrap().internal_dim; + SecretLweKey { + identifier: (nb_partitions + i) as Id, + polynomial_size, + glwe_dimension: 1, + description: format!("small-secret[{i}]"), + } + }) + .collect(); + let bootstrap_keys: Vec<_> = params + .micro_params + .pbs + .iter() + .enumerate() + .map(|(i, v): (usize, &Option<_>)| { + let br_decomposition_parameter = v.unwrap().decomp; + BootstrapKey { + identifier: i as Id, + input_key: small_secret_keys[i].clone(), + output_key: big_secret_keys[i].clone(), + br_decomposition_parameter, + description: format!("pbs[{i}]"), + } + }) + .collect(); + let mut keyswitch_keys = vec![vec![None; nb_partitions]; nb_partitions]; + let mut conversion_keyswitch_keys = vec![vec![None; nb_partitions]; nb_partitions]; + let mut identifier_ks = 0 as Id; + let mut identifier_fks = 0 as Id; + #[allow(clippy::needless_range_loop)] + for src in 0..nb_partitions { + for dst in 0..nb_partitions { + let cross_key = |name: &str| { + if src == dst { + format!("{name}[{src}]") + } else { + format!("{name}[{src}->{dst}]") + } + }; + if let Some(ks) = params.micro_params.ks[src][dst] { + let identifier = identifier_ks; + keyswitch_keys[src][dst] = Some(KeySwitchKey { + identifier, + input_key: big_secret_keys[src].clone(), + output_key: small_secret_keys[dst].clone(), + ks_decomposition_parameter: ks.decomp, + description: cross_key("ks"), + }); + identifier_ks += 1; + } + if let Some(fks) = params.micro_params.fks[src][dst] { + let identifier = identifier_fks; + conversion_keyswitch_keys[src][dst] = Some(ConversionKeySwitchKey { + identifier, + input_key: big_secret_keys[src].clone(), + output_key: big_secret_keys[dst].clone(), + ks_decomposition_parameter: fks.decomp, + fast_keyswitch: true, + description: cross_key("fks"), + }); + identifier_fks += 1; + } + } + } + Self { + big_secret_keys, + small_secret_keys, + keyswitch_keys, + bootstrap_keys, + conversion_keyswitch_keys, + } + } + + pub fn compacted(self) -> CircuitKeys { + CircuitKeys { + secret_keys: [self.big_secret_keys, self.small_secret_keys].concat(), + keyswitch_keys: self + .keyswitch_keys + .into_iter() + .flatten() + .flatten() + .collect(), + bootstrap_keys: self.bootstrap_keys, + conversion_keyswitch_keys: self + .conversion_keyswitch_keys + .into_iter() + .flatten() + .flatten() + .collect(), + } + } } diff --git a/compilers/concrete-optimizer/concrete-optimizer/src/optimization/dag/multi_parameters/optimize.rs b/compilers/concrete-optimizer/concrete-optimizer/src/optimization/dag/multi_parameters/optimize.rs index f7f14dbe6..459436c0d 100644 --- a/compilers/concrete-optimizer/concrete-optimizer/src/optimization/dag/multi_parameters/optimize.rs +++ b/compilers/concrete-optimizer/concrete-optimizer/src/optimization/dag/multi_parameters/optimize.rs @@ -17,19 +17,20 @@ use crate::optimization::dag::multi_parameters::complexity::Complexity; use crate::optimization::dag::multi_parameters::feasible::Feasible; use crate::optimization::dag::multi_parameters::partitions::PartitionIndex; use crate::optimization::dag::multi_parameters::precision_cut::PrecisionCut; +use crate::optimization::dag::multi_parameters::{analyze, keys_spec}; const DEBUG: bool = false; #[derive(Debug, Clone)] pub struct MicroParameters { - pbs: Vec>, - ks: Vec>>, - fks: Vec>>, + pub pbs: Vec>, + pub ks: Vec>>, + pub fks: Vec>>, } // Parameters optimized for 1 partition: // the partition pbs, all used ks for all partitions, a much fks as partition -pub struct PartialMicroParameters { +struct PartialMicroParameters { pbs: CmuxComplexityNoise, ks: Vec>>, fks: Vec>>, @@ -40,19 +41,19 @@ pub struct PartialMicroParameters { #[derive(Debug, Copy, Clone, Eq, PartialEq)] pub struct MacroParameters { - glwe_params: GlweParameters, - internal_dim: u64, + pub glwe_params: GlweParameters, + pub internal_dim: u64, } #[derive(Debug, Clone)] pub struct Parameters { - micro_params: MicroParameters, - macro_params: Vec>, + pub micro_params: MicroParameters, + pub macro_params: Vec>, is_lower_bound: bool, is_feasible: bool, - p_error: f64, - global_p_error: f64, - complexity: f64, + pub p_error: f64, + pub global_p_error: f64, + pub complexity: f64, } #[derive(Debug, Clone)] @@ -792,7 +793,7 @@ pub fn optimize( persistent_caches: &PersistDecompCaches, p_cut: &Option, default_partition: PartitionIndex, -) -> Option { +) -> Option<(AnalyzedDag, Parameters)> { let ciphertext_modulus_log = config.ciphertext_modulus_log; let security_level = config.security_level; let noise_config = NoiseBoundConfig { @@ -904,7 +905,7 @@ pub fn optimize( &feasible, &complexity, ); - Some(params) + Some((dag, params)) } fn used_tlu_keyswitch(dag: &AnalyzedDag) -> Vec> { @@ -1003,5 +1004,48 @@ fn sanity_check( } } +pub fn optimize_to_circuit_solution( + dag: &unparametrized::OperationDag, + config: Config, + search_space: &SearchSpace, + persistent_caches: &PersistDecompCaches, + p_cut: &Option, +) -> keys_spec::CircuitSolution { + let default_partition = 0; + let dag_and_params = optimize( + dag, + config, + search_space, + persistent_caches, + p_cut, + default_partition, + ); + #[allow(clippy::option_if_let_else)] + if let Some((dag, params)) = dag_and_params { + let ext_keys = keys_spec::ExpandedCircuitKeys::of(¶ms); + let instructions_keys = analyze::original_instrs_partition(&dag, &ext_keys); + let circuit_keys = ext_keys.compacted(); + keys_spec::CircuitSolution { + circuit_keys, + instructions_keys, + complexity: params.complexity, + p_error: params.p_error, + global_p_error: params.global_p_error, + is_feasible: true, + error_msg: String::default(), + } + } else { + keys_spec::CircuitSolution { + circuit_keys: keys_spec::CircuitKeys::default(), + instructions_keys: vec![], + complexity: f64::INFINITY, + p_error: 1.0, + global_p_error: 1.0, + is_feasible: false, + error_msg: "No crypto-parameters for the given constraints".into(), + } + } +} + #[cfg(test)] include!("tests/test_optimize.rs"); diff --git a/compilers/concrete-optimizer/concrete-optimizer/src/optimization/dag/multi_parameters/tests/test_optimize.rs b/compilers/concrete-optimizer/concrete-optimizer/src/optimization/dag/multi_parameters/tests/test_optimize.rs index 31d63aae7..a2c2a5823 100644 --- a/compilers/concrete-optimizer/concrete-optimizer/src/optimization/dag/multi_parameters/tests/test_optimize.rs +++ b/compilers/concrete-optimizer/concrete-optimizer/src/optimization/dag/multi_parameters/tests/test_optimize.rs @@ -43,7 +43,7 @@ mod tests { &SHARED_CACHES, p_cut, default_partition, - ) + ).map(|v| v.1) } fn optimize_single(dag: &unparametrized::OperationDag) -> Option {