feat: new dag with Lut/Dot/LevelledOp

This commit is contained in:
rudy-6-4
2022-05-10 16:46:51 +02:00
committed by GitHub
parent 9e5467294f
commit 0e6a9b01a0
11 changed files with 382 additions and 134 deletions

View File

@@ -6,6 +6,8 @@ edition = "2021"
# See more keys and their definitions at https://doc.rust-lang.org/cargo/reference/manifest.html
[dependencies]
delegate = "0.6.2"
derive_more = "0.99.17"
concrete-commons = { git = "ssh://git@github.com/zama-ai/concrete_internal.git", branch = "fix/optimizer_compat" }
concrete-npe = { git = "ssh://git@github.com/zama-ai/concrete_internal.git", branch = "fix/optimizer_compat" }
statrs = "0.15.0"

View File

@@ -1,7 +1,10 @@
use std::collections::HashSet;
use crate::graph::operator::{Operator, OperatorIndex};
use crate::graph::parameter_indexed::{AtomicPatternParametersIndexed, InputParameterIndexed};
use crate::graph::parameter_indexed::{
InputParameterIndexed, LutParametersIndexed, OperatorParameterIndexed,
};
use crate::graph::unparametrized::UnparameterizedOperator;
use crate::graph::{parameter_indexed, range_parametrized, unparametrized};
use crate::parameters::{
BrDecompositionParameterRanges, BrDecompositionParameters, GlweParameterRanges, GlweParameters,
@@ -92,46 +95,65 @@ impl Range {
#[must_use]
#[allow(clippy::needless_pass_by_value)]
pub fn minimal_unify(_g: unparametrized::AtomicPatternDag) -> parameter_indexed::AtomicPatternDag {
pub fn minimal_unify(_g: unparametrized::OperationDag) -> parameter_indexed::OperationDag {
todo!()
}
fn convert_maximal(
op: Operator<(), ()>,
) -> Operator<InputParameterIndexed, AtomicPatternParametersIndexed> {
fn convert_maximal(op: UnparameterizedOperator) -> OperatorParameterIndexed {
let external_glwe_index = 0;
let internal_lwe_index = 1;
let br_decomposition_index = 0;
let ks_decomposition_index = 0;
match op {
Operator::Input { out_precision, .. } => Operator::Input {
Operator::Input {
out_precision,
out_shape,
..
} => Operator::Input {
out_precision,
out_shape,
extra_data: InputParameterIndexed {
lwe_dimension_index: external_glwe_index,
},
},
Operator::AtomicPattern {
in_precision,
out_precision,
multisum_inputs,
..
} => Operator::AtomicPattern {
in_precision,
out_precision,
multisum_inputs,
extra_data: AtomicPatternParametersIndexed {
input_lwe_dimensionlwe_dimension_index: external_glwe_index,
Operator::Lut { input, table, .. } => Operator::Lut {
input,
table,
extra_data: LutParametersIndexed {
input_lwe_dimension_index: external_glwe_index,
ks_decomposition_parameter_index: ks_decomposition_index,
internal_lwe_dimension_index: internal_lwe_index,
br_decomposition_parameter_index: br_decomposition_index,
output_glwe_params_index: external_glwe_index,
},
},
Operator::Dot {
inputs, weights, ..
} => Operator::Dot {
inputs,
weights,
extra_data: (),
},
Operator::LevelledOp {
inputs,
complexity,
manp,
out_shape,
comment,
..
} => Operator::LevelledOp {
inputs,
complexity,
manp,
comment,
out_shape,
extra_data: (),
},
}
}
#[must_use]
pub fn maximal_unify(g: unparametrized::AtomicPatternDag) -> parameter_indexed::AtomicPatternDag {
pub fn maximal_unify(g: unparametrized::OperationDag) -> parameter_indexed::OperationDag {
let operators: Vec<_> = g.operators.into_iter().map(convert_maximal).collect();
let parameters = ParameterCount {
@@ -147,20 +169,25 @@ pub fn maximal_unify(g: unparametrized::AtomicPatternDag) -> parameter_indexed::
};
for (i, op) in operators.iter().enumerate() {
let index = OperatorIndex { i };
match op {
Operator::Input { .. } => {
reverse_map.glwe[0].push(OperatorIndex(i));
reverse_map.glwe[0].push(index);
}
Operator::AtomicPattern { .. } => {
reverse_map.glwe[0].push(OperatorIndex(i));
reverse_map.glwe[1].push(OperatorIndex(i));
reverse_map.br_decomposition[0].push(OperatorIndex(i));
reverse_map.ks_decomposition[0].push(OperatorIndex(i));
Operator::Lut { .. } => {
reverse_map.glwe[0].push(index);
reverse_map.glwe[1].push(index);
reverse_map.br_decomposition[0].push(index);
reverse_map.ks_decomposition[0].push(index);
}
Operator::Dot { .. } | Operator::LevelledOp { .. } => {
reverse_map.glwe[0].push(index);
reverse_map.glwe[1].push(index);
}
}
}
parameter_indexed::AtomicPatternDag {
parameter_indexed::OperationDag {
operators,
parameters_count: parameters,
reverse_map,
@@ -169,16 +196,16 @@ pub fn maximal_unify(g: unparametrized::AtomicPatternDag) -> parameter_indexed::
#[must_use]
pub fn domains_to_ranges(
parameter_indexed::AtomicPatternDag {
parameter_indexed::OperationDag {
operators,
parameters_count,
reverse_map,
}: parameter_indexed::AtomicPatternDag,
}: parameter_indexed::OperationDag,
domains: ParameterDomains,
) -> range_parametrized::AtomicPatternDag {
) -> range_parametrized::OperationDag {
let mut constrained_glwe_parameter_indexes = HashSet::new();
for op in &operators {
if let Operator::AtomicPattern { extra_data, .. } = op {
if let Operator::Lut { extra_data, .. } = op {
let _ = constrained_glwe_parameter_indexes.insert(extra_data.output_glwe_params_index);
}
}
@@ -202,7 +229,7 @@ pub fn domains_to_ranges(
ks_decomposition: vec![domains.ks_decomposition; parameters_count.ks_decomposition],
};
range_parametrized::AtomicPatternDag {
range_parametrized::OperationDag {
operators,
parameter_ranges,
reverse_map,
@@ -217,24 +244,27 @@ pub fn domains_to_ranges(
#[cfg(test)]
mod tests {
use super::*;
use crate::weight::Weight;
use crate::graph::operator::{FunctionTable, LevelledComplexity, Shape, Weights};
#[test]
fn test_maximal_unify() {
let mut graph = unparametrized::AtomicPatternDag::new();
let mut graph = unparametrized::OperationDag::new();
let input1 = graph.add_input(1);
let input1 = graph.add_input(1, Shape::number());
let input2 = graph.add_input(2);
let input2 = graph.add_input(2, Shape::number());
let atomic_pattern1 =
graph.add_atomic_pattern(3, 3, vec![(Weight(1), input1), (Weight(2), input2)]);
let cpx_add = LevelledComplexity::ADDITION;
let sum1 = graph.add_levelled_op(&[input1, input2], cpx_add, 1.0, Shape::number(), "sum");
let _atomic_pattern2 = graph.add_atomic_pattern(
4,
4,
vec![(Weight(1), atomic_pattern1), (Weight(2), input2)],
);
let lut1 = graph.add_lut(sum1, FunctionTable::UNKWOWN);
let concat =
graph.add_levelled_op(&[input1, lut1], cpx_add, 1.0, Shape::number(), "concat");
let dot = graph.add_dot(&[concat], &Weights::vector(&[1, 2]));
let lut2 = graph.add_lut(dot, FunctionTable::UNKWOWN);
let graph_params = maximal_unify(graph);
@@ -250,24 +280,19 @@ mod tests {
assert_eq!(
graph_params.reverse_map.glwe,
vec![
vec![
OperatorIndex(0),
OperatorIndex(1),
OperatorIndex(2),
OperatorIndex(3)
],
vec![OperatorIndex(2), OperatorIndex(3)]
vec![input1, input2, sum1, lut1, concat, dot, lut2],
vec![sum1, lut1, concat, dot, lut2]
]
);
assert_eq!(
graph_params.reverse_map.br_decomposition,
vec![vec![OperatorIndex(2), OperatorIndex(3)]]
vec![vec![lut1, lut2]]
);
assert_eq!(
graph_params.reverse_map.ks_decomposition,
vec![vec![OperatorIndex(2), OperatorIndex(3)]]
vec![vec![lut1, lut2]]
);
// collectes l'ensemble des parametres
// unify structurellement les parametres identiques
@@ -278,19 +303,19 @@ mod tests {
#[test]
fn test_simple_lwe() {
let mut graph = unparametrized::AtomicPatternDag::new();
let input1 = graph.add_input(1);
let _input2 = graph.add_input(2);
let mut graph = unparametrized::OperationDag::new();
let input1 = graph.add_input(1, Shape::number());
let _input2 = graph.add_input(2, Shape::number());
let graph_params = maximal_unify(graph);
let range_parametrized::AtomicPatternDag {
let range_parametrized::OperationDag {
operators,
parameter_ranges,
reverse_map: _,
} = domains_to_ranges(graph_params, DEFAUT_DOMAINS);
let input_1_lwe_params = match &operators[input1.0] {
let input_1_lwe_params = match &operators[input1.i] {
Operator::Input { extra_data, .. } => extra_data.lwe_dimension_index,
_ => unreachable!(),
};
@@ -305,22 +330,25 @@ mod tests {
#[test]
fn test_simple_lwe2() {
let mut graph = unparametrized::AtomicPatternDag::new();
let input1 = graph.add_input(1);
let input2 = graph.add_input(2);
let mut graph = unparametrized::OperationDag::new();
let input1 = graph.add_input(1, Shape::number());
let input2 = graph.add_input(2, Shape::number());
let atomic_pattern1 =
graph.add_atomic_pattern(3, 3, vec![(Weight(1), input1), (Weight(2), input2)]);
let cpx_add = LevelledComplexity::ADDITION;
let concat =
graph.add_levelled_op(&[input1, input2], cpx_add, 1.0, Shape::vector(2), "concat");
let lut1 = graph.add_lut(concat, FunctionTable::UNKWOWN);
let graph_params = maximal_unify(graph);
let range_parametrized::AtomicPatternDag {
let range_parametrized::OperationDag {
operators,
parameter_ranges,
reverse_map: _,
} = domains_to_ranges(graph_params, DEFAUT_DOMAINS);
let input_1_lwe_params = match &operators[input1.0] {
let input_1_lwe_params = match &operators[input1.i] {
Operator::Input { extra_data, .. } => extra_data.lwe_dimension_index,
_ => unreachable!(),
};
@@ -329,22 +357,22 @@ mod tests {
parameter_ranges.glwe[input_1_lwe_params]
);
let atomic_pattern1_out_glwe_params = match &operators[atomic_pattern1.0] {
Operator::AtomicPattern { extra_data, .. } => extra_data.output_glwe_params_index,
let lut1_out_glwe_params = match &operators[lut1.i] {
Operator::Lut { extra_data, .. } => extra_data.output_glwe_params_index,
_ => unreachable!(),
};
assert_eq!(
DEFAUT_DOMAINS.glwe_pbs_constrained,
parameter_ranges.glwe[atomic_pattern1_out_glwe_params]
parameter_ranges.glwe[lut1_out_glwe_params]
);
let atomic_pattern1_internal_glwe_params = match &operators[atomic_pattern1.0] {
Operator::AtomicPattern { extra_data, .. } => extra_data.internal_lwe_dimension_index,
let lut1_internal_glwe_params = match &operators[lut1.i] {
Operator::Lut { extra_data, .. } => extra_data.internal_lwe_dimension_index,
_ => unreachable!(),
};
assert_eq!(
DEFAUT_DOMAINS.free_glwe,
parameter_ranges.glwe[atomic_pattern1_internal_glwe_params]
parameter_ranges.glwe[lut1_internal_glwe_params]
);
}
}

View File

@@ -1,18 +0,0 @@
use crate::weight::Weight;
#[derive(Clone, PartialEq, Eq, Debug)]
pub(crate) enum Operator<InputExtraData, AtomicPatternExtraData> {
Input {
out_precision: u8,
extra_data: InputExtraData,
},
AtomicPattern {
in_precision: u8,
out_precision: u8,
multisum_inputs: Vec<(Weight, OperatorIndex)>,
extra_data: AtomicPatternExtraData,
},
}
#[derive(Clone, Copy, PartialEq, Eq, Debug)]
pub struct OperatorIndex(pub usize);

View File

@@ -0,0 +1,6 @@
#![allow(clippy::module_inception)]
pub mod operator;
pub mod tensor;
pub use self::operator::*;
pub use self::tensor::*;

View File

@@ -0,0 +1,85 @@
use crate::graph::operator::tensor::{ClearTensor, Shape};
use derive_more::{Add, AddAssign};
pub type Weights = ClearTensor;
#[derive(Clone, PartialEq, Debug)]
pub struct FunctionTable {
pub values: Vec<u64>,
}
impl FunctionTable {
pub const UNKWOWN: Self = Self { values: vec![] };
}
#[derive(Clone, Copy, PartialEq, Add, AddAssign, Debug)]
pub struct LevelledComplexity {
pub lwe_dim_cost_factor: f64,
pub fixed_cost: f64,
}
impl LevelledComplexity {
pub const ZERO: Self = Self {
lwe_dim_cost_factor: 0.0,
fixed_cost: 0.0,
};
pub const ADDITION: Self = Self {
lwe_dim_cost_factor: 1.0,
fixed_cost: 0.0,
};
}
impl LevelledComplexity {
pub fn cost(&self, lwe_dimension: u64) -> f64 {
self.lwe_dim_cost_factor * (lwe_dimension as f64) + self.fixed_cost
}
}
impl std::ops::AddAssign<&Self> for LevelledComplexity {
fn add_assign(&mut self, rhs: &Self) {
*self += *rhs;
}
}
impl std::ops::Mul<u64> for LevelledComplexity {
type Output = Self;
fn mul(self, factor: u64) -> Self {
Self {
lwe_dim_cost_factor: self.lwe_dim_cost_factor * factor as f64,
fixed_cost: self.fixed_cost * factor as f64,
}
}
}
#[derive(Clone, PartialEq, Debug)]
pub enum Operator<InputExtraData, LutExtraData, DotExtraData, LevelledOpExtraData> {
Input {
out_precision: u8,
out_shape: Shape,
extra_data: InputExtraData,
},
Lut {
input: OperatorIndex,
table: FunctionTable,
//reduced_precision: u64
extra_data: LutExtraData,
},
Dot {
inputs: Vec<OperatorIndex>,
weights: Weights,
extra_data: DotExtraData,
},
LevelledOp {
inputs: Vec<OperatorIndex>,
complexity: LevelledComplexity,
manp: f64,
out_shape: Shape,
comment: String,
extra_data: LevelledOpExtraData,
},
}
#[derive(Clone, Copy, PartialEq, Eq, Debug)]
pub struct OperatorIndex {
pub i: usize,
}

View File

@@ -0,0 +1,86 @@
use delegate::delegate;
#[derive(Clone, PartialEq, Eq, Debug)]
pub struct Shape {
dimensions_size: Vec<u64>,
}
impl Shape {
pub fn rank(&self) -> usize {
self.dimensions_size.len()
}
pub fn flat_size(&self) -> u64 {
let mut product = 1;
for dim_size in &self.dimensions_size {
product *= dim_size;
}
product
}
pub fn number() -> Self {
Self {
dimensions_size: vec![],
}
}
pub fn is_number(&self) -> bool {
self.rank() == 0
}
pub fn vector(size: u64) -> Self {
Self {
dimensions_size: vec![size],
}
}
pub fn is_vector(&self) -> bool {
self.rank() == 1
}
pub fn duplicated(out_dim_size: u64, other: &Self) -> Self {
let mut dimensions_size = Vec::with_capacity(other.rank() + 1);
dimensions_size.push(out_dim_size as u64);
dimensions_size.extend_from_slice(&other.dimensions_size);
Self { dimensions_size }
}
}
#[derive(Clone, PartialEq, Eq, Debug)]
pub struct ClearTensor {
pub shape: Shape,
pub values: Vec<u64>,
}
#[allow(clippy::trivially_copy_pass_by_ref)]
fn square(v: &u64) -> u64 {
v * v
}
impl ClearTensor {
pub fn number(value: u64) -> Self {
Self {
shape: Shape::number(),
values: vec![value],
}
}
pub fn vector(values: &[u64]) -> Self {
Self {
shape: Shape::vector(values.len() as u64),
values: values.to_vec(),
}
}
delegate! {
to self.shape {
pub fn is_number(&self) -> bool;
pub fn is_vector(&self) -> bool;
pub fn flat_size(&self) -> u64;
pub fn rank(&self) -> usize;
}
}
pub fn square_norm2(&self) -> u64 {
self.values.iter().map(square).sum()
}
}

View File

@@ -7,8 +7,8 @@ pub struct InputParameterIndexed {
}
#[derive(Copy, Clone)]
pub struct AtomicPatternParametersIndexed {
pub input_lwe_dimensionlwe_dimension_index: usize,
pub struct LutParametersIndexed {
pub input_lwe_dimension_index: usize,
pub ks_decomposition_parameter_index: usize,
pub internal_lwe_dimension_index: usize,
pub br_decomposition_parameter_index: usize,
@@ -16,9 +16,9 @@ pub struct AtomicPatternParametersIndexed {
}
pub(crate) type OperatorParameterIndexed =
Operator<InputParameterIndexed, AtomicPatternParametersIndexed>;
Operator<InputParameterIndexed, LutParametersIndexed, (), ()>;
pub struct AtomicPatternDag {
pub struct OperationDag {
pub(crate) operators: Vec<OperatorParameterIndexed>,
pub(crate) parameters_count: ParameterCount,
pub(crate) reverse_map: ParameterToOperation,

View File

@@ -2,7 +2,7 @@ use crate::global_parameters::{ParameterRanges, ParameterToOperation};
use crate::graph::parameter_indexed::OperatorParameterIndexed;
#[allow(dead_code)]
pub struct AtomicPatternDag {
pub struct OperationDag {
pub(crate) operators: Vec<OperatorParameterIndexed>,
pub(crate) parameter_ranges: ParameterRanges,
pub(crate) reverse_map: ParameterToOperation,

View File

@@ -1,46 +1,71 @@
use super::operator::{Operator, OperatorIndex};
use crate::weight::Weight;
use crate::graph::operator::{
FunctionTable, LevelledComplexity, Operator, OperatorIndex, Shape, Weights,
};
#[derive(Clone)]
pub(crate) type UnparameterizedOperator = Operator<(), (), (), ()>;
#[derive(Clone, PartialEq)]
#[must_use]
pub struct AtomicPatternDag {
pub(crate) operators: Vec<Operator<(), ()>>,
pub struct OperationDag {
pub(crate) operators: Vec<UnparameterizedOperator>,
}
impl AtomicPatternDag {
impl OperationDag {
pub const fn new() -> Self {
Self { operators: vec![] }
}
fn add_operator(&mut self, operator: Operator<(), ()>) -> OperatorIndex {
let operator_index = self.operators.len();
fn add_operator(&mut self, operator: UnparameterizedOperator) -> OperatorIndex {
let i = self.operators.len();
self.operators.push(operator);
OperatorIndex(operator_index)
OperatorIndex { i }
}
pub fn add_input(&mut self, out_precision: u8) -> OperatorIndex {
pub fn add_input(&mut self, out_precision: u8, out_shape: Shape) -> OperatorIndex {
self.add_operator(Operator::Input {
out_precision,
out_shape,
extra_data: (),
})
}
pub fn add_atomic_pattern(
&mut self,
in_precision: u8,
out_precision: u8,
multisum_inputs: Vec<(Weight, OperatorIndex)>,
) -> OperatorIndex {
self.add_operator(Operator::AtomicPattern {
in_precision,
out_precision,
multisum_inputs,
pub fn add_lut(&mut self, input: OperatorIndex, table: FunctionTable) -> OperatorIndex {
self.add_operator(Operator::Lut {
input,
table,
extra_data: (),
})
}
pub fn add_dot(&mut self, inputs: &[OperatorIndex], weights: &Weights) -> OperatorIndex {
self.add_operator(Operator::Dot {
inputs: inputs.to_vec(),
weights: weights.clone(),
extra_data: (),
})
}
pub fn add_levelled_op(
&mut self,
inputs: &[OperatorIndex],
complexity: LevelledComplexity,
manp: f64,
out_shape: Shape,
comment: &str,
) -> OperatorIndex {
let inputs = inputs.to_vec();
let comment = comment.to_string();
let op = Operator::LevelledOp {
inputs,
complexity,
manp,
out_shape,
comment,
extra_data: (),
};
self.add_operator(op)
}
#[allow(clippy::len_without_is_empty)]
pub fn len(&self) -> usize {
self.operators.len()
@@ -50,47 +75,80 @@ impl AtomicPatternDag {
#[cfg(test)]
mod tests {
use super::*;
use crate::graph::operator::Shape;
#[test]
fn graph_creation() {
let mut graph = AtomicPatternDag::new();
let mut graph = OperationDag::new();
let input1 = graph.add_input(1);
let input1 = graph.add_input(1, Shape::number());
let input2 = graph.add_input(2);
let input2 = graph.add_input(2, Shape::number());
let atomic_pattern1 =
graph.add_atomic_pattern(3, 3, vec![(Weight(1), input1), (Weight(2), input2)]);
let cpx_add = LevelledComplexity::ADDITION;
let sum1 = graph.add_levelled_op(&[input1, input2], cpx_add, 1.0, Shape::number(), "sum");
let _atomic_pattern2 = graph.add_atomic_pattern(
4,
4,
vec![(Weight(1), atomic_pattern1), (Weight(2), input2)],
);
let lut1 = graph.add_lut(sum1, FunctionTable::UNKWOWN);
let concat =
graph.add_levelled_op(&[input1, lut1], cpx_add, 1.0, Shape::vector(2), "concat");
let dot = graph.add_dot(&[concat], &Weights::vector(&[1, 2]));
let lut2 = graph.add_lut(dot, FunctionTable::UNKWOWN);
let ops_index = [input1, input2, sum1, lut1, concat, dot, lut2];
for (expected_i, op_index) in ops_index.iter().enumerate() {
assert_eq!(expected_i, op_index.i);
}
assert_eq!(
&graph.operators,
&[
graph.operators,
vec![
Operator::Input {
out_precision: 1,
out_shape: Shape::number(),
extra_data: ()
},
Operator::Input {
out_precision: 2,
out_shape: Shape::number(),
extra_data: ()
},
Operator::AtomicPattern {
in_precision: 3,
out_precision: 3,
multisum_inputs: vec![(Weight(1), input1), (Weight(2), input2)],
Operator::LevelledOp {
inputs: vec![input1, input2],
complexity: cpx_add,
manp: 1.0,
out_shape: Shape::number(),
comment: "sum".to_string(),
extra_data: ()
},
Operator::AtomicPattern {
in_precision: 4,
out_precision: 4,
multisum_inputs: vec![(Weight(1), atomic_pattern1), (Weight(2), input2)],
Operator::Lut {
input: sum1,
table: FunctionTable::UNKWOWN,
extra_data: ()
},
Operator::LevelledOp {
inputs: vec![input1, lut1],
complexity: cpx_add,
manp: 1.0,
out_shape: Shape::vector(2),
comment: "concat".to_string(),
extra_data: ()
},
Operator::Dot {
inputs: vec![concat],
weights: Weights {
shape: Shape::vector(2),
values: vec![1, 2]
},
extra_data: ()
},
Operator::Lut {
input: dot,
table: FunctionTable::UNKWOWN,
extra_data: ()
}
]
);
}

View File

@@ -2,7 +2,7 @@ use crate::global_parameters::{ParameterToOperation, ParameterValues};
use crate::graph::parameter_indexed::OperatorParameterIndexed;
#[allow(dead_code)]
pub struct AtomicPatternDag {
pub struct OperationDag {
pub(crate) operators: Vec<OperatorParameterIndexed>,
pub(crate) parameter_ranges: ParameterValues,
pub(crate) reverse_map: ParameterToOperation,

View File

@@ -3,6 +3,7 @@
#![warn(clippy::style)]
#![allow(clippy::cast_precision_loss)] // u64 to f64
#![allow(clippy::cast_possible_truncation)] // u64 to usize
#![allow(clippy::inline_always)] // needed by delegate
#![allow(clippy::missing_panics_doc)]
#![allow(clippy::missing_const_for_fn)]
#![allow(clippy::module_name_repetitions)]