feat(optimizer): check dag inputs index during dag correctness check

This commit is contained in:
rudy
2023-03-24 16:50:24 +01:00
committed by rudy-6-4
parent efa866f069
commit 8a672a0c59

View File

@@ -52,11 +52,25 @@ fn assert_non_empty_inputs(op: &unparametrized::UnparameterizedOperator) {
}
}
fn assert_inputs_index(op: &unparametrized::UnparameterizedOperator, first_bad_index: usize) {
let valid = match op {
Op::Input { .. } => true,
Op::Lut { input, .. } | Op::UnsafeCast { input, .. } | Op::Round { input, .. } => {
input.i < first_bad_index
}
Op::LevelledOp { inputs, .. } | Op::Dot { inputs, .. } => {
inputs.iter().all(|input| input.i < first_bad_index)
}
};
assert!(valid, "Invalid dag, bad index in op: {op:?}");
}
fn assert_dag_correctness(dag: &unparametrized::OperationDag) {
for op in &dag.operators {
for (i, op) in dag.operators.iter().enumerate() {
assert_non_empty_inputs(op);
assert_inputs_uniform_precisions(op, &dag.out_precisions);
assert_dot_uniform_inputs_shape(op, &dag.out_shapes);
assert_inputs_index(op, i);
}
}