fix(optimizer): incorrect broadcast shape

This commit is contained in:
rudy
2023-09-15 14:47:50 +02:00
committed by Quentin Bourgerie
parent 1c0a70f911
commit 3cd26192bc
4 changed files with 33 additions and 12 deletions

View File

@@ -11,7 +11,7 @@ pub enum DotKind {
// inputs = [[x, y, z], [x, y, z]], weights = [[a,b,c]], = [same, same]
// inputs = [[x, y, z], [u, v, w]], weights = [a, b], [x*a + u*b, y*a + v*b, z*c + w*c]
// inputs = [[x, y, z]], weights = [a], [x*a, y*a, z*a]
Broadcast,
Broadcast { shape: Shape },
Unsupported,
}
@@ -25,13 +25,19 @@ pub fn dot_kind<W>(nb_inputs: u64, input_shape: &Shape, weights: &ClearTensor<W>
} else if inputs_shape == weights.shape {
DotKind::CompatibleTensor
} else if nb_inputs == 1 && input_shape.erase_first_dim() == weights.shape {
DotKind::Broadcast
DotKind::Broadcast {
shape: Shape::vector(input_shape.first_dim_size()),
}
} else if weights.shape.is_vector() && weights.shape.flat_size() == nb_inputs {
// Same as simple but with tensor inputs
DotKind::Broadcast
DotKind::Broadcast {
shape: input_shape.clone(),
}
} else if weights.shape.is_number() && nb_inputs == 1 {
// Any input multiply by one number
DotKind::Broadcast
DotKind::Broadcast {
shape: input_shape.clone(),
}
} else {
DotKind::Unsupported
}
@@ -65,7 +71,22 @@ mod tests {
};
assert_eq!(
dot_kind(1, &s2x2, &Weights::vector([1, 2])),
DotKind::Broadcast
DotKind::Broadcast {
shape: Shape::vector(2)
}
);
}
#[test]
fn test_broadcast_scalar_mul() {
let s2x2 = Shape {
dimensions_size: vec![2, 2],
};
assert_eq!(
dot_kind(1, &s2x2, &Weights::number(1)),
DotKind::Broadcast {
shape: s2x2.clone()
}
);
}

View File

@@ -260,7 +260,7 @@ impl OperationDag {
DotKind::Simple | DotKind::Tensor | DotKind::CompatibleTensor => {
Shape::number()
}
DotKind::Broadcast { .. } => Shape::vector(input_shape.first_dim_size()),
DotKind::Broadcast { shape } => shape,
DotKind::Unsupported { .. } => {
let weights_shape = &weights.shape;
println!();

View File

@@ -192,7 +192,7 @@ fn out_variance(
let input_shape = first(inputs, out_shapes);
let kind = dot_kind(inputs.len() as u64, input_shape, weights);
match kind {
DK::Simple | DK::Tensor | DK::Broadcast => {
DK::Simple | DK::Tensor | DK::Broadcast { .. } => {
let inputs_variance = (0..weights.values.len()).map(|j| {
let input = if inputs.len() > 1 {
inputs[j]

View File

@@ -163,7 +163,7 @@ fn out_variance(
let input_shape = first(inputs, out_shapes);
let kind = dot_kind(inputs.len() as u64, input_shape, weights);
match kind {
DK::Simple | DK::Tensor | DK::Broadcast => {
DK::Simple | DK::Tensor | DK::Broadcast { .. } => {
let first_input = inputs[0];
let mut out_variance = SymbolicVariance::ZERO;
for (j, &weight) in weights.values.iter().enumerate() {
@@ -269,7 +269,7 @@ fn op_levelled_complexity(
let input_shape = first(inputs, out_shapes);
let kind = dot_kind(inputs.len() as u64, input_shape, weights);
match kind {
DK::Simple | DK::Tensor | DK::Broadcast | DK::CompatibleTensor => {
DK::Simple | DK::Tensor | DK::Broadcast { .. } | DK::CompatibleTensor => {
LevelledComplexity::ADDITION * (inputs.len() as u64) * input_shape.flat_size()
}
DK::Unsupported { .. } => panic!("Unsupported"),
@@ -883,10 +883,10 @@ pub mod tests {
let shape = Shape {
dimensions_size: vec![2, 2],
};
let input1 = graph.add_input(1, shape);
let input1 = graph.add_input(1, &shape);
let weights = &Weights::number(2);
_ = graph.add_dot([input1], weights);
assert!(*graph.out_shapes.last().unwrap() == Shape::vector(2));
assert!(*graph.out_shapes.last().unwrap() == shape);
let analysis = analyze(&graph);
assert_f64_eq(analysis.out_variances.last().unwrap().input_coeff, 4.0);
}
@@ -902,7 +902,7 @@ pub mod tests {
let lut2 = graph.add_lut(input2, FunctionTable::UNKWOWN, 1);
let weights = &Weights::vector([2, 3]);
_ = graph.add_dot([input1, lut2], weights);
assert!(*graph.out_shapes.last().unwrap() == Shape::vector(2));
assert!(*graph.out_shapes.last().unwrap() == shape);
let analysis = analyze(&graph);
assert_f64_eq(analysis.out_variances.last().unwrap().input_coeff, 4.0);
assert_f64_eq(analysis.out_variances.last().unwrap().lut_coeff, 9.0);