mirror of
https://github.com/powdr-labs/powdr.git
synced 2026-01-09 14:48:16 -05:00
Avoid quadratic terms that sum to zero. (#3084)
This commit is contained in:
@@ -18,6 +18,10 @@ use super::effect::{Assertion, BitDecomposition, BitDecompositionComponent, Effe
|
||||
use super::range_constraint::RangeConstraint;
|
||||
use super::symbolic_expression::SymbolicExpression;
|
||||
|
||||
/// Terms with more than `MAX_SUM_SIZE_FOR_QUADRATIC_ANALYSIS` quadratic terms
|
||||
/// are not analyzed for pairs that sum to zero.
|
||||
const MAX_SUM_SIZE_FOR_QUADRATIC_ANALYSIS: usize = 20;
|
||||
|
||||
#[derive(Default)]
|
||||
pub struct ProcessResult<T: RuntimeConstant, V> {
|
||||
pub effects: Vec<Effect<T, V>>,
|
||||
@@ -189,6 +193,17 @@ impl<T: RuntimeConstant, V: Ord + Clone + Eq> GroupedExpression<T, V> {
|
||||
(&self.quadratic, self.linear.iter(), &self.constant)
|
||||
}
|
||||
|
||||
/// Computes the degree of a GroupedExpression (as it is contsructed) in the unknown variables.
|
||||
/// Variables inside runtime constants are ignored.
|
||||
pub fn degree(&self) -> usize {
|
||||
self.quadratic
|
||||
.iter()
|
||||
.map(|(l, r)| l.degree() + r.degree())
|
||||
.chain((!self.linear.is_empty()).then_some(1))
|
||||
.max()
|
||||
.unwrap_or(0)
|
||||
}
|
||||
|
||||
/// Returns the coefficient of the variable `variable` if this is an affine expression.
|
||||
/// Panics if the expression is quadratic.
|
||||
pub fn coefficient_of_variable(&self, var: &V) -> Option<&T> {
|
||||
@@ -261,6 +276,8 @@ impl<T: RuntimeConstant + Substitutable<V>, V: Ord + Clone + Eq> GroupedExpressi
|
||||
_ => true,
|
||||
}
|
||||
});
|
||||
remove_quadratic_terms_adding_to_zero(&mut self.quadratic);
|
||||
|
||||
if to_add.try_to_known().map(|ta| ta.is_known_zero()) != Some(true) {
|
||||
*self += to_add;
|
||||
}
|
||||
@@ -306,6 +323,7 @@ impl<T: RuntimeConstant + Substitutable<V>, V: Ord + Clone + Eq> GroupedExpressi
|
||||
}
|
||||
})
|
||||
.collect();
|
||||
remove_quadratic_terms_adding_to_zero(&mut self.quadratic);
|
||||
|
||||
*self += to_add;
|
||||
}
|
||||
@@ -862,7 +880,7 @@ impl<T: RuntimeConstant, V: Clone + Ord + Eq> AddAssign<GroupedExpression<T, V>>
|
||||
for GroupedExpression<T, V>
|
||||
{
|
||||
fn add_assign(&mut self, rhs: Self) {
|
||||
self.quadratic.extend(rhs.quadratic);
|
||||
self.quadratic = combine_removing_zeros(std::mem::take(&mut self.quadratic), rhs.quadratic);
|
||||
for (var, coeff) in rhs.linear {
|
||||
self.linear
|
||||
.entry(var.clone())
|
||||
@@ -874,6 +892,83 @@ impl<T: RuntimeConstant, V: Clone + Ord + Eq> AddAssign<GroupedExpression<T, V>>
|
||||
}
|
||||
}
|
||||
|
||||
/// Returns the sum of these quadratic terms while removing terms that
|
||||
/// cancel each other out.
|
||||
fn combine_removing_zeros<E: PartialEq>(first: Vec<(E, E)>, mut second: Vec<(E, E)>) -> Vec<(E, E)>
|
||||
where
|
||||
for<'a> &'a E: Neg<Output = E>,
|
||||
{
|
||||
if first.len() + second.len() > MAX_SUM_SIZE_FOR_QUADRATIC_ANALYSIS {
|
||||
// If there are too many terms, we cannot do this efficiently.
|
||||
return first.into_iter().chain(second).collect();
|
||||
}
|
||||
|
||||
let mut result = first
|
||||
.into_iter()
|
||||
.filter(|first| {
|
||||
// Try to find l1 * r1 inside `second`.
|
||||
if let Some((j, _)) = second
|
||||
.iter()
|
||||
.find_position(|second| quadratic_terms_add_to_zero(first, second))
|
||||
{
|
||||
// We found a match, so they cancel each other out, we remove both.
|
||||
second.remove(j);
|
||||
false
|
||||
} else {
|
||||
true
|
||||
}
|
||||
})
|
||||
.collect_vec();
|
||||
result.extend(second);
|
||||
result
|
||||
}
|
||||
|
||||
/// Removes pairs of items from `terms` whose products add to zero.
|
||||
fn remove_quadratic_terms_adding_to_zero<E: PartialEq>(terms: &mut Vec<(E, E)>)
|
||||
where
|
||||
for<'a> &'a E: Neg<Output = E>,
|
||||
{
|
||||
if terms.len() > MAX_SUM_SIZE_FOR_QUADRATIC_ANALYSIS {
|
||||
// If there are too many terms, we cannot do this efficiently.
|
||||
return;
|
||||
}
|
||||
|
||||
let mut to_remove = HashSet::new();
|
||||
for ((i, first), (j, second)) in terms.iter().enumerate().tuple_combinations() {
|
||||
if to_remove.contains(&i) || to_remove.contains(&j) {
|
||||
// We already removed this term.
|
||||
continue;
|
||||
}
|
||||
if quadratic_terms_add_to_zero(first, second) {
|
||||
// We found a match, so they cancel each other out, we remove both.
|
||||
to_remove.insert(i);
|
||||
to_remove.insert(j);
|
||||
}
|
||||
}
|
||||
if !to_remove.is_empty() {
|
||||
*terms = terms
|
||||
.drain(..)
|
||||
.enumerate()
|
||||
.filter(|(i, _)| !to_remove.contains(i))
|
||||
.map(|(_, term)| term)
|
||||
.collect();
|
||||
}
|
||||
}
|
||||
|
||||
/// Returns true if `first.0 * first.1 = -second.0 * second.1`,
|
||||
/// but does not catch all cases.
|
||||
fn quadratic_terms_add_to_zero<E: PartialEq>(first: &(E, E), second: &(E, E)) -> bool
|
||||
where
|
||||
for<'a> &'a E: Neg<Output = E>,
|
||||
{
|
||||
let (s0, s1) = second;
|
||||
// Check if `first.0 * first.1 == -(second.0 * second.1)`, but we can swap left and right
|
||||
// and we can put the negation either left or right.
|
||||
let n1 = (&-s0, s1);
|
||||
let n2 = (s0, &-s1);
|
||||
[n1, n2].contains(&(&first.0, &first.1)) || [n1, n2].contains(&(&first.1, &first.0))
|
||||
}
|
||||
|
||||
impl<T: RuntimeConstant, V: Clone + Ord + Eq> Sub for &GroupedExpression<T, V> {
|
||||
type Output = GroupedExpression<T, V>;
|
||||
|
||||
@@ -1648,4 +1743,30 @@ c = (((10 + Z) & 0xff000000) >> 24) [negative];
|
||||
"-t * y"
|
||||
);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn combine_removing_zeros() {
|
||||
let a = var("x") * var("y") + var("z") * constant(3);
|
||||
let b = var("t") * var("u") + constant(5) + var("y") * var("x");
|
||||
assert_eq!(
|
||||
(a.clone() - b.clone()).to_string(),
|
||||
"-((t) * (u) - 3 * z + 5)"
|
||||
);
|
||||
assert_eq!((b - a).to_string(), "(t) * (u) - 3 * z + 5");
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn remove_quadratic_zeros_after_substitution() {
|
||||
let a = var("x") * var("r") + var("z") * constant(3);
|
||||
let b = var("t") * var("u") + constant(5) + var("y") * var("x");
|
||||
let mut t = b - a;
|
||||
// Cannot simplify yet, because the terms are different
|
||||
assert_eq!(
|
||||
t.to_string(),
|
||||
"(t) * (u) + (y) * (x) - (x) * (r) - 3 * z + 5"
|
||||
);
|
||||
t.substitute_by_unknown(&"r", &var("y"));
|
||||
// Now the first term in `a` is equal to the last in `b`.
|
||||
assert_eq!(t.to_string(), "(t) * (u) - 3 * z + 5");
|
||||
}
|
||||
}
|
||||
|
||||
@@ -104,7 +104,7 @@ fn is_valid_substitution<T: RuntimeConstant, V: Ord + Clone + Hash + Eq>(
|
||||
constraint_system: &IndexedConstraintSystem<T, V>,
|
||||
degree_bound: DegreeBound,
|
||||
) -> bool {
|
||||
let replacement_deg = expression_degree(expr);
|
||||
let replacement_deg = expr.degree();
|
||||
|
||||
constraint_system
|
||||
.constraints_referencing_variables(std::iter::once(var.clone()))
|
||||
@@ -142,19 +142,6 @@ fn expression_degree_with_virtual_substitution<T: RuntimeConstant, V: Ord + Clon
|
||||
.unwrap_or(0)
|
||||
}
|
||||
|
||||
/// Computes the degree of a GroupedExpression in the unknown variables.
|
||||
/// Variables inside runtime constants are ignored.
|
||||
fn expression_degree<T: RuntimeConstant, V: Ord + Clone>(expr: &GroupedExpression<T, V>) -> usize {
|
||||
let (quadratic, linear, _) = expr.components();
|
||||
|
||||
quadratic
|
||||
.iter()
|
||||
.map(|(l, r)| expression_degree(l) + expression_degree(r))
|
||||
.chain(linear.map(|_| 1))
|
||||
.max()
|
||||
.unwrap_or(0)
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod test {
|
||||
use crate::{
|
||||
|
||||
Reference in New Issue
Block a user