mirror of
https://github.com/zama-ai/concrete.git
synced 2026-02-08 19:44:57 -05:00
feat(optimizer): multiparameters compiler entry point
This commit is contained in:
@@ -6,7 +6,8 @@ use concrete_optimizer::dag::operator::{
|
||||
};
|
||||
use concrete_optimizer::dag::unparametrized;
|
||||
use concrete_optimizer::optimization::config::{Config, SearchSpace};
|
||||
use concrete_optimizer::optimization::dag::multi_parameters::keys_spec;
|
||||
use concrete_optimizer::optimization::dag::multi_parameters::keys_spec::{self, CircuitSolution};
|
||||
use concrete_optimizer::optimization::dag::multi_parameters::optimize::optimize_to_circuit_solution;
|
||||
use concrete_optimizer::optimization::dag::solo_key::optimize_generic::{
|
||||
Encoding, Solution as DagSolution,
|
||||
};
|
||||
@@ -164,60 +165,81 @@ impl From<DagSolution> for ffi::DagSolution {
|
||||
}
|
||||
}
|
||||
|
||||
impl ffi::CircuitSolution {
|
||||
fn of(sol: ffi::DagSolution, dag: &OperationDag) -> Self {
|
||||
let big_key = ffi::SecretLweKey {
|
||||
identifier: 0,
|
||||
polynomial_size: sol.glwe_polynomial_size,
|
||||
glwe_dimension: sol.glwe_dimension,
|
||||
description: "big representation".into(),
|
||||
};
|
||||
let small_key = ffi::SecretLweKey {
|
||||
identifier: 1,
|
||||
polynomial_size: sol.internal_ks_output_lwe_dimension,
|
||||
glwe_dimension: 1,
|
||||
description: "small representation".into(),
|
||||
};
|
||||
let keyswitch_key = ffi::KeySwitchKey {
|
||||
identifier: 0,
|
||||
input_key: big_key.clone(),
|
||||
output_key: small_key.clone(),
|
||||
ks_decomposition_parameter: ffi::KsDecompositionParameters {
|
||||
level: sol.ks_decomposition_level_count,
|
||||
log2_base: sol.ks_decomposition_base_log,
|
||||
},
|
||||
description: "tlu keyswitch".into(),
|
||||
};
|
||||
let bootstrap_key = ffi::BootstrapKey {
|
||||
identifier: 0,
|
||||
input_key: small_key.clone(),
|
||||
output_key: big_key.clone(),
|
||||
br_decomposition_parameter: ffi::BrDecompositionParameters {
|
||||
level: sol.br_decomposition_level_count,
|
||||
log2_base: sol.br_decomposition_base_log,
|
||||
},
|
||||
description: "tlu bootsrap".into(),
|
||||
};
|
||||
let instruction_keys = ffi::InstructionKeys {
|
||||
input_key: big_key.identifier,
|
||||
tlu_keyswitch_key: keyswitch_key.identifier,
|
||||
tlu_bootstrap_key: bootstrap_key.identifier,
|
||||
output_key: big_key.identifier,
|
||||
extra_conversion_keys: vec![],
|
||||
};
|
||||
let instructions_keys = vec![instruction_keys; dag.0.len()];
|
||||
let circuit_keys = ffi::CircuitKeys {
|
||||
secret_keys: [big_key, small_key].into(),
|
||||
keyswitch_keys: [keyswitch_key].into(),
|
||||
bootstrap_keys: [bootstrap_key].into(),
|
||||
conversion_keyswitch_keys: [].into(),
|
||||
};
|
||||
ffi::CircuitSolution {
|
||||
circuit_keys,
|
||||
instructions_keys,
|
||||
complexity: sol.complexity,
|
||||
p_error: sol.p_error,
|
||||
global_p_error: sol.global_p_error,
|
||||
fn convert_to_circuit_solution(sol: &ffi::DagSolution, dag: &OperationDag) -> ffi::CircuitSolution {
|
||||
let big_key = ffi::SecretLweKey {
|
||||
identifier: 0,
|
||||
polynomial_size: sol.glwe_polynomial_size,
|
||||
glwe_dimension: sol.glwe_dimension,
|
||||
description: "big representation".into(),
|
||||
};
|
||||
let small_key = ffi::SecretLweKey {
|
||||
identifier: 1,
|
||||
polynomial_size: sol.internal_ks_output_lwe_dimension,
|
||||
glwe_dimension: 1,
|
||||
description: "small representation".into(),
|
||||
};
|
||||
let keyswitch_key = ffi::KeySwitchKey {
|
||||
identifier: 0,
|
||||
input_key: big_key.clone(),
|
||||
output_key: small_key.clone(),
|
||||
ks_decomposition_parameter: ffi::KsDecompositionParameters {
|
||||
level: sol.ks_decomposition_level_count,
|
||||
log2_base: sol.ks_decomposition_base_log,
|
||||
},
|
||||
description: "tlu keyswitch".into(),
|
||||
};
|
||||
let bootstrap_key = ffi::BootstrapKey {
|
||||
identifier: 0,
|
||||
input_key: small_key.clone(),
|
||||
output_key: big_key.clone(),
|
||||
br_decomposition_parameter: ffi::BrDecompositionParameters {
|
||||
level: sol.br_decomposition_level_count,
|
||||
log2_base: sol.br_decomposition_base_log,
|
||||
},
|
||||
description: "tlu bootstrap".into(),
|
||||
};
|
||||
let instruction_keys = ffi::InstructionKeys {
|
||||
input_key: big_key.identifier,
|
||||
tlu_keyswitch_key: keyswitch_key.identifier,
|
||||
tlu_bootstrap_key: bootstrap_key.identifier,
|
||||
output_key: big_key.identifier,
|
||||
extra_conversion_keys: vec![],
|
||||
};
|
||||
let instructions_keys = vec![instruction_keys; dag.0.len()];
|
||||
let circuit_keys = ffi::CircuitKeys {
|
||||
secret_keys: [big_key, small_key].into(),
|
||||
keyswitch_keys: [keyswitch_key].into(),
|
||||
bootstrap_keys: [bootstrap_key].into(),
|
||||
conversion_keyswitch_keys: [].into(),
|
||||
};
|
||||
let is_feasible = sol.p_error < 1.0;
|
||||
let error_msg = if is_feasible {
|
||||
""
|
||||
} else {
|
||||
"No crypto-parameters for the given constraints"
|
||||
}
|
||||
.into();
|
||||
ffi::CircuitSolution {
|
||||
circuit_keys,
|
||||
instructions_keys,
|
||||
complexity: sol.complexity,
|
||||
p_error: sol.p_error,
|
||||
global_p_error: sol.global_p_error,
|
||||
is_feasible,
|
||||
error_msg,
|
||||
}
|
||||
}
|
||||
|
||||
impl From<CircuitSolution> for ffi::CircuitSolution {
|
||||
fn from(v: CircuitSolution) -> Self {
|
||||
Self {
|
||||
circuit_keys: v.circuit_keys.into(),
|
||||
instructions_keys: vec_into(v.instructions_keys),
|
||||
complexity: v.complexity,
|
||||
p_error: v.p_error,
|
||||
global_p_error: v.global_p_error,
|
||||
is_feasible: v.is_feasible,
|
||||
error_msg: v.error_msg,
|
||||
}
|
||||
}
|
||||
}
|
||||
@@ -230,7 +252,7 @@ impl ffi::CircuitSolution {
|
||||
|
||||
impl From<KsDecompositionParameters> for ffi::KsDecompositionParameters {
|
||||
fn from(v: KsDecompositionParameters) -> Self {
|
||||
ffi::KsDecompositionParameters {
|
||||
Self {
|
||||
level: v.level,
|
||||
log2_base: v.log2_base,
|
||||
}
|
||||
@@ -239,7 +261,7 @@ impl From<KsDecompositionParameters> for ffi::KsDecompositionParameters {
|
||||
|
||||
impl From<BrDecompositionParameters> for ffi::BrDecompositionParameters {
|
||||
fn from(v: BrDecompositionParameters) -> Self {
|
||||
ffi::BrDecompositionParameters {
|
||||
Self {
|
||||
level: v.level,
|
||||
log2_base: v.log2_base,
|
||||
}
|
||||
@@ -294,6 +316,18 @@ impl From<keys_spec::BootstrapKey> for ffi::BootstrapKey {
|
||||
}
|
||||
}
|
||||
|
||||
impl From<keys_spec::InstructionKeys> for ffi::InstructionKeys {
|
||||
fn from(v: keys_spec::InstructionKeys) -> Self {
|
||||
Self {
|
||||
input_key: v.input_key,
|
||||
tlu_keyswitch_key: v.tlu_keyswitch_key,
|
||||
tlu_bootstrap_key: v.tlu_bootstrap_key,
|
||||
output_key: v.output_key,
|
||||
extra_conversion_keys: v.extra_conversion_keys,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
fn vec_into<F, T: std::convert::From<F>>(vec: Vec<F>) -> Vec<T> {
|
||||
vec.into_iter().map(|x| x.into()).collect()
|
||||
}
|
||||
@@ -432,8 +466,22 @@ impl OperationDag {
|
||||
}
|
||||
|
||||
fn optimize_multi(&self, options: ffi::Options) -> ffi::CircuitSolution {
|
||||
let single_parameter = self.optimize(options);
|
||||
ffi::CircuitSolution::of(single_parameter, self)
|
||||
let processing_unit = processing_unit(options);
|
||||
let config = Config {
|
||||
security_level: options.security_level,
|
||||
maximum_acceptable_error_probability: options.maximum_acceptable_error_probability,
|
||||
ciphertext_modulus_log: 64,
|
||||
complexity_model: &CpuComplexity::default(),
|
||||
};
|
||||
let search_space = SearchSpace::default(processing_unit);
|
||||
let circuit_sol = optimize_to_circuit_solution(
|
||||
&self.0,
|
||||
config,
|
||||
&search_space,
|
||||
&caches_from(options),
|
||||
&None,
|
||||
);
|
||||
circuit_sol.into()
|
||||
}
|
||||
}
|
||||
|
||||
@@ -480,6 +528,12 @@ mod ffi {
|
||||
#[namespace = "concrete_optimizer::utils"]
|
||||
fn convert_to_dag_solution(solution: &Solution) -> DagSolution;
|
||||
|
||||
#[namespace = "concrete_optimizer::utils"]
|
||||
fn convert_to_circuit_solution(
|
||||
solution: &DagSolution,
|
||||
dag: &OperationDag,
|
||||
) -> CircuitSolution;
|
||||
|
||||
type OperationDag;
|
||||
|
||||
#[namespace = "concrete_optimizer::dag"]
|
||||
@@ -685,6 +739,8 @@ mod ffi {
|
||||
pub complexity: f64,
|
||||
pub p_error: f64,
|
||||
pub global_p_error: f64,
|
||||
pub is_feasible: bool,
|
||||
pub error_msg: String,
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
@@ -1188,6 +1188,8 @@ struct CircuitSolution final {
|
||||
double complexity;
|
||||
double p_error;
|
||||
double global_p_error;
|
||||
bool is_feasible;
|
||||
::rust::String error_msg;
|
||||
|
||||
::rust::String dump() const noexcept;
|
||||
using IsRelocatable = ::std::true_type;
|
||||
@@ -1204,6 +1206,8 @@ extern "C" {
|
||||
namespace utils {
|
||||
extern "C" {
|
||||
void concrete_optimizer$utils$cxxbridge1$convert_to_dag_solution(::concrete_optimizer::v0::Solution const &solution, ::concrete_optimizer::dag::DagSolution *return$) noexcept;
|
||||
|
||||
void concrete_optimizer$utils$cxxbridge1$convert_to_circuit_solution(::concrete_optimizer::dag::DagSolution const &solution, ::concrete_optimizer::OperationDag const &dag, ::concrete_optimizer::dag::CircuitSolution *return$) noexcept;
|
||||
} // extern "C"
|
||||
} // namespace utils
|
||||
|
||||
@@ -1269,6 +1273,12 @@ namespace utils {
|
||||
concrete_optimizer$utils$cxxbridge1$convert_to_dag_solution(solution, &return$.value);
|
||||
return ::std::move(return$.value);
|
||||
}
|
||||
|
||||
::concrete_optimizer::dag::CircuitSolution convert_to_circuit_solution(::concrete_optimizer::dag::DagSolution const &solution, ::concrete_optimizer::OperationDag const &dag) noexcept {
|
||||
::rust::MaybeUninit<::concrete_optimizer::dag::CircuitSolution> return$;
|
||||
concrete_optimizer$utils$cxxbridge1$convert_to_circuit_solution(solution, dag, &return$.value);
|
||||
return ::std::move(return$.value);
|
||||
}
|
||||
} // namespace utils
|
||||
|
||||
::std::size_t OperationDag::layout::size() noexcept {
|
||||
|
||||
@@ -1169,6 +1169,8 @@ struct CircuitSolution final {
|
||||
double complexity;
|
||||
double p_error;
|
||||
double global_p_error;
|
||||
bool is_feasible;
|
||||
::rust::String error_msg;
|
||||
|
||||
::rust::String dump() const noexcept;
|
||||
using IsRelocatable = ::std::true_type;
|
||||
@@ -1182,6 +1184,8 @@ namespace v0 {
|
||||
|
||||
namespace utils {
|
||||
::concrete_optimizer::dag::DagSolution convert_to_dag_solution(::concrete_optimizer::v0::Solution const &solution) noexcept;
|
||||
|
||||
::concrete_optimizer::dag::CircuitSolution convert_to_circuit_solution(::concrete_optimizer::dag::DagSolution const &solution, ::concrete_optimizer::OperationDag const &dag) noexcept;
|
||||
} // namespace utils
|
||||
|
||||
namespace dag {
|
||||
|
||||
@@ -109,7 +109,7 @@ void test_dag_lut_force_wop() {
|
||||
assert(!solution.crt_decomposition.empty());
|
||||
}
|
||||
|
||||
void test_multi_parameters() {
|
||||
void test_multi_parameters_1_precision() {
|
||||
auto dag = concrete_optimizer::dag::empty();
|
||||
|
||||
std::vector<uint64_t> shape = {3};
|
||||
@@ -122,8 +122,53 @@ void test_multi_parameters() {
|
||||
|
||||
auto options = default_options();
|
||||
auto circuit_solution = dag->optimize_multi(options);
|
||||
auto secret_keys = circuit_solution.circuit_keys.keyswitch_keys;
|
||||
assert(!secret_keys.empty());
|
||||
assert(circuit_solution.is_feasible);
|
||||
auto secret_keys = circuit_solution.circuit_keys.secret_keys;
|
||||
assert(circuit_solution.circuit_keys.secret_keys.size() == 2);
|
||||
assert(circuit_solution.circuit_keys.secret_keys[0].identifier == 0);
|
||||
assert(circuit_solution.circuit_keys.secret_keys[1].identifier == 1);
|
||||
assert(circuit_solution.circuit_keys.bootstrap_keys.size() == 1);
|
||||
assert(circuit_solution.circuit_keys.keyswitch_keys.size() == 1);
|
||||
assert(circuit_solution.circuit_keys.keyswitch_keys[0].identifier == 0);
|
||||
assert(circuit_solution.circuit_keys.keyswitch_keys[0].identifier == 0);
|
||||
assert(circuit_solution.circuit_keys.conversion_keyswitch_keys.size() == 0);
|
||||
}
|
||||
|
||||
void test_multi_parameters_2_precision() {
|
||||
auto dag = concrete_optimizer::dag::empty();
|
||||
|
||||
std::vector<uint64_t> shape = {3};
|
||||
|
||||
concrete_optimizer::dag::OperatorIndex input1 =
|
||||
dag->add_input(PRECISION_8B, slice(shape));
|
||||
|
||||
concrete_optimizer::dag::OperatorIndex input2 =
|
||||
dag->add_input(PRECISION_1B, slice(shape));
|
||||
|
||||
|
||||
std::vector<u_int64_t> table = {};
|
||||
auto lut1 = dag->add_lut(input1, slice(table), PRECISION_8B);
|
||||
auto lut2 = dag->add_lut(input2, slice(table), PRECISION_8B);
|
||||
|
||||
std::vector<concrete_optimizer::dag::OperatorIndex> inputs = {lut1, lut2};
|
||||
|
||||
std::vector<int64_t> weight_vec = {1, 1};
|
||||
|
||||
rust::cxxbridge1::Box<concrete_optimizer::Weights> weights =
|
||||
concrete_optimizer::weights::vector(slice(weight_vec));
|
||||
|
||||
dag->add_dot(slice(inputs), std::move(weights));
|
||||
|
||||
auto options = default_options();
|
||||
auto circuit_solution = dag->optimize_multi(options);
|
||||
assert(circuit_solution.is_feasible);
|
||||
auto secret_keys = circuit_solution.circuit_keys.secret_keys;
|
||||
assert(circuit_solution.circuit_keys.secret_keys.size() == 4);
|
||||
assert(circuit_solution.circuit_keys.bootstrap_keys.size() == 2);
|
||||
assert(circuit_solution.circuit_keys.keyswitch_keys.size() == 2); // 1 layer so less ks
|
||||
std::string actual = circuit_solution.circuit_keys.conversion_keyswitch_keys[0].description.c_str();
|
||||
std::string expected = "fks[1->0]";
|
||||
assert(actual == expected);
|
||||
}
|
||||
|
||||
int main() {
|
||||
@@ -132,7 +177,8 @@ int main() {
|
||||
test_dag_lut();
|
||||
test_dag_lut_wop();
|
||||
test_dag_lut_force_wop();
|
||||
test_multi_parameters();
|
||||
test_multi_parameters_1_precision();
|
||||
test_multi_parameters_2_precision();
|
||||
|
||||
return 0;
|
||||
}
|
||||
|
||||
@@ -21,7 +21,7 @@ fn reindex_op_inputs(op: &Operator, old_index_to_new: &[usize]) -> Operator {
|
||||
pub(crate) fn regen(
|
||||
dag: &OperationDag,
|
||||
f: &mut dyn FnMut(usize, &Operator, &mut OperationDag) -> Option<OperatorIndex>,
|
||||
) -> OperationDag {
|
||||
) -> (OperationDag, Vec<Vec<OperatorIndex>>) {
|
||||
let mut regen_dag = OperationDag::new();
|
||||
let mut old_index_to_new = vec![];
|
||||
for (i, op) in dag.operators.iter().enumerate() {
|
||||
@@ -37,5 +37,25 @@ pub(crate) fn regen(
|
||||
regen_dag.out_shapes.push(dag.out_shapes[i].clone());
|
||||
}
|
||||
}
|
||||
regen_dag
|
||||
(regen_dag, instructions_multi_map(&old_index_to_new))
|
||||
}
|
||||
|
||||
fn instructions_multi_map(old_index_to_new: &[usize]) -> Vec<Vec<OperatorIndex>> {
|
||||
let mut last_new_instr = None;
|
||||
let mut result = vec![];
|
||||
result.reserve_exact(old_index_to_new.len());
|
||||
for &new_instr in old_index_to_new {
|
||||
let start_from = last_new_instr.map_or(new_instr, |v: usize| v + 1);
|
||||
if start_from <= new_instr {
|
||||
result.push(
|
||||
(start_from..=new_instr)
|
||||
.map(|i| OperatorIndex { i })
|
||||
.collect(),
|
||||
);
|
||||
} else {
|
||||
result.push(vec![]);
|
||||
}
|
||||
last_new_instr = Some(new_instr);
|
||||
}
|
||||
result
|
||||
}
|
||||
|
||||
@@ -14,5 +14,11 @@ fn regen_round(_: usize, op: &Operator, dag: &mut OperationDag) -> Option<Operat
|
||||
}
|
||||
|
||||
pub(crate) fn expand_round(dag: &OperationDag) -> OperationDag {
|
||||
regen(dag, &mut regen_round).0
|
||||
}
|
||||
|
||||
pub(crate) fn expand_round_and_index_map(
|
||||
dag: &OperationDag,
|
||||
) -> (OperationDag, Vec<Vec<OperatorIndex>>) {
|
||||
regen(dag, &mut regen_round)
|
||||
}
|
||||
|
||||
@@ -3,7 +3,7 @@ use std::collections::HashSet;
|
||||
use crate::dag::operator::{
|
||||
dot_kind, DotKind, LevelledComplexity, Operator, OperatorIndex, Precision, Shape,
|
||||
};
|
||||
use crate::dag::rewrite::round::expand_round;
|
||||
use crate::dag::rewrite::round::expand_round_and_index_map;
|
||||
use crate::dag::unparametrized;
|
||||
use crate::optimization::config::NoiseBoundConfig;
|
||||
use crate::optimization::dag::multi_parameters::partitionning::partitionning_with_preferred;
|
||||
@@ -17,6 +17,7 @@ use crate::optimization::dag::solo_key::analyze::{
|
||||
};
|
||||
|
||||
use super::complexity::OperationsCount;
|
||||
use super::keys_spec;
|
||||
use super::operations_value::OperationsValue;
|
||||
use super::variance_constraint::VarianceConstraint;
|
||||
|
||||
@@ -41,6 +42,7 @@ pub struct AnalyzedDag {
|
||||
pub undominated_variance_constraints: Vec<VarianceConstraint>,
|
||||
pub operations_count_per_instrs: Vec<OperationsCount>,
|
||||
pub operations_count: OperationsCount,
|
||||
pub instruction_rewrite_index: Vec<Vec<OperatorIndex>>,
|
||||
pub p_cut: PrecisionCut,
|
||||
}
|
||||
|
||||
@@ -50,7 +52,7 @@ pub fn analyze(
|
||||
p_cut: &Option<PrecisionCut>,
|
||||
default_partition: PartitionIndex,
|
||||
) -> AnalyzedDag {
|
||||
let dag = expand_round(dag);
|
||||
let (dag, instruction_rewrite_index) = expand_round_and_index_map(dag);
|
||||
let levelled_complexity = LevelledComplexity::ZERO;
|
||||
// The precision cut is chosen to work well with rounded pbs
|
||||
// Note: this is temporary
|
||||
@@ -72,6 +74,7 @@ pub fn analyze(
|
||||
let operations_count = sum_operations_count(&operations_count_per_instrs);
|
||||
AnalyzedDag {
|
||||
operators: dag.operators,
|
||||
instruction_rewrite_index,
|
||||
nb_partitions,
|
||||
instrs_partition,
|
||||
out_variances,
|
||||
@@ -84,6 +87,77 @@ pub fn analyze(
|
||||
}
|
||||
}
|
||||
|
||||
pub fn original_instrs_partition(
|
||||
dag: &AnalyzedDag,
|
||||
keys: &keys_spec::ExpandedCircuitKeys,
|
||||
) -> Vec<keys_spec::InstructionKeys> {
|
||||
let big_keys = &keys.big_secret_keys;
|
||||
let ks_keys = &keys.keyswitch_keys;
|
||||
let pbs_keys = &keys.bootstrap_keys;
|
||||
let fks_keys = &keys.conversion_keyswitch_keys;
|
||||
let mut result = vec![];
|
||||
result.reserve_exact(dag.instruction_rewrite_index.len());
|
||||
let unknown = keys_spec::Id::MAX;
|
||||
for new_instructions in &dag.instruction_rewrite_index {
|
||||
let mut partition = None;
|
||||
let mut input_partition = None;
|
||||
let mut tlu_keyswitch_key = None;
|
||||
let mut tlu_bootstrap_key = None;
|
||||
let mut conversion_key = None;
|
||||
// let mut extra_conversion_keys = None;
|
||||
for (i, new_instruction) in new_instructions.iter().enumerate() {
|
||||
// focus on TLU information
|
||||
let new_instr_part = &dag.instrs_partition[new_instruction.i];
|
||||
if let Op::Lut { .. } = dag.operators[new_instruction.i] {
|
||||
let ks_dst = new_instr_part.instruction_partition;
|
||||
partition = Some(ks_dst);
|
||||
#[allow(clippy::match_on_vec_items)]
|
||||
let ks_src = match new_instr_part.inputs_transition[0] {
|
||||
Some(Transition::Internal { src_partition }) => src_partition,
|
||||
None => ks_dst,
|
||||
_ => unreachable!(),
|
||||
};
|
||||
input_partition = Some(ks_src);
|
||||
let ks_key = ks_keys[ks_src][ks_dst].as_ref().unwrap().identifier;
|
||||
let pbs_key = pbs_keys[ks_dst].identifier;
|
||||
assert!(tlu_keyswitch_key.unwrap_or(ks_key) == ks_key);
|
||||
assert!(tlu_bootstrap_key.unwrap_or(pbs_key) == pbs_key);
|
||||
tlu_keyswitch_key = Some(ks_key);
|
||||
tlu_bootstrap_key = Some(pbs_key);
|
||||
}
|
||||
if !new_instr_part.alternative_output_representation.is_empty() {
|
||||
assert!(new_instr_part.alternative_output_representation.len() == 1);
|
||||
let src = new_instr_part.instruction_partition;
|
||||
let dst = *new_instr_part
|
||||
.alternative_output_representation
|
||||
.iter()
|
||||
.next()
|
||||
.unwrap();
|
||||
let key = fks_keys[src][dst].as_ref().unwrap().identifier;
|
||||
assert!(conversion_key.unwrap_or(key) == key);
|
||||
conversion_key = Some(key);
|
||||
}
|
||||
// Only last instruction can have alternative conversion
|
||||
assert!(
|
||||
new_instr_part.alternative_output_representation.is_empty()
|
||||
|| i == new_instructions.len() - 1
|
||||
);
|
||||
}
|
||||
let partition =
|
||||
partition.unwrap_or(dag.instrs_partition[new_instructions[0].i].instruction_partition);
|
||||
let input_partition = input_partition.unwrap_or(partition);
|
||||
let merged = keys_spec::InstructionKeys {
|
||||
input_key: big_keys[input_partition].identifier,
|
||||
tlu_keyswitch_key: tlu_keyswitch_key.unwrap_or(unknown),
|
||||
tlu_bootstrap_key: tlu_bootstrap_key.unwrap_or(unknown),
|
||||
output_key: big_keys[partition].identifier,
|
||||
extra_conversion_keys: conversion_key.iter().copied().collect(),
|
||||
};
|
||||
result.push(merged);
|
||||
}
|
||||
result
|
||||
}
|
||||
|
||||
fn out_variance(
|
||||
op: &unparametrized::UnparameterizedOperator,
|
||||
out_shapes: &[Shape],
|
||||
|
||||
@@ -1,11 +1,13 @@
|
||||
use crate::parameters::{BrDecompositionParameters, KsDecompositionParameters};
|
||||
|
||||
type Id = u64;
|
||||
use crate::optimization::dag::multi_parameters::optimize::MacroParameters;
|
||||
|
||||
pub type Id = u64;
|
||||
/* An Id is unique per key type. Starting from 0 for the first key ... */
|
||||
type SecretLweKeyId = Id;
|
||||
type BootstrapKeyId = Id;
|
||||
type KeySwitchKeyId = Id;
|
||||
type ConversionKeySwitchKeyId = Id;
|
||||
pub type SecretLweKeyId = Id;
|
||||
pub type BootstrapKeyId = Id;
|
||||
pub type KeySwitchKeyId = Id;
|
||||
pub type ConversionKeySwitchKeyId = Id;
|
||||
|
||||
#[derive(Debug, Clone)]
|
||||
pub struct SecretLweKey {
|
||||
@@ -38,7 +40,7 @@ pub struct KeySwitchKey {
|
||||
|
||||
#[derive(Debug, Clone)]
|
||||
pub struct ConversionKeySwitchKey {
|
||||
/* Public conversion to make cyphertext with incompatible keys compatible.
|
||||
/* Public conversion to make compatible ciphertext with incompatible keys.
|
||||
It's currently only between two big secret keys. */
|
||||
pub identifier: ConversionKeySwitchKeyId,
|
||||
pub input_key: SecretLweKey,
|
||||
@@ -48,7 +50,7 @@ pub struct ConversionKeySwitchKey {
|
||||
pub description: String,
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone)]
|
||||
#[derive(Debug, Default, Clone)]
|
||||
pub struct CircuitKeys {
|
||||
/* All keys used in a circuit, sorted by Id for each key type */
|
||||
pub secret_keys: Vec<SecretLweKey>,
|
||||
@@ -86,4 +88,129 @@ pub struct CircuitSolution {
|
||||
pub p_error: f64,
|
||||
/* result error rate, assuming any error will propagate to the result */
|
||||
pub global_p_error: f64,
|
||||
pub is_feasible: bool,
|
||||
pub error_msg: String,
|
||||
}
|
||||
|
||||
pub struct ExpandedCircuitKeys {
|
||||
pub big_secret_keys: Vec<SecretLweKey>,
|
||||
pub small_secret_keys: Vec<SecretLweKey>,
|
||||
pub keyswitch_keys: Vec<Vec<Option<KeySwitchKey>>>,
|
||||
pub bootstrap_keys: Vec<BootstrapKey>,
|
||||
pub conversion_keyswitch_keys: Vec<Vec<Option<ConversionKeySwitchKey>>>,
|
||||
}
|
||||
|
||||
impl ExpandedCircuitKeys {
|
||||
pub fn of(params: &super::optimize::Parameters) -> Self {
|
||||
let nb_partitions = params.macro_params.len();
|
||||
let big_secret_keys: Vec<_> = params
|
||||
.macro_params
|
||||
.iter()
|
||||
.enumerate()
|
||||
.map(|(i, v): (usize, &Option<MacroParameters>)| {
|
||||
let glwe_params = v.unwrap().glwe_params;
|
||||
SecretLweKey {
|
||||
identifier: i as Id,
|
||||
polynomial_size: glwe_params.polynomial_size(),
|
||||
glwe_dimension: glwe_params.glwe_dimension,
|
||||
description: format!("big-secret[{i}]"),
|
||||
}
|
||||
})
|
||||
.collect();
|
||||
let small_secret_keys: Vec<_> = params
|
||||
.macro_params
|
||||
.iter()
|
||||
.enumerate()
|
||||
.map(|(i, v): (usize, &Option<MacroParameters>)| {
|
||||
let polynomial_size = v.unwrap().internal_dim;
|
||||
SecretLweKey {
|
||||
identifier: (nb_partitions + i) as Id,
|
||||
polynomial_size,
|
||||
glwe_dimension: 1,
|
||||
description: format!("small-secret[{i}]"),
|
||||
}
|
||||
})
|
||||
.collect();
|
||||
let bootstrap_keys: Vec<_> = params
|
||||
.micro_params
|
||||
.pbs
|
||||
.iter()
|
||||
.enumerate()
|
||||
.map(|(i, v): (usize, &Option<_>)| {
|
||||
let br_decomposition_parameter = v.unwrap().decomp;
|
||||
BootstrapKey {
|
||||
identifier: i as Id,
|
||||
input_key: small_secret_keys[i].clone(),
|
||||
output_key: big_secret_keys[i].clone(),
|
||||
br_decomposition_parameter,
|
||||
description: format!("pbs[{i}]"),
|
||||
}
|
||||
})
|
||||
.collect();
|
||||
let mut keyswitch_keys = vec![vec![None; nb_partitions]; nb_partitions];
|
||||
let mut conversion_keyswitch_keys = vec![vec![None; nb_partitions]; nb_partitions];
|
||||
let mut identifier_ks = 0 as Id;
|
||||
let mut identifier_fks = 0 as Id;
|
||||
#[allow(clippy::needless_range_loop)]
|
||||
for src in 0..nb_partitions {
|
||||
for dst in 0..nb_partitions {
|
||||
let cross_key = |name: &str| {
|
||||
if src == dst {
|
||||
format!("{name}[{src}]")
|
||||
} else {
|
||||
format!("{name}[{src}->{dst}]")
|
||||
}
|
||||
};
|
||||
if let Some(ks) = params.micro_params.ks[src][dst] {
|
||||
let identifier = identifier_ks;
|
||||
keyswitch_keys[src][dst] = Some(KeySwitchKey {
|
||||
identifier,
|
||||
input_key: big_secret_keys[src].clone(),
|
||||
output_key: small_secret_keys[dst].clone(),
|
||||
ks_decomposition_parameter: ks.decomp,
|
||||
description: cross_key("ks"),
|
||||
});
|
||||
identifier_ks += 1;
|
||||
}
|
||||
if let Some(fks) = params.micro_params.fks[src][dst] {
|
||||
let identifier = identifier_fks;
|
||||
conversion_keyswitch_keys[src][dst] = Some(ConversionKeySwitchKey {
|
||||
identifier,
|
||||
input_key: big_secret_keys[src].clone(),
|
||||
output_key: big_secret_keys[dst].clone(),
|
||||
ks_decomposition_parameter: fks.decomp,
|
||||
fast_keyswitch: true,
|
||||
description: cross_key("fks"),
|
||||
});
|
||||
identifier_fks += 1;
|
||||
}
|
||||
}
|
||||
}
|
||||
Self {
|
||||
big_secret_keys,
|
||||
small_secret_keys,
|
||||
keyswitch_keys,
|
||||
bootstrap_keys,
|
||||
conversion_keyswitch_keys,
|
||||
}
|
||||
}
|
||||
|
||||
pub fn compacted(self) -> CircuitKeys {
|
||||
CircuitKeys {
|
||||
secret_keys: [self.big_secret_keys, self.small_secret_keys].concat(),
|
||||
keyswitch_keys: self
|
||||
.keyswitch_keys
|
||||
.into_iter()
|
||||
.flatten()
|
||||
.flatten()
|
||||
.collect(),
|
||||
bootstrap_keys: self.bootstrap_keys,
|
||||
conversion_keyswitch_keys: self
|
||||
.conversion_keyswitch_keys
|
||||
.into_iter()
|
||||
.flatten()
|
||||
.flatten()
|
||||
.collect(),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
@@ -17,19 +17,20 @@ use crate::optimization::dag::multi_parameters::complexity::Complexity;
|
||||
use crate::optimization::dag::multi_parameters::feasible::Feasible;
|
||||
use crate::optimization::dag::multi_parameters::partitions::PartitionIndex;
|
||||
use crate::optimization::dag::multi_parameters::precision_cut::PrecisionCut;
|
||||
use crate::optimization::dag::multi_parameters::{analyze, keys_spec};
|
||||
|
||||
const DEBUG: bool = false;
|
||||
|
||||
#[derive(Debug, Clone)]
|
||||
pub struct MicroParameters {
|
||||
pbs: Vec<Option<CmuxComplexityNoise>>,
|
||||
ks: Vec<Vec<Option<KsComplexityNoise>>>,
|
||||
fks: Vec<Vec<Option<FksComplexityNoise>>>,
|
||||
pub pbs: Vec<Option<CmuxComplexityNoise>>,
|
||||
pub ks: Vec<Vec<Option<KsComplexityNoise>>>,
|
||||
pub fks: Vec<Vec<Option<FksComplexityNoise>>>,
|
||||
}
|
||||
|
||||
// Parameters optimized for 1 partition:
|
||||
// the partition pbs, all used ks for all partitions, a much fks as partition
|
||||
pub struct PartialMicroParameters {
|
||||
struct PartialMicroParameters {
|
||||
pbs: CmuxComplexityNoise,
|
||||
ks: Vec<Vec<Option<KsComplexityNoise>>>,
|
||||
fks: Vec<Vec<Option<FksComplexityNoise>>>,
|
||||
@@ -40,19 +41,19 @@ pub struct PartialMicroParameters {
|
||||
|
||||
#[derive(Debug, Copy, Clone, Eq, PartialEq)]
|
||||
pub struct MacroParameters {
|
||||
glwe_params: GlweParameters,
|
||||
internal_dim: u64,
|
||||
pub glwe_params: GlweParameters,
|
||||
pub internal_dim: u64,
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone)]
|
||||
pub struct Parameters {
|
||||
micro_params: MicroParameters,
|
||||
macro_params: Vec<Option<MacroParameters>>,
|
||||
pub micro_params: MicroParameters,
|
||||
pub macro_params: Vec<Option<MacroParameters>>,
|
||||
is_lower_bound: bool,
|
||||
is_feasible: bool,
|
||||
p_error: f64,
|
||||
global_p_error: f64,
|
||||
complexity: f64,
|
||||
pub p_error: f64,
|
||||
pub global_p_error: f64,
|
||||
pub complexity: f64,
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone)]
|
||||
@@ -792,7 +793,7 @@ pub fn optimize(
|
||||
persistent_caches: &PersistDecompCaches,
|
||||
p_cut: &Option<PrecisionCut>,
|
||||
default_partition: PartitionIndex,
|
||||
) -> Option<Parameters> {
|
||||
) -> Option<(AnalyzedDag, Parameters)> {
|
||||
let ciphertext_modulus_log = config.ciphertext_modulus_log;
|
||||
let security_level = config.security_level;
|
||||
let noise_config = NoiseBoundConfig {
|
||||
@@ -904,7 +905,7 @@ pub fn optimize(
|
||||
&feasible,
|
||||
&complexity,
|
||||
);
|
||||
Some(params)
|
||||
Some((dag, params))
|
||||
}
|
||||
|
||||
fn used_tlu_keyswitch(dag: &AnalyzedDag) -> Vec<Vec<bool>> {
|
||||
@@ -1003,5 +1004,48 @@ fn sanity_check(
|
||||
}
|
||||
}
|
||||
|
||||
pub fn optimize_to_circuit_solution(
|
||||
dag: &unparametrized::OperationDag,
|
||||
config: Config,
|
||||
search_space: &SearchSpace,
|
||||
persistent_caches: &PersistDecompCaches,
|
||||
p_cut: &Option<PrecisionCut>,
|
||||
) -> keys_spec::CircuitSolution {
|
||||
let default_partition = 0;
|
||||
let dag_and_params = optimize(
|
||||
dag,
|
||||
config,
|
||||
search_space,
|
||||
persistent_caches,
|
||||
p_cut,
|
||||
default_partition,
|
||||
);
|
||||
#[allow(clippy::option_if_let_else)]
|
||||
if let Some((dag, params)) = dag_and_params {
|
||||
let ext_keys = keys_spec::ExpandedCircuitKeys::of(¶ms);
|
||||
let instructions_keys = analyze::original_instrs_partition(&dag, &ext_keys);
|
||||
let circuit_keys = ext_keys.compacted();
|
||||
keys_spec::CircuitSolution {
|
||||
circuit_keys,
|
||||
instructions_keys,
|
||||
complexity: params.complexity,
|
||||
p_error: params.p_error,
|
||||
global_p_error: params.global_p_error,
|
||||
is_feasible: true,
|
||||
error_msg: String::default(),
|
||||
}
|
||||
} else {
|
||||
keys_spec::CircuitSolution {
|
||||
circuit_keys: keys_spec::CircuitKeys::default(),
|
||||
instructions_keys: vec![],
|
||||
complexity: f64::INFINITY,
|
||||
p_error: 1.0,
|
||||
global_p_error: 1.0,
|
||||
is_feasible: false,
|
||||
error_msg: "No crypto-parameters for the given constraints".into(),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
include!("tests/test_optimize.rs");
|
||||
|
||||
@@ -43,7 +43,7 @@ mod tests {
|
||||
&SHARED_CACHES,
|
||||
p_cut,
|
||||
default_partition,
|
||||
)
|
||||
).map(|v| v.1)
|
||||
}
|
||||
|
||||
fn optimize_single(dag: &unparametrized::OperationDag) -> Option<Parameters> {
|
||||
|
||||
Reference in New Issue
Block a user