add range/value_parametrized graphs

This commit is contained in:
Mayeul@Zama
2022-02-16 16:55:38 +01:00
parent 993f5cee60
commit 12a32956a8
7 changed files with 380 additions and 1 deletions

332
src/global_parameters.rs Normal file
View File

@@ -0,0 +1,332 @@
use std::collections::HashSet;
use crate::graph::operator::{Operator, OperatorIndex};
use crate::graph::{parameter_indexed, range_parametrized, unparametrized};
use crate::parameters::{
AtomicPatternParameters, GlweParameters, InputParameter, KsDecompositionParameters,
PbsDecompositionParameters,
};
#[derive(Clone)]
pub(crate) struct ParameterToOperation {
pub glwe: Vec<Vec<OperatorIndex>>,
pub pbs_decomposition: Vec<Vec<OperatorIndex>>,
pub ks_decomposition: Vec<Vec<OperatorIndex>>,
}
#[allow(clippy::module_name_repetitions)]
#[derive(Clone, Copy, PartialEq, Debug)]
pub struct ParameterCount {
pub glwe: usize,
pub pbs_decomposition: usize,
pub ks_decomposition: usize,
}
#[derive(Clone)]
pub struct ParameterRanges {
pub glwe: Vec<GlweParameters<Range, Range>>,
pub pbs_decomposition: Vec<PbsDecompositionParameters<Range, Range>>, // 0 => lpetit , 1 => l plus grand
pub ks_decomposition: Vec<KsDecompositionParameters<Range, Range>>,
}
pub struct ParameterValues {
pub glwe: Vec<GlweParameters<u16, u16>>,
pub pbs_decomposition: Vec<PbsDecompositionParameters<u16, u16>>,
pub ks_decomposition: Vec<KsDecompositionParameters<u16, u16>>,
}
#[derive(Clone, Copy)]
pub struct ParameterDomains {
// move next comment to pareto ranges definition
// TODO: verify if pareto optimal parameters depends on precisions
pub glwe_pbs_constrained: GlweParameters<Range, Range>,
pub free_glwe: GlweParameters<Range, Range>,
pub pbs_decomposition: PbsDecompositionParameters<Range, Range>,
pub ks_decomposition: KsDecompositionParameters<Range, Range>,
}
pub const DEFAUT_DOMAINS: ParameterDomains = ParameterDomains {
glwe_pbs_constrained: GlweParameters {
log2_polynomial_size: Range { start: 8, end: 15 },
glwe_dimension: Range { start: 1, end: 10 },
},
free_glwe: GlweParameters {
log2_polynomial_size: Range { start: 0, end: 1 },
glwe_dimension: Range {
start: 600,
end: 2000,
},
},
pbs_decomposition: PbsDecompositionParameters {
log2_base: Range { start: 1, end: 65 },
level: Range { start: 1, end: 65 },
},
ks_decomposition: KsDecompositionParameters {
log2_base: Range { start: 1, end: 65 },
level: Range { start: 1, end: 65 },
},
};
#[derive(Clone, Copy, PartialEq, Debug)]
pub struct Range {
pub start: u16,
pub end: u16,
}
#[must_use]
pub fn minimal_unify(_g: unparametrized::AtomicPatternDag) -> parameter_indexed::AtomicPatternDag {
todo!()
}
fn convert_maximal(
op: Operator<(), ()>,
) -> Operator<InputParameter<usize>, AtomicPatternParameters<usize, usize, usize, usize, usize>> {
let external_glwe_index = 0;
let internal_lwe_index = 1;
let pbs_decomposition_index = 0;
let ks_decomposition_index = 0;
match op {
Operator::Input { out_precision, .. } => Operator::Input {
out_precision,
extra_data: InputParameter {
lwe_dimension: external_glwe_index,
},
},
Operator::AtomicPattern {
in_precision,
out_precision,
multisum_inputs,
..
} => Operator::AtomicPattern {
in_precision,
out_precision,
multisum_inputs,
extra_data: AtomicPatternParameters {
input_lwe_dimension: external_glwe_index,
ks_decomposition_parameter: ks_decomposition_index,
internal_lwe_dimension: internal_lwe_index,
pbs_decomposition_parameter: pbs_decomposition_index,
output_glwe_params: external_glwe_index,
},
},
}
}
#[must_use]
pub fn maximal_unify(g: unparametrized::AtomicPatternDag) -> parameter_indexed::AtomicPatternDag {
let operators: Vec<_> = g.operators.into_iter().map(convert_maximal).collect();
let parameters = ParameterCount {
glwe: 2,
pbs_decomposition: 1,
ks_decomposition: 1,
};
let mut reverse_map = ParameterToOperation {
glwe: vec![vec![], vec![]],
pbs_decomposition: vec![vec![]],
ks_decomposition: vec![vec![]],
};
for (i, op) in operators.iter().enumerate() {
match op {
Operator::Input { .. } => {
reverse_map.glwe[0].push(OperatorIndex(i));
}
Operator::AtomicPattern { .. } => {
reverse_map.glwe[0].push(OperatorIndex(i));
reverse_map.glwe[1].push(OperatorIndex(i));
reverse_map.pbs_decomposition[0].push(OperatorIndex(i));
reverse_map.ks_decomposition[0].push(OperatorIndex(i));
}
}
}
parameter_indexed::AtomicPatternDag {
operators,
parameters_count: parameters,
reverse_map,
}
}
#[must_use]
pub fn domains_to_ranges(
parameter_indexed::AtomicPatternDag {
operators,
parameters_count,
reverse_map,
}: parameter_indexed::AtomicPatternDag,
domains: ParameterDomains,
) -> range_parametrized::AtomicPatternDag {
let mut constrained_glwe_parameter_indexes = HashSet::new();
for op in &operators {
if let Operator::AtomicPattern { extra_data, .. } = op {
constrained_glwe_parameter_indexes.insert(extra_data.output_glwe_params);
}
}
let mut glwe = vec![];
for i in 0..parameters_count.glwe {
if constrained_glwe_parameter_indexes.contains(&i) {
glwe.push(domains.glwe_pbs_constrained);
} else {
glwe.push(domains.free_glwe);
}
}
let parameter_ranges = ParameterRanges {
glwe,
pbs_decomposition: vec![
domains.pbs_decomposition;
parameters_count.pbs_decomposition as usize
],
ks_decomposition: vec![domains.ks_decomposition; parameters_count.ks_decomposition],
};
range_parametrized::AtomicPatternDag {
operators,
parameter_ranges,
reverse_map,
}
}
// fn fill_ranges(g: parameter_indexed::AtomicPatternDag) -> parameter_ranged::AtomicPatternDag {
// //check unconstrained GlweDim -> set range_poly_size=[1, 2[
// todo!()
// }
#[cfg(test)]
mod tests {
use super::*;
use crate::weight::Weight;
#[test]
fn test_maximal_unify() {
let mut graph = unparametrized::AtomicPatternDag::new();
let input1 = graph.add_input(1);
let input2 = graph.add_input(2);
let atomic_pattern1 =
graph.add_atomic_pattern(3, 3, vec![(Weight(1), input1), (Weight(2), input2)]);
let _atomic_pattern2 = graph.add_atomic_pattern(
4,
4,
vec![(Weight(1), atomic_pattern1), (Weight(2), input2)],
);
let graph_params = maximal_unify(graph);
assert_eq!(
graph_params.parameters_count,
ParameterCount {
glwe: 2,
pbs_decomposition: 1,
ks_decomposition: 1,
}
);
assert_eq!(
graph_params.reverse_map.glwe,
vec![
vec![
OperatorIndex(0),
OperatorIndex(1),
OperatorIndex(2),
OperatorIndex(3)
],
vec![OperatorIndex(2), OperatorIndex(3)]
]
);
assert_eq!(
graph_params.reverse_map.pbs_decomposition,
vec![vec![OperatorIndex(2), OperatorIndex(3)]]
);
assert_eq!(
graph_params.reverse_map.ks_decomposition,
vec![vec![OperatorIndex(2), OperatorIndex(3)]]
);
// collectes l'ensemble des parametres
// unify structurellement les parametres identiques
// => counts
// =>
// let parametrized_expr = { global, dag + indexation}
}
#[test]
fn test_simple_lwe() {
let mut graph = unparametrized::AtomicPatternDag::new();
let input1 = graph.add_input(1);
let _input2 = graph.add_input(2);
let graph_params = maximal_unify(graph);
let range_parametrized::AtomicPatternDag {
operators,
parameter_ranges,
reverse_map: _,
} = domains_to_ranges(graph_params, DEFAUT_DOMAINS);
let input_1_lwe_params = match &operators[input1.0] {
Operator::Input { extra_data, .. } => extra_data.lwe_dimension,
_ => unreachable!(),
};
dbg!(&parameter_ranges.glwe);
assert_eq!(
DEFAUT_DOMAINS.free_glwe,
parameter_ranges.glwe[input_1_lwe_params]
);
}
#[test]
fn test_simple_lwe2() {
let mut graph = unparametrized::AtomicPatternDag::new();
let input1 = graph.add_input(1);
let input2 = graph.add_input(2);
let atomic_pattern1 =
graph.add_atomic_pattern(3, 3, vec![(Weight(1), input1), (Weight(2), input2)]);
let graph_params = maximal_unify(graph);
let range_parametrized::AtomicPatternDag {
operators,
parameter_ranges,
reverse_map: _,
} = domains_to_ranges(graph_params, DEFAUT_DOMAINS);
let input_1_lwe_params = match &operators[input1.0] {
Operator::Input { extra_data, .. } => extra_data.lwe_dimension,
_ => unreachable!(),
};
assert_eq!(
DEFAUT_DOMAINS.glwe_pbs_constrained,
parameter_ranges.glwe[input_1_lwe_params]
);
let atomic_pattern1_out_glwe_params = match &operators[atomic_pattern1.0] {
Operator::AtomicPattern { extra_data, .. } => extra_data.output_glwe_params,
_ => unreachable!(),
};
assert_eq!(
DEFAUT_DOMAINS.glwe_pbs_constrained,
parameter_ranges.glwe[atomic_pattern1_out_glwe_params]
);
let atomic_pattern1_internal_glwe_params = match &operators[atomic_pattern1.0] {
Operator::AtomicPattern { extra_data, .. } => extra_data.internal_lwe_dimension,
_ => unreachable!(),
};
assert_eq!(
DEFAUT_DOMAINS.free_glwe,
parameter_ranges.glwe[atomic_pattern1_internal_glwe_params]
);
}
}

View File

@@ -1,2 +1,5 @@
pub mod operator;
pub mod parameter_indexed;
pub mod range_parametrized;
pub mod unparametrized;
pub mod value_parametrized;

View File

@@ -0,0 +1,14 @@
use crate::global_parameters::{ParameterCount, ParameterToOperation};
use crate::parameters::{AtomicPatternParameters, InputParameter};
use super::operator::Operator;
type Index = usize;
pub struct AtomicPatternDag {
pub(crate) operators: Vec<
Operator<InputParameter<usize>, AtomicPatternParameters<Index, Index, Index, Index, Index>>,
>,
pub(crate) parameters_count: ParameterCount,
pub(crate) reverse_map: ParameterToOperation,
}

View File

@@ -0,0 +1,14 @@
use crate::global_parameters::{ParameterRanges, ParameterToOperation};
use crate::parameters::{AtomicPatternParameters, InputParameter};
use super::operator::Operator;
type Index = usize;
pub struct AtomicPatternDag {
pub(crate) operators: Vec<
Operator<InputParameter<usize>, AtomicPatternParameters<Index, Index, Index, Index, Index>>,
>,
pub(crate) parameter_ranges: ParameterRanges,
pub(crate) reverse_map: ParameterToOperation,
}

View File

@@ -4,7 +4,7 @@ use crate::weight::Weight;
#[derive(Clone)]
#[must_use]
pub struct AtomicPatternDag {
operators: Vec<Operator<(), ()>>,
pub(crate) operators: Vec<Operator<(), ()>>,
}
impl AtomicPatternDag {

View File

@@ -0,0 +1,14 @@
use crate::global_parameters::{ParameterToOperation, ParameterValues};
use crate::parameters::{AtomicPatternParameters, InputParameter};
use super::operator::Operator;
type Index = usize;
pub struct AtomicPatternDag {
pub(crate) operators: Vec<
Operator<InputParameter<usize>, AtomicPatternParameters<Index, Index, Index, Index, Index>>,
>,
pub(crate) parameter_ranges: ParameterValues,
pub(crate) reverse_map: ParameterToOperation,
}

View File

@@ -3,7 +3,9 @@
#![warn(clippy::style)]
#![allow(clippy::missing_panics_doc)]
#![allow(clippy::module_name_repetitions)]
#![allow(clippy::missing_const_for_fn)]
pub mod global_parameters;
pub mod graph;
pub mod parameters;
pub mod weight;