feat: dot weights are signed

This commit is contained in:
rudy
2022-09-13 16:07:05 +02:00
committed by rudy-6-4
parent 3b9535ac2f
commit 1f15162b39
9 changed files with 46 additions and 27 deletions

View File

@@ -247,7 +247,7 @@ impl OperationDag {
pub struct Weights(operator::Weights);
fn vector(weights: &[u64]) -> Box<Weights> {
fn vector(weights: &[i64]) -> Box<Weights> {
Box::new(Weights(operator::Weights::vector(weights)))
}
@@ -320,7 +320,7 @@ mod ffi {
type Weights;
#[namespace = "concrete_optimizer::weights"]
fn vector(weights: &[u64]) -> Box<Weights>;
fn vector(weights: &[i64]) -> Box<Weights>;
}
#[derive(Clone, Copy)]

View File

@@ -1099,7 +1099,7 @@ void concrete_optimizer$cxxbridge1$OperationDag$dump(const ::concrete_optimizer:
namespace weights {
extern "C" {
::concrete_optimizer::Weights *concrete_optimizer$weights$cxxbridge1$vector(::rust::Slice<const ::std::uint64_t> weights) noexcept;
::concrete_optimizer::Weights *concrete_optimizer$weights$cxxbridge1$vector(::rust::Slice<const ::std::int64_t> weights) noexcept;
} // extern "C"
} // namespace weights
@@ -1172,7 +1172,7 @@ namespace dag {
}
namespace weights {
::rust::Box<::concrete_optimizer::Weights> vector(::rust::Slice<const ::std::uint64_t> weights) noexcept {
::rust::Box<::concrete_optimizer::Weights> vector(::rust::Slice<const ::std::int64_t> weights) noexcept {
return ::rust::Box<::concrete_optimizer::Weights>::from_raw(concrete_optimizer$weights$cxxbridge1$vector(weights));
}
} // namespace weights

View File

@@ -1050,6 +1050,6 @@ namespace dag {
} // namespace dag
namespace weights {
::rust::Box<::concrete_optimizer::Weights> vector(::rust::Slice<const ::std::uint64_t> weights) noexcept;
::rust::Box<::concrete_optimizer::Weights> vector(::rust::Slice<const ::std::int64_t> weights) noexcept;
} // namespace weights
} // namespace concrete_optimizer

View File

@@ -42,7 +42,7 @@ void test_dag_no_lut() {
std::vector<concrete_optimizer::dag::OperatorIndex> inputs = {node1};
std::vector<uint64_t> weight_vec = {1, 1, 1};
std::vector<int64_t> weight_vec = {1, 1, 1};
rust::cxxbridge1::Box<concrete_optimizer::Weights> weights =
concrete_optimizer::weights::vector(slice(weight_vec));

View File

@@ -13,7 +13,7 @@ pub enum DotKind {
Unsupported,
}
pub fn dot_kind(nb_inputs: u64, input_shape: &Shape, weights: &ClearTensor) -> DotKind {
pub fn dot_kind<W>(nb_inputs: u64, input_shape: &Shape, weights: &ClearTensor<W>) -> DotKind {
let inputs_shape = Shape::duplicated(nb_inputs, input_shape);
if input_shape.is_number() && inputs_shape == weights.shape {
DotKind::Simple

View File

@@ -1,7 +1,7 @@
use crate::dag::operator::tensor::{ClearTensor, Shape};
use derive_more::{Add, AddAssign};
pub type Weights = ClearTensor;
pub type Weights = ClearTensor<i64>;
#[derive(Clone, PartialEq, Eq, Debug)]
pub struct FunctionTable {

View File

@@ -1,3 +1,5 @@
use std::{iter::Sum, ops::Mul};
use delegate::delegate;
use crate::utils::square_ref;
@@ -59,20 +61,23 @@ impl Shape {
}
#[derive(Clone, PartialEq, Eq, Debug)]
pub struct ClearTensor {
pub struct ClearTensor<W> {
pub shape: Shape,
pub values: Vec<u64>,
pub values: Vec<W>,
}
impl ClearTensor {
pub fn number(value: u64) -> Self {
impl<W> ClearTensor<W>
where
W: Copy + Mul<Output = W> + Sum<W>,
{
pub fn number(value: W) -> Self {
Self {
shape: Shape::number(),
values: vec![value],
}
}
pub fn vector(values: impl Into<Vec<u64>>) -> Self {
pub fn vector(values: impl Into<Vec<W>>) -> Self {
let values = values.into();
Self {
shape: Shape::vector(values.len() as u64),
@@ -89,7 +94,7 @@ impl ClearTensor {
}
}
pub fn square_norm2(&self) -> u64 {
pub fn square_norm2(&self) -> W {
self.values.iter().map(square_ref).sum()
}
}
@@ -102,15 +107,21 @@ impl From<&Self> for Shape {
}
// helps using shared weights
impl From<&Self> for ClearTensor {
impl<W> From<&Self> for ClearTensor<W>
where
W: Copy + Mul<Output = W> + Sum<W>,
{
fn from(item: &Self) -> Self {
item.clone()
}
}
// helps using array as weights
impl<const N: usize> From<[u64; N]> for ClearTensor {
fn from(item: [u64; N]) -> Self {
Self::vector(item)
impl<const N: usize, W> From<[W; N]> for ClearTensor<W>
where
W: Copy + Mul<Output = W> + Sum<W>,
{
fn from(items: [W; N]) -> Self {
Self::vector(items)
}
}

View File

@@ -361,9 +361,10 @@ mod tests {
use crate::computing_cost::cpu::CpuComplexity;
use crate::dag::operator::{FunctionTable, Shape, Weights};
use crate::noise_estimator::p_error::repeat_p_error;
use crate::optimization::atomic_pattern;
use crate::optimization::config::SearchSpace;
use crate::optimization::dag::solo_key::symbolic_variance::VarianceOrigin;
use crate::optimization::{atomic_pattern, decomposition};
use crate::optimization::decomposition;
use crate::utils::square;
fn small_relative_diff(v1: f64, v2: f64) -> bool {
@@ -497,7 +498,7 @@ mod tests {
}
}
fn v0_parameter_ref_with_dot(precision: Precision, weight: u64) {
fn v0_parameter_ref_with_dot(precision: Precision, weight: i64) {
let security_level = 128;
let cache = decomposition::cache(security_level);
@@ -596,7 +597,7 @@ mod tests {
fn lut_with_input_base_noise_better_than_lut_with_lut_base_noise(
precision: Precision,
weight: u64,
weight: i64,
cache: &PersistDecompCache,
) {
let weight = &Weights::number(weight);
@@ -673,7 +674,7 @@ mod tests {
}
}
fn circuit(dag: &mut unparametrized::OperationDag, precision: Precision, weight: u64) {
fn circuit(dag: &mut unparametrized::OperationDag, precision: Precision, weight: i64) {
let input = dag.add_input(precision, Shape::number());
let dot1 = dag.add_dot([input], [weight]);
let lut1 = dag.add_lut(dot1, FunctionTable::UNKWOWN, precision);
@@ -682,7 +683,7 @@ mod tests {
}
fn assert_multi_precision_dominate_single(
weight: u64,
weight: i64,
cache: &PersistDecompCache,
) -> Option<bool> {
let low_precision = 4u8;
@@ -767,7 +768,7 @@ mod tests {
fn check_global_p_error_input(
dim: u64,
weight: u64,
weight: i64,
precision: u8,
cache: &PersistDecompCache,
) -> f64 {
@@ -797,7 +798,7 @@ mod tests {
fn check_global_p_error_lut(
depth: u64,
weight: u64,
weight: i64,
precision: u8,
cache: &PersistDecompCache,
) {
@@ -825,8 +826,8 @@ mod tests {
depth: u64,
precision_low: Precision,
precision_high: Precision,
weight_low: u64,
weight_high: u64,
weight_low: i64,
weight_high: i64,
) -> unparametrized::OperationDag {
let shape = Shape::number();
let mut dag = unparametrized::OperationDag::new();

View File

@@ -48,6 +48,13 @@ impl std::ops::Mul<u64> for SymbolicVariance {
}
}
impl std::ops::Mul<i64> for SymbolicVariance {
type Output = Self;
fn mul(self, sq_weight: i64) -> Self {
self * sq_weight as f64
}
}
impl SymbolicVariance {
pub const ZERO: Self = Self {
input_coeff: 0.0,