feat(dag): support dot broadcast

this is to facilitate a test with many multiplication
This commit is contained in:
rudy
2022-07-19 17:13:50 +02:00
committed by rudy-6-4
parent 4fffc26bbc
commit 517ab218dc

View File

@@ -240,7 +240,7 @@ fn out_variance(
let input_shape = first(inputs, out_shapes);
let kind = dot_kind(inputs.len() as u64, input_shape, weights);
match kind {
DK::Simple | DK::Tensor => {
DK::Simple | DK::Tensor | DK::Broadcast => {
let first_input = inputs[0];
let mut out_variance = SymbolicVariance::ZERO;
for (j, &weight) in weights.values.iter().enumerate() {
@@ -253,7 +253,7 @@ fn out_variance(
}
out_variance
}
DK::CompatibleTensor { .. } | DK::Broadcast { .. } => todo!("TODO"),
DK::CompatibleTensor { .. } => todo!("TODO"),
DK::Unsupported { .. } => panic!("Unsupported"),
}
}
@@ -336,8 +336,9 @@ fn op_levelled_complexity(
let input_shape = first(inputs, out_shapes);
let kind = dot_kind(inputs.len() as u64, input_shape, weights);
match kind {
DK::Simple | DK::Tensor => LevelledComplexity::ADDITION * weights.flat_size(),
DK::CompatibleTensor { .. } | DK::Broadcast { .. } => todo!("TODO"),
DK::Simple | DK::Tensor | DK::Broadcast | DK::CompatibleTensor => {
LevelledComplexity::ADDITION * (inputs.len() as u64) * input_shape.flat_size()
}
DK::Unsupported { .. } => panic!("Unsupported"),
}
}