mirror of
https://github.com/zama-ai/concrete.git
synced 2026-02-17 16:11:26 -05:00
fix(optimizer): incorrect broadcast shape
This commit is contained in:
@@ -11,7 +11,7 @@ pub enum DotKind {
|
||||
// inputs = [[x, y, z], [x, y, z]], weights = [[a,b,c]], = [same, same]
|
||||
// inputs = [[x, y, z], [u, v, w]], weights = [a, b], [x*a + u*b, y*a + v*b, z*c + w*c]
|
||||
// inputs = [[x, y, z]], weights = [a], [x*a, y*a, z*a]
|
||||
Broadcast,
|
||||
Broadcast { shape: Shape },
|
||||
Unsupported,
|
||||
}
|
||||
|
||||
@@ -25,13 +25,19 @@ pub fn dot_kind<W>(nb_inputs: u64, input_shape: &Shape, weights: &ClearTensor<W>
|
||||
} else if inputs_shape == weights.shape {
|
||||
DotKind::CompatibleTensor
|
||||
} else if nb_inputs == 1 && input_shape.erase_first_dim() == weights.shape {
|
||||
DotKind::Broadcast
|
||||
DotKind::Broadcast {
|
||||
shape: Shape::vector(input_shape.first_dim_size()),
|
||||
}
|
||||
} else if weights.shape.is_vector() && weights.shape.flat_size() == nb_inputs {
|
||||
// Same as simple but with tensor inputs
|
||||
DotKind::Broadcast
|
||||
DotKind::Broadcast {
|
||||
shape: input_shape.clone(),
|
||||
}
|
||||
} else if weights.shape.is_number() && nb_inputs == 1 {
|
||||
// Any input multiply by one number
|
||||
DotKind::Broadcast
|
||||
DotKind::Broadcast {
|
||||
shape: input_shape.clone(),
|
||||
}
|
||||
} else {
|
||||
DotKind::Unsupported
|
||||
}
|
||||
@@ -65,7 +71,22 @@ mod tests {
|
||||
};
|
||||
assert_eq!(
|
||||
dot_kind(1, &s2x2, &Weights::vector([1, 2])),
|
||||
DotKind::Broadcast
|
||||
DotKind::Broadcast {
|
||||
shape: Shape::vector(2)
|
||||
}
|
||||
);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_broadcast_scalar_mul() {
|
||||
let s2x2 = Shape {
|
||||
dimensions_size: vec![2, 2],
|
||||
};
|
||||
assert_eq!(
|
||||
dot_kind(1, &s2x2, &Weights::number(1)),
|
||||
DotKind::Broadcast {
|
||||
shape: s2x2.clone()
|
||||
}
|
||||
);
|
||||
}
|
||||
|
||||
|
||||
@@ -260,7 +260,7 @@ impl OperationDag {
|
||||
DotKind::Simple | DotKind::Tensor | DotKind::CompatibleTensor => {
|
||||
Shape::number()
|
||||
}
|
||||
DotKind::Broadcast { .. } => Shape::vector(input_shape.first_dim_size()),
|
||||
DotKind::Broadcast { shape } => shape,
|
||||
DotKind::Unsupported { .. } => {
|
||||
let weights_shape = &weights.shape;
|
||||
println!();
|
||||
|
||||
@@ -192,7 +192,7 @@ fn out_variance(
|
||||
let input_shape = first(inputs, out_shapes);
|
||||
let kind = dot_kind(inputs.len() as u64, input_shape, weights);
|
||||
match kind {
|
||||
DK::Simple | DK::Tensor | DK::Broadcast => {
|
||||
DK::Simple | DK::Tensor | DK::Broadcast { .. } => {
|
||||
let inputs_variance = (0..weights.values.len()).map(|j| {
|
||||
let input = if inputs.len() > 1 {
|
||||
inputs[j]
|
||||
|
||||
@@ -163,7 +163,7 @@ fn out_variance(
|
||||
let input_shape = first(inputs, out_shapes);
|
||||
let kind = dot_kind(inputs.len() as u64, input_shape, weights);
|
||||
match kind {
|
||||
DK::Simple | DK::Tensor | DK::Broadcast => {
|
||||
DK::Simple | DK::Tensor | DK::Broadcast { .. } => {
|
||||
let first_input = inputs[0];
|
||||
let mut out_variance = SymbolicVariance::ZERO;
|
||||
for (j, &weight) in weights.values.iter().enumerate() {
|
||||
@@ -269,7 +269,7 @@ fn op_levelled_complexity(
|
||||
let input_shape = first(inputs, out_shapes);
|
||||
let kind = dot_kind(inputs.len() as u64, input_shape, weights);
|
||||
match kind {
|
||||
DK::Simple | DK::Tensor | DK::Broadcast | DK::CompatibleTensor => {
|
||||
DK::Simple | DK::Tensor | DK::Broadcast { .. } | DK::CompatibleTensor => {
|
||||
LevelledComplexity::ADDITION * (inputs.len() as u64) * input_shape.flat_size()
|
||||
}
|
||||
DK::Unsupported { .. } => panic!("Unsupported"),
|
||||
@@ -883,10 +883,10 @@ pub mod tests {
|
||||
let shape = Shape {
|
||||
dimensions_size: vec![2, 2],
|
||||
};
|
||||
let input1 = graph.add_input(1, shape);
|
||||
let input1 = graph.add_input(1, &shape);
|
||||
let weights = &Weights::number(2);
|
||||
_ = graph.add_dot([input1], weights);
|
||||
assert!(*graph.out_shapes.last().unwrap() == Shape::vector(2));
|
||||
assert!(*graph.out_shapes.last().unwrap() == shape);
|
||||
let analysis = analyze(&graph);
|
||||
assert_f64_eq(analysis.out_variances.last().unwrap().input_coeff, 4.0);
|
||||
}
|
||||
@@ -902,7 +902,7 @@ pub mod tests {
|
||||
let lut2 = graph.add_lut(input2, FunctionTable::UNKWOWN, 1);
|
||||
let weights = &Weights::vector([2, 3]);
|
||||
_ = graph.add_dot([input1, lut2], weights);
|
||||
assert!(*graph.out_shapes.last().unwrap() == Shape::vector(2));
|
||||
assert!(*graph.out_shapes.last().unwrap() == shape);
|
||||
let analysis = analyze(&graph);
|
||||
assert_f64_eq(analysis.out_variances.last().unwrap().input_coeff, 4.0);
|
||||
assert_f64_eq(analysis.out_variances.last().unwrap().lut_coeff, 9.0);
|
||||
|
||||
Reference in New Issue
Block a user