diff --git a/compilers/concrete-compiler/compiler/include/concretelang/Dialect/FHE/Analysis/ConcreteOptimizer.h b/compilers/concrete-compiler/compiler/include/concretelang/Dialect/FHE/Analysis/ConcreteOptimizer.h index dd0beae31..0dcd2d876 100644 --- a/compilers/concrete-compiler/compiler/include/concretelang/Dialect/FHE/Analysis/ConcreteOptimizer.h +++ b/compilers/concrete-compiler/compiler/include/concretelang/Dialect/FHE/Analysis/ConcreteOptimizer.h @@ -20,6 +20,9 @@ namespace optimizer { std::unique_ptr createDagPass(optimizer::Config config, concrete_optimizer::Dag &dag); +void applyCompositionRules(optimizer::Config config, + concrete_optimizer::Dag &dag); + } // namespace optimizer } // namespace concretelang } // namespace mlir diff --git a/compilers/concrete-compiler/compiler/lib/Dialect/FHE/Analysis/ConcreteOptimizer.cpp b/compilers/concrete-compiler/compiler/lib/Dialect/FHE/Analysis/ConcreteOptimizer.cpp index 150a65359..322bdb8c6 100644 --- a/compilers/concrete-compiler/compiler/lib/Dialect/FHE/Analysis/ConcreteOptimizer.cpp +++ b/compilers/concrete-compiler/compiler/lib/Dialect/FHE/Analysis/ConcreteOptimizer.cpp @@ -973,6 +973,21 @@ std::unique_ptr createDagPass(optimizer::Config config, return std::make_unique(config, dag); } +// Adds the composition rules to the +void applyCompositionRules(optimizer::Config config, + concrete_optimizer::Dag &dag) { + + if (config.composable) { + auto inputs = dag.get_input_indices(); + auto outputs = dag.get_output_indices(); + dag.add_compositions( + rust::Slice( + outputs.data(), outputs.size()), + rust::Slice( + inputs.data(), inputs.size())); + } +} + } // namespace optimizer } // namespace concretelang } // namespace mlir diff --git a/compilers/concrete-compiler/compiler/lib/Support/Pipeline.cpp b/compilers/concrete-compiler/compiler/lib/Support/Pipeline.cpp index 3a51aff91..caae41a35 100644 --- a/compilers/concrete-compiler/compiler/lib/Support/Pipeline.cpp +++ b/compilers/concrete-compiler/compiler/lib/Support/Pipeline.cpp @@ -134,7 +134,10 @@ getFHEContextFromFHE(mlir::MLIRContext &context, mlir::ModuleOp &module, if (pm.run(module.getOperation()).failed()) { return StreamStringError() << "Failed to create concrete-optimizer dag\n"; } + optimizer::applyCompositionRules(config, *dag); + std::optional description; + if (!constraint) { description = std::nullopt; } else { diff --git a/compilers/concrete-compiler/compiler/lib/Support/V0Parameters.cpp b/compilers/concrete-compiler/compiler/lib/Support/V0Parameters.cpp index 7d883e2e2..05e7595ac 100644 --- a/compilers/concrete-compiler/compiler/lib/Support/V0Parameters.cpp +++ b/compilers/concrete-compiler/compiler/lib/Support/V0Parameters.cpp @@ -35,8 +35,7 @@ concrete_optimizer::Options options_from_config(optimizer::Config config) { /* .encoding = */ config.encoding, /* .cache_on_disk = */ config.cache_on_disk, /* .ciphertext_modulus_log = */ config.ciphertext_modulus_log, - /* .fft_precision = */ config.fft_precision, - /* .composable = */ config.composable}; + /* .fft_precision = */ config.fft_precision}; return options; } diff --git a/compilers/concrete-compiler/compiler/tests/check_tests/Dialect/FHE/optimizer_ast.mlir b/compilers/concrete-compiler/compiler/tests/check_tests/Dialect/FHE/optimizer_ast.mlir index 2ff224951..3b3edcbe2 100644 --- a/compilers/concrete-compiler/compiler/tests/check_tests/Dialect/FHE/optimizer_ast.mlir +++ b/compilers/concrete-compiler/compiler/tests/check_tests/Dialect/FHE/optimizer_ast.mlir @@ -4,7 +4,7 @@ func.func @main(%arg0: tensor<5x!FHE.eint<5>>) -> !FHE.eint<5> { %weights = arith.constant dense<[-1, -1, -1, -1, -1]> : tensor<5xi6> %tlu = arith.constant dense<[0, 0, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0]> : tensor<32xi64> %0 = "FHELinalg.apply_lookup_table"(%arg0, %tlu) : (tensor<5x!FHE.eint<5>>, tensor<32xi64>) -> tensor<5x!FHE.eint<5>> - // CHECK: Dot { [[a:.*]], weights: ClearTensor { shape: Shape { dimensions_size: [5] }, values: [-1, -1, -1, -1, -1] } } + // CHECK: Dot { [[a:.*]], weights: ClearTensor { shape: Shape { dimensions_size: [5] }, values: [-1, -1, -1, -1, -1] }, kind: Tensor } %1 = "FHELinalg.dot_eint_int"(%0, %weights) : (tensor<5x!FHE.eint<5>>, tensor<5xi6>) -> !FHE.eint<5> return %1 : !FHE.eint<5> } diff --git a/compilers/concrete-compiler/compiler/tests/end_to_end_tests/end_to_end_jit_test.cc b/compilers/concrete-compiler/compiler/tests/end_to_end_tests/end_to_end_jit_test.cc index abf87c1b5..3357d0d15 100644 --- a/compilers/concrete-compiler/compiler/tests/end_to_end_tests/end_to_end_jit_test.cc +++ b/compilers/concrete-compiler/compiler/tests/end_to_end_tests/end_to_end_jit_test.cc @@ -418,15 +418,16 @@ TEST(CompileNotComposable, not_composable_2) { TestProgram circuit(options); auto err = circuit.compile(R"XXX( func.func @main(%arg0: !FHE.eint<3>) -> (!FHE.eint<3>, !FHE.eint<3>) { - %cst_1 = arith.constant 1 : i4 + %cst_1 = arith.constant 2 : i4 %cst_2 = arith.constant dense<[0, 1, 2, 3, 4, 5, 6, 7]> : tensor<8xi64> - %1 = "FHE.add_eint_int"(%arg0, %cst_1) : (!FHE.eint<3>, i4) -> !FHE.eint<3> + %1 = "FHE.mul_eint_int"(%arg0, %cst_1) : (!FHE.eint<3>, i4) -> !FHE.eint<3> %2 = "FHE.apply_lookup_table"(%1, %cst_2): (!FHE.eint<3>, tensor<8xi64>) -> (!FHE.eint<3>) return %1, %2: !FHE.eint<3>, !FHE.eint<3> } )XXX"); ASSERT_OUTCOME_HAS_FAILURE_WITH_ERRORMSG( - err, "Program can not be composed: Output 1 has variance 1σ²In[0]."); + err, "Program can not be composed: Dag is not composable, because of " + "output 1: Partition 0 has input coefficient 4"); } TEST(CompileComposable, composable_supported_dag_mono) { diff --git a/compilers/concrete-optimizer/charts/src/bin/norm2_complexity.rs b/compilers/concrete-optimizer/charts/src/bin/norm2_complexity.rs index 3ac31bbbc..83cc4144b 100644 --- a/compilers/concrete-optimizer/charts/src/bin/norm2_complexity.rs +++ b/compilers/concrete-optimizer/charts/src/bin/norm2_complexity.rs @@ -47,7 +47,6 @@ fn main() -> Result<(), Box> { ciphertext_modulus_log, fft_precision, complexity_model: &CpuComplexity::default(), - composable: false, }; let cache = decomposition::cache( diff --git a/compilers/concrete-optimizer/charts/src/bin/precision_complexity.rs b/compilers/concrete-optimizer/charts/src/bin/precision_complexity.rs index 240d44b86..068cf6efb 100644 --- a/compilers/concrete-optimizer/charts/src/bin/precision_complexity.rs +++ b/compilers/concrete-optimizer/charts/src/bin/precision_complexity.rs @@ -47,7 +47,6 @@ fn main() -> Result<(), Box> { ciphertext_modulus_log, fft_precision, complexity_model: &CpuComplexity::default(), - composable: false, }; let cache = decomposition::cache( diff --git a/compilers/concrete-optimizer/concrete-optimizer-cpp/src/concrete-optimizer.rs b/compilers/concrete-optimizer/concrete-optimizer-cpp/src/concrete-optimizer.rs index 0193bb302..c68574145 100644 --- a/compilers/concrete-optimizer/concrete-optimizer-cpp/src/concrete-optimizer.rs +++ b/compilers/concrete-optimizer/concrete-optimizer-cpp/src/concrete-optimizer.rs @@ -59,7 +59,6 @@ fn optimize_bootstrap(precision: u64, noise_factor: f64, options: ffi::Options) ciphertext_modulus_log: options.ciphertext_modulus_log, fft_precision: options.fft_precision, complexity_model: &CpuComplexity::default(), - composable: options.composable, }; let sum_size = 1; @@ -489,6 +488,20 @@ impl Dag { self.0.viz_string() } + fn get_input_indices(&self) -> Vec { + self.0 + .get_input_operators_iter() + .map(|n| ffi::OperatorIndex { index: n.id.0 }) + .collect() + } + + fn get_output_indices(&self) -> Vec { + self.0 + .get_output_operators_iter() + .map(|n| ffi::OperatorIndex { index: n.id.0 }) + .collect() + } + fn optimize(&self, options: ffi::Options) -> ffi::DagSolution { let processing_unit = processing_unit(options); let config = Config { @@ -498,13 +511,12 @@ impl Dag { ciphertext_modulus_log: options.ciphertext_modulus_log, fft_precision: options.fft_precision, complexity_model: &CpuComplexity::default(), - composable: options.composable, }; let search_space = SearchSpace::default(processing_unit); let encoding = options.encoding.into(); - if options.composable { + if self.0.is_composed() { let circuit_sol = concrete_optimizer::optimization::dag::multi_parameters::optimize_generic::optimize( &self.0, @@ -535,6 +547,32 @@ impl Dag { self.0.get_circuit_count() } + fn add_compositions(&mut self, froms: &[ffi::OperatorIndex], tos: &[ffi::OperatorIndex]) { + self.0.add_compositions( + froms + .iter() + .map(|a| OperatorIndex(a.index)) + .collect::>(), + tos.iter() + .map(|a| OperatorIndex(a.index)) + .collect::>(), + ); + } + + fn add_all_compositions(&mut self) { + let froms = self + .0 + .get_output_operators_iter() + .map(|o| o.id) + .collect::>(); + let tos = self + .0 + .get_input_operators_iter() + .map(|o| o.id) + .collect::>(); + self.0.add_compositions(froms, tos); + } + fn optimize_multi(&self, options: ffi::Options) -> ffi::CircuitSolution { let processing_unit = processing_unit(options); let config = Config { @@ -544,7 +582,6 @@ impl Dag { ciphertext_modulus_log: options.ciphertext_modulus_log, fft_precision: options.fft_precision, complexity_model: &CpuComplexity::default(), - composable: options.composable, }; let search_space = SearchSpace::default(processing_unit); @@ -763,6 +800,10 @@ mod ffi { fn optimize(self: &Dag, options: Options) -> DagSolution; + fn add_compositions(self: &mut Dag, froms: &[OperatorIndex], tos: &[OperatorIndex]); + + fn add_all_compositions(self: &mut Dag); + #[namespace = "concrete_optimizer::dag"] fn dump(self: &CircuitSolution) -> String; @@ -781,6 +822,10 @@ mod ffi { fn optimize_multi(self: &Dag, options: Options) -> CircuitSolution; + fn get_input_indices(self: &Dag) -> Vec; + + fn get_output_indices(self: &Dag) -> Vec; + fn NO_KEY_ID() -> u64; } @@ -857,7 +902,6 @@ mod ffi { pub cache_on_disk: bool, pub ciphertext_modulus_log: u32, pub fft_precision: u32, - pub composable: bool, } #[namespace = "concrete_optimizer::dag"] diff --git a/compilers/concrete-optimizer/concrete-optimizer-cpp/src/cpp/concrete-optimizer.cpp b/compilers/concrete-optimizer/concrete-optimizer-cpp/src/cpp/concrete-optimizer.cpp index f3d55819d..182bc1cd2 100644 --- a/compilers/concrete-optimizer/concrete-optimizer-cpp/src/cpp/concrete-optimizer.cpp +++ b/compilers/concrete-optimizer/concrete-optimizer-cpp/src/cpp/concrete-optimizer.cpp @@ -972,8 +972,12 @@ struct Dag final : public ::rust::Opaque { ::rust::Box<::concrete_optimizer::DagBuilder> builder(::rust::String circuit) noexcept; ::rust::String dump() const noexcept; ::concrete_optimizer::dag::DagSolution optimize(::concrete_optimizer::Options options) const noexcept; + void add_compositions(::rust::Slice<::concrete_optimizer::dag::OperatorIndex const> froms, ::rust::Slice<::concrete_optimizer::dag::OperatorIndex const> tos) noexcept; + void add_all_compositions() noexcept; ::std::size_t get_circuit_count() const noexcept; ::concrete_optimizer::dag::CircuitSolution optimize_multi(::concrete_optimizer::Options options) const noexcept; + ::rust::Vec<::concrete_optimizer::dag::OperatorIndex> get_input_indices() const noexcept; + ::rust::Vec<::concrete_optimizer::dag::OperatorIndex> get_output_indices() const noexcept; ~Dag() = delete; private: @@ -1111,7 +1115,6 @@ struct Options final { bool cache_on_disk; ::std::uint32_t ciphertext_modulus_log; ::std::uint32_t fft_precision; - bool composable; using IsRelocatable = ::std::true_type; }; @@ -1315,6 +1318,10 @@ void concrete_optimizer$cxxbridge1$DagBuilder$dump(::concrete_optimizer::DagBuil void concrete_optimizer$cxxbridge1$DagBuilder$tag_operator_as_output(::concrete_optimizer::DagBuilder &self, ::concrete_optimizer::dag::OperatorIndex op) noexcept; void concrete_optimizer$cxxbridge1$Dag$optimize(::concrete_optimizer::Dag const &self, ::concrete_optimizer::Options options, ::concrete_optimizer::dag::DagSolution *return$) noexcept; + +void concrete_optimizer$cxxbridge1$Dag$add_compositions(::concrete_optimizer::Dag &self, ::rust::Slice<::concrete_optimizer::dag::OperatorIndex const> froms, ::rust::Slice<::concrete_optimizer::dag::OperatorIndex const> tos) noexcept; + +void concrete_optimizer$cxxbridge1$Dag$add_all_compositions(::concrete_optimizer::Dag &self) noexcept; } // extern "C" namespace dag { @@ -1343,6 +1350,10 @@ extern "C" { void concrete_optimizer$cxxbridge1$Dag$optimize_multi(::concrete_optimizer::Dag const &self, ::concrete_optimizer::Options options, ::concrete_optimizer::dag::CircuitSolution *return$) noexcept; +void concrete_optimizer$cxxbridge1$Dag$get_input_indices(::concrete_optimizer::Dag const &self, ::rust::Vec<::concrete_optimizer::dag::OperatorIndex> *return$) noexcept; + +void concrete_optimizer$cxxbridge1$Dag$get_output_indices(::concrete_optimizer::Dag const &self, ::rust::Vec<::concrete_optimizer::dag::OperatorIndex> *return$) noexcept; + ::std::uint64_t concrete_optimizer$cxxbridge1$NO_KEY_ID() noexcept; } // extern "C" @@ -1438,6 +1449,14 @@ void DagBuilder::tag_operator_as_output(::concrete_optimizer::dag::OperatorIndex return ::std::move(return$.value); } +void Dag::add_compositions(::rust::Slice<::concrete_optimizer::dag::OperatorIndex const> froms, ::rust::Slice<::concrete_optimizer::dag::OperatorIndex const> tos) noexcept { + concrete_optimizer$cxxbridge1$Dag$add_compositions(*this, froms, tos); +} + +void Dag::add_all_compositions() noexcept { + concrete_optimizer$cxxbridge1$Dag$add_all_compositions(*this); +} + namespace dag { ::rust::String CircuitSolution::dump() const noexcept { ::rust::MaybeUninit<::rust::String> return$; @@ -1480,6 +1499,18 @@ namespace weights { return ::std::move(return$.value); } +::rust::Vec<::concrete_optimizer::dag::OperatorIndex> Dag::get_input_indices() const noexcept { + ::rust::MaybeUninit<::rust::Vec<::concrete_optimizer::dag::OperatorIndex>> return$; + concrete_optimizer$cxxbridge1$Dag$get_input_indices(*this, &return$.value); + return ::std::move(return$.value); +} + +::rust::Vec<::concrete_optimizer::dag::OperatorIndex> Dag::get_output_indices() const noexcept { + ::rust::MaybeUninit<::rust::Vec<::concrete_optimizer::dag::OperatorIndex>> return$; + concrete_optimizer$cxxbridge1$Dag$get_output_indices(*this, &return$.value); + return ::std::move(return$.value); +} + ::std::uint64_t NO_KEY_ID() noexcept { return concrete_optimizer$cxxbridge1$NO_KEY_ID(); } @@ -1498,6 +1529,15 @@ void cxxbridge1$box$concrete_optimizer$DagBuilder$drop(::rust::Box<::concrete_op void cxxbridge1$box$concrete_optimizer$Weights$dealloc(::concrete_optimizer::Weights *) noexcept; void cxxbridge1$box$concrete_optimizer$Weights$drop(::rust::Box<::concrete_optimizer::Weights> *ptr) noexcept; +void cxxbridge1$rust_vec$concrete_optimizer$dag$OperatorIndex$new(::rust::Vec<::concrete_optimizer::dag::OperatorIndex> const *ptr) noexcept; +void cxxbridge1$rust_vec$concrete_optimizer$dag$OperatorIndex$drop(::rust::Vec<::concrete_optimizer::dag::OperatorIndex> *ptr) noexcept; +::std::size_t cxxbridge1$rust_vec$concrete_optimizer$dag$OperatorIndex$len(::rust::Vec<::concrete_optimizer::dag::OperatorIndex> const *ptr) noexcept; +::std::size_t cxxbridge1$rust_vec$concrete_optimizer$dag$OperatorIndex$capacity(::rust::Vec<::concrete_optimizer::dag::OperatorIndex> const *ptr) noexcept; +::concrete_optimizer::dag::OperatorIndex const *cxxbridge1$rust_vec$concrete_optimizer$dag$OperatorIndex$data(::rust::Vec<::concrete_optimizer::dag::OperatorIndex> const *ptr) noexcept; +void cxxbridge1$rust_vec$concrete_optimizer$dag$OperatorIndex$reserve_total(::rust::Vec<::concrete_optimizer::dag::OperatorIndex> *ptr, ::std::size_t new_cap) noexcept; +void cxxbridge1$rust_vec$concrete_optimizer$dag$OperatorIndex$set_len(::rust::Vec<::concrete_optimizer::dag::OperatorIndex> *ptr, ::std::size_t len) noexcept; +void cxxbridge1$rust_vec$concrete_optimizer$dag$OperatorIndex$truncate(::rust::Vec<::concrete_optimizer::dag::OperatorIndex> *ptr, ::std::size_t len) noexcept; + void cxxbridge1$rust_vec$concrete_optimizer$dag$SecretLweKey$new(::rust::Vec<::concrete_optimizer::dag::SecretLweKey> const *ptr) noexcept; void cxxbridge1$rust_vec$concrete_optimizer$dag$SecretLweKey$drop(::rust::Vec<::concrete_optimizer::dag::SecretLweKey> *ptr) noexcept; ::std::size_t cxxbridge1$rust_vec$concrete_optimizer$dag$SecretLweKey$len(::rust::Vec<::concrete_optimizer::dag::SecretLweKey> const *ptr) noexcept; @@ -1601,6 +1641,38 @@ void Box<::concrete_optimizer::Weights>::drop() noexcept { cxxbridge1$box$concrete_optimizer$Weights$drop(this); } template <> +Vec<::concrete_optimizer::dag::OperatorIndex>::Vec() noexcept { + cxxbridge1$rust_vec$concrete_optimizer$dag$OperatorIndex$new(this); +} +template <> +void Vec<::concrete_optimizer::dag::OperatorIndex>::drop() noexcept { + return cxxbridge1$rust_vec$concrete_optimizer$dag$OperatorIndex$drop(this); +} +template <> +::std::size_t Vec<::concrete_optimizer::dag::OperatorIndex>::size() const noexcept { + return cxxbridge1$rust_vec$concrete_optimizer$dag$OperatorIndex$len(this); +} +template <> +::std::size_t Vec<::concrete_optimizer::dag::OperatorIndex>::capacity() const noexcept { + return cxxbridge1$rust_vec$concrete_optimizer$dag$OperatorIndex$capacity(this); +} +template <> +::concrete_optimizer::dag::OperatorIndex const *Vec<::concrete_optimizer::dag::OperatorIndex>::data() const noexcept { + return cxxbridge1$rust_vec$concrete_optimizer$dag$OperatorIndex$data(this); +} +template <> +void Vec<::concrete_optimizer::dag::OperatorIndex>::reserve_total(::std::size_t new_cap) noexcept { + return cxxbridge1$rust_vec$concrete_optimizer$dag$OperatorIndex$reserve_total(this, new_cap); +} +template <> +void Vec<::concrete_optimizer::dag::OperatorIndex>::set_len(::std::size_t len) noexcept { + return cxxbridge1$rust_vec$concrete_optimizer$dag$OperatorIndex$set_len(this, len); +} +template <> +void Vec<::concrete_optimizer::dag::OperatorIndex>::truncate(::std::size_t len) { + return cxxbridge1$rust_vec$concrete_optimizer$dag$OperatorIndex$truncate(this, len); +} +template <> Vec<::concrete_optimizer::dag::SecretLweKey>::Vec() noexcept { cxxbridge1$rust_vec$concrete_optimizer$dag$SecretLweKey$new(this); } diff --git a/compilers/concrete-optimizer/concrete-optimizer-cpp/src/cpp/concrete-optimizer.hpp b/compilers/concrete-optimizer/concrete-optimizer-cpp/src/cpp/concrete-optimizer.hpp index b254a8458..e893d708d 100644 --- a/compilers/concrete-optimizer/concrete-optimizer-cpp/src/cpp/concrete-optimizer.hpp +++ b/compilers/concrete-optimizer/concrete-optimizer-cpp/src/cpp/concrete-optimizer.hpp @@ -953,8 +953,12 @@ struct Dag final : public ::rust::Opaque { ::rust::Box<::concrete_optimizer::DagBuilder> builder(::rust::String circuit) noexcept; ::rust::String dump() const noexcept; ::concrete_optimizer::dag::DagSolution optimize(::concrete_optimizer::Options options) const noexcept; + void add_compositions(::rust::Slice<::concrete_optimizer::dag::OperatorIndex const> froms, ::rust::Slice<::concrete_optimizer::dag::OperatorIndex const> tos) noexcept; + void add_all_compositions() noexcept; ::std::size_t get_circuit_count() const noexcept; ::concrete_optimizer::dag::CircuitSolution optimize_multi(::concrete_optimizer::Options options) const noexcept; + ::rust::Vec<::concrete_optimizer::dag::OperatorIndex> get_input_indices() const noexcept; + ::rust::Vec<::concrete_optimizer::dag::OperatorIndex> get_output_indices() const noexcept; ~Dag() = delete; private: @@ -1092,7 +1096,6 @@ struct Options final { bool cache_on_disk; ::std::uint32_t ciphertext_modulus_log; ::std::uint32_t fft_precision; - bool composable; using IsRelocatable = ::std::true_type; }; diff --git a/compilers/concrete-optimizer/concrete-optimizer-cpp/tests/src/main.cpp b/compilers/concrete-optimizer/concrete-optimizer-cpp/tests/src/main.cpp index 5766b6ff0..c2ca86991 100644 --- a/compilers/concrete-optimizer/concrete-optimizer-cpp/tests/src/main.cpp +++ b/compilers/concrete-optimizer/concrete-optimizer-cpp/tests/src/main.cpp @@ -30,7 +30,6 @@ concrete_optimizer::Options default_options() { .cache_on_disk = true, .ciphertext_modulus_log = CIPHERTEXT_MODULUS_LOG, .fft_precision = 53, - .composable = false }; } @@ -38,7 +37,6 @@ concrete_optimizer::Options default_options() { TEST test_v0() { auto options = default_options(); - options.composable = true; concrete_optimizer::v0::Solution solution = concrete_optimizer::v0::optimize_bootstrap( PRECISION_1B, NOISE_DEVIATION_COEFF, options); @@ -261,7 +259,7 @@ TEST test_composable_dag_mono_fallback_on_dag_multi() { assert(!solution1.use_wop_pbs); assert(solution1.p_error < options.maximum_acceptable_error_probability); - options.composable = true; + dag->add_all_compositions(); auto solution2 = dag->optimize(options); assert(!solution2.use_wop_pbs); assert(solution2.p_error < options.maximum_acceptable_error_probability); @@ -298,7 +296,7 @@ TEST test_non_composable_dag_mono_fallback_on_woppbs() { assert(!solution1.use_wop_pbs); assert(solution1.p_error < options.maximum_acceptable_error_probability); - options.composable = true; + dag->add_all_compositions(); auto solution2 = dag->optimize(options); assert(solution2.p_error < options.maximum_acceptable_error_probability); assert(solution1.complexity < solution2.complexity); diff --git a/compilers/concrete-optimizer/concrete-optimizer/src/dag/operator/dot_kind.rs b/compilers/concrete-optimizer/concrete-optimizer/src/dag/operator/dot_kind.rs index 3a31ed3e3..519c7df4a 100644 --- a/compilers/concrete-optimizer/concrete-optimizer/src/dag/operator/dot_kind.rs +++ b/compilers/concrete-optimizer/concrete-optimizer/src/dag/operator/dot_kind.rs @@ -1,6 +1,6 @@ use super::{ClearTensor, Shape}; -#[derive(PartialEq, Eq, Debug)] +#[derive(PartialEq, Eq, Debug, Clone)] pub enum DotKind { // inputs = [x,y,z], weights = [a,b,c], = x*a + y*b + z*c Simple, diff --git a/compilers/concrete-optimizer/concrete-optimizer/src/dag/operator/operator.rs b/compilers/concrete-optimizer/concrete-optimizer/src/dag/operator/operator.rs index 5561d005a..16eaab9a4 100644 --- a/compilers/concrete-optimizer/concrete-optimizer/src/dag/operator/operator.rs +++ b/compilers/concrete-optimizer/concrete-optimizer/src/dag/operator/operator.rs @@ -4,6 +4,8 @@ use std::ops::Deref; use crate::dag::operator::tensor::{ClearTensor, Shape}; +use super::DotKind; + pub type Weights = ClearTensor; #[derive(Clone, PartialEq, Eq, Debug)] @@ -82,6 +84,7 @@ pub enum Operator { Dot { inputs: Vec, weights: Weights, + kind: DotKind, }, LevelledOp { inputs: Vec, @@ -116,7 +119,7 @@ impl Operator { } } -#[derive(Clone, Copy, PartialEq, Eq, Debug)] +#[derive(Clone, Copy, PartialEq, Eq, Debug, Hash)] pub struct OperatorIndex(pub usize); impl Deref for OperatorIndex { @@ -143,7 +146,9 @@ impl fmt::Display for Operator { } => { write!(f, "Input : u{out_precision} x {out_shape:?}")?; } - Self::Dot { inputs, weights } => { + Self::Dot { + inputs, weights, .. + } => { for (i, (input, weight)) in inputs.iter().zip(weights.values.iter()).enumerate() { if i > 0 { write!(f, " + ")?; diff --git a/compilers/concrete-optimizer/concrete-optimizer/src/dag/rewrite/regen.rs b/compilers/concrete-optimizer/concrete-optimizer/src/dag/rewrite/regen.rs index daa940f3c..efd3f26ef 100644 --- a/compilers/concrete-optimizer/concrete-optimizer/src/dag/rewrite/regen.rs +++ b/compilers/concrete-optimizer/concrete-optimizer/src/dag/rewrite/regen.rs @@ -41,6 +41,9 @@ pub(crate) fn regen( .for_each(|n| regen_dag.output_state[n.0].transition_use()); } } + // remap composition + regen_dag.composition = dag.composition.clone(); + regen_dag.composition.update_index(&old_index_to_new); (regen_dag, instructions_multi_map(&old_index_to_new)) } diff --git a/compilers/concrete-optimizer/concrete-optimizer/src/dag/unparametrized.rs b/compilers/concrete-optimizer/concrete-optimizer/src/dag/unparametrized.rs index ad6a86414..cffabfff3 100644 --- a/compilers/concrete-optimizer/concrete-optimizer/src/dag/unparametrized.rs +++ b/compilers/concrete-optimizer/concrete-optimizer/src/dag/unparametrized.rs @@ -1,16 +1,20 @@ use crate::dag::operator::{ - dot_kind, DotKind, FunctionTable, LevelledComplexity, Operator, OperatorIndex, Precision, - Shape, Weights, + FunctionTable, LevelledComplexity, Operator, OperatorIndex, Precision, Shape, Weights, }; -use std::{collections::HashSet, fmt}; +use std::{ + collections::{HashMap, HashSet}, + fmt, +}; + +use super::operator::dot_kind::{dot_kind, DotKind}; /// The name of the default. Used when adding operations directly on the dag instead of via a /// builder. const DEFAULT_CIRCUIT: &str = "_"; /// A state machine to define if an operator is used as output to a circuit. -#[derive(Debug, Clone, PartialEq, Copy)] -pub(crate) enum OutputState { +#[derive(Debug, Clone, PartialEq, Eq, Copy)] +pub enum OutputState { /// The operator was created and neither used as input to another operator, nor tagged as output /// explicitly. It is considered an output. Unused, @@ -51,29 +55,29 @@ impl OutputState { /// A type referencing every informations related to an operator of the dag. #[derive(Debug, Clone)] #[allow(unused)] -pub(crate) struct DagOperator<'dag> { - pub(crate) id: OperatorIndex, - pub(crate) dag: &'dag Dag, - pub(crate) operator: &'dag Operator, - pub(crate) shape: &'dag Shape, - pub(crate) precision: &'dag Precision, - pub(crate) output_state: &'dag OutputState, - pub(crate) circuit_tag: &'dag String, +pub struct DagOperator<'dag> { + pub id: OperatorIndex, + pub dag: &'dag Dag, + pub operator: &'dag Operator, + pub shape: &'dag Shape, + pub precision: &'dag Precision, + pub output_state: &'dag OutputState, + pub circuit_tag: &'dag String, } impl<'dag> DagOperator<'dag> { /// Returns if the operator is an input. - pub(crate) fn is_input(&self) -> bool { + pub fn is_input(&self) -> bool { matches!(self.operator, Operator::Input { .. }) } /// Returns if the operator is an output. - pub(crate) fn is_output(&self) -> bool { + pub fn is_output(&self) -> bool { self.output_state.is_output() } /// Returns an iterator over the operators used as input to this operator. - pub(crate) fn get_inputs_iter(&self) -> impl Iterator> + '_ { + pub fn get_inputs_iter(&self) -> impl Iterator> { self.operator .get_inputs_iter() .map(|id| self.dag.get_operator(*id)) @@ -90,19 +94,19 @@ pub struct DagCircuit<'dag> { impl<'dag> DagCircuit<'dag> { /// Returns an iterator over the operators of this circuit. - pub(crate) fn get_operators_iter(&self) -> impl Iterator> + '_ { + pub fn get_operators_iter(&self) -> impl Iterator> + '_ { self.ids.iter().map(|id| self.dag.get_operator(*id)) } /// Returns an iterator over the circuit's input operators. #[allow(unused)] - pub(crate) fn get_input_operators_iter(&self) -> impl Iterator> + '_ { + pub fn get_input_operators_iter(&self) -> impl Iterator> + '_ { self.get_operators_iter().filter(DagOperator::is_input) } /// Returns an iterator over the circuit's output operators. #[allow(unused)] - pub(crate) fn get_output_operators_iter(&self) -> impl Iterator> + '_ { + pub fn get_output_operators_iter(&self) -> impl Iterator> + '_ { self.get_operators_iter().filter(DagOperator::is_output) } } @@ -176,7 +180,14 @@ impl<'dag> DagBuilder<'dag> { ) -> OperatorIndex { let inputs = inputs.into(); let weights = weights.into(); - self.add_operator(Operator::Dot { inputs, weights }) + // We detect the kind of dot to simplify matching later on. + let nb_inputs = inputs.len() as u64; + let input_shape = self.dag.get_operator(inputs[0]).shape; + self.add_operator(Operator::Dot { + inputs, + kind: dot_kind(nb_inputs, input_shape, &weights), + weights, + }) } pub fn add_levelled_op( @@ -353,28 +364,26 @@ impl<'dag> DagBuilder<'dag> { | Operator::UnsafeCast { input, .. } | Operator::Round { input, .. } => self.dag.out_shapes[input.0].clone(), Operator::Dot { - inputs, weights, .. + kind: DotKind::Simple | DotKind::Tensor | DotKind::CompatibleTensor, + .. + } => Shape::number(), + Operator::Dot { + kind: DotKind::Broadcast { shape }, + .. + } => shape.clone(), + Operator::Dot { + kind: DotKind::Unsupported { .. }, + weights, + inputs, } => { - let input_shape = self.dag.out_shapes[inputs[0].0].clone(); - let kind = dot_kind(inputs.len() as u64, &input_shape, weights); - match kind { - DotKind::Simple | DotKind::Tensor | DotKind::CompatibleTensor => { - Shape::number() - } - DotKind::Broadcast { shape } => shape, - DotKind::Unsupported { .. } => { - let weights_shape = &weights.shape; - - println!(); - println!(); - println!("Error diagnostic on dot operation:"); - println!( - "Incompatible operands: <{input_shape:?}> DOT <{weights_shape:?}>" - ); - println!(); - panic!("Unsupported or invalid dot operation") - } - } + let weights_shape = &weights.shape; + let input_shape = self.dag.get_operator(inputs[0]).shape; + println!(); + println!(); + println!("Error diagnostic on dot operation:"); + println!("Incompatible operands: <{input_shape:?}> DOT <{weights_shape:?}>"); + println!(); + panic!("Unsupported or invalid dot operation") } } } @@ -392,6 +401,41 @@ impl<'dag> DagBuilder<'dag> { } } +#[derive(Clone, PartialEq, Debug, Default)] +pub(crate) struct CompositionRules(HashMap>); + +impl CompositionRules { + pub(crate) fn add(&mut self, from: OperatorIndex, to: OperatorIndex) { + let _ = self + .0 + .entry(to) + .and_modify(|e| e.push(from)) + .or_insert_with(|| [from].into()); + } + + pub(crate) fn update_index(&mut self, old_to_new_map: &[usize]) { + let mut old_map = HashMap::with_capacity(self.0.capacity()); + std::mem::swap(&mut old_map, &mut self.0); + for (old_id, mut compositions) in old_map { + let adapter = |old_id: OperatorIndex| -> OperatorIndex { + OperatorIndex(old_to_new_map[old_id.0]) + }; + compositions + .iter_mut() + .for_each(|from| *from = adapter(*from)); + let _ = self.0.insert(adapter(old_id), compositions); + } + } +} + +impl IntoIterator for CompositionRules { + type Item = (OperatorIndex, Vec); + type IntoIter = std::collections::hash_map::IntoIter>; + fn into_iter(self) -> Self::IntoIter { + self.0.into_iter() + } +} + /// A type containing a Directed Acyclic Graph of operators. /// /// This is the major datatype used to encode a module in the optimizer. It is equivalent to an fhe @@ -417,6 +461,8 @@ pub struct Dag { pub(crate) output_state: Vec, // Collect the circuit the operators are associated with pub(crate) circuit_tags: Vec, + // Composition rules + pub(crate) composition: CompositionRules, } impl fmt::Display for Dag { @@ -443,6 +489,7 @@ impl Dag { out_precisions: vec![], output_state: vec![], circuit_tags: vec![], + composition: CompositionRules::default(), } } @@ -586,13 +633,38 @@ impl Dag { ) } + /// Adds a composition rule to the dag. + pub fn add_composition(&mut self, from: OperatorIndex, to: OperatorIndex) { + debug_assert!(self.get_operator(from).is_output()); + debug_assert!(self.get_operator(to).is_input()); + self.composition.add(from, to); + } + + /// Adds a composition rule between every elements of from and every elements of to. + pub fn add_compositions, B: AsRef<[OperatorIndex]>>( + &mut self, + from: A, + to: B, + ) { + for from_i in from.as_ref() { + for to_i in to.as_ref() { + self.add_composition(*from_i, *to_i); + } + } + } + + /// Returns whether the dag contains a composition rule. + pub fn is_composed(&self) -> bool { + !self.composition.0.is_empty() + } + /// Returns an iterator over the operator indices. - pub(crate) fn get_indices_iter(&self) -> impl Iterator { + pub fn get_indices_iter(&self) -> impl Iterator { (0..self.len()).map(OperatorIndex) } /// Returns an iterator over the circuits contained in the dag. - pub(crate) fn get_circuits_iter(&self) -> impl Iterator> + '_ { + pub fn get_circuits_iter(&self) -> impl Iterator> + '_ { let mut circuits: HashSet = HashSet::new(); self.circuit_tags.iter().for_each(|name| { let _ = circuits.insert(name.to_owned()); @@ -607,7 +679,7 @@ impl Dag { /// # Note: /// /// Panics if no circuit with the given name exist in the dag. - pub(crate) fn get_circuit>(&self, circuit: A) -> DagCircuit { + pub fn get_circuit>(&self, circuit: A) -> DagCircuit { let circuit = circuit.as_ref().to_string(); assert!(self.circuit_tags.contains(&circuit)); let ids = self @@ -624,21 +696,22 @@ impl Dag { } /// Returns an iterator over the input operators of the dag. - pub(crate) fn get_input_operators_iter(&self) -> impl Iterator> { + #[allow(unused)] + pub fn get_input_operators_iter(&self) -> impl Iterator> { self.get_indices_iter() .map(|i| self.get_operator(i)) .filter(DagOperator::is_input) } /// Returns an iterator over the outputs operators of the dag. - pub(crate) fn get_output_operators_iter(&self) -> impl Iterator> { + pub fn get_output_operators_iter(&self) -> impl Iterator> { self.get_indices_iter() .map(|i| self.get_operator(i)) .filter(DagOperator::is_output) } /// Returns an iterator over the operators of the dag. - pub(crate) fn get_operators_iter(&self) -> impl Iterator> { + pub fn get_operators_iter(&self) -> impl Iterator> { self.get_indices_iter().map(|i| self.get_operator(i)) } @@ -647,7 +720,7 @@ impl Dag { /// # Note: /// /// Panics if the operator index is invalid. - pub(crate) fn get_operator(&self, id: OperatorIndex) -> DagOperator<'_> { + pub fn get_operator(&self, id: OperatorIndex) -> DagOperator<'_> { assert!(id.0 < self.len()); DagOperator { dag: self, @@ -686,7 +759,6 @@ impl Dag { #[cfg(test)] mod tests { use super::*; - use crate::dag::operator::Shape; #[test] fn output_marking() { @@ -722,7 +794,6 @@ mod tests { let mut graph = Dag::new(); let mut builder = graph.builder("_"); let input1 = builder.add_input(1, Shape::number()); - let input2 = builder.add_input(2, Shape::number()); let cpx_add = LevelledComplexity::ADDITION; @@ -778,6 +849,7 @@ mod tests { shape: Shape::vector(2), values: vec![1, 2] }, + kind: DotKind::Tensor }, Operator::Lut { input: dot, @@ -811,6 +883,7 @@ mod tests { Operator::Dot { inputs: vec![input1], weights: Weights::number(1 << 5), + kind: DotKind::Tensor, }, Operator::UnsafeCast { input: OperatorIndex(1), @@ -826,6 +899,7 @@ mod tests { Operator::Dot { inputs: vec![input1, OperatorIndex(3)], weights: Weights::vector([1, -1]), + kind: DotKind::Simple, }, Operator::UnsafeCast { input: OperatorIndex(4), @@ -836,6 +910,7 @@ mod tests { Operator::Dot { inputs: vec![OperatorIndex(5)], weights: Weights::number(1 << 4), + kind: DotKind::Tensor, }, Operator::UnsafeCast { input: OperatorIndex(6), @@ -851,6 +926,7 @@ mod tests { Operator::Dot { inputs: vec![OperatorIndex(5), OperatorIndex(8)], weights: Weights::vector([1, -1]), + kind: DotKind::Simple, }, Operator::UnsafeCast { input: OperatorIndex(9), @@ -861,6 +937,7 @@ mod tests { Operator::Dot { inputs: vec![OperatorIndex(10)], weights: Weights::number(1 << 3), + kind: DotKind::Tensor, }, Operator::UnsafeCast { input: OperatorIndex(11), @@ -876,6 +953,7 @@ mod tests { Operator::Dot { inputs: vec![OperatorIndex(10), OperatorIndex(13)], weights: Weights::vector([1, -1]), + kind: DotKind::Simple, }, Operator::UnsafeCast { input: OperatorIndex(14), diff --git a/compilers/concrete-optimizer/concrete-optimizer/src/optimization/config.rs b/compilers/concrete-optimizer/concrete-optimizer/src/optimization/config.rs index 62775295c..d23c255b7 100644 --- a/compilers/concrete-optimizer/concrete-optimizer/src/optimization/config.rs +++ b/compilers/concrete-optimizer/concrete-optimizer/src/optimization/config.rs @@ -18,7 +18,6 @@ pub struct Config<'a> { pub ciphertext_modulus_log: u32, pub fft_precision: u32, pub complexity_model: &'a dyn ComplexityModel, - pub composable: bool, } #[derive(Clone, Debug)] diff --git a/compilers/concrete-optimizer/concrete-optimizer/src/optimization/dag/multi_parameters/analyze.rs b/compilers/concrete-optimizer/concrete-optimizer/src/optimization/dag/multi_parameters/analyze.rs index 1f8bb65f4..9d9d5207b 100644 --- a/compilers/concrete-optimizer/concrete-optimizer/src/optimization/dag/multi_parameters/analyze.rs +++ b/compilers/concrete-optimizer/concrete-optimizer/src/optimization/dag/multi_parameters/analyze.rs @@ -1,8 +1,8 @@ -use crate::dag::operator::{ - dot_kind, DotKind, LevelledComplexity, Operator, OperatorIndex, Precision, Shape, -}; +use std::ops::{Deref, Index, IndexMut}; + +use crate::dag::operator::{DotKind, LevelledComplexity, Operator, OperatorIndex, Precision}; use crate::dag::rewrite::round::expand_round_and_index_map; -use crate::dag::unparametrized; +use crate::dag::unparametrized::{Dag, DagOperator}; use crate::optimization::config::NoiseBoundConfig; use crate::optimization::dag::multi_parameters::partition_cut::PartitionCut; use crate::optimization::dag::multi_parameters::partitionning::partitionning_with_preferred; @@ -10,29 +10,310 @@ use crate::optimization::dag::multi_parameters::partitions::{ InstructionPartition, PartitionIndex, Transition, }; use crate::optimization::dag::multi_parameters::symbolic_variance::SymbolicVariance; -use crate::optimization::dag::solo_key::analyze::{first, safe_noise_bound}; -use crate::optimization::Err::NotComposable; -use crate::optimization::Result; +use crate::optimization::dag::solo_key::analyze::safe_noise_bound; +use crate::optimization::{Err, Result}; use super::complexity::OperationsCount; use super::keys_spec; use super::operations_value::OperationsValue; +use super::partitions::Partitions; use super::variance_constraint::VarianceConstraint; - use crate::utils::square; -// private short convention -use DotKind as DK; +const MAX_FORWARDING: u16 = 1000; -type Op = Operator; +#[derive(Debug, Clone)] +pub struct PartitionedDag { + pub(crate) dag: Dag, + pub(crate) partitions: Partitions, +} + +impl PartitionedDag { + fn get_initial_variances(&self) -> Variances { + let vars = self + .dag + .get_operators_iter() + .map(|op| { + if op.is_input() { + let mut var = OperatorVariance::nan(self.partitions.nb_partitions); + let partition = self.partitions[op.id].instruction_partition; + var[partition] = + SymbolicVariance::input(self.partitions.nb_partitions, partition); + var + } else { + OperatorVariance::nan(self.partitions.nb_partitions) + } + }) + .collect(); + Variances { vars } + } +} + +#[derive(Debug)] +pub struct VariancedDagOperator<'a> { + dag: &'a VariancedDag, + id: OperatorIndex, +} + +impl<'a> VariancedDagOperator<'a> { + #[allow(unused)] + fn get_inputs_iter(&self) -> impl Iterator> { + self.operator() + .get_inputs_iter() + .collect::>() + .into_iter() + .map(|n| self.dag.get_operator(n.id)) + } + + pub(crate) fn operator(&self) -> DagOperator<'a> { + self.dag.dag.get_operator(self.id) + } + + #[allow(unused)] + pub(crate) fn partition(&self) -> &InstructionPartition { + &self.dag.partitions[self.id] + } + + pub(crate) fn variance(&self) -> &OperatorVariance { + &self.dag.variances[self.id] + } +} + +pub struct VariancedDagOperatorMut<'a> { + dag: &'a mut VariancedDag, + id: OperatorIndex, +} + +impl<'a> VariancedDagOperatorMut<'a> { + fn get_inputs_iter(&self) -> impl Iterator> { + self.operator() + .get_inputs_iter() + .collect::>() + .into_iter() + .map(|n| self.dag.get_operator(n.id)) + } + + pub(crate) fn operator(&self) -> DagOperator<'_> { + self.dag.dag.get_operator(self.id) + } + + pub(crate) fn partition(&self) -> &InstructionPartition { + &self.dag.partitions[self.id] + } + + #[allow(unused)] + pub(crate) fn partition_mut(&mut self) -> &mut InstructionPartition { + &mut self.dag.partitions[self.id] + } + + #[allow(unused)] + pub(crate) fn variance(&self) -> &OperatorVariance { + &self.dag.variances[self.id] + } + + pub(crate) fn variance_mut(&mut self) -> &mut OperatorVariance { + &mut self.dag.variances[self.id] + } +} + +#[derive(Debug, Clone)] +pub struct VariancedDag { + pub(crate) dag: Dag, + pub(crate) partitions: Partitions, + pub(crate) variances: Variances, +} + +impl VariancedDag { + fn try_from_partitioned(partitioned: PartitionedDag) -> Result { + // We compute the initial variances with noise at input nodes and NANs everywhere + // else. + let variances = partitioned.get_initial_variances(); + let PartitionedDag { dag, partitions } = partitioned; + let mut varianced = Self { + dag, + partitions, + variances, + }; + + // We forward the noise once to verify the composability. + let _ = varianced.forward_noise(); + varianced.check_composability()?; + varianced.apply_composition_rules(); + + // We loop, forwarding the noise, until it settles. + for _ in 0..MAX_FORWARDING { + // The noise gets computed from inputs down to outputs. + if varianced.forward_noise() { + // Noise settled, we return the varianced dag. + return Ok(varianced); + } + // The noise of the inputs gets updated following the composition rules + varianced.apply_composition_rules(); + } + + panic!("Forwarding of noise did not reach a fixed point.") + } + + fn get_operator(&self, index: OperatorIndex) -> VariancedDagOperator<'_> { + VariancedDagOperator { + dag: self, + id: index, + } + } + + fn get_operator_mut(&mut self, index: OperatorIndex) -> VariancedDagOperatorMut<'_> { + VariancedDagOperatorMut { + dag: self, + id: index, + } + } + + /// Patches the inputs following the composition rules. + fn apply_composition_rules(&mut self) { + for (to, froms) in self.dag.composition.clone() { + let maxed_variance = froms + .into_iter() + .map(|id| self.get_operator(id).variance().to_owned()) + .reduce(|acc, var| acc.partition_wise_max(&var)) + .unwrap(); + let mut input = self.get_operator_mut(to); + *(input.variance_mut()) = maxed_variance; + } + } + + /// Propagates the noise downward in the graph. + fn forward_noise(&mut self) -> bool { + // We save the old variance to compute the diff at the end. + let old_variances = self.variances.clone(); + let nb_partitions = self.partitions.nb_partitions; + + // We loop through the operators and propagate the noise. + for operator_id in self.dag.get_indices_iter() { + let mut operator = self.get_operator_mut(operator_id); + // Inputs are already computed + if operator.operator().is_input() { + continue; + } + let max_var = |acc: SymbolicVariance, input: SymbolicVariance| acc.max(&input); + // Operator variance will be used to override the noise + let mut operator_variance = OperatorVariance::nan(nb_partitions); + // We first compute the noise in the partition of the operator + operator_variance[operator.partition().instruction_partition] = match operator + .operator() + .operator + { + Operator::Input { .. } => unreachable!(), + Operator::Lut { .. } => SymbolicVariance::after_pbs( + nb_partitions, + operator.partition().instruction_partition, + ), + Operator::LevelledOp { manp, .. } => { + let max_var = operator + .get_inputs_iter() + .map(|a| a.variance()[operator.partition().instruction_partition].clone()) + .reduce(max_var) + .unwrap(); + max_var.after_levelled_op(*manp) + } + Operator::Dot { + kind: DotKind::CompatibleTensor { .. }, + .. + } => todo!("TODO"), + Operator::Dot { + kind: DotKind::Unsupported { .. }, + .. + } => panic!("Unsupported"), + Operator::Dot { + inputs, + weights, + kind: DotKind::Simple | DotKind::Tensor | DotKind::Broadcast { .. }, + } if inputs.len() == 1 => { + let var = operator + .get_inputs_iter() + .next() + .unwrap() + .variance() + .clone(); + weights + .values + .iter() + .fold(SymbolicVariance::ZERO, |acc, weight| { + acc + var[operator.partition().instruction_partition].clone() + * square(*weight as f64) + }) + } + Operator::Dot { + weights, + kind: DotKind::Simple | DotKind::Tensor | DotKind::Broadcast { .. }, + .. + } => weights + .values + .iter() + .zip(operator.get_inputs_iter().map(|n| n.variance().clone())) + .fold(SymbolicVariance::ZERO, |acc, (weight, var)| { + acc + var[operator.partition().instruction_partition].clone() + * square(*weight as f64) + }), + Operator::UnsafeCast { .. } => { + operator.get_inputs_iter().next().unwrap().variance() + [operator.partition().instruction_partition] + .clone() + } + Operator::Round { .. } => { + unreachable!("Round should have been either expanded or integrated to a lut") + } + }; + // We add the noise for the transitions to alternative representations + operator + .partition() + .alternative_output_representation + .iter() + .for_each(|index| { + operator_variance[*index] = operator_variance + [operator.partition().instruction_partition] + .after_partition_keyswitch_to_big( + operator.partition().instruction_partition, + *index, + ); + }); + // We override the noise + *operator.variance_mut() = operator_variance; + } + + // We return whether there is a diff or not. + old_variances == self.variances + } + + #[allow(unused)] + fn check_composability(&self) -> Result<()> { + self.dag + .composition + .clone() + .into_iter() + .flat_map(|(to, froms)| froms.into_iter()) + .map(|i| self.get_operator(i)) + .filter(|op| op.operator().is_output()) + .try_for_each(|op| { + let id = op.id; + op.variance() + .check_growing_input_noise() + .map_err(|err| match err { + Err::NotComposable(prev) => Err::NotComposable(format!( + "Dag is not composable, because of output {id}: {prev}" + )), + Err::NoParametersFound => Err::NoParametersFound, + }) + }) + } +} #[derive(Debug)] pub struct AnalyzedDag { - pub operators: Vec, + pub operators: Vec, // Collect all operators ouput variances pub nb_partitions: usize, pub instrs_partition: Vec, - pub out_variances: Vec>, + pub instrs_variances: Vec, // The full dag levelled complexity pub levelled_complexity: LevelledComplexity, // All variance constraints including dominated ones @@ -46,11 +327,10 @@ pub struct AnalyzedDag { } pub fn analyze( - dag: &unparametrized::Dag, + dag: &Dag, noise_config: &NoiseBoundConfig, p_cut: &Option, default_partition: PartitionIndex, - composable: bool, ) -> Result { let (dag, instruction_rewrite_index) = expand_round_and_index_map(dag); let levelled_complexity = LevelledComplexity::ZERO; @@ -61,45 +341,20 @@ pub fn analyze( Some(p_cut) => p_cut.clone(), None => PartitionCut::for_each_precision(&dag), }; - let partitions = partitionning_with_preferred(&dag, &p_cut, default_partition, composable); - let instrs_partition = partitions.instrs_partition; - let nb_partitions = partitions.nb_partitions; - let mut out_variances = self::out_variances(&dag, nb_partitions, &instrs_partition, &None); - if composable { - // Verify that there is no input symbol in the symbolic variances of the outputs. - check_composability(&dag, &out_variances, nb_partitions)?; - // Get the largest output out_variance - let largest_output_variances = dag - .get_output_operators_iter() - .map(|op| out_variances[op.id.0].clone()) - .reduce(|lhs, rhs| { - lhs.into_iter() - .zip(rhs) - .map(|(lhsi, rhsi)| lhsi.max(&rhsi)) - .collect() - }) - .expect("Failed to get the largest output variance."); - // Re-compute the out variances with input variances overriden by input variances - out_variances = self::out_variances( - &dag, - nb_partitions, - &instrs_partition, - &Some(largest_output_variances), - ); - } - let variance_constraints = - collect_all_variance_constraints(&dag, noise_config, &instrs_partition, &out_variances); + let partitions = partitionning_with_preferred(&dag, &p_cut, default_partition); + let partitioned_dag = PartitionedDag { dag, partitions }; + let varianced_dag = VariancedDag::try_from_partitioned(partitioned_dag)?; + let variance_constraints = collect_all_variance_constraints(&varianced_dag, noise_config); let undominated_variance_constraints = VarianceConstraint::remove_dominated(&variance_constraints); - let operations_count_per_instrs = - collect_operations_count(&dag, nb_partitions, &instrs_partition); + let operations_count_per_instrs = collect_operations_count(&varianced_dag); let operations_count = sum_operations_count(&operations_count_per_instrs); Ok(AnalyzedDag { - operators: dag.operators, + operators: varianced_dag.dag.operators, instruction_rewrite_index, - nb_partitions, - instrs_partition, - out_variances, + nb_partitions: varianced_dag.partitions.nb_partitions, + instrs_partition: varianced_dag.partitions.instrs_partition, + instrs_variances: varianced_dag.variances.vars, levelled_complexity, variance_constraints, undominated_variance_constraints, @@ -109,41 +364,6 @@ pub fn analyze( }) } -fn check_composability( - dag: &unparametrized::Dag, - symbolic_variances: &[Vec], - nb_partitions: usize, -) -> Result<()> { - // If the circuit only contains inputs, then it is composable. - let only_inputs = dag - .operators - .iter() - .all(|node| matches!(node, Operator::Input { .. })); - if only_inputs { - return Ok(()); - } - - // If the circuit outputs are free from input variances, it means that every outputs are - // refreshed, and the function can be composed. - let in_var_in_out_var = dag - .get_output_operators_iter() - .flat_map(|op| symbolic_variances[op.id.0].iter().map(move |v| (op.id, v))) - .find_map(|(output_index, sym_var)| { - (0..nb_partitions) - .find(|partition| { - let coeff = sym_var.coeff_input(*partition); - coeff != 0.0f64 && !coeff.is_nan() - }) - .map(|_| (output_index, sym_var)) - }); - match in_var_in_out_var { - Some((output_id, sym_var)) => Err(NotComposable(format!( - "Output {output_id} has variance {sym_var}." - ))), - None => Ok(()), - } -} - pub fn original_instrs_partition( dag: &AnalyzedDag, keys: &keys_spec::ExpandedCircuitKeys, @@ -165,7 +385,7 @@ pub fn original_instrs_partition( for (i, new_instruction) in new_instructions.iter().enumerate() { // focus on TLU information let new_instr_part = &dag.instrs_partition[new_instruction.0]; - if let Op::Lut { .. } = dag.operators[new_instruction.0] { + if let Operator::Lut { .. } = dag.operators[new_instruction.0] { let ks_dst = new_instr_part.instruction_partition; partition = Some(ks_dst); #[allow(clippy::match_on_vec_items)] @@ -175,8 +395,8 @@ pub fn original_instrs_partition( _ => unreachable!(), }; input_partition = Some(ks_src); - let ks_key = ks_keys[ks_src][ks_dst].as_ref().unwrap().identifier; - let pbs_key = pbs_keys[ks_dst].identifier; + let ks_key = ks_keys[ks_src.0][ks_dst.0].as_ref().unwrap().identifier; + let pbs_key = pbs_keys[ks_dst.0].identifier; assert!(tlu_keyswitch_key.unwrap_or(ks_key) == ks_key); assert!(tlu_bootstrap_key.unwrap_or(pbs_key) == pbs_key); tlu_keyswitch_key = Some(ks_key); @@ -190,7 +410,7 @@ pub fn original_instrs_partition( .iter() .next() .unwrap(); - let key = fks_keys[src][dst].as_ref().unwrap().identifier; + let key = fks_keys[src.0][dst.0].as_ref().unwrap().identifier; assert!(conversion_key.unwrap_or(key) == key); conversion_key = Some(key); } @@ -204,10 +424,10 @@ pub fn original_instrs_partition( partition.unwrap_or(dag.instrs_partition[new_instructions[0].0].instruction_partition); let input_partition = input_partition.unwrap_or(partition); let merged = keys_spec::InstructionKeys { - input_key: big_keys[input_partition].identifier, + input_key: big_keys[input_partition.0].identifier, tlu_keyswitch_key: tlu_keyswitch_key.unwrap_or(unknown), tlu_bootstrap_key: tlu_bootstrap_key.unwrap_or(unknown), - output_key: big_keys[partition].identifier, + output_key: big_keys[partition.0].identifier, extra_conversion_keys: conversion_key.iter().copied().collect(), tlu_circuit_bootstrap_key: keys_spec::NO_KEY_ID, tlu_private_functional_packing_key: keys_spec::NO_KEY_ID, @@ -217,105 +437,100 @@ pub fn original_instrs_partition( result } -fn out_variance( - op: &Operator, - out_shapes: &[Shape], - out_variances: &[Vec], - nb_partitions: usize, - instr_partition: &InstructionPartition, - input_override: Option>, -) -> Vec { - // If an override is given for input and we have an input node, we override. - if let (Some(overr), Op::Input { .. }) = (input_override, op) { - return overr; +#[derive(PartialEq, Debug, Clone)] +pub struct OperatorVariance { + pub(crate) vars: Vec, +} + +impl Index for OperatorVariance { + type Output = SymbolicVariance; + + fn index(&self, index: PartitionIndex) -> &Self::Output { + &self.vars[index.0] } - // one variance per partition, in case the result is converted - let partition = instr_partition.instruction_partition; - let out_variance_of = |input: &OperatorIndex| { - assert!(input.0 < out_variances.len()); - assert!(partition < out_variances[input.0].len()); - assert!(out_variances[input.0][partition] != SymbolicVariance::ZERO); - assert!(!out_variances[input.0][partition].coeffs.values[0].is_nan()); - assert!(out_variances[input.0][partition].partition != usize::MAX); - out_variances[input.0][partition].clone() - }; - let max_variance = |acc: SymbolicVariance, input: SymbolicVariance| acc.max(&input); - let variance = match op { - Op::Input { .. } => SymbolicVariance::input(nb_partitions, partition), - Op::Lut { .. } => SymbolicVariance::after_pbs(nb_partitions, partition), - Op::LevelledOp { inputs, manp, .. } => { - let inputs_variance = inputs.iter().map(out_variance_of); - let max_variance = inputs_variance.reduce(max_variance).unwrap(); - max_variance.after_levelled_op(*manp) +} + +impl IndexMut for OperatorVariance { + fn index_mut(&mut self, index: PartitionIndex) -> &mut Self::Output { + &mut self.vars[index.0] + } +} + +impl Deref for OperatorVariance { + type Target = [SymbolicVariance]; + + fn deref(&self) -> &Self::Target { + &self.vars + } +} + +impl OperatorVariance { + pub fn nan(nb_partitions: usize) -> Self { + Self { + vars: (0..nb_partitions) + .map(|_| SymbolicVariance::nan(nb_partitions)) + .collect(), } - Op::Dot { - inputs, weights, .. - } => { - let input_shape = first(inputs, out_shapes); - let kind = dot_kind(inputs.len() as u64, input_shape, weights); - match kind { - DK::Simple | DK::Tensor | DK::Broadcast { .. } => { - let inputs_variance = (0..weights.values.len()).map(|j| { - let input = if inputs.len() > 1 { - inputs[j] - } else { - inputs[0] - }; - out_variance_of(&input) - }); - let mut out_variance = SymbolicVariance::ZERO; - for (input_variance, &weight) in inputs_variance.zip(&weights.values) { - assert!(input_variance != SymbolicVariance::ZERO); - out_variance += input_variance * square(weight as f64); - } - out_variance + } + + pub fn partition_wise_max(&self, other: &Self) -> Self { + let vars = self + .vars + .iter() + .zip(other.vars.iter()) + .map(|(s, o)| s.max(o)) + .collect(); + Self { vars } + } + + pub fn check_growing_input_noise(&self) -> Result<()> { + self.vars + .iter() + .flat_map(|var| { + PartitionIndex::range(0, var.nb_partitions()).map(|i| (i, var.coeff_input(i))) + }) + .try_for_each(|(partition, coeff)| { + if !coeff.is_nan() && coeff > 1.0 { + Result::Err(Err::NotComposable(format!( + "Partition {partition} has input coefficient {coeff}" + ))) + } else { + Ok(()) } - DK::CompatibleTensor { .. } => todo!("TODO"), - DK::Unsupported { .. } => panic!("Unsupported"), - } - } - Op::UnsafeCast { input, .. } => out_variance_of(input), - Op::Round { .. } => { - unreachable!("Round should have been either expanded or integrated to a lut") - } - }; - // Injecting NAN in unused symbolic variance to detect bad use - let unused = SymbolicVariance::nan(nb_partitions); - let mut result = vec![unused; nb_partitions]; - for &dst_partition in &instr_partition.alternative_output_representation { - let src_partition = partition; - // make converted variance available in dst_partition - result[dst_partition] = - variance.after_partition_keyswitch_to_big(src_partition, dst_partition); + }) } - result[partition] = variance; - result } -fn out_variances( - dag: &unparametrized::Dag, - nb_partitions: usize, - instrs_partition: &[InstructionPartition], - input_override: &Option>, -) -> Vec> { - let nb_ops = dag.operators.len(); - let mut out_variances = Vec::with_capacity(nb_ops); - for (op, instr_partition) in dag.operators.iter().zip(instrs_partition) { - let vf = out_variance( - op, - &dag.out_shapes, - &out_variances, - nb_partitions, - instr_partition, - input_override.clone(), - ); - out_variances.push(vf); - } - out_variances +#[derive(PartialEq, Debug, Clone)] +pub struct Variances { + pub(crate) vars: Vec, } +impl Index for Variances { + type Output = OperatorVariance; + + fn index(&self, index: OperatorIndex) -> &Self::Output { + &self.vars[index.0] + } +} + +impl IndexMut for Variances { + fn index_mut(&mut self, index: OperatorIndex) -> &mut Self::Output { + &mut self.vars[index.0] + } +} + +impl Deref for Variances { + type Target = [OperatorVariance]; + + fn deref(&self) -> &Self::Target { + &self.vars + } +} + +#[allow(unused)] fn variance_constraint( - dag: &unparametrized::Dag, + dag: &Dag, noise_config: &NoiseBoundConfig, partition: PartitionIndex, op_i: usize, @@ -333,21 +548,25 @@ fn variance_constraint( } } +#[allow(unused)] #[allow(clippy::float_cmp)] #[allow(clippy::match_on_vec_items)] fn collect_all_variance_constraints( - dag: &unparametrized::Dag, + dag: &VariancedDag, noise_config: &NoiseBoundConfig, - instrs_partition: &[InstructionPartition], - out_variances: &[Vec], ) -> Vec { + let VariancedDag { + dag, + partitions, + variances, + } = dag; let mut constraints = vec![]; for op in dag.get_operators_iter() { - let partition = instrs_partition[op.id.0].instruction_partition; - if let Op::Lut { input, .. } = op.operator { + let partition = partitions[op.id].instruction_partition; + if let Operator::Lut { input, .. } = op.operator { let precision = dag.out_precisions[input.0]; let dst_partition = partition; - let src_partition = match instrs_partition[op.id.0].inputs_transition[0] { + let src_partition = match partitions[op.id].inputs_transition[0] { None => dst_partition, Some(Transition::Internal { src_partition }) => { assert!(src_partition != dst_partition); @@ -355,7 +574,7 @@ fn collect_all_variance_constraints( } Some(Transition::Additional { src_partition }) => { assert!(src_partition != dst_partition); - let variance = &out_variances[input.0][dst_partition]; + let variance = &variances[*input][dst_partition]; assert!( variance.coeff_partition_keyswitch_to_big(src_partition, dst_partition) == 1.0 @@ -363,7 +582,7 @@ fn collect_all_variance_constraints( dst_partition } }; - let variance = &out_variances[input.0][src_partition].clone(); + let variance = &variances[*input][src_partition].clone(); let variance = variance .after_partition_keyswitch_to_small(src_partition, dst_partition) .after_modulus_switching(partition); @@ -378,7 +597,7 @@ fn collect_all_variance_constraints( } if op.is_output() { let precision = dag.out_precisions[op.id.0]; - let variance = out_variances[op.id.0][partition].clone(); + let variance = variances[op.id][partition].clone(); constraints.push(variance_constraint( dag, noise_config, @@ -392,15 +611,16 @@ fn collect_all_variance_constraints( constraints } +#[allow(unused)] #[allow(clippy::match_on_vec_items)] fn operations_counts( - dag: &unparametrized::Dag, + dag: &Dag, op: &Operator, nb_partitions: usize, instr_partition: &InstructionPartition, ) -> OperationsCount { let mut counts = OperationsValue::zero(nb_partitions); - if let Op::Lut { input, .. } = op { + if let Operator::Lut { input, .. } = op { let partition = instr_partition.instruction_partition; let nb_lut = dag.out_shapes[input.0].flat_size() as f64; let src_partition = match instr_partition.inputs_transition[0] { @@ -416,18 +636,24 @@ fn operations_counts( OperationsCount { counts } } -fn collect_operations_count( - dag: &unparametrized::Dag, - nb_partitions: usize, - instrs_partition: &[InstructionPartition], -) -> Vec { - dag.operators +#[allow(unused)] +fn collect_operations_count(dag: &VariancedDag) -> Vec { + dag.dag + .operators .iter() .enumerate() - .map(|(i, op)| operations_counts(dag, op, nb_partitions, &instrs_partition[i])) + .map(|(i, op)| { + operations_counts( + &dag.dag, + op, + dag.partitions.nb_partitions, + &dag.partitions[OperatorIndex(i)], + ) + }) .collect() } +#[allow(unused)] fn sum_operations_count(all_counts: &[OperationsCount]) -> OperationsCount { let mut sum_counts = OperationsValue::zero(all_counts[0].counts.nb_partitions()); for OperationsCount { counts } in all_counts { @@ -455,13 +681,18 @@ pub mod tests { default_partition: PartitionIndex, ) -> AnalyzedDag { let p_cut = PartitionCut::for_each_precision(dag); - super::analyze(dag, &CONFIG, &Some(p_cut), default_partition, false).unwrap() + super::analyze(dag, &CONFIG, &Some(p_cut), default_partition).unwrap() } #[allow(clippy::float_cmp)] - fn assert_input_on(dag: &AnalyzedDag, partition: usize, op_i: usize, expected_coeff: f64) { + fn assert_input_on( + dag: &AnalyzedDag, + partition: PartitionIndex, + op_i: usize, + expected_coeff: f64, + ) { for symbolic_variance_partition in [LOW_PRECISION_PARTITION, HIGH_PRECISION_PARTITION] { - let sb = dag.out_variances[op_i][partition].clone(); + let sb = dag.instrs_variances[op_i][partition].clone(); let coeff = if sb == SymbolicVariance::ZERO { 0.0 } else { @@ -471,7 +702,7 @@ pub mod tests { assert!( coeff == expected_coeff, "INCORRECT INPUT COEFF ON GOOD PARTITION {:?} {:?} {} {}", - dag.out_variances[op_i], + dag.instrs_variances[op_i], partition, coeff, expected_coeff @@ -480,7 +711,7 @@ pub mod tests { assert!( coeff == 0.0, "INCORRECT INPUT COEFF ON WRONG PARTITION {:?} {:?} {} {}", - dag.out_variances[op_i], + dag.instrs_variances[op_i], partition, coeff, expected_coeff @@ -490,11 +721,16 @@ pub mod tests { } #[allow(clippy::float_cmp)] - fn assert_pbs_on(dag: &AnalyzedDag, partition: usize, op_i: usize, expected_coeff: f64) { + fn assert_pbs_on( + dag: &AnalyzedDag, + partition: PartitionIndex, + op_i: usize, + expected_coeff: f64, + ) { for symbolic_variance_partition in [LOW_PRECISION_PARTITION, HIGH_PRECISION_PARTITION] { - let sb = dag.out_variances[op_i][partition].clone(); - eprintln!("{:?}", dag.out_variances[op_i]); - eprintln!("{:?}", dag.out_variances[op_i][partition]); + let sb = dag.instrs_variances[op_i][partition].clone(); + eprintln!("{:?}", dag.instrs_variances[op_i]); + eprintln!("{:?}", dag.instrs_variances[op_i][partition]); let coeff = if sb == SymbolicVariance::ZERO { 0.0 } else { @@ -504,7 +740,7 @@ pub mod tests { assert!( coeff == expected_coeff, "INCORRECT PBS COEFF ON GOOD PARTITION {:?} {:?} {} {}", - dag.out_variances[op_i], + dag.instrs_variances[op_i], partition, coeff, expected_coeff @@ -513,7 +749,7 @@ pub mod tests { assert!( coeff == 0.0, "INCORRECT PBS COEFF ON GOOD PARTITION {:?} {:?} {} {}", - dag.out_variances[op_i], + dag.instrs_variances[op_i], partition, coeff, expected_coeff @@ -523,22 +759,53 @@ pub mod tests { } #[test] - fn test_composition_with_inputs_only() { + fn test_composition_with_nongrowing_inputs_only() { let mut dag = unparametrized::Dag::new(); - let _ = dag.add_input(1, Shape::number()); + let inp = dag.add_input(1, Shape::number()); + let oup = dag.add_levelled_op( + [inp], + LevelledComplexity::ZERO, + 1.0, + Shape::number(), + "comment", + ); + dag.add_composition(oup, inp); let p_cut = PartitionCut::for_each_precision(&dag); - let res = super::analyze(&dag, &CONFIG, &Some(p_cut), LOW_PRECISION_PARTITION, true); - assert!(res.is_ok()); + let analyzed_dag = + super::analyze(&dag, &CONFIG, &Some(p_cut), LOW_PRECISION_PARTITION).unwrap(); + let last_var = analyzed_dag.instrs_variances[analyzed_dag.instrs_variances.len() - 1] + [PartitionIndex(0)] + .to_string(); + assert_eq!(last_var, "1σ²In[0]"); + } + + #[test] + #[should_panic( + expected = "called `Result::unwrap()` on an `Err` value: NotComposable(\"Dag is not composable, because of output 1: Partition 0 has input coefficient 1.2100000000000002\")" + )] + fn test_composition_with_growing_inputs_panics() { + let mut dag = unparametrized::Dag::new(); + let inp = dag.add_input(1, Shape::number()); + let oup = dag.add_levelled_op( + [inp], + LevelledComplexity::ZERO, + 1.1, + Shape::number(), + "comment", + ); + dag.add_composition(oup, inp); + let p_cut = PartitionCut::for_each_precision(&dag); + let _ = super::analyze(&dag, &CONFIG, &Some(p_cut), LOW_PRECISION_PARTITION).unwrap(); } #[test] fn test_composition_1_partition() { let mut dag = unparametrized::Dag::new(); let input1 = dag.add_input(1, Shape::number()); - let _ = dag.add_lut(input1, FunctionTable::UNKWOWN, 2); + let output = dag.add_lut(input1, FunctionTable::UNKWOWN, 2); + dag.add_composition(output, input1); let p_cut = PartitionCut::for_each_precision(&dag); - let dag = - super::analyze(&dag, &CONFIG, &Some(p_cut), LOW_PRECISION_PARTITION, true).unwrap(); + let dag = super::analyze(&dag, &CONFIG, &Some(p_cut), LOW_PRECISION_PARTITION).unwrap(); assert!(dag.nb_partitions == 1); let actual_constraint_strings = dag .variance_constraints @@ -559,9 +826,9 @@ pub mod tests { let lut1 = dag.add_lut(input1, FunctionTable::UNKWOWN, 6); let lut3 = dag.add_lut(lut1, FunctionTable::UNKWOWN, 3); let input2 = dag.add_dot([input1, lut3], [1, 1]); - let _ = dag.add_lut(input2, FunctionTable::UNKWOWN, 3); - let analyzed_dag = - super::analyze(&dag, &CONFIG, &None, LOW_PRECISION_PARTITION, true).unwrap(); + let output = dag.add_lut(input2, FunctionTable::UNKWOWN, 3); + dag.add_compositions([output], [input1]); + let analyzed_dag = super::analyze(&dag, &CONFIG, &None, LOW_PRECISION_PARTITION).unwrap(); assert_eq!(analyzed_dag.nb_partitions, 2); let actual_constraint_strings = analyzed_dag .variance_constraints @@ -601,9 +868,10 @@ pub mod tests { let lut3 = dag.add_lut(lut1, FunctionTable::UNKWOWN, 3); let a = dag.add_dot([input2, lut3], [1, 1]); let b = dag.add_dot([input1, lut3], [1, 1]); - let _ = dag.add_lut(a, FunctionTable::UNKWOWN, 3); - let _ = dag.add_lut(b, FunctionTable::UNKWOWN, 3); - let analyzed_dag = super::analyze(&dag, &CONFIG, &None, 1, true).unwrap(); + let out1 = dag.add_lut(a, FunctionTable::UNKWOWN, 3); + let out2 = dag.add_lut(b, FunctionTable::UNKWOWN, 3); + dag.add_compositions([out1, out2], [input1, input2]); + let analyzed_dag = super::analyze(&dag, &CONFIG, &None, PartitionIndex(1)).unwrap(); assert_eq!(analyzed_dag.nb_partitions, 3); let actual_constraint_strings = analyzed_dag .variance_constraints @@ -619,7 +887,10 @@ pub mod tests { "1σ²Br[0] < (2²)**-7 (3bits partition:0 count:1, dom=14)", ]; assert_eq!(actual_constraint_strings, expected_constraint_strings); - let partitions = vec![1, 1, 0, 1, 1, 1, 2, 0]; + let partitions = [1, 1, 0, 1, 1, 1, 2, 0] + .into_iter() + .map(PartitionIndex) + .collect::>(); assert_eq!( partitions, analyzed_dag @@ -630,10 +901,10 @@ pub mod tests { ); assert!(analyzed_dag.instrs_partition[6] .alternative_output_representation - .contains(&1)); + .contains(&PartitionIndex(1))); assert!(analyzed_dag.instrs_partition[7] .alternative_output_representation - .contains(&1)); + .contains(&PartitionIndex(1))); } #[allow(clippy::needless_range_loop)] @@ -710,18 +981,18 @@ pub mod tests { // First layer is fully LOW_PRECISION_PARTITION for op_i in input1.0..lut1.0 { let p = LOW_PRECISION_PARTITION; - let sb = &dag.out_variances[op_i][p]; + let sb = &dag.instrs_variances[op_i][p]; assert!(sb.coeff_input(p) >= 1.0 || sb.coeff_pbs(p) >= 1.0); assert!(nan_symbolic_variance( - &dag.out_variances[op_i][HIGH_PRECISION_PARTITION] + &dag.instrs_variances[op_i][HIGH_PRECISION_PARTITION] )); } // First lut is HIGH_PRECISION_PARTITION and immedialtely converted to LOW_PRECISION_PARTITION let p = HIGH_PRECISION_PARTITION; - let sb = &dag.out_variances[lut1.0][p]; + let sb = &dag.instrs_variances[lut1.0][p]; assert!(sb.coeff_input(p) == 0.0); assert!(sb.coeff_pbs(p) == 1.0); - let sb_after_fast_ks = &dag.out_variances[lut1.0][LOW_PRECISION_PARTITION]; + let sb_after_fast_ks = &dag.instrs_variances[lut1.0][LOW_PRECISION_PARTITION]; assert!( sb_after_fast_ks.coeff_partition_keyswitch_to_big( HIGH_PRECISION_PARTITION, @@ -732,7 +1003,7 @@ pub mod tests { for op_i in (lut1.0 + 1)..lut2.0 { assert!(LOW_PRECISION_PARTITION == dag.instrs_partition[op_i].instruction_partition); let p = LOW_PRECISION_PARTITION; - let sb = &dag.out_variances[op_i][p]; + let sb = &dag.instrs_variances[op_i][p]; // The base noise is either from the other partition and shifted or from the current partition and 1 assert!(sb.coeff_input(LOW_PRECISION_PARTITION) == 0.0); assert!(sb.coeff_input(HIGH_PRECISION_PARTITION) == 0.0); @@ -755,9 +1026,9 @@ pub mod tests { } } assert!(nan_symbolic_variance( - &dag.out_variances[lut2.0][LOW_PRECISION_PARTITION] + &dag.instrs_variances[lut2.0][LOW_PRECISION_PARTITION] )); - let sb = &dag.out_variances[lut2.0][HIGH_PRECISION_PARTITION]; + let sb = &dag.instrs_variances[lut2.0][HIGH_PRECISION_PARTITION]; assert!(sb.coeff_pbs(HIGH_PRECISION_PARTITION) >= 1.0); } @@ -776,12 +1047,12 @@ pub mod tests { show_partitionning(&old_dag, &dag.instrs_partition); // First layer is fully HIGH_PRECISION_PARTITION assert!( - dag.out_variances[free_input1.0][HIGH_PRECISION_PARTITION] + dag.instrs_variances[free_input1.0][HIGH_PRECISION_PARTITION] .coeff_input(HIGH_PRECISION_PARTITION) == 1.0 ); // First layer tlu - let sb = &dag.out_variances[input1.0][HIGH_PRECISION_PARTITION]; + let sb = &dag.instrs_variances[input1.0][HIGH_PRECISION_PARTITION]; assert!(sb.coeff_input(LOW_PRECISION_PARTITION) == 0.0); assert!(sb.coeff_pbs(HIGH_PRECISION_PARTITION) == 1.0); assert!( @@ -789,7 +1060,7 @@ pub mod tests { == 0.0 ); // The same cyphertext exists in another partition with additional noise due to fast keyswitch - let sb = &dag.out_variances[input1.0][LOW_PRECISION_PARTITION]; + let sb = &dag.instrs_variances[input1.0][LOW_PRECISION_PARTITION]; assert!(sb.coeff_input(LOW_PRECISION_PARTITION) == 0.0); assert!(sb.coeff_pbs(HIGH_PRECISION_PARTITION) == 1.0); assert!( @@ -801,7 +1072,7 @@ pub mod tests { let mut first_bit_extract_verified = false; let mut first_bit_erase_verified = false; for op_i in (input1.0 + 1)..rounded1.0 { - if let Op::Dot { + if let Operator::Dot { weights, inputs, .. } = &dag.operators[op_i] { @@ -809,7 +1080,7 @@ pub mod tests { let first_bit_extract = bit_extract && !first_bit_extract_verified; let bit_erase = weights.values == [1, -1]; let first_bit_erase = bit_erase && !first_bit_erase_verified; - let input0_sb = &dag.out_variances[inputs[0].0][LOW_PRECISION_PARTITION]; + let input0_sb = &dag.instrs_variances[inputs[0].0][LOW_PRECISION_PARTITION]; let input0_coeff_pbs_high = input0_sb.coeff_pbs(HIGH_PRECISION_PARTITION); let input0_coeff_pbs_low = input0_sb.coeff_pbs(LOW_PRECISION_PARTITION); let input0_coeff_fks = input0_sb.coeff_partition_keyswitch_to_big( @@ -827,7 +1098,7 @@ pub mod tests { assert!(input0_coeff_fks == 1.0); } else if bit_erase { first_bit_erase_verified |= first_bit_erase; - let input1_sb = &dag.out_variances[inputs[1].0][LOW_PRECISION_PARTITION]; + let input1_sb = &dag.instrs_variances[inputs[1].0][LOW_PRECISION_PARTITION]; let input1_coeff_pbs_high = input1_sb.coeff_pbs(HIGH_PRECISION_PARTITION); let input1_coeff_pbs_low = input1_sb.coeff_pbs(LOW_PRECISION_PARTITION); let input1_coeff_fks = input1_sb.coeff_partition_keyswitch_to_big( @@ -1031,14 +1302,8 @@ pub mod tests { _ = dag.add_lut(lut_input, FunctionTable::UNKWOWN, 1); let precisions: Vec<_> = (1..=max_precision).collect(); let p_cut = PartitionCut::from_precisions(&precisions); - let dag = super::analyze( - &dag, - &CONFIG, - &Some(p_cut.clone()), - LOW_PRECISION_PARTITION, - false, - ) - .unwrap(); + let dag = + super::analyze(&dag, &CONFIG, &Some(p_cut.clone()), LOW_PRECISION_PARTITION).unwrap(); assert!(dag.nb_partitions == p_cut.p_cut.len() + 1); } } diff --git a/compilers/concrete-optimizer/concrete-optimizer/src/optimization/dag/multi_parameters/complexity.rs b/compilers/concrete-optimizer/concrete-optimizer/src/optimization/dag/multi_parameters/complexity.rs index 460b634e2..c8a787712 100644 --- a/compilers/concrete-optimizer/concrete-optimizer/src/optimization/dag/multi_parameters/complexity.rs +++ b/compilers/concrete-optimizer/concrete-optimizer/src/optimization/dag/multi_parameters/complexity.rs @@ -2,7 +2,7 @@ use std::fmt; use crate::utils::f64::f64_dot; -use super::operations_value::OperationsValue; +use super::{operations_value::OperationsValue, partitions::PartitionIndex}; #[derive(Clone, Debug)] pub struct OperationsCount { @@ -26,8 +26,8 @@ impl fmt::Display for OperationsCount { let counts = &self.counts; let nb_partitions = counts.nb_partitions(); let index = &counts.index; - for src_partition in 0..nb_partitions { - for dst_partition in 0..nb_partitions { + for src_partition in PartitionIndex::range(0, nb_partitions) { + for dst_partition in PartitionIndex::range(0, nb_partitions) { let coeff = counts.values[index.keyswitch_to_small(src_partition, dst_partition)]; if coeff != 0.0 { if src_partition == dst_partition { @@ -39,14 +39,14 @@ impl fmt::Display for OperationsCount { } } } - for src_partition in 0..nb_partitions { + for src_partition in PartitionIndex::range(0, nb_partitions) { assert!(counts.values[index.input(src_partition)] == 0.0); let coeff = counts.values[index.pbs(src_partition)]; if coeff != 0.0 { write!(f, "{add_plus}{coeff}¢Br[{src_partition}]")?; add_plus = " + "; } - for dst_partition in 0..nb_partitions { + for dst_partition in PartitionIndex::range(0, nb_partitions) { let coeff = counts.values[index.keyswitch_to_big(src_partition, dst_partition)]; if coeff != 0.0 { write!(f, "{add_plus}{coeff}¢FK[{src_partition}→{dst_partition}]")?; @@ -55,7 +55,7 @@ impl fmt::Display for OperationsCount { } } - for partition in 0..nb_partitions { + for partition in PartitionIndex::range(0, nb_partitions) { assert!(counts.values[index.modulus_switching(partition)] == 0.0); } if add_plus.is_empty() { @@ -80,8 +80,8 @@ impl Complexity { &self, complexity_cut: f64, costs: &OperationsValue, - src_partition: usize, - dst_partition: usize, + src_partition: PartitionIndex, + dst_partition: PartitionIndex, ) -> f64 { let ks_index = costs.index.keyswitch_to_small(src_partition, dst_partition); let actual_ks_cost = costs.values[ks_index]; @@ -98,8 +98,8 @@ impl Complexity { &self, complexity_cut: f64, costs: &OperationsValue, - src_partition: usize, - dst_partition: usize, + src_partition: PartitionIndex, + dst_partition: PartitionIndex, ) -> f64 { let fks_index = costs.index.keyswitch_to_big(src_partition, dst_partition); let actual_fks_cost = costs.values[fks_index]; diff --git a/compilers/concrete-optimizer/concrete-optimizer/src/optimization/dag/multi_parameters/feasible.rs b/compilers/concrete-optimizer/concrete-optimizer/src/optimization/dag/multi_parameters/feasible.rs index f0bf81fc0..c8ff041fd 100644 --- a/compilers/concrete-optimizer/concrete-optimizer/src/optimization/dag/multi_parameters/feasible.rs +++ b/compilers/concrete-optimizer/concrete-optimizer/src/optimization/dag/multi_parameters/feasible.rs @@ -29,7 +29,7 @@ impl Feasible { pub fn pbs_max_feasible_variance( &self, operations_variance: &OperationsValue, - partition: usize, + partition: PartitionIndex, ) -> f64 { let pbs_index = operations_variance.index.pbs(partition); let actual_pbs_variance = operations_variance.values[pbs_index]; @@ -52,8 +52,8 @@ impl Feasible { pub fn ks_max_feasible_variance( &self, operations_variance: &OperationsValue, - src_partition: usize, - dst_partition: usize, + src_partition: PartitionIndex, + dst_partition: PartitionIndex, ) -> f64 { let ks_index = operations_variance .index @@ -81,8 +81,8 @@ impl Feasible { pub fn fks_max_feasible_variance( &self, operations_variance: &OperationsValue, - src_partition: usize, - dst_partition: usize, + src_partition: PartitionIndex, + dst_partition: PartitionIndex, ) -> f64 { let fks_index = operations_variance .index @@ -195,7 +195,7 @@ impl Feasible { .iter() .filter(|constraint| { constraint.partition == partition - || (0..nb_partitions).any(|i| touch_any_ks(constraint, i)) + || PartitionIndex::range(0, nb_partitions).any(|i| touch_any_ks(constraint, i)) }) .cloned() .collect(); diff --git a/compilers/concrete-optimizer/concrete-optimizer/src/optimization/dag/multi_parameters/mod.rs b/compilers/concrete-optimizer/concrete-optimizer/src/optimization/dag/multi_parameters/mod.rs index 661b7e13b..6bd16dd23 100644 --- a/compilers/concrete-optimizer/concrete-optimizer/src/optimization/dag/multi_parameters/mod.rs +++ b/compilers/concrete-optimizer/concrete-optimizer/src/optimization/dag/multi_parameters/mod.rs @@ -1,4 +1,4 @@ -mod analyze; +pub(crate) mod analyze; mod complexity; mod fast_keyswitch; mod feasible; diff --git a/compilers/concrete-optimizer/concrete-optimizer/src/optimization/dag/multi_parameters/operations_value.rs b/compilers/concrete-optimizer/concrete-optimizer/src/optimization/dag/multi_parameters/operations_value.rs index 875706aff..e70b21426 100644 --- a/compilers/concrete-optimizer/concrete-optimizer/src/optimization/dag/multi_parameters/operations_value.rs +++ b/compilers/concrete-optimizer/concrete-optimizer/src/optimization/dag/multi_parameters/operations_value.rs @@ -1,5 +1,7 @@ use std::ops::{Deref, DerefMut}; +use super::partitions::PartitionIndex; + /** * Index actual operations (input, ks, pbs, fks, modulus switching, etc). */ @@ -80,46 +82,54 @@ impl Indexing { self.nb_partitions * (STABLE_NB_VALUES_BY_PARTITION + 2 * self.nb_partitions) } - pub fn input(&self, partition: usize) -> usize { - assert!(partition < self.nb_partitions); - self.maybe_compressed(partition * self.nb_coeff_per_partition() + VALUE_INDEX_FRESH) + pub fn input(&self, partition: PartitionIndex) -> usize { + assert!(partition.0 < self.nb_partitions); + self.maybe_compressed(partition.0 * self.nb_coeff_per_partition() + VALUE_INDEX_FRESH) } - pub fn pbs(&self, partition: usize) -> usize { - assert!(partition < self.nb_partitions); - self.maybe_compressed(partition * self.nb_coeff_per_partition() + VALUE_INDEX_PBS) + pub fn pbs(&self, partition: PartitionIndex) -> usize { + assert!(partition.0 < self.nb_partitions); + self.maybe_compressed(partition.0 * self.nb_coeff_per_partition() + VALUE_INDEX_PBS) } - pub fn modulus_switching(&self, partition: usize) -> usize { - assert!(partition < self.nb_partitions); - self.maybe_compressed(partition * self.nb_coeff_per_partition() + VALUE_INDEX_MODULUS) + pub fn modulus_switching(&self, partition: PartitionIndex) -> usize { + assert!(partition.0 < self.nb_partitions); + self.maybe_compressed(partition.0 * self.nb_coeff_per_partition() + VALUE_INDEX_MODULUS) } - pub fn keyswitch_to_small(&self, src_partition: usize, dst_partition: usize) -> usize { - assert!(src_partition < self.nb_partitions); - assert!(dst_partition < self.nb_partitions); + pub fn keyswitch_to_small( + &self, + src_partition: PartitionIndex, + dst_partition: PartitionIndex, + ) -> usize { + assert!(src_partition.0 < self.nb_partitions); + assert!(dst_partition.0 < self.nb_partitions); self.maybe_compressed( // Skip other partition - dst_partition * self.nb_coeff_per_partition() + dst_partition.0 * self.nb_coeff_per_partition() // Skip non keyswitchs + STABLE_NB_VALUES_BY_PARTITION // Select the right keyswicth to small - + src_partition, + + src_partition.0, ) } - pub fn keyswitch_to_big(&self, src_partition: usize, dst_partition: usize) -> usize { - assert!(src_partition < self.nb_partitions); - assert!(dst_partition < self.nb_partitions); + pub fn keyswitch_to_big( + &self, + src_partition: PartitionIndex, + dst_partition: PartitionIndex, + ) -> usize { + assert!(src_partition.0 < self.nb_partitions); + assert!(dst_partition.0 < self.nb_partitions); self.maybe_compressed( // Skip other partition - dst_partition * self.nb_coeff_per_partition() + dst_partition.0 * self.nb_coeff_per_partition() // Skip non keyswitchs + STABLE_NB_VALUES_BY_PARTITION // Skip keyswitch to small + self.nb_keyswitchs_per_partition() // Select the right keyswicth to big - + src_partition, + + src_partition.0, ) } @@ -131,12 +141,23 @@ impl Indexing { /** * Represent any values indexed by actual operations (input, pbs, modulus switching, ks, fks, , etc) variance, */ -#[derive(Clone, Debug, PartialEq, PartialOrd)] +#[derive(Clone, Debug, PartialOrd)] pub struct OperationsValue { pub index: Indexing, pub values: Vec, } +impl PartialEq for OperationsValue { + fn eq(&self, other: &Self) -> bool { + self.index == other.index + && self + .values + .iter() + .zip(other.values.iter()) + .all(|(a, b)| a.is_nan() && b.is_nan() || *a == *b) + } +} + impl OperationsValue { pub const ZERO: Self = Self { index: Indexing { @@ -172,23 +193,27 @@ impl OperationsValue { } } - pub fn input(&mut self, partition: usize) -> &mut f64 { + pub fn input(&mut self, partition: PartitionIndex) -> &mut f64 { &mut self.values[self.index.input(partition)] } - pub fn pbs(&mut self, partition: usize) -> &mut f64 { + pub fn pbs(&mut self, partition: PartitionIndex) -> &mut f64 { &mut self.values[self.index.pbs(partition)] } - pub fn ks(&mut self, src_partition: usize, dst_partition: usize) -> &mut f64 { + pub fn ks(&mut self, src_partition: PartitionIndex, dst_partition: PartitionIndex) -> &mut f64 { &mut self.values[self.index.keyswitch_to_small(src_partition, dst_partition)] } - pub fn fks(&mut self, src_partition: usize, dst_partition: usize) -> &mut f64 { + pub fn fks( + &mut self, + src_partition: PartitionIndex, + dst_partition: PartitionIndex, + ) -> &mut f64 { &mut self.values[self.index.keyswitch_to_big(src_partition, dst_partition)] } - pub fn modulus_switching(&mut self, partition: usize) -> &mut f64 { + pub fn modulus_switching(&mut self, partition: PartitionIndex) -> &mut f64 { &mut self.values[self.index.modulus_switching(partition)] } diff --git a/compilers/concrete-optimizer/concrete-optimizer/src/optimization/dag/multi_parameters/optimize/mod.rs b/compilers/concrete-optimizer/concrete-optimizer/src/optimization/dag/multi_parameters/optimize/mod.rs index b60ede5e2..572d6561f 100644 --- a/compilers/concrete-optimizer/concrete-optimizer/src/optimization/dag/multi_parameters/optimize/mod.rs +++ b/compilers/concrete-optimizer/concrete-optimizer/src/optimization/dag/multi_parameters/optimize/mod.rs @@ -68,9 +68,9 @@ struct OperationsCV { cost: OperationsValue, } -type KsSrc = usize; -type KsDst = usize; -type FksSrc = usize; +type KsSrc = PartitionIndex; +type KsDst = PartitionIndex; +type FksSrc = PartitionIndex; #[inline(never)] fn optimize_1_ks( @@ -124,7 +124,8 @@ fn optimize_many_independant_ks( let mut operations = operations.clone(); let mut ks_bests = Vec::with_capacity(macro_parameters.len()); for (ks_dst, macro_dst) in macro_parameters.iter().enumerate() { - if !ks_used[ks_src][ks_dst] { + let ks_dst = PartitionIndex(ks_dst); + if !ks_used[ks_src.0][ks_dst.0] { continue; } let output_dim = macro_dst.internal_dim; @@ -153,8 +154,8 @@ struct Best1FksAndManyKs { fn optimize_1_fks_and_all_compatible_ks( macro_parameters: &[MacroParameters], ks_used: &[Vec], - fks_src: usize, - fks_dst: usize, + fks_src: PartitionIndex, + fks_dst: PartitionIndex, operations: &OperationsCV, feasible: &Feasible, complexity: &Complexity, @@ -164,15 +165,15 @@ fn optimize_1_fks_and_all_compatible_ks( fft_precision: u32, ) -> Option<(Best1FksAndManyKs, OperationsCV)> { // At this point every thing else is known apart fks and ks - let input_glwe = macro_parameters[fks_src].glwe_params; - let output_glwe = macro_parameters[fks_dst].glwe_params; + let input_glwe = macro_parameters[fks_src.0].glwe_params; + let output_glwe = macro_parameters[fks_dst.0].glwe_params; let output_lwe_dim = output_glwe.sample_extract_lwe_dimension(); // OPT: have a separate cache for fks let ks_pareto = caches.pareto_quantities(output_lwe_dim).to_owned(); // TODO: fast ks in the other direction as well let use_fast_ks = REAL_FAST_KS && input_glwe.sample_extract_lwe_dimension() >= output_lwe_dim; let ks_src = fks_dst; - let ks_input_dim = macro_parameters[fks_dst] + let ks_input_dim = macro_parameters[fks_dst.0] .glwe_params .sample_extract_lwe_dimension(); let mut operations = operations.clone(); @@ -296,7 +297,7 @@ fn optimize_dst_exclusive_fks_subset_and_all_ks( macro_parameters, ks_used, *fks_src, - fks_dst, + PartitionIndex(fks_dst), &acc_operations, feasible, complexity, @@ -311,7 +312,7 @@ fn optimize_dst_exclusive_fks_subset_and_all_ks( // There is no fks to optimize let (many_ks, operations) = optimize_many_independant_ks( macro_parameters, - ks_src, + PartitionIndex(ks_src), ks_input_lwe_dim, ks_used, &acc_operations, @@ -408,10 +409,10 @@ fn optimize_1_cmux_and_dst_exclusive_fks_subset_and_all_ks( let mut ks = vec![vec![None; nb_partitions]; nb_partitions]; for (fks_dst, one_best_fks_ks) in best_fks_ks.iter().enumerate() { if let Some((fks_src, sol_fks)) = one_best_fks_ks.fks { - fks[fks_src][fks_dst] = Some(sol_fks); + fks[fks_src.0][fks_dst] = Some(sol_fks); } for (ks_dst, sol_ks) in &one_best_fks_ks.many_ks { - ks[fks_dst][*ks_dst] = Some(*sol_ks); + ks[fks_dst][ks_dst.0] = Some(*sol_ks); } } Some(PartialMicroParameters { @@ -432,11 +433,11 @@ fn apply_all_ks_lower_bound( operations: &mut OperationsCV, ) { for (src, dst) in cross_partition(nb_partitions) { - if !used_tlu_keyswitch[src][dst] { + if !used_tlu_keyswitch[src.0][dst.0] { continue; } - let in_glwe_params = macro_parameters[src].glwe_params; - let out_internal_dim = macro_parameters[dst].internal_dim; + let in_glwe_params = macro_parameters[src.0].glwe_params; + let out_internal_dim = macro_parameters[dst.0].internal_dim; let ks_pareto = caches.pareto_quantities(out_internal_dim); let in_lwe_dim = in_glwe_params.sample_extract_lwe_dimension(); *operations.variance.ks(src, dst) = keyswitch::lowest_noise_ks(ks_pareto, in_lwe_dim); @@ -456,11 +457,11 @@ fn apply_fks_variance_and_cost_or_lower_bound( fft_precision: u32, ) { for (src, dst) in cross_partition(nb_partitions) { - if !used_conversion_keyswitch[src][dst] { + if !used_conversion_keyswitch[src.0][dst.0] { continue; } - let input_glwe = ¯o_parameters[src].glwe_params; - let output_glwe = ¯o_parameters[dst].glwe_params; + let input_glwe = ¯o_parameters[src.0].glwe_params; + let output_glwe = ¯o_parameters[dst.0].glwe_params; if input_glwe == output_glwe { *operations.variance.fks(src, dst) = 0.0; *operations.cost.fks(src, dst) = 0.0; @@ -468,8 +469,8 @@ fn apply_fks_variance_and_cost_or_lower_bound( } // if an optimized fks is applicable and is not to be optimized // we use the already optimized fks instead of a lower bound - if let Some(fks) = initial_fks[src][dst] { - let to_be_optimized = fks_to_optimize[src].map_or(false, |fdst| dst == fdst); + if let Some(fks) = initial_fks[src.0][dst.0] { + let to_be_optimized = fks_to_optimize[src.0].map_or(false, |fdst| dst == fdst); if !to_be_optimized { if input_glwe == &fks.src_glwe_param && output_glwe == &fks.dst_glwe_param { *operations.variance.fks(src, dst) = fks.noise; @@ -519,17 +520,17 @@ fn apply_partitions_input_and_modulus_variance_and_cost( variance_modulus_switching: f64, operations: &mut OperationsCV, ) { - for i in 0..nb_partitions { + for i in PartitionIndex::range(0, nb_partitions) { let (input_variance, variance_modulus_switching) = - if macro_parameters[i] == macro_parameters[partition] { + if macro_parameters[i.0] == macro_parameters[partition.0] { (input_variance, variance_modulus_switching) } else { - let input_variance = macro_parameters[i] + let input_variance = macro_parameters[i.0] .glwe_params .minimal_variance(ciphertext_modulus_log, security_level); let variance_modulus_switching = estimate_modulus_switching_noise_with_binary_key( - macro_parameters[i].internal_dim, - macro_parameters[i].glwe_params.log2_polynomial_size, + macro_parameters[i.0].internal_dim, + macro_parameters[i.0].glwe_params.log2_polynomial_size, ciphertext_modulus_log, ); (input_variance, variance_modulus_switching) @@ -548,15 +549,16 @@ fn apply_pbs_variance_and_cost_or_lower_bounds( ) { // setting already chosen pbs and lower bounds for (i, pbs) in initial_pbs.iter().enumerate() { + let i = PartitionIndex(i); let pbs = if i == partition { &None } else { pbs }; if let Some(pbs) = pbs { - let internal_dim = macro_parameters[i].internal_dim; + let internal_dim = macro_parameters[i.0].internal_dim; *operations.variance.pbs(i) = pbs.noise_br(internal_dim); *operations.cost.pbs(i) = pbs.complexity_br(internal_dim); } else { // OPT: Most values could be shared on first optimize_macro - let in_internal_dim = macro_parameters[i].internal_dim; - let out_glwe_params = macro_parameters[i].glwe_params; + let in_internal_dim = macro_parameters[i.0].internal_dim; + let out_glwe_params = macro_parameters[i.0].glwe_params; let variance_min = cmux::lowest_noise_br(caches.pareto_quantities(out_glwe_params), in_internal_dim); *operations.variance.pbs(i) = variance_min; @@ -576,28 +578,28 @@ fn fks_to_optimize( // When fks is unused a None is used to keep the same loop structure. let mut fks_paretos: Vec> = vec![]; fks_paretos.reserve_exact(nb_partitions); - for fks_dst in 0..nb_partitions { + for fks_dst in PartitionIndex::range(0, nb_partitions) { // find the i-th valid fks_src - let fks_src = if used_conversion_keyswitch[optimized_partition][fks_dst] { + let fks_src = if used_conversion_keyswitch[optimized_partition.0][fks_dst.0] { Some(optimized_partition) } else { let mut count_used: usize = 0; let mut fks_src = None; #[allow(clippy::needless_range_loop)] - for src in 0..nb_partitions { - let used = used_conversion_keyswitch[src][fks_dst]; - if used && count_used == optimized_partition { + for src in PartitionIndex::range(0, nb_partitions) { + let used = used_conversion_keyswitch[src.0][fks_dst.0]; + if used && count_used == optimized_partition.0 { fks_src = Some(src); break; } count_used += used as usize; } if fks_src.is_none() && count_used > 0 { - let n_th = optimized_partition % count_used; + let n_th = optimized_partition.0 % count_used; count_used = 0; #[allow(clippy::needless_range_loop)] - for src in 0..nb_partitions { - let used = used_conversion_keyswitch[src][fks_dst]; + for src in PartitionIndex::range(0, nb_partitions) { + let used = used_conversion_keyswitch[src.0][fks_dst.0]; if used && count_used == n_th { fks_src = Some(src); break; @@ -632,7 +634,7 @@ fn optimize_macro( best_p_error: f64, ) -> Parameters { let nb_partitions = init_parameters.macro_params.len(); - assert!(partition < nb_partitions); + assert!(partition.0 < nb_partitions); let variance_modulus_switching_of = |glwe_log2_poly_size, internal_lwe_dimensions| { estimate_modulus_switching_noise_with_binary_key( @@ -686,12 +688,12 @@ fn optimize_macro( }; // Heuristic to fill missing macro parameters - let macros: Vec<_> = (0..nb_partitions) + let macros: Vec<_> = PartitionIndex::range(0, nb_partitions) .map(|i| { if i == partition { macro_param_partition } else { - init_parameters.macro_params[i].unwrap_or(macro_param_partition) + init_parameters.macro_params[i.0].unwrap_or(macro_param_partition) } }) .collect(); @@ -765,7 +767,7 @@ fn optimize_macro( // here we optimize for feasibility only // if nothing is feasible, it will give improves feasability for later iterations let mut macro_params = init_parameters.macro_params.clone(); - macro_params[partition] = Some(MacroParameters { + macro_params[partition.0] = Some(MacroParameters { glwe_params, internal_dim, }); @@ -781,7 +783,7 @@ fn optimize_macro( let p_error = feasible.p_error(&operations.variance); let global_p_error = feasible.global_p_error(&operations.variance); let mut pbs = init_parameters.micro_params.pbs.clone(); - pbs[partition] = Some(cmux_params); + pbs[partition.0] = Some(cmux_params); let micro_params = MicroParameters { pbs, ks: vec![vec![None; nb_partitions]; nb_partitions], @@ -825,35 +827,36 @@ fn optimize_macro( // optimize_micro has already checked for best-ness lb_message = None; let mut macro_params = init_parameters.macro_params.clone(); - macro_params[partition] = Some(macro_param_partition); + macro_params[partition.0] = Some(macro_param_partition); let mut is_lower_bound = macro_params.iter().any(Option::is_none); if is_lower_bound { lb_message = Some("is_lower_bound due to missing macro parameter"); } // copy back pbs from other partition let mut all_pbs = init_parameters.micro_params.pbs.clone(); - all_pbs[partition] = Some(some_micro_params.pbs); + all_pbs[partition.0] = Some(some_micro_params.pbs); let mut all_fks = init_parameters.micro_params.fks.clone(); for (dst_partition, maybe_fks) in fks_to_optimize.iter().enumerate() { + let dst_partition = PartitionIndex(dst_partition); if let &Some(src_partition) = maybe_fks { - all_fks[src_partition][dst_partition] = - some_micro_params.fks[src_partition][dst_partition]; - assert!(used_conversion_keyswitch[src_partition][dst_partition]); - assert!(all_fks[src_partition][dst_partition].is_some()); + all_fks[src_partition.0][dst_partition.0] = + some_micro_params.fks[src_partition.0][dst_partition.0]; + assert!(used_conversion_keyswitch[src_partition.0][dst_partition.0]); + assert!(all_fks[src_partition.0][dst_partition.0].is_some()); } } // As all fks cannot be re-optimized in some case, we need to check previous ones are still valid. for (src_partition, dst_partition) in cross_partition(nb_partitions) { - if !used_conversion_keyswitch[src_partition][dst_partition] { + if !used_conversion_keyswitch[src_partition.0][dst_partition.0] { continue; } - let fks = &all_fks[src_partition][dst_partition]; + let fks = &all_fks[src_partition.0][dst_partition.0]; if !is_lower_bound && fks.is_none() { lb_message = Some("is_lower_bound due to missing fast keyswitch parameter"); is_lower_bound = true; } - let src_glwe_param = macro_params[src_partition].map(|p| p.glwe_params); - let dst_glwe_param = macro_params[dst_partition].map(|p| p.glwe_params); + let src_glwe_param = macro_params[src_partition.0].map(|p| p.glwe_params); + let dst_glwe_param = macro_params[dst_partition.0].map(|p| p.glwe_params); let src_glwe_param_stable = src_glwe_param == fks.map(|p| p.src_glwe_param); let dst_glwe_param_stable = dst_glwe_param == fks.map(|p| p.dst_glwe_param); if src_glwe_param_stable && dst_glwe_param_stable { @@ -862,7 +865,7 @@ fn optimize_macro( if !is_lower_bound { lb_message = Some("is_lower_bound due to changing others fks macro param"); } - all_fks[src_partition][dst_partition] = None; + all_fks[src_partition.0][dst_partition.0] = None; is_lower_bound = true; } let micro_params = MicroParameters { @@ -894,8 +897,9 @@ fn optimize_macro( best_parameters } -fn cross_partition(nb_partitions: usize) -> impl Iterator { - (0..nb_partitions).flat_map(move |a: usize| (0..nb_partitions).map(move |b: usize| (a, b))) +fn cross_partition(nb_partitions: usize) -> impl Iterator { + PartitionIndex::range(0, nb_partitions) + .flat_map(move |a| PartitionIndex::range(0, nb_partitions).map(move |b| (a, b))) } #[allow(clippy::too_many_lines, clippy::missing_errors_doc)] @@ -910,14 +914,13 @@ pub fn optimize( let ciphertext_modulus_log = config.ciphertext_modulus_log; let fft_precision = config.fft_precision; let security_level = config.security_level; - let composable = config.composable; let noise_config = NoiseBoundConfig { security_level, maximum_acceptable_error_probability: config.maximum_acceptable_error_probability, ciphertext_modulus_log, }; - let dag = analyze(dag, &noise_config, p_cut, default_partition, composable)?; + let dag = analyze(dag, &noise_config, p_cut, default_partition)?; let kappa = error::sigma_scale_of_error_probability(config.maximum_acceptable_error_probability); @@ -950,7 +953,7 @@ pub fn optimize( let mut fix_point = params.clone(); let mut best_params: Option = None; for iter in 0..=10 { - for partition in (0..nb_partitions).rev() { + for partition in PartitionIndex::range(0, nb_partitions).rev() { let new_params = optimize_macro( security_level, ciphertext_modulus_log, @@ -1045,7 +1048,7 @@ fn used_tlu_keyswitch(dag: &AnalyzedDag) -> Vec> { .coeff_keyswitch_to_small(src_partition, dst_partition) != 0.0 { - result[src_partition][dst_partition] = true; + result[src_partition.0][dst_partition.0] = true; break; } } @@ -1062,7 +1065,7 @@ fn used_conversion_keyswitch(dag: &AnalyzedDag) -> Vec> { .coeff_partition_keyswitch_to_big(src_partition, dst_partition) != 0.0 { - result[src_partition][dst_partition] = true; + result[src_partition.0][dst_partition.0] = true; break; } } @@ -1091,8 +1094,8 @@ fn sanity_check( cost: complexity.zero_cost(), }; let micro_params = ¶ms.micro_params; - for partition in 0..nb_partitions { - let partition_macro = params.macro_params[partition].unwrap(); + for partition in PartitionIndex::range(0, nb_partitions) { + let partition_macro = params.macro_params[partition.0].unwrap(); let glwe_param = partition_macro.glwe_params; let internal_dim = partition_macro.internal_dim; let input_variance = glwe_param.minimal_variance(ciphertext_modulus_log, security_level); @@ -1103,42 +1106,42 @@ fn sanity_check( ); *operations.variance.input(partition) = input_variance; *operations.variance.modulus_switching(partition) = variance_modulus_switching; - if let Some(pbs) = micro_params.pbs[partition] { + if let Some(pbs) = micro_params.pbs[partition.0] { *operations.variance.pbs(partition) = pbs.noise_br(internal_dim); *operations.cost.pbs(partition) = pbs.complexity_br(internal_dim); } else { *operations.variance.pbs(partition) = f64::MAX; *operations.cost.pbs(partition) = f64::MAX; } - for src_partition in 0..nb_partitions { - let src_partition_macro = params.macro_params[src_partition].unwrap(); + for src_partition in PartitionIndex::range(0, nb_partitions) { + let src_partition_macro = params.macro_params[src_partition.0].unwrap(); let src_glwe_param = src_partition_macro.glwe_params; let src_lwe_dim = src_glwe_param.sample_extract_lwe_dimension(); - if let Some(ks) = micro_params.ks[src_partition][partition] { + if let Some(ks) = micro_params.ks[src_partition.0][partition.0] { assert!( - used_tlu_keyswitch[src_partition][partition], + used_tlu_keyswitch[src_partition.0][partition.0], "Superflous ks[{src_partition}->{partition}]" ); *operations.variance.ks(src_partition, partition) = ks.noise(src_lwe_dim); *operations.cost.ks(src_partition, partition) = ks.complexity(src_lwe_dim); } else { assert!( - !used_tlu_keyswitch[src_partition][partition], + !used_tlu_keyswitch[src_partition.0][partition.0], "Missing ks[{src_partition}->{partition}]" ); *operations.variance.ks(src_partition, partition) = f64::MAX; *operations.cost.ks(src_partition, partition) = f64::MAX; } - if let Some(fks) = micro_params.fks[src_partition][partition] { + if let Some(fks) = micro_params.fks[src_partition.0][partition.0] { assert!( - used_conversion_keyswitch[src_partition][partition], + used_conversion_keyswitch[src_partition.0][partition.0], "Superflous fks[{src_partition}->{partition}]" ); *operations.variance.fks(src_partition, partition) = fks.noise; *operations.cost.fks(src_partition, partition) = fks.complexity; } else { assert!( - !used_conversion_keyswitch[src_partition][partition], + !used_conversion_keyswitch[src_partition.0][partition.0], "Missing fks[{src_partition}->{partition}]" ); *operations.variance.fks(src_partition, partition) = f64::MAX; @@ -1164,7 +1167,7 @@ pub fn optimize_to_circuit_solution( ) -> keys_spec::CircuitSolution { if lut_count_from_dag(dag) == 0 { // If there are no lut in the dag the noise is never refresh so the dag cannot be composable - if config.composable { + if dag.is_composed() { return keys_spec::CircuitSolution::no_solution( NotComposable("No luts in the circuit.".into()).to_string(), ); @@ -1176,7 +1179,7 @@ pub fn optimize_to_circuit_solution( } return keys_spec::CircuitSolution::no_solution(NoParametersFound.to_string()); } - let default_partition = 0; + let default_partition = PartitionIndex::FIRST; let dag_and_params = optimize( dag, config, diff --git a/compilers/concrete-optimizer/concrete-optimizer/src/optimization/dag/multi_parameters/optimize/tests.rs b/compilers/concrete-optimizer/concrete-optimizer/src/optimization/dag/multi_parameters/optimize/tests.rs index 578467669..1df078e1a 100644 --- a/compilers/concrete-optimizer/concrete-optimizer/src/optimization/dag/multi_parameters/optimize/tests.rs +++ b/compilers/concrete-optimizer/concrete-optimizer/src/optimization/dag/multi_parameters/optimize/tests.rs @@ -28,7 +28,7 @@ static SHARED_CACHES: Lazy = Lazy::new(|| { const _4_SIGMA: f64 = 0.000_063_342_483_999_973; -const LOW_PARTITION: PartitionIndex = 0; +const LOW_PARTITION: PartitionIndex = PartitionIndex(0); static CPU_COMPLEXITY: Lazy = Lazy::new(CpuComplexity::default); @@ -41,14 +41,13 @@ fn default_config() -> Config<'static> { ciphertext_modulus_log: 64, fft_precision: 53, complexity_model, - composable: false, } } fn optimize( dag: &unparametrized::Dag, p_cut: &Option, - default_partition: usize, + default_partition: PartitionIndex, ) -> Option { let config = default_config(); let search_space = SearchSpace::default_cpu(); @@ -288,7 +287,7 @@ fn optimize_multi_independant_2_partitions_finally_added_and_luted() { let dag_multi = dag_lut_sum_of_2_partitions_2_layer(precision1, precision2, true); let sol_1 = single_precision_sol[precision1 as usize].clone(); let sol_2 = single_precision_sol[precision2 as usize].clone(); - let sol_multi = optimize(&dag_multi, &p_cut, 0); + let sol_multi = optimize(&dag_multi, &p_cut, PartitionIndex(0)); let feasible_multi = sol_multi.is_some(); let feasible_2 = sol_2.is_some(); assert!(feasible_multi); @@ -335,7 +334,7 @@ fn optimize_multi_independant_2_partitions_finally_added_and_luted() { fn optimize_rounded(dag: &unparametrized::Dag) -> Option { let p_cut = Some(PartitionCut::from_precisions(&[1, 128])); - let default_partition = 0; + let default_partition = PartitionIndex(0); optimize(dag, &p_cut, default_partition) } @@ -385,8 +384,8 @@ fn test_optimize_v3_expanded_round( [true, false], // FKS[1->0] ]; for (src, dst) in cross_partition(2) { - assert!(sol.micro_params.ks[src][dst].is_some() == expected_ks[src][dst]); - assert!(sol.micro_params.fks[src][dst].is_some() == expected_fks[src][dst]); + assert!(sol.micro_params.ks[src.0][dst.0].is_some() == expected_ks[src.0][dst.0]); + assert!(sol.micro_params.fks[src.0][dst.0].is_some() == expected_fks[src.0][dst.0]); } } @@ -477,7 +476,7 @@ fn test_partition_chain(decreasing: bool) { ); _ = dag.add_lut(lut_input, FunctionTable::UNKWOWN, min_precision); let mut p_cut = PartitionCut::empty(); - let sol = optimize(&dag, &Some(p_cut.clone()), 0).unwrap(); + let sol = optimize(&dag, &Some(p_cut.clone()), PartitionIndex(0)).unwrap(); assert!(sol.macro_params.len() == 1); let mut complexity = sol.complexity; for &out_precision in &input_precisions { @@ -487,7 +486,7 @@ fn test_partition_chain(decreasing: bool) { } p_cut.p_cut.push((out_precision, f64::MAX)); p_cut.p_cut.sort_by(|a, b| a.partial_cmp(b).unwrap()); - let sol = optimize(&dag, &Some(p_cut.clone()), 0).unwrap(); + let sol = optimize(&dag, &Some(p_cut.clone()), PartitionIndex(0)).unwrap(); let nb_partitions = sol.macro_params.len(); assert!( nb_partitions == (p_cut.p_cut.len() + 1), @@ -500,22 +499,23 @@ fn test_partition_chain(decreasing: bool) { sol.complexity ); for (src, dst) in cross_partition(nb_partitions) { - let ks = sol.micro_params.ks[src][dst]; + let ks = sol.micro_params.ks[src.0][dst.0]; eprintln!("{} {src} {dst}", ks.is_some()); - let expected_ks = (!decreasing || src == dst + 1) && (decreasing || src + 1 == dst) - || (src == dst && (src == 0 || src == nb_partitions - 1)); + let expected_ks = (!decreasing || src.0 == dst.0 + 1) + && (decreasing || src.0 + 1 == dst.0) + || (src == dst && (src == PartitionIndex(0) || src.0 == nb_partitions - 1)); assert!( ks.is_some() == expected_ks, "{:?} {:?}", ks.is_some(), expected_ks ); - let fks = sol.micro_params.fks[src][dst]; + let fks = sol.micro_params.fks[src.0][dst.0]; assert!(fks.is_none()); } complexity = sol.complexity; } - let sol = optimize(&dag, &None, 0); + let sol = optimize(&dag, &None, PartitionIndex(0)); assert!(sol.unwrap().complexity == complexity); } @@ -567,7 +567,7 @@ fn test_independant_partitions_non_feasible_single_params() { let sol_single = solo_key::optimize::tests::optimize(&dag).best_solution; assert!(sol_single.is_none()); // solves in multi - let sol = optimize(&dag, &None, 0); + let sol = optimize(&dag, &None, PartitionIndex(0)); assert!(sol.is_some()); let sol = sol.unwrap(); // check optimality @@ -600,7 +600,7 @@ fn test_chained_partitions_non_feasible_single_params() { ); let sol_single = solo_key::optimize::tests::optimize(&dag).best_solution; assert!(sol_single.is_none()); - let sol = optimize(&dag, &None, 0); + let sol = optimize(&dag, &None, PartitionIndex(0)); assert!(sol.is_some()); } @@ -611,13 +611,13 @@ fn test_multi_rounded_fks_coherency() { let reduced_8 = dag.add_expanded_rounded_lut(input1, FunctionTable::UNKWOWN, 8, 8); let reduced_4 = dag.add_expanded_rounded_lut(input1, FunctionTable::UNKWOWN, 4, 8); _ = dag.add_dot([reduced_8, reduced_4], [1, 1]); - let sol = optimize(&dag, &None, 0); + let sol = optimize(&dag, &None, PartitionIndex(0)); assert!(sol.is_some()); let sol = sol.unwrap(); for (src, dst) in cross_partition(sol.macro_params.len()) { - if let Some(fks) = sol.micro_params.fks[src][dst] { - assert!(fks.src_glwe_param == sol.macro_params[src].unwrap().glwe_params); - assert!(fks.dst_glwe_param == sol.macro_params[dst].unwrap().glwe_params); + if let Some(fks) = sol.micro_params.fks[src.0][dst.0] { + assert!(fks.src_glwe_param == sol.macro_params[src.0].unwrap().glwe_params); + assert!(fks.dst_glwe_param == sol.macro_params[dst.0].unwrap().glwe_params); } } } @@ -653,7 +653,6 @@ fn test_big_secret_key_sharing() { ciphertext_modulus_log: 64, fft_precision: 53, complexity_model: &CpuComplexity::default(), - composable: false, }; let config_no_sharing = Config { key_sharing: false, @@ -703,7 +702,6 @@ fn test_big_and_small_secret_key() { ciphertext_modulus_log: 64, fft_precision: 53, complexity_model: &CpuComplexity::default(), - composable: false, }; let config_no_sharing = Config { key_sharing: false, @@ -745,23 +743,26 @@ fn test_composition_2_partitions() { let lut1 = dag.add_lut(input1, FunctionTable::UNKWOWN, 6); let lut3 = dag.add_lut(lut1, FunctionTable::UNKWOWN, 3); let input2 = dag.add_dot([input1, lut3], [1, 1]); - let _ = dag.add_lut(input2, FunctionTable::UNKWOWN, 3); - let normal_config = default_config(); - let composed_config = Config { - composable: true, - ..normal_config - }; + let out = dag.add_lut(input2, FunctionTable::UNKWOWN, 3); let search_space = SearchSpace::default_cpu(); - let normal_sol = super::optimize(&dag, normal_config, &search_space, &SHARED_CACHES, &None, 1) - .unwrap() - .1; - let composed_sol = super::optimize( + let normal_sol = super::optimize( &dag, - composed_config, + default_config(), &search_space, &SHARED_CACHES, &None, - 1, + PartitionIndex(1), + ) + .unwrap() + .1; + dag.add_composition(out, input1); + let composed_sol = super::optimize( + &dag, + default_config(), + &search_space, + &SHARED_CACHES, + &None, + PartitionIndex(1), ) .unwrap() .1; @@ -773,23 +774,28 @@ fn test_composition_2_partitions() { fn test_composition_1_partition_not_composable() { let mut dag = unparametrized::Dag::new(); let input1 = dag.add_input(8, Shape::number()); - let input1 = dag.add_dot([input1], [1 << 16]); - let lut1 = dag.add_lut(input1, FunctionTable::UNKWOWN, 8); - let _ = dag.add_dot([lut1], [1 << 16]); + let dot = dag.add_dot([input1], [1 << 16]); + let lut1 = dag.add_lut(dot, FunctionTable::UNKWOWN, 8); + let oup = dag.add_dot([lut1], [1 << 16]); let normal_config = default_config(); - let composed_config = Config { - composable: true, - ..normal_config - }; + let composed_config = normal_config; let search_space = SearchSpace::default_cpu(); - let normal_sol = super::optimize(&dag, normal_config, &search_space, &SHARED_CACHES, &None, 1); + let normal_sol = super::optimize( + &dag, + normal_config, + &search_space, + &SHARED_CACHES, + &None, + PartitionIndex(1), + ); + dag.add_composition(oup, input1); let composed_sol = super::optimize( &dag, composed_config, &search_space, &SHARED_CACHES, &None, - 1, + PartitionIndex(1), ); assert!(normal_sol.is_ok()); assert!(composed_sol.is_err()); @@ -805,11 +811,11 @@ fn test_maximal_multi() { let lut2 = dag.add_lut(lut1, FunctionTable::UNKWOWN, 8u8); _ = dag.add_dot([lut2], [1 << 16]); - let sol = optimize(&dag, &None, 0).unwrap(); + let sol = optimize(&dag, &None, PartitionIndex(0)).unwrap(); assert!(sol.macro_params.len() == 1); let p_cut = PartitionCut::maximal_partitionning(&dag); - let sol = optimize(&dag, &Some(p_cut.clone()), 0).unwrap(); + let sol = optimize(&dag, &Some(p_cut.clone()), PartitionIndex(0)).unwrap(); assert!(sol.macro_params.len() == 2); eprintln!("{:?}", sol.micro_params.pbs); @@ -842,6 +848,6 @@ fn test_bug_with_zero_noise() { let v2 = dag.add_levelled_op([v1], complexity, 1.0, &out_shape, "comment"); let v3 = dag.add_unsafe_cast(v2, 1); let _ = dag.add_lut(v3, FunctionTable { values: vec![] }, 1); - let sol = optimize(&dag, &None, 0); + let sol = optimize(&dag, &None, PartitionIndex(0)); assert!(sol.is_some()); } diff --git a/compilers/concrete-optimizer/concrete-optimizer/src/optimization/dag/multi_parameters/optimize_generic.rs b/compilers/concrete-optimizer/concrete-optimizer/src/optimization/dag/multi_parameters/optimize_generic.rs index 049ed02f9..e2c84eb27 100644 --- a/compilers/concrete-optimizer/concrete-optimizer/src/optimization/dag/multi_parameters/optimize_generic.rs +++ b/compilers/concrete-optimizer/concrete-optimizer/src/optimization/dag/multi_parameters/optimize_generic.rs @@ -62,8 +62,9 @@ pub fn optimize( caches: &PersistDecompCaches, p_cut: &Option, ) -> CircuitSolution { - let native = || native_optimize(dag, config, search_space, caches, p_cut); - let crt = || crt_optimize(dag, config, search_space, default_log_norm2_woppbs, caches); + let dag = dag.clone(); + let native = || native_optimize(&dag, config, search_space, caches, p_cut); + let crt = || crt_optimize(&dag, config, search_space, default_log_norm2_woppbs, caches); match encoding { Encoding::Auto => best_complexity_solution(native(), crt()), Encoding::Native => native(), diff --git a/compilers/concrete-optimizer/concrete-optimizer/src/optimization/dag/multi_parameters/partition_cut.rs b/compilers/concrete-optimizer/concrete-optimizer/src/optimization/dag/multi_parameters/partition_cut.rs index 5f775687c..c3e607bb5 100644 --- a/compilers/concrete-optimizer/concrete-optimizer/src/optimization/dag/multi_parameters/partition_cut.rs +++ b/compilers/concrete-optimizer/concrete-optimizer/src/optimization/dag/multi_parameters/partition_cut.rs @@ -66,10 +66,10 @@ impl PartitionCut { if dag.out_precisions[input.0] <= precision_cut && self.rnorm2(op_i) <= norm2_cut { - return Some(partition); + return Some(PartitionIndex(partition)); } } - Some(self.p_cut.len()) + Some(PartitionIndex(self.p_cut.len())) } _ => None, } @@ -193,7 +193,7 @@ impl PartitionCut { pub fn delete_unused_cut(&self, used: &HashSet) -> Self { let mut p_cut = vec![]; for (i, &cut) in self.p_cut.iter().enumerate() { - if used.contains(&i) { + if used.contains(&PartitionIndex(i)) { p_cut.push(cut); } } diff --git a/compilers/concrete-optimizer/concrete-optimizer/src/optimization/dag/multi_parameters/partitionning.rs b/compilers/concrete-optimizer/concrete-optimizer/src/optimization/dag/multi_parameters/partitionning.rs index 83707f371..fb4d7d66d 100644 --- a/compilers/concrete-optimizer/concrete-optimizer/src/optimization/dag/multi_parameters/partitionning.rs +++ b/compilers/concrete-optimizer/concrete-optimizer/src/optimization/dag/multi_parameters/partitionning.rs @@ -46,7 +46,7 @@ impl Blocks { // Extract block of instructions connected by levelled ops. // This facilitates reasonning about conflicts on levelled ops. #[allow(clippy::match_same_arms)] -fn extract_levelled_block(dag: &unparametrized::Dag, composable: bool) -> Blocks { +fn extract_levelled_block(dag: &unparametrized::Dag) -> Blocks { let mut uf = UnionFind::new(dag.operators.len()); for (op_i, op) in dag.operators.iter().enumerate() { match op { @@ -64,23 +64,19 @@ fn extract_levelled_block(dag: &unparametrized::Dag, composable: bool) -> Blocks Op::Round { .. } => unreachable!("Round should have been expanded"), }; } - if composable { - // Without knowledge of how outputs are forwarded to inputs, we can't do better than putting - // all inputs and outputs in the same partition. - let mut input_iter = dag.get_input_operators_iter().map(|op| op.id.0); - let first_inp = input_iter.next().unwrap(); - dag.get_output_operators_iter() - .map(|op| op.id.0) - .chain(input_iter) - .for_each(|ind| uf.union(first_inp, ind)); + // We apply the composition rules + for (to_id, froms) in dag.composition.clone() { + for from_id in froms { + uf.union(to_id.0, from_id.0); + } } Blocks::from(uf) } #[derive(Clone, Debug, Default)] struct BlockConstraints { - forced: HashSet, // hard constraints, need to be resolved, given by PartitionFromOp - exit: HashSet, // soft constraints, to have less inter partition keyswitch in TLUs + forced: HashSet, // hard constraints, need to be resolved, given by PartitionFromOp + exit: HashSet, // soft constraints, to have less inter partition keyswitch in TLUs } /* For each levelled block collect BlockConstraints */ @@ -117,7 +113,8 @@ fn get_singleton_value(hashset: &HashSet) -> V { } fn only_1_partition(dag: &unparametrized::Dag) -> Partitions { - let mut instrs_partition = vec![InstructionPartition::new(0); dag.operators.len()]; + let mut instrs_partition = + vec![InstructionPartition::new(PartitionIndex::FIRST); dag.operators.len()]; for (op_i, op) in dag.operators.iter().enumerate() { match op { Op::Dot { inputs, .. } | Op::LevelledOp { inputs, .. } => { @@ -140,9 +137,8 @@ fn resolve_by_levelled_block( dag: &unparametrized::Dag, p_cut: &PartitionCut, default_partition: PartitionIndex, - composable: bool, ) -> Partitions { - let blocks = extract_levelled_block(dag, composable); + let blocks = extract_levelled_block(dag); let constraints_by_blocks = levelled_blocks_constraints(dag, &blocks, p_cut); let present_partitions: HashSet = constraints_by_blocks .iter() @@ -155,7 +151,6 @@ fn resolve_by_levelled_block( dag, &p_cut.delete_unused_cut(&present_partitions), default_partition, - composable, ); } if nb_partitions == 1 { @@ -251,12 +246,11 @@ pub fn partitionning_with_preferred( dag: &unparametrized::Dag, p_cut: &PartitionCut, default_partition: PartitionIndex, - composable: bool, ) -> Partitions { if p_cut.p_cut.is_empty() { only_1_partition(dag) } else { - resolve_by_levelled_block(dag, p_cut, default_partition, composable) + resolve_by_levelled_block(dag, p_cut, default_partition) } } @@ -264,8 +258,8 @@ pub fn partitionning_with_preferred( pub mod tests { // 2 Partitions labels - pub const LOW_PRECISION_PARTITION: PartitionIndex = 0; - pub const HIGH_PRECISION_PARTITION: PartitionIndex = 1; + pub const LOW_PRECISION_PARTITION: PartitionIndex = PartitionIndex(0); + pub const HIGH_PRECISION_PARTITION: PartitionIndex = PartitionIndex(1); use super::*; use crate::dag::operator::{FunctionTable, Shape, Weights}; @@ -275,27 +269,25 @@ pub mod tests { PartitionCut::from_precisions(&[2, 128]) } - fn partitionning_no_p_cut(dag: &unparametrized::Dag, composable: bool) -> Partitions { + fn partitionning_no_p_cut(dag: &unparametrized::Dag) -> Partitions { let p_cut = PartitionCut::empty(); - partitionning_with_preferred(dag, &p_cut, LOW_PRECISION_PARTITION, composable) + partitionning_with_preferred(dag, &p_cut, LOW_PRECISION_PARTITION) } - fn partitionning(dag: &unparametrized::Dag, composable: bool) -> Partitions { + fn partitionning(dag: &unparametrized::Dag) -> Partitions { partitionning_with_preferred( dag, &PartitionCut::for_each_precision(dag), LOW_PRECISION_PARTITION, - composable, ) } fn partitionning_with_preferred( dag: &unparametrized::Dag, p_cut: &PartitionCut, - default_partition: usize, - composable: bool, + default_partition: PartitionIndex, ) -> Partitions { - super::partitionning_with_preferred(dag, p_cut, default_partition, composable) + super::partitionning_with_preferred(dag, p_cut, default_partition) } pub fn show_partitionning(dag: &unparametrized::Dag, partitions: &[InstructionPartition]) { @@ -334,7 +326,7 @@ pub mod tests { let mut dag = unparametrized::Dag::new(); let input1 = dag.add_input(16, Shape::number()); _ = dag.add_expanded_rounded_lut(input1, FunctionTable::UNKWOWN, 4, 8); - let instrs_partition = partitionning_no_p_cut(&dag, false).instrs_partition; + let instrs_partition = partitionning_no_p_cut(&dag).instrs_partition; for instr_partition in instrs_partition { assert!(instr_partition.instruction_partition == LOW_PRECISION_PARTITION); assert!(instr_partition.no_transition()); @@ -345,7 +337,7 @@ pub mod tests { fn test_1_input_2_partitions() { let mut dag = unparametrized::Dag::new(); _ = dag.add_input(1, Shape::number()); - let partitions = partitionning(&dag, false); + let partitions = partitionning(&dag); assert!(partitions.nb_partitions == 1); let instrs_partition = partitions.instrs_partition; assert!(instrs_partition[0].instruction_partition == LOW_PRECISION_PARTITION); @@ -358,12 +350,13 @@ pub mod tests { let input = dag.add_input(10, Shape::number()); let lut1 = dag.add_lut(input, FunctionTable::UNKWOWN, 2); let output = dag.add_lut(lut1, FunctionTable::UNKWOWN, 10); - let partitions = partitionning(&dag, false); + let partitions = partitionning(&dag); assert!( partitions.instrs_partition[input.0].instruction_partition != partitions.instrs_partition[output.0].instruction_partition ); - let partitions = partitionning(&dag, true); + dag.add_composition(output, input); + let partitions = partitionning(&dag); assert!( partitions.instrs_partition[input.0].instruction_partition == partitions.instrs_partition[output.0].instruction_partition @@ -386,7 +379,7 @@ pub mod tests { expected_partitions.push(LOW_PRECISION_PARTITION); let lut5 = dag.add_lut(lut4, FunctionTable::UNKWOWN, 8); expected_partitions.push(HIGH_PRECISION_PARTITION); - let partitions = partitionning(&dag, false); + let partitions = partitionning(&dag); assert!(partitions.nb_partitions == 2); let instrs_partition = partitions.instrs_partition; let consider = |op_i: OperatorIndex| &instrs_partition[op_i.0]; @@ -407,7 +400,7 @@ pub mod tests { let input2 = dag.add_input(1, Shape::number()); let lut2 = dag.add_lut(input2, FunctionTable::UNKWOWN, 8); let _dot = dag.add_dot([input1, lut2], Weights::from([1, 1])); - let partitions = partitionning(&dag, false); + let partitions = partitionning(&dag); assert!(partitions.nb_partitions == 1); } @@ -418,7 +411,7 @@ pub mod tests { let input2 = dag.add_input(1, Shape::number()); let lut2 = dag.add_lut(input1, FunctionTable::UNKWOWN, 1); let _dot = dag.add_dot([input2, lut2], Weights::from([1, 1])); - let partitions = partitionning(&dag, false); + let partitions = partitionning(&dag); assert!(partitions.nb_partitions == 1); } @@ -430,7 +423,7 @@ pub mod tests { let lut1 = dag.add_lut(input1, FunctionTable::UNKWOWN, 8); let lut2 = dag.add_lut(input2, FunctionTable::UNKWOWN, 8); let dot = dag.add_dot([lut1, lut2], Weights::from([1, 1])); - let partitions = partitionning(&dag, false); + let partitions = partitionning(&dag); let consider = |op_i: OperatorIndex| &partitions.instrs_partition[op_i.0]; // input1 let p = consider(input1); @@ -486,7 +479,7 @@ pub mod tests { let lut1 = dag.add_lut(rounded1, FunctionTable::UNKWOWN, acc_precision); let rounded2 = dag.add_expanded_round(lut1, precision); let lut2 = dag.add_lut(rounded2, FunctionTable::UNKWOWN, acc_precision); - let partitions = partitionning(&dag, false); + let partitions = partitionning(&dag); let consider = |op_i| &partitions.instrs_partition[op_i]; // First layer is fully LOW_PRECISION_PARTITION for op_i in input1.0..lut1.0 { @@ -536,7 +529,7 @@ pub mod tests { let rounded1 = dag.add_expanded_round(input1, precision); let rounded_layer: Vec<_> = ((input1.0 + 1)..rounded1.0).collect(); let lut1 = dag.add_lut(rounded1, FunctionTable::UNKWOWN, acc_precision); - let partitions = partitionning(&dag, false); + let partitions = partitionning(&dag); let consider = |op_i: usize| &partitions.instrs_partition[op_i]; // First layer is fully HIGH_PRECISION_PARTITION @@ -597,7 +590,7 @@ pub mod tests { let rounded_layer = (input1.0 + 1)..rounded1.0; let _lut1 = dag.add_lut(rounded1, FunctionTable::UNKWOWN, acc_precision); let partitions = - partitionning_with_preferred(&dag, &default_p_cut(), HIGH_PRECISION_PARTITION, false); + partitionning_with_preferred(&dag, &default_p_cut(), HIGH_PRECISION_PARTITION); show_partitionning(&dag, &partitions.instrs_partition); let consider = |op_i: usize| &partitions.instrs_partition[op_i]; diff --git a/compilers/concrete-optimizer/concrete-optimizer/src/optimization/dag/multi_parameters/partitions.rs b/compilers/concrete-optimizer/concrete-optimizer/src/optimization/dag/multi_parameters/partitions.rs index 90e1b89d4..28ee90c61 100644 --- a/compilers/concrete-optimizer/concrete-optimizer/src/optimization/dag/multi_parameters/partitions.rs +++ b/compilers/concrete-optimizer/concrete-optimizer/src/optimization/dag/multi_parameters/partitions.rs @@ -1,6 +1,38 @@ -use std::collections::HashSet; +use std::{ + collections::HashSet, + fmt::Display, + ops::{Deref, Index, IndexMut}, +}; + +use crate::dag::operator::OperatorIndex; + +#[derive(Clone, Debug, PartialEq, Eq, Default, PartialOrd, Ord, Hash, Copy)] +pub struct PartitionIndex(pub(crate) usize); + +impl PartitionIndex { + pub const FIRST: Self = Self(0); + + pub const INVALID: Self = Self(usize::MAX); + + pub fn range(from: usize, to: usize) -> impl DoubleEndedIterator { + (from..to).map(PartitionIndex) + } +} + +impl Display for PartitionIndex { + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + self.0.fmt(f) + } +} + +impl Deref for PartitionIndex { + type Target = usize; + + fn deref(&self) -> &Self::Target { + &self.0 + } +} -pub type PartitionIndex = usize; pub type AdditionalRepresentations = HashSet; // How one input is made compatible with the instruction partition @@ -47,3 +79,30 @@ pub struct Partitions { pub nb_partitions: usize, pub instrs_partition: Vec, } + +impl Index for Partitions { + type Output = InstructionPartition; + + fn index(&self, index: OperatorIndex) -> &Self::Output { + &self.instrs_partition[index.0] + } +} + +impl IndexMut for Partitions { + fn index_mut(&mut self, index: OperatorIndex) -> &mut Self::Output { + &mut self.instrs_partition[index.0] + } +} + +#[allow(unused)] +pub struct PartitionsCircuit<'part> { + pub(crate) partitions: &'part [InstructionPartition], + pub(crate) idx: Vec, +} + +impl<'part> PartitionsCircuit<'part> { + #[allow(unused)] + pub fn get_node_iter(&'part self) -> impl Iterator { + self.idx.iter().map(|i| self.partitions.get(*i).unwrap()) + } +} diff --git a/compilers/concrete-optimizer/concrete-optimizer/src/optimization/dag/multi_parameters/symbolic_variance.rs b/compilers/concrete-optimizer/concrete-optimizer/src/optimization/dag/multi_parameters/symbolic_variance.rs index f3fdc8717..4bf70f5b8 100644 --- a/compilers/concrete-optimizer/concrete-optimizer/src/optimization/dag/multi_parameters/symbolic_variance.rs +++ b/compilers/concrete-optimizer/concrete-optimizer/src/optimization/dag/multi_parameters/symbolic_variance.rs @@ -2,6 +2,8 @@ use std::fmt; use crate::optimization::dag::multi_parameters::operations_value::OperationsValue; +use super::partitions::PartitionIndex; + /** * A variance that is represented as a linear combination of base variances. * Only the linear coefficient are known. @@ -20,14 +22,14 @@ use crate::optimization::dag::multi_parameters::operations_value::OperationsValu */ #[derive(Clone, Debug, PartialEq, PartialOrd)] pub struct SymbolicVariance { - pub partition: usize, + pub partition: PartitionIndex, pub coeffs: OperationsValue, } impl SymbolicVariance { // To be used as a initial accumulator pub const ZERO: Self = Self { - partition: 0, + partition: PartitionIndex::FIRST, coeffs: OperationsValue::ZERO, }; @@ -37,12 +39,12 @@ impl SymbolicVariance { pub fn nan(nb_partitions: usize) -> Self { Self { - partition: usize::MAX, + partition: PartitionIndex::INVALID, coeffs: OperationsValue::nan(nb_partitions), } } - pub fn input(nb_partitions: usize, partition: usize) -> Self { + pub fn input(nb_partitions: usize, partition: PartitionIndex) -> Self { let mut r = Self { partition, coeffs: OperationsValue::zero(nb_partitions), @@ -52,11 +54,11 @@ impl SymbolicVariance { r } - pub fn coeff_input(&self, partition: usize) -> f64 { + pub fn coeff_input(&self, partition: PartitionIndex) -> f64 { self.coeffs[self.coeffs.index.input(partition)] } - pub fn after_pbs(nb_partitions: usize, partition: usize) -> Self { + pub fn after_pbs(nb_partitions: usize, partition: PartitionIndex) -> Self { let mut r = Self { partition, coeffs: OperationsValue::zero(nb_partitions), @@ -65,15 +67,15 @@ impl SymbolicVariance { r } - pub fn coeff_pbs(&self, partition: usize) -> f64 { + pub fn coeff_pbs(&self, partition: PartitionIndex) -> f64 { self.coeffs[self.coeffs.index.pbs(partition)] } - pub fn coeff_modulus_switching(&self, partition: usize) -> f64 { + pub fn coeff_modulus_switching(&self, partition: PartitionIndex) -> f64 { self.coeffs[self.coeffs.index.modulus_switching(partition)] } - pub fn after_modulus_switching(&self, partition: usize) -> Self { + pub fn after_modulus_switching(&self, partition: PartitionIndex) -> Self { let mut new = self.clone(); let index = self.coeffs.index.modulus_switching(partition); assert!(new.coeffs[index] == 0.0); @@ -81,7 +83,11 @@ impl SymbolicVariance { new } - pub fn coeff_keyswitch_to_small(&self, src_partition: usize, dst_partition: usize) -> f64 { + pub fn coeff_keyswitch_to_small( + &self, + src_partition: PartitionIndex, + dst_partition: PartitionIndex, + ) -> f64 { self.coeffs[self .coeffs .index @@ -90,8 +96,8 @@ impl SymbolicVariance { pub fn after_partition_keyswitch_to_small( &self, - src_partition: usize, - dst_partition: usize, + src_partition: PartitionIndex, + dst_partition: PartitionIndex, ) -> Self { let index = self .coeffs @@ -102,8 +108,8 @@ impl SymbolicVariance { pub fn coeff_partition_keyswitch_to_big( &self, - src_partition: usize, - dst_partition: usize, + src_partition: PartitionIndex, + dst_partition: PartitionIndex, ) -> f64 { self.coeffs[self .coeffs @@ -113,8 +119,8 @@ impl SymbolicVariance { pub fn after_partition_keyswitch_to_big( &self, - src_partition: usize, - dst_partition: usize, + src_partition: PartitionIndex, + dst_partition: PartitionIndex, ) -> Self { let index = self .coeffs @@ -125,12 +131,12 @@ impl SymbolicVariance { pub fn after_partition_keyswitch( &self, - src_partition: usize, - dst_partition: usize, + src_partition: PartitionIndex, + dst_partition: PartitionIndex, index: usize, ) -> Self { - assert!(src_partition < self.nb_partitions()); - assert!(dst_partition < self.nb_partitions()); + assert!(src_partition.0 < self.nb_partitions()); + assert!(dst_partition.0 < self.nb_partitions()); assert!(src_partition == self.partition); let mut new = self.clone(); new.partition = dst_partition; @@ -144,7 +150,7 @@ impl SymbolicVariance { // detect the previous base manp level // this is the maximum value of fresh base noise and pbs base noise let mut current_max: f64 = 0.0; - for partition in 0..self.nb_partitions() { + for partition in PartitionIndex::range(0, self.nb_partitions()) { let fresh_coeff = self.coeff_input(partition); let pbs_noise_coeff = self.coeff_pbs(partition); current_max = current_max.max(fresh_coeff).max(pbs_noise_coeff); @@ -195,7 +201,7 @@ impl fmt::Display for SymbolicVariance { return write!(f, "NAN x σ²"); } let mut add_plus = ""; - for src_partition in 0..self.nb_partitions() { + for src_partition in PartitionIndex::range(0, self.nb_partitions()) { let coeff = self.coeff_input(src_partition); if coeff != 0.0 { write!(f, "{add_plus}{coeff}σ²In[{src_partition}]")?; @@ -206,7 +212,7 @@ impl fmt::Display for SymbolicVariance { write!(f, "{add_plus}{coeff}σ²Br[{src_partition}]")?; add_plus = " + "; } - for dst_partition in 0..self.nb_partitions() { + for dst_partition in PartitionIndex::range(0, self.nb_partitions()) { let coeff = self.coeff_partition_keyswitch_to_big(src_partition, dst_partition); if coeff != 0.0 { write!(f, "{add_plus}{coeff}σ²FK[{src_partition}→{dst_partition}]")?; @@ -214,8 +220,8 @@ impl fmt::Display for SymbolicVariance { } } } - for src_partition in 0..self.nb_partitions() { - for dst_partition in 0..self.nb_partitions() { + for src_partition in PartitionIndex::range(0, self.nb_partitions()) { + for dst_partition in PartitionIndex::range(0, self.nb_partitions()) { let coeff = self.coeff_keyswitch_to_small(src_partition, dst_partition); if coeff != 0.0 { if src_partition == dst_partition { @@ -227,7 +233,7 @@ impl fmt::Display for SymbolicVariance { } } } - for partition in 0..self.nb_partitions() { + for partition in PartitionIndex::range(0, self.nb_partitions()) { let coeff = self.coeff_modulus_switching(partition); if coeff != 0.0 { write!(f, "{add_plus}{coeff}σ²M[{partition}]")?; @@ -238,6 +244,21 @@ impl fmt::Display for SymbolicVariance { } } +impl std::ops::Add for SymbolicVariance { + type Output = Self; + + fn add(mut self, rhs: Self) -> Self::Output { + if self.coeffs.is_empty() { + self = rhs; + } else { + for i in 0..self.coeffs.len() { + self.coeffs[i] += rhs.coeffs[i]; + } + }; + self + } +} + impl std::ops::AddAssign for SymbolicVariance { fn add_assign(&mut self, rhs: Self) { if self.coeffs.is_empty() { diff --git a/compilers/concrete-optimizer/concrete-optimizer/src/optimization/dag/multi_parameters/variance_constraint.rs b/compilers/concrete-optimizer/concrete-optimizer/src/optimization/dag/multi_parameters/variance_constraint.rs index adbe3bec5..2adb375d1 100644 --- a/compilers/concrete-optimizer/concrete-optimizer/src/optimization/dag/multi_parameters/variance_constraint.rs +++ b/compilers/concrete-optimizer/concrete-optimizer/src/optimization/dag/multi_parameters/variance_constraint.rs @@ -49,7 +49,7 @@ impl VarianceConstraint { let self_renorm = other.safe_variance_bound / self.safe_variance_bound; let rel_diff = |f: &dyn Fn(&SymbolicVariance) -> f64| self_renorm * f(self_var) - f(other_var); - for partition in 0..self.variance.nb_partitions() { + for partition in PartitionIndex::range(0, self.variance.nb_partitions()) { let diffs = [ rel_diff(&|var| var.coeff_pbs(partition)), rel_diff(&|var| var.coeff_pbs(partition) + var.coeff_input(partition)), @@ -61,8 +61,8 @@ impl VarianceConstraint { } } } - for src_partition in 0..self.variance.nb_partitions() { - for dst_partition in 0..self.variance.nb_partitions() { + for src_partition in PartitionIndex::range(0, self.variance.nb_partitions()) { + for dst_partition in PartitionIndex::range(0, self.variance.nb_partitions()) { let diffs = [ rel_diff(&|var| var.coeff_keyswitch_to_small(src_partition, dst_partition)), rel_diff(&|var| { diff --git a/compilers/concrete-optimizer/concrete-optimizer/src/optimization/dag/solo_key/analyze.rs b/compilers/concrete-optimizer/concrete-optimizer/src/optimization/dag/solo_key/analyze.rs index 48b5d44a2..a42f48911 100644 --- a/compilers/concrete-optimizer/concrete-optimizer/src/optimization/dag/solo_key/analyze.rs +++ b/compilers/concrete-optimizer/concrete-optimizer/src/optimization/dag/solo_key/analyze.rs @@ -1,6 +1,6 @@ use super::symbolic_variance::{SymbolicVariance, VarianceOrigin}; use crate::dag::operator::{ - dot_kind, LevelledComplexity, Operator, OperatorIndex, Precision, Shape, + DotKind, LevelledComplexity, Operator, OperatorIndex, Precision, Shape, }; use crate::dag::rewrite::round::expand_round; use crate::dag::unparametrized::Dag; @@ -8,7 +8,6 @@ use crate::noise_estimator::error; use crate::noise_estimator::p_error::{combine_errors, repeat_p_error}; use crate::optimization::config::NoiseBoundConfig; use crate::utils::square; -use dot_kind::DotKind; use std::collections::{HashMap, HashSet}; pub fn first<'a, Property>(inputs: &[OperatorIndex], properties: &'a [Property]) -> &'a Property { @@ -140,7 +139,7 @@ pub struct VariancesAndBound { fn out_variance( op: &Operator, - out_shapes: &[Shape], + _out_shapes: &[Shape], out_variances: &[SymbolicVariance], ) -> SymbolicVariance { // Maintain a linear combination of input_variance and lut_out_variance @@ -158,28 +157,37 @@ fn out_variance( origin * variance_factor } Operator::Dot { - inputs, weights, .. - } => { - let input_shape = first(inputs, out_shapes); - let kind = dot_kind(inputs.len() as u64, input_shape, weights); - match kind { - DotKind::Simple | DotKind::Tensor | DotKind::Broadcast { .. } => { - let first_input = inputs[0]; - let mut out_variance = SymbolicVariance::ZERO; - for (j, &weight) in weights.values.iter().enumerate() { - let k = if inputs.len() > 1 { - inputs[j].0 - } else { - first_input.0 - }; - out_variance += out_variances[k] * square(weight as f64); - } - out_variance - } - DotKind::CompatibleTensor { .. } => todo!("TODO"), - DotKind::Unsupported { .. } => panic!("Unsupported"), - } + kind: DotKind::CompatibleTensor { .. }, + .. + } => todo!("TODO"), + Operator::Dot { + kind: DotKind::Unsupported { .. }, + .. + } => panic!("Unsupported"), + Operator::Dot { + inputs, + weights, + kind: DotKind::Simple | DotKind::Tensor | DotKind::Broadcast { .. }, + } if inputs.len() == 1 => { + let var = out_variances[inputs.iter().next().unwrap().0]; + weights + .values + .iter() + .fold(SymbolicVariance::ZERO, |acc, weight| { + acc + var * square(*weight as f64) + }) } + Operator::Dot { + inputs, + weights, + kind: DotKind::Simple | DotKind::Tensor | DotKind::Broadcast { .. }, + } => weights + .values + .iter() + .zip(inputs.iter().map(|n| out_variances[n.0])) + .fold(SymbolicVariance::ZERO, |acc, (weight, var)| { + acc + var * square(*weight as f64) + }), Operator::UnsafeCast { input, .. } => out_variances[input.0], Operator::Round { .. } => { unreachable!("Round should have been either expanded or integrated to a lut") @@ -230,20 +238,15 @@ fn in_luts_variance( fn op_levelled_complexity(op: &Operator, out_shapes: &[Shape]) -> LevelledComplexity { match op { Operator::Dot { - inputs, weights, .. - } => { - let input_shape = first(inputs, out_shapes); - let kind = dot_kind(inputs.len() as u64, input_shape, weights); - match kind { - DotKind::Simple - | DotKind::Tensor - | DotKind::Broadcast { .. } - | DotKind::CompatibleTensor => { - LevelledComplexity::ADDITION * (inputs.len() as u64) * input_shape.flat_size() - } - DotKind::Unsupported { .. } => panic!("Unsupported"), - } + kind: DotKind::Unsupported, + .. + } => panic!("Unsupported"), + Operator::Dot { inputs, .. } => { + LevelledComplexity::ADDITION + * (inputs.len() as u64) + * out_shapes[inputs[0].0].flat_size() } + Operator::LevelledOp { complexity, .. } => *complexity, Operator::Input { .. } | Operator::Lut { .. } | Operator::UnsafeCast { .. } => { LevelledComplexity::ZERO diff --git a/compilers/concrete-optimizer/concrete-optimizer/src/optimization/dag/solo_key/optimize.rs b/compilers/concrete-optimizer/concrete-optimizer/src/optimization/dag/solo_key/optimize.rs index aff178f6e..4d4ae204a 100644 --- a/compilers/concrete-optimizer/concrete-optimizer/src/optimization/dag/solo_key/optimize.rs +++ b/compilers/concrete-optimizer/concrete-optimizer/src/optimization/dag/solo_key/optimize.rs @@ -480,7 +480,6 @@ pub(crate) mod tests { ciphertext_modulus_log: 64, fft_precision: 53, complexity_model: &CpuComplexity::default(), - composable: false, }; let search_space = SearchSpace::default_cpu(); @@ -526,7 +525,6 @@ pub(crate) mod tests { ciphertext_modulus_log: 64, fft_precision: 53, complexity_model: &CpuComplexity::default(), - composable: false, }; _ = optimize_v0( @@ -625,7 +623,6 @@ pub(crate) mod tests { ciphertext_modulus_log: 64, fft_precision: 53, complexity_model: &CpuComplexity::default(), - composable: false, }; let state = optimize(&dag); diff --git a/compilers/concrete-optimizer/concrete-optimizer/src/utils/viz.rs b/compilers/concrete-optimizer/concrete-optimizer/src/utils/viz.rs index 810d7c0b4..4684790bc 100644 --- a/compilers/concrete-optimizer/concrete-optimizer/src/utils/viz.rs +++ b/compilers/concrete-optimizer/concrete-optimizer/src/utils/viz.rs @@ -36,6 +36,13 @@ impl Viz for crate::dag::unparametrized::Dag { let mut graph = vec![]; self.get_circuits_iter() .for_each(|circuit| graph.push(circuit.viz_node())); + self.composition + .clone() + .into_iter() + .flat_map(|(to, froms)| froms.into_iter().map(move |f| (to, f))) + .for_each(|(to, from)| { + graph.push(format!("{from} -> {to} [color=red, style=dashed]")); + }); graph.join("\n") } } @@ -108,10 +115,66 @@ impl<'dag> Viz for crate::dag::unparametrized::DagOperator<'dag> { } } +impl Viz for crate::optimization::dag::multi_parameters::analyze::PartitionedDag { + fn viz_node(&self) -> String { + let mut output = self.dag.viz_node(); + self.partitions + .instrs_partition + .iter() + .enumerate() + .for_each(|(i, part)| { + let partition = part.instruction_partition; + let circuit = &self.dag.circuit_tags[i]; + // let color = partition.0 + 1; + output.push_str(&format!("subgraph cluster_circuit_{circuit} {{\n")); + output.push_str(&format!("partition_{i} [label =\"{partition}\"];\n")); + output.push_str(&format!( + "partition_{i} -> {i} [arrowhead=none, color=gray80, weight=99];\n" + )); + output.push_str("}\n"); + }); + output + } +} + +impl Viz for crate::optimization::dag::multi_parameters::analyze::VariancedDag { + fn viz_node(&self) -> String { + let mut output = self.dag.viz_node(); + self.partitions + .instrs_partition + .iter() + .zip(self.variances.vars.iter()) + .enumerate() + .for_each(|(i, (part, var))| { + let partition = part.instruction_partition; + let circuit = &self.dag.circuit_tags[i]; + let variances = var + .vars + .iter() + .enumerate() + .map(|(i, var)| format!("{{{i}|{var}}}")) + .collect::>() + .join("|"); + let label = format!( + "<{{Partition | {partition} | Variances | {variances} }}>" + ); + output.push_str(&format!("subgraph cluster_circuit_{circuit} {{\n")); + output.push_str(&format!( + "info_{i} [label ={label} color=gray80 fillcolor=gray90];\n" + )); + output.push_str(&format!( + "{i} -> info_{i} [arrowhead=none, color=gray90, weight=99];\n" + )); + output.push_str("}\n"); + }); + output + } +} + macro_rules! _viz { ($path: expr, $object:expr) => {{ let mut path = std::env::temp_dir(); - path.push($path.as_str()); + path.push(AsRef::::as_ref($path)); let _ = std::process::Command::new("sh") .arg("-c") .arg(format!( @@ -139,7 +202,7 @@ macro_rules! viz { }; ($object:expr) => { let name = format!("concrete_optimizer_dbg_{}.svg", rand::random::()); - $crate::utils::viz::viz!(name, $object); + $crate::utils::viz::viz!(&name, $object); }; } @@ -158,7 +221,7 @@ macro_rules! vizp { }}; ($object:expr) => { let name = format!("concrete_optimizer_dbg_{}.svg", rand::random::()); - $crate::utils::viz::vizp!(name, $object); + $crate::utils::viz::vizp!(&name, $object); }; } diff --git a/compilers/concrete-optimizer/v0-parameters/benches/benchmark.rs b/compilers/concrete-optimizer/v0-parameters/benches/benchmark.rs index 7c5026144..5ce7a8e11 100644 --- a/compilers/concrete-optimizer/v0-parameters/benches/benchmark.rs +++ b/compilers/concrete-optimizer/v0-parameters/benches/benchmark.rs @@ -20,7 +20,6 @@ fn v0_pbs_optimization(c: &mut Criterion) { cache_on_disk: true, ciphertext_modulus_log: 64, fft_precision: 53, - composable: false, }; c.bench_function("v0 PBS table generation", |b| { @@ -47,7 +46,6 @@ fn v0_pbs_optimization_simulate_graph(c: &mut Criterion) { cache_on_disk: true, ciphertext_modulus_log: 64, fft_precision: 53, - composable: false, }; c.bench_function("v0 PBS simulate dag table generation", |b| { @@ -74,7 +72,6 @@ fn v0_wop_pbs_optimization(c: &mut Criterion) { cache_on_disk: true, ciphertext_modulus_log: 64, fft_precision: 53, - composable: false, }; c.bench_function("v0 WoP-PBS table generation", |b| { diff --git a/compilers/concrete-optimizer/v0-parameters/src/lib.rs b/compilers/concrete-optimizer/v0-parameters/src/lib.rs index 4a2b39090..e3321efd5 100644 --- a/compilers/concrete-optimizer/v0-parameters/src/lib.rs +++ b/compilers/concrete-optimizer/v0-parameters/src/lib.rs @@ -92,9 +92,6 @@ pub struct Args { #[clap(long, default_value_t = 53)] pub fft_precision: u32, - - #[clap(long)] - pub composable: bool, } pub fn all_results(args: &Args) -> Vec>> { @@ -103,7 +100,6 @@ pub fn all_results(args: &Args) -> Vec>> { let maximum_acceptable_error_probability = args.p_error; let security_level = args.security_level; let cache_on_disk = args.cache_on_disk; - let composable = args.composable; let search_space = SearchSpace { glwe_log_polynomial_sizes: (args.min_log_poly_size..=args.max_log_poly_size).collect(), @@ -126,7 +122,6 @@ pub fn all_results(args: &Args) -> Vec>> { ciphertext_modulus_log: args.ciphertext_modulus_log, fft_precision: args.fft_precision, complexity_model: &CpuComplexity::default(), - composable, }; let cache = decomposition::cache( @@ -300,7 +295,6 @@ mod tests { cache_on_disk: true, ciphertext_modulus_log: 64, fft_precision: 53, - composable: false, }; let mut actual_output = Vec::::new(); @@ -345,7 +339,6 @@ mod tests { cache_on_disk: true, ciphertext_modulus_log: 64, fft_precision: 53, - composable: false, }; let mut actual_output = Vec::::new();