From 60da7133126fc7bc61e98799274a2e194ae5f228 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Alexandre=20P=C3=A9r=C3=A9?= Date: Thu, 7 Dec 2023 09:41:53 +0100 Subject: [PATCH] feat(optimizer): adds support for function composition --- .../concretelang/Support/V0Parameters.h | 3 + .../concretelang/TestLib/TestCircuit.h | 22 + .../lib/Bindings/Python/CompilerAPIModule.cpp | 4 + .../concrete/compiler/compilation_options.py | 13 + .../compiler/lib/Support/V0Parameters.cpp | 2 +- .../concrete-compiler/compiler/src/main.cpp | 7 + .../end_to_end_tests/end_to_end_jit_test.cc | 30 + .../end_to_end_tests/end_to_end_jit_test.h | 10 +- compilers/concrete-optimizer/Cargo.toml | 1 + .../charts/src/bin/norm2_complexity.rs | 1 + .../charts/src/bin/precision_complexity.rs | 1 + .../src/concrete-optimizer.rs | 5 + .../src/cpp/concrete-optimizer.cpp | 1 + .../src/cpp/concrete-optimizer.hpp | 1 + .../concrete-optimizer-cpp/tests/src/main.cpp | 1 + .../src/dag/unparametrized.rs | 19 + .../src/optimization/config.rs | 1 + .../dag/multi_parameters/analyze.rs | 149 +++- .../{optimize.rs => optimize/mod.rs} | 64 +- .../dag/multi_parameters/optimize/tests.rs | 799 ++++++++++++++++++ .../dag/multi_parameters/partitionning.rs | 65 +- .../multi_parameters/tests/test_optimize.rs | 758 ----------------- .../src/optimization/dag/solo_key/optimize.rs | 3 + .../src/optimization/mod.rs | 17 + .../v0-parameters/benches/benchmark.rs | 3 + .../v0-parameters/src/lib.rs | 7 + docs/howto/configure.md | 4 +- .../concrete/fhe/compilation/configuration.py | 11 + .../concrete/fhe/compilation/server.py | 1 + frontends/concrete-python/tests/conftest.py | 69 ++ .../tests/execution/test_composition.py | 39 + 31 files changed, 1295 insertions(+), 816 deletions(-) rename compilers/concrete-optimizer/concrete-optimizer/src/optimization/dag/multi_parameters/{optimize.rs => optimize/mod.rs} (96%) create mode 100644 compilers/concrete-optimizer/concrete-optimizer/src/optimization/dag/multi_parameters/optimize/tests.rs delete mode 100644 compilers/concrete-optimizer/concrete-optimizer/src/optimization/dag/multi_parameters/tests/test_optimize.rs create mode 100644 frontends/concrete-python/tests/execution/test_composition.py diff --git a/compilers/concrete-compiler/compiler/include/concretelang/Support/V0Parameters.h b/compilers/concrete-compiler/compiler/include/concretelang/Support/V0Parameters.h index be7ae0dd4..7b8c5172d 100644 --- a/compilers/concrete-compiler/compiler/include/concretelang/Support/V0Parameters.h +++ b/compilers/concrete-compiler/compiler/include/concretelang/Support/V0Parameters.h @@ -80,6 +80,7 @@ constexpr concrete_optimizer::Encoding DEFAULT_ENCODING = constexpr bool DEFAULT_CACHE_ON_DISK = true; constexpr uint32_t DEFAULT_CIPHERTEXT_MODULUS_LOG = 64; constexpr uint32_t DEFAULT_FFT_PRECISION = 53; +constexpr bool DEFAULT_COMPOSABLE = false; /// The strategy of the crypto optimization enum Strategy { @@ -111,6 +112,7 @@ struct Config { bool cache_on_disk; uint32_t ciphertext_modulus_log; uint32_t fft_precision; + bool composable; }; constexpr Config DEFAULT_CONFIG = { @@ -126,6 +128,7 @@ constexpr Config DEFAULT_CONFIG = { DEFAULT_CACHE_ON_DISK, DEFAULT_CIPHERTEXT_MODULUS_LOG, DEFAULT_FFT_PRECISION, + DEFAULT_COMPOSABLE, }; using Dag = rust::Box; diff --git a/compilers/concrete-compiler/compiler/include/concretelang/TestLib/TestCircuit.h b/compilers/concrete-compiler/compiler/include/concretelang/TestLib/TestCircuit.h index d4b521450..9f56bf8a7 100644 --- a/compilers/concrete-compiler/compiler/include/concretelang/TestLib/TestCircuit.h +++ b/compilers/concrete-compiler/compiler/include/concretelang/TestLib/TestCircuit.h @@ -113,6 +113,28 @@ public: return processedOutputs; } + Result> compose_n_times(std::vector inputs, + size_t n) { + // preprocess arguments + auto preparedArgs = std::vector(); + OUTCOME_TRY(auto clientCircuit, getClientCircuit()); + for (size_t i = 0; i < inputs.size(); i++) { + OUTCOME_TRY(auto preparedInput, clientCircuit.prepareInput(inputs[i], i)); + preparedArgs.push_back(preparedInput); + } + // Call server multiple times in a row + for (size_t i = 0; i < n; i++) { + OUTCOME_TRY(preparedArgs, callServer(preparedArgs)); + } + // postprocess arguments + std::vector processedOutputs(preparedArgs.size()); + for (size_t i = 0; i < processedOutputs.size(); i++) { + OUTCOME_TRY(processedOutputs[i], + clientCircuit.processOutput(preparedArgs[i], i)); + } + return processedOutputs; + } + Result> callServer(std::vector inputs) { std::vector returns; diff --git a/compilers/concrete-compiler/compiler/lib/Bindings/Python/CompilerAPIModule.cpp b/compilers/concrete-compiler/compiler/lib/Bindings/Python/CompilerAPIModule.cpp index 3c7e99d8c..df6ac89a6 100644 --- a/compilers/concrete-compiler/compiler/lib/Bindings/Python/CompilerAPIModule.cpp +++ b/compilers/concrete-compiler/compiler/lib/Bindings/Python/CompilerAPIModule.cpp @@ -111,6 +111,10 @@ void mlir::concretelang::python::populateCompilerAPISubmodule( [](CompilationOptions &options, double global_p_error) { options.optimizerConfig.global_p_error = global_p_error; }) + .def("set_composable", + [](CompilationOptions &options, bool composable) { + options.optimizerConfig.composable = composable; + }) .def("set_security_level", [](CompilationOptions &options, int security_level) { options.optimizerConfig.security = security_level; diff --git a/compilers/concrete-compiler/compiler/lib/Bindings/Python/concrete/compiler/compilation_options.py b/compilers/concrete-compiler/compiler/lib/Bindings/Python/concrete/compiler/compilation_options.py index d09edc29a..4a1ad1839 100644 --- a/compilers/concrete-compiler/compiler/lib/Bindings/Python/concrete/compiler/compilation_options.py +++ b/compilers/concrete-compiler/compiler/lib/Bindings/Python/concrete/compiler/compilation_options.py @@ -61,6 +61,19 @@ class CompilationOptions(WrapperCpp): # pylint: enable=arguments-differ + def set_composable(self, composable: bool): + """Set option for composition. + + Args: + composable (bool): whether to turn it on or off + + Raises: + TypeError: if the value to set is not boolean + """ + if not isinstance(composable, bool): + raise TypeError("can't set the option to a non-boolean value") + self.cpp().set_composable(composable) + def set_auto_parallelize(self, auto_parallelize: bool): """Set option for auto parallelization. diff --git a/compilers/concrete-compiler/compiler/lib/Support/V0Parameters.cpp b/compilers/concrete-compiler/compiler/lib/Support/V0Parameters.cpp index 68a8a3085..932007365 100644 --- a/compilers/concrete-compiler/compiler/lib/Support/V0Parameters.cpp +++ b/compilers/concrete-compiler/compiler/lib/Support/V0Parameters.cpp @@ -35,7 +35,7 @@ concrete_optimizer::Options options_from_config(optimizer::Config config) { /* .cache_on_disk = */ config.cache_on_disk, /* .ciphertext_modulus_log = */ config.ciphertext_modulus_log, /* .fft_precision = */ config.fft_precision, - }; + /* .composable = */ config.composable}; return options; } diff --git a/compilers/concrete-compiler/compiler/src/main.cpp b/compilers/concrete-compiler/compiler/src/main.cpp index fc3a740b3..132b96693 100644 --- a/compilers/concrete-compiler/compiler/src/main.cpp +++ b/compilers/concrete-compiler/compiler/src/main.cpp @@ -314,6 +314,12 @@ llvm::cl::opt optimizerNoCacheOnDisk( "cache issues."), llvm::cl::init(false)); +llvm::cl::opt optimizerAllowComposition( + "optimizer-allow-composition", + llvm::cl::desc("Optimizer is parameterized to allow calling the circuit on " + "its own output without decryptions."), + llvm::cl::init(false)); + llvm::cl::list fhelinalgTileSizes( "fhelinalg-tile-sizes", llvm::cl::desc( @@ -480,6 +486,7 @@ cmdlineCompilationOptions() { options.optimizerConfig.key_sharing = cmdline::optimizerKeySharing; options.optimizerConfig.encoding = cmdline::optimizerEncoding; options.optimizerConfig.cache_on_disk = !cmdline::optimizerNoCacheOnDisk; + options.optimizerConfig.composable = cmdline::optimizerAllowComposition; if (!std::isnan(options.optimizerConfig.global_p_error) && options.optimizerConfig.strategy == optimizer::Strategy::V0) { diff --git a/compilers/concrete-compiler/compiler/tests/end_to_end_tests/end_to_end_jit_test.cc b/compilers/concrete-compiler/compiler/tests/end_to_end_tests/end_to_end_jit_test.cc index 64c2cec2c..5a58a443b 100644 --- a/compilers/concrete-compiler/compiler/tests/end_to_end_tests/end_to_end_jit_test.cc +++ b/compilers/concrete-compiler/compiler/tests/end_to_end_tests/end_to_end_jit_test.cc @@ -363,3 +363,33 @@ func.func @main(%arg0: tensor<2x!FHE.eint<7>>, %arg1: tensor<2xi8>, %acc: ASSERT_EQ(lambda({arg0, arg1, acc}), 76_u64); } + +TEST(CompileAndRunComposed, compose_add_eint) { + checkedJit(testCircuit, R"XXX( +func.func @main(%arg0: !FHE.eint<3>) -> !FHE.eint<3> { + %cst_1 = arith.constant 1 : i4 + %cst_2 = arith.constant dense<[0, 1, 2, 3, 4, 5, 6, 7]> : tensor<8xi64> + %1 = "FHE.add_eint_int"(%arg0, %cst_1) : (!FHE.eint<3>, i4) -> !FHE.eint<3> + %2 = "FHE.apply_lookup_table"(%1, %cst_2): (!FHE.eint<3>, tensor<8xi64>) -> (!FHE.eint<3>) + return %2: !FHE.eint<3> +} +)XXX", + "main", false, DEFAULT_dataflowParallelize, + DEFAULT_loopParallelize, DEFAULT_batchTFHEOps, + DEFAULT_global_p_error, DEFAULT_chunkedIntegers, DEFAULT_chunkSize, + DEFAULT_chunkWidth, true); + auto lambda = [&](std::vector args, size_t n) { + return testCircuit.compose_n_times(args, n) + .value()[0] + .template getTensor() + .value()[0]; + }; + ASSERT_EQ(lambda({Tensor(0)}, 1), (uint64_t)1); + ASSERT_EQ(lambda({Tensor(0)}, 2), (uint64_t)2); + ASSERT_EQ(lambda({Tensor(0)}, 3), (uint64_t)3); + ASSERT_EQ(lambda({Tensor(0)}, 4), (uint64_t)4); + ASSERT_EQ(lambda({Tensor(0)}, 5), (uint64_t)5); + ASSERT_EQ(lambda({Tensor(0)}, 6), (uint64_t)6); + ASSERT_EQ(lambda({Tensor(0)}, 7), (uint64_t)7); + ASSERT_EQ(lambda({Tensor(0)}, 8), (uint64_t)0); +} diff --git a/compilers/concrete-compiler/compiler/tests/end_to_end_tests/end_to_end_jit_test.h b/compilers/concrete-compiler/compiler/tests/end_to_end_tests/end_to_end_jit_test.h index 31ee38b1c..cfbc0bccd 100644 --- a/compilers/concrete-compiler/compiler/tests/end_to_end_tests/end_to_end_jit_test.h +++ b/compilers/concrete-compiler/compiler/tests/end_to_end_tests/end_to_end_jit_test.h @@ -4,6 +4,7 @@ #include "../tests_tools/keySetCache.h" #include "concretelang/Common/Error.h" #include "concretelang/Support/CompilerEngine.h" +#include "concretelang/Support/V0Parameters.h" #include "concretelang/TestLib/TestCircuit.h" #include "cstdlib" #include "end_to_end_test.h" @@ -24,6 +25,7 @@ double DEFAULT_global_p_error = TEST_ERROR_RATE; bool DEFAULT_chunkedIntegers = false; unsigned int DEFAULT_chunkSize = 4; unsigned int DEFAULT_chunkWidth = 2; +bool DEFAULT_composable = false; // Jit-compiles the function specified by `func` from `src` and // returns the corresponding lambda. Any compilation errors are caught @@ -37,7 +39,8 @@ inline Result internalCheckedJit( double global_p_error = DEFAULT_global_p_error, bool chunkedIntegers = DEFAULT_chunkedIntegers, unsigned int chunkSize = DEFAULT_chunkSize, - unsigned int chunkWidth = DEFAULT_chunkWidth) { + unsigned int chunkWidth = DEFAULT_chunkWidth, + bool composable = DEFAULT_composable) { auto options = mlir::concretelang::CompilationOptions(std::string(func.data())); @@ -60,6 +63,11 @@ inline Result internalCheckedJit( #endif #endif options.batchTFHEOps = batchTFHEOps; + if (composable) { + options.optimizerConfig.composable = composable; + options.optimizerConfig.strategy = + mlir::concretelang::optimizer::Strategy::DAG_MULTI; + } std::vector sources = {src.str()}; TestCircuit testCircuit(options); diff --git a/compilers/concrete-optimizer/Cargo.toml b/compilers/concrete-optimizer/Cargo.toml index 12c6310c8..3275854a6 100644 --- a/compilers/concrete-optimizer/Cargo.toml +++ b/compilers/concrete-optimizer/Cargo.toml @@ -8,6 +8,7 @@ members = [ "brute-force-optimizer", ] +resolver = "2" [profile.test] diff --git a/compilers/concrete-optimizer/charts/src/bin/norm2_complexity.rs b/compilers/concrete-optimizer/charts/src/bin/norm2_complexity.rs index 83cc4144b..3ac31bbbc 100644 --- a/compilers/concrete-optimizer/charts/src/bin/norm2_complexity.rs +++ b/compilers/concrete-optimizer/charts/src/bin/norm2_complexity.rs @@ -47,6 +47,7 @@ fn main() -> Result<(), Box> { ciphertext_modulus_log, fft_precision, complexity_model: &CpuComplexity::default(), + composable: false, }; let cache = decomposition::cache( diff --git a/compilers/concrete-optimizer/charts/src/bin/precision_complexity.rs b/compilers/concrete-optimizer/charts/src/bin/precision_complexity.rs index 068cf6efb..240d44b86 100644 --- a/compilers/concrete-optimizer/charts/src/bin/precision_complexity.rs +++ b/compilers/concrete-optimizer/charts/src/bin/precision_complexity.rs @@ -47,6 +47,7 @@ fn main() -> Result<(), Box> { ciphertext_modulus_log, fft_precision, complexity_model: &CpuComplexity::default(), + composable: false, }; let cache = decomposition::cache( diff --git a/compilers/concrete-optimizer/concrete-optimizer-cpp/src/concrete-optimizer.rs b/compilers/concrete-optimizer/concrete-optimizer-cpp/src/concrete-optimizer.rs index 25a11716f..325b45825 100644 --- a/compilers/concrete-optimizer/concrete-optimizer-cpp/src/concrete-optimizer.rs +++ b/compilers/concrete-optimizer/concrete-optimizer-cpp/src/concrete-optimizer.rs @@ -55,6 +55,7 @@ fn optimize_bootstrap(precision: u64, noise_factor: f64, options: ffi::Options) ciphertext_modulus_log: options.ciphertext_modulus_log, fft_precision: options.fft_precision, complexity_model: &CpuComplexity::default(), + composable: options.composable, }; let sum_size = 1; @@ -510,6 +511,7 @@ impl OperationDag { ciphertext_modulus_log: options.ciphertext_modulus_log, fft_precision: options.fft_precision, complexity_model: &CpuComplexity::default(), + composable: options.composable, }; let search_space = SearchSpace::default(processing_unit); @@ -534,6 +536,7 @@ impl OperationDag { ciphertext_modulus_log: options.ciphertext_modulus_log, fft_precision: options.fft_precision, complexity_model: &CpuComplexity::default(), + composable: options.composable, }; let search_space = SearchSpace::default(processing_unit); @@ -563,6 +566,7 @@ impl OperationDag { ciphertext_modulus_log: options.ciphertext_modulus_log, fft_precision: options.fft_precision, complexity_model: &CpuComplexity::default(), + composable: options.composable, }; let search_space = SearchSpace::default(processing_unit); @@ -770,6 +774,7 @@ mod ffi { pub cache_on_disk: bool, pub ciphertext_modulus_log: u32, pub fft_precision: u32, + pub composable: bool, } #[namespace = "concrete_optimizer::dag"] diff --git a/compilers/concrete-optimizer/concrete-optimizer-cpp/src/cpp/concrete-optimizer.cpp b/compilers/concrete-optimizer/concrete-optimizer-cpp/src/cpp/concrete-optimizer.cpp index de79044dd..62af7271f 100644 --- a/compilers/concrete-optimizer/concrete-optimizer-cpp/src/cpp/concrete-optimizer.cpp +++ b/compilers/concrete-optimizer/concrete-optimizer-cpp/src/cpp/concrete-optimizer.cpp @@ -1083,6 +1083,7 @@ struct Options final { bool cache_on_disk; ::std::uint32_t ciphertext_modulus_log; ::std::uint32_t fft_precision; + bool composable; using IsRelocatable = ::std::true_type; }; diff --git a/compilers/concrete-optimizer/concrete-optimizer-cpp/src/cpp/concrete-optimizer.hpp b/compilers/concrete-optimizer/concrete-optimizer-cpp/src/cpp/concrete-optimizer.hpp index c04e198b9..7a5371fef 100644 --- a/compilers/concrete-optimizer/concrete-optimizer-cpp/src/cpp/concrete-optimizer.hpp +++ b/compilers/concrete-optimizer/concrete-optimizer-cpp/src/cpp/concrete-optimizer.hpp @@ -1064,6 +1064,7 @@ struct Options final { bool cache_on_disk; ::std::uint32_t ciphertext_modulus_log; ::std::uint32_t fft_precision; + bool composable; using IsRelocatable = ::std::true_type; }; diff --git a/compilers/concrete-optimizer/concrete-optimizer-cpp/tests/src/main.cpp b/compilers/concrete-optimizer/concrete-optimizer-cpp/tests/src/main.cpp index 417e1bec5..6a47f442e 100644 --- a/compilers/concrete-optimizer/concrete-optimizer-cpp/tests/src/main.cpp +++ b/compilers/concrete-optimizer/concrete-optimizer-cpp/tests/src/main.cpp @@ -29,6 +29,7 @@ concrete_optimizer::Options default_options() { .cache_on_disk = true, .ciphertext_modulus_log = CIPHERTEXT_MODULUS_LOG, .fft_precision = 53, + .composable = false }; } diff --git a/compilers/concrete-optimizer/concrete-optimizer/src/dag/unparametrized.rs b/compilers/concrete-optimizer/concrete-optimizer/src/dag/unparametrized.rs index 08b8e9629..7535a3075 100644 --- a/compilers/concrete-optimizer/concrete-optimizer/src/dag/unparametrized.rs +++ b/compilers/concrete-optimizer/concrete-optimizer/src/dag/unparametrized.rs @@ -5,6 +5,7 @@ use crate::dag::operator::{ dot_kind, DotKind, FunctionTable, LevelledComplexity, Operator, OperatorIndex, Precision, Shape, Weights, }; +use crate::optimization::dag::solo_key::analyze::extra_final_values_to_check; pub(crate) type UnparameterizedOperator = Operator; @@ -243,6 +244,24 @@ impl OperationDag { self.add_lut(rounded, table, out_precision) } + pub(crate) fn get_input_index_iter(&self) -> impl Iterator + '_ { + self.operators + .iter() + .enumerate() + .filter_map(|(index, op)| match op { + Operator::Input { .. } => Some(index), + _ => None, + }) + } + + pub(crate) fn get_output_index(&self) -> Vec { + return extra_final_values_to_check(self) + .iter() + .enumerate() + .filter_map(|(index, is_output)| is_output.then_some(index)) + .collect(); + } + fn infer_out_shape(&self, op: &UnparameterizedOperator) -> Shape { match op { Operator::Input { out_shape, .. } | Operator::LevelledOp { out_shape, .. } => { diff --git a/compilers/concrete-optimizer/concrete-optimizer/src/optimization/config.rs b/compilers/concrete-optimizer/concrete-optimizer/src/optimization/config.rs index d23c255b7..62775295c 100644 --- a/compilers/concrete-optimizer/concrete-optimizer/src/optimization/config.rs +++ b/compilers/concrete-optimizer/concrete-optimizer/src/optimization/config.rs @@ -18,6 +18,7 @@ pub struct Config<'a> { pub ciphertext_modulus_log: u32, pub fft_precision: u32, pub complexity_model: &'a dyn ComplexityModel, + pub composable: bool, } #[derive(Clone, Debug)] diff --git a/compilers/concrete-optimizer/concrete-optimizer/src/optimization/dag/multi_parameters/analyze.rs b/compilers/concrete-optimizer/concrete-optimizer/src/optimization/dag/multi_parameters/analyze.rs index 8dd329a77..3b45e73e6 100644 --- a/compilers/concrete-optimizer/concrete-optimizer/src/optimization/dag/multi_parameters/analyze.rs +++ b/compilers/concrete-optimizer/concrete-optimizer/src/optimization/dag/multi_parameters/analyze.rs @@ -15,6 +15,8 @@ use crate::optimization::dag::multi_parameters::symbolic_variance::SymbolicVaria use crate::optimization::dag::solo_key::analyze::{ extra_final_values_to_check, first, safe_noise_bound, }; +use crate::optimization::Err::NotComposable; +use crate::optimization::Result; use super::complexity::OperationsCount; use super::keys_spec; @@ -28,6 +30,7 @@ use DotKind as DK; type Op = Operator; +#[derive(Debug)] pub struct AnalyzedDag { pub operators: Vec, // Collect all operators ouput variances @@ -51,7 +54,8 @@ pub fn analyze( noise_config: &NoiseBoundConfig, p_cut: &Option, default_partition: PartitionIndex, -) -> AnalyzedDag { + composable: bool, +) -> Result { let (dag, instruction_rewrite_index) = expand_round_and_index_map(dag); let levelled_complexity = LevelledComplexity::ZERO; // The precision cut is chosen to work well with rounded pbs @@ -61,10 +65,35 @@ pub fn analyze( Some(p_cut) => p_cut.clone(), None => maximal_p_cut(&dag), }; - let partitions = partitionning_with_preferred(&dag, &p_cut, default_partition); + let partitions = partitionning_with_preferred(&dag, &p_cut, default_partition, composable); let instrs_partition = partitions.instrs_partition; let nb_partitions = partitions.nb_partitions; - let out_variances = out_variances(&dag, nb_partitions, &instrs_partition); + let mut out_variances = self::out_variances(&dag, nb_partitions, &instrs_partition, &None); + if composable { + // Verify that there is no input symbol in the symbolic variances of the outputs. + if !no_input_var_in_out_var(&dag, &out_variances, nb_partitions) { + return Err(NotComposable); + } + // Get the largest output out_variance + let largest_output_variances = dag + .get_output_index() + .into_iter() + .map(|index| out_variances[index].clone()) + .reduce(|lhs, rhs| { + lhs.into_iter() + .zip(rhs) + .map(|(lhsi, rhsi)| lhsi.max(&rhsi)) + .collect() + }) + .expect("Failed to get the largest output variance."); + // Re-compute the out variances with input variances overriden by input variances + out_variances = self::out_variances( + &dag, + nb_partitions, + &instrs_partition, + &Some(largest_output_variances), + ); + } let variance_constraints = collect_all_variance_constraints(&dag, noise_config, &instrs_partition, &out_variances); let undominated_variance_constraints = @@ -72,7 +101,7 @@ pub fn analyze( let operations_count_per_instrs = collect_operations_count(&dag, nb_partitions, &instrs_partition); let operations_count = sum_operations_count(&operations_count_per_instrs); - AnalyzedDag { + Ok(AnalyzedDag { operators: dag.operators, instruction_rewrite_index, nb_partitions, @@ -84,7 +113,28 @@ pub fn analyze( operations_count_per_instrs, operations_count, p_cut, - } + }) +} + +fn no_input_var_in_out_var( + dag: &unparametrized::OperationDag, + symbolic_variances: &[Vec], + nb_partitions: usize, +) -> bool { + // let a = dag.get_output_index() + // .iter() + // .flat_map(|index| symbolic_variances[*index].iter()).collect::>(); + // println!("symb_variances: {:?}", a); + + dag.get_output_index() + .iter() + .flat_map(|index| symbolic_variances[*index].iter()) + .all(|sym_var| { + (0..nb_partitions).all(|partition| { + let coeff = sym_var.coeff_input(partition); + coeff == 0.0f64 || coeff.is_nan() + }) + }) } pub fn original_instrs_partition( @@ -166,7 +216,12 @@ fn out_variance( out_variances: &[Vec], nb_partitions: usize, instr_partition: &InstructionPartition, + input_override: Option>, ) -> Vec { + // If an override is given for input and we have an input node, we override. + if let (Some(overr), Op::Input { .. }) = (input_override, op) { + return overr; + } // one variance per partition, in case the result is converted let partition = instr_partition.instruction_partition; let out_variance_of = |input: &OperatorIndex| { @@ -234,6 +289,7 @@ fn out_variances( dag: &unparametrized::OperationDag, nb_partitions: usize, instrs_partition: &[InstructionPartition], + input_override: &Option>, ) -> Vec> { let nb_ops = dag.operators.len(); let mut out_variances = Vec::with_capacity(nb_ops); @@ -244,6 +300,7 @@ fn out_variances( &out_variances, nb_partitions, instr_partition, + input_override.clone(), ); out_variances.push(vf); } @@ -405,7 +462,7 @@ pub mod tests { default_partition: PartitionIndex, ) -> AnalyzedDag { let p_cut = PrecisionCut { p_cut: vec![2] }; - super::analyze(dag, &CONFIG, &Some(p_cut), default_partition) + super::analyze(dag, &CONFIG, &Some(p_cut), default_partition, false).unwrap() } #[allow(clippy::float_cmp)] @@ -472,6 +529,77 @@ pub mod tests { } } + #[test] + fn test_composition_with_input_fails() { + let mut dag = unparametrized::OperationDag::new(); + let _ = dag.add_input(1, Shape::number()); + let p_cut = PrecisionCut { p_cut: vec![2] }; + let res = super::analyze(&dag, &CONFIG, &Some(p_cut), LOW_PRECISION_PARTITION, true); + assert!(res.is_err()); + assert!(res.unwrap_err() == NotComposable); + } + + #[test] + fn test_composition_1_partition() { + let mut dag = unparametrized::OperationDag::new(); + let input1 = dag.add_input(1, Shape::number()); + let _ = dag.add_lut(input1, FunctionTable::UNKWOWN, 2); + let p_cut = PrecisionCut { p_cut: vec![2] }; + let dag = + super::analyze(&dag, &CONFIG, &Some(p_cut), LOW_PRECISION_PARTITION, true).unwrap(); + assert!(dag.nb_partitions == 1); + let actual_constraint_strings = dag + .variance_constraints + .iter() + .map(ToString::to_string) + .collect::>(); + let expected_constraint_strings = vec![ + "1σ²Br[0] + 1σ²K[0] + 1σ²M[0] < (2²)**-5 (1bits partition:0 count:1, dom=10)", + "1σ²Br[0] < (2²)**-6 (2bits partition:0 count:1, dom=12)", + ]; + assert!(actual_constraint_strings == expected_constraint_strings); + } + + #[test] + fn test_composition_2_partitions() { + let mut dag = unparametrized::OperationDag::new(); + let input1 = dag.add_input(3, Shape::number()); + let lut1 = dag.add_lut(input1, FunctionTable::UNKWOWN, 6); + let lut3 = dag.add_lut(lut1, FunctionTable::UNKWOWN, 3); + let input2 = dag.add_dot([input1, lut3], [1, 1]); + let _ = dag.add_lut(input2, FunctionTable::UNKWOWN, 3); + let analyzed_dag = + super::analyze(&dag, &CONFIG, &None, LOW_PRECISION_PARTITION, true).unwrap(); + assert_eq!(analyzed_dag.nb_partitions, 2); + let actual_constraint_strings = analyzed_dag + .variance_constraints + .iter() + .map(ToString::to_string) + .collect::>(); + let expected_constraint_strings = vec![ + "1σ²Br[0] + 1σ²K[0] + 1σ²M[0] < (2²)**-7 (3bits partition:0 count:1, dom=14)", + "1σ²Br[0] + 1σ²K[0→1] + 1σ²M[1] < (2²)**-10 (6bits partition:1 count:1, dom=20)", + "1σ²Br[0] + 1σ²Br[1] + 1σ²FK[1→0] + 1σ²K[0] + 1σ²M[0] < (2²)**-7 (3bits partition:0 count:1, dom=14)", + "1σ²Br[0] < (2²)**-7 (3bits partition:0 count:1, dom=14)", + ]; + assert_eq!(actual_constraint_strings, expected_constraint_strings); + let partitions = vec![ + LOW_PRECISION_PARTITION, + LOW_PRECISION_PARTITION, + HIGH_PRECISION_PARTITION, + LOW_PRECISION_PARTITION, + LOW_PRECISION_PARTITION, + ]; + assert_eq!( + partitions, + analyzed_dag + .instrs_partition + .iter() + .map(|p| p.instruction_partition) + .collect::>() + ); + } + #[allow(clippy::needless_range_loop)] #[test] fn test_lut_sequence() { @@ -872,7 +1000,14 @@ pub mod tests { eprintln!("{}", dag.dump()); let p_cut = PrecisionCut { p_cut }; eprintln!("{p_cut}"); - let dag = super::analyze(&dag, &CONFIG, &Some(p_cut.clone()), LOW_PRECISION_PARTITION); + let dag = super::analyze( + &dag, + &CONFIG, + &Some(p_cut.clone()), + LOW_PRECISION_PARTITION, + false, + ) + .unwrap(); assert!(dag.nb_partitions == p_cut.p_cut.len() + 1); } } diff --git a/compilers/concrete-optimizer/concrete-optimizer/src/optimization/dag/multi_parameters/optimize.rs b/compilers/concrete-optimizer/concrete-optimizer/src/optimization/dag/multi_parameters/optimize/mod.rs similarity index 96% rename from compilers/concrete-optimizer/concrete-optimizer/src/optimization/dag/multi_parameters/optimize.rs rename to compilers/concrete-optimizer/concrete-optimizer/src/optimization/dag/multi_parameters/optimize/mod.rs index 2a3997425..c41648554 100644 --- a/compilers/concrete-optimizer/concrete-optimizer/src/optimization/dag/multi_parameters/optimize.rs +++ b/compilers/concrete-optimizer/concrete-optimizer/src/optimization/dag/multi_parameters/optimize/mod.rs @@ -3,6 +3,7 @@ use concrete_cpu_noise_model::gaussian_noise::noise::modulus_switching::estimate use crate::dag::unparametrized; use crate::noise_estimator::error; +use crate::optimization; use crate::optimization::config::{Config, NoiseBoundConfig, SearchSpace}; use crate::optimization::dag::multi_parameters::analyze::{analyze, AnalyzedDag}; use crate::optimization::dag::multi_parameters::fast_keyswitch; @@ -20,6 +21,7 @@ use crate::optimization::dag::multi_parameters::feasible::Feasible; use crate::optimization::dag::multi_parameters::partitions::PartitionIndex; use crate::optimization::dag::multi_parameters::precision_cut::PrecisionCut; use crate::optimization::dag::multi_parameters::{analyze, keys_spec}; +use crate::optimization::Err::{NoParametersFound, NotComposable}; use super::keys_spec::InstructionKeys; @@ -896,7 +898,7 @@ fn cross_partition(nb_partitions: usize) -> impl Iterator (0..nb_partitions).flat_map(move |a: usize| (0..nb_partitions).map(move |b: usize| (a, b))) } -#[allow(clippy::too_many_lines)] +#[allow(clippy::too_many_lines, clippy::missing_errors_doc)] pub fn optimize( dag: &unparametrized::OperationDag, config: Config, @@ -904,18 +906,18 @@ pub fn optimize( persistent_caches: &PersistDecompCaches, p_cut: &Option, default_partition: PartitionIndex, -) -> Option<(AnalyzedDag, Parameters)> { +) -> optimization::Result<(AnalyzedDag, Parameters)> { let ciphertext_modulus_log = config.ciphertext_modulus_log; let fft_precision = config.fft_precision; let security_level = config.security_level; + let composable = config.composable; let noise_config = NoiseBoundConfig { security_level, maximum_acceptable_error_probability: config.maximum_acceptable_error_probability, ciphertext_modulus_log, }; - let dag = analyze(dag, &noise_config, p_cut, default_partition); - + let dag = analyze(dag, &noise_config, p_cut, default_partition, composable)?; let kappa = error::sigma_scale_of_error_probability(config.maximum_acceptable_error_probability); @@ -971,7 +973,7 @@ pub fn optimize( params = new_params; if !params.is_feasible { if nb_partitions == 1 { - return None; + return Err(NoParametersFound); } if DEBUG { eprintln!( @@ -1019,7 +1021,7 @@ pub fn optimize( fix_point = params.clone(); } if best_params.is_none() { - return None; + return Err(NoParametersFound); } let best_params = best_params.unwrap(); sanity_check( @@ -1031,7 +1033,7 @@ pub fn optimize( &feasible, &complexity, ); - Some((dag, best_params)) + Ok((dag, best_params)) } fn used_tlu_keyswitch(dag: &AnalyzedDag) -> Vec> { @@ -1180,31 +1182,33 @@ pub fn optimize_to_circuit_solution( default_partition, ); #[allow(clippy::option_if_let_else)] - if let Some((dag, params)) = dag_and_params { - let ext_keys = keys_spec::ExpandedCircuitKeys::of(¶ms); - let instructions_keys = analyze::original_instrs_partition(&dag, &ext_keys); - let (ext_keys, instructions_keys) = if config.key_sharing { - let (ext_keys, key_sharing) = ext_keys.shared_keys(); - let instructions_keys = InstructionKeys::shared_keys(&instructions_keys, &key_sharing); - (ext_keys, instructions_keys) - } else { - (ext_keys, instructions_keys) - }; - let circuit_keys = ext_keys.compacted(); - keys_spec::CircuitSolution { - circuit_keys, - instructions_keys, - crt_decomposition: vec![], - complexity: params.complexity, - p_error: params.p_error, - global_p_error: params.global_p_error, - is_feasible: true, - error_msg: String::default(), + match dag_and_params { + Err(e) => keys_spec::CircuitSolution::no_solution(e.to_string()), + Ok((dag, params)) => { + let ext_keys = keys_spec::ExpandedCircuitKeys::of(¶ms); + let instructions_keys = analyze::original_instrs_partition(&dag, &ext_keys); + let (ext_keys, instructions_keys) = if config.key_sharing { + let (ext_keys, key_sharing) = ext_keys.shared_keys(); + let instructions_keys = + InstructionKeys::shared_keys(&instructions_keys, &key_sharing); + (ext_keys, instructions_keys) + } else { + (ext_keys, instructions_keys) + }; + let circuit_keys = ext_keys.compacted(); + keys_spec::CircuitSolution { + circuit_keys, + instructions_keys, + crt_decomposition: vec![], + complexity: params.complexity, + p_error: params.p_error, + global_p_error: params.global_p_error, + is_feasible: true, + error_msg: String::default(), + } } - } else { - keys_spec::CircuitSolution::no_solution("No crypto-parameters for the given constraints") } } #[cfg(test)] -include!("tests/test_optimize.rs"); +mod tests; diff --git a/compilers/concrete-optimizer/concrete-optimizer/src/optimization/dag/multi_parameters/optimize/tests.rs b/compilers/concrete-optimizer/concrete-optimizer/src/optimization/dag/multi_parameters/optimize/tests.rs new file mode 100644 index 000000000..0c76fc09e --- /dev/null +++ b/compilers/concrete-optimizer/concrete-optimizer/src/optimization/dag/multi_parameters/optimize/tests.rs @@ -0,0 +1,799 @@ +#![allow(clippy::float_cmp)] + +use once_cell::sync::Lazy; + +use super::*; +use crate::computing_cost::cpu::CpuComplexity; +use crate::config; +use crate::dag::operator::{FunctionTable, LevelledComplexity, Shape}; +use crate::dag::unparametrized; +use crate::optimization::dag::solo_key; +use crate::optimization::dag::solo_key::optimize::{add_v0_dag, v0_dag}; +use crate::optimization::decomposition; + +const CIPHERTEXT_MODULUS_LOG: u32 = 64; +const FFT_PRECISION: u32 = 53; + +static SHARED_CACHES: Lazy = Lazy::new(|| { + let processing_unit = config::ProcessingUnit::Cpu; + decomposition::cache( + 128, + processing_unit, + None, + true, + CIPHERTEXT_MODULUS_LOG, + FFT_PRECISION, + ) +}); + +const _4_SIGMA: f64 = 0.000_063_342_483_999_973; + +const LOW_PARTITION: PartitionIndex = 0; + +fn optimize( + dag: &unparametrized::OperationDag, + p_cut: &Option, + default_partition: usize, +) -> Option { + let config = Config { + security_level: 128, + maximum_acceptable_error_probability: _4_SIGMA, + key_sharing: true, + ciphertext_modulus_log: 64, + fft_precision: 53, + complexity_model: &CpuComplexity::default(), + composable: false, + }; + + let search_space = SearchSpace::default_cpu(); + super::optimize( + dag, + config, + &search_space, + &SHARED_CACHES, + p_cut, + default_partition, + ) + .map_or(None, |v| Some(v.1)) +} + +fn optimize_single(dag: &unparametrized::OperationDag) -> Option { + optimize(dag, &Some(PrecisionCut { p_cut: vec![] }), LOW_PARTITION) +} + +fn equiv_single(dag: &unparametrized::OperationDag) -> Option { + let sol_mono = solo_key::optimize::tests::optimize(dag); + let sol_multi = optimize_single(dag); + if sol_mono.best_solution.is_none() != sol_multi.is_none() { + eprintln!("Not same feasibility"); + return Some(false); + }; + if sol_multi.is_none() { + return None; + } + let equiv = + sol_mono.best_solution.unwrap().complexity == sol_multi.as_ref().unwrap().complexity; + if !equiv { + eprintln!("Not same complexity"); + eprintln!("Single: {:?}", sol_mono.best_solution.unwrap()); + eprintln!("Multi: {:?}", sol_multi.clone().unwrap().complexity); + eprintln!("Multi: {:?}", sol_multi.unwrap()); + } + Some(equiv) +} + +#[test] +fn optimize_simple_parameter_v0_dag() { + for precision in 1..11 { + for manp in 1..25 { + eprintln!("P M {precision} {manp}"); + let dag = v0_dag(0, precision, manp as f64); + if let Some(equiv) = equiv_single(&dag) { + assert!(equiv); + } else { + break; + } + } + } +} + +#[test] +fn optimize_simple_parameter_rounded_lut_2_layers() { + for accumulator_precision in 1..11 { + for precision in 1..accumulator_precision { + for manp in [1, 8, 16] { + eprintln!("CASE {accumulator_precision} {precision} {manp}"); + let dag = v0_dag(0, precision, manp as f64); + if let Some(equiv) = equiv_single(&dag) { + assert!(equiv); + } else { + break; + } + } + } + } +} + +fn equiv_2_single( + dag_multi: &unparametrized::OperationDag, + dag_1: &unparametrized::OperationDag, + dag_2: &unparametrized::OperationDag, +) -> Option { + let precision_max = dag_multi.out_precisions.iter().copied().max().unwrap(); + let p_cut = Some(PrecisionCut { + p_cut: vec![precision_max - 1], + }); + eprintln!("{dag_multi}"); + let sol_single_1 = solo_key::optimize::tests::optimize(dag_1); + let sol_single_2 = solo_key::optimize::tests::optimize(dag_2); + let sol_multi = optimize(dag_multi, &p_cut, LOW_PARTITION); + let sol_multi_1 = optimize(dag_1, &p_cut, LOW_PARTITION); + let sol_multi_2 = optimize(dag_2, &p_cut, LOW_PARTITION); + let feasible_1 = sol_single_1.best_solution.is_some(); + let feasible_2 = sol_single_2.best_solution.is_some(); + let feasible_multi = sol_multi.is_some(); + if (feasible_1 && feasible_2) != feasible_multi { + eprintln!("Not same feasibility {feasible_1} {feasible_2} {feasible_multi}"); + return Some(false); + } + if sol_multi.is_none() { + return None; + } + let sol_multi = sol_multi.unwrap(); + let sol_multi_1 = sol_multi_1.unwrap(); + let sol_multi_2 = sol_multi_2.unwrap(); + let cost_1 = sol_single_1.best_solution.unwrap().complexity; + let cost_2 = sol_single_2.best_solution.unwrap().complexity; + let cost_multi = sol_multi.complexity; + let equiv = cost_1 + cost_2 == cost_multi + && cost_1 == sol_multi_1.complexity + && cost_2 == sol_multi_2.complexity + && sol_multi.micro_params.ks[0][0].unwrap().decomp + == sol_multi_1.micro_params.ks[0][0].unwrap().decomp + && sol_multi.micro_params.ks[1][1].unwrap().decomp + == sol_multi_2.micro_params.ks[0][0].unwrap().decomp; + if !equiv { + eprintln!("Not same complexity"); + eprintln!("Multi: {cost_multi:?}"); + eprintln!("Added Single: {:?}", cost_1 + cost_2); + eprintln!("Single1: {:?}", sol_single_1.best_solution.unwrap()); + eprintln!("Single2: {:?}", sol_single_2.best_solution.unwrap()); + eprintln!("Multi: {sol_multi:?}"); + eprintln!("Multi1: {sol_multi_1:?}"); + eprintln!("Multi2: {sol_multi_2:?}"); + } + Some(equiv) +} + +#[test] +fn optimize_multi_independant_2_precisions() { + let sum_size = 0; + for precision1 in 1..11 { + for precision2 in (precision1 + 1)..11 { + for manp in [1, 8, 16] { + eprintln!("CASE {precision1} {precision2} {manp}"); + let noise_factor = manp as f64; + let mut dag_multi = v0_dag(sum_size, precision1, noise_factor); + add_v0_dag(&mut dag_multi, sum_size, precision2, noise_factor); + let dag_1 = v0_dag(sum_size, precision1, noise_factor); + let dag_2 = v0_dag(sum_size, precision2, noise_factor); + if let Some(equiv) = equiv_2_single(&dag_multi, &dag_1, &dag_2) { + assert!(equiv, "FAILED ON {precision1} {precision2} {manp}"); + } else { + break; + } + } + } + } +} + +fn dag_lut_sum_of_2_partitions_2_layer( + precision1: u8, + precision2: u8, + final_lut: bool, +) -> unparametrized::OperationDag { + let mut dag = unparametrized::OperationDag::new(); + let input1 = dag.add_input(precision1, Shape::number()); + let input2 = dag.add_input(precision2, Shape::number()); + let lut1 = dag.add_lut(input1, FunctionTable::UNKWOWN, precision1); + let lut2 = dag.add_lut(input2, FunctionTable::UNKWOWN, precision2); + let lut1 = dag.add_lut(lut1, FunctionTable::UNKWOWN, precision2); + let lut2 = dag.add_lut(lut2, FunctionTable::UNKWOWN, precision2); + let dot = dag.add_dot([lut1, lut2], [1, 1]); + if final_lut { + _ = dag.add_lut(dot, FunctionTable::UNKWOWN, precision1); + } + dag +} + +#[test] +fn optimize_multi_independant_2_partitions_finally_added() { + let default_partition = 0; + let single_precision_sol: Vec<_> = (0..11) + .map(|precision| { + let dag = dag_lut_sum_of_2_partitions_2_layer(precision, precision, false); + optimize_single(&dag) + }) + .collect(); + + for precision1 in 1..11 { + for precision2 in (precision1 + 1)..11 { + let p_cut = Some(PrecisionCut { + p_cut: vec![precision1], + }); + let dag_multi = dag_lut_sum_of_2_partitions_2_layer(precision1, precision2, false); + let sol_1 = single_precision_sol[precision1 as usize].clone(); + let sol_2 = single_precision_sol[precision2 as usize].clone(); + let sol_multi = optimize(&dag_multi, &p_cut, LOW_PARTITION); + let feasible_multi = sol_multi.is_some(); + let feasible_2 = sol_2.is_some(); + assert!(feasible_multi); + assert!(feasible_2); + let sol_multi = sol_multi.unwrap(); + let sol_1 = sol_1.unwrap(); + let sol_2 = sol_2.unwrap(); + assert!(sol_1.complexity < sol_multi.complexity); + if REAL_FAST_KS { + assert!(sol_multi.complexity < sol_2.complexity); + } + eprintln!("{:?}", sol_multi.micro_params.fks); + let fks_complexity = sol_multi.micro_params.fks[(default_partition + 1) % 2] + [default_partition] + .unwrap() + .complexity; + let sol_multi_without_fks = sol_multi.complexity - fks_complexity; + let perfect_complexity = (sol_1.complexity + sol_2.complexity) / 2.0; + if REAL_FAST_KS { + assert!(sol_multi.macro_params[1] == sol_2.macro_params[0]); + } + // The smallest the precision the more fks noise break partition independence + #[allow(clippy::collapsible_else_if)] + let maximal_relative_degratdation = if REAL_FAST_KS { + if precision1 < 4 { + 1.1 + } else if precision1 <= 7 { + 1.03 + } else { + 1.001 + } + } else { + if precision1 < 4 { + 1.5 + } else if precision1 <= 7 { + 1.8 + } else { + 1.6 + } + }; + assert!( + sol_multi_without_fks / perfect_complexity < maximal_relative_degratdation, + "{precision1} {precision2} {} < {maximal_relative_degratdation}", + sol_multi_without_fks / perfect_complexity + ); + } + } +} + +#[test] +fn optimize_multi_independant_2_partitions_finally_added_and_luted() { + let default_partition = 0; + let single_precision_sol: Vec<_> = (0..11) + .map(|precision| { + let dag = dag_lut_sum_of_2_partitions_2_layer(precision, precision, true); + optimize_single(&dag) + }) + .collect(); + for precision1 in 1..11 { + for precision2 in (precision1 + 1)..11 { + let p_cut = Some(PrecisionCut { + p_cut: vec![precision1], + }); + let dag_multi = dag_lut_sum_of_2_partitions_2_layer(precision1, precision2, true); + let sol_1 = single_precision_sol[precision1 as usize].clone(); + let sol_2 = single_precision_sol[precision2 as usize].clone(); + let sol_multi = optimize(&dag_multi, &p_cut, 0); + let feasible_multi = sol_multi.is_some(); + let feasible_2 = sol_2.is_some(); + assert!(feasible_multi); + assert!(feasible_2); + let sol_multi = sol_multi.unwrap(); + let sol_1 = sol_1.unwrap(); + let sol_2 = sol_2.unwrap(); + // The smallest the precision the more fks noise dominate + assert!(sol_1.complexity < sol_multi.complexity); + assert!(sol_multi.complexity < sol_2.complexity); + let fks_complexity = sol_multi.micro_params.fks[(default_partition + 1) % 2] + [default_partition] + .unwrap() + .complexity; + let sol_multi_without_fks = sol_multi.complexity - fks_complexity; + let perfect_complexity = (sol_1.complexity + sol_2.complexity) / 2.0; + let relative_degradation = sol_multi_without_fks / perfect_complexity; + #[allow(clippy::collapsible_else_if)] + let maxim_relative_degradation = if REAL_FAST_KS { + if precision1 < 4 { + 1.2 + } else if precision1 <= 7 { + 1.19 + } else { + 1.15 + } + } else { + if precision1 < 4 { + 1.45 + } else if precision1 <= 7 { + 1.8 + } else { + 1.6 + } + }; + assert!( + relative_degradation < maxim_relative_degradation, + "{precision1} {precision2} {}", + sol_multi_without_fks / perfect_complexity + ); + } + } +} + +fn optimize_rounded(dag: &unparametrized::OperationDag) -> Option { + let p_cut = Some(PrecisionCut { p_cut: vec![1] }); + let default_partition = 0; + optimize(dag, &p_cut, default_partition) +} + +fn dag_rounded_lut_2_layers( + accumulator_precision: usize, + precision: usize, +) -> unparametrized::OperationDag { + let out_precision = accumulator_precision as u8; + let rounded_precision = precision as u8; + let mut dag = unparametrized::OperationDag::new(); + let input1 = dag.add_input(precision as u8, Shape::number()); + let rounded1 = dag.add_expanded_rounded_lut( + input1, + FunctionTable::UNKWOWN, + rounded_precision, + out_precision, + ); + let rounded2 = dag.add_expanded_rounded_lut( + rounded1, + FunctionTable::UNKWOWN, + rounded_precision, + out_precision, + ); + let _rounded3 = + dag.add_expanded_rounded_lut(rounded2, FunctionTable::UNKWOWN, rounded_precision, 1); + dag +} + +fn test_optimize_v3_expanded_round( + precision_acc: usize, + precision_tlu: usize, + minimal_speedup: f64, +) { + let dag = dag_rounded_lut_2_layers(precision_acc, precision_tlu); + let sol_mono = solo_key::optimize::tests::optimize(&dag) + .best_solution + .unwrap(); + let sol = optimize_rounded(&dag).unwrap(); + let speedup = sol_mono.complexity / sol.complexity; + assert!( + speedup >= minimal_speedup, + "Speedup {speedup} smaller than {minimal_speedup} for {precision_acc}/{precision_tlu}" + ); + let expected_ks = [ + [true, true], // KS[0], KS[0->1] + [false, true], // KS[1] + ]; + let expected_fks = [ + [false, false], + [true, false], // FKS[1->0] + ]; + for (src, dst) in cross_partition(2) { + assert!(sol.micro_params.ks[src][dst].is_some() == expected_ks[src][dst]); + assert!(sol.micro_params.fks[src][dst].is_some() == expected_fks[src][dst]); + } +} + +#[test] +fn test_optimize_v3_expanded_round_16_8() { + if REAL_FAST_KS { + test_optimize_v3_expanded_round(16, 8, 5.5); + } else { + test_optimize_v3_expanded_round(16, 8, 3.9); + } +} + +#[test] +fn test_optimize_v3_expanded_round_16_6() { + if REAL_FAST_KS { + test_optimize_v3_expanded_round(16, 6, 3.3); + } else { + test_optimize_v3_expanded_round(16, 6, 2.6); + } +} + +#[test] +fn optimize_v3_direct_round() { + let mut dag = unparametrized::OperationDag::new(); + let input1 = dag.add_input(16, Shape::number()); + _ = dag.add_expanded_rounded_lut(input1, FunctionTable::UNKWOWN, 8, 16); + let sol = optimize_rounded(&dag).unwrap(); + let sol_mono = solo_key::optimize::tests::optimize(&dag) + .best_solution + .unwrap(); + let minimal_speedup = 8.6; + let speedup = sol_mono.complexity / sol.complexity; + assert!( + speedup >= minimal_speedup, + "Speedup {speedup} smaller than {minimal_speedup}" + ); +} + +#[test] +fn optimize_sign_extract() { + let precision = 8; + let high_precision = 16; + let mut dag = unparametrized::OperationDag::new(); + let complexity = LevelledComplexity::ZERO; + let free_small_input1 = dag.add_input(precision, Shape::number()); + let small_input1 = dag.add_lut(free_small_input1, FunctionTable::UNKWOWN, precision); + let small_input1 = dag.add_lut(small_input1, FunctionTable::UNKWOWN, high_precision); + let input1 = dag.add_levelled_op( + [small_input1], + complexity, + 1.0, + Shape::vector(1_000_000), + "comment", + ); + let rounded1 = dag.add_expanded_round(input1, 1); + let _rounded2 = dag.add_lut(rounded1, FunctionTable::UNKWOWN, 1); + let sol = optimize_rounded(&dag).unwrap(); + let sol_mono = solo_key::optimize::tests::optimize(&dag) + .best_solution + .unwrap(); + let speedup = sol_mono.complexity / sol.complexity; + let minimal_speedup = if REAL_FAST_KS { 80.0 } else { 30.0 }; + assert!( + speedup >= minimal_speedup, + "Speedup {speedup} smaller than {minimal_speedup}" + ); +} + +fn test_partition_chain(decreasing: bool) { + // tlu chain with decreasing precision (decreasing partition index) + // check that increasing partitionning gaves faster solutions + // check solution has the right structure + let mut dag = unparametrized::OperationDag::new(); + let min_precision = 6; + let max_precision = 8; + let mut input_precisions: Vec<_> = (min_precision..=max_precision).collect(); + if decreasing { + input_precisions.reverse(); + } + let mut lut_input = dag.add_input(input_precisions[0], Shape::number()); + for &out_precision in &input_precisions { + lut_input = dag.add_lut(lut_input, FunctionTable::UNKWOWN, out_precision); + } + lut_input = dag.add_lut( + lut_input, + FunctionTable::UNKWOWN, + *input_precisions.last().unwrap(), + ); + _ = dag.add_lut(lut_input, FunctionTable::UNKWOWN, min_precision); + let mut p_cut = PrecisionCut { p_cut: vec![] }; + let sol = optimize(&dag, &Some(p_cut.clone()), 0).unwrap(); + assert!(sol.macro_params.len() == 1); + let mut complexity = sol.complexity; + for &out_precision in &input_precisions { + if out_precision == max_precision { + // There is nothing to cut above max_precision + continue; + } + p_cut.p_cut.push(out_precision); + p_cut.p_cut.sort_unstable(); + eprintln!("PCUT {p_cut}"); + let sol = optimize(&dag, &Some(p_cut.clone()), 0).unwrap(); + let nb_partitions = sol.macro_params.len(); + assert!( + nb_partitions == (p_cut.p_cut.len() + 1), + "bad nb partitions {} {p_cut}", + sol.macro_params.len() + ); + assert!( + sol.complexity < complexity, + "{} < {complexity} {out_precision} / {max_precision}", + sol.complexity + ); + for (src, dst) in cross_partition(nb_partitions) { + let ks = sol.micro_params.ks[src][dst]; + eprintln!("{} {src} {dst}", ks.is_some()); + let expected_ks = (!decreasing || src == dst + 1) && (decreasing || src + 1 == dst) + || (src == dst && (src == 0 || src == nb_partitions - 1)); + assert!( + ks.is_some() == expected_ks, + "{:?} {:?}", + ks.is_some(), + expected_ks + ); + let fks = sol.micro_params.fks[src][dst]; + assert!(fks.is_none()); + } + complexity = sol.complexity; + } + let sol = optimize(&dag, &None, 0); + assert!(sol.unwrap().complexity == complexity); +} + +#[test] +fn test_partition_decreasing_chain() { + test_partition_chain(true); +} + +#[test] +fn test_partition_increasing_chain() { + test_partition_chain(true); +} + +const MAX_WEIGHT: &[u64] = &[ + // max v0 weight for each precision + 1_073_741_824, + 1_073_741_824, // 2**30, 1b + 536_870_912, // 2**29, 2b + 268_435_456, // 2**28, 3b + 67_108_864, // 2**26, 4b + 16_777_216, // 2**24, 5b + 4_194_304, // 2**22, 6b + 1_048_576, // 2**20, 7b + 262_144, // 2**18, 8b + 65_536, // 2**16, 9b + 16384, // 2**14, 10b + 2048, // 2**11, 11b +]; + +#[test] +fn test_independant_partitions_non_feasible_single_params() { + // generate hard circuit, non feasible with single parameters + // composed of independant partitions so we know the optimal result + let precisions = [0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11]; + let sum_size = 0; + let noise_factor = MAX_WEIGHT[precisions[0]] as f64; + let mut dag = v0_dag(sum_size, precisions[0] as u64, noise_factor); + let sol_single = optimize_single(&dag); + let mut optimal_complexity = sol_single.as_ref().unwrap().complexity; + let mut optimal_p_error = sol_single.unwrap().p_error; + for &out_precision in &precisions[1..] { + let noise_factor = MAX_WEIGHT[out_precision] as f64; + add_v0_dag(&mut dag, sum_size, out_precision as u64, noise_factor); + let sol_single = optimize_single(&v0_dag(sum_size, out_precision as u64, noise_factor)); + optimal_complexity += sol_single.as_ref().unwrap().complexity; + optimal_p_error += sol_single.as_ref().unwrap().p_error; + } + // check non feasible in single + let sol_single = solo_key::optimize::tests::optimize(&dag).best_solution; + assert!(sol_single.is_none()); + // solves in multi + let sol = optimize(&dag, &None, 0); + assert!(sol.is_some()); + let sol = sol.unwrap(); + // check optimality + assert!(sol.complexity / optimal_complexity < 1.0 + f64::EPSILON); + assert!(sol.p_error / optimal_p_error < 1.0 + f64::EPSILON); +} + +#[test] +fn test_chained_partitions_non_feasible_single_params() { + // generate hard circuit, non feasible with single parameters + let precisions = [0u8, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11]; + // Note: reversing chain have issues for connecting lower bits to 7 bits, there may be no feasible solution + let mut dag = unparametrized::OperationDag::new(); + let mut lut_input = dag.add_input(precisions[0], Shape::number()); + for out_precision in precisions { + let noise_factor = MAX_WEIGHT[*dag.out_precisions.last().unwrap() as usize] as f64; + lut_input = dag.add_levelled_op( + [lut_input], + LevelledComplexity::ZERO, + noise_factor, + Shape::number(), + "", + ); + lut_input = dag.add_lut(lut_input, FunctionTable::UNKWOWN, out_precision); + } + _ = dag.add_lut( + lut_input, + FunctionTable::UNKWOWN, + *precisions.last().unwrap(), + ); + let sol_single = solo_key::optimize::tests::optimize(&dag).best_solution; + assert!(sol_single.is_none()); + let sol = optimize(&dag, &None, 0); + assert!(sol.is_some()); +} + +#[test] +fn test_multi_rounded_fks_coherency() { + let mut dag = unparametrized::OperationDag::new(); + let input1 = dag.add_input(16, Shape::number()); + let reduced_8 = dag.add_expanded_rounded_lut(input1, FunctionTable::UNKWOWN, 8, 8); + let reduced_4 = dag.add_expanded_rounded_lut(input1, FunctionTable::UNKWOWN, 4, 8); + _ = dag.add_dot([reduced_8, reduced_4], [1, 1]); + let sol = optimize(&dag, &None, 0); + assert!(sol.is_some()); + let sol = sol.unwrap(); + for (src, dst) in cross_partition(sol.macro_params.len()) { + if let Some(fks) = sol.micro_params.fks[src][dst] { + assert!(fks.src_glwe_param == sol.macro_params[src].unwrap().glwe_params); + assert!(fks.dst_glwe_param == sol.macro_params[dst].unwrap().glwe_params); + } + } +} + +#[test] +fn test_levelled_only() { + let mut dag = unparametrized::OperationDag::new(); + let _ = dag.add_input(22, Shape::number()); + let config = Config { + security_level: 128, + maximum_acceptable_error_probability: _4_SIGMA, + key_sharing: true, + ciphertext_modulus_log: 64, + fft_precision: 53, + complexity_model: &CpuComplexity::default(), + composable: false, + }; + + let search_space = SearchSpace::default_cpu(); + let sol = + super::optimize_to_circuit_solution(&dag, config, &search_space, &SHARED_CACHES, &None); + let sol_mono = solo_key::optimize::tests::optimize(&dag) + .best_solution + .unwrap(); + assert!(sol.circuit_keys.secret_keys.len() == 1); + assert!(sol.circuit_keys.secret_keys[0].polynomial_size == sol_mono.glwe_polynomial_size); +} + +#[test] +fn test_big_secret_key_sharing() { + let mut dag = unparametrized::OperationDag::new(); + let input1 = dag.add_input(4, Shape::number()); + let input2 = dag.add_input(5, Shape::number()); + let input2 = dag.add_dot([input2], [128]); + let lut1 = dag.add_lut(input1, FunctionTable::UNKWOWN, 5); + let lut2 = dag.add_lut(input2, FunctionTable::UNKWOWN, 5); + let _ = dag.add_dot([lut1, lut2], [16, 1]); + let config_sharing = Config { + security_level: 128, + maximum_acceptable_error_probability: _4_SIGMA, + key_sharing: true, + ciphertext_modulus_log: 64, + fft_precision: 53, + complexity_model: &CpuComplexity::default(), + composable: false, + }; + let config_no_sharing = Config { + key_sharing: false, + ..config_sharing + }; + let mut search_space = SearchSpace::default_cpu(); + // eprintln!("{:?}", search_space); + search_space.glwe_dimensions = vec![1]; // forcing big key sharing + let sol_sharing = super::optimize_to_circuit_solution( + &dag, + config_sharing, + &search_space, + &SHARED_CACHES, + &None, + ); + eprintln!("NO SHARING"); + let sol_no_sharing = super::optimize_to_circuit_solution( + &dag, + config_no_sharing, + &search_space, + &SHARED_CACHES, + &None, + ); + let keys_sharing = sol_sharing.circuit_keys; + let keys_no_sharing = sol_no_sharing.circuit_keys; + assert!(keys_sharing.secret_keys.len() == 3); + assert!(keys_no_sharing.secret_keys.len() == 4); + assert!(keys_sharing.conversion_keyswitch_keys.is_empty()); + assert!(keys_no_sharing.conversion_keyswitch_keys.len() == 1); + assert!(keys_sharing.bootstrap_keys.len() == keys_no_sharing.bootstrap_keys.len()); + assert!(keys_sharing.keyswitch_keys.len() == keys_no_sharing.keyswitch_keys.len()); +} + +#[test] +fn test_big_and_small_secret_key() { + let mut dag = unparametrized::OperationDag::new(); + let input1 = dag.add_input(4, Shape::number()); + let input2 = dag.add_input(5, Shape::number()); + let input2 = dag.add_dot([input2], [128]); + let lut1 = dag.add_lut(input1, FunctionTable::UNKWOWN, 5); + let lut2 = dag.add_lut(input2, FunctionTable::UNKWOWN, 5); + let _ = dag.add_dot([lut1, lut2], [16, 1]); + let config_sharing = Config { + security_level: 128, + maximum_acceptable_error_probability: _4_SIGMA, + key_sharing: true, + ciphertext_modulus_log: 64, + fft_precision: 53, + complexity_model: &CpuComplexity::default(), + composable: false, + }; + let config_no_sharing = Config { + key_sharing: false, + ..config_sharing + }; + let mut search_space = SearchSpace::default_cpu(); + search_space.glwe_dimensions = vec![1]; // forcing big key sharing + search_space.internal_lwe_dimensions = vec![768]; // forcing small key sharing + let sol_sharing = super::optimize_to_circuit_solution( + &dag, + config_sharing, + &search_space, + &SHARED_CACHES, + &None, + ); + let sol_no_sharing = super::optimize_to_circuit_solution( + &dag, + config_no_sharing, + &search_space, + &SHARED_CACHES, + &None, + ); + let keys_sharing = sol_sharing.circuit_keys; + let keys_no_sharing = sol_no_sharing.circuit_keys; + assert!(keys_sharing.secret_keys.len() == 2); + assert!(keys_no_sharing.secret_keys.len() == 4); + assert!(keys_sharing.conversion_keyswitch_keys.is_empty()); + assert!(keys_no_sharing.conversion_keyswitch_keys.len() == 1); + // boostrap are merged due to same (level, base) + assert!(keys_sharing.bootstrap_keys.len() + 1 == keys_no_sharing.bootstrap_keys.len()); + // keyswitch are still different due to another (level, base) + assert!(keys_sharing.keyswitch_keys.len() == keys_no_sharing.keyswitch_keys.len()); +} + +#[test] +fn test_composition_2_partitions() { + let mut dag = unparametrized::OperationDag::new(); + let input1 = dag.add_input(3, Shape::number()); + let lut1 = dag.add_lut(input1, FunctionTable::UNKWOWN, 6); + let lut3 = dag.add_lut(lut1, FunctionTable::UNKWOWN, 3); + let input2 = dag.add_dot([input1, lut3], [1, 1]); + let _ = dag.add_lut(input2, FunctionTable::UNKWOWN, 3); + let normal_config = Config { + security_level: 128, + maximum_acceptable_error_probability: _4_SIGMA, + key_sharing: true, + ciphertext_modulus_log: 64, + fft_precision: 53, + complexity_model: &CpuComplexity::default(), + composable: false, + }; + let composed_config = Config { + security_level: 128, + maximum_acceptable_error_probability: _4_SIGMA, + key_sharing: false, + ciphertext_modulus_log: 64, + fft_precision: 53, + complexity_model: &CpuComplexity::default(), + composable: true, + }; + let search_space = SearchSpace::default_cpu(); + let normal_sol = super::optimize(&dag, normal_config, &search_space, &SHARED_CACHES, &None, 1) + .unwrap() + .1; + let composed_sol = super::optimize( + &dag, + composed_config, + &search_space, + &SHARED_CACHES, + &None, + 1, + ) + .unwrap() + .1; + assert!(composed_sol.is_feasible); + assert!(composed_sol.complexity > normal_sol.complexity); +} diff --git a/compilers/concrete-optimizer/concrete-optimizer/src/optimization/dag/multi_parameters/partitionning.rs b/compilers/concrete-optimizer/concrete-optimizer/src/optimization/dag/multi_parameters/partitionning.rs index 1fd8eeed0..1d28bf48b 100644 --- a/compilers/concrete-optimizer/concrete-optimizer/src/optimization/dag/multi_parameters/partitionning.rs +++ b/compilers/concrete-optimizer/concrete-optimizer/src/optimization/dag/multi_parameters/partitionning.rs @@ -46,7 +46,7 @@ impl Blocks { // Extract block of instructions connected by levelled ops. // This facilitates reasonning about conflicts on levelled ops. #[allow(clippy::match_same_arms)] -fn extract_levelled_block(dag: &unparametrized::OperationDag) -> Blocks { +fn extract_levelled_block(dag: &unparametrized::OperationDag, composable: bool) -> Blocks { let mut uf = UnionFind::new(dag.operators.len()); for (op_i, op) in dag.operators.iter().enumerate() { match op { @@ -64,6 +64,16 @@ fn extract_levelled_block(dag: &unparametrized::OperationDag) -> Blocks { Op::Round { .. } => unreachable!("Round should have been expanded"), }; } + if composable { + // Without knowledge of how outputs are forwarded to inputs, we can't do better than putting + // all inputs and outputs in the same partition. + let mut input_iter = dag.get_input_index_iter(); + let first_inp = input_iter.next().unwrap(); + dag.get_output_index() + .into_iter() + .chain(input_iter) + .for_each(|ind| uf.union(first_inp, ind)); + } Blocks::from(uf) } @@ -130,8 +140,9 @@ fn resolve_by_levelled_block( dag: &unparametrized::OperationDag, p_cut: &PrecisionCut, default_partition: PartitionIndex, + composable: bool, ) -> Partitions { - let blocks = extract_levelled_block(dag); + let blocks = extract_levelled_block(dag, composable); let constraints_by_blocks = levelled_blocks_constraints(dag, &blocks, p_cut); let present_partitions: HashSet = constraints_by_blocks .iter() @@ -225,11 +236,12 @@ pub fn partitionning_with_preferred( dag: &unparametrized::OperationDag, p_cut: &PrecisionCut, default_partition: PartitionIndex, + composable: bool, ) -> Partitions { if p_cut.p_cut.is_empty() { only_1_partition(dag) } else { - resolve_by_levelled_block(dag, p_cut, default_partition) + resolve_by_levelled_block(dag, p_cut, default_partition, composable) } } @@ -248,21 +260,22 @@ pub mod tests { PrecisionCut { p_cut: vec![2] } } - fn partitionning_no_p_cut(dag: &unparametrized::OperationDag) -> Partitions { + fn partitionning_no_p_cut(dag: &unparametrized::OperationDag, composable: bool) -> Partitions { let p_cut = PrecisionCut { p_cut: vec![] }; - partitionning_with_preferred(dag, &p_cut, LOW_PRECISION_PARTITION) + partitionning_with_preferred(dag, &p_cut, LOW_PRECISION_PARTITION, composable) } - fn partitionning(dag: &unparametrized::OperationDag) -> Partitions { - partitionning_with_preferred(dag, &default_p_cut(), LOW_PRECISION_PARTITION) + fn partitionning(dag: &unparametrized::OperationDag, composable: bool) -> Partitions { + partitionning_with_preferred(dag, &default_p_cut(), LOW_PRECISION_PARTITION, composable) } fn partitionning_with_preferred( dag: &unparametrized::OperationDag, p_cut: &PrecisionCut, default_partition: usize, + composable: bool, ) -> Partitions { - super::partitionning_with_preferred(dag, p_cut, default_partition) + super::partitionning_with_preferred(dag, p_cut, default_partition, composable) } pub fn show_partitionning( @@ -304,7 +317,7 @@ pub mod tests { let mut dag = unparametrized::OperationDag::new(); let input1 = dag.add_input(16, Shape::number()); _ = dag.add_expanded_rounded_lut(input1, FunctionTable::UNKWOWN, 4, 8); - let instrs_partition = partitionning_no_p_cut(&dag).instrs_partition; + let instrs_partition = partitionning_no_p_cut(&dag, false).instrs_partition; for instr_partition in instrs_partition { assert!(instr_partition.instruction_partition == LOW_PRECISION_PARTITION); assert!(instr_partition.no_transition()); @@ -315,13 +328,31 @@ pub mod tests { fn test_1_input_2_partitions() { let mut dag = unparametrized::OperationDag::new(); _ = dag.add_input(1, Shape::number()); - let partitions = partitionning(&dag); + let partitions = partitionning(&dag, false); assert!(partitions.nb_partitions == 1); let instrs_partition = partitions.instrs_partition; assert!(instrs_partition[0].instruction_partition == LOW_PRECISION_PARTITION); assert!(partitions.nb_partitions == 1); } + #[test] + fn test_2_partitions_with_without_compo() { + let mut dag = unparametrized::OperationDag::new(); + let input = dag.add_input(10, Shape::number()); + let lut1 = dag.add_lut(input, FunctionTable::UNKWOWN, 2); + let output = dag.add_lut(lut1, FunctionTable::UNKWOWN, 10); + let partitions = partitionning(&dag, false); + assert!( + partitions.instrs_partition[input.i].instruction_partition + != partitions.instrs_partition[output.i].instruction_partition + ); + let partitions = partitionning(&dag, true); + assert!( + partitions.instrs_partition[input.i].instruction_partition + == partitions.instrs_partition[output.i].instruction_partition + ); + } + #[test] fn test_2_lut_sequence() { let mut dag = unparametrized::OperationDag::new(); @@ -338,7 +369,7 @@ pub mod tests { expected_partitions.push(LOW_PRECISION_PARTITION); let lut5 = dag.add_lut(lut4, FunctionTable::UNKWOWN, 8); expected_partitions.push(HIGH_PRECISION_PARTITION); - let partitions = partitionning(&dag); + let partitions = partitionning(&dag, false); assert!(partitions.nb_partitions == 2); let instrs_partition = partitions.instrs_partition; let consider = |op_i: OperatorIndex| &instrs_partition[op_i.i]; @@ -359,7 +390,7 @@ pub mod tests { let input2 = dag.add_input(1, Shape::number()); let lut2 = dag.add_lut(input2, FunctionTable::UNKWOWN, 8); let _dot = dag.add_dot([input1, lut2], Weights::from([1, 1])); - let partitions = partitionning(&dag); + let partitions = partitionning(&dag, false); assert!(partitions.nb_partitions == 1); } @@ -370,7 +401,7 @@ pub mod tests { let input2 = dag.add_input(1, Shape::number()); let lut2 = dag.add_lut(input1, FunctionTable::UNKWOWN, 1); let _dot = dag.add_dot([input2, lut2], Weights::from([1, 1])); - let partitions = partitionning(&dag); + let partitions = partitionning(&dag, false); assert!(partitions.nb_partitions == 1); } @@ -382,7 +413,7 @@ pub mod tests { let lut1 = dag.add_lut(input1, FunctionTable::UNKWOWN, 8); let lut2 = dag.add_lut(input2, FunctionTable::UNKWOWN, 8); let dot = dag.add_dot([lut1, lut2], Weights::from([1, 1])); - let partitions = partitionning(&dag); + let partitions = partitionning(&dag, false); let consider = |op_i: OperatorIndex| &partitions.instrs_partition[op_i.i]; // input1 let p = consider(input1); @@ -438,7 +469,7 @@ pub mod tests { let lut1 = dag.add_lut(rounded1, FunctionTable::UNKWOWN, acc_precision); let rounded2 = dag.add_expanded_round(lut1, precision); let lut2 = dag.add_lut(rounded2, FunctionTable::UNKWOWN, acc_precision); - let partitions = partitionning(&dag); + let partitions = partitionning(&dag, false); let consider = |op_i| &partitions.instrs_partition[op_i]; // First layer is fully LOW_PRECISION_PARTITION for op_i in input1.i..lut1.i { @@ -488,7 +519,7 @@ pub mod tests { let rounded1 = dag.add_expanded_round(input1, precision); let rounded_layer: Vec<_> = ((input1.i + 1)..rounded1.i).collect(); let lut1 = dag.add_lut(rounded1, FunctionTable::UNKWOWN, acc_precision); - let partitions = partitionning(&dag); + let partitions = partitionning(&dag, false); let consider = |op_i: usize| &partitions.instrs_partition[op_i]; // First layer is fully HIGH_PRECISION_PARTITION @@ -549,7 +580,7 @@ pub mod tests { let rounded_layer = (input1.i + 1)..rounded1.i; let _lut1 = dag.add_lut(rounded1, FunctionTable::UNKWOWN, acc_precision); let partitions = - partitionning_with_preferred(&dag, &default_p_cut(), HIGH_PRECISION_PARTITION); + partitionning_with_preferred(&dag, &default_p_cut(), HIGH_PRECISION_PARTITION, false); show_partitionning(&dag, &partitions.instrs_partition); let consider = |op_i: usize| &partitions.instrs_partition[op_i]; diff --git a/compilers/concrete-optimizer/concrete-optimizer/src/optimization/dag/multi_parameters/tests/test_optimize.rs b/compilers/concrete-optimizer/concrete-optimizer/src/optimization/dag/multi_parameters/tests/test_optimize.rs deleted file mode 100644 index 018af83d7..000000000 --- a/compilers/concrete-optimizer/concrete-optimizer/src/optimization/dag/multi_parameters/tests/test_optimize.rs +++ /dev/null @@ -1,758 +0,0 @@ -// part of optimize.rs -#[cfg(test)] -mod tests { - #![allow(clippy::float_cmp)] - - use once_cell::sync::Lazy; - - use super::*; - use crate::computing_cost::cpu::CpuComplexity; - use crate::config; - use crate::dag::operator::{FunctionTable, LevelledComplexity, Shape}; - use crate::dag::unparametrized; - use crate::optimization::dag::solo_key; - use crate::optimization::dag::solo_key::optimize::{add_v0_dag, v0_dag}; - use crate::optimization::decomposition; - - const CIPHERTEXT_MODULUS_LOG: u32 = 64; - const FFT_PRECISION: u32 = 53; - - static SHARED_CACHES: Lazy = Lazy::new(|| { - let processing_unit = config::ProcessingUnit::Cpu; - decomposition::cache( - 128, - processing_unit, - None, - true, - CIPHERTEXT_MODULUS_LOG, - FFT_PRECISION, - ) - }); - - const _4_SIGMA: f64 = 0.000_063_342_483_999_973; - - const LOW_PARTITION: PartitionIndex = 0; - - fn optimize( - dag: &unparametrized::OperationDag, - p_cut: &Option, - default_partition: usize, - ) -> Option { - let config = Config { - security_level: 128, - maximum_acceptable_error_probability: _4_SIGMA, - key_sharing: true, - ciphertext_modulus_log: 64, - fft_precision: 53, - complexity_model: &CpuComplexity::default(), - }; - - let search_space = SearchSpace::default_cpu(); - super::optimize( - dag, - config, - &search_space, - &SHARED_CACHES, - p_cut, - default_partition, - ) - .map(|v| v.1) - } - - fn optimize_single(dag: &unparametrized::OperationDag) -> Option { - optimize(dag, &Some(PrecisionCut { p_cut: vec![] }), LOW_PARTITION) - } - - fn equiv_single(dag: &unparametrized::OperationDag) -> Option { - let sol_mono = solo_key::optimize::tests::optimize(dag); - let sol_multi = optimize_single(dag); - if sol_mono.best_solution.is_none() != sol_multi.is_none() { - eprintln!("Not same feasibility"); - return Some(false); - }; - if sol_multi.is_none() { - return None; - } - let equiv = - sol_mono.best_solution.unwrap().complexity == sol_multi.as_ref().unwrap().complexity; - if !equiv { - eprintln!("Not same complexity"); - eprintln!("Single: {:?}", sol_mono.best_solution.unwrap()); - eprintln!("Multi: {:?}", sol_multi.clone().unwrap().complexity); - eprintln!("Multi: {:?}", sol_multi.unwrap()); - } - Some(equiv) - } - - #[test] - fn optimize_simple_parameter_v0_dag() { - for precision in 1..11 { - for manp in 1..25 { - eprintln!("P M {precision} {manp}"); - let dag = v0_dag(0, precision, manp as f64); - if let Some(equiv) = equiv_single(&dag) { - assert!(equiv); - } else { - break; - } - } - } - } - - #[test] - fn optimize_simple_parameter_rounded_lut_2_layers() { - for accumulator_precision in 1..11 { - for precision in 1..accumulator_precision { - for manp in [1, 8, 16] { - eprintln!("CASE {accumulator_precision} {precision} {manp}"); - let dag = v0_dag(0, precision, manp as f64); - if let Some(equiv) = equiv_single(&dag) { - assert!(equiv); - } else { - break; - } - } - } - } - } - - fn equiv_2_single( - dag_multi: &unparametrized::OperationDag, - dag_1: &unparametrized::OperationDag, - dag_2: &unparametrized::OperationDag, - ) -> Option { - let precision_max = dag_multi.out_precisions.iter().copied().max().unwrap(); - let p_cut = Some(PrecisionCut { - p_cut: vec![precision_max - 1], - }); - eprintln!("{dag_multi}"); - let sol_single_1 = solo_key::optimize::tests::optimize(dag_1); - let sol_single_2 = solo_key::optimize::tests::optimize(dag_2); - let sol_multi = optimize(dag_multi, &p_cut, LOW_PARTITION); - let sol_multi_1 = optimize(dag_1, &p_cut, LOW_PARTITION); - let sol_multi_2 = optimize(dag_2, &p_cut, LOW_PARTITION); - let feasible_1 = sol_single_1.best_solution.is_some(); - let feasible_2 = sol_single_2.best_solution.is_some(); - let feasible_multi = sol_multi.is_some(); - if (feasible_1 && feasible_2) != feasible_multi { - eprintln!("Not same feasibility {feasible_1} {feasible_2} {feasible_multi}"); - return Some(false); - } - if sol_multi.is_none() { - return None; - } - let sol_multi = sol_multi.unwrap(); - let sol_multi_1 = sol_multi_1.unwrap(); - let sol_multi_2 = sol_multi_2.unwrap(); - let cost_1 = sol_single_1.best_solution.unwrap().complexity; - let cost_2 = sol_single_2.best_solution.unwrap().complexity; - let cost_multi = sol_multi.complexity; - let equiv = cost_1 + cost_2 == cost_multi - && cost_1 == sol_multi_1.complexity - && cost_2 == sol_multi_2.complexity - && sol_multi.micro_params.ks[0][0].unwrap().decomp - == sol_multi_1.micro_params.ks[0][0].unwrap().decomp - && sol_multi.micro_params.ks[1][1].unwrap().decomp - == sol_multi_2.micro_params.ks[0][0].unwrap().decomp; - if !equiv { - eprintln!("Not same complexity"); - eprintln!("Multi: {cost_multi:?}"); - eprintln!("Added Single: {:?}", cost_1 + cost_2); - eprintln!("Single1: {:?}", sol_single_1.best_solution.unwrap()); - eprintln!("Single2: {:?}", sol_single_2.best_solution.unwrap()); - eprintln!("Multi: {sol_multi:?}"); - eprintln!("Multi1: {sol_multi_1:?}"); - eprintln!("Multi2: {sol_multi_2:?}"); - } - Some(equiv) - } - - #[test] - fn optimize_multi_independant_2_precisions() { - let sum_size = 0; - for precision1 in 1..11 { - for precision2 in (precision1 + 1)..11 { - for manp in [1, 8, 16] { - eprintln!("CASE {precision1} {precision2} {manp}"); - let noise_factor = manp as f64; - let mut dag_multi = v0_dag(sum_size, precision1, noise_factor); - add_v0_dag(&mut dag_multi, sum_size, precision2, noise_factor); - let dag_1 = v0_dag(sum_size, precision1, noise_factor); - let dag_2 = v0_dag(sum_size, precision2, noise_factor); - if let Some(equiv) = equiv_2_single(&dag_multi, &dag_1, &dag_2) { - assert!(equiv, "FAILED ON {precision1} {precision2} {manp}"); - } else { - break; - } - } - } - } - } - - fn dag_lut_sum_of_2_partitions_2_layer( - precision1: u8, - precision2: u8, - final_lut: bool, - ) -> unparametrized::OperationDag { - let mut dag = unparametrized::OperationDag::new(); - let input1 = dag.add_input(precision1, Shape::number()); - let input2 = dag.add_input(precision2, Shape::number()); - let lut1 = dag.add_lut(input1, FunctionTable::UNKWOWN, precision1); - let lut2 = dag.add_lut(input2, FunctionTable::UNKWOWN, precision2); - let lut1 = dag.add_lut(lut1, FunctionTable::UNKWOWN, precision2); - let lut2 = dag.add_lut(lut2, FunctionTable::UNKWOWN, precision2); - let dot = dag.add_dot([lut1, lut2], [1, 1]); - if final_lut { - _ = dag.add_lut(dot, FunctionTable::UNKWOWN, precision1); - } - dag - } - - #[test] - fn optimize_multi_independant_2_partitions_finally_added() { - let default_partition = 0; - let single_precision_sol: Vec<_> = (0..11) - .map(|precision| { - let dag = dag_lut_sum_of_2_partitions_2_layer(precision, precision, false); - optimize_single(&dag) - }) - .collect(); - - for precision1 in 1..11 { - for precision2 in (precision1 + 1)..11 { - let p_cut = Some(PrecisionCut { - p_cut: vec![precision1], - }); - let dag_multi = dag_lut_sum_of_2_partitions_2_layer(precision1, precision2, false); - let sol_1 = single_precision_sol[precision1 as usize].clone(); - let sol_2 = single_precision_sol[precision2 as usize].clone(); - let sol_multi = optimize(&dag_multi, &p_cut, LOW_PARTITION); - let feasible_multi = sol_multi.is_some(); - let feasible_2 = sol_2.is_some(); - assert!(feasible_multi); - assert!(feasible_2); - let sol_multi = sol_multi.unwrap(); - let sol_1 = sol_1.unwrap(); - let sol_2 = sol_2.unwrap(); - assert!(sol_1.complexity < sol_multi.complexity); - if REAL_FAST_KS { - assert!(sol_multi.complexity < sol_2.complexity); - } - eprintln!("{:?}", sol_multi.micro_params.fks); - let fks_complexity = sol_multi.micro_params.fks[(default_partition + 1) % 2] - [default_partition] - .unwrap() - .complexity; - let sol_multi_without_fks = sol_multi.complexity - fks_complexity; - let perfect_complexity = (sol_1.complexity + sol_2.complexity) / 2.0; - if REAL_FAST_KS { - assert!(sol_multi.macro_params[1] == sol_2.macro_params[0]); - } - // The smallest the precision the more fks noise break partition independence - #[allow(clippy::collapsible_else_if)] - let maximal_relative_degratdation = if REAL_FAST_KS { - if precision1 < 4 { - 1.1 - } else if precision1 <= 7 { - 1.03 - } else { - 1.001 - } - } else { - if precision1 < 4 { - 1.5 - } else if precision1 <= 7 { - 1.8 - } else { - 1.6 - } - }; - assert!( - sol_multi_without_fks / perfect_complexity < maximal_relative_degratdation, - "{precision1} {precision2} {} < {maximal_relative_degratdation}", - sol_multi_without_fks / perfect_complexity - ); - } - } - } - - #[test] - fn optimize_multi_independant_2_partitions_finally_added_and_luted() { - let default_partition = 0; - let single_precision_sol: Vec<_> = (0..11) - .map(|precision| { - let dag = dag_lut_sum_of_2_partitions_2_layer(precision, precision, true); - optimize_single(&dag) - }) - .collect(); - for precision1 in 1..11 { - for precision2 in (precision1 + 1)..11 { - let p_cut = Some(PrecisionCut { - p_cut: vec![precision1], - }); - let dag_multi = dag_lut_sum_of_2_partitions_2_layer(precision1, precision2, true); - let sol_1 = single_precision_sol[precision1 as usize].clone(); - let sol_2 = single_precision_sol[precision2 as usize].clone(); - let sol_multi = optimize(&dag_multi, &p_cut, 0); - let feasible_multi = sol_multi.is_some(); - let feasible_2 = sol_2.is_some(); - assert!(feasible_multi); - assert!(feasible_2); - let sol_multi = sol_multi.unwrap(); - let sol_1 = sol_1.unwrap(); - let sol_2 = sol_2.unwrap(); - // The smallest the precision the more fks noise dominate - assert!(sol_1.complexity < sol_multi.complexity); - assert!(sol_multi.complexity < sol_2.complexity); - let fks_complexity = sol_multi.micro_params.fks[(default_partition + 1) % 2] - [default_partition] - .unwrap() - .complexity; - let sol_multi_without_fks = sol_multi.complexity - fks_complexity; - let perfect_complexity = (sol_1.complexity + sol_2.complexity) / 2.0; - let relative_degradation = sol_multi_without_fks / perfect_complexity; - #[allow(clippy::collapsible_else_if)] - let maxim_relative_degradation = if REAL_FAST_KS { - if precision1 < 4 { - 1.2 - } else if precision1 <= 7 { - 1.19 - } else { - 1.15 - } - } else { - if precision1 < 4 { - 1.45 - } else if precision1 <= 7 { - 1.8 - } else { - 1.6 - } - }; - assert!( - relative_degradation < maxim_relative_degradation, - "{precision1} {precision2} {}", - sol_multi_without_fks / perfect_complexity - ); - } - } - } - - fn optimize_rounded(dag: &unparametrized::OperationDag) -> Option { - let p_cut = Some(PrecisionCut { p_cut: vec![1] }); - let default_partition = 0; - optimize(dag, &p_cut, default_partition) - } - - fn dag_rounded_lut_2_layers( - accumulator_precision: usize, - precision: usize, - ) -> unparametrized::OperationDag { - let out_precision = accumulator_precision as u8; - let rounded_precision = precision as u8; - let mut dag = unparametrized::OperationDag::new(); - let input1 = dag.add_input(precision as u8, Shape::number()); - let rounded1 = dag.add_expanded_rounded_lut( - input1, - FunctionTable::UNKWOWN, - rounded_precision, - out_precision, - ); - let rounded2 = dag.add_expanded_rounded_lut( - rounded1, - FunctionTable::UNKWOWN, - rounded_precision, - out_precision, - ); - let _rounded3 = - dag.add_expanded_rounded_lut(rounded2, FunctionTable::UNKWOWN, rounded_precision, 1); - dag - } - - fn test_optimize_v3_expanded_round( - precision_acc: usize, - precision_tlu: usize, - minimal_speedup: f64, - ) { - let dag = dag_rounded_lut_2_layers(precision_acc, precision_tlu); - let sol_mono = solo_key::optimize::tests::optimize(&dag) - .best_solution - .unwrap(); - let sol = optimize_rounded(&dag).unwrap(); - let speedup = sol_mono.complexity / sol.complexity; - assert!( - speedup >= minimal_speedup, - "Speedup {speedup} smaller than {minimal_speedup} for {precision_acc}/{precision_tlu}" - ); - let expected_ks = [ - [true, true], // KS[0], KS[0->1] - [false, true], // KS[1] - ]; - let expected_fks = [ - [false, false], - [true, false], // FKS[1->0] - ]; - for (src, dst) in cross_partition(2) { - assert!(sol.micro_params.ks[src][dst].is_some() == expected_ks[src][dst]); - assert!(sol.micro_params.fks[src][dst].is_some() == expected_fks[src][dst]); - } - } - - #[test] - fn test_optimize_v3_expanded_round_16_8() { - if REAL_FAST_KS { - test_optimize_v3_expanded_round(16, 8, 5.5); - } else { - test_optimize_v3_expanded_round(16, 8, 3.9); - } - } - - #[test] - fn test_optimize_v3_expanded_round_16_6() { - if REAL_FAST_KS { - test_optimize_v3_expanded_round(16, 6, 3.3); - } else { - test_optimize_v3_expanded_round(16, 6, 2.6); - } - } - - #[test] - fn optimize_v3_direct_round() { - let mut dag = unparametrized::OperationDag::new(); - let input1 = dag.add_input(16, Shape::number()); - _ = dag.add_expanded_rounded_lut(input1, FunctionTable::UNKWOWN, 8, 16); - let sol = optimize_rounded(&dag).unwrap(); - let sol_mono = solo_key::optimize::tests::optimize(&dag) - .best_solution - .unwrap(); - let minimal_speedup = 8.6; - let speedup = sol_mono.complexity / sol.complexity; - assert!( - speedup >= minimal_speedup, - "Speedup {speedup} smaller than {minimal_speedup}" - ); - } - - #[test] - fn optimize_sign_extract() { - let precision = 8; - let high_precision = 16; - let mut dag = unparametrized::OperationDag::new(); - let complexity = LevelledComplexity::ZERO; - let free_small_input1 = dag.add_input(precision, Shape::number()); - let small_input1 = dag.add_lut(free_small_input1, FunctionTable::UNKWOWN, precision); - let small_input1 = dag.add_lut(small_input1, FunctionTable::UNKWOWN, high_precision); - let input1 = dag.add_levelled_op( - [small_input1], - complexity, - 1.0, - Shape::vector(1_000_000), - "comment", - ); - let rounded1 = dag.add_expanded_round(input1, 1); - let _rounded2 = dag.add_lut(rounded1, FunctionTable::UNKWOWN, 1); - let sol = optimize_rounded(&dag).unwrap(); - let sol_mono = solo_key::optimize::tests::optimize(&dag) - .best_solution - .unwrap(); - let speedup = sol_mono.complexity / sol.complexity; - let minimal_speedup = if REAL_FAST_KS { 80.0 } else { 30.0 }; - assert!( - speedup >= minimal_speedup, - "Speedup {speedup} smaller than {minimal_speedup}" - ); - } - - fn test_partition_chain(decreasing: bool) { - // tlu chain with decreasing precision (decreasing partition index) - // check that increasing partitionning gaves faster solutions - // check solution has the right structure - let mut dag = unparametrized::OperationDag::new(); - let min_precision = 6; - let max_precision = 8; - let mut input_precisions: Vec<_> = (min_precision..=max_precision).collect(); - if decreasing { - input_precisions.reverse(); - } - let mut lut_input = dag.add_input(input_precisions[0], Shape::number()); - for &out_precision in &input_precisions { - lut_input = dag.add_lut(lut_input, FunctionTable::UNKWOWN, out_precision); - } - lut_input = dag.add_lut( - lut_input, - FunctionTable::UNKWOWN, - *input_precisions.last().unwrap(), - ); - _ = dag.add_lut(lut_input, FunctionTable::UNKWOWN, min_precision); - let mut p_cut = PrecisionCut { p_cut: vec![] }; - let sol = optimize(&dag, &Some(p_cut.clone()), 0).unwrap(); - assert!(sol.macro_params.len() == 1); - let mut complexity = sol.complexity; - for &out_precision in &input_precisions { - if out_precision == max_precision { - // There is nothing to cut above max_precision - continue; - } - p_cut.p_cut.push(out_precision); - p_cut.p_cut.sort_unstable(); - eprintln!("PCUT {p_cut}"); - let sol = optimize(&dag, &Some(p_cut.clone()), 0).unwrap(); - let nb_partitions = sol.macro_params.len(); - assert!( - nb_partitions == (p_cut.p_cut.len() + 1), - "bad nb partitions {} {p_cut}", - sol.macro_params.len() - ); - assert!( - sol.complexity < complexity, - "{} < {complexity} {out_precision} / {max_precision}", - sol.complexity - ); - for (src, dst) in cross_partition(nb_partitions) { - let ks = sol.micro_params.ks[src][dst]; - eprintln!("{} {src} {dst}", ks.is_some()); - let expected_ks = (!decreasing || src == dst + 1) && (decreasing || src + 1 == dst) - || (src == dst && (src == 0 || src == nb_partitions - 1)); - assert!( - ks.is_some() == expected_ks, - "{:?} {:?}", - ks.is_some(), - expected_ks - ); - let fks = sol.micro_params.fks[src][dst]; - assert!(fks.is_none()); - } - complexity = sol.complexity; - } - let sol = optimize(&dag, &None, 0); - assert!(sol.unwrap().complexity == complexity); - } - - #[test] - fn test_partition_decreasing_chain() { - test_partition_chain(true); - } - - #[test] - fn test_partition_increasing_chain() { - test_partition_chain(true); - } - - const MAX_WEIGHT: &[u64] = &[ - // max v0 weight for each precision - 1_073_741_824, - 1_073_741_824, // 2**30, 1b - 536_870_912, // 2**29, 2b - 268_435_456, // 2**28, 3b - 67_108_864, // 2**26, 4b - 16_777_216, // 2**24, 5b - 4_194_304, // 2**22, 6b - 1_048_576, // 2**20, 7b - 262_144, // 2**18, 8b - 65_536, // 2**16, 9b - 16384, // 2**14, 10b - 2048, // 2**11, 11b - ]; - - #[test] - fn test_independant_partitions_non_feasible_single_params() { - // generate hard circuit, non feasible with single parameters - // composed of independant partitions so we know the optimal result - let precisions = [0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11]; - let sum_size = 0; - let noise_factor = MAX_WEIGHT[precisions[0]] as f64; - let mut dag = v0_dag(sum_size, precisions[0] as u64, noise_factor); - let sol_single = optimize_single(&dag); - let mut optimal_complexity = sol_single.as_ref().unwrap().complexity; - let mut optimal_p_error = sol_single.unwrap().p_error; - for &out_precision in &precisions[1..] { - let noise_factor = MAX_WEIGHT[out_precision] as f64; - add_v0_dag(&mut dag, sum_size, out_precision as u64, noise_factor); - let sol_single = optimize_single(&v0_dag(sum_size, out_precision as u64, noise_factor)); - optimal_complexity += sol_single.as_ref().unwrap().complexity; - optimal_p_error += sol_single.as_ref().unwrap().p_error; - } - // check non feasible in single - let sol_single = solo_key::optimize::tests::optimize(&dag).best_solution; - assert!(sol_single.is_none()); - // solves in multi - let sol = optimize(&dag, &None, 0); - assert!(sol.is_some()); - let sol = sol.unwrap(); - // check optimality - assert!(sol.complexity / optimal_complexity < 1.0 + f64::EPSILON); - assert!(sol.p_error / optimal_p_error < 1.0 + f64::EPSILON); - } - - #[test] - fn test_chained_partitions_non_feasible_single_params() { - // generate hard circuit, non feasible with single parameters - let precisions = [0u8, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11]; - // Note: reversing chain have issues for connecting lower bits to 7 bits, there may be no feasible solution - let mut dag = unparametrized::OperationDag::new(); - let mut lut_input = dag.add_input(precisions[0], Shape::number()); - for out_precision in precisions { - let noise_factor = MAX_WEIGHT[*dag.out_precisions.last().unwrap() as usize] as f64; - lut_input = dag.add_levelled_op( - [lut_input], - LevelledComplexity::ZERO, - noise_factor, - Shape::number(), - "", - ); - lut_input = dag.add_lut(lut_input, FunctionTable::UNKWOWN, out_precision); - } - _ = dag.add_lut( - lut_input, - FunctionTable::UNKWOWN, - *precisions.last().unwrap(), - ); - let sol_single = solo_key::optimize::tests::optimize(&dag).best_solution; - assert!(sol_single.is_none()); - let sol = optimize(&dag, &None, 0); - assert!(sol.is_some()); - } - - #[test] - fn test_multi_rounded_fks_coherency() { - let mut dag = unparametrized::OperationDag::new(); - let input1 = dag.add_input(16, Shape::number()); - let reduced_8 = dag.add_expanded_rounded_lut(input1, FunctionTable::UNKWOWN, 8, 8); - let reduced_4 = dag.add_expanded_rounded_lut(input1, FunctionTable::UNKWOWN, 4, 8); - _ = dag.add_dot([reduced_8, reduced_4], [1, 1]); - let sol = optimize(&dag, &None, 0); - assert!(sol.is_some()); - let sol = sol.unwrap(); - for (src, dst) in cross_partition(sol.macro_params.len()) { - if let Some(fks) = sol.micro_params.fks[src][dst] { - assert!(fks.src_glwe_param == sol.macro_params[src].unwrap().glwe_params); - assert!(fks.dst_glwe_param == sol.macro_params[dst].unwrap().glwe_params); - } - } - } - - #[test] - fn test_levelled_only() { - let mut dag = unparametrized::OperationDag::new(); - let _ = dag.add_input(22, Shape::number()); - let config = Config { - security_level: 128, - maximum_acceptable_error_probability: _4_SIGMA, - key_sharing: true, - ciphertext_modulus_log: 64, - fft_precision: 53, - complexity_model: &CpuComplexity::default(), - }; - - let search_space = SearchSpace::default_cpu(); - let sol = super::optimize_to_circuit_solution( - &dag, - config, - &search_space, - &SHARED_CACHES, - &None, - ); - let sol_mono = solo_key::optimize::tests::optimize(&dag).best_solution.unwrap(); - assert! (sol.circuit_keys.secret_keys.len() == 1); - assert! (sol.circuit_keys.secret_keys[0].polynomial_size == sol_mono.glwe_polynomial_size); - } - - #[test] - fn test_big_secret_key_sharing() { - let mut dag = unparametrized::OperationDag::new(); - let input1 = dag.add_input(4, Shape::number()); - let input2 = dag.add_input(5, Shape::number()); - let input2 = dag.add_dot([input2], [128]); - let lut1 = dag.add_lut(input1, FunctionTable::UNKWOWN, 5); - let lut2 = dag.add_lut(input2, FunctionTable::UNKWOWN, 5); - let _ = dag.add_dot([lut1, lut2], [16, 1]); - let config_sharing = Config { - security_level: 128, - maximum_acceptable_error_probability: _4_SIGMA, - key_sharing: true, - ciphertext_modulus_log: 64, - fft_precision: 53, - complexity_model: &CpuComplexity::default(), - }; - let config_no_sharing = Config { - key_sharing: false, - ..config_sharing - }; - let mut search_space = SearchSpace::default_cpu(); - // eprintln!("{:?}", search_space); - search_space.glwe_dimensions = vec![1]; // forcing big key sharing - let sol_sharing = super::optimize_to_circuit_solution( - &dag, - config_sharing, - &search_space, - &SHARED_CACHES, - &None, - ); - eprintln!("NO SHARING"); - let sol_no_sharing = super::optimize_to_circuit_solution( - &dag, - config_no_sharing, - &search_space, - &SHARED_CACHES, - &None, - ); - let keys_sharing = sol_sharing.circuit_keys; - let keys_no_sharing = sol_no_sharing.circuit_keys; - assert!(keys_sharing.secret_keys.len() == 3); - assert!(keys_no_sharing.secret_keys.len() == 4); - assert!(keys_sharing.conversion_keyswitch_keys.is_empty()); - assert!(keys_no_sharing.conversion_keyswitch_keys.len() == 1); - assert!(keys_sharing.bootstrap_keys.len() == keys_no_sharing.bootstrap_keys.len()); - assert!(keys_sharing.keyswitch_keys.len() == keys_no_sharing.keyswitch_keys.len()); - } - - #[test] - fn test_big_and_small_secret_key() { - let mut dag = unparametrized::OperationDag::new(); - let input1 = dag.add_input(4, Shape::number()); - let input2 = dag.add_input(5, Shape::number()); - let input2 = dag.add_dot([input2], [128]); - let lut1 = dag.add_lut(input1, FunctionTable::UNKWOWN, 5); - let lut2 = dag.add_lut(input2, FunctionTable::UNKWOWN, 5); - let _ = dag.add_dot([lut1, lut2], [16, 1]); - let config_sharing = Config { - security_level: 128, - maximum_acceptable_error_probability: _4_SIGMA, - key_sharing: true, - ciphertext_modulus_log: 64, - fft_precision: 53, - complexity_model: &CpuComplexity::default(), - }; - let config_no_sharing = Config { - key_sharing: false, - ..config_sharing - }; - let mut search_space = SearchSpace::default_cpu(); - search_space.glwe_dimensions = vec![1]; // forcing big key sharing - search_space.internal_lwe_dimensions = vec![768]; // forcing small key sharing - let sol_sharing = super::optimize_to_circuit_solution( - &dag, - config_sharing, - &search_space, - &SHARED_CACHES, - &None, - ); - let sol_no_sharing = super::optimize_to_circuit_solution( - &dag, - config_no_sharing, - &search_space, - &SHARED_CACHES, - &None, - ); - let keys_sharing = sol_sharing.circuit_keys; - let keys_no_sharing = sol_no_sharing.circuit_keys; - assert!(keys_sharing.secret_keys.len() == 2); - assert!(keys_no_sharing.secret_keys.len() == 4); - assert!(keys_sharing.conversion_keyswitch_keys.is_empty()); - assert!(keys_no_sharing.conversion_keyswitch_keys.len() == 1); - // boostrap are merged due to same (level, base) - assert!(keys_sharing.bootstrap_keys.len() + 1 == keys_no_sharing.bootstrap_keys.len()); - // keyswitch are still different due to another (level, base) - assert!(keys_sharing.keyswitch_keys.len() == keys_no_sharing.keyswitch_keys.len()); - } -} diff --git a/compilers/concrete-optimizer/concrete-optimizer/src/optimization/dag/solo_key/optimize.rs b/compilers/concrete-optimizer/concrete-optimizer/src/optimization/dag/solo_key/optimize.rs index 945018467..67e1322f4 100644 --- a/compilers/concrete-optimizer/concrete-optimizer/src/optimization/dag/solo_key/optimize.rs +++ b/compilers/concrete-optimizer/concrete-optimizer/src/optimization/dag/solo_key/optimize.rs @@ -480,6 +480,7 @@ pub(crate) mod tests { ciphertext_modulus_log: 64, fft_precision: 53, complexity_model: &CpuComplexity::default(), + composable: false, }; let search_space = SearchSpace::default_cpu(); @@ -525,6 +526,7 @@ pub(crate) mod tests { ciphertext_modulus_log: 64, fft_precision: 53, complexity_model: &CpuComplexity::default(), + composable: false, }; _ = optimize_v0( @@ -623,6 +625,7 @@ pub(crate) mod tests { ciphertext_modulus_log: 64, fft_precision: 53, complexity_model: &CpuComplexity::default(), + composable: false, }; let state = optimize(&dag); diff --git a/compilers/concrete-optimizer/concrete-optimizer/src/optimization/mod.rs b/compilers/concrete-optimizer/concrete-optimizer/src/optimization/mod.rs index e66098c05..2e608694d 100644 --- a/compilers/concrete-optimizer/concrete-optimizer/src/optimization/mod.rs +++ b/compilers/concrete-optimizer/concrete-optimizer/src/optimization/mod.rs @@ -3,3 +3,20 @@ pub mod config; pub mod dag; pub mod decomposition; pub mod wop_atomic_pattern; + +#[derive(Clone, Copy, Debug, PartialEq, Eq)] +pub enum Err { + NotComposable, + NoParametersFound, +} + +impl std::fmt::Display for Err { + fn fmt(&self, f: &mut std::fmt::Formatter) -> std::fmt::Result { + match self { + Self::NotComposable => write!(f, "NotComposable"), + Self::NoParametersFound => write!(f, "NoParametersFound"), + } + } +} + +type Result = std::result::Result; diff --git a/compilers/concrete-optimizer/v0-parameters/benches/benchmark.rs b/compilers/concrete-optimizer/v0-parameters/benches/benchmark.rs index 5ce7a8e11..7c5026144 100644 --- a/compilers/concrete-optimizer/v0-parameters/benches/benchmark.rs +++ b/compilers/concrete-optimizer/v0-parameters/benches/benchmark.rs @@ -20,6 +20,7 @@ fn v0_pbs_optimization(c: &mut Criterion) { cache_on_disk: true, ciphertext_modulus_log: 64, fft_precision: 53, + composable: false, }; c.bench_function("v0 PBS table generation", |b| { @@ -46,6 +47,7 @@ fn v0_pbs_optimization_simulate_graph(c: &mut Criterion) { cache_on_disk: true, ciphertext_modulus_log: 64, fft_precision: 53, + composable: false, }; c.bench_function("v0 PBS simulate dag table generation", |b| { @@ -72,6 +74,7 @@ fn v0_wop_pbs_optimization(c: &mut Criterion) { cache_on_disk: true, ciphertext_modulus_log: 64, fft_precision: 53, + composable: false, }; c.bench_function("v0 WoP-PBS table generation", |b| { diff --git a/compilers/concrete-optimizer/v0-parameters/src/lib.rs b/compilers/concrete-optimizer/v0-parameters/src/lib.rs index e3321efd5..4a2b39090 100644 --- a/compilers/concrete-optimizer/v0-parameters/src/lib.rs +++ b/compilers/concrete-optimizer/v0-parameters/src/lib.rs @@ -92,6 +92,9 @@ pub struct Args { #[clap(long, default_value_t = 53)] pub fft_precision: u32, + + #[clap(long)] + pub composable: bool, } pub fn all_results(args: &Args) -> Vec>> { @@ -100,6 +103,7 @@ pub fn all_results(args: &Args) -> Vec>> { let maximum_acceptable_error_probability = args.p_error; let security_level = args.security_level; let cache_on_disk = args.cache_on_disk; + let composable = args.composable; let search_space = SearchSpace { glwe_log_polynomial_sizes: (args.min_log_poly_size..=args.max_log_poly_size).collect(), @@ -122,6 +126,7 @@ pub fn all_results(args: &Args) -> Vec>> { ciphertext_modulus_log: args.ciphertext_modulus_log, fft_precision: args.fft_precision, complexity_model: &CpuComplexity::default(), + composable, }; let cache = decomposition::cache( @@ -295,6 +300,7 @@ mod tests { cache_on_disk: true, ciphertext_modulus_log: 64, fft_precision: 53, + composable: false, }; let mut actual_output = Vec::::new(); @@ -339,6 +345,7 @@ mod tests { cache_on_disk: true, ciphertext_modulus_log: 64, fft_precision: 53, + composable: false, }; let mut actual_output = Vec::::new(); diff --git a/docs/howto/configure.md b/docs/howto/configure.md index 0d621d3bf..2a0331f3a 100644 --- a/docs/howto/configure.md +++ b/docs/howto/configure.md @@ -74,8 +74,6 @@ Additional kwargs to `compile` functions take higher precedence. So if you set t * Use single precision for the whole circuit. * **parameter\_selection\_strategy**: (fhe.ParameterSelectionStrategy) = fhe.ParameterSelectionStrategy.MULTI * Set how cryptographic parameters are selected. -* **jit**: bool = False - * Enable JIT compilation. * **loop\_parallelize**: bool = True * Enable loop parallelization in the compiler. * **dataflow\_parallelize**: bool = False @@ -108,3 +106,5 @@ Additional kwargs to `compile` functions take higher precedence. So if you set t * Specify preference for bitwise strategies, can be a single strategy or an ordered list of strategies. See [Bitwise](../tutorial/bitwise.md) to learn more. * **shifts_with_promotion**: bool = True, * Enable promotions in encrypted shifts instead of casting in runtime. See [Bitwise#Shifts](../tutorial/bitwise.md#Shifts) to learn more. +* **composable**: bool = False, + * Specify that the function must be composable with itself. diff --git a/frontends/concrete-python/concrete/fhe/compilation/configuration.py b/frontends/concrete-python/concrete/fhe/compilation/configuration.py index dc7dafffa..752be4f9f 100644 --- a/frontends/concrete-python/concrete/fhe/compilation/configuration.py +++ b/frontends/concrete-python/concrete/fhe/compilation/configuration.py @@ -899,6 +899,7 @@ class Configuration: shifts_with_promotion: bool multivariate_strategy_preference: List[MultivariateStrategy] min_max_strategy_preference: List[MinMaxStrategy] + composable: bool def __init__( self, @@ -947,6 +948,7 @@ class Configuration: min_max_strategy_preference: Optional[ Union[MinMaxStrategy, str, List[Union[MinMaxStrategy, str]]] ] = None, + composable: bool = False, ): self.verbose = verbose self.compiler_debug_mode = compiler_debug_mode @@ -1023,6 +1025,7 @@ class Configuration: else [MinMaxStrategy.parse(min_max_strategy_preference)] ) ) + self.composable = composable self._validate() @@ -1077,6 +1080,7 @@ class Configuration: min_max_strategy_preference: Union[ Keep, Optional[Union[MinMaxStrategy, str, List[Union[MinMaxStrategy, str]]]] ] = KEEP, + composable: Union[Keep, bool] = KEEP, ) -> "Configuration": """ Get a new configuration from another one specified changes. @@ -1136,6 +1140,13 @@ class Configuration: message = "Dataflow parallelism is not available in macOS" raise RuntimeError(message) + if ( + self.composable + and self.parameter_selection_strategy != ParameterSelectionStrategy.MULTI + ): # pragma: no cover + message = "Composition can only be used with MULTI parameter selection strategy" + raise RuntimeError(message) + def __check_fork_consistency(): hints_init = get_type_hints(Configuration.__init__) diff --git a/frontends/concrete-python/concrete/fhe/compilation/server.py b/frontends/concrete-python/concrete/fhe/compilation/server.py index bc9fd5eba..e6e523d9a 100644 --- a/frontends/concrete-python/concrete/fhe/compilation/server.py +++ b/frontends/concrete-python/concrete/fhe/compilation/server.py @@ -113,6 +113,7 @@ class Server: options.set_dataflow_parallelize(configuration.dataflow_parallelize) options.set_auto_parallelize(configuration.auto_parallelize) options.set_compress_inputs(configuration.compress_inputs) + options.set_composable(configuration.composable) if configuration.auto_parallelize or configuration.dataflow_parallelize: # pylint: disable=c-extension-no-member,no-member diff --git a/frontends/concrete-python/tests/conftest.py b/frontends/concrete-python/tests/conftest.py index 60629ec8d..ab0742d17 100644 --- a/frontends/concrete-python/tests/conftest.py +++ b/frontends/concrete-python/tests/conftest.py @@ -345,6 +345,75 @@ Actual Output During Simulation """ raise AssertionError(message) + @staticmethod + def check_composition( + circuit: fhe.Circuit, function: Callable, sample: Union[Any, List[Any]], composed: int + ): + """ + Assert that `circuit` behaves the same as `function` on `sample` when composed. + + Args: + circuit (fhe.Circuit): + compiled circuit + + function (Callable): + original function + + sample (List[Any]): + inputs + + composed (int): + number of times to compose the function (call sequentially with inputs as outputs) + """ + + if not isinstance(sample, list): + sample = [sample] + + def sanitize(values): + if not isinstance(values, tuple): + values = (values,) + + result = [] + for value in values: + if isinstance(value, (bool, np.bool_)): + value = int(value) + elif isinstance(value, np.ndarray) and value.dtype == np.bool_: + value = value.astype(np.int64) + + result.append(value) + + return tuple(result) + + def compute_expected(sample): + for _ in range(composed): + sample = function(*sample) + if not isinstance(sample, tuple): + sample = (sample,) + return sample + + def compute_actual(sample): + inp = circuit.encrypt(*sample) + for _ in range(composed): + inp = circuit.run(inp) + out = circuit.decrypt(inp) + return out + + expected = sanitize(compute_expected(sample)) + actual = sanitize(compute_actual(sample)) + + if not all(np.array_equal(e, a) for e, a in zip(expected, actual)): + message = f""" + + Expected Output + =============== + {expected} + + Actual Output + ============= + {actual} + """ + raise AssertionError(message) + @staticmethod def check_str(expected: str, actual: str): """ diff --git a/frontends/concrete-python/tests/execution/test_composition.py b/frontends/concrete-python/tests/execution/test_composition.py new file mode 100644 index 000000000..38c39050b --- /dev/null +++ b/frontends/concrete-python/tests/execution/test_composition.py @@ -0,0 +1,39 @@ +""" +Tests of execution of add operation. +""" + +import numpy as np + +from concrete import fhe + + +def test_composed_inc(helpers): + """ + Test add where one of the operators is a constant. + """ + + if helpers.configuration().parameter_selection_strategy != fhe.ParameterSelectionStrategy.MULTI: + # Only valid with multi + return + + lut = fhe.LookupTable(list(range(32))) + + @fhe.compiler({"x": "encrypted"}) + def function(x): + return lut[x + 1] + + inputset = range(30) + circuit = function.compile(inputset, helpers.configuration()) + + samples = [ + [ + np.random.randint( + 0, + 31 - 6, + ) + ] + for _ in range(5) + ] + + for sample in samples: + helpers.check_composition(circuit, function, sample, composed=6)