Avoid quadratic terms that sum to zero. (#3084)

This commit is contained in:
chriseth
2025-07-25 11:25:23 +02:00
committed by GitHub
parent 786642c0bc
commit 993a1ac81b
2 changed files with 123 additions and 15 deletions

View File

@@ -18,6 +18,10 @@ use super::effect::{Assertion, BitDecomposition, BitDecompositionComponent, Effe
use super::range_constraint::RangeConstraint;
use super::symbolic_expression::SymbolicExpression;
/// Terms with more than `MAX_SUM_SIZE_FOR_QUADRATIC_ANALYSIS` quadratic terms
/// are not analyzed for pairs that sum to zero.
const MAX_SUM_SIZE_FOR_QUADRATIC_ANALYSIS: usize = 20;
#[derive(Default)]
pub struct ProcessResult<T: RuntimeConstant, V> {
pub effects: Vec<Effect<T, V>>,
@@ -189,6 +193,17 @@ impl<T: RuntimeConstant, V: Ord + Clone + Eq> GroupedExpression<T, V> {
(&self.quadratic, self.linear.iter(), &self.constant)
}
/// Computes the degree of a GroupedExpression (as it is contsructed) in the unknown variables.
/// Variables inside runtime constants are ignored.
pub fn degree(&self) -> usize {
self.quadratic
.iter()
.map(|(l, r)| l.degree() + r.degree())
.chain((!self.linear.is_empty()).then_some(1))
.max()
.unwrap_or(0)
}
/// Returns the coefficient of the variable `variable` if this is an affine expression.
/// Panics if the expression is quadratic.
pub fn coefficient_of_variable(&self, var: &V) -> Option<&T> {
@@ -261,6 +276,8 @@ impl<T: RuntimeConstant + Substitutable<V>, V: Ord + Clone + Eq> GroupedExpressi
_ => true,
}
});
remove_quadratic_terms_adding_to_zero(&mut self.quadratic);
if to_add.try_to_known().map(|ta| ta.is_known_zero()) != Some(true) {
*self += to_add;
}
@@ -306,6 +323,7 @@ impl<T: RuntimeConstant + Substitutable<V>, V: Ord + Clone + Eq> GroupedExpressi
}
})
.collect();
remove_quadratic_terms_adding_to_zero(&mut self.quadratic);
*self += to_add;
}
@@ -862,7 +880,7 @@ impl<T: RuntimeConstant, V: Clone + Ord + Eq> AddAssign<GroupedExpression<T, V>>
for GroupedExpression<T, V>
{
fn add_assign(&mut self, rhs: Self) {
self.quadratic.extend(rhs.quadratic);
self.quadratic = combine_removing_zeros(std::mem::take(&mut self.quadratic), rhs.quadratic);
for (var, coeff) in rhs.linear {
self.linear
.entry(var.clone())
@@ -874,6 +892,83 @@ impl<T: RuntimeConstant, V: Clone + Ord + Eq> AddAssign<GroupedExpression<T, V>>
}
}
/// Returns the sum of these quadratic terms while removing terms that
/// cancel each other out.
fn combine_removing_zeros<E: PartialEq>(first: Vec<(E, E)>, mut second: Vec<(E, E)>) -> Vec<(E, E)>
where
for<'a> &'a E: Neg<Output = E>,
{
if first.len() + second.len() > MAX_SUM_SIZE_FOR_QUADRATIC_ANALYSIS {
// If there are too many terms, we cannot do this efficiently.
return first.into_iter().chain(second).collect();
}
let mut result = first
.into_iter()
.filter(|first| {
// Try to find l1 * r1 inside `second`.
if let Some((j, _)) = second
.iter()
.find_position(|second| quadratic_terms_add_to_zero(first, second))
{
// We found a match, so they cancel each other out, we remove both.
second.remove(j);
false
} else {
true
}
})
.collect_vec();
result.extend(second);
result
}
/// Removes pairs of items from `terms` whose products add to zero.
fn remove_quadratic_terms_adding_to_zero<E: PartialEq>(terms: &mut Vec<(E, E)>)
where
for<'a> &'a E: Neg<Output = E>,
{
if terms.len() > MAX_SUM_SIZE_FOR_QUADRATIC_ANALYSIS {
// If there are too many terms, we cannot do this efficiently.
return;
}
let mut to_remove = HashSet::new();
for ((i, first), (j, second)) in terms.iter().enumerate().tuple_combinations() {
if to_remove.contains(&i) || to_remove.contains(&j) {
// We already removed this term.
continue;
}
if quadratic_terms_add_to_zero(first, second) {
// We found a match, so they cancel each other out, we remove both.
to_remove.insert(i);
to_remove.insert(j);
}
}
if !to_remove.is_empty() {
*terms = terms
.drain(..)
.enumerate()
.filter(|(i, _)| !to_remove.contains(i))
.map(|(_, term)| term)
.collect();
}
}
/// Returns true if `first.0 * first.1 = -second.0 * second.1`,
/// but does not catch all cases.
fn quadratic_terms_add_to_zero<E: PartialEq>(first: &(E, E), second: &(E, E)) -> bool
where
for<'a> &'a E: Neg<Output = E>,
{
let (s0, s1) = second;
// Check if `first.0 * first.1 == -(second.0 * second.1)`, but we can swap left and right
// and we can put the negation either left or right.
let n1 = (&-s0, s1);
let n2 = (s0, &-s1);
[n1, n2].contains(&(&first.0, &first.1)) || [n1, n2].contains(&(&first.1, &first.0))
}
impl<T: RuntimeConstant, V: Clone + Ord + Eq> Sub for &GroupedExpression<T, V> {
type Output = GroupedExpression<T, V>;
@@ -1648,4 +1743,30 @@ c = (((10 + Z) & 0xff000000) >> 24) [negative];
"-t * y"
);
}
#[test]
fn combine_removing_zeros() {
let a = var("x") * var("y") + var("z") * constant(3);
let b = var("t") * var("u") + constant(5) + var("y") * var("x");
assert_eq!(
(a.clone() - b.clone()).to_string(),
"-((t) * (u) - 3 * z + 5)"
);
assert_eq!((b - a).to_string(), "(t) * (u) - 3 * z + 5");
}
#[test]
fn remove_quadratic_zeros_after_substitution() {
let a = var("x") * var("r") + var("z") * constant(3);
let b = var("t") * var("u") + constant(5) + var("y") * var("x");
let mut t = b - a;
// Cannot simplify yet, because the terms are different
assert_eq!(
t.to_string(),
"(t) * (u) + (y) * (x) - (x) * (r) - 3 * z + 5"
);
t.substitute_by_unknown(&"r", &var("y"));
// Now the first term in `a` is equal to the last in `b`.
assert_eq!(t.to_string(), "(t) * (u) - 3 * z + 5");
}
}

View File

@@ -104,7 +104,7 @@ fn is_valid_substitution<T: RuntimeConstant, V: Ord + Clone + Hash + Eq>(
constraint_system: &IndexedConstraintSystem<T, V>,
degree_bound: DegreeBound,
) -> bool {
let replacement_deg = expression_degree(expr);
let replacement_deg = expr.degree();
constraint_system
.constraints_referencing_variables(std::iter::once(var.clone()))
@@ -142,19 +142,6 @@ fn expression_degree_with_virtual_substitution<T: RuntimeConstant, V: Ord + Clon
.unwrap_or(0)
}
/// Computes the degree of a GroupedExpression in the unknown variables.
/// Variables inside runtime constants are ignored.
fn expression_degree<T: RuntimeConstant, V: Ord + Clone>(expr: &GroupedExpression<T, V>) -> usize {
let (quadratic, linear, _) = expr.components();
quadratic
.iter()
.map(|(l, r)| expression_degree(l) + expression_degree(r))
.chain(linear.map(|_| 1))
.max()
.unwrap_or(0)
}
#[cfg(test)]
mod test {
use crate::{