diff --git a/compilers/concrete-compiler/compiler/lib/Dialect/FHE/Analysis/ConcreteOptimizer.cpp b/compilers/concrete-compiler/compiler/lib/Dialect/FHE/Analysis/ConcreteOptimizer.cpp index 0fcb5e7fc..ae2dd3f2b 100644 --- a/compilers/concrete-compiler/compiler/lib/Dialect/FHE/Analysis/ConcreteOptimizer.cpp +++ b/compilers/concrete-compiler/compiler/lib/Dialect/FHE/Analysis/ConcreteOptimizer.cpp @@ -129,18 +129,20 @@ struct FunctionToDag { void addOperation(optimizer::Dag &dag, mlir::Operation &op) { DEBUG("Instr " << op); + auto encrypted_inputs = encryptedInputs(op); if (isReturn(op)) { - // This op has no result + for (auto op : encrypted_inputs) { + dag->tag_operator_as_output(op); + } return; } - - auto encrypted_inputs = encryptedInputs(op); if (!hasEncryptedResult(op)) { // This op is unrelated to FHE assert(encrypted_inputs.empty() || mlir::isa(op)); return; } + assert(op.getNumResults() == 1); auto val = op.getResult(0); auto precision = fhe::utils::getEintPrecision(val); diff --git a/compilers/concrete-compiler/compiler/tests/end_to_end_tests/end_to_end_jit_test.cc b/compilers/concrete-compiler/compiler/tests/end_to_end_tests/end_to_end_jit_test.cc index 1e0f61705..0dcb52b7a 100644 --- a/compilers/concrete-compiler/compiler/tests/end_to_end_tests/end_to_end_jit_test.cc +++ b/compilers/concrete-compiler/compiler/tests/end_to_end_tests/end_to_end_jit_test.cc @@ -410,9 +410,7 @@ func.func @main(%arg0: !FHE.eint<3>) -> !FHE.eint<3> { err, "Program can not be composed: No luts in the circuit."); } -// This test pass while the compilation should failed as %1 is not refresh so it -// should not be composable. -TEST(DISABLED_CompileNotComposable, not_composable_2) { +TEST(CompileNotComposable, not_composable_2) { mlir::concretelang::CompilationOptions options("main"); options.optimizerConfig.composable = true; options.optimizerConfig.display = true; @@ -427,7 +425,8 @@ func.func @main(%arg0: !FHE.eint<3>) -> (!FHE.eint<3>, !FHE.eint<3>) { return %1, %2: !FHE.eint<3>, !FHE.eint<3> } )XXX"); - ASSERT_OUTCOME_HAS_FAILURE_WITH_ERRORMSG(err, "NotComposable"); + ASSERT_OUTCOME_HAS_FAILURE_WITH_ERRORMSG( + err, "Program can not be composed: Output 1 has variance 1σ²In[0]."); } TEST(CompileComposable, composable_supported_dag_mono) { diff --git a/compilers/concrete-optimizer/concrete-optimizer-cpp/src/concrete-optimizer.rs b/compilers/concrete-optimizer/concrete-optimizer-cpp/src/concrete-optimizer.rs index 02fbee070..b02ca15e9 100644 --- a/compilers/concrete-optimizer/concrete-optimizer-cpp/src/concrete-optimizer.rs +++ b/compilers/concrete-optimizer/concrete-optimizer-cpp/src/concrete-optimizer.rs @@ -599,6 +599,10 @@ impl OperationDag { self.0.dump() } + fn tag_operator_as_output(&mut self, op: ffi::OperatorIndex) { + self.0.tag_operator_as_output(op.into()); + } + fn optimize_multi(&self, options: ffi::Options) -> ffi::CircuitSolution { let processing_unit = processing_unit(options); let config = Config { @@ -751,6 +755,8 @@ mod ffi { #[namespace = "concrete_optimizer::weights"] fn number(weight: i64) -> Box; + fn tag_operator_as_output(self: &mut OperationDag, op: OperatorIndex); + fn optimize_multi(self: &OperationDag, options: Options) -> CircuitSolution; fn NO_KEY_ID() -> u64; diff --git a/compilers/concrete-optimizer/concrete-optimizer-cpp/src/cpp/concrete-optimizer.cpp b/compilers/concrete-optimizer/concrete-optimizer-cpp/src/cpp/concrete-optimizer.cpp index a445ab635..5733ae123 100644 --- a/compilers/concrete-optimizer/concrete-optimizer-cpp/src/cpp/concrete-optimizer.cpp +++ b/compilers/concrete-optimizer/concrete-optimizer-cpp/src/cpp/concrete-optimizer.cpp @@ -976,6 +976,7 @@ struct OperationDag final : public ::rust::Opaque { ::concrete_optimizer::dag::OperatorIndex add_unsafe_cast_op(::concrete_optimizer::dag::OperatorIndex input, ::std::uint8_t rounded_precision) noexcept; ::concrete_optimizer::dag::DagSolution optimize(::concrete_optimizer::Options options) const noexcept; ::rust::String dump() const noexcept; + void tag_operator_as_output(::concrete_optimizer::dag::OperatorIndex op) noexcept; ::concrete_optimizer::dag::CircuitSolution optimize_multi(::concrete_optimizer::Options options) const noexcept; ~OperationDag() = delete; @@ -1312,6 +1313,8 @@ extern "C" { } // namespace weights extern "C" { +void concrete_optimizer$cxxbridge1$OperationDag$tag_operator_as_output(::concrete_optimizer::OperationDag &self, ::concrete_optimizer::dag::OperatorIndex op) noexcept; + void concrete_optimizer$cxxbridge1$OperationDag$optimize_multi(::concrete_optimizer::OperationDag const &self, ::concrete_optimizer::Options options, ::concrete_optimizer::dag::CircuitSolution *return$) noexcept; ::std::uint64_t concrete_optimizer$cxxbridge1$NO_KEY_ID() noexcept; @@ -1419,6 +1422,10 @@ namespace weights { } } // namespace weights +void OperationDag::tag_operator_as_output(::concrete_optimizer::dag::OperatorIndex op) noexcept { + concrete_optimizer$cxxbridge1$OperationDag$tag_operator_as_output(*this, op); +} + ::concrete_optimizer::dag::CircuitSolution OperationDag::optimize_multi(::concrete_optimizer::Options options) const noexcept { ::rust::MaybeUninit<::concrete_optimizer::dag::CircuitSolution> return$; concrete_optimizer$cxxbridge1$OperationDag$optimize_multi(*this, options, &return$.value); diff --git a/compilers/concrete-optimizer/concrete-optimizer-cpp/src/cpp/concrete-optimizer.hpp b/compilers/concrete-optimizer/concrete-optimizer-cpp/src/cpp/concrete-optimizer.hpp index 2c3d3fd82..486d9adce 100644 --- a/compilers/concrete-optimizer/concrete-optimizer-cpp/src/cpp/concrete-optimizer.hpp +++ b/compilers/concrete-optimizer/concrete-optimizer-cpp/src/cpp/concrete-optimizer.hpp @@ -957,6 +957,7 @@ struct OperationDag final : public ::rust::Opaque { ::concrete_optimizer::dag::OperatorIndex add_unsafe_cast_op(::concrete_optimizer::dag::OperatorIndex input, ::std::uint8_t rounded_precision) noexcept; ::concrete_optimizer::dag::DagSolution optimize(::concrete_optimizer::Options options) const noexcept; ::rust::String dump() const noexcept; + void tag_operator_as_output(::concrete_optimizer::dag::OperatorIndex op) noexcept; ::concrete_optimizer::dag::CircuitSolution optimize_multi(::concrete_optimizer::Options options) const noexcept; ~OperationDag() = delete; diff --git a/compilers/concrete-optimizer/concrete-optimizer-cpp/tests/src/main.cpp b/compilers/concrete-optimizer/concrete-optimizer-cpp/tests/src/main.cpp index d23b21e1b..1441e09eb 100644 --- a/compilers/concrete-optimizer/concrete-optimizer-cpp/tests/src/main.cpp +++ b/compilers/concrete-optimizer/concrete-optimizer-cpp/tests/src/main.cpp @@ -61,7 +61,8 @@ TEST test_dag_no_lut() { rust::cxxbridge1::Box weights = concrete_optimizer::weights::vector(slice(weight_vec)); - dag->add_dot(slice(inputs), std::move(weights)); + auto id = dag->add_dot(slice(inputs), std::move(weights)); + dag->tag_operator_as_output(id); auto solution = dag->optimize(default_options()); assert(solution.glwe_polynomial_size == 1); @@ -77,7 +78,8 @@ TEST test_dag_lut() { dag->add_input(PRECISION_8B, slice(shape)); std::vector table = {}; - dag->add_lut(input, slice(table), PRECISION_8B); + auto id = dag->add_lut(input, slice(table), PRECISION_8B); + dag->tag_operator_as_output(id); auto solution = dag->optimize(default_options()); assert(solution.glwe_dimension == 1); @@ -94,7 +96,8 @@ TEST test_dag_lut_wop() { dag->add_input(PRECISION_16B, slice(shape)); std::vector table = {}; - dag->add_lut(input, slice(table), PRECISION_16B); + auto id = dag->add_lut(input, slice(table), PRECISION_16B); + dag->tag_operator_as_output(id); auto solution = dag->optimize(default_options()); assert(solution.glwe_dimension == 2); @@ -111,7 +114,8 @@ TEST test_dag_lut_force_wop() { dag->add_input(PRECISION_8B, slice(shape)); std::vector table = {}; - dag->add_lut(input, slice(table), PRECISION_8B); + auto id = dag->add_lut(input, slice(table), PRECISION_8B); + dag->tag_operator_as_output(id); auto options = default_options(); options.encoding = concrete_optimizer::Encoding::Crt; @@ -129,7 +133,8 @@ TEST test_multi_parameters_1_precision() { dag->add_input(PRECISION_8B, slice(shape)); std::vector table = {}; - dag->add_lut(input, slice(table), PRECISION_8B); + auto id = dag->add_lut(input, slice(table), PRECISION_8B); + dag->tag_operator_as_output(id); auto options = default_options(); auto circuit_solution = dag->optimize_multi(options); @@ -167,7 +172,8 @@ TEST test_multi_parameters_2_precision() { rust::cxxbridge1::Box weights = concrete_optimizer::weights::vector(slice(weight_vec)); - dag->add_dot(slice(inputs), std::move(weights)); + auto id = dag->add_dot(slice(inputs), std::move(weights)); + dag->tag_operator_as_output(id); auto options = default_options(); auto circuit_solution = dag->optimize_multi(options); @@ -206,7 +212,8 @@ TEST test_multi_parameters_2_precision_crt() { rust::cxxbridge1::Box weights = concrete_optimizer::weights::vector(slice(weight_vec)); - dag->add_dot(slice(inputs), std::move(weights)); + auto id = dag->add_dot(slice(inputs), std::move(weights)); + dag->tag_operator_as_output(id); auto options = default_options(); options.encoding = concrete_optimizer::Encoding::Crt; @@ -238,7 +245,8 @@ TEST test_composable_dag_mono_fallback_on_dag_multi() { std::vector lut1v = {lut1}; rust::cxxbridge1::Box weights2 = concrete_optimizer::weights::vector(slice(weight_vec)); - dag->add_dot(slice(lut1v), std::move(weights2)); + auto id = dag->add_dot(slice(lut1v), std::move(weights2)); + dag->tag_operator_as_output(id); auto options = default_options(); auto solution1 = dag->optimize(options); @@ -272,7 +280,8 @@ TEST test_non_composable_dag_mono_fallback_on_woppbs() { std::vector lut1v = {lut1}; rust::cxxbridge1::Box weights2 = concrete_optimizer::weights::vector(slice(weight_vec)); - dag->add_dot(slice(lut1v), std::move(weights2)); + auto id = dag->add_dot(slice(lut1v), std::move(weights2)); + dag->tag_operator_as_output(id); auto options = default_options(); diff --git a/compilers/concrete-optimizer/concrete-optimizer/src/dag/operator/operator.rs b/compilers/concrete-optimizer/concrete-optimizer/src/dag/operator/operator.rs index cdb8852a4..1154db9ba 100644 --- a/compilers/concrete-optimizer/concrete-optimizer/src/dag/operator/operator.rs +++ b/compilers/concrete-optimizer/concrete-optimizer/src/dag/operator/operator.rs @@ -1,4 +1,5 @@ use std::fmt; +use std::iter::{empty, once}; use crate::dag::operator::tensor::{ClearTensor, Shape}; @@ -101,6 +102,19 @@ pub enum Operator { }, } +impl Operator { + // Returns an iterator on the indices of the operator inputs. + pub(crate) fn get_inputs_iter(&self) -> Box + '_> { + match self { + Self::Input { .. } => Box::new(empty()), + Self::LevelledOp { inputs, .. } | Self::Dot { inputs, .. } => Box::new(inputs.iter()), + Self::UnsafeCast { input, .. } + | Self::Lut { input, .. } + | Self::Round { input, .. } => Box::new(once(input)), + } + } +} + #[derive(Clone, Copy, PartialEq, Eq, Debug)] pub struct OperatorIndex { pub i: usize, diff --git a/compilers/concrete-optimizer/concrete-optimizer/src/dag/rewrite/regen.rs b/compilers/concrete-optimizer/concrete-optimizer/src/dag/rewrite/regen.rs index d4a295216..b9ea66552 100644 --- a/compilers/concrete-optimizer/concrete-optimizer/src/dag/rewrite/regen.rs +++ b/compilers/concrete-optimizer/concrete-optimizer/src/dag/rewrite/regen.rs @@ -35,6 +35,7 @@ pub(crate) fn regen( regen_dag.operators.push(op.clone()); regen_dag.out_precisions.push(dag.out_precisions[i]); regen_dag.out_shapes.push(dag.out_shapes[i].clone()); + regen_dag.output_tags.push(dag.output_tags[i]); } } (regen_dag, instructions_multi_map(&old_index_to_new)) diff --git a/compilers/concrete-optimizer/concrete-optimizer/src/dag/unparametrized.rs b/compilers/concrete-optimizer/concrete-optimizer/src/dag/unparametrized.rs index 7535a3075..a7a289401 100644 --- a/compilers/concrete-optimizer/concrete-optimizer/src/dag/unparametrized.rs +++ b/compilers/concrete-optimizer/concrete-optimizer/src/dag/unparametrized.rs @@ -5,7 +5,6 @@ use crate::dag::operator::{ dot_kind, DotKind, FunctionTable, LevelledComplexity, Operator, OperatorIndex, Precision, Shape, Weights, }; -use crate::optimization::dag::solo_key::analyze::extra_final_values_to_check; pub(crate) type UnparameterizedOperator = Operator; @@ -17,6 +16,8 @@ pub struct OperationDag { pub(crate) out_shapes: Vec, // Collect all operators ouput precision pub(crate) out_precisions: Vec, + // Collect whether operators are tagged as outputs + pub(crate) output_tags: Vec, } impl fmt::Display for OperationDag { @@ -34,6 +35,7 @@ impl OperationDag { operators: vec![], out_shapes: vec![], out_precisions: vec![], + output_tags: vec![], } } @@ -43,6 +45,7 @@ impl OperationDag { .push(self.infer_out_precision(&operator)); self.out_shapes.push(self.infer_out_shape(&operator)); self.operators.push(operator); + self.output_tags.push(false); OperatorIndex { i } } @@ -130,6 +133,11 @@ impl OperationDag { }) } + pub fn tag_operator_as_output(&mut self, operator: OperatorIndex) { + assert!(operator.i < self.len()); + self.output_tags[operator.i] = true; + } + #[allow(clippy::len_without_is_empty)] pub fn len(&self) -> usize { self.operators.len() @@ -244,6 +252,7 @@ impl OperationDag { self.add_lut(rounded, table, out_precision) } + /// Returns an iterator over input nodes indices. pub(crate) fn get_input_index_iter(&self) -> impl Iterator + '_ { self.operators .iter() @@ -254,12 +263,36 @@ impl OperationDag { }) } - pub(crate) fn get_output_index(&self) -> Vec { - return extra_final_values_to_check(self) + /// If no outputs were declared, automatically tag final nodes as outputs. + #[allow(unused)] + pub(crate) fn detect_outputs(&mut self) { + assert!(!self.is_output_tagged()); + self.output_tags = vec![true; self.len()]; + self.operators + .iter() + .flat_map(|op| op.get_inputs_iter()) + .for_each(|op| self.output_tags[op.i] = false); + } + + fn is_output_tagged(&self) -> bool { + self.output_tags + .iter() + .copied() + .reduce(|a, b| a || b) + .unwrap() + } + + /// Returns an iterator over output nodes indices. + pub(crate) fn get_output_index_iter(&self) -> impl Iterator + '_ { + self.output_tags .iter() .enumerate() .filter_map(|(index, is_output)| is_output.then_some(index)) - .collect(); + } + + /// Returns whether the node is tagged as output. + pub(crate) fn is_output_node(&self, oid: usize) -> bool { + self.output_tags[oid] } fn infer_out_shape(&self, op: &UnparameterizedOperator) -> Shape { diff --git a/compilers/concrete-optimizer/concrete-optimizer/src/optimization/dag/multi_parameters/analyze.rs b/compilers/concrete-optimizer/concrete-optimizer/src/optimization/dag/multi_parameters/analyze.rs index db141d732..1330f4c0c 100644 --- a/compilers/concrete-optimizer/concrete-optimizer/src/optimization/dag/multi_parameters/analyze.rs +++ b/compilers/concrete-optimizer/concrete-optimizer/src/optimization/dag/multi_parameters/analyze.rs @@ -10,9 +10,7 @@ use crate::optimization::dag::multi_parameters::partitions::{ InstructionPartition, PartitionIndex, Transition, }; use crate::optimization::dag::multi_parameters::symbolic_variance::SymbolicVariance; -use crate::optimization::dag::solo_key::analyze::{ - extra_final_values_to_check, first, safe_noise_bound, -}; +use crate::optimization::dag::solo_key::analyze::{first, safe_noise_bound}; use crate::optimization::Err::NotComposable; use crate::optimization::Result; @@ -72,8 +70,7 @@ pub fn analyze( check_composability(&dag, &out_variances, nb_partitions)?; // Get the largest output out_variance let largest_output_variances = dag - .get_output_index() - .into_iter() + .get_output_index_iter() .map(|index| out_variances[index].clone()) .reduce(|lhs, rhs| { lhs.into_iter() @@ -129,8 +126,7 @@ fn check_composability( // If the circuit outputs are free from input variances, it means that every outputs are // refreshed, and the function can be composed. let in_var_in_out_var = dag - .get_output_index() - .into_iter() + .get_output_index_iter() .flat_map(|index| symbolic_variances[index].iter().map(move |v| (index, v))) .find_map(|(output_index, sym_var)| { (0..nb_partitions) @@ -345,7 +341,6 @@ fn collect_all_variance_constraints( instrs_partition: &[InstructionPartition], out_variances: &[Vec], ) -> Vec { - let decryption_points = extra_final_values_to_check(dag); let mut constraints = vec![]; for (op_i, op) in dag.operators.iter().enumerate() { let partition = instrs_partition[op_i].instruction_partition; @@ -381,7 +376,7 @@ fn collect_all_variance_constraints( variance, )); } - if decryption_points[op_i] { + if dag.is_output_node(op_i) { let precision = dag.out_precisions[op_i]; let variance = out_variances[op_i][partition].clone(); constraints.push(variance_constraint( @@ -532,6 +527,7 @@ pub mod tests { let mut dag = unparametrized::OperationDag::new(); let _ = dag.add_input(1, Shape::number()); let p_cut = PartitionCut::for_each_precision(&dag); + dag.detect_outputs(); let res = super::analyze(&dag, &CONFIG, &Some(p_cut), LOW_PRECISION_PARTITION, true); assert!(res.is_ok()); } @@ -542,6 +538,7 @@ pub mod tests { let input1 = dag.add_input(1, Shape::number()); let _ = dag.add_lut(input1, FunctionTable::UNKWOWN, 2); let p_cut = PartitionCut::for_each_precision(&dag); + dag.detect_outputs(); let dag = super::analyze(&dag, &CONFIG, &Some(p_cut), LOW_PRECISION_PARTITION, true).unwrap(); assert!(dag.nb_partitions == 1); @@ -565,6 +562,7 @@ pub mod tests { let lut3 = dag.add_lut(lut1, FunctionTable::UNKWOWN, 3); let input2 = dag.add_dot([input1, lut3], [1, 1]); let _ = dag.add_lut(input2, FunctionTable::UNKWOWN, 3); + dag.detect_outputs(); let analyzed_dag = super::analyze(&dag, &CONFIG, &None, LOW_PRECISION_PARTITION, true).unwrap(); assert_eq!(analyzed_dag.nb_partitions, 2); @@ -608,6 +606,7 @@ pub mod tests { let b = dag.add_dot([input1, lut3], [1, 1]); let _ = dag.add_lut(a, FunctionTable::UNKWOWN, 3); let _ = dag.add_lut(b, FunctionTable::UNKWOWN, 3); + dag.detect_outputs(); let analyzed_dag = super::analyze(&dag, &CONFIG, &None, 1, true).unwrap(); assert_eq!(analyzed_dag.nb_partitions, 3); let actual_constraint_strings = analyzed_dag @@ -659,6 +658,7 @@ pub mod tests { LOW_PRECISION_PARTITION, HIGH_PRECISION_PARTITION, ]; + dag.detect_outputs(); let dag = analyze(&dag); assert!(dag.nb_partitions == 2); for op_i in input1.i..=lut5.i { @@ -690,6 +690,7 @@ pub mod tests { &out_shape, "comment", ); + dag.detect_outputs(); let dag = analyze(&dag); assert!(dag.nb_partitions == 1); } @@ -709,6 +710,7 @@ pub mod tests { let lut1 = dag.add_lut(rounded1, FunctionTable::UNKWOWN, acc_precision); let rounded2 = dag.add_expanded_round(lut1, precision); let lut2 = dag.add_lut(rounded2, FunctionTable::UNKWOWN, acc_precision); + dag.detect_outputs(); let old_dag = dag; let dag = analyze(&old_dag); show_partitionning(&old_dag, &dag.instrs_partition); @@ -776,6 +778,7 @@ pub mod tests { let input1 = dag.add_lut(free_input1, FunctionTable::UNKWOWN, acc_precision); let rounded1 = dag.add_expanded_round(input1, precision); let _lut1 = dag.add_lut(rounded1, FunctionTable::UNKWOWN, acc_precision); + dag.detect_outputs(); let old_dag = dag; let dag = analyze(&old_dag); show_partitionning(&old_dag, &dag.instrs_partition); @@ -865,6 +868,7 @@ pub mod tests { let input1 = dag.add_lut(free_input1, FunctionTable::UNKWOWN, acc_precision); let rounded1 = dag.add_expanded_round(input1, precision); let _lut1 = dag.add_lut(rounded1, FunctionTable::UNKWOWN, precision); + dag.detect_outputs(); let old_dag = dag; let dag = analyze(&old_dag); show_partitionning(&old_dag, &dag.instrs_partition); @@ -926,6 +930,7 @@ pub mod tests { // let input1 = dag.add_input(acc_precision, Shape::number()); let rounded1 = dag.add_expanded_round(input1, precision); let _lut1 = dag.add_lut(rounded1, FunctionTable::UNKWOWN, precision); + dag.detect_outputs(); let old_dag = dag; let dag = analyze_with_preferred(&old_dag, HIGH_PRECISION_PARTITION); show_partitionning(&old_dag, &dag.instrs_partition); @@ -983,6 +988,7 @@ pub mod tests { let input1 = dag.add_lut(free_input1, FunctionTable::UNKWOWN, acc_precision); let rounded1 = dag.add_expanded_round(input1, precision); let _lut1 = dag.add_lut(rounded1, FunctionTable::UNKWOWN, precision); + dag.detect_outputs(); let old_dag = dag; let dag = analyze(&old_dag); // Partition 0 @@ -1036,6 +1042,7 @@ pub mod tests { _ = dag.add_lut(lut_input, FunctionTable::UNKWOWN, 1); let precisions: Vec<_> = (1..=max_precision).collect(); let p_cut = PartitionCut::from_precisions(&precisions); + dag.detect_outputs(); let dag = super::analyze( &dag, &CONFIG, diff --git a/compilers/concrete-optimizer/concrete-optimizer/src/optimization/dag/multi_parameters/optimize/tests.rs b/compilers/concrete-optimizer/concrete-optimizer/src/optimization/dag/multi_parameters/optimize/tests.rs index 7e152866a..d15dc1dff 100644 --- a/compilers/concrete-optimizer/concrete-optimizer/src/optimization/dag/multi_parameters/optimize/tests.rs +++ b/compilers/concrete-optimizer/concrete-optimizer/src/optimization/dag/multi_parameters/optimize/tests.rs @@ -177,8 +177,11 @@ fn optimize_multi_independant_2_precisions() { let noise_factor = manp as f64; let mut dag_multi = v0_dag(sum_size, precision1, noise_factor); add_v0_dag(&mut dag_multi, sum_size, precision2, noise_factor); - let dag_1 = v0_dag(sum_size, precision1, noise_factor); - let dag_2 = v0_dag(sum_size, precision2, noise_factor); + dag_multi.detect_outputs(); + let mut dag_1 = v0_dag(sum_size, precision1, noise_factor); + dag_1.detect_outputs(); + let mut dag_2 = v0_dag(sum_size, precision2, noise_factor); + dag_2.detect_outputs(); if let Some(equiv) = equiv_2_single(&dag_multi, &dag_1, &dag_2) { assert!(equiv, "FAILED ON {precision1} {precision2} {manp}"); } else { @@ -205,6 +208,7 @@ fn dag_lut_sum_of_2_partitions_2_layer( if final_lut { _ = dag.add_lut(dot, FunctionTable::UNKWOWN, precision1); } + dag.detect_outputs(); dag } @@ -615,6 +619,7 @@ fn test_multi_rounded_fks_coherency() { let reduced_8 = dag.add_expanded_rounded_lut(input1, FunctionTable::UNKWOWN, 8, 8); let reduced_4 = dag.add_expanded_rounded_lut(input1, FunctionTable::UNKWOWN, 4, 8); _ = dag.add_dot([reduced_8, reduced_4], [1, 1]); + dag.detect_outputs(); let sol = optimize(&dag, &None, 0); assert!(sol.is_some()); let sol = sol.unwrap(); @@ -650,6 +655,7 @@ fn test_big_secret_key_sharing() { let lut1 = dag.add_lut(input1, FunctionTable::UNKWOWN, 5); let lut2 = dag.add_lut(input2, FunctionTable::UNKWOWN, 5); let _ = dag.add_dot([lut1, lut2], [16, 1]); + dag.detect_outputs(); let config_sharing = Config { security_level: 128, maximum_acceptable_error_probability: _4_SIGMA, @@ -700,6 +706,7 @@ fn test_big_and_small_secret_key() { let lut1 = dag.add_lut(input1, FunctionTable::UNKWOWN, 5); let lut2 = dag.add_lut(input2, FunctionTable::UNKWOWN, 5); let _ = dag.add_dot([lut1, lut2], [16, 1]); + dag.detect_outputs(); let config_sharing = Config { security_level: 128, maximum_acceptable_error_probability: _4_SIGMA, @@ -750,6 +757,7 @@ fn test_composition_2_partitions() { let lut3 = dag.add_lut(lut1, FunctionTable::UNKWOWN, 3); let input2 = dag.add_dot([input1, lut3], [1, 1]); let _ = dag.add_lut(input2, FunctionTable::UNKWOWN, 3); + dag.detect_outputs(); let normal_config = default_config(); let composed_config = Config { composable: true, @@ -780,6 +788,7 @@ fn test_composition_1_partition_not_composable() { let input1 = dag.add_dot([input1], [1 << 16]); let lut1 = dag.add_lut(input1, FunctionTable::UNKWOWN, 8); let _ = dag.add_dot([lut1], [1 << 16]); + dag.detect_outputs(); let normal_config = default_config(); let composed_config = Config { composable: true, @@ -808,6 +817,7 @@ fn test_maximal_multi() { let lut1 = dag.add_lut(input, FunctionTable::UNKWOWN, 8u8); let lut2 = dag.add_lut(lut1, FunctionTable::UNKWOWN, 8u8); _ = dag.add_dot([lut2], [1 << 16]); + dag.detect_outputs(); let sol = optimize(&dag, &None, 0).unwrap(); assert!(sol.macro_params.len() == 1); diff --git a/compilers/concrete-optimizer/concrete-optimizer/src/optimization/dag/multi_parameters/partition_cut.rs b/compilers/concrete-optimizer/concrete-optimizer/src/optimization/dag/multi_parameters/partition_cut.rs index 836bede20..20b93a304 100644 --- a/compilers/concrete-optimizer/concrete-optimizer/src/optimization/dag/multi_parameters/partition_cut.rs +++ b/compilers/concrete-optimizer/concrete-optimizer/src/optimization/dag/multi_parameters/partition_cut.rs @@ -6,7 +6,7 @@ use crate::dag::operator::{Operator, OperatorIndex, Precision}; use crate::dag::rewrite::round::expand_round_and_index_map; use crate::dag::unparametrized; use crate::optimization::dag::multi_parameters::partitions::PartitionIndex; -use crate::optimization::dag::solo_key::analyze::{extra_final_values_to_check, out_variances}; +use crate::optimization::dag::solo_key::analyze::out_variances; use crate::optimization::dag::solo_key::symbolic_variance::SymbolicVariance; const ROUND_INNER_MULTI_PARAMETER: bool = false; @@ -143,12 +143,10 @@ impl PartitionCut { } } } - for (op_i, &need_decrypt) in extra_final_values_to_check(&dag).iter().enumerate() { - if need_decrypt { - for &origin in &noise_origins[op_i] { - max_output_norm2[origin] = max_output_norm2[origin].max(out_norm2(op_i)); - assert!(!max_output_norm2[origin].is_nan()); - } + for op_i in dag.get_output_index_iter() { + for &origin in &noise_origins[op_i] { + max_output_norm2[origin] = max_output_norm2[origin].max(out_norm2(op_i)); + assert!(!max_output_norm2[origin].is_nan()); } } let mut round_done: HashMap = HashMap::default(); diff --git a/compilers/concrete-optimizer/concrete-optimizer/src/optimization/dag/multi_parameters/partitionning.rs b/compilers/concrete-optimizer/concrete-optimizer/src/optimization/dag/multi_parameters/partitionning.rs index 9b08773fc..b547211ff 100644 --- a/compilers/concrete-optimizer/concrete-optimizer/src/optimization/dag/multi_parameters/partitionning.rs +++ b/compilers/concrete-optimizer/concrete-optimizer/src/optimization/dag/multi_parameters/partitionning.rs @@ -69,8 +69,7 @@ fn extract_levelled_block(dag: &unparametrized::OperationDag, composable: bool) // all inputs and outputs in the same partition. let mut input_iter = dag.get_input_index_iter(); let first_inp = input_iter.next().unwrap(); - dag.get_output_index() - .into_iter() + dag.get_output_index_iter() .chain(input_iter) .for_each(|ind| uf.union(first_inp, ind)); } @@ -362,6 +361,7 @@ pub mod tests { let input = dag.add_input(10, Shape::number()); let lut1 = dag.add_lut(input, FunctionTable::UNKWOWN, 2); let output = dag.add_lut(lut1, FunctionTable::UNKWOWN, 10); + dag.detect_outputs(); let partitions = partitionning(&dag, false); assert!( partitions.instrs_partition[input.i].instruction_partition diff --git a/compilers/concrete-optimizer/concrete-optimizer/src/optimization/dag/multi_parameters/visualization.rs b/compilers/concrete-optimizer/concrete-optimizer/src/optimization/dag/multi_parameters/visualization.rs index 99c8c7caf..a525d2fec 100644 --- a/compilers/concrete-optimizer/concrete-optimizer/src/optimization/dag/multi_parameters/visualization.rs +++ b/compilers/concrete-optimizer/concrete-optimizer/src/optimization/dag/multi_parameters/visualization.rs @@ -42,15 +42,7 @@ fn write_dot_svg(dot: &str, maybe_path: Option) -> PathBuf { } fn extract_node_inputs(node: &Operator) -> Vec { - match node { - Operator::Input { .. } => vec![], - Operator::LevelledOp { inputs, .. } | Operator::Dot { inputs, .. } => { - inputs.iter().map(|n| n.i).collect() - } - Operator::UnsafeCast { input, .. } - | Operator::Lut { input, .. } - | Operator::Round { input, .. } => vec![input.i], - } + node.get_inputs_iter().map(|id| id.i).collect() } fn extract_node_label(node: &Operator, index: usize) -> String { diff --git a/compilers/concrete-optimizer/concrete-optimizer/src/optimization/dag/solo_key/analyze.rs b/compilers/concrete-optimizer/concrete-optimizer/src/optimization/dag/solo_key/analyze.rs index 190347da1..0e4d32581 100644 --- a/compilers/concrete-optimizer/concrete-optimizer/src/optimization/dag/solo_key/analyze.rs +++ b/compilers/concrete-optimizer/concrete-optimizer/src/optimization/dag/solo_key/analyze.rs @@ -206,42 +206,17 @@ pub fn out_variances(dag: &unparametrized::OperationDag) -> Vec Vec { - let nb_ops = dag.operators.len(); - let mut extra_values_to_check = vec![true; nb_ops]; - for op in &dag.operators { - match op { - Op::Input { .. } => (), - Op::Lut { input, .. } | Op::UnsafeCast { input, .. } | Op::Round { input, .. } => { - extra_values_to_check[input.i] = false; - } - Op::Dot { inputs, .. } | Op::LevelledOp { inputs, .. } => { - for input in inputs { - extra_values_to_check[input.i] = false; - } - } - } - } - extra_values_to_check -} - fn extra_final_variances( dag: &unparametrized::OperationDag, out_variances: &[SymbolicVariance], ) -> Vec<(Precision, Shape, SymbolicVariance)> { - extra_final_values_to_check(dag) - .iter() - .enumerate() - .filter_map(|(i, &is_final)| { - if is_final { - Some(( - dag.out_precisions[i], - dag.out_shapes[i].clone(), - out_variances[i], - )) - } else { - None - } + dag.get_output_index_iter() + .map(|i| { + ( + dag.out_precisions[i], + dag.out_shapes[i].clone(), + out_variances[i], + ) }) .collect() } @@ -666,6 +641,7 @@ pub mod tests { fn test_1_input() { let mut graph = unparametrized::OperationDag::new(); let input1 = graph.add_input(1, Shape::number()); + graph.detect_outputs(); let analysis = analyze(&graph); let one_lut_cost = 100.0; let lwe_dim = 1024; @@ -689,6 +665,7 @@ pub mod tests { let mut graph = unparametrized::OperationDag::new(); let input1 = graph.add_input(8, Shape::number()); let lut1 = graph.add_lut(input1, FunctionTable::UNKWOWN, 8); + graph.detect_outputs(); let analysis = analyze(&graph); let one_lut_cost = 100.0; let lwe_dim = 1024; @@ -715,6 +692,7 @@ pub mod tests { let weights = Weights::vector([1, 2]); let norm2: f64 = 1.0 * 1.0 + 2.0 * 2.0; let dot = graph.add_dot([input1, input1], weights); + graph.detect_outputs(); let analysis = analyze(&graph); let one_lut_cost = 100.0; let lwe_dim = 1024; @@ -746,6 +724,7 @@ pub mod tests { #[allow(clippy::imprecise_flops)] let manp = (1.0 * 1.0 + 2.0 * 2_f64).sqrt(); let dot = graph.add_levelled_op([input1, input1], cpx_dot, manp, Shape::number(), "dot"); + graph.detect_outputs(); let analysis = analyze(&graph); let one_lut_cost = 100.0; let lwe_dim = 1024; @@ -775,6 +754,7 @@ pub mod tests { let lut1 = graph.add_lut(dot1, FunctionTable::UNKWOWN, 1); let dot2 = graph.add_dot([lut1, lut1], weights); let lut2 = graph.add_lut(dot2, FunctionTable::UNKWOWN, 1); + graph.detect_outputs(); let analysis = analyze(&graph); let one_lut_cost = 100.0; let lwe_dim = 1024; @@ -823,6 +803,7 @@ pub mod tests { let weights = &Weights::vector([2, 3]); let dot1 = graph.add_dot([input1, lut1], weights); let _lut2 = graph.add_lut(dot1, FunctionTable::UNKWOWN, 1); + graph.detect_outputs(); let analysis = analyze(&graph); let one_lut_cost = 100.0; let lwe_dim = 1024; @@ -849,6 +830,7 @@ pub mod tests { for i in 1..=max_precision { _ = graph.add_input(i, Shape::number()); } + graph.detect_outputs(); let analysis = analyze(&graph); assert!(analysis.constraints_by_precisions.len() == max_precision as usize); let mut prev_safe_noise_bound = 0.0; @@ -869,6 +851,7 @@ pub mod tests { let input = graph.add_input(p, Shape::number()); let _lut = graph.add_lut(input, FunctionTable::UNKWOWN, p); } + graph.detect_outputs(); let analysis = analyze(&graph); assert!(analysis.constraints_by_precisions.len() == max_precision as usize); let mut prev_safe_noise_bound = 0.0; @@ -896,6 +879,7 @@ pub mod tests { let weights = &Weights::number(2); _ = graph.add_dot([input1], weights); assert!(*graph.out_shapes.last().unwrap() == shape); + graph.detect_outputs(); let analysis = analyze(&graph); assert_f64_eq(analysis.out_variances.last().unwrap().input_coeff, 4.0); } @@ -912,6 +896,7 @@ pub mod tests { let weights = &Weights::vector([2, 3]); _ = graph.add_dot([input1, lut2], weights); assert!(*graph.out_shapes.last().unwrap() == shape); + graph.detect_outputs(); let analysis = analyze(&graph); assert_f64_eq(analysis.out_variances.last().unwrap().input_coeff, 4.0); assert_f64_eq(analysis.out_variances.last().unwrap().lut_coeff, 9.0); diff --git a/compilers/concrete-optimizer/concrete-optimizer/src/optimization/dag/solo_key/optimize.rs b/compilers/concrete-optimizer/concrete-optimizer/src/optimization/dag/solo_key/optimize.rs index 67e1322f4..8f92a2c28 100644 --- a/compilers/concrete-optimizer/concrete-optimizer/src/optimization/dag/solo_key/optimize.rs +++ b/compilers/concrete-optimizer/concrete-optimizer/src/optimization/dag/solo_key/optimize.rs @@ -596,6 +596,7 @@ pub(crate) mod tests { let lut1 = dag.add_lut(dot1, FunctionTable::UNKWOWN, precision); let dot2 = dag.add_dot([lut1], [weight]); let _lut2 = dag.add_lut(dot2, FunctionTable::UNKWOWN, precision); + dag.detect_outputs(); } { let dag2 = analyze::analyze(