From 8a672a0c594ef9002db8e2bf0f6464591cf714fa Mon Sep 17 00:00:00 2001 From: rudy Date: Fri, 24 Mar 2023 16:50:24 +0100 Subject: [PATCH] feat(optimizer): check dag inputs index during dag correctness check --- .../src/optimization/dag/solo_key/analyze.rs | 16 +++++++++++++++- 1 file changed, 15 insertions(+), 1 deletion(-) diff --git a/compilers/concrete-optimizer/concrete-optimizer/src/optimization/dag/solo_key/analyze.rs b/compilers/concrete-optimizer/concrete-optimizer/src/optimization/dag/solo_key/analyze.rs index 2e86858e5..02f8e14c9 100644 --- a/compilers/concrete-optimizer/concrete-optimizer/src/optimization/dag/solo_key/analyze.rs +++ b/compilers/concrete-optimizer/concrete-optimizer/src/optimization/dag/solo_key/analyze.rs @@ -52,11 +52,25 @@ fn assert_non_empty_inputs(op: &unparametrized::UnparameterizedOperator) { } } +fn assert_inputs_index(op: &unparametrized::UnparameterizedOperator, first_bad_index: usize) { + let valid = match op { + Op::Input { .. } => true, + Op::Lut { input, .. } | Op::UnsafeCast { input, .. } | Op::Round { input, .. } => { + input.i < first_bad_index + } + Op::LevelledOp { inputs, .. } | Op::Dot { inputs, .. } => { + inputs.iter().all(|input| input.i < first_bad_index) + } + }; + assert!(valid, "Invalid dag, bad index in op: {op:?}"); +} + fn assert_dag_correctness(dag: &unparametrized::OperationDag) { - for op in &dag.operators { + for (i, op) in dag.operators.iter().enumerate() { assert_non_empty_inputs(op); assert_inputs_uniform_precisions(op, &dag.out_precisions); assert_dot_uniform_inputs_shape(op, &dag.out_shapes); + assert_inputs_index(op, i); } }