feat: multiprecision, allow precision change in lut

This commit is contained in:
rudy
2022-06-01 18:47:08 +02:00
committed by rudy-6-4
parent b3e3a10f22
commit d220eb4009
8 changed files with 82 additions and 53 deletions

View File

@@ -71,12 +71,17 @@ impl OperationDag {
self.0.add_input(out_precision, out_shape).into()
}
fn add_lut(&mut self, input: ffi::OperatorIndex, table: &[u64]) -> ffi::OperatorIndex {
fn add_lut(
&mut self,
input: ffi::OperatorIndex,
table: &[u64],
out_precision: Precision,
) -> ffi::OperatorIndex {
let table = FunctionTable {
values: table.to_owned(),
};
self.0.add_lut(input.into(), table).into()
self.0.add_lut(input.into(), table, out_precision).into()
}
#[allow(clippy::boxed_local)]
@@ -160,7 +165,12 @@ mod ffi {
out_shape: &[u64],
) -> OperatorIndex;
fn add_lut(self: &mut OperationDag, input: OperatorIndex, table: &[u64]) -> OperatorIndex;
fn add_lut(
self: &mut OperationDag,
input: OperatorIndex,
table: &[u64],
out_precision: u8,
) -> OperatorIndex;
fn add_dot(
self: &mut OperationDag,

View File

@@ -620,7 +620,7 @@ namespace concrete_optimizer {
#define CXXBRIDGE1_STRUCT_concrete_optimizer$OperationDag
struct OperationDag final : public ::rust::Opaque {
::concrete_optimizer::dag::OperatorIndex add_input(::std::uint8_t out_precision, ::rust::Slice<const ::std::uint64_t> out_shape) noexcept;
::concrete_optimizer::dag::OperatorIndex add_lut(::concrete_optimizer::dag::OperatorIndex input, ::rust::Slice<const ::std::uint64_t> table) noexcept;
::concrete_optimizer::dag::OperatorIndex add_lut(::concrete_optimizer::dag::OperatorIndex input, ::rust::Slice<const ::std::uint64_t> table, ::std::uint8_t out_precision) noexcept;
::concrete_optimizer::dag::OperatorIndex add_dot(::rust::Slice<const ::concrete_optimizer::dag::OperatorIndex> inputs, ::rust::Box<::concrete_optimizer::Weights> weights) noexcept;
::concrete_optimizer::dag::OperatorIndex add_levelled_op(::rust::Slice<const ::concrete_optimizer::dag::OperatorIndex> inputs, double lwe_dim_cost_factor, double fixed_cost, double manp, ::rust::Slice<const ::std::uint64_t> out_shape, ::rust::Str comment) noexcept;
~OperationDag() = delete;
@@ -698,7 +698,7 @@ extern "C" {
extern "C" {
::concrete_optimizer::dag::OperatorIndex concrete_optimizer$cxxbridge1$OperationDag$add_input(::concrete_optimizer::OperationDag &self, ::std::uint8_t out_precision, ::rust::Slice<const ::std::uint64_t> out_shape) noexcept;
::concrete_optimizer::dag::OperatorIndex concrete_optimizer$cxxbridge1$OperationDag$add_lut(::concrete_optimizer::OperationDag &self, ::concrete_optimizer::dag::OperatorIndex input, ::rust::Slice<const ::std::uint64_t> table) noexcept;
::concrete_optimizer::dag::OperatorIndex concrete_optimizer$cxxbridge1$OperationDag$add_lut(::concrete_optimizer::OperationDag &self, ::concrete_optimizer::dag::OperatorIndex input, ::rust::Slice<const ::std::uint64_t> table, ::std::uint8_t out_precision) noexcept;
::concrete_optimizer::dag::OperatorIndex concrete_optimizer$cxxbridge1$OperationDag$add_dot(::concrete_optimizer::OperationDag &self, ::rust::Slice<const ::concrete_optimizer::dag::OperatorIndex> inputs, ::concrete_optimizer::Weights *weights) noexcept;
@@ -737,8 +737,8 @@ namespace dag {
return concrete_optimizer$cxxbridge1$OperationDag$add_input(*this, out_precision, out_shape);
}
::concrete_optimizer::dag::OperatorIndex OperationDag::add_lut(::concrete_optimizer::dag::OperatorIndex input, ::rust::Slice<const ::std::uint64_t> table) noexcept {
return concrete_optimizer$cxxbridge1$OperationDag$add_lut(*this, input, table);
::concrete_optimizer::dag::OperatorIndex OperationDag::add_lut(::concrete_optimizer::dag::OperatorIndex input, ::rust::Slice<const ::std::uint64_t> table, ::std::uint8_t out_precision) noexcept {
return concrete_optimizer$cxxbridge1$OperationDag$add_lut(*this, input, table, out_precision);
}
::concrete_optimizer::dag::OperatorIndex OperationDag::add_dot(::rust::Slice<const ::concrete_optimizer::dag::OperatorIndex> inputs, ::rust::Box<::concrete_optimizer::Weights> weights) noexcept {

View File

@@ -621,7 +621,7 @@ namespace concrete_optimizer {
#define CXXBRIDGE1_STRUCT_concrete_optimizer$OperationDag
struct OperationDag final : public ::rust::Opaque {
::concrete_optimizer::dag::OperatorIndex add_input(::std::uint8_t out_precision, ::rust::Slice<const ::std::uint64_t> out_shape) noexcept;
::concrete_optimizer::dag::OperatorIndex add_lut(::concrete_optimizer::dag::OperatorIndex input, ::rust::Slice<const ::std::uint64_t> table) noexcept;
::concrete_optimizer::dag::OperatorIndex add_lut(::concrete_optimizer::dag::OperatorIndex input, ::rust::Slice<const ::std::uint64_t> table, ::std::uint8_t out_precision) noexcept;
::concrete_optimizer::dag::OperatorIndex add_dot(::rust::Slice<const ::concrete_optimizer::dag::OperatorIndex> inputs, ::rust::Box<::concrete_optimizer::Weights> weights) noexcept;
::concrete_optimizer::dag::OperatorIndex add_levelled_op(::rust::Slice<const ::concrete_optimizer::dag::OperatorIndex> inputs, double lwe_dim_cost_factor, double fixed_cost, double manp, ::rust::Slice<const ::std::uint64_t> out_shape, ::rust::Str comment) noexcept;
~OperationDag() = delete;

View File

@@ -63,7 +63,7 @@ pub enum Operator<InputExtraData, LutExtraData, DotExtraData, LevelledOpExtraDat
Lut {
input: OperatorIndex,
table: FunctionTable,
//reduced_precision: u64
out_precision: Precision,
extra_data: LutExtraData,
},
Dot {

View File

@@ -34,10 +34,16 @@ impl OperationDag {
})
}
pub fn add_lut(&mut self, input: OperatorIndex, table: FunctionTable) -> OperatorIndex {
pub fn add_lut(
&mut self,
input: OperatorIndex,
table: FunctionTable,
out_precision: Precision,
) -> OperatorIndex {
self.add_operator(Operator::Lut {
input,
table,
out_precision,
extra_data: (),
})
}
@@ -100,14 +106,14 @@ mod tests {
let cpx_add = LevelledComplexity::ADDITION;
let sum1 = graph.add_levelled_op([input1, input2], cpx_add, 1.0, Shape::number(), "sum");
let lut1 = graph.add_lut(sum1, FunctionTable::UNKWOWN);
let lut1 = graph.add_lut(sum1, FunctionTable::UNKWOWN, 1);
let concat =
graph.add_levelled_op([input1, lut1], cpx_add, 1.0, Shape::vector(2), "concat");
let dot = graph.add_dot([concat], [1, 2]);
let lut2 = graph.add_lut(dot, FunctionTable::UNKWOWN);
let lut2 = graph.add_lut(dot, FunctionTable::UNKWOWN, 2);
let ops_index = [input1, input2, sum1, lut1, concat, dot, lut2];
for (expected_i, op_index) in ops_index.iter().enumerate() {
@@ -138,6 +144,7 @@ mod tests {
Operator::Lut {
input: sum1,
table: FunctionTable::UNKWOWN,
out_precision: 1,
extra_data: ()
},
Operator::LevelledOp {
@@ -159,6 +166,7 @@ mod tests {
Operator::Lut {
input: dot,
table: FunctionTable::UNKWOWN,
out_precision: 2,
extra_data: ()
}
]

View File

@@ -116,9 +116,15 @@ fn convert_maximal(op: UnparameterizedOperator) -> OperatorParameterIndexed {
lwe_dimension_index: external_glwe_index,
},
},
Operator::Lut { input, table, .. } => Operator::Lut {
Operator::Lut {
input,
table,
out_precision,
..
} => Operator::Lut {
input,
table,
out_precision,
extra_data: LutParametersIndexed {
input_lwe_dimension_index: external_glwe_index,
ks_decomposition_parameter_index: ks_decomposition_index,
@@ -257,13 +263,13 @@ mod tests {
let cpx_add = LevelledComplexity::ADDITION;
let sum1 = graph.add_levelled_op([input1, input2], cpx_add, 1.0, Shape::number(), "sum");
let lut1 = graph.add_lut(sum1, FunctionTable::UNKWOWN);
let lut1 = graph.add_lut(sum1, FunctionTable::UNKWOWN, 2);
let concat = graph.add_levelled_op([input1, lut1], cpx_add, 1.0, Shape::number(), "concat");
let dot = graph.add_dot([concat], [1, 2]);
let lut2 = graph.add_lut(dot, FunctionTable::UNKWOWN);
let lut2 = graph.add_lut(dot, FunctionTable::UNKWOWN, 2);
let graph_params = maximal_unify(graph);
@@ -337,7 +343,7 @@ mod tests {
let concat =
graph.add_levelled_op([input1, input2], cpx_add, 1.0, Shape::vector(2), "concat");
let lut1 = graph.add_lut(concat, FunctionTable::UNKWOWN);
let lut1 = graph.add_lut(concat, FunctionTable::UNKWOWN, 2);
let graph_params = maximal_unify(graph);

View File

@@ -189,11 +189,10 @@ fn out_shapes(dag: &unparametrized::OperationDag) -> Vec<Shape> {
fn out_precision(
op: &unparametrized::UnparameterizedOperator,
out_precisions: &mut [Precision],
out_precisions: &[Precision],
) -> Precision {
match op {
Op::Input { out_precision, .. } => *out_precision,
Op::Lut { input, .. } => out_precisions[input.i],
Op::Input { out_precision, .. } | Op::Lut { out_precision, .. } => *out_precision,
Op::Dot { inputs, .. } | Op::LevelledOp { inputs, .. } => out_precisions[inputs[0].i],
}
}
@@ -202,7 +201,7 @@ fn out_precisions(dag: &unparametrized::OperationDag) -> Vec<Precision> {
let nb_ops = dag.operators.len();
let mut out_precisions = Vec::<Precision>::with_capacity(nb_ops);
for op in &dag.operators {
let precision = out_precision(op, &mut out_precisions);
let precision = out_precision(op, &out_precisions);
out_precisions.push(precision);
}
out_precisions
@@ -612,7 +611,7 @@ mod tests {
fn test_1_lut() {
let mut graph = unparametrized::OperationDag::new();
let input1 = graph.add_input(8, Shape::number());
let lut1 = graph.add_lut(input1, FunctionTable::UNKWOWN);
let lut1 = graph.add_lut(input1, FunctionTable::UNKWOWN, 8);
let analysis = analyze(&graph);
let one_lut_cost = 100.0;
let lwe_dim = 1024;
@@ -695,9 +694,9 @@ mod tests {
let input1 = graph.add_input(1, Shape::vector(2));
let weights = &Weights::vector([1, 2]);
let dot1 = graph.add_dot([input1], weights);
let lut1 = graph.add_lut(dot1, FunctionTable::UNKWOWN);
let lut1 = graph.add_lut(dot1, FunctionTable::UNKWOWN, 1);
let dot2 = graph.add_dot([lut1, lut1], weights);
let lut2 = graph.add_lut(dot2, FunctionTable::UNKWOWN);
let lut2 = graph.add_lut(dot2, FunctionTable::UNKWOWN, 1);
let analysis = analyze(&graph);
let one_lut_cost = 100.0;
let lwe_dim = 1024;
@@ -742,10 +741,10 @@ mod tests {
fn test_lut_dot_mixed_lut() {
let mut graph = unparametrized::OperationDag::new();
let input1 = graph.add_input(1, Shape::number());
let lut1 = graph.add_lut(input1, FunctionTable::UNKWOWN);
let lut1 = graph.add_lut(input1, FunctionTable::UNKWOWN, 1);
let weights = &Weights::vector([2, 3]);
let dot1 = graph.add_dot([input1, lut1], weights);
let _lut2 = graph.add_lut(dot1, FunctionTable::UNKWOWN);
let _lut2 = graph.add_lut(dot1, FunctionTable::UNKWOWN, 1);
let analysis = analyze(&graph);
let one_lut_cost = 100.0;
let lwe_dim = 1024;
@@ -768,15 +767,16 @@ mod tests {
#[test]
fn test_multi_precision_input() {
let mut graph = unparametrized::OperationDag::new();
let max_precision = 5_usize;
let max_precision: Precision = 5;
for i in 1..=max_precision {
let _ = graph.add_input(i as u8, Shape::number());
}
let analysis = analyze(&graph);
assert!(analysis.constraints_by_precisions.len() == max_precision);
assert!(analysis.constraints_by_precisions.len() == max_precision as usize);
let mut prev_safe_noise_bound = 0.0;
for (i, ns) in analysis.constraints_by_precisions.iter().enumerate() {
assert_eq!(ns.precision, (max_precision - i) as u8);
let i_prec = i as Precision;
assert_eq!(ns.precision, max_precision - i_prec);
assert_f64_eq(ns.pareto_output[0].input_coeff, 1.0);
assert!(prev_safe_noise_bound < ns.safe_variance_bound);
prev_safe_noise_bound = ns.safe_variance_bound;
@@ -786,16 +786,17 @@ mod tests {
#[test]
fn test_multi_precision_lut() {
let mut graph = unparametrized::OperationDag::new();
let max_precision = 5_usize;
for i in 1..=max_precision {
let input = graph.add_input(i as u8, Shape::number());
let _lut = graph.add_lut(input, FunctionTable::UNKWOWN);
let max_precision: Precision = 5;
for p in 1..=max_precision {
let input = graph.add_input(p, Shape::number());
let _lut = graph.add_lut(input, FunctionTable::UNKWOWN, p);
}
let analysis = analyze(&graph);
assert!(analysis.constraints_by_precisions.len() == max_precision);
assert!(analysis.constraints_by_precisions.len() == max_precision as usize);
let mut prev_safe_noise_bound = 0.0;
for (i, ns) in analysis.constraints_by_precisions.iter().enumerate() {
assert_eq!(ns.precision, (max_precision - i) as u8);
let i_prec = i as Precision;
assert_eq!(ns.precision, max_precision - i_prec);
assert_eq!(ns.pareto_output.len(), 1);
assert_eq!(ns.pareto_in_lut.len(), 1);
assert_f64_eq(ns.pareto_output[0].input_coeff, 0.0);

View File

@@ -1,7 +1,7 @@
use concrete_commons::dispersion::DispersionParameter;
use concrete_commons::numeric::UnsignedInteger;
use crate::dag::operator::LevelledComplexity;
use crate::dag::operator::{LevelledComplexity, Precision};
use crate::dag::unparametrized;
use crate::noise_estimator::error;
use crate::noise_estimator::operators::atomic_pattern as noise_atomic_pattern;
@@ -300,11 +300,12 @@ pub fn optimize_v0<W: UnsignedInteger>(
let complexity = LevelledComplexity::ADDITION * sum_size;
let comment = "dot";
let mut dag = unparametrized::OperationDag::new();
let input1 = dag.add_input(precision as u8, out_shape);
let precision = precision as Precision;
let input1 = dag.add_input(precision, out_shape);
let dot1 = dag.add_levelled_op([input1], complexity, same_scale_manp, out_shape, comment);
let lut1 = dag.add_lut(dot1, FunctionTable::UNKWOWN);
let lut1 = dag.add_lut(dot1, FunctionTable::UNKWOWN, precision);
let dot2 = dag.add_levelled_op([lut1], complexity, manp, out_shape, comment);
let _lut2 = dag.add_lut(dot2, FunctionTable::UNKWOWN);
let _lut2 = dag.add_lut(dot2, FunctionTable::UNKWOWN, precision);
let mut state = optimize::<u64>(
&dag,
security_level,
@@ -456,14 +457,14 @@ mod tests {
}
}
fn v0_parameter_ref_with_dot(precision: u64, weight: u64) {
fn v0_parameter_ref_with_dot(precision: Precision, weight: u64) {
let mut dag = unparametrized::OperationDag::new();
{
let input1 = dag.add_input(precision as u8, Shape::number());
let input1 = dag.add_input(precision, Shape::number());
let dot1 = dag.add_dot([input1], [1]);
let lut1 = dag.add_lut(dot1, FunctionTable::UNKWOWN);
let lut1 = dag.add_lut(dot1, FunctionTable::UNKWOWN, precision);
let dot2 = dag.add_dot([lut1], [weight]);
let _lut2 = dag.add_lut(dot2, FunctionTable::UNKWOWN);
let _lut2 = dag.add_lut(dot2, FunctionTable::UNKWOWN, precision);
}
{
let dag2 = analyze::analyze(&dag, &CONFIG);
@@ -488,7 +489,7 @@ mod tests {
let state = optimize(&dag);
let state_ref = atomic_pattern::optimize_one::<u64>(
1,
precision,
precision as u64,
security_level,
weight as f64,
maximum_acceptable_error_probability,
@@ -510,10 +511,10 @@ mod tests {
assert!(sol.assert_same(sol_ref));
}
fn no_lut_vs_lut(precision: u64) {
fn no_lut_vs_lut(precision: Precision) {
let mut dag_lut = unparametrized::OperationDag::new();
let input1 = dag_lut.add_input(precision as u8, Shape::number());
let _lut1 = dag_lut.add_lut(input1, FunctionTable::UNKWOWN);
let _lut1 = dag_lut.add_lut(input1, FunctionTable::UNKWOWN, precision);
let mut dag_no_lut = unparametrized::OperationDag::new();
let _input2 = dag_no_lut.add_input(precision as u8, Shape::number());
@@ -540,23 +541,26 @@ mod tests {
}
}
fn lut_with_input_base_noise_better_than_lut_with_lut_base_noise(precision: u64, weight: u64) {
fn lut_with_input_base_noise_better_than_lut_with_lut_base_noise(
precision: Precision,
weight: u64,
) {
let weight = &Weights::number(weight);
let mut dag_1 = unparametrized::OperationDag::new();
{
let input1 = dag_1.add_input(precision as u8, Shape::number());
let scaled_input1 = dag_1.add_dot([input1], weight);
let lut1 = dag_1.add_lut(scaled_input1, FunctionTable::UNKWOWN);
let _lut2 = dag_1.add_lut(lut1, FunctionTable::UNKWOWN);
let lut1 = dag_1.add_lut(scaled_input1, FunctionTable::UNKWOWN, precision);
let _lut2 = dag_1.add_lut(lut1, FunctionTable::UNKWOWN, precision);
}
let mut dag_2 = unparametrized::OperationDag::new();
{
let input1 = dag_2.add_input(precision as u8, Shape::number());
let lut1 = dag_2.add_lut(input1, FunctionTable::UNKWOWN);
let lut1 = dag_2.add_lut(input1, FunctionTable::UNKWOWN, precision);
let scaled_lut1 = dag_2.add_dot([lut1], weight);
let _lut2 = dag_2.add_lut(scaled_lut1, FunctionTable::UNKWOWN);
let _lut2 = dag_2.add_lut(scaled_lut1, FunctionTable::UNKWOWN, precision);
}
let state_1 = optimize(&dag_1);
@@ -581,12 +585,12 @@ mod tests {
}
}
fn circuit(dag: &mut unparametrized::OperationDag, precision: u8, weight: u64) {
fn circuit(dag: &mut unparametrized::OperationDag, precision: Precision, weight: u64) {
let input = dag.add_input(precision, Shape::number());
let dot1 = dag.add_dot([input], [weight]);
let lut1 = dag.add_lut(dot1, FunctionTable::UNKWOWN);
let lut1 = dag.add_lut(dot1, FunctionTable::UNKWOWN, precision);
let dot2 = dag.add_dot([lut1], [weight]);
let _lut2 = dag.add_lut(dot2, FunctionTable::UNKWOWN);
let _lut2 = dag.add_lut(dot2, FunctionTable::UNKWOWN, precision);
}
fn assert_multi_precision_dominate_single(weight: u64) -> Option<bool> {