Generic quadratic equiv (#2972)

This commit is contained in:
chriseth
2025-07-02 17:06:16 +02:00
committed by GitHub
parent b96fe393ab
commit e964c9edfb

View File

@@ -3,24 +3,22 @@
use std::{collections::HashSet, fmt::Display, hash::Hash};
use itertools::Itertools;
use powdr_number::FieldElement;
use crate::{
grouped_expression::{QuadraticSymbolicExpression, RangeConstraintProvider},
grouped_expression::{GroupedExpression, RangeConstraintProvider},
range_constraint::RangeConstraint,
runtime_constant::RuntimeConstant,
symbolic_expression::SymbolicExpression,
};
/// Given a list of constraints in the form of quadratic symbolic expressions, tries to determine
/// Given a list of constraints in the form of grouped expressions, tries to determine
/// pairs of equivalent variables.
pub fn find_quadratic_equalities<T: FieldElement, V: Ord + Clone + Hash + Eq + Display>(
constraints: &[QuadraticSymbolicExpression<T, V>],
range_constraints: impl RangeConstraintProvider<T, V>,
pub fn find_quadratic_equalities<T: RuntimeConstant, V: Ord + Clone + Hash + Eq + Display>(
constraints: &[GroupedExpression<T, V>],
range_constraints: impl RangeConstraintProvider<T::FieldType, V>,
) -> Vec<(V, V)> {
let candidates = constraints
.iter()
.filter_map(QuadraticEqualityCandidate::try_from_qse)
.filter_map(QuadraticEqualityCandidate::try_from_grouped_expression)
.filter(|c| c.variables.len() >= 2)
.collect::<Vec<_>>();
candidates
@@ -39,12 +37,12 @@ pub fn find_quadratic_equalities<T: FieldElement, V: Ord + Clone + Hash + Eq + D
/// `expr` or `expr + offset` (see [`QuadraticSymbolicExpression::solve_quadratic`]),
/// then `X` and `Y` must be equal and are returned.
fn process_quadratic_equality_candidate_pair<
T: FieldElement,
T: RuntimeConstant,
V: Ord + Clone + Hash + Eq + Display,
>(
c1: &QuadraticEqualityCandidate<T, V>,
c2: &QuadraticEqualityCandidate<T, V>,
range_constraints: &impl RangeConstraintProvider<T, V>,
range_constraints: &impl RangeConstraintProvider<T::FieldType, V>,
) -> Option<(V, V)> {
if c1.variables.len() != c2.variables.len() || c1.variables.len() < 2 {
return None;
@@ -78,8 +76,8 @@ fn process_quadratic_equality_candidate_pair<
// Now the only remaining check is to see if the affine expressions are the same.
// This could have been the first step, but it is rather expensive, so we do it last.
if c1.expr - QuadraticSymbolicExpression::from_unknown_variable(c1_var.clone())
!= c2.expr - QuadraticSymbolicExpression::from_unknown_variable(c2_var.clone())
if c1.expr - GroupedExpression::from_unknown_variable(c1_var.clone())
!= c2.expr - GroupedExpression::from_unknown_variable(c2_var.clone())
{
return None;
}
@@ -96,19 +94,18 @@ fn process_quadratic_equality_candidate_pair<
}
/// This represents an identity `expr * (expr + offset) = 0`,
/// where `expr` is an affine expression and `offset` is a symbolic expression
/// without unknown variables.
/// where `expr` is an affine expression and `offset` is a runtime constant.
///
/// All unknown variables appearing in `expr` are stored in `variables`.
struct QuadraticEqualityCandidate<T: FieldElement, V: Ord + Clone + Hash + Eq> {
expr: QuadraticSymbolicExpression<T, V>,
offset: SymbolicExpression<T, V>,
struct QuadraticEqualityCandidate<T: RuntimeConstant, V: Ord + Clone + Hash + Eq> {
expr: GroupedExpression<T, V>,
offset: T,
/// All unknown variables in `expr`.
variables: HashSet<V>,
}
impl<T: FieldElement, V: Ord + Clone + Hash + Eq> QuadraticEqualityCandidate<T, V> {
fn try_from_qse(constr: &QuadraticSymbolicExpression<T, V>) -> Option<Self> {
impl<T: RuntimeConstant, V: Ord + Clone + Hash + Eq> QuadraticEqualityCandidate<T, V> {
fn try_from_grouped_expression(constr: &GroupedExpression<T, V>) -> Option<Self> {
let (left, right) = constr.try_as_single_product()?;
if !left.is_affine() || !right.is_affine() {
return None;
@@ -131,18 +128,14 @@ impl<T: FieldElement, V: Ord + Clone + Hash + Eq> QuadraticEqualityCandidate<T,
/// Returns an equivalent candidate that is normalized
/// such that `var` has a coefficient of `1`.
fn normalized_for_var(&self, var: &V) -> Self {
let inverse_coefficient = self
.expr
.coefficient_of_variable(var)
.unwrap()
.field_inverse();
let coefficient = self.expr.coefficient_of_variable(var).unwrap();
// self represents
// `(coeff * var + X) * (coeff * var + X + offset) = 0`
// Dividing by `coeff` twice results in
// `(var + X / coeff) * (var + X / coeff + offset / coeff) = 0`
let offset = &self.offset * &inverse_coefficient;
let expr = self.expr.clone() * inverse_coefficient;
let offset = self.offset.field_div(coefficient);
let expr = self.expr.clone() * coefficient.field_inverse();
Self {
expr,
offset,