mirror of
https://github.com/zama-ai/concrete.git
synced 2026-02-09 12:15:09 -05:00
feat: rounded lut for classical pbs
step 1, provide the sub-dag construction
This commit is contained in:
@@ -15,6 +15,7 @@ puruspe = "0.2.0"
|
||||
|
||||
[dev-dependencies]
|
||||
approx = "0.5"
|
||||
once_cell = "1.16.0"
|
||||
pretty_assertions = "1.2.1"
|
||||
|
||||
[lib]
|
||||
|
||||
@@ -1,5 +1,2 @@
|
||||
pub mod operator;
|
||||
pub mod parameter_indexed;
|
||||
pub mod range_parametrized;
|
||||
pub mod unparametrized;
|
||||
pub mod value_parametrized;
|
||||
|
||||
@@ -65,22 +65,19 @@ pub type Precision = u8;
|
||||
pub const MIN_PRECISION: Precision = 1;
|
||||
|
||||
#[derive(Clone, PartialEq, Debug)]
|
||||
pub enum Operator<InputExtraData, LutExtraData, DotExtraData, LevelledOpExtraData> {
|
||||
pub enum Operator {
|
||||
Input {
|
||||
out_precision: Precision,
|
||||
out_shape: Shape,
|
||||
extra_data: InputExtraData,
|
||||
},
|
||||
Lut {
|
||||
input: OperatorIndex,
|
||||
table: FunctionTable,
|
||||
out_precision: Precision,
|
||||
extra_data: LutExtraData,
|
||||
},
|
||||
Dot {
|
||||
inputs: Vec<OperatorIndex>,
|
||||
weights: Weights,
|
||||
extra_data: DotExtraData,
|
||||
},
|
||||
LevelledOp {
|
||||
inputs: Vec<OperatorIndex>,
|
||||
@@ -88,7 +85,12 @@ pub enum Operator<InputExtraData, LutExtraData, DotExtraData, LevelledOpExtraDat
|
||||
manp: f64,
|
||||
out_shape: Shape,
|
||||
comment: String,
|
||||
extra_data: LevelledOpExtraData,
|
||||
},
|
||||
// Used to reduced or increase precision when the cyphertext is compatible with different precision
|
||||
// This is done without any checking
|
||||
UnsafeCast {
|
||||
input: OperatorIndex,
|
||||
out_precision: Precision, // precision is changed without modifying the input, can be increase or decrease
|
||||
},
|
||||
}
|
||||
|
||||
|
||||
@@ -1,25 +0,0 @@
|
||||
use crate::global_parameters::{ParameterCount, ParameterToOperation};
|
||||
|
||||
use super::operator::Operator;
|
||||
|
||||
pub struct InputParameterIndexed {
|
||||
pub lwe_dimension_index: usize,
|
||||
}
|
||||
|
||||
#[derive(Copy, Clone)]
|
||||
pub struct LutParametersIndexed {
|
||||
pub input_lwe_dimension_index: usize,
|
||||
pub ks_decomposition_parameter_index: usize,
|
||||
pub internal_lwe_dimension_index: usize,
|
||||
pub br_decomposition_parameter_index: usize,
|
||||
pub output_glwe_params_index: usize,
|
||||
}
|
||||
|
||||
pub(crate) type OperatorParameterIndexed =
|
||||
Operator<InputParameterIndexed, LutParametersIndexed, (), ()>;
|
||||
|
||||
pub struct OperationDag {
|
||||
pub(crate) operators: Vec<OperatorParameterIndexed>,
|
||||
pub(crate) parameters_count: ParameterCount,
|
||||
pub(crate) reverse_map: ParameterToOperation,
|
||||
}
|
||||
@@ -1,9 +0,0 @@
|
||||
use crate::dag::parameter_indexed::OperatorParameterIndexed;
|
||||
use crate::global_parameters::{ParameterRanges, ParameterToOperation};
|
||||
|
||||
#[allow(dead_code)]
|
||||
pub struct OperationDag {
|
||||
pub(crate) operators: Vec<OperatorParameterIndexed>,
|
||||
pub(crate) parameter_ranges: ParameterRanges,
|
||||
pub(crate) reverse_map: ParameterToOperation,
|
||||
}
|
||||
@@ -1,24 +1,36 @@
|
||||
use std::fmt::Write;
|
||||
|
||||
use crate::dag::operator::{
|
||||
FunctionTable, LevelledComplexity, Operator, OperatorIndex, Precision, Shape, Weights,
|
||||
dot_kind, DotKind, FunctionTable, LevelledComplexity, Operator, OperatorIndex, Precision,
|
||||
Shape, Weights,
|
||||
};
|
||||
|
||||
pub(crate) type UnparameterizedOperator = Operator<(), (), (), ()>;
|
||||
pub(crate) type UnparameterizedOperator = Operator;
|
||||
|
||||
#[derive(Clone, PartialEq, Debug)]
|
||||
#[must_use]
|
||||
pub struct OperationDag {
|
||||
pub(crate) operators: Vec<UnparameterizedOperator>,
|
||||
// Collect all operators ouput shape
|
||||
pub(crate) out_shapes: Vec<Shape>,
|
||||
// Collect all operators ouput precision
|
||||
pub(crate) out_precisions: Vec<Precision>,
|
||||
}
|
||||
|
||||
impl OperationDag {
|
||||
pub const fn new() -> Self {
|
||||
Self { operators: vec![] }
|
||||
Self {
|
||||
operators: vec![],
|
||||
out_shapes: vec![],
|
||||
out_precisions: vec![],
|
||||
}
|
||||
}
|
||||
|
||||
fn add_operator(&mut self, operator: UnparameterizedOperator) -> OperatorIndex {
|
||||
let i = self.operators.len();
|
||||
self.out_precisions
|
||||
.push(self.infer_out_precision(&operator));
|
||||
self.out_shapes.push(self.infer_out_shape(&operator));
|
||||
self.operators.push(operator);
|
||||
OperatorIndex { i }
|
||||
}
|
||||
@@ -32,7 +44,6 @@ impl OperationDag {
|
||||
self.add_operator(Operator::Input {
|
||||
out_precision,
|
||||
out_shape,
|
||||
extra_data: (),
|
||||
})
|
||||
}
|
||||
|
||||
@@ -46,7 +57,6 @@ impl OperationDag {
|
||||
input,
|
||||
table,
|
||||
out_precision,
|
||||
extra_data: (),
|
||||
})
|
||||
}
|
||||
|
||||
@@ -57,11 +67,7 @@ impl OperationDag {
|
||||
) -> OperatorIndex {
|
||||
let inputs = inputs.into();
|
||||
let weights = weights.into();
|
||||
self.add_operator(Operator::Dot {
|
||||
inputs,
|
||||
weights,
|
||||
extra_data: (),
|
||||
})
|
||||
self.add_operator(Operator::Dot { inputs, weights })
|
||||
}
|
||||
|
||||
pub fn add_levelled_op(
|
||||
@@ -81,11 +87,25 @@ impl OperationDag {
|
||||
manp,
|
||||
out_shape,
|
||||
comment,
|
||||
extra_data: (),
|
||||
};
|
||||
self.add_operator(op)
|
||||
}
|
||||
|
||||
pub fn add_unsafe_cast(
|
||||
&mut self,
|
||||
input: OperatorIndex,
|
||||
out_precision: Precision,
|
||||
) -> OperatorIndex {
|
||||
let input_precision = self.out_precisions[input.i];
|
||||
if input_precision == out_precision {
|
||||
return input;
|
||||
}
|
||||
self.add_operator(Operator::UnsafeCast {
|
||||
input,
|
||||
out_precision,
|
||||
})
|
||||
}
|
||||
|
||||
#[allow(clippy::len_without_is_empty)]
|
||||
pub fn len(&self) -> usize {
|
||||
self.operators.len()
|
||||
@@ -100,6 +120,139 @@ impl OperationDag {
|
||||
}
|
||||
acc
|
||||
}
|
||||
|
||||
fn add_shift_left_lsb_to_msb_no_padding(&mut self, input: OperatorIndex) -> OperatorIndex {
|
||||
// Convert any input to simple 1bit msb replacing the padding
|
||||
// For now encoding is not explicit, so 1 bit content without padding <=> 0 bit content with padding.
|
||||
let in_precision = self.out_precisions[input.i];
|
||||
let shift_factor = Weights::number(1 << (in_precision as i64));
|
||||
let lsb_as_msb = self.add_dot([input], shift_factor);
|
||||
self.add_unsafe_cast(lsb_as_msb, 0 as Precision)
|
||||
}
|
||||
|
||||
fn add_lut_1bit_no_padding(
|
||||
&mut self,
|
||||
input: OperatorIndex,
|
||||
table: FunctionTable,
|
||||
out_precision: Precision,
|
||||
) -> OperatorIndex {
|
||||
// For now encoding is not explicit, so 1 bit content without padding <=> 0 bit content with padding.
|
||||
let in_precision = self.out_precisions[input.i];
|
||||
assert!(in_precision == 0);
|
||||
// An add after with a clear constant is skipped here as it doesn't change noise handling.
|
||||
self.add_lut(input, table, out_precision)
|
||||
}
|
||||
|
||||
fn add_shift_right_msb_no_padding_to_lsb(
|
||||
&mut self,
|
||||
input: OperatorIndex,
|
||||
out_precision: Precision,
|
||||
) -> OperatorIndex {
|
||||
// Convert simple 1 bit msb to a nbit with zero padding
|
||||
let to_nbits_padded = FunctionTable::UNKWOWN;
|
||||
self.add_lut_1bit_no_padding(input, to_nbits_padded, out_precision)
|
||||
}
|
||||
|
||||
fn add_isolate_lowest_bit(&mut self, input: OperatorIndex) -> OperatorIndex {
|
||||
// The lowest bit is converted to a cyphertext of same precision as input.
|
||||
// Introduce a pbs of input precision but this precision is only used on 1 levelled op and converted to lower precision
|
||||
// Noise is reduced by a pbs.
|
||||
let out_precision = self.out_precisions[input.i];
|
||||
let lsb_as_msb = self.add_shift_left_lsb_to_msb_no_padding(input);
|
||||
self.add_shift_right_msb_no_padding_to_lsb(lsb_as_msb, out_precision)
|
||||
}
|
||||
|
||||
pub fn add_truncate_1_bit(&mut self, input: OperatorIndex) -> OperatorIndex {
|
||||
// Reset a bit.
|
||||
// ex: 10110 is truncated to 1011, 10111 is truncated to 1011
|
||||
let in_precision = self.out_precisions[input.i];
|
||||
let lowest_bit = self.add_isolate_lowest_bit(input);
|
||||
let bit_cleared = self.add_dot([input, lowest_bit], [1, -1]);
|
||||
self.add_unsafe_cast(bit_cleared, in_precision - 1)
|
||||
}
|
||||
|
||||
pub fn add_expanded_round(
|
||||
&mut self,
|
||||
input: OperatorIndex,
|
||||
rounded_precision: Precision,
|
||||
) -> OperatorIndex {
|
||||
// Round such that the ouput has precision out_precision.
|
||||
// We round by adding 2**(removed_precision - 1) to the last remaining bit to clear (this step is a no-op).
|
||||
// Than all lower bits are cleared.
|
||||
// Note: this is a simplified graph, some constant additions are missing without consequence on crypto parameter choice.
|
||||
// Note: reset and rounds could be done by 4, 3, 2 and 1 bits groups for efficiency.
|
||||
// bit efficiency is better for 4 precision then 3, but the feasability is lower for high noise
|
||||
let in_precision = self.out_precisions[input.i];
|
||||
assert!(rounded_precision <= in_precision);
|
||||
if in_precision == rounded_precision {
|
||||
return input;
|
||||
}
|
||||
// Add rounding constant, this is a represented as non-op since it doesn't influence crypto parameters.
|
||||
let mut rounded = input;
|
||||
// The rounded is in high precision with garbage lowest bits
|
||||
let bits_to_truncate = in_precision - rounded_precision;
|
||||
for _ in 1..=bits_to_truncate as i64 {
|
||||
rounded = self.add_truncate_1_bit(rounded);
|
||||
}
|
||||
rounded
|
||||
}
|
||||
|
||||
pub fn add_expanded_rounded_lut(
|
||||
&mut self,
|
||||
input: OperatorIndex,
|
||||
table: FunctionTable,
|
||||
rounded_precision: Precision,
|
||||
out_precision: Precision,
|
||||
) -> OperatorIndex {
|
||||
// note: this is a simplified graph, some constant additions are missing without consequence on crypto parameter choice.
|
||||
let rounded = self.add_expanded_round(input, rounded_precision);
|
||||
self.add_lut(rounded, table, out_precision)
|
||||
}
|
||||
|
||||
fn infer_out_shape(&self, op: &UnparameterizedOperator) -> Shape {
|
||||
match op {
|
||||
Operator::Input { out_shape, .. } | Operator::LevelledOp { out_shape, .. } => {
|
||||
out_shape.clone()
|
||||
}
|
||||
Operator::Lut { input, .. } | Operator::UnsafeCast { input, .. } => {
|
||||
self.out_shapes[input.i].clone()
|
||||
}
|
||||
Operator::Dot {
|
||||
inputs, weights, ..
|
||||
} => {
|
||||
let input_shape = self.out_shapes[inputs[0].i].clone();
|
||||
let kind = dot_kind(inputs.len() as u64, &input_shape, weights);
|
||||
match kind {
|
||||
DotKind::Simple | DotKind::Tensor | DotKind::CompatibleTensor => {
|
||||
Shape::number()
|
||||
}
|
||||
DotKind::Broadcast { .. } => Shape::vector(input_shape.first_dim_size()),
|
||||
DotKind::Unsupported { .. } => {
|
||||
let weights_shape = &weights.shape;
|
||||
println!();
|
||||
println!();
|
||||
println!("Error diagnostic on dot operation:");
|
||||
println!(
|
||||
"Incompatible operands: <{input_shape:?}> DOT <{weights_shape:?}>"
|
||||
);
|
||||
println!();
|
||||
panic!("Unsupported or invalid dot operation")
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
fn infer_out_precision(&self, op: &UnparameterizedOperator) -> Precision {
|
||||
match op {
|
||||
Operator::Input { out_precision, .. }
|
||||
| Operator::Lut { out_precision, .. }
|
||||
| Operator::UnsafeCast { out_precision, .. } => *out_precision,
|
||||
Operator::Dot { inputs, .. } | Operator::LevelledOp { inputs, .. } => {
|
||||
self.out_precisions[inputs[0].i]
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
@@ -138,12 +291,10 @@ mod tests {
|
||||
Operator::Input {
|
||||
out_precision: 1,
|
||||
out_shape: Shape::number(),
|
||||
extra_data: ()
|
||||
},
|
||||
Operator::Input {
|
||||
out_precision: 2,
|
||||
out_shape: Shape::number(),
|
||||
extra_data: ()
|
||||
},
|
||||
Operator::LevelledOp {
|
||||
inputs: vec![input1, input2],
|
||||
@@ -151,13 +302,11 @@ mod tests {
|
||||
manp: 1.0,
|
||||
out_shape: Shape::number(),
|
||||
comment: "sum".to_string(),
|
||||
extra_data: ()
|
||||
},
|
||||
Operator::Lut {
|
||||
input: sum1,
|
||||
table: FunctionTable::UNKWOWN,
|
||||
out_precision: 1,
|
||||
extra_data: ()
|
||||
},
|
||||
Operator::LevelledOp {
|
||||
inputs: vec![input1, lut1],
|
||||
@@ -165,7 +314,6 @@ mod tests {
|
||||
manp: 1.0,
|
||||
out_shape: Shape::vector(2),
|
||||
comment: "concat".to_string(),
|
||||
extra_data: ()
|
||||
},
|
||||
Operator::Dot {
|
||||
inputs: vec![concat],
|
||||
@@ -173,15 +321,119 @@ mod tests {
|
||||
shape: Shape::vector(2),
|
||||
values: vec![1, 2]
|
||||
},
|
||||
extra_data: ()
|
||||
},
|
||||
Operator::Lut {
|
||||
input: dot,
|
||||
table: FunctionTable::UNKWOWN,
|
||||
out_precision: 2,
|
||||
extra_data: ()
|
||||
}
|
||||
]
|
||||
);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_rounded_lut() {
|
||||
let mut graph = OperationDag::new();
|
||||
let out_precision = 5;
|
||||
let rounded_precision = 2;
|
||||
let input1 = graph.add_input(out_precision, Shape::number());
|
||||
let _ = graph.add_expanded_rounded_lut(
|
||||
input1,
|
||||
FunctionTable::UNKWOWN,
|
||||
rounded_precision,
|
||||
out_precision,
|
||||
);
|
||||
let expecteds = [
|
||||
Operator::Input {
|
||||
out_precision,
|
||||
out_shape: Shape::number(),
|
||||
},
|
||||
// The rounding addition skipped, it's a no-op wrt crypto parameter
|
||||
// Clear: cleared = input - bit0
|
||||
//// Extract bit
|
||||
Operator::Dot {
|
||||
inputs: vec![input1],
|
||||
weights: Weights::number(1 << 5),
|
||||
},
|
||||
Operator::UnsafeCast {
|
||||
input: OperatorIndex { i: 1 },
|
||||
out_precision: 0,
|
||||
},
|
||||
//// 1 Bit to out_precision
|
||||
Operator::Lut {
|
||||
input: OperatorIndex { i: 2 },
|
||||
table: FunctionTable::UNKWOWN,
|
||||
out_precision: 5,
|
||||
},
|
||||
//// Erase bit
|
||||
Operator::Dot {
|
||||
inputs: vec![input1, OperatorIndex { i: 3 }],
|
||||
weights: Weights::vector([1, -1]),
|
||||
},
|
||||
Operator::UnsafeCast {
|
||||
input: OperatorIndex { i: 4 },
|
||||
out_precision: 4,
|
||||
},
|
||||
// Clear: cleared = input - bit0 - bit1
|
||||
//// Extract bit
|
||||
Operator::Dot {
|
||||
inputs: vec![OperatorIndex { i: 5 }],
|
||||
weights: Weights::number(1 << 4),
|
||||
},
|
||||
Operator::UnsafeCast {
|
||||
input: OperatorIndex { i: 6 },
|
||||
out_precision: 0,
|
||||
},
|
||||
//// 1 Bit to out_precision
|
||||
Operator::Lut {
|
||||
input: OperatorIndex { i: 7 },
|
||||
table: FunctionTable::UNKWOWN,
|
||||
out_precision: 4,
|
||||
},
|
||||
//// Erase bit
|
||||
Operator::Dot {
|
||||
inputs: vec![OperatorIndex { i: 5 }, OperatorIndex { i: 8 }],
|
||||
weights: Weights::vector([1, -1]),
|
||||
},
|
||||
Operator::UnsafeCast {
|
||||
input: OperatorIndex { i: 9 },
|
||||
out_precision: 3,
|
||||
},
|
||||
// Clear: cleared = input - bit0 - bit1 - bit2
|
||||
//// Extract bit
|
||||
Operator::Dot {
|
||||
inputs: vec![OperatorIndex { i: 10 }],
|
||||
weights: Weights::number(1 << 3),
|
||||
},
|
||||
Operator::UnsafeCast {
|
||||
input: OperatorIndex { i: 11 },
|
||||
out_precision: 0,
|
||||
},
|
||||
//// 1 Bit to out_precision
|
||||
Operator::Lut {
|
||||
input: OperatorIndex { i: 12 },
|
||||
table: FunctionTable::UNKWOWN,
|
||||
out_precision: 3,
|
||||
},
|
||||
//// Erase bit
|
||||
Operator::Dot {
|
||||
inputs: vec![OperatorIndex { i: 10 }, OperatorIndex { i: 13 }],
|
||||
weights: Weights::vector([1, -1]),
|
||||
},
|
||||
Operator::UnsafeCast {
|
||||
input: OperatorIndex { i: 14 },
|
||||
out_precision: 2,
|
||||
},
|
||||
// Lut on rounded precision
|
||||
Operator::Lut {
|
||||
input: OperatorIndex { i: 15 },
|
||||
table: FunctionTable::UNKWOWN,
|
||||
out_precision: 5,
|
||||
},
|
||||
];
|
||||
assert_eq!(expecteds.len(), graph.operators.len());
|
||||
for (i, (expected, actual)) in std::iter::zip(expecteds, graph.operators).enumerate() {
|
||||
assert_eq!(expected, actual, "{i}-th operation");
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
@@ -1,9 +0,0 @@
|
||||
use crate::dag::parameter_indexed::OperatorParameterIndexed;
|
||||
use crate::global_parameters::{ParameterToOperation, ParameterValues};
|
||||
|
||||
#[allow(dead_code)]
|
||||
pub struct OperationDag {
|
||||
pub(crate) operators: Vec<OperatorParameterIndexed>,
|
||||
pub(crate) parameter_ranges: ParameterValues,
|
||||
pub(crate) reverse_map: ParameterToOperation,
|
||||
}
|
||||
@@ -1,43 +1,14 @@
|
||||
use std::collections::HashSet;
|
||||
|
||||
use crate::dag::operator::{Operator, OperatorIndex};
|
||||
use crate::dag::parameter_indexed::{
|
||||
InputParameterIndexed, LutParametersIndexed, OperatorParameterIndexed,
|
||||
};
|
||||
use crate::dag::unparametrized::UnparameterizedOperator;
|
||||
use crate::dag::{parameter_indexed, range_parametrized, unparametrized};
|
||||
use crate::parameters::{
|
||||
BrDecompositionParameterRanges, BrDecompositionParameters, GlweParameterRanges, GlweParameters,
|
||||
KsDecompositionParameterRanges, KsDecompositionParameters,
|
||||
BrDecompositionParameterRanges, GlweParameterRanges, KsDecompositionParameterRanges,
|
||||
};
|
||||
|
||||
#[derive(Clone)]
|
||||
pub(crate) struct ParameterToOperation {
|
||||
pub glwe: Vec<Vec<OperatorIndex>>,
|
||||
pub br_decomposition: Vec<Vec<OperatorIndex>>,
|
||||
pub ks_decomposition: Vec<Vec<OperatorIndex>>,
|
||||
}
|
||||
|
||||
#[derive(Clone, Copy, PartialEq, Eq, Debug)]
|
||||
pub struct ParameterCount {
|
||||
struct ParameterCount {
|
||||
pub glwe: usize,
|
||||
pub br_decomposition: usize,
|
||||
pub ks_decomposition: usize,
|
||||
}
|
||||
|
||||
#[derive(Clone)]
|
||||
pub struct ParameterRanges {
|
||||
pub glwe: Vec<GlweParameterRanges>,
|
||||
pub br_decomposition: Vec<BrDecompositionParameterRanges>, // 0 => lpetit , 1 => l plus grand
|
||||
pub ks_decomposition: Vec<KsDecompositionParameterRanges>,
|
||||
}
|
||||
|
||||
pub struct ParameterValues {
|
||||
pub glwe: Vec<GlweParameters>,
|
||||
pub br_decomposition: Vec<BrDecompositionParameters>,
|
||||
pub ks_decomposition: Vec<KsDecompositionParameters>,
|
||||
}
|
||||
|
||||
#[derive(Clone, Copy)]
|
||||
pub struct ParameterDomains {
|
||||
// move next comment to pareto ranges definition
|
||||
@@ -91,289 +62,3 @@ impl Range {
|
||||
self.into_iter().collect()
|
||||
}
|
||||
}
|
||||
|
||||
#[must_use]
|
||||
#[allow(clippy::needless_pass_by_value)]
|
||||
pub fn minimal_unify(_g: unparametrized::OperationDag) -> parameter_indexed::OperationDag {
|
||||
todo!()
|
||||
}
|
||||
|
||||
fn convert_maximal(op: UnparameterizedOperator) -> OperatorParameterIndexed {
|
||||
let external_glwe_index = 0;
|
||||
let internal_lwe_index = 1;
|
||||
let br_decomposition_index = 0;
|
||||
let ks_decomposition_index = 0;
|
||||
match op {
|
||||
Operator::Input {
|
||||
out_precision,
|
||||
out_shape,
|
||||
..
|
||||
} => Operator::Input {
|
||||
out_precision,
|
||||
out_shape,
|
||||
extra_data: InputParameterIndexed {
|
||||
lwe_dimension_index: external_glwe_index,
|
||||
},
|
||||
},
|
||||
Operator::Lut {
|
||||
input,
|
||||
table,
|
||||
out_precision,
|
||||
..
|
||||
} => Operator::Lut {
|
||||
input,
|
||||
table,
|
||||
out_precision,
|
||||
extra_data: LutParametersIndexed {
|
||||
input_lwe_dimension_index: external_glwe_index,
|
||||
ks_decomposition_parameter_index: ks_decomposition_index,
|
||||
internal_lwe_dimension_index: internal_lwe_index,
|
||||
br_decomposition_parameter_index: br_decomposition_index,
|
||||
output_glwe_params_index: external_glwe_index,
|
||||
},
|
||||
},
|
||||
Operator::Dot {
|
||||
inputs, weights, ..
|
||||
} => Operator::Dot {
|
||||
inputs,
|
||||
weights,
|
||||
extra_data: (),
|
||||
},
|
||||
Operator::LevelledOp {
|
||||
inputs,
|
||||
complexity,
|
||||
manp,
|
||||
out_shape,
|
||||
comment,
|
||||
..
|
||||
} => Operator::LevelledOp {
|
||||
inputs,
|
||||
complexity,
|
||||
manp,
|
||||
comment,
|
||||
out_shape,
|
||||
extra_data: (),
|
||||
},
|
||||
}
|
||||
}
|
||||
|
||||
#[must_use]
|
||||
pub fn maximal_unify(g: unparametrized::OperationDag) -> parameter_indexed::OperationDag {
|
||||
let operators: Vec<_> = g.operators.into_iter().map(convert_maximal).collect();
|
||||
|
||||
let parameters = ParameterCount {
|
||||
glwe: 2,
|
||||
br_decomposition: 1,
|
||||
ks_decomposition: 1,
|
||||
};
|
||||
|
||||
let mut reverse_map = ParameterToOperation {
|
||||
glwe: vec![vec![], vec![]],
|
||||
br_decomposition: vec![vec![]],
|
||||
ks_decomposition: vec![vec![]],
|
||||
};
|
||||
|
||||
for (i, op) in operators.iter().enumerate() {
|
||||
let index = OperatorIndex { i };
|
||||
match op {
|
||||
Operator::Input { .. } => {
|
||||
reverse_map.glwe[0].push(index);
|
||||
}
|
||||
Operator::Lut { .. } => {
|
||||
reverse_map.glwe[0].push(index);
|
||||
reverse_map.glwe[1].push(index);
|
||||
reverse_map.br_decomposition[0].push(index);
|
||||
reverse_map.ks_decomposition[0].push(index);
|
||||
}
|
||||
Operator::Dot { .. } | Operator::LevelledOp { .. } => {
|
||||
reverse_map.glwe[0].push(index);
|
||||
reverse_map.glwe[1].push(index);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
parameter_indexed::OperationDag {
|
||||
operators,
|
||||
parameters_count: parameters,
|
||||
reverse_map,
|
||||
}
|
||||
}
|
||||
|
||||
#[must_use]
|
||||
pub fn domains_to_ranges(
|
||||
parameter_indexed::OperationDag {
|
||||
operators,
|
||||
parameters_count,
|
||||
reverse_map,
|
||||
}: parameter_indexed::OperationDag,
|
||||
domains: ParameterDomains,
|
||||
) -> range_parametrized::OperationDag {
|
||||
let mut constrained_glwe_parameter_indexes = HashSet::new();
|
||||
for op in &operators {
|
||||
if let Operator::Lut { extra_data, .. } = op {
|
||||
let _ = constrained_glwe_parameter_indexes.insert(extra_data.output_glwe_params_index);
|
||||
}
|
||||
}
|
||||
|
||||
let mut glwe = vec![];
|
||||
|
||||
for i in 0..parameters_count.glwe {
|
||||
if constrained_glwe_parameter_indexes.contains(&i) {
|
||||
glwe.push(domains.glwe_pbs_constrained);
|
||||
} else {
|
||||
glwe.push(domains.free_glwe);
|
||||
}
|
||||
}
|
||||
|
||||
let parameter_ranges = ParameterRanges {
|
||||
glwe,
|
||||
br_decomposition: vec![domains.br_decomposition; parameters_count.br_decomposition],
|
||||
ks_decomposition: vec![domains.ks_decomposition; parameters_count.ks_decomposition],
|
||||
};
|
||||
|
||||
range_parametrized::OperationDag {
|
||||
operators,
|
||||
parameter_ranges,
|
||||
reverse_map,
|
||||
}
|
||||
}
|
||||
|
||||
// fn fill_ranges(g: parameter_indexed::AtomicPatternDag) -> parameter_ranged::AtomicPatternDag {
|
||||
// //check unconstrained GlweDim -> set range_poly_size=[1, 2[
|
||||
// todo!()
|
||||
// }
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use super::*;
|
||||
use crate::dag::operator::{FunctionTable, LevelledComplexity, Shape};
|
||||
|
||||
#[test]
|
||||
fn test_maximal_unify() {
|
||||
let mut graph = unparametrized::OperationDag::new();
|
||||
|
||||
let input1 = graph.add_input(1, Shape::number());
|
||||
|
||||
let input2 = graph.add_input(2, Shape::number());
|
||||
|
||||
let cpx_add = LevelledComplexity::ADDITION;
|
||||
let sum1 = graph.add_levelled_op([input1, input2], cpx_add, 1.0, Shape::number(), "sum");
|
||||
|
||||
let lut1 = graph.add_lut(sum1, FunctionTable::UNKWOWN, 2);
|
||||
|
||||
let concat = graph.add_levelled_op([input1, lut1], cpx_add, 1.0, Shape::number(), "concat");
|
||||
|
||||
let dot = graph.add_dot([concat], [1, 2]);
|
||||
|
||||
let lut2 = graph.add_lut(dot, FunctionTable::UNKWOWN, 2);
|
||||
|
||||
let graph_params = maximal_unify(graph);
|
||||
|
||||
assert_eq!(
|
||||
graph_params.parameters_count,
|
||||
ParameterCount {
|
||||
glwe: 2,
|
||||
br_decomposition: 1,
|
||||
ks_decomposition: 1,
|
||||
}
|
||||
);
|
||||
|
||||
assert_eq!(
|
||||
graph_params.reverse_map.glwe,
|
||||
vec![
|
||||
vec![input1, input2, sum1, lut1, concat, dot, lut2],
|
||||
vec![sum1, lut1, concat, dot, lut2]
|
||||
]
|
||||
);
|
||||
|
||||
assert_eq!(
|
||||
graph_params.reverse_map.br_decomposition,
|
||||
vec![vec![lut1, lut2]]
|
||||
);
|
||||
|
||||
assert_eq!(
|
||||
graph_params.reverse_map.ks_decomposition,
|
||||
vec![vec![lut1, lut2]]
|
||||
);
|
||||
// collectes l'ensemble des parametres
|
||||
// unify structurellement les parametres identiques
|
||||
// => counts
|
||||
// =>
|
||||
// let parametrized_expr = { global, dag + indexation}
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_simple_lwe() {
|
||||
let mut graph = unparametrized::OperationDag::new();
|
||||
let input1 = graph.add_input(1, Shape::number());
|
||||
let _input2 = graph.add_input(2, Shape::number());
|
||||
|
||||
let graph_params = maximal_unify(graph);
|
||||
|
||||
let range_parametrized::OperationDag {
|
||||
operators,
|
||||
parameter_ranges,
|
||||
reverse_map: _,
|
||||
} = domains_to_ranges(graph_params, DEFAUT_DOMAINS);
|
||||
|
||||
let input_1_lwe_params = match &operators[input1.i] {
|
||||
Operator::Input { extra_data, .. } => extra_data.lwe_dimension_index,
|
||||
_ => unreachable!(),
|
||||
};
|
||||
|
||||
dbg!(¶meter_ranges.glwe);
|
||||
|
||||
assert_eq!(
|
||||
DEFAUT_DOMAINS.free_glwe,
|
||||
parameter_ranges.glwe[input_1_lwe_params]
|
||||
);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_simple_lwe2() {
|
||||
let mut graph = unparametrized::OperationDag::new();
|
||||
let input1 = graph.add_input(1, Shape::number());
|
||||
let input2 = graph.add_input(2, Shape::number());
|
||||
|
||||
let cpx_add = LevelledComplexity::ADDITION;
|
||||
let concat =
|
||||
graph.add_levelled_op([input1, input2], cpx_add, 1.0, Shape::vector(2), "concat");
|
||||
|
||||
let lut1 = graph.add_lut(concat, FunctionTable::UNKWOWN, 2);
|
||||
|
||||
let graph_params = maximal_unify(graph);
|
||||
|
||||
let range_parametrized::OperationDag {
|
||||
operators,
|
||||
parameter_ranges,
|
||||
reverse_map: _,
|
||||
} = domains_to_ranges(graph_params, DEFAUT_DOMAINS);
|
||||
|
||||
let input_1_lwe_params = match &operators[input1.i] {
|
||||
Operator::Input { extra_data, .. } => extra_data.lwe_dimension_index,
|
||||
_ => unreachable!(),
|
||||
};
|
||||
assert_eq!(
|
||||
DEFAUT_DOMAINS.glwe_pbs_constrained,
|
||||
parameter_ranges.glwe[input_1_lwe_params]
|
||||
);
|
||||
|
||||
let lut1_out_glwe_params = match &operators[lut1.i] {
|
||||
Operator::Lut { extra_data, .. } => extra_data.output_glwe_params_index,
|
||||
_ => unreachable!(),
|
||||
};
|
||||
assert_eq!(
|
||||
DEFAUT_DOMAINS.glwe_pbs_constrained,
|
||||
parameter_ranges.glwe[lut1_out_glwe_params]
|
||||
);
|
||||
|
||||
let lut1_internal_glwe_params = match &operators[lut1.i] {
|
||||
Operator::Lut { extra_data, .. } => extra_data.internal_lwe_dimension_index,
|
||||
_ => unreachable!(),
|
||||
};
|
||||
assert_eq!(
|
||||
DEFAUT_DOMAINS.free_glwe,
|
||||
parameter_ranges.glwe[lut1_internal_glwe_params]
|
||||
);
|
||||
}
|
||||
}
|
||||
|
||||
@@ -54,6 +54,8 @@ fn assert_non_empty_inputs(op: &unparametrized::UnparameterizedOperator) {
|
||||
fn assert_dag_correctness(dag: &unparametrized::OperationDag) {
|
||||
for op in &dag.operators {
|
||||
assert_non_empty_inputs(op);
|
||||
assert_inputs_uniform_precisions(op, &dag.out_precisions);
|
||||
assert_dot_uniform_inputs_shape(op, &dag.out_shapes);
|
||||
}
|
||||
}
|
||||
|
||||
@@ -68,10 +70,6 @@ fn assert_valid_variances(dag: &OperationDag) {
|
||||
}
|
||||
|
||||
fn assert_properties_correctness(dag: &OperationDag) {
|
||||
for op in &dag.operators {
|
||||
assert_inputs_uniform_precisions(op, &dag.out_precisions);
|
||||
assert_dot_uniform_inputs_shape(op, &dag.out_shapes);
|
||||
}
|
||||
assert_valid_variances(dag);
|
||||
}
|
||||
|
||||
@@ -89,10 +87,6 @@ fn variance_origin(inputs: &[OperatorIndex], out_variances: &[SymbolicVariance])
|
||||
#[derive(Clone, Debug)]
|
||||
pub struct OperationDag {
|
||||
pub operators: Vec<Op>,
|
||||
// Collect all operators ouput shape
|
||||
pub out_shapes: Vec<Shape>,
|
||||
// Collect all operators ouput precision
|
||||
pub out_precisions: Vec<Precision>,
|
||||
// Collect all operators ouput variances
|
||||
pub out_variances: Vec<SymbolicVariance>,
|
||||
pub nb_luts: u64,
|
||||
@@ -118,66 +112,6 @@ pub struct VariancesAndBound {
|
||||
pub all_in_lut: Vec<(u64, SymbolicVariance)>,
|
||||
}
|
||||
|
||||
fn out_shape(op: &unparametrized::UnparameterizedOperator, out_shapes: &mut [Shape]) -> Shape {
|
||||
match op {
|
||||
Op::Input { out_shape, .. } | Op::LevelledOp { out_shape, .. } => out_shape.clone(),
|
||||
Op::Lut { input, .. } => out_shapes[input.i].clone(),
|
||||
Op::Dot {
|
||||
inputs, weights, ..
|
||||
} => {
|
||||
if inputs.is_empty() {
|
||||
return Shape::number();
|
||||
}
|
||||
let input_shape = first(inputs, out_shapes);
|
||||
let kind = dot_kind(inputs.len() as u64, input_shape, weights);
|
||||
match kind {
|
||||
DK::Simple | DK::Tensor => Shape::number(),
|
||||
DK::CompatibleTensor => weights.shape.clone(),
|
||||
DK::Broadcast { .. } => Shape::vector(input_shape.first_dim_size()),
|
||||
DK::Unsupported { .. } => {
|
||||
let weights_shape = &weights.shape;
|
||||
println!();
|
||||
println!();
|
||||
println!("Error diagnostic on dot operation:");
|
||||
println!("Incompatible operands: <{input_shape:?}> DOT <{weights_shape:?}>");
|
||||
println!();
|
||||
panic!("Unsupported or invalid dot operation")
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
fn out_shapes(dag: &unparametrized::OperationDag) -> Vec<Shape> {
|
||||
let nb_ops = dag.operators.len();
|
||||
let mut out_shapes = Vec::<Shape>::with_capacity(nb_ops);
|
||||
for op in &dag.operators {
|
||||
let shape = out_shape(op, &mut out_shapes);
|
||||
out_shapes.push(shape);
|
||||
}
|
||||
out_shapes
|
||||
}
|
||||
|
||||
fn out_precision(
|
||||
op: &unparametrized::UnparameterizedOperator,
|
||||
out_precisions: &[Precision],
|
||||
) -> Precision {
|
||||
match op {
|
||||
Op::Input { out_precision, .. } | Op::Lut { out_precision, .. } => *out_precision,
|
||||
Op::Dot { inputs, .. } | Op::LevelledOp { inputs, .. } => out_precisions[inputs[0].i],
|
||||
}
|
||||
}
|
||||
|
||||
fn out_precisions(dag: &unparametrized::OperationDag) -> Vec<Precision> {
|
||||
let nb_ops = dag.operators.len();
|
||||
let mut out_precisions = Vec::<Precision>::with_capacity(nb_ops);
|
||||
for op in &dag.operators {
|
||||
let precision = out_precision(op, &out_precisions);
|
||||
out_precisions.push(precision);
|
||||
}
|
||||
out_precisions
|
||||
}
|
||||
|
||||
fn out_variance(
|
||||
op: &unparametrized::UnparameterizedOperator,
|
||||
out_shapes: &[Shape],
|
||||
@@ -220,17 +154,15 @@ fn out_variance(
|
||||
DK::Unsupported { .. } => panic!("Unsupported"),
|
||||
}
|
||||
}
|
||||
Op::UnsafeCast { input, .. } => out_variances[input.i],
|
||||
}
|
||||
}
|
||||
|
||||
fn out_variances(
|
||||
dag: &unparametrized::OperationDag,
|
||||
out_shapes: &[Shape],
|
||||
) -> Vec<SymbolicVariance> {
|
||||
fn out_variances(dag: &unparametrized::OperationDag) -> Vec<SymbolicVariance> {
|
||||
let nb_ops = dag.operators.len();
|
||||
let mut out_variances = Vec::with_capacity(nb_ops);
|
||||
for op in &dag.operators {
|
||||
let vf = out_variance(op, out_shapes, &mut out_variances);
|
||||
let vf = out_variance(op, &dag.out_shapes, &mut out_variances);
|
||||
out_variances.push(vf);
|
||||
}
|
||||
out_variances
|
||||
@@ -242,7 +174,7 @@ fn extra_final_values_to_check(dag: &unparametrized::OperationDag) -> Vec<bool>
|
||||
for op in &dag.operators {
|
||||
match op {
|
||||
Op::Input { .. } => (),
|
||||
Op::Lut { input, .. } => {
|
||||
Op::Lut { input, .. } | Op::UnsafeCast { input, .. } => {
|
||||
extra_values_to_check[input.i] = false;
|
||||
}
|
||||
Op::Dot { inputs, .. } | Op::LevelledOp { inputs, .. } => {
|
||||
@@ -257,8 +189,6 @@ fn extra_final_values_to_check(dag: &unparametrized::OperationDag) -> Vec<bool>
|
||||
|
||||
fn extra_final_variances(
|
||||
dag: &unparametrized::OperationDag,
|
||||
out_shapes: &[Shape],
|
||||
out_precisions: &[Precision],
|
||||
out_variances: &[SymbolicVariance],
|
||||
) -> Vec<(Precision, Shape, SymbolicVariance)> {
|
||||
extra_final_values_to_check(dag)
|
||||
@@ -266,7 +196,11 @@ fn extra_final_variances(
|
||||
.enumerate()
|
||||
.filter_map(|(i, &is_final)| {
|
||||
if is_final {
|
||||
Some((out_precisions[i], out_shapes[i].clone(), out_variances[i]))
|
||||
Some((
|
||||
dag.out_precisions[i],
|
||||
dag.out_shapes[i].clone(),
|
||||
out_variances[i],
|
||||
))
|
||||
} else {
|
||||
None
|
||||
}
|
||||
@@ -276,8 +210,6 @@ fn extra_final_variances(
|
||||
|
||||
fn in_luts_variance(
|
||||
dag: &unparametrized::OperationDag,
|
||||
out_shapes: &[Shape],
|
||||
out_precisions: &[Precision],
|
||||
out_variances: &[SymbolicVariance],
|
||||
) -> Vec<(Precision, Shape, SymbolicVariance)> {
|
||||
dag.operators
|
||||
@@ -286,8 +218,8 @@ fn in_luts_variance(
|
||||
.filter_map(|(i, op)| {
|
||||
if let &Op::Lut { input, .. } = op {
|
||||
Some((
|
||||
out_precisions[input.i],
|
||||
out_shapes[i].clone(),
|
||||
dag.out_precisions[input.i],
|
||||
dag.out_shapes[i].clone(),
|
||||
out_variances[input.i],
|
||||
))
|
||||
} else {
|
||||
@@ -315,26 +247,23 @@ fn op_levelled_complexity(
|
||||
}
|
||||
}
|
||||
Op::LevelledOp { complexity, .. } => *complexity,
|
||||
Op::Input { .. } | Op::Lut { .. } => LevelledComplexity::ZERO,
|
||||
Op::Input { .. } | Op::Lut { .. } | Op::UnsafeCast { .. } => LevelledComplexity::ZERO,
|
||||
}
|
||||
}
|
||||
|
||||
fn levelled_complexity(
|
||||
dag: &unparametrized::OperationDag,
|
||||
out_shapes: &[Shape],
|
||||
) -> LevelledComplexity {
|
||||
fn levelled_complexity(dag: &unparametrized::OperationDag) -> LevelledComplexity {
|
||||
let mut levelled_complexity = LevelledComplexity::ZERO;
|
||||
for op in &dag.operators {
|
||||
levelled_complexity += op_levelled_complexity(op, out_shapes);
|
||||
levelled_complexity += op_levelled_complexity(op, &dag.out_shapes);
|
||||
}
|
||||
levelled_complexity
|
||||
}
|
||||
|
||||
fn lut_count(dag: &unparametrized::OperationDag, out_shapes: &[Shape]) -> u64 {
|
||||
pub fn lut_count_from_dag(dag: &unparametrized::OperationDag) -> u64 {
|
||||
let mut count = 0;
|
||||
for (i, op) in dag.operators.iter().enumerate() {
|
||||
if let Op::Lut { .. } = op {
|
||||
count += out_shapes[i].flat_size();
|
||||
count += dag.out_shapes[i].flat_size();
|
||||
}
|
||||
}
|
||||
count
|
||||
@@ -433,10 +362,8 @@ fn constraint_for_one_precision(
|
||||
|
||||
pub fn worst_log_norm(dag: &unparametrized::OperationDag) -> f64 {
|
||||
assert_dag_correctness(dag);
|
||||
let out_shapes = out_shapes(dag);
|
||||
let out_precisions = out_precisions(dag);
|
||||
let out_variances = out_variances(dag, &out_shapes);
|
||||
let in_luts_variance = in_luts_variance(dag, &out_shapes, &out_precisions, &out_variances);
|
||||
let out_variances = out_variances(dag);
|
||||
let in_luts_variance = in_luts_variance(dag, &out_variances);
|
||||
let coeffs = in_luts_variance
|
||||
.iter()
|
||||
.map(|(_precision, _shape, symbolic_variance)| {
|
||||
@@ -447,25 +374,18 @@ pub fn worst_log_norm(dag: &unparametrized::OperationDag) -> f64 {
|
||||
worst.log2()
|
||||
}
|
||||
|
||||
pub fn lut_count_from_dag(dag: &unparametrized::OperationDag) -> u64 {
|
||||
lut_count(dag, &out_shapes(dag))
|
||||
}
|
||||
|
||||
pub fn analyze(
|
||||
dag: &unparametrized::OperationDag,
|
||||
noise_config: &NoiseBoundConfig,
|
||||
) -> OperationDag {
|
||||
assert_dag_correctness(dag);
|
||||
let out_shapes = out_shapes(dag);
|
||||
let out_precisions = out_precisions(dag);
|
||||
let out_variances = out_variances(dag, &out_shapes);
|
||||
let in_luts_variance = in_luts_variance(dag, &out_shapes, &out_precisions, &out_variances);
|
||||
let nb_luts = lut_count(dag, &out_shapes);
|
||||
let extra_final_variances =
|
||||
extra_final_variances(dag, &out_shapes, &out_precisions, &out_variances);
|
||||
let levelled_complexity = levelled_complexity(dag, &out_shapes);
|
||||
let out_variances = out_variances(dag);
|
||||
let in_luts_variance = in_luts_variance(dag, &out_variances);
|
||||
let nb_luts = lut_count_from_dag(dag);
|
||||
let extra_final_variances = extra_final_variances(dag, &out_variances);
|
||||
let levelled_complexity = levelled_complexity(dag);
|
||||
let constraints_by_precisions = constraints_by_precisions(
|
||||
&out_precisions,
|
||||
&dag.out_precisions,
|
||||
&extra_final_variances,
|
||||
&in_luts_variance,
|
||||
noise_config,
|
||||
@@ -475,8 +395,6 @@ pub fn analyze(
|
||||
.all(|(_, _, sb)| sb.origin() == VarianceOrigin::Input);
|
||||
let result = OperationDag {
|
||||
operators: dag.operators.clone(),
|
||||
out_shapes,
|
||||
out_precisions,
|
||||
out_variances,
|
||||
nb_luts,
|
||||
levelled_complexity,
|
||||
@@ -712,9 +630,9 @@ mod tests {
|
||||
let complexity_cost = analysis.complexity(lwe_dim, one_lut_cost);
|
||||
|
||||
assert_eq!(analysis.out_variances[input1.i], SymbolicVariance::INPUT);
|
||||
assert_eq!(analysis.out_shapes[input1.i], Shape::number());
|
||||
assert_eq!(graph.out_shapes[input1.i], Shape::number());
|
||||
assert_eq!(analysis.levelled_complexity, LevelledComplexity::ZERO);
|
||||
assert_eq!(analysis.out_precisions[input1.i], 1);
|
||||
assert_eq!(graph.out_precisions[input1.i], 1);
|
||||
assert_f64_eq(complexity_cost, 0.0);
|
||||
assert!(analysis.nb_luts == 0);
|
||||
let constraint = analysis.constraint();
|
||||
@@ -735,9 +653,9 @@ mod tests {
|
||||
let complexity_cost = analysis.complexity(lwe_dim, one_lut_cost);
|
||||
|
||||
assert!(analysis.out_variances[lut1.i] == SymbolicVariance::LUT);
|
||||
assert!(analysis.out_shapes[lut1.i] == Shape::number());
|
||||
assert!(graph.out_shapes[lut1.i] == Shape::number());
|
||||
assert!(analysis.levelled_complexity == LevelledComplexity::ZERO);
|
||||
assert_eq!(analysis.out_precisions[lut1.i], 8);
|
||||
assert_eq!(graph.out_precisions[lut1.i], 8);
|
||||
assert_f64_eq(one_lut_cost, complexity_cost);
|
||||
let constraint = analysis.constraint();
|
||||
assert!(constraint.pareto_output.len() == 1);
|
||||
@@ -765,9 +683,9 @@ mod tests {
|
||||
lut_coeff: 0.0,
|
||||
};
|
||||
assert!(analysis.out_variances[dot.i] == expected_var);
|
||||
assert!(analysis.out_shapes[dot.i] == Shape::number());
|
||||
assert!(graph.out_shapes[dot.i] == Shape::number());
|
||||
assert!(analysis.levelled_complexity == LevelledComplexity::ADDITION * 2);
|
||||
assert_eq!(analysis.out_precisions[dot.i], 1);
|
||||
assert_eq!(graph.out_precisions[dot.i], 1);
|
||||
let expected_dot_cost = (2 * lwe_dim) as f64;
|
||||
assert_f64_eq(expected_dot_cost, complexity_cost);
|
||||
let constraint = analysis.constraint();
|
||||
@@ -792,7 +710,7 @@ mod tests {
|
||||
let complexity_cost = analysis.complexity(lwe_dim, one_lut_cost);
|
||||
|
||||
assert!(analysis.out_variances[dot.i].origin() == VO::Input);
|
||||
assert_eq!(analysis.out_precisions[dot.i], 3);
|
||||
assert_eq!(graph.out_precisions[dot.i], 3);
|
||||
let expected_square_norm2 = weights.square_norm2() as f64;
|
||||
let actual_square_norm2 = analysis.out_variances[dot.i].input_coeff;
|
||||
// Due to call on log2() to compute manp the result is not exact
|
||||
|
||||
@@ -240,10 +240,10 @@ pub fn optimize(
|
||||
maximum_acceptable_error_probability: config.maximum_acceptable_error_probability,
|
||||
ciphertext_modulus_log,
|
||||
};
|
||||
let dag = analyze::analyze(dag, &noise_config);
|
||||
|
||||
let &min_precision = dag.out_precisions.iter().min().unwrap();
|
||||
|
||||
let dag = analyze::analyze(dag, &noise_config);
|
||||
|
||||
let safe_variance = error::safe_variance_bound_2padbits(
|
||||
min_precision as u64,
|
||||
ciphertext_modulus_log,
|
||||
@@ -357,6 +357,8 @@ pub fn optimize_v0(
|
||||
mod tests {
|
||||
use std::time::Instant;
|
||||
|
||||
use once_cell::sync::Lazy;
|
||||
|
||||
use super::*;
|
||||
use crate::computing_cost::cpu::CpuComplexity;
|
||||
use crate::config;
|
||||
@@ -388,10 +390,12 @@ mod tests {
|
||||
|
||||
const _4_SIGMA: f64 = 1.0 - 0.999_936_657_516;
|
||||
|
||||
fn optimize(
|
||||
dag: &unparametrized::OperationDag,
|
||||
cache: &PersistDecompCache,
|
||||
) -> OptimizationState {
|
||||
static SHARED_CACHES: Lazy<PersistDecompCache> = Lazy::new(|| {
|
||||
let processing_unit = config::ProcessingUnit::Cpu;
|
||||
decomposition::cache(128, processing_unit, None)
|
||||
});
|
||||
|
||||
fn optimize(dag: &unparametrized::OperationDag) -> OptimizationState {
|
||||
let config = Config {
|
||||
security_level: 128,
|
||||
maximum_acceptable_error_probability: _4_SIGMA,
|
||||
@@ -401,7 +405,7 @@ mod tests {
|
||||
|
||||
let search_space = SearchSpace::default_cpu();
|
||||
|
||||
super::optimize(dag, config, &search_space, cache)
|
||||
super::optimize(dag, config, &search_space, &SHARED_CACHES)
|
||||
}
|
||||
|
||||
struct Times {
|
||||
@@ -442,15 +446,13 @@ mod tests {
|
||||
complexity_model: &CpuComplexity::default(),
|
||||
};
|
||||
|
||||
let cache = decomposition::cache(config.security_level, processing_unit, None);
|
||||
|
||||
let _ = optimize_v0(
|
||||
sum_size,
|
||||
precision,
|
||||
config,
|
||||
weight as f64,
|
||||
&search_space,
|
||||
&cache,
|
||||
&SHARED_CACHES,
|
||||
);
|
||||
// ensure cache is filled
|
||||
|
||||
@@ -461,7 +463,7 @@ mod tests {
|
||||
config,
|
||||
weight as f64,
|
||||
&search_space,
|
||||
&cache,
|
||||
&SHARED_CACHES,
|
||||
);
|
||||
|
||||
times.dag_time += chrono.elapsed().as_nanos();
|
||||
@@ -472,7 +474,7 @@ mod tests {
|
||||
config,
|
||||
weight as f64,
|
||||
&search_space,
|
||||
&cache,
|
||||
&SHARED_CACHES,
|
||||
);
|
||||
times.worst_time += chrono.elapsed().as_nanos();
|
||||
assert_eq!(
|
||||
@@ -504,8 +506,6 @@ mod tests {
|
||||
let processing_unit = config::ProcessingUnit::Cpu;
|
||||
let security_level = 128;
|
||||
|
||||
let cache = decomposition::cache(security_level, processing_unit, None);
|
||||
|
||||
let mut dag = unparametrized::OperationDag::new();
|
||||
{
|
||||
let input1 = dag.add_input(precision, Shape::number());
|
||||
@@ -542,14 +542,14 @@ mod tests {
|
||||
complexity_model: &CpuComplexity::default(),
|
||||
};
|
||||
|
||||
let state = optimize(&dag, &cache);
|
||||
let state = optimize(&dag);
|
||||
let state_ref = atomic_pattern::optimize_one(
|
||||
1,
|
||||
precision as u64,
|
||||
config,
|
||||
weight as f64,
|
||||
&search_space,
|
||||
&cache,
|
||||
&SHARED_CACHES,
|
||||
);
|
||||
assert_eq!(
|
||||
state.best_solution.is_some(),
|
||||
@@ -567,7 +567,7 @@ mod tests {
|
||||
assert!(sol.global_p_error <= 1.0);
|
||||
}
|
||||
|
||||
fn no_lut_vs_lut(precision: Precision, cache: &PersistDecompCache) {
|
||||
fn no_lut_vs_lut(precision: Precision) {
|
||||
let mut dag_lut = unparametrized::OperationDag::new();
|
||||
let input1 = dag_lut.add_input(precision, Shape::number());
|
||||
let _lut1 = dag_lut.add_lut(input1, FunctionTable::UNKWOWN, precision);
|
||||
@@ -575,8 +575,8 @@ mod tests {
|
||||
let mut dag_no_lut = unparametrized::OperationDag::new();
|
||||
let _input2 = dag_no_lut.add_input(precision, Shape::number());
|
||||
|
||||
let state_no_lut = optimize(&dag_no_lut, cache);
|
||||
let state_lut = optimize(&dag_lut, cache);
|
||||
let state_no_lut = optimize(&dag_no_lut);
|
||||
let state_lut = optimize(&dag_lut);
|
||||
assert_eq!(
|
||||
state_no_lut.best_solution.is_some(),
|
||||
state_lut.best_solution.is_some()
|
||||
@@ -592,17 +592,14 @@ mod tests {
|
||||
}
|
||||
#[test]
|
||||
fn test_lut_vs_no_lut() {
|
||||
let processing_unit = config::ProcessingUnit::Cpu;
|
||||
let cache = decomposition::cache(128, processing_unit, None);
|
||||
for precision in 1..=8 {
|
||||
no_lut_vs_lut(precision, &cache);
|
||||
no_lut_vs_lut(precision);
|
||||
}
|
||||
}
|
||||
|
||||
fn lut_with_input_base_noise_better_than_lut_with_lut_base_noise(
|
||||
precision: Precision,
|
||||
weight: i64,
|
||||
cache: &PersistDecompCache,
|
||||
) {
|
||||
let weight = &Weights::number(weight);
|
||||
|
||||
@@ -622,8 +619,8 @@ mod tests {
|
||||
let _lut2 = dag_2.add_lut(scaled_lut1, FunctionTable::UNKWOWN, precision);
|
||||
}
|
||||
|
||||
let state_1 = optimize(&dag_1, cache);
|
||||
let state_2 = optimize(&dag_2, cache);
|
||||
let state_1 = optimize(&dag_1);
|
||||
let state_2 = optimize(&dag_2);
|
||||
|
||||
if state_1.best_solution.is_none() {
|
||||
assert!(state_2.best_solution.is_none());
|
||||
@@ -636,19 +633,15 @@ mod tests {
|
||||
|
||||
#[test]
|
||||
fn test_lut_with_input_base_noise_better_than_lut_with_lut_base_noise() {
|
||||
let processing_unit = config::ProcessingUnit::Cpu;
|
||||
let cache = decomposition::cache(128, processing_unit, None);
|
||||
for log_weight in 1..=16 {
|
||||
let weight = 1 << log_weight;
|
||||
for precision in 5..=9 {
|
||||
lut_with_input_base_noise_better_than_lut_with_lut_base_noise(
|
||||
precision, weight, &cache,
|
||||
);
|
||||
lut_with_input_base_noise_better_than_lut_with_lut_base_noise(precision, weight);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
fn lut_1_layer_has_better_complexity(precision: Precision, cache: &PersistDecompCache) {
|
||||
fn lut_1_layer_has_better_complexity(precision: Precision) {
|
||||
let dag_1_layer = {
|
||||
let mut dag = unparametrized::OperationDag::new();
|
||||
let input1 = dag.add_input(precision, Shape::number());
|
||||
@@ -664,19 +657,17 @@ mod tests {
|
||||
dag
|
||||
};
|
||||
|
||||
let sol_1_layer = optimize(&dag_1_layer, cache).best_solution.unwrap();
|
||||
let sol_2_layer = optimize(&dag_2_layer, cache).best_solution.unwrap();
|
||||
let sol_1_layer = optimize(&dag_1_layer).best_solution.unwrap();
|
||||
let sol_2_layer = optimize(&dag_2_layer).best_solution.unwrap();
|
||||
assert!(sol_1_layer.complexity < sol_2_layer.complexity);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_lut_1_layer_is_better() {
|
||||
let processing_unit = config::ProcessingUnit::Cpu;
|
||||
let cache = decomposition::cache(128, processing_unit, None);
|
||||
// for some reason on 4, 5, 6, the complexity is already minimal
|
||||
// this could be due to pre-defined pareto set
|
||||
for precision in [1, 2, 3, 7, 8] {
|
||||
lut_1_layer_has_better_complexity(precision, &cache);
|
||||
lut_1_layer_has_better_complexity(precision);
|
||||
}
|
||||
}
|
||||
|
||||
@@ -688,10 +679,7 @@ mod tests {
|
||||
let _lut2 = dag.add_lut(dot2, FunctionTable::UNKWOWN, precision);
|
||||
}
|
||||
|
||||
fn assert_multi_precision_dominate_single(
|
||||
weight: i64,
|
||||
cache: &PersistDecompCache,
|
||||
) -> Option<bool> {
|
||||
fn assert_multi_precision_dominate_single(weight: i64) -> Option<bool> {
|
||||
let low_precision = 4u8;
|
||||
let high_precision = 5u8;
|
||||
let mut dag_low = unparametrized::OperationDag::new();
|
||||
@@ -704,12 +692,12 @@ mod tests {
|
||||
circuit(&mut dag_multi, low_precision, weight);
|
||||
circuit(&mut dag_multi, high_precision, 1);
|
||||
}
|
||||
let state_multi = optimize(&dag_multi, cache);
|
||||
let state_multi = optimize(&dag_multi);
|
||||
|
||||
let mut sol_multi = state_multi.best_solution?;
|
||||
|
||||
let state_low = optimize(&dag_low, cache);
|
||||
let state_high = optimize(&dag_high, cache);
|
||||
let state_low = optimize(&dag_low);
|
||||
let state_high = optimize(&dag_high);
|
||||
|
||||
let sol_low = state_low.best_solution.unwrap();
|
||||
let sol_high = state_high.best_solution.unwrap();
|
||||
@@ -728,12 +716,10 @@ mod tests {
|
||||
|
||||
#[test]
|
||||
fn test_multi_precision_dominate_single() {
|
||||
let processing_unit = config::ProcessingUnit::Cpu;
|
||||
let cache = decomposition::cache(128, processing_unit, None);
|
||||
let mut prev = Some(true); // true -> ... -> true -> false -> ... -> false
|
||||
for log2_weight in 0..29 {
|
||||
let weight = 1 << log2_weight;
|
||||
let current = assert_multi_precision_dominate_single(weight, &cache);
|
||||
let current = assert_multi_precision_dominate_single(weight);
|
||||
#[allow(clippy::match_like_matches_macro)] // less readable
|
||||
let authorized = match (prev, current) {
|
||||
(Some(false), Some(true)) => false,
|
||||
@@ -763,29 +749,22 @@ mod tests {
|
||||
|
||||
#[test]
|
||||
fn test_global_p_error_input() {
|
||||
let processing_unit = config::ProcessingUnit::Cpu;
|
||||
let cache = decomposition::cache(128, processing_unit, None);
|
||||
for precision in [4_u8, 8] {
|
||||
for weight in [1, 3, 27, 243, 729] {
|
||||
for dim in [1, 2, 16, 32] {
|
||||
let _ = check_global_p_error_input(dim, weight, precision, &cache);
|
||||
let _ = check_global_p_error_input(dim, weight, precision);
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
fn check_global_p_error_input(
|
||||
dim: u64,
|
||||
weight: i64,
|
||||
precision: u8,
|
||||
cache: &PersistDecompCache,
|
||||
) -> f64 {
|
||||
fn check_global_p_error_input(dim: u64, weight: i64, precision: u8) -> f64 {
|
||||
let shape = Shape::vector(dim);
|
||||
let weights = Weights::number(weight);
|
||||
let mut dag = unparametrized::OperationDag::new();
|
||||
let input1 = dag.add_input(precision, shape);
|
||||
let _dot1 = dag.add_dot([input1], weights); // this is just several multiply
|
||||
let state = optimize(&dag, cache);
|
||||
let state = optimize(&dag);
|
||||
let sol = state.best_solution.unwrap();
|
||||
let worst_expected_p_error_dim = local_to_approx_global_p_error(sol.p_error, dim);
|
||||
approx::assert_relative_eq!(sol.global_p_error, worst_expected_p_error_dim);
|
||||
@@ -794,23 +773,16 @@ mod tests {
|
||||
|
||||
#[test]
|
||||
fn test_global_p_error_lut() {
|
||||
let processing_unit = config::ProcessingUnit::Cpu;
|
||||
let cache = decomposition::cache(128, processing_unit, None);
|
||||
for precision in [4_u8, 8] {
|
||||
for weight in [1, 3, 27, 243, 729] {
|
||||
for depth in [2, 16, 32] {
|
||||
check_global_p_error_lut(depth, weight, precision, &cache);
|
||||
check_global_p_error_lut(depth, weight, precision);
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
fn check_global_p_error_lut(
|
||||
depth: u64,
|
||||
weight: i64,
|
||||
precision: u8,
|
||||
cache: &PersistDecompCache,
|
||||
) {
|
||||
fn check_global_p_error_lut(depth: u64, weight: i64, precision: u8) {
|
||||
let shape = Shape::number();
|
||||
let weights = Weights::number(weight);
|
||||
let mut dag = unparametrized::OperationDag::new();
|
||||
@@ -819,7 +791,7 @@ mod tests {
|
||||
let dot = dag.add_dot([last_val], &weights);
|
||||
last_val = dag.add_lut(dot, FunctionTable::UNKWOWN, precision);
|
||||
}
|
||||
let state = optimize(&dag, cache);
|
||||
let state = optimize(&dag);
|
||||
let sol = state.best_solution.unwrap();
|
||||
// the first lut on input has reduced impact on error probability
|
||||
let lower_nb_dominating_lut = depth - 1;
|
||||
@@ -856,8 +828,6 @@ mod tests {
|
||||
#[allow(clippy::unnecessary_cast)] // clippy bug refusing as Precision on const
|
||||
#[test]
|
||||
fn test_global_p_error_dominating_lut() {
|
||||
let processing_unit = config::ProcessingUnit::Cpu;
|
||||
let cache = decomposition::cache(128, processing_unit, None);
|
||||
let depth = 128;
|
||||
let weights_low = 1;
|
||||
let weights_high = 1;
|
||||
@@ -870,7 +840,7 @@ mod tests {
|
||||
weights_low,
|
||||
weights_high,
|
||||
);
|
||||
let sol = optimize(&dag, &cache).best_solution.unwrap();
|
||||
let sol = optimize(&dag).best_solution.unwrap();
|
||||
// the 2 first luts and low precision/weight luts have little impact on error probability
|
||||
let nb_dominating_lut = depth - 1;
|
||||
let approx_global_p_error = local_to_approx_global_p_error(sol.p_error, nb_dominating_lut);
|
||||
@@ -885,8 +855,6 @@ mod tests {
|
||||
#[allow(clippy::unnecessary_cast)] // clippy bug refusing as Precision on const
|
||||
#[test]
|
||||
fn test_global_p_error_non_dominating_lut() {
|
||||
let processing_unit = config::ProcessingUnit::Cpu;
|
||||
let cache = decomposition::cache(128, processing_unit, None);
|
||||
let depth = 128;
|
||||
let weights_low = 1024 * 1024 * 3;
|
||||
let weights_high = 1;
|
||||
@@ -899,7 +867,7 @@ mod tests {
|
||||
weights_low,
|
||||
weights_high,
|
||||
);
|
||||
let sol = optimize(&dag, &cache).best_solution.unwrap();
|
||||
let sol = optimize(&dag).best_solution.unwrap();
|
||||
// all intern luts have an impact on error probability almost equaly
|
||||
let nb_dominating_lut = (2 * depth) - 1;
|
||||
let approx_global_p_error = local_to_approx_global_p_error(sol.p_error, nb_dominating_lut);
|
||||
@@ -910,4 +878,89 @@ mod tests {
|
||||
max_relative = 0.05
|
||||
);
|
||||
}
|
||||
|
||||
fn circuit_with_rounded_lut(
|
||||
rounded_precision: Precision,
|
||||
precision: Precision,
|
||||
weight: i64,
|
||||
) -> unparametrized::OperationDag {
|
||||
// circuit with intermediate high precision in levelled op
|
||||
let shape = Shape::number();
|
||||
let mut dag = unparametrized::OperationDag::new();
|
||||
let weight = Weights::number(weight);
|
||||
let val = dag.add_input(precision, &shape);
|
||||
let lut1 =
|
||||
dag.add_expanded_rounded_lut(val, FunctionTable::UNKWOWN, rounded_precision, precision);
|
||||
let dot = dag.add_dot([lut1], &weight);
|
||||
let _lut2 = dag.add_expanded_rounded_lut(
|
||||
dot,
|
||||
FunctionTable::UNKWOWN,
|
||||
rounded_precision,
|
||||
rounded_precision,
|
||||
);
|
||||
dag
|
||||
}
|
||||
|
||||
fn check_global_p_error_rounded_lut(
|
||||
precision: Precision,
|
||||
rounded_precision: Precision,
|
||||
weight: i64,
|
||||
) {
|
||||
let dag_no_rounded = circuit_with_rounded_lut(precision, precision, weight);
|
||||
let dag_rounded = circuit_with_rounded_lut(rounded_precision, precision, weight);
|
||||
let dag_reduced = circuit_with_rounded_lut(rounded_precision, rounded_precision, weight);
|
||||
let best_reduced = optimize(&dag_reduced).best_solution.unwrap();
|
||||
let best_rounded = optimize(&dag_rounded).best_solution.unwrap();
|
||||
let best_no_rounded_complexity = optimize(&dag_no_rounded)
|
||||
.best_solution
|
||||
.map_or(f64::INFINITY, |s| s.complexity);
|
||||
// println!("Slowdown acc {rounded_precision} -> {precision}, {best_rounded.complexity/best_reduced.complexity}");
|
||||
// println!("Speedup tlu {precision} -> {rounded_precision}, {best_no_rounded_complexity/best_rounded.complexity}");
|
||||
if weight == 0 && precision - rounded_precision <= 4
|
||||
|| weight == 16 && precision - rounded_precision <= 3
|
||||
{
|
||||
// linear slowdown with almost no margin
|
||||
assert!(
|
||||
best_rounded.complexity
|
||||
<= best_reduced.complexity
|
||||
* (1.0 + 1.01 * (precision - rounded_precision) as f64)
|
||||
);
|
||||
} else if precision - rounded_precision <= 4 {
|
||||
// linear slowdown with margin
|
||||
assert!(
|
||||
best_rounded.complexity
|
||||
<= best_reduced.complexity
|
||||
* (1.0 + 1.5 * (precision - rounded_precision) as f64)
|
||||
);
|
||||
} else if precision != rounded_precision {
|
||||
// slowdown
|
||||
assert!(best_reduced.complexity < best_rounded.complexity);
|
||||
}
|
||||
// linear speedup
|
||||
assert!(
|
||||
best_rounded.complexity * (precision - rounded_precision) as f64
|
||||
<= best_no_rounded_complexity
|
||||
);
|
||||
}
|
||||
|
||||
#[allow(clippy::unnecessary_cast)] // clippy bug refusing as Precision on const
|
||||
#[test]
|
||||
fn test_global_p_error_rounded_lut() {
|
||||
let precision = 8 as Precision;
|
||||
for rounded_precision in 4..9 {
|
||||
check_global_p_error_rounded_lut(precision, rounded_precision, 1);
|
||||
}
|
||||
}
|
||||
|
||||
#[allow(clippy::unnecessary_cast)] // clippy bug refusing as Precision on const
|
||||
#[test]
|
||||
fn test_global_p_error_increased_accumulator() {
|
||||
let rounded_precision = 8 as Precision;
|
||||
for precision in [9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19, 20] {
|
||||
for weight in [1, 2, 4, 8, 16, 32, 64, 128] {
|
||||
println!("{precision} {weight}");
|
||||
check_global_p_error_rounded_lut(precision, rounded_precision, weight);
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
@@ -1,5 +1,5 @@
|
||||
use crate::dag::operator::{Operator, Precision};
|
||||
use crate::dag::unparametrized::{OperationDag, UnparameterizedOperator};
|
||||
use crate::dag::operator::Precision;
|
||||
use crate::dag::unparametrized::OperationDag;
|
||||
use crate::noise_estimator::p_error::repeat_p_error;
|
||||
use crate::optimization::atomic_pattern::Solution as WpSolution;
|
||||
use crate::optimization::config::{Config, SearchSpace};
|
||||
@@ -18,21 +18,8 @@ pub enum Solution {
|
||||
WopSolution(WopSolution),
|
||||
}
|
||||
|
||||
fn precision_op(op: &UnparameterizedOperator) -> Option<Precision> {
|
||||
match op {
|
||||
Operator::Input { out_precision, .. } | Operator::Lut { out_precision, .. } => {
|
||||
Some(*out_precision)
|
||||
}
|
||||
Operator::Dot { .. } | Operator::LevelledOp { .. } => None,
|
||||
}
|
||||
}
|
||||
|
||||
fn max_precision(dag: &OperationDag) -> Precision {
|
||||
dag.operators
|
||||
.iter()
|
||||
.filter_map(precision_op)
|
||||
.max()
|
||||
.unwrap_or(0)
|
||||
dag.out_precisions.iter().copied().max().unwrap_or(0)
|
||||
}
|
||||
|
||||
fn updated_global_p_error(nb_luts: u64, sol: WopSolution) -> WopSolution {
|
||||
|
||||
Reference in New Issue
Block a user