mirror of
https://github.com/zama-ai/concrete.git
synced 2026-02-08 19:44:57 -05:00
feat(dag): support dot broadcast
this is to facilitate a test with many multiplication
This commit is contained in:
@@ -240,7 +240,7 @@ fn out_variance(
|
||||
let input_shape = first(inputs, out_shapes);
|
||||
let kind = dot_kind(inputs.len() as u64, input_shape, weights);
|
||||
match kind {
|
||||
DK::Simple | DK::Tensor => {
|
||||
DK::Simple | DK::Tensor | DK::Broadcast => {
|
||||
let first_input = inputs[0];
|
||||
let mut out_variance = SymbolicVariance::ZERO;
|
||||
for (j, &weight) in weights.values.iter().enumerate() {
|
||||
@@ -253,7 +253,7 @@ fn out_variance(
|
||||
}
|
||||
out_variance
|
||||
}
|
||||
DK::CompatibleTensor { .. } | DK::Broadcast { .. } => todo!("TODO"),
|
||||
DK::CompatibleTensor { .. } => todo!("TODO"),
|
||||
DK::Unsupported { .. } => panic!("Unsupported"),
|
||||
}
|
||||
}
|
||||
@@ -336,8 +336,9 @@ fn op_levelled_complexity(
|
||||
let input_shape = first(inputs, out_shapes);
|
||||
let kind = dot_kind(inputs.len() as u64, input_shape, weights);
|
||||
match kind {
|
||||
DK::Simple | DK::Tensor => LevelledComplexity::ADDITION * weights.flat_size(),
|
||||
DK::CompatibleTensor { .. } | DK::Broadcast { .. } => todo!("TODO"),
|
||||
DK::Simple | DK::Tensor | DK::Broadcast | DK::CompatibleTensor => {
|
||||
LevelledComplexity::ADDITION * (inputs.len() as u64) * input_shape.flat_size()
|
||||
}
|
||||
DK::Unsupported { .. } => panic!("Unsupported"),
|
||||
}
|
||||
}
|
||||
|
||||
Reference in New Issue
Block a user