feat: rounded lut for classical pbs

step 1, provide the sub-dag construction
This commit is contained in:
rudy
2022-09-21 18:24:58 +02:00
committed by rudy
parent 819c9e71ca
commit 3db828d3d0
11 changed files with 441 additions and 589 deletions

View File

@@ -15,6 +15,7 @@ puruspe = "0.2.0"
[dev-dependencies]
approx = "0.5"
once_cell = "1.16.0"
pretty_assertions = "1.2.1"
[lib]

View File

@@ -1,5 +1,2 @@
pub mod operator;
pub mod parameter_indexed;
pub mod range_parametrized;
pub mod unparametrized;
pub mod value_parametrized;

View File

@@ -65,22 +65,19 @@ pub type Precision = u8;
pub const MIN_PRECISION: Precision = 1;
#[derive(Clone, PartialEq, Debug)]
pub enum Operator<InputExtraData, LutExtraData, DotExtraData, LevelledOpExtraData> {
pub enum Operator {
Input {
out_precision: Precision,
out_shape: Shape,
extra_data: InputExtraData,
},
Lut {
input: OperatorIndex,
table: FunctionTable,
out_precision: Precision,
extra_data: LutExtraData,
},
Dot {
inputs: Vec<OperatorIndex>,
weights: Weights,
extra_data: DotExtraData,
},
LevelledOp {
inputs: Vec<OperatorIndex>,
@@ -88,7 +85,12 @@ pub enum Operator<InputExtraData, LutExtraData, DotExtraData, LevelledOpExtraDat
manp: f64,
out_shape: Shape,
comment: String,
extra_data: LevelledOpExtraData,
},
// Used to reduced or increase precision when the cyphertext is compatible with different precision
// This is done without any checking
UnsafeCast {
input: OperatorIndex,
out_precision: Precision, // precision is changed without modifying the input, can be increase or decrease
},
}

View File

@@ -1,25 +0,0 @@
use crate::global_parameters::{ParameterCount, ParameterToOperation};
use super::operator::Operator;
pub struct InputParameterIndexed {
pub lwe_dimension_index: usize,
}
#[derive(Copy, Clone)]
pub struct LutParametersIndexed {
pub input_lwe_dimension_index: usize,
pub ks_decomposition_parameter_index: usize,
pub internal_lwe_dimension_index: usize,
pub br_decomposition_parameter_index: usize,
pub output_glwe_params_index: usize,
}
pub(crate) type OperatorParameterIndexed =
Operator<InputParameterIndexed, LutParametersIndexed, (), ()>;
pub struct OperationDag {
pub(crate) operators: Vec<OperatorParameterIndexed>,
pub(crate) parameters_count: ParameterCount,
pub(crate) reverse_map: ParameterToOperation,
}

View File

@@ -1,9 +0,0 @@
use crate::dag::parameter_indexed::OperatorParameterIndexed;
use crate::global_parameters::{ParameterRanges, ParameterToOperation};
#[allow(dead_code)]
pub struct OperationDag {
pub(crate) operators: Vec<OperatorParameterIndexed>,
pub(crate) parameter_ranges: ParameterRanges,
pub(crate) reverse_map: ParameterToOperation,
}

View File

@@ -1,24 +1,36 @@
use std::fmt::Write;
use crate::dag::operator::{
FunctionTable, LevelledComplexity, Operator, OperatorIndex, Precision, Shape, Weights,
dot_kind, DotKind, FunctionTable, LevelledComplexity, Operator, OperatorIndex, Precision,
Shape, Weights,
};
pub(crate) type UnparameterizedOperator = Operator<(), (), (), ()>;
pub(crate) type UnparameterizedOperator = Operator;
#[derive(Clone, PartialEq, Debug)]
#[must_use]
pub struct OperationDag {
pub(crate) operators: Vec<UnparameterizedOperator>,
// Collect all operators ouput shape
pub(crate) out_shapes: Vec<Shape>,
// Collect all operators ouput precision
pub(crate) out_precisions: Vec<Precision>,
}
impl OperationDag {
pub const fn new() -> Self {
Self { operators: vec![] }
Self {
operators: vec![],
out_shapes: vec![],
out_precisions: vec![],
}
}
fn add_operator(&mut self, operator: UnparameterizedOperator) -> OperatorIndex {
let i = self.operators.len();
self.out_precisions
.push(self.infer_out_precision(&operator));
self.out_shapes.push(self.infer_out_shape(&operator));
self.operators.push(operator);
OperatorIndex { i }
}
@@ -32,7 +44,6 @@ impl OperationDag {
self.add_operator(Operator::Input {
out_precision,
out_shape,
extra_data: (),
})
}
@@ -46,7 +57,6 @@ impl OperationDag {
input,
table,
out_precision,
extra_data: (),
})
}
@@ -57,11 +67,7 @@ impl OperationDag {
) -> OperatorIndex {
let inputs = inputs.into();
let weights = weights.into();
self.add_operator(Operator::Dot {
inputs,
weights,
extra_data: (),
})
self.add_operator(Operator::Dot { inputs, weights })
}
pub fn add_levelled_op(
@@ -81,11 +87,25 @@ impl OperationDag {
manp,
out_shape,
comment,
extra_data: (),
};
self.add_operator(op)
}
pub fn add_unsafe_cast(
&mut self,
input: OperatorIndex,
out_precision: Precision,
) -> OperatorIndex {
let input_precision = self.out_precisions[input.i];
if input_precision == out_precision {
return input;
}
self.add_operator(Operator::UnsafeCast {
input,
out_precision,
})
}
#[allow(clippy::len_without_is_empty)]
pub fn len(&self) -> usize {
self.operators.len()
@@ -100,6 +120,139 @@ impl OperationDag {
}
acc
}
fn add_shift_left_lsb_to_msb_no_padding(&mut self, input: OperatorIndex) -> OperatorIndex {
// Convert any input to simple 1bit msb replacing the padding
// For now encoding is not explicit, so 1 bit content without padding <=> 0 bit content with padding.
let in_precision = self.out_precisions[input.i];
let shift_factor = Weights::number(1 << (in_precision as i64));
let lsb_as_msb = self.add_dot([input], shift_factor);
self.add_unsafe_cast(lsb_as_msb, 0 as Precision)
}
fn add_lut_1bit_no_padding(
&mut self,
input: OperatorIndex,
table: FunctionTable,
out_precision: Precision,
) -> OperatorIndex {
// For now encoding is not explicit, so 1 bit content without padding <=> 0 bit content with padding.
let in_precision = self.out_precisions[input.i];
assert!(in_precision == 0);
// An add after with a clear constant is skipped here as it doesn't change noise handling.
self.add_lut(input, table, out_precision)
}
fn add_shift_right_msb_no_padding_to_lsb(
&mut self,
input: OperatorIndex,
out_precision: Precision,
) -> OperatorIndex {
// Convert simple 1 bit msb to a nbit with zero padding
let to_nbits_padded = FunctionTable::UNKWOWN;
self.add_lut_1bit_no_padding(input, to_nbits_padded, out_precision)
}
fn add_isolate_lowest_bit(&mut self, input: OperatorIndex) -> OperatorIndex {
// The lowest bit is converted to a cyphertext of same precision as input.
// Introduce a pbs of input precision but this precision is only used on 1 levelled op and converted to lower precision
// Noise is reduced by a pbs.
let out_precision = self.out_precisions[input.i];
let lsb_as_msb = self.add_shift_left_lsb_to_msb_no_padding(input);
self.add_shift_right_msb_no_padding_to_lsb(lsb_as_msb, out_precision)
}
pub fn add_truncate_1_bit(&mut self, input: OperatorIndex) -> OperatorIndex {
// Reset a bit.
// ex: 10110 is truncated to 1011, 10111 is truncated to 1011
let in_precision = self.out_precisions[input.i];
let lowest_bit = self.add_isolate_lowest_bit(input);
let bit_cleared = self.add_dot([input, lowest_bit], [1, -1]);
self.add_unsafe_cast(bit_cleared, in_precision - 1)
}
pub fn add_expanded_round(
&mut self,
input: OperatorIndex,
rounded_precision: Precision,
) -> OperatorIndex {
// Round such that the ouput has precision out_precision.
// We round by adding 2**(removed_precision - 1) to the last remaining bit to clear (this step is a no-op).
// Than all lower bits are cleared.
// Note: this is a simplified graph, some constant additions are missing without consequence on crypto parameter choice.
// Note: reset and rounds could be done by 4, 3, 2 and 1 bits groups for efficiency.
// bit efficiency is better for 4 precision then 3, but the feasability is lower for high noise
let in_precision = self.out_precisions[input.i];
assert!(rounded_precision <= in_precision);
if in_precision == rounded_precision {
return input;
}
// Add rounding constant, this is a represented as non-op since it doesn't influence crypto parameters.
let mut rounded = input;
// The rounded is in high precision with garbage lowest bits
let bits_to_truncate = in_precision - rounded_precision;
for _ in 1..=bits_to_truncate as i64 {
rounded = self.add_truncate_1_bit(rounded);
}
rounded
}
pub fn add_expanded_rounded_lut(
&mut self,
input: OperatorIndex,
table: FunctionTable,
rounded_precision: Precision,
out_precision: Precision,
) -> OperatorIndex {
// note: this is a simplified graph, some constant additions are missing without consequence on crypto parameter choice.
let rounded = self.add_expanded_round(input, rounded_precision);
self.add_lut(rounded, table, out_precision)
}
fn infer_out_shape(&self, op: &UnparameterizedOperator) -> Shape {
match op {
Operator::Input { out_shape, .. } | Operator::LevelledOp { out_shape, .. } => {
out_shape.clone()
}
Operator::Lut { input, .. } | Operator::UnsafeCast { input, .. } => {
self.out_shapes[input.i].clone()
}
Operator::Dot {
inputs, weights, ..
} => {
let input_shape = self.out_shapes[inputs[0].i].clone();
let kind = dot_kind(inputs.len() as u64, &input_shape, weights);
match kind {
DotKind::Simple | DotKind::Tensor | DotKind::CompatibleTensor => {
Shape::number()
}
DotKind::Broadcast { .. } => Shape::vector(input_shape.first_dim_size()),
DotKind::Unsupported { .. } => {
let weights_shape = &weights.shape;
println!();
println!();
println!("Error diagnostic on dot operation:");
println!(
"Incompatible operands: <{input_shape:?}> DOT <{weights_shape:?}>"
);
println!();
panic!("Unsupported or invalid dot operation")
}
}
}
}
}
fn infer_out_precision(&self, op: &UnparameterizedOperator) -> Precision {
match op {
Operator::Input { out_precision, .. }
| Operator::Lut { out_precision, .. }
| Operator::UnsafeCast { out_precision, .. } => *out_precision,
Operator::Dot { inputs, .. } | Operator::LevelledOp { inputs, .. } => {
self.out_precisions[inputs[0].i]
}
}
}
}
#[cfg(test)]
@@ -138,12 +291,10 @@ mod tests {
Operator::Input {
out_precision: 1,
out_shape: Shape::number(),
extra_data: ()
},
Operator::Input {
out_precision: 2,
out_shape: Shape::number(),
extra_data: ()
},
Operator::LevelledOp {
inputs: vec![input1, input2],
@@ -151,13 +302,11 @@ mod tests {
manp: 1.0,
out_shape: Shape::number(),
comment: "sum".to_string(),
extra_data: ()
},
Operator::Lut {
input: sum1,
table: FunctionTable::UNKWOWN,
out_precision: 1,
extra_data: ()
},
Operator::LevelledOp {
inputs: vec![input1, lut1],
@@ -165,7 +314,6 @@ mod tests {
manp: 1.0,
out_shape: Shape::vector(2),
comment: "concat".to_string(),
extra_data: ()
},
Operator::Dot {
inputs: vec![concat],
@@ -173,15 +321,119 @@ mod tests {
shape: Shape::vector(2),
values: vec![1, 2]
},
extra_data: ()
},
Operator::Lut {
input: dot,
table: FunctionTable::UNKWOWN,
out_precision: 2,
extra_data: ()
}
]
);
}
#[test]
fn test_rounded_lut() {
let mut graph = OperationDag::new();
let out_precision = 5;
let rounded_precision = 2;
let input1 = graph.add_input(out_precision, Shape::number());
let _ = graph.add_expanded_rounded_lut(
input1,
FunctionTable::UNKWOWN,
rounded_precision,
out_precision,
);
let expecteds = [
Operator::Input {
out_precision,
out_shape: Shape::number(),
},
// The rounding addition skipped, it's a no-op wrt crypto parameter
// Clear: cleared = input - bit0
//// Extract bit
Operator::Dot {
inputs: vec![input1],
weights: Weights::number(1 << 5),
},
Operator::UnsafeCast {
input: OperatorIndex { i: 1 },
out_precision: 0,
},
//// 1 Bit to out_precision
Operator::Lut {
input: OperatorIndex { i: 2 },
table: FunctionTable::UNKWOWN,
out_precision: 5,
},
//// Erase bit
Operator::Dot {
inputs: vec![input1, OperatorIndex { i: 3 }],
weights: Weights::vector([1, -1]),
},
Operator::UnsafeCast {
input: OperatorIndex { i: 4 },
out_precision: 4,
},
// Clear: cleared = input - bit0 - bit1
//// Extract bit
Operator::Dot {
inputs: vec![OperatorIndex { i: 5 }],
weights: Weights::number(1 << 4),
},
Operator::UnsafeCast {
input: OperatorIndex { i: 6 },
out_precision: 0,
},
//// 1 Bit to out_precision
Operator::Lut {
input: OperatorIndex { i: 7 },
table: FunctionTable::UNKWOWN,
out_precision: 4,
},
//// Erase bit
Operator::Dot {
inputs: vec![OperatorIndex { i: 5 }, OperatorIndex { i: 8 }],
weights: Weights::vector([1, -1]),
},
Operator::UnsafeCast {
input: OperatorIndex { i: 9 },
out_precision: 3,
},
// Clear: cleared = input - bit0 - bit1 - bit2
//// Extract bit
Operator::Dot {
inputs: vec![OperatorIndex { i: 10 }],
weights: Weights::number(1 << 3),
},
Operator::UnsafeCast {
input: OperatorIndex { i: 11 },
out_precision: 0,
},
//// 1 Bit to out_precision
Operator::Lut {
input: OperatorIndex { i: 12 },
table: FunctionTable::UNKWOWN,
out_precision: 3,
},
//// Erase bit
Operator::Dot {
inputs: vec![OperatorIndex { i: 10 }, OperatorIndex { i: 13 }],
weights: Weights::vector([1, -1]),
},
Operator::UnsafeCast {
input: OperatorIndex { i: 14 },
out_precision: 2,
},
// Lut on rounded precision
Operator::Lut {
input: OperatorIndex { i: 15 },
table: FunctionTable::UNKWOWN,
out_precision: 5,
},
];
assert_eq!(expecteds.len(), graph.operators.len());
for (i, (expected, actual)) in std::iter::zip(expecteds, graph.operators).enumerate() {
assert_eq!(expected, actual, "{i}-th operation");
}
}
}

View File

@@ -1,9 +0,0 @@
use crate::dag::parameter_indexed::OperatorParameterIndexed;
use crate::global_parameters::{ParameterToOperation, ParameterValues};
#[allow(dead_code)]
pub struct OperationDag {
pub(crate) operators: Vec<OperatorParameterIndexed>,
pub(crate) parameter_ranges: ParameterValues,
pub(crate) reverse_map: ParameterToOperation,
}

View File

@@ -1,43 +1,14 @@
use std::collections::HashSet;
use crate::dag::operator::{Operator, OperatorIndex};
use crate::dag::parameter_indexed::{
InputParameterIndexed, LutParametersIndexed, OperatorParameterIndexed,
};
use crate::dag::unparametrized::UnparameterizedOperator;
use crate::dag::{parameter_indexed, range_parametrized, unparametrized};
use crate::parameters::{
BrDecompositionParameterRanges, BrDecompositionParameters, GlweParameterRanges, GlweParameters,
KsDecompositionParameterRanges, KsDecompositionParameters,
BrDecompositionParameterRanges, GlweParameterRanges, KsDecompositionParameterRanges,
};
#[derive(Clone)]
pub(crate) struct ParameterToOperation {
pub glwe: Vec<Vec<OperatorIndex>>,
pub br_decomposition: Vec<Vec<OperatorIndex>>,
pub ks_decomposition: Vec<Vec<OperatorIndex>>,
}
#[derive(Clone, Copy, PartialEq, Eq, Debug)]
pub struct ParameterCount {
struct ParameterCount {
pub glwe: usize,
pub br_decomposition: usize,
pub ks_decomposition: usize,
}
#[derive(Clone)]
pub struct ParameterRanges {
pub glwe: Vec<GlweParameterRanges>,
pub br_decomposition: Vec<BrDecompositionParameterRanges>, // 0 => lpetit , 1 => l plus grand
pub ks_decomposition: Vec<KsDecompositionParameterRanges>,
}
pub struct ParameterValues {
pub glwe: Vec<GlweParameters>,
pub br_decomposition: Vec<BrDecompositionParameters>,
pub ks_decomposition: Vec<KsDecompositionParameters>,
}
#[derive(Clone, Copy)]
pub struct ParameterDomains {
// move next comment to pareto ranges definition
@@ -91,289 +62,3 @@ impl Range {
self.into_iter().collect()
}
}
#[must_use]
#[allow(clippy::needless_pass_by_value)]
pub fn minimal_unify(_g: unparametrized::OperationDag) -> parameter_indexed::OperationDag {
todo!()
}
fn convert_maximal(op: UnparameterizedOperator) -> OperatorParameterIndexed {
let external_glwe_index = 0;
let internal_lwe_index = 1;
let br_decomposition_index = 0;
let ks_decomposition_index = 0;
match op {
Operator::Input {
out_precision,
out_shape,
..
} => Operator::Input {
out_precision,
out_shape,
extra_data: InputParameterIndexed {
lwe_dimension_index: external_glwe_index,
},
},
Operator::Lut {
input,
table,
out_precision,
..
} => Operator::Lut {
input,
table,
out_precision,
extra_data: LutParametersIndexed {
input_lwe_dimension_index: external_glwe_index,
ks_decomposition_parameter_index: ks_decomposition_index,
internal_lwe_dimension_index: internal_lwe_index,
br_decomposition_parameter_index: br_decomposition_index,
output_glwe_params_index: external_glwe_index,
},
},
Operator::Dot {
inputs, weights, ..
} => Operator::Dot {
inputs,
weights,
extra_data: (),
},
Operator::LevelledOp {
inputs,
complexity,
manp,
out_shape,
comment,
..
} => Operator::LevelledOp {
inputs,
complexity,
manp,
comment,
out_shape,
extra_data: (),
},
}
}
#[must_use]
pub fn maximal_unify(g: unparametrized::OperationDag) -> parameter_indexed::OperationDag {
let operators: Vec<_> = g.operators.into_iter().map(convert_maximal).collect();
let parameters = ParameterCount {
glwe: 2,
br_decomposition: 1,
ks_decomposition: 1,
};
let mut reverse_map = ParameterToOperation {
glwe: vec![vec![], vec![]],
br_decomposition: vec![vec![]],
ks_decomposition: vec![vec![]],
};
for (i, op) in operators.iter().enumerate() {
let index = OperatorIndex { i };
match op {
Operator::Input { .. } => {
reverse_map.glwe[0].push(index);
}
Operator::Lut { .. } => {
reverse_map.glwe[0].push(index);
reverse_map.glwe[1].push(index);
reverse_map.br_decomposition[0].push(index);
reverse_map.ks_decomposition[0].push(index);
}
Operator::Dot { .. } | Operator::LevelledOp { .. } => {
reverse_map.glwe[0].push(index);
reverse_map.glwe[1].push(index);
}
}
}
parameter_indexed::OperationDag {
operators,
parameters_count: parameters,
reverse_map,
}
}
#[must_use]
pub fn domains_to_ranges(
parameter_indexed::OperationDag {
operators,
parameters_count,
reverse_map,
}: parameter_indexed::OperationDag,
domains: ParameterDomains,
) -> range_parametrized::OperationDag {
let mut constrained_glwe_parameter_indexes = HashSet::new();
for op in &operators {
if let Operator::Lut { extra_data, .. } = op {
let _ = constrained_glwe_parameter_indexes.insert(extra_data.output_glwe_params_index);
}
}
let mut glwe = vec![];
for i in 0..parameters_count.glwe {
if constrained_glwe_parameter_indexes.contains(&i) {
glwe.push(domains.glwe_pbs_constrained);
} else {
glwe.push(domains.free_glwe);
}
}
let parameter_ranges = ParameterRanges {
glwe,
br_decomposition: vec![domains.br_decomposition; parameters_count.br_decomposition],
ks_decomposition: vec![domains.ks_decomposition; parameters_count.ks_decomposition],
};
range_parametrized::OperationDag {
operators,
parameter_ranges,
reverse_map,
}
}
// fn fill_ranges(g: parameter_indexed::AtomicPatternDag) -> parameter_ranged::AtomicPatternDag {
// //check unconstrained GlweDim -> set range_poly_size=[1, 2[
// todo!()
// }
#[cfg(test)]
mod tests {
use super::*;
use crate::dag::operator::{FunctionTable, LevelledComplexity, Shape};
#[test]
fn test_maximal_unify() {
let mut graph = unparametrized::OperationDag::new();
let input1 = graph.add_input(1, Shape::number());
let input2 = graph.add_input(2, Shape::number());
let cpx_add = LevelledComplexity::ADDITION;
let sum1 = graph.add_levelled_op([input1, input2], cpx_add, 1.0, Shape::number(), "sum");
let lut1 = graph.add_lut(sum1, FunctionTable::UNKWOWN, 2);
let concat = graph.add_levelled_op([input1, lut1], cpx_add, 1.0, Shape::number(), "concat");
let dot = graph.add_dot([concat], [1, 2]);
let lut2 = graph.add_lut(dot, FunctionTable::UNKWOWN, 2);
let graph_params = maximal_unify(graph);
assert_eq!(
graph_params.parameters_count,
ParameterCount {
glwe: 2,
br_decomposition: 1,
ks_decomposition: 1,
}
);
assert_eq!(
graph_params.reverse_map.glwe,
vec![
vec![input1, input2, sum1, lut1, concat, dot, lut2],
vec![sum1, lut1, concat, dot, lut2]
]
);
assert_eq!(
graph_params.reverse_map.br_decomposition,
vec![vec![lut1, lut2]]
);
assert_eq!(
graph_params.reverse_map.ks_decomposition,
vec![vec![lut1, lut2]]
);
// collectes l'ensemble des parametres
// unify structurellement les parametres identiques
// => counts
// =>
// let parametrized_expr = { global, dag + indexation}
}
#[test]
fn test_simple_lwe() {
let mut graph = unparametrized::OperationDag::new();
let input1 = graph.add_input(1, Shape::number());
let _input2 = graph.add_input(2, Shape::number());
let graph_params = maximal_unify(graph);
let range_parametrized::OperationDag {
operators,
parameter_ranges,
reverse_map: _,
} = domains_to_ranges(graph_params, DEFAUT_DOMAINS);
let input_1_lwe_params = match &operators[input1.i] {
Operator::Input { extra_data, .. } => extra_data.lwe_dimension_index,
_ => unreachable!(),
};
dbg!(&parameter_ranges.glwe);
assert_eq!(
DEFAUT_DOMAINS.free_glwe,
parameter_ranges.glwe[input_1_lwe_params]
);
}
#[test]
fn test_simple_lwe2() {
let mut graph = unparametrized::OperationDag::new();
let input1 = graph.add_input(1, Shape::number());
let input2 = graph.add_input(2, Shape::number());
let cpx_add = LevelledComplexity::ADDITION;
let concat =
graph.add_levelled_op([input1, input2], cpx_add, 1.0, Shape::vector(2), "concat");
let lut1 = graph.add_lut(concat, FunctionTable::UNKWOWN, 2);
let graph_params = maximal_unify(graph);
let range_parametrized::OperationDag {
operators,
parameter_ranges,
reverse_map: _,
} = domains_to_ranges(graph_params, DEFAUT_DOMAINS);
let input_1_lwe_params = match &operators[input1.i] {
Operator::Input { extra_data, .. } => extra_data.lwe_dimension_index,
_ => unreachable!(),
};
assert_eq!(
DEFAUT_DOMAINS.glwe_pbs_constrained,
parameter_ranges.glwe[input_1_lwe_params]
);
let lut1_out_glwe_params = match &operators[lut1.i] {
Operator::Lut { extra_data, .. } => extra_data.output_glwe_params_index,
_ => unreachable!(),
};
assert_eq!(
DEFAUT_DOMAINS.glwe_pbs_constrained,
parameter_ranges.glwe[lut1_out_glwe_params]
);
let lut1_internal_glwe_params = match &operators[lut1.i] {
Operator::Lut { extra_data, .. } => extra_data.internal_lwe_dimension_index,
_ => unreachable!(),
};
assert_eq!(
DEFAUT_DOMAINS.free_glwe,
parameter_ranges.glwe[lut1_internal_glwe_params]
);
}
}

View File

@@ -54,6 +54,8 @@ fn assert_non_empty_inputs(op: &unparametrized::UnparameterizedOperator) {
fn assert_dag_correctness(dag: &unparametrized::OperationDag) {
for op in &dag.operators {
assert_non_empty_inputs(op);
assert_inputs_uniform_precisions(op, &dag.out_precisions);
assert_dot_uniform_inputs_shape(op, &dag.out_shapes);
}
}
@@ -68,10 +70,6 @@ fn assert_valid_variances(dag: &OperationDag) {
}
fn assert_properties_correctness(dag: &OperationDag) {
for op in &dag.operators {
assert_inputs_uniform_precisions(op, &dag.out_precisions);
assert_dot_uniform_inputs_shape(op, &dag.out_shapes);
}
assert_valid_variances(dag);
}
@@ -89,10 +87,6 @@ fn variance_origin(inputs: &[OperatorIndex], out_variances: &[SymbolicVariance])
#[derive(Clone, Debug)]
pub struct OperationDag {
pub operators: Vec<Op>,
// Collect all operators ouput shape
pub out_shapes: Vec<Shape>,
// Collect all operators ouput precision
pub out_precisions: Vec<Precision>,
// Collect all operators ouput variances
pub out_variances: Vec<SymbolicVariance>,
pub nb_luts: u64,
@@ -118,66 +112,6 @@ pub struct VariancesAndBound {
pub all_in_lut: Vec<(u64, SymbolicVariance)>,
}
fn out_shape(op: &unparametrized::UnparameterizedOperator, out_shapes: &mut [Shape]) -> Shape {
match op {
Op::Input { out_shape, .. } | Op::LevelledOp { out_shape, .. } => out_shape.clone(),
Op::Lut { input, .. } => out_shapes[input.i].clone(),
Op::Dot {
inputs, weights, ..
} => {
if inputs.is_empty() {
return Shape::number();
}
let input_shape = first(inputs, out_shapes);
let kind = dot_kind(inputs.len() as u64, input_shape, weights);
match kind {
DK::Simple | DK::Tensor => Shape::number(),
DK::CompatibleTensor => weights.shape.clone(),
DK::Broadcast { .. } => Shape::vector(input_shape.first_dim_size()),
DK::Unsupported { .. } => {
let weights_shape = &weights.shape;
println!();
println!();
println!("Error diagnostic on dot operation:");
println!("Incompatible operands: <{input_shape:?}> DOT <{weights_shape:?}>");
println!();
panic!("Unsupported or invalid dot operation")
}
}
}
}
}
fn out_shapes(dag: &unparametrized::OperationDag) -> Vec<Shape> {
let nb_ops = dag.operators.len();
let mut out_shapes = Vec::<Shape>::with_capacity(nb_ops);
for op in &dag.operators {
let shape = out_shape(op, &mut out_shapes);
out_shapes.push(shape);
}
out_shapes
}
fn out_precision(
op: &unparametrized::UnparameterizedOperator,
out_precisions: &[Precision],
) -> Precision {
match op {
Op::Input { out_precision, .. } | Op::Lut { out_precision, .. } => *out_precision,
Op::Dot { inputs, .. } | Op::LevelledOp { inputs, .. } => out_precisions[inputs[0].i],
}
}
fn out_precisions(dag: &unparametrized::OperationDag) -> Vec<Precision> {
let nb_ops = dag.operators.len();
let mut out_precisions = Vec::<Precision>::with_capacity(nb_ops);
for op in &dag.operators {
let precision = out_precision(op, &out_precisions);
out_precisions.push(precision);
}
out_precisions
}
fn out_variance(
op: &unparametrized::UnparameterizedOperator,
out_shapes: &[Shape],
@@ -220,17 +154,15 @@ fn out_variance(
DK::Unsupported { .. } => panic!("Unsupported"),
}
}
Op::UnsafeCast { input, .. } => out_variances[input.i],
}
}
fn out_variances(
dag: &unparametrized::OperationDag,
out_shapes: &[Shape],
) -> Vec<SymbolicVariance> {
fn out_variances(dag: &unparametrized::OperationDag) -> Vec<SymbolicVariance> {
let nb_ops = dag.operators.len();
let mut out_variances = Vec::with_capacity(nb_ops);
for op in &dag.operators {
let vf = out_variance(op, out_shapes, &mut out_variances);
let vf = out_variance(op, &dag.out_shapes, &mut out_variances);
out_variances.push(vf);
}
out_variances
@@ -242,7 +174,7 @@ fn extra_final_values_to_check(dag: &unparametrized::OperationDag) -> Vec<bool>
for op in &dag.operators {
match op {
Op::Input { .. } => (),
Op::Lut { input, .. } => {
Op::Lut { input, .. } | Op::UnsafeCast { input, .. } => {
extra_values_to_check[input.i] = false;
}
Op::Dot { inputs, .. } | Op::LevelledOp { inputs, .. } => {
@@ -257,8 +189,6 @@ fn extra_final_values_to_check(dag: &unparametrized::OperationDag) -> Vec<bool>
fn extra_final_variances(
dag: &unparametrized::OperationDag,
out_shapes: &[Shape],
out_precisions: &[Precision],
out_variances: &[SymbolicVariance],
) -> Vec<(Precision, Shape, SymbolicVariance)> {
extra_final_values_to_check(dag)
@@ -266,7 +196,11 @@ fn extra_final_variances(
.enumerate()
.filter_map(|(i, &is_final)| {
if is_final {
Some((out_precisions[i], out_shapes[i].clone(), out_variances[i]))
Some((
dag.out_precisions[i],
dag.out_shapes[i].clone(),
out_variances[i],
))
} else {
None
}
@@ -276,8 +210,6 @@ fn extra_final_variances(
fn in_luts_variance(
dag: &unparametrized::OperationDag,
out_shapes: &[Shape],
out_precisions: &[Precision],
out_variances: &[SymbolicVariance],
) -> Vec<(Precision, Shape, SymbolicVariance)> {
dag.operators
@@ -286,8 +218,8 @@ fn in_luts_variance(
.filter_map(|(i, op)| {
if let &Op::Lut { input, .. } = op {
Some((
out_precisions[input.i],
out_shapes[i].clone(),
dag.out_precisions[input.i],
dag.out_shapes[i].clone(),
out_variances[input.i],
))
} else {
@@ -315,26 +247,23 @@ fn op_levelled_complexity(
}
}
Op::LevelledOp { complexity, .. } => *complexity,
Op::Input { .. } | Op::Lut { .. } => LevelledComplexity::ZERO,
Op::Input { .. } | Op::Lut { .. } | Op::UnsafeCast { .. } => LevelledComplexity::ZERO,
}
}
fn levelled_complexity(
dag: &unparametrized::OperationDag,
out_shapes: &[Shape],
) -> LevelledComplexity {
fn levelled_complexity(dag: &unparametrized::OperationDag) -> LevelledComplexity {
let mut levelled_complexity = LevelledComplexity::ZERO;
for op in &dag.operators {
levelled_complexity += op_levelled_complexity(op, out_shapes);
levelled_complexity += op_levelled_complexity(op, &dag.out_shapes);
}
levelled_complexity
}
fn lut_count(dag: &unparametrized::OperationDag, out_shapes: &[Shape]) -> u64 {
pub fn lut_count_from_dag(dag: &unparametrized::OperationDag) -> u64 {
let mut count = 0;
for (i, op) in dag.operators.iter().enumerate() {
if let Op::Lut { .. } = op {
count += out_shapes[i].flat_size();
count += dag.out_shapes[i].flat_size();
}
}
count
@@ -433,10 +362,8 @@ fn constraint_for_one_precision(
pub fn worst_log_norm(dag: &unparametrized::OperationDag) -> f64 {
assert_dag_correctness(dag);
let out_shapes = out_shapes(dag);
let out_precisions = out_precisions(dag);
let out_variances = out_variances(dag, &out_shapes);
let in_luts_variance = in_luts_variance(dag, &out_shapes, &out_precisions, &out_variances);
let out_variances = out_variances(dag);
let in_luts_variance = in_luts_variance(dag, &out_variances);
let coeffs = in_luts_variance
.iter()
.map(|(_precision, _shape, symbolic_variance)| {
@@ -447,25 +374,18 @@ pub fn worst_log_norm(dag: &unparametrized::OperationDag) -> f64 {
worst.log2()
}
pub fn lut_count_from_dag(dag: &unparametrized::OperationDag) -> u64 {
lut_count(dag, &out_shapes(dag))
}
pub fn analyze(
dag: &unparametrized::OperationDag,
noise_config: &NoiseBoundConfig,
) -> OperationDag {
assert_dag_correctness(dag);
let out_shapes = out_shapes(dag);
let out_precisions = out_precisions(dag);
let out_variances = out_variances(dag, &out_shapes);
let in_luts_variance = in_luts_variance(dag, &out_shapes, &out_precisions, &out_variances);
let nb_luts = lut_count(dag, &out_shapes);
let extra_final_variances =
extra_final_variances(dag, &out_shapes, &out_precisions, &out_variances);
let levelled_complexity = levelled_complexity(dag, &out_shapes);
let out_variances = out_variances(dag);
let in_luts_variance = in_luts_variance(dag, &out_variances);
let nb_luts = lut_count_from_dag(dag);
let extra_final_variances = extra_final_variances(dag, &out_variances);
let levelled_complexity = levelled_complexity(dag);
let constraints_by_precisions = constraints_by_precisions(
&out_precisions,
&dag.out_precisions,
&extra_final_variances,
&in_luts_variance,
noise_config,
@@ -475,8 +395,6 @@ pub fn analyze(
.all(|(_, _, sb)| sb.origin() == VarianceOrigin::Input);
let result = OperationDag {
operators: dag.operators.clone(),
out_shapes,
out_precisions,
out_variances,
nb_luts,
levelled_complexity,
@@ -712,9 +630,9 @@ mod tests {
let complexity_cost = analysis.complexity(lwe_dim, one_lut_cost);
assert_eq!(analysis.out_variances[input1.i], SymbolicVariance::INPUT);
assert_eq!(analysis.out_shapes[input1.i], Shape::number());
assert_eq!(graph.out_shapes[input1.i], Shape::number());
assert_eq!(analysis.levelled_complexity, LevelledComplexity::ZERO);
assert_eq!(analysis.out_precisions[input1.i], 1);
assert_eq!(graph.out_precisions[input1.i], 1);
assert_f64_eq(complexity_cost, 0.0);
assert!(analysis.nb_luts == 0);
let constraint = analysis.constraint();
@@ -735,9 +653,9 @@ mod tests {
let complexity_cost = analysis.complexity(lwe_dim, one_lut_cost);
assert!(analysis.out_variances[lut1.i] == SymbolicVariance::LUT);
assert!(analysis.out_shapes[lut1.i] == Shape::number());
assert!(graph.out_shapes[lut1.i] == Shape::number());
assert!(analysis.levelled_complexity == LevelledComplexity::ZERO);
assert_eq!(analysis.out_precisions[lut1.i], 8);
assert_eq!(graph.out_precisions[lut1.i], 8);
assert_f64_eq(one_lut_cost, complexity_cost);
let constraint = analysis.constraint();
assert!(constraint.pareto_output.len() == 1);
@@ -765,9 +683,9 @@ mod tests {
lut_coeff: 0.0,
};
assert!(analysis.out_variances[dot.i] == expected_var);
assert!(analysis.out_shapes[dot.i] == Shape::number());
assert!(graph.out_shapes[dot.i] == Shape::number());
assert!(analysis.levelled_complexity == LevelledComplexity::ADDITION * 2);
assert_eq!(analysis.out_precisions[dot.i], 1);
assert_eq!(graph.out_precisions[dot.i], 1);
let expected_dot_cost = (2 * lwe_dim) as f64;
assert_f64_eq(expected_dot_cost, complexity_cost);
let constraint = analysis.constraint();
@@ -792,7 +710,7 @@ mod tests {
let complexity_cost = analysis.complexity(lwe_dim, one_lut_cost);
assert!(analysis.out_variances[dot.i].origin() == VO::Input);
assert_eq!(analysis.out_precisions[dot.i], 3);
assert_eq!(graph.out_precisions[dot.i], 3);
let expected_square_norm2 = weights.square_norm2() as f64;
let actual_square_norm2 = analysis.out_variances[dot.i].input_coeff;
// Due to call on log2() to compute manp the result is not exact

View File

@@ -240,10 +240,10 @@ pub fn optimize(
maximum_acceptable_error_probability: config.maximum_acceptable_error_probability,
ciphertext_modulus_log,
};
let dag = analyze::analyze(dag, &noise_config);
let &min_precision = dag.out_precisions.iter().min().unwrap();
let dag = analyze::analyze(dag, &noise_config);
let safe_variance = error::safe_variance_bound_2padbits(
min_precision as u64,
ciphertext_modulus_log,
@@ -357,6 +357,8 @@ pub fn optimize_v0(
mod tests {
use std::time::Instant;
use once_cell::sync::Lazy;
use super::*;
use crate::computing_cost::cpu::CpuComplexity;
use crate::config;
@@ -388,10 +390,12 @@ mod tests {
const _4_SIGMA: f64 = 1.0 - 0.999_936_657_516;
fn optimize(
dag: &unparametrized::OperationDag,
cache: &PersistDecompCache,
) -> OptimizationState {
static SHARED_CACHES: Lazy<PersistDecompCache> = Lazy::new(|| {
let processing_unit = config::ProcessingUnit::Cpu;
decomposition::cache(128, processing_unit, None)
});
fn optimize(dag: &unparametrized::OperationDag) -> OptimizationState {
let config = Config {
security_level: 128,
maximum_acceptable_error_probability: _4_SIGMA,
@@ -401,7 +405,7 @@ mod tests {
let search_space = SearchSpace::default_cpu();
super::optimize(dag, config, &search_space, cache)
super::optimize(dag, config, &search_space, &SHARED_CACHES)
}
struct Times {
@@ -442,15 +446,13 @@ mod tests {
complexity_model: &CpuComplexity::default(),
};
let cache = decomposition::cache(config.security_level, processing_unit, None);
let _ = optimize_v0(
sum_size,
precision,
config,
weight as f64,
&search_space,
&cache,
&SHARED_CACHES,
);
// ensure cache is filled
@@ -461,7 +463,7 @@ mod tests {
config,
weight as f64,
&search_space,
&cache,
&SHARED_CACHES,
);
times.dag_time += chrono.elapsed().as_nanos();
@@ -472,7 +474,7 @@ mod tests {
config,
weight as f64,
&search_space,
&cache,
&SHARED_CACHES,
);
times.worst_time += chrono.elapsed().as_nanos();
assert_eq!(
@@ -504,8 +506,6 @@ mod tests {
let processing_unit = config::ProcessingUnit::Cpu;
let security_level = 128;
let cache = decomposition::cache(security_level, processing_unit, None);
let mut dag = unparametrized::OperationDag::new();
{
let input1 = dag.add_input(precision, Shape::number());
@@ -542,14 +542,14 @@ mod tests {
complexity_model: &CpuComplexity::default(),
};
let state = optimize(&dag, &cache);
let state = optimize(&dag);
let state_ref = atomic_pattern::optimize_one(
1,
precision as u64,
config,
weight as f64,
&search_space,
&cache,
&SHARED_CACHES,
);
assert_eq!(
state.best_solution.is_some(),
@@ -567,7 +567,7 @@ mod tests {
assert!(sol.global_p_error <= 1.0);
}
fn no_lut_vs_lut(precision: Precision, cache: &PersistDecompCache) {
fn no_lut_vs_lut(precision: Precision) {
let mut dag_lut = unparametrized::OperationDag::new();
let input1 = dag_lut.add_input(precision, Shape::number());
let _lut1 = dag_lut.add_lut(input1, FunctionTable::UNKWOWN, precision);
@@ -575,8 +575,8 @@ mod tests {
let mut dag_no_lut = unparametrized::OperationDag::new();
let _input2 = dag_no_lut.add_input(precision, Shape::number());
let state_no_lut = optimize(&dag_no_lut, cache);
let state_lut = optimize(&dag_lut, cache);
let state_no_lut = optimize(&dag_no_lut);
let state_lut = optimize(&dag_lut);
assert_eq!(
state_no_lut.best_solution.is_some(),
state_lut.best_solution.is_some()
@@ -592,17 +592,14 @@ mod tests {
}
#[test]
fn test_lut_vs_no_lut() {
let processing_unit = config::ProcessingUnit::Cpu;
let cache = decomposition::cache(128, processing_unit, None);
for precision in 1..=8 {
no_lut_vs_lut(precision, &cache);
no_lut_vs_lut(precision);
}
}
fn lut_with_input_base_noise_better_than_lut_with_lut_base_noise(
precision: Precision,
weight: i64,
cache: &PersistDecompCache,
) {
let weight = &Weights::number(weight);
@@ -622,8 +619,8 @@ mod tests {
let _lut2 = dag_2.add_lut(scaled_lut1, FunctionTable::UNKWOWN, precision);
}
let state_1 = optimize(&dag_1, cache);
let state_2 = optimize(&dag_2, cache);
let state_1 = optimize(&dag_1);
let state_2 = optimize(&dag_2);
if state_1.best_solution.is_none() {
assert!(state_2.best_solution.is_none());
@@ -636,19 +633,15 @@ mod tests {
#[test]
fn test_lut_with_input_base_noise_better_than_lut_with_lut_base_noise() {
let processing_unit = config::ProcessingUnit::Cpu;
let cache = decomposition::cache(128, processing_unit, None);
for log_weight in 1..=16 {
let weight = 1 << log_weight;
for precision in 5..=9 {
lut_with_input_base_noise_better_than_lut_with_lut_base_noise(
precision, weight, &cache,
);
lut_with_input_base_noise_better_than_lut_with_lut_base_noise(precision, weight);
}
}
}
fn lut_1_layer_has_better_complexity(precision: Precision, cache: &PersistDecompCache) {
fn lut_1_layer_has_better_complexity(precision: Precision) {
let dag_1_layer = {
let mut dag = unparametrized::OperationDag::new();
let input1 = dag.add_input(precision, Shape::number());
@@ -664,19 +657,17 @@ mod tests {
dag
};
let sol_1_layer = optimize(&dag_1_layer, cache).best_solution.unwrap();
let sol_2_layer = optimize(&dag_2_layer, cache).best_solution.unwrap();
let sol_1_layer = optimize(&dag_1_layer).best_solution.unwrap();
let sol_2_layer = optimize(&dag_2_layer).best_solution.unwrap();
assert!(sol_1_layer.complexity < sol_2_layer.complexity);
}
#[test]
fn test_lut_1_layer_is_better() {
let processing_unit = config::ProcessingUnit::Cpu;
let cache = decomposition::cache(128, processing_unit, None);
// for some reason on 4, 5, 6, the complexity is already minimal
// this could be due to pre-defined pareto set
for precision in [1, 2, 3, 7, 8] {
lut_1_layer_has_better_complexity(precision, &cache);
lut_1_layer_has_better_complexity(precision);
}
}
@@ -688,10 +679,7 @@ mod tests {
let _lut2 = dag.add_lut(dot2, FunctionTable::UNKWOWN, precision);
}
fn assert_multi_precision_dominate_single(
weight: i64,
cache: &PersistDecompCache,
) -> Option<bool> {
fn assert_multi_precision_dominate_single(weight: i64) -> Option<bool> {
let low_precision = 4u8;
let high_precision = 5u8;
let mut dag_low = unparametrized::OperationDag::new();
@@ -704,12 +692,12 @@ mod tests {
circuit(&mut dag_multi, low_precision, weight);
circuit(&mut dag_multi, high_precision, 1);
}
let state_multi = optimize(&dag_multi, cache);
let state_multi = optimize(&dag_multi);
let mut sol_multi = state_multi.best_solution?;
let state_low = optimize(&dag_low, cache);
let state_high = optimize(&dag_high, cache);
let state_low = optimize(&dag_low);
let state_high = optimize(&dag_high);
let sol_low = state_low.best_solution.unwrap();
let sol_high = state_high.best_solution.unwrap();
@@ -728,12 +716,10 @@ mod tests {
#[test]
fn test_multi_precision_dominate_single() {
let processing_unit = config::ProcessingUnit::Cpu;
let cache = decomposition::cache(128, processing_unit, None);
let mut prev = Some(true); // true -> ... -> true -> false -> ... -> false
for log2_weight in 0..29 {
let weight = 1 << log2_weight;
let current = assert_multi_precision_dominate_single(weight, &cache);
let current = assert_multi_precision_dominate_single(weight);
#[allow(clippy::match_like_matches_macro)] // less readable
let authorized = match (prev, current) {
(Some(false), Some(true)) => false,
@@ -763,29 +749,22 @@ mod tests {
#[test]
fn test_global_p_error_input() {
let processing_unit = config::ProcessingUnit::Cpu;
let cache = decomposition::cache(128, processing_unit, None);
for precision in [4_u8, 8] {
for weight in [1, 3, 27, 243, 729] {
for dim in [1, 2, 16, 32] {
let _ = check_global_p_error_input(dim, weight, precision, &cache);
let _ = check_global_p_error_input(dim, weight, precision);
}
}
}
}
fn check_global_p_error_input(
dim: u64,
weight: i64,
precision: u8,
cache: &PersistDecompCache,
) -> f64 {
fn check_global_p_error_input(dim: u64, weight: i64, precision: u8) -> f64 {
let shape = Shape::vector(dim);
let weights = Weights::number(weight);
let mut dag = unparametrized::OperationDag::new();
let input1 = dag.add_input(precision, shape);
let _dot1 = dag.add_dot([input1], weights); // this is just several multiply
let state = optimize(&dag, cache);
let state = optimize(&dag);
let sol = state.best_solution.unwrap();
let worst_expected_p_error_dim = local_to_approx_global_p_error(sol.p_error, dim);
approx::assert_relative_eq!(sol.global_p_error, worst_expected_p_error_dim);
@@ -794,23 +773,16 @@ mod tests {
#[test]
fn test_global_p_error_lut() {
let processing_unit = config::ProcessingUnit::Cpu;
let cache = decomposition::cache(128, processing_unit, None);
for precision in [4_u8, 8] {
for weight in [1, 3, 27, 243, 729] {
for depth in [2, 16, 32] {
check_global_p_error_lut(depth, weight, precision, &cache);
check_global_p_error_lut(depth, weight, precision);
}
}
}
}
fn check_global_p_error_lut(
depth: u64,
weight: i64,
precision: u8,
cache: &PersistDecompCache,
) {
fn check_global_p_error_lut(depth: u64, weight: i64, precision: u8) {
let shape = Shape::number();
let weights = Weights::number(weight);
let mut dag = unparametrized::OperationDag::new();
@@ -819,7 +791,7 @@ mod tests {
let dot = dag.add_dot([last_val], &weights);
last_val = dag.add_lut(dot, FunctionTable::UNKWOWN, precision);
}
let state = optimize(&dag, cache);
let state = optimize(&dag);
let sol = state.best_solution.unwrap();
// the first lut on input has reduced impact on error probability
let lower_nb_dominating_lut = depth - 1;
@@ -856,8 +828,6 @@ mod tests {
#[allow(clippy::unnecessary_cast)] // clippy bug refusing as Precision on const
#[test]
fn test_global_p_error_dominating_lut() {
let processing_unit = config::ProcessingUnit::Cpu;
let cache = decomposition::cache(128, processing_unit, None);
let depth = 128;
let weights_low = 1;
let weights_high = 1;
@@ -870,7 +840,7 @@ mod tests {
weights_low,
weights_high,
);
let sol = optimize(&dag, &cache).best_solution.unwrap();
let sol = optimize(&dag).best_solution.unwrap();
// the 2 first luts and low precision/weight luts have little impact on error probability
let nb_dominating_lut = depth - 1;
let approx_global_p_error = local_to_approx_global_p_error(sol.p_error, nb_dominating_lut);
@@ -885,8 +855,6 @@ mod tests {
#[allow(clippy::unnecessary_cast)] // clippy bug refusing as Precision on const
#[test]
fn test_global_p_error_non_dominating_lut() {
let processing_unit = config::ProcessingUnit::Cpu;
let cache = decomposition::cache(128, processing_unit, None);
let depth = 128;
let weights_low = 1024 * 1024 * 3;
let weights_high = 1;
@@ -899,7 +867,7 @@ mod tests {
weights_low,
weights_high,
);
let sol = optimize(&dag, &cache).best_solution.unwrap();
let sol = optimize(&dag).best_solution.unwrap();
// all intern luts have an impact on error probability almost equaly
let nb_dominating_lut = (2 * depth) - 1;
let approx_global_p_error = local_to_approx_global_p_error(sol.p_error, nb_dominating_lut);
@@ -910,4 +878,89 @@ mod tests {
max_relative = 0.05
);
}
fn circuit_with_rounded_lut(
rounded_precision: Precision,
precision: Precision,
weight: i64,
) -> unparametrized::OperationDag {
// circuit with intermediate high precision in levelled op
let shape = Shape::number();
let mut dag = unparametrized::OperationDag::new();
let weight = Weights::number(weight);
let val = dag.add_input(precision, &shape);
let lut1 =
dag.add_expanded_rounded_lut(val, FunctionTable::UNKWOWN, rounded_precision, precision);
let dot = dag.add_dot([lut1], &weight);
let _lut2 = dag.add_expanded_rounded_lut(
dot,
FunctionTable::UNKWOWN,
rounded_precision,
rounded_precision,
);
dag
}
fn check_global_p_error_rounded_lut(
precision: Precision,
rounded_precision: Precision,
weight: i64,
) {
let dag_no_rounded = circuit_with_rounded_lut(precision, precision, weight);
let dag_rounded = circuit_with_rounded_lut(rounded_precision, precision, weight);
let dag_reduced = circuit_with_rounded_lut(rounded_precision, rounded_precision, weight);
let best_reduced = optimize(&dag_reduced).best_solution.unwrap();
let best_rounded = optimize(&dag_rounded).best_solution.unwrap();
let best_no_rounded_complexity = optimize(&dag_no_rounded)
.best_solution
.map_or(f64::INFINITY, |s| s.complexity);
// println!("Slowdown acc {rounded_precision} -> {precision}, {best_rounded.complexity/best_reduced.complexity}");
// println!("Speedup tlu {precision} -> {rounded_precision}, {best_no_rounded_complexity/best_rounded.complexity}");
if weight == 0 && precision - rounded_precision <= 4
|| weight == 16 && precision - rounded_precision <= 3
{
// linear slowdown with almost no margin
assert!(
best_rounded.complexity
<= best_reduced.complexity
* (1.0 + 1.01 * (precision - rounded_precision) as f64)
);
} else if precision - rounded_precision <= 4 {
// linear slowdown with margin
assert!(
best_rounded.complexity
<= best_reduced.complexity
* (1.0 + 1.5 * (precision - rounded_precision) as f64)
);
} else if precision != rounded_precision {
// slowdown
assert!(best_reduced.complexity < best_rounded.complexity);
}
// linear speedup
assert!(
best_rounded.complexity * (precision - rounded_precision) as f64
<= best_no_rounded_complexity
);
}
#[allow(clippy::unnecessary_cast)] // clippy bug refusing as Precision on const
#[test]
fn test_global_p_error_rounded_lut() {
let precision = 8 as Precision;
for rounded_precision in 4..9 {
check_global_p_error_rounded_lut(precision, rounded_precision, 1);
}
}
#[allow(clippy::unnecessary_cast)] // clippy bug refusing as Precision on const
#[test]
fn test_global_p_error_increased_accumulator() {
let rounded_precision = 8 as Precision;
for precision in [9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19, 20] {
for weight in [1, 2, 4, 8, 16, 32, 64, 128] {
println!("{precision} {weight}");
check_global_p_error_rounded_lut(precision, rounded_precision, weight);
}
}
}
}

View File

@@ -1,5 +1,5 @@
use crate::dag::operator::{Operator, Precision};
use crate::dag::unparametrized::{OperationDag, UnparameterizedOperator};
use crate::dag::operator::Precision;
use crate::dag::unparametrized::OperationDag;
use crate::noise_estimator::p_error::repeat_p_error;
use crate::optimization::atomic_pattern::Solution as WpSolution;
use crate::optimization::config::{Config, SearchSpace};
@@ -18,21 +18,8 @@ pub enum Solution {
WopSolution(WopSolution),
}
fn precision_op(op: &UnparameterizedOperator) -> Option<Precision> {
match op {
Operator::Input { out_precision, .. } | Operator::Lut { out_precision, .. } => {
Some(*out_precision)
}
Operator::Dot { .. } | Operator::LevelledOp { .. } => None,
}
}
fn max_precision(dag: &OperationDag) -> Precision {
dag.operators
.iter()
.filter_map(precision_op)
.max()
.unwrap_or(0)
dag.out_precisions.iter().copied().max().unwrap_or(0)
}
fn updated_global_p_error(nb_luts: u64, sol: WopSolution) -> WopSolution {