mirror of
https://github.com/zama-ai/concrete.git
synced 2026-02-09 12:15:09 -05:00
feat: round operation is supported on input dag
This commit is contained in:
@@ -1,2 +1,3 @@
|
||||
pub mod operator;
|
||||
pub mod rewrite;
|
||||
pub mod unparametrized;
|
||||
|
||||
@@ -92,6 +92,11 @@ pub enum Operator {
|
||||
input: OperatorIndex,
|
||||
out_precision: Precision, // precision is changed without modifying the input, can be increase or decrease
|
||||
},
|
||||
// Round is expanded to sub-graph on direct representation or fused in lut for Radix and Crt representation.
|
||||
Round {
|
||||
input: OperatorIndex,
|
||||
out_precision: Precision,
|
||||
},
|
||||
}
|
||||
|
||||
#[derive(Clone, Copy, PartialEq, Eq, Debug)]
|
||||
|
||||
2
concrete-optimizer/src/dag/rewrite/mod.rs
Normal file
2
concrete-optimizer/src/dag/rewrite/mod.rs
Normal file
@@ -0,0 +1,2 @@
|
||||
pub mod regen;
|
||||
pub mod round;
|
||||
41
concrete-optimizer/src/dag/rewrite/regen.rs
Normal file
41
concrete-optimizer/src/dag/rewrite/regen.rs
Normal file
@@ -0,0 +1,41 @@
|
||||
use crate::dag::operator::operator::Operator;
|
||||
use crate::dag::operator::OperatorIndex;
|
||||
use crate::dag::unparametrized::OperationDag;
|
||||
|
||||
fn reindex_op_inputs(op: &Operator, old_index_to_new: &[usize]) -> Operator {
|
||||
let mut op = op.clone();
|
||||
match &mut op {
|
||||
Operator::Input { .. } => (),
|
||||
Operator::Lut { input, .. }
|
||||
| Operator::UnsafeCast { input, .. }
|
||||
| Operator::Round { input, .. } => input.i = old_index_to_new[input.i],
|
||||
Operator::Dot { inputs, .. } | Operator::LevelledOp { inputs, .. } => {
|
||||
for input in inputs {
|
||||
input.i = old_index_to_new[input.i];
|
||||
}
|
||||
}
|
||||
};
|
||||
op
|
||||
}
|
||||
|
||||
pub(crate) fn regen(
|
||||
dag: &OperationDag,
|
||||
f: &mut dyn FnMut(usize, &Operator, &mut OperationDag) -> Option<OperatorIndex>,
|
||||
) -> OperationDag {
|
||||
let mut regen_dag = OperationDag::new();
|
||||
let mut old_index_to_new = vec![];
|
||||
for (i, op) in dag.operators.iter().enumerate() {
|
||||
let op = reindex_op_inputs(op, &old_index_to_new);
|
||||
let size = regen_dag.operators.len();
|
||||
if let Some(op_i) = f(i, &op, &mut regen_dag) {
|
||||
old_index_to_new.push(op_i.i);
|
||||
} else {
|
||||
assert!(size == regen_dag.operators.len());
|
||||
old_index_to_new.push(regen_dag.len());
|
||||
regen_dag.operators.push(op.clone());
|
||||
regen_dag.out_precisions.push(dag.out_precisions[i]);
|
||||
regen_dag.out_shapes.push(dag.out_shapes[i].clone());
|
||||
}
|
||||
}
|
||||
regen_dag
|
||||
}
|
||||
18
concrete-optimizer/src/dag/rewrite/round.rs
Normal file
18
concrete-optimizer/src/dag/rewrite/round.rs
Normal file
@@ -0,0 +1,18 @@
|
||||
use crate::dag::operator::{Operator, OperatorIndex};
|
||||
use crate::dag::unparametrized::OperationDag;
|
||||
|
||||
use super::regen::regen;
|
||||
|
||||
fn regen_round(_: usize, op: &Operator, dag: &mut OperationDag) -> Option<OperatorIndex> {
|
||||
match *op {
|
||||
Operator::Round {
|
||||
input,
|
||||
out_precision,
|
||||
} => Some(dag.add_expanded_round(input, out_precision)),
|
||||
_ => None,
|
||||
}
|
||||
}
|
||||
|
||||
pub(crate) fn expand_round(dag: &OperationDag) -> OperationDag {
|
||||
regen(dag, &mut regen_round)
|
||||
}
|
||||
@@ -106,6 +106,19 @@ impl OperationDag {
|
||||
})
|
||||
}
|
||||
|
||||
pub fn add_round_op(
|
||||
&mut self,
|
||||
input: OperatorIndex,
|
||||
rounded_precision: Precision,
|
||||
) -> OperatorIndex {
|
||||
let in_precision = self.out_precisions[input.i];
|
||||
assert!(rounded_precision <= in_precision);
|
||||
self.add_operator(Operator::Round {
|
||||
input,
|
||||
out_precision: rounded_precision,
|
||||
})
|
||||
}
|
||||
|
||||
#[allow(clippy::len_without_is_empty)]
|
||||
pub fn len(&self) -> usize {
|
||||
self.operators.len()
|
||||
@@ -209,14 +222,25 @@ impl OperationDag {
|
||||
self.add_lut(rounded, table, out_precision)
|
||||
}
|
||||
|
||||
pub fn add_rounded_lut(
|
||||
&mut self,
|
||||
input: OperatorIndex,
|
||||
table: FunctionTable,
|
||||
rounded_precision: Precision,
|
||||
out_precision: Precision,
|
||||
) -> OperatorIndex {
|
||||
let rounded = self.add_round_op(input, rounded_precision);
|
||||
self.add_lut(rounded, table, out_precision)
|
||||
}
|
||||
|
||||
fn infer_out_shape(&self, op: &UnparameterizedOperator) -> Shape {
|
||||
match op {
|
||||
Operator::Input { out_shape, .. } | Operator::LevelledOp { out_shape, .. } => {
|
||||
out_shape.clone()
|
||||
}
|
||||
Operator::Lut { input, .. } | Operator::UnsafeCast { input, .. } => {
|
||||
self.out_shapes[input.i].clone()
|
||||
}
|
||||
Operator::Lut { input, .. }
|
||||
| Operator::UnsafeCast { input, .. }
|
||||
| Operator::Round { input, .. } => self.out_shapes[input.i].clone(),
|
||||
Operator::Dot {
|
||||
inputs, weights, ..
|
||||
} => {
|
||||
@@ -247,7 +271,8 @@ impl OperationDag {
|
||||
match op {
|
||||
Operator::Input { out_precision, .. }
|
||||
| Operator::Lut { out_precision, .. }
|
||||
| Operator::UnsafeCast { out_precision, .. } => *out_precision,
|
||||
| Operator::UnsafeCast { out_precision, .. }
|
||||
| Operator::Round { out_precision, .. } => *out_precision,
|
||||
Operator::Dot { inputs, .. } | Operator::LevelledOp { inputs, .. } => {
|
||||
self.out_precisions[inputs[0].i]
|
||||
}
|
||||
|
||||
@@ -2,6 +2,7 @@ use super::symbolic_variance::{SymbolicVariance, VarianceOrigin};
|
||||
use crate::dag::operator::{
|
||||
dot_kind, DotKind, LevelledComplexity, OperatorIndex, Precision, Shape,
|
||||
};
|
||||
use crate::dag::rewrite::round::expand_round;
|
||||
use crate::dag::unparametrized;
|
||||
use crate::noise_estimator::error;
|
||||
use crate::noise_estimator::p_error::{combine_errors, repeat_p_error};
|
||||
@@ -59,6 +60,12 @@ fn assert_dag_correctness(dag: &unparametrized::OperationDag) {
|
||||
}
|
||||
}
|
||||
|
||||
fn assert_no_round(dag: &unparametrized::OperationDag) {
|
||||
for op in &dag.operators {
|
||||
assert!(!matches!(op, Op::Round { .. }));
|
||||
}
|
||||
}
|
||||
|
||||
fn assert_valid_variances(dag: &OperationDag) {
|
||||
for &out_variance in &dag.out_variances {
|
||||
assert!(
|
||||
@@ -155,6 +162,9 @@ fn out_variance(
|
||||
}
|
||||
}
|
||||
Op::UnsafeCast { input, .. } => out_variances[input.i],
|
||||
Op::Round { .. } => {
|
||||
unreachable!("Round should have been either expanded or integrated to a lut")
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
@@ -174,7 +184,7 @@ fn extra_final_values_to_check(dag: &unparametrized::OperationDag) -> Vec<bool>
|
||||
for op in &dag.operators {
|
||||
match op {
|
||||
Op::Input { .. } => (),
|
||||
Op::Lut { input, .. } | Op::UnsafeCast { input, .. } => {
|
||||
Op::Lut { input, .. } | Op::UnsafeCast { input, .. } | Op::Round { input, .. } => {
|
||||
extra_values_to_check[input.i] = false;
|
||||
}
|
||||
Op::Dot { inputs, .. } | Op::LevelledOp { inputs, .. } => {
|
||||
@@ -248,6 +258,9 @@ fn op_levelled_complexity(
|
||||
}
|
||||
Op::LevelledOp { complexity, .. } => *complexity,
|
||||
Op::Input { .. } | Op::Lut { .. } | Op::UnsafeCast { .. } => LevelledComplexity::ZERO,
|
||||
Op::Round { .. } => {
|
||||
unreachable!("Round should have been either expanded or integrated to a lut")
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
@@ -379,6 +392,8 @@ pub fn analyze(
|
||||
noise_config: &NoiseBoundConfig,
|
||||
) -> OperationDag {
|
||||
assert_dag_correctness(dag);
|
||||
let dag = &expand_round(dag);
|
||||
assert_no_round(dag);
|
||||
let out_variances = out_variances(dag);
|
||||
let in_luts_variance = in_luts_variance(dag, &out_variances);
|
||||
let nb_luts = lut_count_from_dag(dag);
|
||||
|
||||
@@ -896,10 +896,9 @@ mod tests {
|
||||
let mut dag = unparametrized::OperationDag::new();
|
||||
let weight = Weights::number(weight);
|
||||
let val = dag.add_input(precision, &shape);
|
||||
let lut1 =
|
||||
dag.add_expanded_rounded_lut(val, FunctionTable::UNKWOWN, rounded_precision, precision);
|
||||
let lut1 = dag.add_rounded_lut(val, FunctionTable::UNKWOWN, rounded_precision, precision);
|
||||
let dot = dag.add_dot([lut1], &weight);
|
||||
let _lut2 = dag.add_expanded_rounded_lut(
|
||||
let _lut2 = dag.add_rounded_lut(
|
||||
dot,
|
||||
FunctionTable::UNKWOWN,
|
||||
rounded_precision,
|
||||
|
||||
Reference in New Issue
Block a user