Introduce var transformation. (#3146)

This introduces a single solver layer that does the variable
transformations. This way the trait bounds are much easier to state.
This commit is contained in:
chriseth
2025-08-11 17:48:30 +02:00
committed by GitHub
parent f03aa041ff
commit 9e08aa9016
4 changed files with 221 additions and 101 deletions

View File

@@ -89,6 +89,9 @@ fn solver_based_optimization<T: FieldElement, V: Clone + Ord + Hash + Display>(
for (var, value) in assignments.iter() {
log::trace!(" {var} = {value}");
}
// Assert that all substitutions are affine so that the degree
// does not increase.
assert!(assignments.iter().all(|(_, expr)| expr.is_affine()));
constraint_system.apply_substitutions(assignments);
Ok(constraint_system)
}

View File

@@ -7,7 +7,8 @@ use crate::runtime_constant::{
ReferencedSymbols, RuntimeConstant, Substitutable, VarTransformable,
};
use crate::solver::base::BaseSolver;
use crate::solver::boolean_extracted::{BooleanExtractedSolver, Variable};
use crate::solver::boolean_extracted::BooleanExtractedSolver;
use crate::solver::var_transformation::{VarTransformation, Variable};
use super::grouped_expression::{Error as QseError, RangeConstraintProvider};
@@ -19,6 +20,7 @@ mod base;
mod boolean_extracted;
mod exhaustive_search;
mod quadratic_equivalences;
mod var_transformation;
/// Solve a constraint system, i.e. derive assignments for variables in the system.
pub fn solve_system<T, V>(
@@ -55,11 +57,12 @@ where
+ Hash,
V: Ord + Clone + Hash + Eq + Display,
{
let solver = BaseSolver::new(bus_interaction_handler);
let mut boolean_extracted_solver = BooleanExtractedSolver::new(solver);
boolean_extracted_solver.add_algebraic_constraints(constraint_system.algebraic_constraints);
boolean_extracted_solver.add_bus_interactions(constraint_system.bus_interactions);
boolean_extracted_solver
let mut solver = VarTransformation::new(BooleanExtractedSolver::new(BaseSolver::new(
bus_interaction_handler,
)));
solver.add_algebraic_constraints(constraint_system.algebraic_constraints);
solver.add_bus_interactions(constraint_system.bus_interactions);
solver
}
pub trait Solver<T: RuntimeConstant, V>: RangeConstraintProvider<T::FieldType, V> + Sized {

View File

@@ -2,48 +2,14 @@ use crate::boolean_extractor::try_extract_boolean;
use crate::constraint_system::BusInteraction;
use crate::grouped_expression::{GroupedExpression, RangeConstraintProvider};
use crate::range_constraint::RangeConstraint;
use crate::runtime_constant::{RuntimeConstant, VarTransformable};
use crate::runtime_constant::RuntimeConstant;
use crate::solver::var_transformation::Variable;
use crate::solver::{Error, Solver, VariableAssignment};
use std::collections::HashSet;
use std::fmt::{Debug, Display};
use std::fmt::Display;
use std::hash::Hash;
/// We introduce new variables.
/// This enum avoids clashes with the original variables.
#[derive(Clone, Debug, PartialEq, Eq, Ord, PartialOrd, Hash)]
pub enum Variable<V> {
/// A regular variable that also exists in the original system.
Original(V),
/// A new boolean-constrained variable that was introduced by the solver.
Boolean(usize),
}
impl<V> From<V> for Variable<V> {
/// Converts a regular variable to a `Variable`.
fn from(v: V) -> Self {
Variable::Original(v)
}
}
impl<V: Clone> Variable<V> {
pub fn try_to_original(&self) -> Option<V> {
match self {
Variable::Original(v) => Some(v.clone()),
Variable::Boolean(_) => None,
}
}
}
impl<V: Display> Display for Variable<V> {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
match self {
Variable::Original(v) => write!(f, "{v}"),
Variable::Boolean(i) => write!(f, "boolean_{i}"),
}
}
}
struct BooleanVarDispenser {
next_boolean_id: usize,
}
@@ -70,10 +36,9 @@ pub struct BooleanExtractedSolver<T, V, S> {
impl<T, V, S> BooleanExtractedSolver<T, V, S>
where
T: RuntimeConstant + VarTransformable<V, Variable<V>>,
T::Transformed: RuntimeConstant<FieldType = T::FieldType>,
T: RuntimeConstant,
V: Clone + Eq,
S: Solver<T::Transformed, Variable<V>>,
S: Solver<T, Variable<V>>,
{
pub fn new(solver: S) -> Self {
Self {
@@ -84,14 +49,14 @@ where
}
}
impl<T, V, S> RangeConstraintProvider<T::FieldType, V> for BooleanExtractedSolver<T, V, S>
impl<T, V, S> RangeConstraintProvider<T::FieldType, Variable<V>> for BooleanExtractedSolver<T, V, S>
where
T: RuntimeConstant,
S: RangeConstraintProvider<T::FieldType, Variable<V>>,
V: Clone,
{
fn get(&self, var: &V) -> RangeConstraint<T::FieldType> {
self.solver.get(&Variable::from(var.clone()))
fn get(&self, var: &Variable<V>) -> RangeConstraint<T::FieldType> {
self.solver.get(var)
}
}
@@ -101,37 +66,23 @@ impl<T, V, S: Display> Display for BooleanExtractedSolver<T, V, S> {
}
}
impl<T, V, S> Solver<T, V> for BooleanExtractedSolver<T, V, S>
impl<T, V, S> Solver<T, Variable<V>> for BooleanExtractedSolver<T, V, S>
where
T: RuntimeConstant + VarTransformable<V, Variable<V>> + Display,
T::Transformed: RuntimeConstant<FieldType = T::FieldType>
+ VarTransformable<Variable<V>, V, Transformed = T>
+ Display,
T: RuntimeConstant + Display,
V: Ord + Clone + Eq + Hash + Display,
S: Solver<T::Transformed, Variable<V>>,
S: Solver<T, Variable<V>>,
{
/// Solves the system and ignores all assignments that contain a boolean variable
/// (either on the LHS or the RHS).
fn solve(&mut self) -> Result<Vec<VariableAssignment<T, V>>, Error> {
let assignments = self.solver.solve()?;
Ok(assignments
.into_iter()
.filter_map(|(v, expr)| {
let v = v.try_to_original()?;
let expr = expr.try_transform_var_type(&mut |v| v.try_to_original())?;
Some((v, expr))
})
.collect())
fn solve(&mut self) -> Result<Vec<VariableAssignment<T, Variable<V>>>, Error> {
self.solver.solve()
}
fn add_algebraic_constraints(
&mut self,
constraints: impl IntoIterator<Item = GroupedExpression<T, V>>,
constraints: impl IntoIterator<Item = GroupedExpression<T, Variable<V>>>,
) {
let mut new_boolean_vars = vec![];
self.solver
.add_algebraic_constraints(constraints.into_iter().flat_map(|constr| {
let constr = constr.transform_var_type(&mut |v| v.clone().into());
let extracted = try_extract_boolean(&constr, &mut || {
let v = self.boolean_var_dispenser.next_var();
new_boolean_vars.push(v.clone());
@@ -148,52 +99,37 @@ where
fn add_bus_interactions(
&mut self,
bus_interactions: impl IntoIterator<Item = BusInteraction<GroupedExpression<T, V>>>,
bus_interactions: impl IntoIterator<Item = BusInteraction<GroupedExpression<T, Variable<V>>>>,
) {
self.solver
.add_bus_interactions(bus_interactions.into_iter().map(|bus_interaction| {
bus_interaction
.fields()
.map(|expr| {
// We cannot extract booleans here because that only works
// for "constr = 0".
expr.transform_var_type(&mut |v| v.clone().into())
})
.collect()
}))
// We cannot extract booleans here because that only works
// for "constr = 0".
self.solver.add_bus_interactions(bus_interactions)
}
fn add_range_constraint(&mut self, variable: &V, constraint: RangeConstraint<T::FieldType>) {
self.solver
.add_range_constraint(&variable.clone().into(), constraint);
fn add_range_constraint(
&mut self,
variable: &Variable<V>,
constraint: RangeConstraint<T::FieldType>,
) {
self.solver.add_range_constraint(variable, constraint);
}
fn retain_variables(&mut self, variables_to_keep: &HashSet<V>) {
// We do not add boolean variables because we want constraints
// to be removed that only reference variables to be removed and
// boolean variables derived from them.
let variables_to_keep = variables_to_keep
.iter()
.map(|v| Variable::from(v.clone()))
.collect::<HashSet<_>>();
self.solver.retain_variables(&variables_to_keep);
fn retain_variables(&mut self, variables_to_keep: &HashSet<Variable<V>>) {
self.solver.retain_variables(variables_to_keep);
}
fn range_constraint_for_expression(
&self,
expr: &GroupedExpression<T, V>,
expr: &GroupedExpression<T, Variable<V>>,
) -> RangeConstraint<T::FieldType> {
let expr = expr.transform_var_type(&mut |v| v.clone().into());
self.solver.range_constraint_for_expression(&expr)
self.solver.range_constraint_for_expression(expr)
}
fn are_expressions_known_to_be_different(
&mut self,
a: &GroupedExpression<T, V>,
b: &GroupedExpression<T, V>,
a: &GroupedExpression<T, Variable<V>>,
b: &GroupedExpression<T, Variable<V>>,
) -> bool {
let a = a.transform_var_type(&mut |v| v.clone().into());
let b = b.transform_var_type(&mut |v| v.clone().into());
self.solver.are_expressions_known_to_be_different(&a, &b)
self.solver.are_expressions_known_to_be_different(a, b)
}
}

View File

@@ -0,0 +1,178 @@
use crate::constraint_system::BusInteraction;
use crate::grouped_expression::{GroupedExpression, RangeConstraintProvider};
use crate::range_constraint::RangeConstraint;
use crate::runtime_constant::{RuntimeConstant, VarTransformable};
use crate::solver::{Error, Solver, VariableAssignment};
use std::collections::HashSet;
use std::fmt::{Debug, Display};
use std::hash::Hash;
/// We introduce new variables.
/// This enum avoids clashes with the original variables.
#[derive(Clone, Debug, PartialEq, Eq, Ord, PartialOrd, Hash)]
pub enum Variable<V> {
/// A regular variable that also exists in the original system.
Original(V),
/// A new boolean-constrained variable that was introduced by the solver.
Boolean(usize),
/// A new variable introduced by the linearizer.
Linear(usize),
}
impl<V> From<V> for Variable<V> {
/// Converts a regular variable to a `Variable`.
fn from(v: V) -> Self {
Variable::Original(v)
}
}
impl<V: Clone> From<&V> for Variable<V> {
/// Converts a regular variable to a `Variable`.
fn from(v: &V) -> Self {
Variable::Original(v.clone())
}
}
impl<V: Clone> Variable<V> {
pub fn try_to_original(&self) -> Option<V> {
match self {
Variable::Original(v) => Some(v.clone()),
_ => None,
}
}
}
impl<V: Display> Display for Variable<V> {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
match self {
Variable::Original(v) => write!(f, "{v}"),
Variable::Boolean(i) => write!(f, "bool_{i}"),
Variable::Linear(i) => write!(f, "lin_{i}"),
}
}
}
/// A solver that transforms variables from one type to another,
pub struct VarTransformation<T, V, S> {
solver: S,
_phantom: std::marker::PhantomData<(T, V)>,
}
impl<T, V, S> VarTransformation<T, V, S>
where
T: RuntimeConstant + VarTransformable<V, Variable<V>>,
T::Transformed: RuntimeConstant<FieldType = T::FieldType>,
V: Clone + Eq,
S: Solver<T::Transformed, Variable<V>>,
{
pub fn new(solver: S) -> Self {
Self {
solver,
_phantom: std::marker::PhantomData,
}
}
}
impl<T, V, S> RangeConstraintProvider<T::FieldType, V> for VarTransformation<T, V, S>
where
T: RuntimeConstant,
S: RangeConstraintProvider<T::FieldType, Variable<V>>,
V: Clone,
{
fn get(&self, var: &V) -> RangeConstraint<T::FieldType> {
self.solver.get(&Variable::from(var))
}
}
impl<T, V, S: Display> Display for VarTransformation<T, V, S> {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
write!(f, "{}", self.solver)
}
}
impl<T, V, S> Solver<T, V> for VarTransformation<T, V, S>
where
T: RuntimeConstant + VarTransformable<V, Variable<V>> + Display,
T::Transformed: RuntimeConstant<FieldType = T::FieldType>
+ VarTransformable<Variable<V>, V, Transformed = T>
+ Display,
V: Ord + Clone + Eq + Hash + Display,
S: Solver<T::Transformed, Variable<V>>,
{
/// Solves the system and ignores all assignments that contain a new variable
/// (either on the LHS or the RHS).
fn solve(&mut self) -> Result<Vec<VariableAssignment<T, V>>, Error> {
let assignments = self.solver.solve()?;
Ok(assignments
.into_iter()
.filter_map(|(v, expr)| {
assert!(expr.is_affine());
let v = v.try_to_original()?;
let expr = expr.try_transform_var_type(&mut |v| v.try_to_original())?;
Some((v, expr))
})
.collect())
}
fn add_algebraic_constraints(
&mut self,
constraints: impl IntoIterator<Item = GroupedExpression<T, V>>,
) {
self.solver
.add_algebraic_constraints(constraints.into_iter().map(|c| transform(&c)));
}
fn add_bus_interactions(
&mut self,
bus_interactions: impl IntoIterator<Item = BusInteraction<GroupedExpression<T, V>>>,
) {
self.solver.add_bus_interactions(
bus_interactions
.into_iter()
.map(|bus_interaction| bus_interaction.fields().map(transform).collect()),
)
}
fn add_range_constraint(&mut self, variable: &V, constraint: RangeConstraint<T::FieldType>) {
self.solver
.add_range_constraint(&variable.into(), constraint);
}
fn retain_variables(&mut self, variables_to_keep: &HashSet<V>) {
// This will cause constraints to be deleted if they
// only contain newly added variables.
let variables_to_keep = variables_to_keep
.iter()
.map(From::from)
.collect::<HashSet<_>>();
self.solver.retain_variables(&variables_to_keep);
}
fn range_constraint_for_expression(
&self,
expr: &GroupedExpression<T, V>,
) -> RangeConstraint<T::FieldType> {
self.solver
.range_constraint_for_expression(&transform(expr))
}
fn are_expressions_known_to_be_different(
&mut self,
a: &GroupedExpression<T, V>,
b: &GroupedExpression<T, V>,
) -> bool {
let a = transform(a);
let b = transform(b);
self.solver.are_expressions_known_to_be_different(&a, &b)
}
}
fn transform<T, V: Ord + Clone>(
expr: &GroupedExpression<T, V>,
) -> GroupedExpression<T::Transformed, Variable<V>>
where
T: RuntimeConstant + VarTransformable<V, Variable<V>>,
{
expr.transform_var_type(&mut |v| v.into())
}