From 9e08aa9016dc30ff3288efc27e7c22b3dc1d7ad8 Mon Sep 17 00:00:00 2001 From: chriseth Date: Mon, 11 Aug 2025 17:48:30 +0200 Subject: [PATCH] Introduce var transformation. (#3146) This introduces a single solver layer that does the variable transformations. This way the trait bounds are much easier to state. --- autoprecompiles/src/constraint_optimizer.rs | 3 + constraint-solver/src/solver.rs | 15 +- .../src/solver/boolean_extracted.rs | 126 +++---------- .../src/solver/var_transformation.rs | 178 ++++++++++++++++++ 4 files changed, 221 insertions(+), 101 deletions(-) create mode 100644 constraint-solver/src/solver/var_transformation.rs diff --git a/autoprecompiles/src/constraint_optimizer.rs b/autoprecompiles/src/constraint_optimizer.rs index b8ce9675c..01c5e7692 100644 --- a/autoprecompiles/src/constraint_optimizer.rs +++ b/autoprecompiles/src/constraint_optimizer.rs @@ -89,6 +89,9 @@ fn solver_based_optimization( for (var, value) in assignments.iter() { log::trace!(" {var} = {value}"); } + // Assert that all substitutions are affine so that the degree + // does not increase. + assert!(assignments.iter().all(|(_, expr)| expr.is_affine())); constraint_system.apply_substitutions(assignments); Ok(constraint_system) } diff --git a/constraint-solver/src/solver.rs b/constraint-solver/src/solver.rs index c89b021c4..c97f68dc5 100644 --- a/constraint-solver/src/solver.rs +++ b/constraint-solver/src/solver.rs @@ -7,7 +7,8 @@ use crate::runtime_constant::{ ReferencedSymbols, RuntimeConstant, Substitutable, VarTransformable, }; use crate::solver::base::BaseSolver; -use crate::solver::boolean_extracted::{BooleanExtractedSolver, Variable}; +use crate::solver::boolean_extracted::BooleanExtractedSolver; +use crate::solver::var_transformation::{VarTransformation, Variable}; use super::grouped_expression::{Error as QseError, RangeConstraintProvider}; @@ -19,6 +20,7 @@ mod base; mod boolean_extracted; mod exhaustive_search; mod quadratic_equivalences; +mod var_transformation; /// Solve a constraint system, i.e. derive assignments for variables in the system. pub fn solve_system( @@ -55,11 +57,12 @@ where + Hash, V: Ord + Clone + Hash + Eq + Display, { - let solver = BaseSolver::new(bus_interaction_handler); - let mut boolean_extracted_solver = BooleanExtractedSolver::new(solver); - boolean_extracted_solver.add_algebraic_constraints(constraint_system.algebraic_constraints); - boolean_extracted_solver.add_bus_interactions(constraint_system.bus_interactions); - boolean_extracted_solver + let mut solver = VarTransformation::new(BooleanExtractedSolver::new(BaseSolver::new( + bus_interaction_handler, + ))); + solver.add_algebraic_constraints(constraint_system.algebraic_constraints); + solver.add_bus_interactions(constraint_system.bus_interactions); + solver } pub trait Solver: RangeConstraintProvider + Sized { diff --git a/constraint-solver/src/solver/boolean_extracted.rs b/constraint-solver/src/solver/boolean_extracted.rs index a8534bb05..15e2b222e 100644 --- a/constraint-solver/src/solver/boolean_extracted.rs +++ b/constraint-solver/src/solver/boolean_extracted.rs @@ -2,48 +2,14 @@ use crate::boolean_extractor::try_extract_boolean; use crate::constraint_system::BusInteraction; use crate::grouped_expression::{GroupedExpression, RangeConstraintProvider}; use crate::range_constraint::RangeConstraint; -use crate::runtime_constant::{RuntimeConstant, VarTransformable}; +use crate::runtime_constant::RuntimeConstant; +use crate::solver::var_transformation::Variable; use crate::solver::{Error, Solver, VariableAssignment}; use std::collections::HashSet; -use std::fmt::{Debug, Display}; +use std::fmt::Display; use std::hash::Hash; -/// We introduce new variables. -/// This enum avoids clashes with the original variables. -#[derive(Clone, Debug, PartialEq, Eq, Ord, PartialOrd, Hash)] -pub enum Variable { - /// A regular variable that also exists in the original system. - Original(V), - /// A new boolean-constrained variable that was introduced by the solver. - Boolean(usize), -} - -impl From for Variable { - /// Converts a regular variable to a `Variable`. - fn from(v: V) -> Self { - Variable::Original(v) - } -} - -impl Variable { - pub fn try_to_original(&self) -> Option { - match self { - Variable::Original(v) => Some(v.clone()), - Variable::Boolean(_) => None, - } - } -} - -impl Display for Variable { - fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { - match self { - Variable::Original(v) => write!(f, "{v}"), - Variable::Boolean(i) => write!(f, "boolean_{i}"), - } - } -} - struct BooleanVarDispenser { next_boolean_id: usize, } @@ -70,10 +36,9 @@ pub struct BooleanExtractedSolver { impl BooleanExtractedSolver where - T: RuntimeConstant + VarTransformable>, - T::Transformed: RuntimeConstant, + T: RuntimeConstant, V: Clone + Eq, - S: Solver>, + S: Solver>, { pub fn new(solver: S) -> Self { Self { @@ -84,14 +49,14 @@ where } } -impl RangeConstraintProvider for BooleanExtractedSolver +impl RangeConstraintProvider> for BooleanExtractedSolver where T: RuntimeConstant, S: RangeConstraintProvider>, V: Clone, { - fn get(&self, var: &V) -> RangeConstraint { - self.solver.get(&Variable::from(var.clone())) + fn get(&self, var: &Variable) -> RangeConstraint { + self.solver.get(var) } } @@ -101,37 +66,23 @@ impl Display for BooleanExtractedSolver { } } -impl Solver for BooleanExtractedSolver +impl Solver> for BooleanExtractedSolver where - T: RuntimeConstant + VarTransformable> + Display, - T::Transformed: RuntimeConstant - + VarTransformable, V, Transformed = T> - + Display, + T: RuntimeConstant + Display, V: Ord + Clone + Eq + Hash + Display, - S: Solver>, + S: Solver>, { - /// Solves the system and ignores all assignments that contain a boolean variable - /// (either on the LHS or the RHS). - fn solve(&mut self) -> Result>, Error> { - let assignments = self.solver.solve()?; - Ok(assignments - .into_iter() - .filter_map(|(v, expr)| { - let v = v.try_to_original()?; - let expr = expr.try_transform_var_type(&mut |v| v.try_to_original())?; - Some((v, expr)) - }) - .collect()) + fn solve(&mut self) -> Result>>, Error> { + self.solver.solve() } fn add_algebraic_constraints( &mut self, - constraints: impl IntoIterator>, + constraints: impl IntoIterator>>, ) { let mut new_boolean_vars = vec![]; self.solver .add_algebraic_constraints(constraints.into_iter().flat_map(|constr| { - let constr = constr.transform_var_type(&mut |v| v.clone().into()); let extracted = try_extract_boolean(&constr, &mut || { let v = self.boolean_var_dispenser.next_var(); new_boolean_vars.push(v.clone()); @@ -148,52 +99,37 @@ where fn add_bus_interactions( &mut self, - bus_interactions: impl IntoIterator>>, + bus_interactions: impl IntoIterator>>>, ) { - self.solver - .add_bus_interactions(bus_interactions.into_iter().map(|bus_interaction| { - bus_interaction - .fields() - .map(|expr| { - // We cannot extract booleans here because that only works - // for "constr = 0". - expr.transform_var_type(&mut |v| v.clone().into()) - }) - .collect() - })) + // We cannot extract booleans here because that only works + // for "constr = 0". + self.solver.add_bus_interactions(bus_interactions) } - fn add_range_constraint(&mut self, variable: &V, constraint: RangeConstraint) { - self.solver - .add_range_constraint(&variable.clone().into(), constraint); + fn add_range_constraint( + &mut self, + variable: &Variable, + constraint: RangeConstraint, + ) { + self.solver.add_range_constraint(variable, constraint); } - fn retain_variables(&mut self, variables_to_keep: &HashSet) { - // We do not add boolean variables because we want constraints - // to be removed that only reference variables to be removed and - // boolean variables derived from them. - let variables_to_keep = variables_to_keep - .iter() - .map(|v| Variable::from(v.clone())) - .collect::>(); - self.solver.retain_variables(&variables_to_keep); + fn retain_variables(&mut self, variables_to_keep: &HashSet>) { + self.solver.retain_variables(variables_to_keep); } fn range_constraint_for_expression( &self, - expr: &GroupedExpression, + expr: &GroupedExpression>, ) -> RangeConstraint { - let expr = expr.transform_var_type(&mut |v| v.clone().into()); - self.solver.range_constraint_for_expression(&expr) + self.solver.range_constraint_for_expression(expr) } fn are_expressions_known_to_be_different( &mut self, - a: &GroupedExpression, - b: &GroupedExpression, + a: &GroupedExpression>, + b: &GroupedExpression>, ) -> bool { - let a = a.transform_var_type(&mut |v| v.clone().into()); - let b = b.transform_var_type(&mut |v| v.clone().into()); - self.solver.are_expressions_known_to_be_different(&a, &b) + self.solver.are_expressions_known_to_be_different(a, b) } } diff --git a/constraint-solver/src/solver/var_transformation.rs b/constraint-solver/src/solver/var_transformation.rs new file mode 100644 index 000000000..298e720cf --- /dev/null +++ b/constraint-solver/src/solver/var_transformation.rs @@ -0,0 +1,178 @@ +use crate::constraint_system::BusInteraction; +use crate::grouped_expression::{GroupedExpression, RangeConstraintProvider}; +use crate::range_constraint::RangeConstraint; +use crate::runtime_constant::{RuntimeConstant, VarTransformable}; +use crate::solver::{Error, Solver, VariableAssignment}; + +use std::collections::HashSet; +use std::fmt::{Debug, Display}; +use std::hash::Hash; + +/// We introduce new variables. +/// This enum avoids clashes with the original variables. +#[derive(Clone, Debug, PartialEq, Eq, Ord, PartialOrd, Hash)] +pub enum Variable { + /// A regular variable that also exists in the original system. + Original(V), + /// A new boolean-constrained variable that was introduced by the solver. + Boolean(usize), + /// A new variable introduced by the linearizer. + Linear(usize), +} + +impl From for Variable { + /// Converts a regular variable to a `Variable`. + fn from(v: V) -> Self { + Variable::Original(v) + } +} + +impl From<&V> for Variable { + /// Converts a regular variable to a `Variable`. + fn from(v: &V) -> Self { + Variable::Original(v.clone()) + } +} + +impl Variable { + pub fn try_to_original(&self) -> Option { + match self { + Variable::Original(v) => Some(v.clone()), + _ => None, + } + } +} + +impl Display for Variable { + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + match self { + Variable::Original(v) => write!(f, "{v}"), + Variable::Boolean(i) => write!(f, "bool_{i}"), + Variable::Linear(i) => write!(f, "lin_{i}"), + } + } +} + +/// A solver that transforms variables from one type to another, +pub struct VarTransformation { + solver: S, + _phantom: std::marker::PhantomData<(T, V)>, +} + +impl VarTransformation +where + T: RuntimeConstant + VarTransformable>, + T::Transformed: RuntimeConstant, + V: Clone + Eq, + S: Solver>, +{ + pub fn new(solver: S) -> Self { + Self { + solver, + _phantom: std::marker::PhantomData, + } + } +} + +impl RangeConstraintProvider for VarTransformation +where + T: RuntimeConstant, + S: RangeConstraintProvider>, + V: Clone, +{ + fn get(&self, var: &V) -> RangeConstraint { + self.solver.get(&Variable::from(var)) + } +} + +impl Display for VarTransformation { + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + write!(f, "{}", self.solver) + } +} + +impl Solver for VarTransformation +where + T: RuntimeConstant + VarTransformable> + Display, + T::Transformed: RuntimeConstant + + VarTransformable, V, Transformed = T> + + Display, + V: Ord + Clone + Eq + Hash + Display, + S: Solver>, +{ + /// Solves the system and ignores all assignments that contain a new variable + /// (either on the LHS or the RHS). + fn solve(&mut self) -> Result>, Error> { + let assignments = self.solver.solve()?; + Ok(assignments + .into_iter() + .filter_map(|(v, expr)| { + assert!(expr.is_affine()); + let v = v.try_to_original()?; + let expr = expr.try_transform_var_type(&mut |v| v.try_to_original())?; + Some((v, expr)) + }) + .collect()) + } + + fn add_algebraic_constraints( + &mut self, + constraints: impl IntoIterator>, + ) { + self.solver + .add_algebraic_constraints(constraints.into_iter().map(|c| transform(&c))); + } + + fn add_bus_interactions( + &mut self, + bus_interactions: impl IntoIterator>>, + ) { + self.solver.add_bus_interactions( + bus_interactions + .into_iter() + .map(|bus_interaction| bus_interaction.fields().map(transform).collect()), + ) + } + + fn add_range_constraint(&mut self, variable: &V, constraint: RangeConstraint) { + self.solver + .add_range_constraint(&variable.into(), constraint); + } + + fn retain_variables(&mut self, variables_to_keep: &HashSet) { + // This will cause constraints to be deleted if they + // only contain newly added variables. + let variables_to_keep = variables_to_keep + .iter() + .map(From::from) + .collect::>(); + self.solver.retain_variables(&variables_to_keep); + } + + fn range_constraint_for_expression( + &self, + expr: &GroupedExpression, + ) -> RangeConstraint { + self.solver + .range_constraint_for_expression(&transform(expr)) + } + + fn are_expressions_known_to_be_different( + &mut self, + a: &GroupedExpression, + b: &GroupedExpression, + ) -> bool { + let a = transform(a); + let b = transform(b); + self.solver.are_expressions_known_to_be_different(&a, &b) + } +} + +fn transform( + expr: &GroupedExpression, +) -> GroupedExpression> +where + T: RuntimeConstant + VarTransformable>, +{ + expr.transform_var_type(&mut |v| v.into()) +}