feat: round operation is supported on input dag

This commit is contained in:
rudy
2022-10-18 10:00:05 +02:00
committed by rudy-6-4
parent f19becac21
commit b7fa08ef57
8 changed files with 114 additions and 8 deletions

View File

@@ -1,2 +1,3 @@
pub mod operator;
pub mod rewrite;
pub mod unparametrized;

View File

@@ -92,6 +92,11 @@ pub enum Operator {
input: OperatorIndex,
out_precision: Precision, // precision is changed without modifying the input, can be increase or decrease
},
// Round is expanded to sub-graph on direct representation or fused in lut for Radix and Crt representation.
Round {
input: OperatorIndex,
out_precision: Precision,
},
}
#[derive(Clone, Copy, PartialEq, Eq, Debug)]

View File

@@ -0,0 +1,2 @@
pub mod regen;
pub mod round;

View File

@@ -0,0 +1,41 @@
use crate::dag::operator::operator::Operator;
use crate::dag::operator::OperatorIndex;
use crate::dag::unparametrized::OperationDag;
fn reindex_op_inputs(op: &Operator, old_index_to_new: &[usize]) -> Operator {
let mut op = op.clone();
match &mut op {
Operator::Input { .. } => (),
Operator::Lut { input, .. }
| Operator::UnsafeCast { input, .. }
| Operator::Round { input, .. } => input.i = old_index_to_new[input.i],
Operator::Dot { inputs, .. } | Operator::LevelledOp { inputs, .. } => {
for input in inputs {
input.i = old_index_to_new[input.i];
}
}
};
op
}
pub(crate) fn regen(
dag: &OperationDag,
f: &mut dyn FnMut(usize, &Operator, &mut OperationDag) -> Option<OperatorIndex>,
) -> OperationDag {
let mut regen_dag = OperationDag::new();
let mut old_index_to_new = vec![];
for (i, op) in dag.operators.iter().enumerate() {
let op = reindex_op_inputs(op, &old_index_to_new);
let size = regen_dag.operators.len();
if let Some(op_i) = f(i, &op, &mut regen_dag) {
old_index_to_new.push(op_i.i);
} else {
assert!(size == regen_dag.operators.len());
old_index_to_new.push(regen_dag.len());
regen_dag.operators.push(op.clone());
regen_dag.out_precisions.push(dag.out_precisions[i]);
regen_dag.out_shapes.push(dag.out_shapes[i].clone());
}
}
regen_dag
}

View File

@@ -0,0 +1,18 @@
use crate::dag::operator::{Operator, OperatorIndex};
use crate::dag::unparametrized::OperationDag;
use super::regen::regen;
fn regen_round(_: usize, op: &Operator, dag: &mut OperationDag) -> Option<OperatorIndex> {
match *op {
Operator::Round {
input,
out_precision,
} => Some(dag.add_expanded_round(input, out_precision)),
_ => None,
}
}
pub(crate) fn expand_round(dag: &OperationDag) -> OperationDag {
regen(dag, &mut regen_round)
}

View File

@@ -106,6 +106,19 @@ impl OperationDag {
})
}
pub fn add_round_op(
&mut self,
input: OperatorIndex,
rounded_precision: Precision,
) -> OperatorIndex {
let in_precision = self.out_precisions[input.i];
assert!(rounded_precision <= in_precision);
self.add_operator(Operator::Round {
input,
out_precision: rounded_precision,
})
}
#[allow(clippy::len_without_is_empty)]
pub fn len(&self) -> usize {
self.operators.len()
@@ -209,14 +222,25 @@ impl OperationDag {
self.add_lut(rounded, table, out_precision)
}
pub fn add_rounded_lut(
&mut self,
input: OperatorIndex,
table: FunctionTable,
rounded_precision: Precision,
out_precision: Precision,
) -> OperatorIndex {
let rounded = self.add_round_op(input, rounded_precision);
self.add_lut(rounded, table, out_precision)
}
fn infer_out_shape(&self, op: &UnparameterizedOperator) -> Shape {
match op {
Operator::Input { out_shape, .. } | Operator::LevelledOp { out_shape, .. } => {
out_shape.clone()
}
Operator::Lut { input, .. } | Operator::UnsafeCast { input, .. } => {
self.out_shapes[input.i].clone()
}
Operator::Lut { input, .. }
| Operator::UnsafeCast { input, .. }
| Operator::Round { input, .. } => self.out_shapes[input.i].clone(),
Operator::Dot {
inputs, weights, ..
} => {
@@ -247,7 +271,8 @@ impl OperationDag {
match op {
Operator::Input { out_precision, .. }
| Operator::Lut { out_precision, .. }
| Operator::UnsafeCast { out_precision, .. } => *out_precision,
| Operator::UnsafeCast { out_precision, .. }
| Operator::Round { out_precision, .. } => *out_precision,
Operator::Dot { inputs, .. } | Operator::LevelledOp { inputs, .. } => {
self.out_precisions[inputs[0].i]
}

View File

@@ -2,6 +2,7 @@ use super::symbolic_variance::{SymbolicVariance, VarianceOrigin};
use crate::dag::operator::{
dot_kind, DotKind, LevelledComplexity, OperatorIndex, Precision, Shape,
};
use crate::dag::rewrite::round::expand_round;
use crate::dag::unparametrized;
use crate::noise_estimator::error;
use crate::noise_estimator::p_error::{combine_errors, repeat_p_error};
@@ -59,6 +60,12 @@ fn assert_dag_correctness(dag: &unparametrized::OperationDag) {
}
}
fn assert_no_round(dag: &unparametrized::OperationDag) {
for op in &dag.operators {
assert!(!matches!(op, Op::Round { .. }));
}
}
fn assert_valid_variances(dag: &OperationDag) {
for &out_variance in &dag.out_variances {
assert!(
@@ -155,6 +162,9 @@ fn out_variance(
}
}
Op::UnsafeCast { input, .. } => out_variances[input.i],
Op::Round { .. } => {
unreachable!("Round should have been either expanded or integrated to a lut")
}
}
}
@@ -174,7 +184,7 @@ fn extra_final_values_to_check(dag: &unparametrized::OperationDag) -> Vec<bool>
for op in &dag.operators {
match op {
Op::Input { .. } => (),
Op::Lut { input, .. } | Op::UnsafeCast { input, .. } => {
Op::Lut { input, .. } | Op::UnsafeCast { input, .. } | Op::Round { input, .. } => {
extra_values_to_check[input.i] = false;
}
Op::Dot { inputs, .. } | Op::LevelledOp { inputs, .. } => {
@@ -248,6 +258,9 @@ fn op_levelled_complexity(
}
Op::LevelledOp { complexity, .. } => *complexity,
Op::Input { .. } | Op::Lut { .. } | Op::UnsafeCast { .. } => LevelledComplexity::ZERO,
Op::Round { .. } => {
unreachable!("Round should have been either expanded or integrated to a lut")
}
}
}
@@ -379,6 +392,8 @@ pub fn analyze(
noise_config: &NoiseBoundConfig,
) -> OperationDag {
assert_dag_correctness(dag);
let dag = &expand_round(dag);
assert_no_round(dag);
let out_variances = out_variances(dag);
let in_luts_variance = in_luts_variance(dag, &out_variances);
let nb_luts = lut_count_from_dag(dag);

View File

@@ -896,10 +896,9 @@ mod tests {
let mut dag = unparametrized::OperationDag::new();
let weight = Weights::number(weight);
let val = dag.add_input(precision, &shape);
let lut1 =
dag.add_expanded_rounded_lut(val, FunctionTable::UNKWOWN, rounded_precision, precision);
let lut1 = dag.add_rounded_lut(val, FunctionTable::UNKWOWN, rounded_precision, precision);
let dot = dag.add_dot([lut1], &weight);
let _lut2 = dag.add_expanded_rounded_lut(
let _lut2 = dag.add_rounded_lut(
dot,
FunctionTable::UNKWOWN,
rounded_precision,