diff --git a/concrete-optimizer-cpp/src/concrete-optimizer.rs b/concrete-optimizer-cpp/src/concrete-optimizer.rs index 477dcefd0..e89bd8714 100644 --- a/concrete-optimizer-cpp/src/concrete-optimizer.rs +++ b/concrete-optimizer-cpp/src/concrete-optimizer.rs @@ -247,7 +247,7 @@ impl OperationDag { pub struct Weights(operator::Weights); -fn vector(weights: &[u64]) -> Box { +fn vector(weights: &[i64]) -> Box { Box::new(Weights(operator::Weights::vector(weights))) } @@ -320,7 +320,7 @@ mod ffi { type Weights; #[namespace = "concrete_optimizer::weights"] - fn vector(weights: &[u64]) -> Box; + fn vector(weights: &[i64]) -> Box; } #[derive(Clone, Copy)] diff --git a/concrete-optimizer-cpp/src/cpp/concrete-optimizer.cpp b/concrete-optimizer-cpp/src/cpp/concrete-optimizer.cpp index 67987f2c6..ce19f8e44 100644 --- a/concrete-optimizer-cpp/src/cpp/concrete-optimizer.cpp +++ b/concrete-optimizer-cpp/src/cpp/concrete-optimizer.cpp @@ -1099,7 +1099,7 @@ void concrete_optimizer$cxxbridge1$OperationDag$dump(const ::concrete_optimizer: namespace weights { extern "C" { -::concrete_optimizer::Weights *concrete_optimizer$weights$cxxbridge1$vector(::rust::Slice weights) noexcept; +::concrete_optimizer::Weights *concrete_optimizer$weights$cxxbridge1$vector(::rust::Slice weights) noexcept; } // extern "C" } // namespace weights @@ -1172,7 +1172,7 @@ namespace dag { } namespace weights { -::rust::Box<::concrete_optimizer::Weights> vector(::rust::Slice weights) noexcept { +::rust::Box<::concrete_optimizer::Weights> vector(::rust::Slice weights) noexcept { return ::rust::Box<::concrete_optimizer::Weights>::from_raw(concrete_optimizer$weights$cxxbridge1$vector(weights)); } } // namespace weights diff --git a/concrete-optimizer-cpp/src/cpp/concrete-optimizer.hpp b/concrete-optimizer-cpp/src/cpp/concrete-optimizer.hpp index 1d83a26fa..8a77ca946 100644 --- a/concrete-optimizer-cpp/src/cpp/concrete-optimizer.hpp +++ b/concrete-optimizer-cpp/src/cpp/concrete-optimizer.hpp @@ -1050,6 +1050,6 @@ namespace dag { } // namespace dag namespace weights { -::rust::Box<::concrete_optimizer::Weights> vector(::rust::Slice weights) noexcept; +::rust::Box<::concrete_optimizer::Weights> vector(::rust::Slice weights) noexcept; } // namespace weights } // namespace concrete_optimizer diff --git a/concrete-optimizer-cpp/tests/src/main.cpp b/concrete-optimizer-cpp/tests/src/main.cpp index d5b5eae84..7cda066bc 100644 --- a/concrete-optimizer-cpp/tests/src/main.cpp +++ b/concrete-optimizer-cpp/tests/src/main.cpp @@ -42,7 +42,7 @@ void test_dag_no_lut() { std::vector inputs = {node1}; - std::vector weight_vec = {1, 1, 1}; + std::vector weight_vec = {1, 1, 1}; rust::cxxbridge1::Box weights = concrete_optimizer::weights::vector(slice(weight_vec)); diff --git a/concrete-optimizer/src/dag/operator/dot_kind.rs b/concrete-optimizer/src/dag/operator/dot_kind.rs index 1dc47ac6a..55eb89232 100644 --- a/concrete-optimizer/src/dag/operator/dot_kind.rs +++ b/concrete-optimizer/src/dag/operator/dot_kind.rs @@ -13,7 +13,7 @@ pub enum DotKind { Unsupported, } -pub fn dot_kind(nb_inputs: u64, input_shape: &Shape, weights: &ClearTensor) -> DotKind { +pub fn dot_kind(nb_inputs: u64, input_shape: &Shape, weights: &ClearTensor) -> DotKind { let inputs_shape = Shape::duplicated(nb_inputs, input_shape); if input_shape.is_number() && inputs_shape == weights.shape { DotKind::Simple diff --git a/concrete-optimizer/src/dag/operator/operator.rs b/concrete-optimizer/src/dag/operator/operator.rs index 72014d63d..d08cdefab 100644 --- a/concrete-optimizer/src/dag/operator/operator.rs +++ b/concrete-optimizer/src/dag/operator/operator.rs @@ -1,7 +1,7 @@ use crate::dag::operator::tensor::{ClearTensor, Shape}; use derive_more::{Add, AddAssign}; -pub type Weights = ClearTensor; +pub type Weights = ClearTensor; #[derive(Clone, PartialEq, Eq, Debug)] pub struct FunctionTable { diff --git a/concrete-optimizer/src/dag/operator/tensor.rs b/concrete-optimizer/src/dag/operator/tensor.rs index 477b9e13e..a09533106 100644 --- a/concrete-optimizer/src/dag/operator/tensor.rs +++ b/concrete-optimizer/src/dag/operator/tensor.rs @@ -1,3 +1,5 @@ +use std::{iter::Sum, ops::Mul}; + use delegate::delegate; use crate::utils::square_ref; @@ -59,20 +61,23 @@ impl Shape { } #[derive(Clone, PartialEq, Eq, Debug)] -pub struct ClearTensor { +pub struct ClearTensor { pub shape: Shape, - pub values: Vec, + pub values: Vec, } -impl ClearTensor { - pub fn number(value: u64) -> Self { +impl ClearTensor +where + W: Copy + Mul + Sum, +{ + pub fn number(value: W) -> Self { Self { shape: Shape::number(), values: vec![value], } } - pub fn vector(values: impl Into>) -> Self { + pub fn vector(values: impl Into>) -> Self { let values = values.into(); Self { shape: Shape::vector(values.len() as u64), @@ -89,7 +94,7 @@ impl ClearTensor { } } - pub fn square_norm2(&self) -> u64 { + pub fn square_norm2(&self) -> W { self.values.iter().map(square_ref).sum() } } @@ -102,15 +107,21 @@ impl From<&Self> for Shape { } // helps using shared weights -impl From<&Self> for ClearTensor { +impl From<&Self> for ClearTensor +where + W: Copy + Mul + Sum, +{ fn from(item: &Self) -> Self { item.clone() } } // helps using array as weights -impl From<[u64; N]> for ClearTensor { - fn from(item: [u64; N]) -> Self { - Self::vector(item) +impl From<[W; N]> for ClearTensor +where + W: Copy + Mul + Sum, +{ + fn from(items: [W; N]) -> Self { + Self::vector(items) } } diff --git a/concrete-optimizer/src/optimization/dag/solo_key/optimize.rs b/concrete-optimizer/src/optimization/dag/solo_key/optimize.rs index ad38e61e0..bdf63f1a9 100644 --- a/concrete-optimizer/src/optimization/dag/solo_key/optimize.rs +++ b/concrete-optimizer/src/optimization/dag/solo_key/optimize.rs @@ -361,9 +361,10 @@ mod tests { use crate::computing_cost::cpu::CpuComplexity; use crate::dag::operator::{FunctionTable, Shape, Weights}; use crate::noise_estimator::p_error::repeat_p_error; + use crate::optimization::atomic_pattern; use crate::optimization::config::SearchSpace; use crate::optimization::dag::solo_key::symbolic_variance::VarianceOrigin; - use crate::optimization::{atomic_pattern, decomposition}; + use crate::optimization::decomposition; use crate::utils::square; fn small_relative_diff(v1: f64, v2: f64) -> bool { @@ -497,7 +498,7 @@ mod tests { } } - fn v0_parameter_ref_with_dot(precision: Precision, weight: u64) { + fn v0_parameter_ref_with_dot(precision: Precision, weight: i64) { let security_level = 128; let cache = decomposition::cache(security_level); @@ -596,7 +597,7 @@ mod tests { fn lut_with_input_base_noise_better_than_lut_with_lut_base_noise( precision: Precision, - weight: u64, + weight: i64, cache: &PersistDecompCache, ) { let weight = &Weights::number(weight); @@ -673,7 +674,7 @@ mod tests { } } - fn circuit(dag: &mut unparametrized::OperationDag, precision: Precision, weight: u64) { + fn circuit(dag: &mut unparametrized::OperationDag, precision: Precision, weight: i64) { let input = dag.add_input(precision, Shape::number()); let dot1 = dag.add_dot([input], [weight]); let lut1 = dag.add_lut(dot1, FunctionTable::UNKWOWN, precision); @@ -682,7 +683,7 @@ mod tests { } fn assert_multi_precision_dominate_single( - weight: u64, + weight: i64, cache: &PersistDecompCache, ) -> Option { let low_precision = 4u8; @@ -767,7 +768,7 @@ mod tests { fn check_global_p_error_input( dim: u64, - weight: u64, + weight: i64, precision: u8, cache: &PersistDecompCache, ) -> f64 { @@ -797,7 +798,7 @@ mod tests { fn check_global_p_error_lut( depth: u64, - weight: u64, + weight: i64, precision: u8, cache: &PersistDecompCache, ) { @@ -825,8 +826,8 @@ mod tests { depth: u64, precision_low: Precision, precision_high: Precision, - weight_low: u64, - weight_high: u64, + weight_low: i64, + weight_high: i64, ) -> unparametrized::OperationDag { let shape = Shape::number(); let mut dag = unparametrized::OperationDag::new(); diff --git a/concrete-optimizer/src/optimization/dag/solo_key/symbolic_variance.rs b/concrete-optimizer/src/optimization/dag/solo_key/symbolic_variance.rs index b0dea5545..1f5a83a85 100644 --- a/concrete-optimizer/src/optimization/dag/solo_key/symbolic_variance.rs +++ b/concrete-optimizer/src/optimization/dag/solo_key/symbolic_variance.rs @@ -48,6 +48,13 @@ impl std::ops::Mul for SymbolicVariance { } } +impl std::ops::Mul for SymbolicVariance { + type Output = Self; + fn mul(self, sq_weight: i64) -> Self { + self * sq_weight as f64 + } +} + impl SymbolicVariance { pub const ZERO: Self = Self { input_coeff: 0.0,