diff --git a/concrete-optimizer/src/optimization/dag/solo_key/analyze.rs b/concrete-optimizer/src/optimization/dag/solo_key/analyze.rs index cec907161..948b04dac 100644 --- a/concrete-optimizer/src/optimization/dag/solo_key/analyze.rs +++ b/concrete-optimizer/src/optimization/dag/solo_key/analyze.rs @@ -240,7 +240,7 @@ fn out_variance( let input_shape = first(inputs, out_shapes); let kind = dot_kind(inputs.len() as u64, input_shape, weights); match kind { - DK::Simple | DK::Tensor => { + DK::Simple | DK::Tensor | DK::Broadcast => { let first_input = inputs[0]; let mut out_variance = SymbolicVariance::ZERO; for (j, &weight) in weights.values.iter().enumerate() { @@ -253,7 +253,7 @@ fn out_variance( } out_variance } - DK::CompatibleTensor { .. } | DK::Broadcast { .. } => todo!("TODO"), + DK::CompatibleTensor { .. } => todo!("TODO"), DK::Unsupported { .. } => panic!("Unsupported"), } } @@ -336,8 +336,9 @@ fn op_levelled_complexity( let input_shape = first(inputs, out_shapes); let kind = dot_kind(inputs.len() as u64, input_shape, weights); match kind { - DK::Simple | DK::Tensor => LevelledComplexity::ADDITION * weights.flat_size(), - DK::CompatibleTensor { .. } | DK::Broadcast { .. } => todo!("TODO"), + DK::Simple | DK::Tensor | DK::Broadcast | DK::CompatibleTensor => { + LevelledComplexity::ADDITION * (inputs.len() as u64) * input_shape.flat_size() + } DK::Unsupported { .. } => panic!("Unsupported"), } }