From 517ab218dcf39459bb73de54b099ba07d67b6d39 Mon Sep 17 00:00:00 2001 From: rudy Date: Tue, 19 Jul 2022 17:13:50 +0200 Subject: [PATCH] feat(dag): support dot broadcast this is to facilitate a test with many multiplication --- .../src/optimization/dag/solo_key/analyze.rs | 9 +++++---- 1 file changed, 5 insertions(+), 4 deletions(-) diff --git a/concrete-optimizer/src/optimization/dag/solo_key/analyze.rs b/concrete-optimizer/src/optimization/dag/solo_key/analyze.rs index cec907161..948b04dac 100644 --- a/concrete-optimizer/src/optimization/dag/solo_key/analyze.rs +++ b/concrete-optimizer/src/optimization/dag/solo_key/analyze.rs @@ -240,7 +240,7 @@ fn out_variance( let input_shape = first(inputs, out_shapes); let kind = dot_kind(inputs.len() as u64, input_shape, weights); match kind { - DK::Simple | DK::Tensor => { + DK::Simple | DK::Tensor | DK::Broadcast => { let first_input = inputs[0]; let mut out_variance = SymbolicVariance::ZERO; for (j, &weight) in weights.values.iter().enumerate() { @@ -253,7 +253,7 @@ fn out_variance( } out_variance } - DK::CompatibleTensor { .. } | DK::Broadcast { .. } => todo!("TODO"), + DK::CompatibleTensor { .. } => todo!("TODO"), DK::Unsupported { .. } => panic!("Unsupported"), } } @@ -336,8 +336,9 @@ fn op_levelled_complexity( let input_shape = first(inputs, out_shapes); let kind = dot_kind(inputs.len() as u64, input_shape, weights); match kind { - DK::Simple | DK::Tensor => LevelledComplexity::ADDITION * weights.flat_size(), - DK::CompatibleTensor { .. } | DK::Broadcast { .. } => todo!("TODO"), + DK::Simple | DK::Tensor | DK::Broadcast | DK::CompatibleTensor => { + LevelledComplexity::ADDITION * (inputs.len() as u64) * input_shape.flat_size() + } DK::Unsupported { .. } => panic!("Unsupported"), } }