mirror of
https://github.com/zama-ai/concrete.git
synced 2026-02-08 19:44:57 -05:00
feat: dot weights are signed
This commit is contained in:
@@ -247,7 +247,7 @@ impl OperationDag {
|
||||
|
||||
pub struct Weights(operator::Weights);
|
||||
|
||||
fn vector(weights: &[u64]) -> Box<Weights> {
|
||||
fn vector(weights: &[i64]) -> Box<Weights> {
|
||||
Box::new(Weights(operator::Weights::vector(weights)))
|
||||
}
|
||||
|
||||
@@ -320,7 +320,7 @@ mod ffi {
|
||||
type Weights;
|
||||
|
||||
#[namespace = "concrete_optimizer::weights"]
|
||||
fn vector(weights: &[u64]) -> Box<Weights>;
|
||||
fn vector(weights: &[i64]) -> Box<Weights>;
|
||||
}
|
||||
|
||||
#[derive(Clone, Copy)]
|
||||
|
||||
@@ -1099,7 +1099,7 @@ void concrete_optimizer$cxxbridge1$OperationDag$dump(const ::concrete_optimizer:
|
||||
|
||||
namespace weights {
|
||||
extern "C" {
|
||||
::concrete_optimizer::Weights *concrete_optimizer$weights$cxxbridge1$vector(::rust::Slice<const ::std::uint64_t> weights) noexcept;
|
||||
::concrete_optimizer::Weights *concrete_optimizer$weights$cxxbridge1$vector(::rust::Slice<const ::std::int64_t> weights) noexcept;
|
||||
} // extern "C"
|
||||
} // namespace weights
|
||||
|
||||
@@ -1172,7 +1172,7 @@ namespace dag {
|
||||
}
|
||||
|
||||
namespace weights {
|
||||
::rust::Box<::concrete_optimizer::Weights> vector(::rust::Slice<const ::std::uint64_t> weights) noexcept {
|
||||
::rust::Box<::concrete_optimizer::Weights> vector(::rust::Slice<const ::std::int64_t> weights) noexcept {
|
||||
return ::rust::Box<::concrete_optimizer::Weights>::from_raw(concrete_optimizer$weights$cxxbridge1$vector(weights));
|
||||
}
|
||||
} // namespace weights
|
||||
|
||||
@@ -1050,6 +1050,6 @@ namespace dag {
|
||||
} // namespace dag
|
||||
|
||||
namespace weights {
|
||||
::rust::Box<::concrete_optimizer::Weights> vector(::rust::Slice<const ::std::uint64_t> weights) noexcept;
|
||||
::rust::Box<::concrete_optimizer::Weights> vector(::rust::Slice<const ::std::int64_t> weights) noexcept;
|
||||
} // namespace weights
|
||||
} // namespace concrete_optimizer
|
||||
|
||||
@@ -42,7 +42,7 @@ void test_dag_no_lut() {
|
||||
|
||||
std::vector<concrete_optimizer::dag::OperatorIndex> inputs = {node1};
|
||||
|
||||
std::vector<uint64_t> weight_vec = {1, 1, 1};
|
||||
std::vector<int64_t> weight_vec = {1, 1, 1};
|
||||
|
||||
rust::cxxbridge1::Box<concrete_optimizer::Weights> weights =
|
||||
concrete_optimizer::weights::vector(slice(weight_vec));
|
||||
|
||||
@@ -13,7 +13,7 @@ pub enum DotKind {
|
||||
Unsupported,
|
||||
}
|
||||
|
||||
pub fn dot_kind(nb_inputs: u64, input_shape: &Shape, weights: &ClearTensor) -> DotKind {
|
||||
pub fn dot_kind<W>(nb_inputs: u64, input_shape: &Shape, weights: &ClearTensor<W>) -> DotKind {
|
||||
let inputs_shape = Shape::duplicated(nb_inputs, input_shape);
|
||||
if input_shape.is_number() && inputs_shape == weights.shape {
|
||||
DotKind::Simple
|
||||
|
||||
@@ -1,7 +1,7 @@
|
||||
use crate::dag::operator::tensor::{ClearTensor, Shape};
|
||||
use derive_more::{Add, AddAssign};
|
||||
|
||||
pub type Weights = ClearTensor;
|
||||
pub type Weights = ClearTensor<i64>;
|
||||
|
||||
#[derive(Clone, PartialEq, Eq, Debug)]
|
||||
pub struct FunctionTable {
|
||||
|
||||
@@ -1,3 +1,5 @@
|
||||
use std::{iter::Sum, ops::Mul};
|
||||
|
||||
use delegate::delegate;
|
||||
|
||||
use crate::utils::square_ref;
|
||||
@@ -59,20 +61,23 @@ impl Shape {
|
||||
}
|
||||
|
||||
#[derive(Clone, PartialEq, Eq, Debug)]
|
||||
pub struct ClearTensor {
|
||||
pub struct ClearTensor<W> {
|
||||
pub shape: Shape,
|
||||
pub values: Vec<u64>,
|
||||
pub values: Vec<W>,
|
||||
}
|
||||
|
||||
impl ClearTensor {
|
||||
pub fn number(value: u64) -> Self {
|
||||
impl<W> ClearTensor<W>
|
||||
where
|
||||
W: Copy + Mul<Output = W> + Sum<W>,
|
||||
{
|
||||
pub fn number(value: W) -> Self {
|
||||
Self {
|
||||
shape: Shape::number(),
|
||||
values: vec![value],
|
||||
}
|
||||
}
|
||||
|
||||
pub fn vector(values: impl Into<Vec<u64>>) -> Self {
|
||||
pub fn vector(values: impl Into<Vec<W>>) -> Self {
|
||||
let values = values.into();
|
||||
Self {
|
||||
shape: Shape::vector(values.len() as u64),
|
||||
@@ -89,7 +94,7 @@ impl ClearTensor {
|
||||
}
|
||||
}
|
||||
|
||||
pub fn square_norm2(&self) -> u64 {
|
||||
pub fn square_norm2(&self) -> W {
|
||||
self.values.iter().map(square_ref).sum()
|
||||
}
|
||||
}
|
||||
@@ -102,15 +107,21 @@ impl From<&Self> for Shape {
|
||||
}
|
||||
|
||||
// helps using shared weights
|
||||
impl From<&Self> for ClearTensor {
|
||||
impl<W> From<&Self> for ClearTensor<W>
|
||||
where
|
||||
W: Copy + Mul<Output = W> + Sum<W>,
|
||||
{
|
||||
fn from(item: &Self) -> Self {
|
||||
item.clone()
|
||||
}
|
||||
}
|
||||
|
||||
// helps using array as weights
|
||||
impl<const N: usize> From<[u64; N]> for ClearTensor {
|
||||
fn from(item: [u64; N]) -> Self {
|
||||
Self::vector(item)
|
||||
impl<const N: usize, W> From<[W; N]> for ClearTensor<W>
|
||||
where
|
||||
W: Copy + Mul<Output = W> + Sum<W>,
|
||||
{
|
||||
fn from(items: [W; N]) -> Self {
|
||||
Self::vector(items)
|
||||
}
|
||||
}
|
||||
|
||||
@@ -361,9 +361,10 @@ mod tests {
|
||||
use crate::computing_cost::cpu::CpuComplexity;
|
||||
use crate::dag::operator::{FunctionTable, Shape, Weights};
|
||||
use crate::noise_estimator::p_error::repeat_p_error;
|
||||
use crate::optimization::atomic_pattern;
|
||||
use crate::optimization::config::SearchSpace;
|
||||
use crate::optimization::dag::solo_key::symbolic_variance::VarianceOrigin;
|
||||
use crate::optimization::{atomic_pattern, decomposition};
|
||||
use crate::optimization::decomposition;
|
||||
use crate::utils::square;
|
||||
|
||||
fn small_relative_diff(v1: f64, v2: f64) -> bool {
|
||||
@@ -497,7 +498,7 @@ mod tests {
|
||||
}
|
||||
}
|
||||
|
||||
fn v0_parameter_ref_with_dot(precision: Precision, weight: u64) {
|
||||
fn v0_parameter_ref_with_dot(precision: Precision, weight: i64) {
|
||||
let security_level = 128;
|
||||
|
||||
let cache = decomposition::cache(security_level);
|
||||
@@ -596,7 +597,7 @@ mod tests {
|
||||
|
||||
fn lut_with_input_base_noise_better_than_lut_with_lut_base_noise(
|
||||
precision: Precision,
|
||||
weight: u64,
|
||||
weight: i64,
|
||||
cache: &PersistDecompCache,
|
||||
) {
|
||||
let weight = &Weights::number(weight);
|
||||
@@ -673,7 +674,7 @@ mod tests {
|
||||
}
|
||||
}
|
||||
|
||||
fn circuit(dag: &mut unparametrized::OperationDag, precision: Precision, weight: u64) {
|
||||
fn circuit(dag: &mut unparametrized::OperationDag, precision: Precision, weight: i64) {
|
||||
let input = dag.add_input(precision, Shape::number());
|
||||
let dot1 = dag.add_dot([input], [weight]);
|
||||
let lut1 = dag.add_lut(dot1, FunctionTable::UNKWOWN, precision);
|
||||
@@ -682,7 +683,7 @@ mod tests {
|
||||
}
|
||||
|
||||
fn assert_multi_precision_dominate_single(
|
||||
weight: u64,
|
||||
weight: i64,
|
||||
cache: &PersistDecompCache,
|
||||
) -> Option<bool> {
|
||||
let low_precision = 4u8;
|
||||
@@ -767,7 +768,7 @@ mod tests {
|
||||
|
||||
fn check_global_p_error_input(
|
||||
dim: u64,
|
||||
weight: u64,
|
||||
weight: i64,
|
||||
precision: u8,
|
||||
cache: &PersistDecompCache,
|
||||
) -> f64 {
|
||||
@@ -797,7 +798,7 @@ mod tests {
|
||||
|
||||
fn check_global_p_error_lut(
|
||||
depth: u64,
|
||||
weight: u64,
|
||||
weight: i64,
|
||||
precision: u8,
|
||||
cache: &PersistDecompCache,
|
||||
) {
|
||||
@@ -825,8 +826,8 @@ mod tests {
|
||||
depth: u64,
|
||||
precision_low: Precision,
|
||||
precision_high: Precision,
|
||||
weight_low: u64,
|
||||
weight_high: u64,
|
||||
weight_low: i64,
|
||||
weight_high: i64,
|
||||
) -> unparametrized::OperationDag {
|
||||
let shape = Shape::number();
|
||||
let mut dag = unparametrized::OperationDag::new();
|
||||
|
||||
@@ -48,6 +48,13 @@ impl std::ops::Mul<u64> for SymbolicVariance {
|
||||
}
|
||||
}
|
||||
|
||||
impl std::ops::Mul<i64> for SymbolicVariance {
|
||||
type Output = Self;
|
||||
fn mul(self, sq_weight: i64) -> Self {
|
||||
self * sq_weight as f64
|
||||
}
|
||||
}
|
||||
|
||||
impl SymbolicVariance {
|
||||
pub const ZERO: Self = Self {
|
||||
input_coeff: 0.0,
|
||||
|
||||
Reference in New Issue
Block a user