mirror of
https://github.com/powdr-labs/powdr.git
synced 2026-01-09 14:48:16 -05:00
Introduce var transformation. (#3146)
This introduces a single solver layer that does the variable transformations. This way the trait bounds are much easier to state.
This commit is contained in:
@@ -89,6 +89,9 @@ fn solver_based_optimization<T: FieldElement, V: Clone + Ord + Hash + Display>(
|
||||
for (var, value) in assignments.iter() {
|
||||
log::trace!(" {var} = {value}");
|
||||
}
|
||||
// Assert that all substitutions are affine so that the degree
|
||||
// does not increase.
|
||||
assert!(assignments.iter().all(|(_, expr)| expr.is_affine()));
|
||||
constraint_system.apply_substitutions(assignments);
|
||||
Ok(constraint_system)
|
||||
}
|
||||
|
||||
@@ -7,7 +7,8 @@ use crate::runtime_constant::{
|
||||
ReferencedSymbols, RuntimeConstant, Substitutable, VarTransformable,
|
||||
};
|
||||
use crate::solver::base::BaseSolver;
|
||||
use crate::solver::boolean_extracted::{BooleanExtractedSolver, Variable};
|
||||
use crate::solver::boolean_extracted::BooleanExtractedSolver;
|
||||
use crate::solver::var_transformation::{VarTransformation, Variable};
|
||||
|
||||
use super::grouped_expression::{Error as QseError, RangeConstraintProvider};
|
||||
|
||||
@@ -19,6 +20,7 @@ mod base;
|
||||
mod boolean_extracted;
|
||||
mod exhaustive_search;
|
||||
mod quadratic_equivalences;
|
||||
mod var_transformation;
|
||||
|
||||
/// Solve a constraint system, i.e. derive assignments for variables in the system.
|
||||
pub fn solve_system<T, V>(
|
||||
@@ -55,11 +57,12 @@ where
|
||||
+ Hash,
|
||||
V: Ord + Clone + Hash + Eq + Display,
|
||||
{
|
||||
let solver = BaseSolver::new(bus_interaction_handler);
|
||||
let mut boolean_extracted_solver = BooleanExtractedSolver::new(solver);
|
||||
boolean_extracted_solver.add_algebraic_constraints(constraint_system.algebraic_constraints);
|
||||
boolean_extracted_solver.add_bus_interactions(constraint_system.bus_interactions);
|
||||
boolean_extracted_solver
|
||||
let mut solver = VarTransformation::new(BooleanExtractedSolver::new(BaseSolver::new(
|
||||
bus_interaction_handler,
|
||||
)));
|
||||
solver.add_algebraic_constraints(constraint_system.algebraic_constraints);
|
||||
solver.add_bus_interactions(constraint_system.bus_interactions);
|
||||
solver
|
||||
}
|
||||
|
||||
pub trait Solver<T: RuntimeConstant, V>: RangeConstraintProvider<T::FieldType, V> + Sized {
|
||||
|
||||
@@ -2,48 +2,14 @@ use crate::boolean_extractor::try_extract_boolean;
|
||||
use crate::constraint_system::BusInteraction;
|
||||
use crate::grouped_expression::{GroupedExpression, RangeConstraintProvider};
|
||||
use crate::range_constraint::RangeConstraint;
|
||||
use crate::runtime_constant::{RuntimeConstant, VarTransformable};
|
||||
use crate::runtime_constant::RuntimeConstant;
|
||||
use crate::solver::var_transformation::Variable;
|
||||
use crate::solver::{Error, Solver, VariableAssignment};
|
||||
|
||||
use std::collections::HashSet;
|
||||
use std::fmt::{Debug, Display};
|
||||
use std::fmt::Display;
|
||||
use std::hash::Hash;
|
||||
|
||||
/// We introduce new variables.
|
||||
/// This enum avoids clashes with the original variables.
|
||||
#[derive(Clone, Debug, PartialEq, Eq, Ord, PartialOrd, Hash)]
|
||||
pub enum Variable<V> {
|
||||
/// A regular variable that also exists in the original system.
|
||||
Original(V),
|
||||
/// A new boolean-constrained variable that was introduced by the solver.
|
||||
Boolean(usize),
|
||||
}
|
||||
|
||||
impl<V> From<V> for Variable<V> {
|
||||
/// Converts a regular variable to a `Variable`.
|
||||
fn from(v: V) -> Self {
|
||||
Variable::Original(v)
|
||||
}
|
||||
}
|
||||
|
||||
impl<V: Clone> Variable<V> {
|
||||
pub fn try_to_original(&self) -> Option<V> {
|
||||
match self {
|
||||
Variable::Original(v) => Some(v.clone()),
|
||||
Variable::Boolean(_) => None,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl<V: Display> Display for Variable<V> {
|
||||
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
|
||||
match self {
|
||||
Variable::Original(v) => write!(f, "{v}"),
|
||||
Variable::Boolean(i) => write!(f, "boolean_{i}"),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
struct BooleanVarDispenser {
|
||||
next_boolean_id: usize,
|
||||
}
|
||||
@@ -70,10 +36,9 @@ pub struct BooleanExtractedSolver<T, V, S> {
|
||||
|
||||
impl<T, V, S> BooleanExtractedSolver<T, V, S>
|
||||
where
|
||||
T: RuntimeConstant + VarTransformable<V, Variable<V>>,
|
||||
T::Transformed: RuntimeConstant<FieldType = T::FieldType>,
|
||||
T: RuntimeConstant,
|
||||
V: Clone + Eq,
|
||||
S: Solver<T::Transformed, Variable<V>>,
|
||||
S: Solver<T, Variable<V>>,
|
||||
{
|
||||
pub fn new(solver: S) -> Self {
|
||||
Self {
|
||||
@@ -84,14 +49,14 @@ where
|
||||
}
|
||||
}
|
||||
|
||||
impl<T, V, S> RangeConstraintProvider<T::FieldType, V> for BooleanExtractedSolver<T, V, S>
|
||||
impl<T, V, S> RangeConstraintProvider<T::FieldType, Variable<V>> for BooleanExtractedSolver<T, V, S>
|
||||
where
|
||||
T: RuntimeConstant,
|
||||
S: RangeConstraintProvider<T::FieldType, Variable<V>>,
|
||||
V: Clone,
|
||||
{
|
||||
fn get(&self, var: &V) -> RangeConstraint<T::FieldType> {
|
||||
self.solver.get(&Variable::from(var.clone()))
|
||||
fn get(&self, var: &Variable<V>) -> RangeConstraint<T::FieldType> {
|
||||
self.solver.get(var)
|
||||
}
|
||||
}
|
||||
|
||||
@@ -101,37 +66,23 @@ impl<T, V, S: Display> Display for BooleanExtractedSolver<T, V, S> {
|
||||
}
|
||||
}
|
||||
|
||||
impl<T, V, S> Solver<T, V> for BooleanExtractedSolver<T, V, S>
|
||||
impl<T, V, S> Solver<T, Variable<V>> for BooleanExtractedSolver<T, V, S>
|
||||
where
|
||||
T: RuntimeConstant + VarTransformable<V, Variable<V>> + Display,
|
||||
T::Transformed: RuntimeConstant<FieldType = T::FieldType>
|
||||
+ VarTransformable<Variable<V>, V, Transformed = T>
|
||||
+ Display,
|
||||
T: RuntimeConstant + Display,
|
||||
V: Ord + Clone + Eq + Hash + Display,
|
||||
S: Solver<T::Transformed, Variable<V>>,
|
||||
S: Solver<T, Variable<V>>,
|
||||
{
|
||||
/// Solves the system and ignores all assignments that contain a boolean variable
|
||||
/// (either on the LHS or the RHS).
|
||||
fn solve(&mut self) -> Result<Vec<VariableAssignment<T, V>>, Error> {
|
||||
let assignments = self.solver.solve()?;
|
||||
Ok(assignments
|
||||
.into_iter()
|
||||
.filter_map(|(v, expr)| {
|
||||
let v = v.try_to_original()?;
|
||||
let expr = expr.try_transform_var_type(&mut |v| v.try_to_original())?;
|
||||
Some((v, expr))
|
||||
})
|
||||
.collect())
|
||||
fn solve(&mut self) -> Result<Vec<VariableAssignment<T, Variable<V>>>, Error> {
|
||||
self.solver.solve()
|
||||
}
|
||||
|
||||
fn add_algebraic_constraints(
|
||||
&mut self,
|
||||
constraints: impl IntoIterator<Item = GroupedExpression<T, V>>,
|
||||
constraints: impl IntoIterator<Item = GroupedExpression<T, Variable<V>>>,
|
||||
) {
|
||||
let mut new_boolean_vars = vec![];
|
||||
self.solver
|
||||
.add_algebraic_constraints(constraints.into_iter().flat_map(|constr| {
|
||||
let constr = constr.transform_var_type(&mut |v| v.clone().into());
|
||||
let extracted = try_extract_boolean(&constr, &mut || {
|
||||
let v = self.boolean_var_dispenser.next_var();
|
||||
new_boolean_vars.push(v.clone());
|
||||
@@ -148,52 +99,37 @@ where
|
||||
|
||||
fn add_bus_interactions(
|
||||
&mut self,
|
||||
bus_interactions: impl IntoIterator<Item = BusInteraction<GroupedExpression<T, V>>>,
|
||||
bus_interactions: impl IntoIterator<Item = BusInteraction<GroupedExpression<T, Variable<V>>>>,
|
||||
) {
|
||||
self.solver
|
||||
.add_bus_interactions(bus_interactions.into_iter().map(|bus_interaction| {
|
||||
bus_interaction
|
||||
.fields()
|
||||
.map(|expr| {
|
||||
// We cannot extract booleans here because that only works
|
||||
// for "constr = 0".
|
||||
expr.transform_var_type(&mut |v| v.clone().into())
|
||||
})
|
||||
.collect()
|
||||
}))
|
||||
// We cannot extract booleans here because that only works
|
||||
// for "constr = 0".
|
||||
self.solver.add_bus_interactions(bus_interactions)
|
||||
}
|
||||
|
||||
fn add_range_constraint(&mut self, variable: &V, constraint: RangeConstraint<T::FieldType>) {
|
||||
self.solver
|
||||
.add_range_constraint(&variable.clone().into(), constraint);
|
||||
fn add_range_constraint(
|
||||
&mut self,
|
||||
variable: &Variable<V>,
|
||||
constraint: RangeConstraint<T::FieldType>,
|
||||
) {
|
||||
self.solver.add_range_constraint(variable, constraint);
|
||||
}
|
||||
|
||||
fn retain_variables(&mut self, variables_to_keep: &HashSet<V>) {
|
||||
// We do not add boolean variables because we want constraints
|
||||
// to be removed that only reference variables to be removed and
|
||||
// boolean variables derived from them.
|
||||
let variables_to_keep = variables_to_keep
|
||||
.iter()
|
||||
.map(|v| Variable::from(v.clone()))
|
||||
.collect::<HashSet<_>>();
|
||||
self.solver.retain_variables(&variables_to_keep);
|
||||
fn retain_variables(&mut self, variables_to_keep: &HashSet<Variable<V>>) {
|
||||
self.solver.retain_variables(variables_to_keep);
|
||||
}
|
||||
|
||||
fn range_constraint_for_expression(
|
||||
&self,
|
||||
expr: &GroupedExpression<T, V>,
|
||||
expr: &GroupedExpression<T, Variable<V>>,
|
||||
) -> RangeConstraint<T::FieldType> {
|
||||
let expr = expr.transform_var_type(&mut |v| v.clone().into());
|
||||
self.solver.range_constraint_for_expression(&expr)
|
||||
self.solver.range_constraint_for_expression(expr)
|
||||
}
|
||||
|
||||
fn are_expressions_known_to_be_different(
|
||||
&mut self,
|
||||
a: &GroupedExpression<T, V>,
|
||||
b: &GroupedExpression<T, V>,
|
||||
a: &GroupedExpression<T, Variable<V>>,
|
||||
b: &GroupedExpression<T, Variable<V>>,
|
||||
) -> bool {
|
||||
let a = a.transform_var_type(&mut |v| v.clone().into());
|
||||
let b = b.transform_var_type(&mut |v| v.clone().into());
|
||||
self.solver.are_expressions_known_to_be_different(&a, &b)
|
||||
self.solver.are_expressions_known_to_be_different(a, b)
|
||||
}
|
||||
}
|
||||
|
||||
178
constraint-solver/src/solver/var_transformation.rs
Normal file
178
constraint-solver/src/solver/var_transformation.rs
Normal file
@@ -0,0 +1,178 @@
|
||||
use crate::constraint_system::BusInteraction;
|
||||
use crate::grouped_expression::{GroupedExpression, RangeConstraintProvider};
|
||||
use crate::range_constraint::RangeConstraint;
|
||||
use crate::runtime_constant::{RuntimeConstant, VarTransformable};
|
||||
use crate::solver::{Error, Solver, VariableAssignment};
|
||||
|
||||
use std::collections::HashSet;
|
||||
use std::fmt::{Debug, Display};
|
||||
use std::hash::Hash;
|
||||
|
||||
/// We introduce new variables.
|
||||
/// This enum avoids clashes with the original variables.
|
||||
#[derive(Clone, Debug, PartialEq, Eq, Ord, PartialOrd, Hash)]
|
||||
pub enum Variable<V> {
|
||||
/// A regular variable that also exists in the original system.
|
||||
Original(V),
|
||||
/// A new boolean-constrained variable that was introduced by the solver.
|
||||
Boolean(usize),
|
||||
/// A new variable introduced by the linearizer.
|
||||
Linear(usize),
|
||||
}
|
||||
|
||||
impl<V> From<V> for Variable<V> {
|
||||
/// Converts a regular variable to a `Variable`.
|
||||
fn from(v: V) -> Self {
|
||||
Variable::Original(v)
|
||||
}
|
||||
}
|
||||
|
||||
impl<V: Clone> From<&V> for Variable<V> {
|
||||
/// Converts a regular variable to a `Variable`.
|
||||
fn from(v: &V) -> Self {
|
||||
Variable::Original(v.clone())
|
||||
}
|
||||
}
|
||||
|
||||
impl<V: Clone> Variable<V> {
|
||||
pub fn try_to_original(&self) -> Option<V> {
|
||||
match self {
|
||||
Variable::Original(v) => Some(v.clone()),
|
||||
_ => None,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl<V: Display> Display for Variable<V> {
|
||||
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
|
||||
match self {
|
||||
Variable::Original(v) => write!(f, "{v}"),
|
||||
Variable::Boolean(i) => write!(f, "bool_{i}"),
|
||||
Variable::Linear(i) => write!(f, "lin_{i}"),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
/// A solver that transforms variables from one type to another,
|
||||
pub struct VarTransformation<T, V, S> {
|
||||
solver: S,
|
||||
_phantom: std::marker::PhantomData<(T, V)>,
|
||||
}
|
||||
|
||||
impl<T, V, S> VarTransformation<T, V, S>
|
||||
where
|
||||
T: RuntimeConstant + VarTransformable<V, Variable<V>>,
|
||||
T::Transformed: RuntimeConstant<FieldType = T::FieldType>,
|
||||
V: Clone + Eq,
|
||||
S: Solver<T::Transformed, Variable<V>>,
|
||||
{
|
||||
pub fn new(solver: S) -> Self {
|
||||
Self {
|
||||
solver,
|
||||
_phantom: std::marker::PhantomData,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl<T, V, S> RangeConstraintProvider<T::FieldType, V> for VarTransformation<T, V, S>
|
||||
where
|
||||
T: RuntimeConstant,
|
||||
S: RangeConstraintProvider<T::FieldType, Variable<V>>,
|
||||
V: Clone,
|
||||
{
|
||||
fn get(&self, var: &V) -> RangeConstraint<T::FieldType> {
|
||||
self.solver.get(&Variable::from(var))
|
||||
}
|
||||
}
|
||||
|
||||
impl<T, V, S: Display> Display for VarTransformation<T, V, S> {
|
||||
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
|
||||
write!(f, "{}", self.solver)
|
||||
}
|
||||
}
|
||||
|
||||
impl<T, V, S> Solver<T, V> for VarTransformation<T, V, S>
|
||||
where
|
||||
T: RuntimeConstant + VarTransformable<V, Variable<V>> + Display,
|
||||
T::Transformed: RuntimeConstant<FieldType = T::FieldType>
|
||||
+ VarTransformable<Variable<V>, V, Transformed = T>
|
||||
+ Display,
|
||||
V: Ord + Clone + Eq + Hash + Display,
|
||||
S: Solver<T::Transformed, Variable<V>>,
|
||||
{
|
||||
/// Solves the system and ignores all assignments that contain a new variable
|
||||
/// (either on the LHS or the RHS).
|
||||
fn solve(&mut self) -> Result<Vec<VariableAssignment<T, V>>, Error> {
|
||||
let assignments = self.solver.solve()?;
|
||||
Ok(assignments
|
||||
.into_iter()
|
||||
.filter_map(|(v, expr)| {
|
||||
assert!(expr.is_affine());
|
||||
let v = v.try_to_original()?;
|
||||
let expr = expr.try_transform_var_type(&mut |v| v.try_to_original())?;
|
||||
Some((v, expr))
|
||||
})
|
||||
.collect())
|
||||
}
|
||||
|
||||
fn add_algebraic_constraints(
|
||||
&mut self,
|
||||
constraints: impl IntoIterator<Item = GroupedExpression<T, V>>,
|
||||
) {
|
||||
self.solver
|
||||
.add_algebraic_constraints(constraints.into_iter().map(|c| transform(&c)));
|
||||
}
|
||||
|
||||
fn add_bus_interactions(
|
||||
&mut self,
|
||||
bus_interactions: impl IntoIterator<Item = BusInteraction<GroupedExpression<T, V>>>,
|
||||
) {
|
||||
self.solver.add_bus_interactions(
|
||||
bus_interactions
|
||||
.into_iter()
|
||||
.map(|bus_interaction| bus_interaction.fields().map(transform).collect()),
|
||||
)
|
||||
}
|
||||
|
||||
fn add_range_constraint(&mut self, variable: &V, constraint: RangeConstraint<T::FieldType>) {
|
||||
self.solver
|
||||
.add_range_constraint(&variable.into(), constraint);
|
||||
}
|
||||
|
||||
fn retain_variables(&mut self, variables_to_keep: &HashSet<V>) {
|
||||
// This will cause constraints to be deleted if they
|
||||
// only contain newly added variables.
|
||||
let variables_to_keep = variables_to_keep
|
||||
.iter()
|
||||
.map(From::from)
|
||||
.collect::<HashSet<_>>();
|
||||
self.solver.retain_variables(&variables_to_keep);
|
||||
}
|
||||
|
||||
fn range_constraint_for_expression(
|
||||
&self,
|
||||
expr: &GroupedExpression<T, V>,
|
||||
) -> RangeConstraint<T::FieldType> {
|
||||
self.solver
|
||||
.range_constraint_for_expression(&transform(expr))
|
||||
}
|
||||
|
||||
fn are_expressions_known_to_be_different(
|
||||
&mut self,
|
||||
a: &GroupedExpression<T, V>,
|
||||
b: &GroupedExpression<T, V>,
|
||||
) -> bool {
|
||||
let a = transform(a);
|
||||
let b = transform(b);
|
||||
self.solver.are_expressions_known_to_be_different(&a, &b)
|
||||
}
|
||||
}
|
||||
|
||||
fn transform<T, V: Ord + Clone>(
|
||||
expr: &GroupedExpression<T, V>,
|
||||
) -> GroupedExpression<T::Transformed, Variable<V>>
|
||||
where
|
||||
T: RuntimeConstant + VarTransformable<V, Variable<V>>,
|
||||
{
|
||||
expr.transform_var_type(&mut |v| v.into())
|
||||
}
|
||||
Reference in New Issue
Block a user