Boolean extracted solver (#3094)

Extracts boolean variables as alternative constraints for all
constraints in the solver.

The exact implementation adds a layer of the Solver trait on top of an
existing solver where we try to extract a boolean variable for all
constraints that are added to the solver.

The interface is also simplified a little in that we always start out
with an "empty" solver and use the `add_algebraic_constraints` and
`add_bus_interaction` functions to fill it with the system.
This commit is contained in:
chriseth
2025-08-01 10:21:20 +02:00
committed by GitHub
parent 806a858cac
commit 4544897d58
5 changed files with 337 additions and 77 deletions

View File

@@ -415,28 +415,31 @@ impl<T: RuntimeConstant + VarTransformable<V1, V2>, V1: Ord + Clone, V2: Ord + C
{
type Transformed = GroupedExpression<T::Transformed, V2>;
fn transform_var_type(&self, var_transform: &mut impl FnMut(&V1) -> V2) -> Self::Transformed {
GroupedExpression {
fn try_transform_var_type(
&self,
var_transform: &mut impl FnMut(&V1) -> Option<V2>,
) -> Option<Self::Transformed> {
Some(GroupedExpression {
quadratic: self
.quadratic
.iter()
.map(|(l, r)| {
(
l.transform_var_type(var_transform),
r.transform_var_type(var_transform),
)
Some((
l.try_transform_var_type(var_transform)?,
r.try_transform_var_type(var_transform)?,
))
})
.collect(),
.collect::<Option<Vec<_>>>()?,
linear: self
.linear
.iter()
.map(|(var, coeff)| {
let new_var = var_transform(var);
(new_var, coeff.transform_var_type(var_transform))
let new_var = var_transform(var)?;
Some((new_var, coeff.try_transform_var_type(var_transform)?))
})
.collect(),
constant: self.constant.transform_var_type(var_transform),
}
.collect::<Option<BTreeMap<_, _>>>()?,
constant: self.constant.try_transform_var_type(var_transform)?,
})
}
}

View File

@@ -28,7 +28,7 @@ pub fn apply_substitutions<T: RuntimeConstant + Substitutable<V>, V: Hash + Eq +
/// Structure on top of a [`ConstraintSystem`] that stores indices
/// to more efficiently update the constraints.
#[derive(Clone, Default)]
#[derive(Clone)]
pub struct IndexedConstraintSystem<T, V> {
/// The constraint system.
constraint_system: ConstraintSystem<T, V>,
@@ -36,6 +36,15 @@ pub struct IndexedConstraintSystem<T, V> {
variable_occurrences: HashMap<V, BTreeSet<ConstraintSystemItem>>,
}
impl<T, V> Default for IndexedConstraintSystem<T, V> {
fn default() -> Self {
IndexedConstraintSystem {
constraint_system: ConstraintSystem::default(),
variable_occurrences: HashMap::new(),
}
}
}
/// Structure on top of [`IndexedConstraintSystem`] that
/// tracks changes to variables and how they may affect constraints.
///
@@ -44,12 +53,21 @@ pub struct IndexedConstraintSystem<T, V> {
/// and are put in a queue. Handling an item can cause an update to a variable,
/// which causes all constraints referencing that variable to be put back into the
/// queue.
#[derive(Clone, Default)]
#[derive(Clone)]
pub struct IndexedConstraintSystemWithQueue<T, V> {
constraint_system: IndexedConstraintSystem<T, V>,
queue: ConstraintSystemQueue,
}
impl<T, V> Default for IndexedConstraintSystemWithQueue<T, V> {
fn default() -> Self {
IndexedConstraintSystemWithQueue {
constraint_system: IndexedConstraintSystem::default(),
queue: ConstraintSystemQueue::default(),
}
}
}
/// A reference to an item in the constraint system.
#[derive(Debug, Clone, Copy, PartialEq, Eq, Ord, PartialOrd, Hash)]
enum ConstraintSystemItem {
@@ -115,6 +133,10 @@ impl<T: RuntimeConstant, V: Clone + Eq> IndexedConstraintSystem<T, V> {
&self.constraint_system.bus_interactions
}
pub fn variables(&self) -> impl Iterator<Item = &V> {
self.variable_occurrences.keys()
}
/// Returns all expressions that appear in the constraint system, i.e. all algebraic
/// constraints and all expressions in bus interactions.
pub fn expressions(&self) -> impl Iterator<Item = &GroupedExpression<T, V>> {
@@ -192,6 +214,7 @@ fn retain<V, Item>(
})
.collect();
});
occurrences.retain(|_, occurrences| !occurrences.is_empty());
}
impl<T: RuntimeConstant, V: Clone + Ord + Hash> IndexedConstraintSystem<T, V> {
@@ -457,6 +480,23 @@ where
}));
}
pub fn add_bus_interactions(
&mut self,
bus_interactions: impl IntoIterator<Item = BusInteraction<GroupedExpression<T, V>>>,
) {
let initial_len = self
.constraint_system
.constraint_system
.bus_interactions
.len();
self.constraint_system
.add_bus_interactions(bus_interactions.into_iter().enumerate().map(|(i, c)| {
self.queue
.push(ConstraintSystemItem::BusInteraction(initial_len + i));
c
}));
}
pub fn retain_algebraic_constraints(
&mut self,
mut f: impl FnMut(&GroupedExpression<T, V>) -> bool,

View File

@@ -92,7 +92,15 @@ pub trait VarTransformable<V1, V2> {
type Transformed;
/// Transforms `self` by applying the `var_transform` function to all variables.
fn transform_var_type(&self, var_transform: &mut impl FnMut(&V1) -> V2) -> Self::Transformed;
fn transform_var_type(&self, var_transform: &mut impl FnMut(&V1) -> V2) -> Self::Transformed {
self.try_transform_var_type(&mut |v| Some(var_transform(v)))
.unwrap()
}
fn try_transform_var_type(
&self,
var_transform: &mut impl FnMut(&V1) -> Option<V2>,
) -> Option<Self::Transformed>;
}
impl<T: FieldElement> RuntimeConstant for T {
@@ -138,4 +146,12 @@ impl<T: FieldElement, V1, V2> VarTransformable<V1, V2> for T {
// No variables to transform.
*self
}
fn try_transform_var_type(
&self,
_var_transform: &mut impl FnMut(&V1) -> Option<V2>,
) -> Option<Self::Transformed> {
// No variables to transform.
Some(*self)
}
}

View File

@@ -1,14 +1,16 @@
use powdr_number::{ExpressionConvertible, FieldElement};
use crate::boolean_extractor::try_extract_boolean;
use crate::constraint_system::{
BusInteractionHandler, ConstraintRef, ConstraintSystem, DefaultBusInteractionHandler,
BusInteraction, BusInteractionHandler, ConstraintRef, ConstraintSystem,
};
use crate::effect::Effect;
use crate::grouped_expression::GroupedExpression;
use crate::indexed_constraint_system::IndexedConstraintSystemWithQueue;
use crate::range_constraint::RangeConstraint;
use crate::runtime_constant::{ReferencedSymbols, RuntimeConstant, Substitutable};
use crate::utils::known_variables;
use crate::runtime_constant::{
ReferencedSymbols, RuntimeConstant, Substitutable, VarTransformable,
};
use super::grouped_expression::{Error as QseError, RangeConstraintProvider};
use std::collections::{HashMap, HashSet};
@@ -24,16 +26,16 @@ pub fn solve_system<T, V>(
bus_interaction_handler: impl BusInteractionHandler<T::FieldType>,
) -> Result<Vec<VariableAssignment<T, V>>, Error>
where
T: RuntimeConstant
T: RuntimeConstant + VarTransformable<V, Variable<V>> + Display,
T::Transformed: RuntimeConstant<FieldType = T::FieldType>
+ VarTransformable<Variable<V>, V, Transformed = T>
+ ReferencedSymbols<Variable<V>>
+ Display
+ ReferencedSymbols<V>
+ Substitutable<V>
+ ExpressionConvertible<T::FieldType, V>,
+ ExpressionConvertible<T::FieldType, Variable<V>>
+ Substitutable<Variable<V>>,
V: Ord + Clone + Hash + Eq + Display,
{
SolverImpl::new(constraint_system)
.with_bus_interaction_handler(bus_interaction_handler)
.solve()
new_solver(constraint_system, bus_interaction_handler).solve()
}
/// Creates a new solver for the given system and bus interaction handler.
@@ -42,19 +44,23 @@ pub fn new_solver<T, V>(
bus_interaction_handler: impl BusInteractionHandler<T::FieldType>,
) -> impl Solver<T, V>
where
T: RuntimeConstant
T: RuntimeConstant + VarTransformable<V, Variable<V>> + Display,
T::Transformed: RuntimeConstant<FieldType = T::FieldType>
+ VarTransformable<Variable<V>, V, Transformed = T>
+ ReferencedSymbols<Variable<V>>
+ Display
+ ReferencedSymbols<V>
+ Substitutable<V>
+ ExpressionConvertible<T::FieldType, V>,
+ ExpressionConvertible<T::FieldType, Variable<V>>
+ Substitutable<Variable<V>>,
V: Ord + Clone + Hash + Eq + Display,
{
SolverImpl::new(constraint_system).with_bus_interaction_handler(bus_interaction_handler)
let solver = SolverImpl::new(bus_interaction_handler);
let mut boolean_extracted_solver = BooleanExtractedSolver::new(solver);
boolean_extracted_solver.add_algebraic_constraints(constraint_system.algebraic_constraints);
boolean_extracted_solver.add_bus_interactions(constraint_system.bus_interactions);
boolean_extracted_solver
}
pub trait Solver<T: RuntimeConstant, V: Ord + Clone + Eq>:
RangeConstraintProvider<T::FieldType, V> + Sized
{
pub trait Solver<T: RuntimeConstant, V>: RangeConstraintProvider<T::FieldType, V> + Sized {
/// Solves the constraints as far as possible, returning concrete variable
/// assignments. Does not return the same assignments again if called more than once.
fn solve(&mut self) -> Result<Vec<VariableAssignment<T, V>>, Error>;
@@ -65,19 +71,24 @@ pub trait Solver<T: RuntimeConstant, V: Ord + Clone + Eq>:
constraints: impl IntoIterator<Item = GroupedExpression<T, V>>,
);
/// Removes all variables except those in `variables_to_keep`.
/// The idea is that the outside system is not interested in the variables
/// any more. This should remove all constraints that include one of the variables
/// and also remove all variables derived from those variables.
/// Adds a new bus interaction to the system.
fn add_bus_interactions(
&mut self,
bus_interactions: impl IntoIterator<Item = BusInteraction<GroupedExpression<T, V>>>,
);
/// Adds a new range constraint for the variable.
fn add_range_constraint(&mut self, var: &V, constraint: RangeConstraint<T::FieldType>);
/// Permits the solver to remove all variables except those in `variables_to_keep`.
/// This should only keep the constraints that reference at least one of the variables.
fn retain_variables(&mut self, variables_to_keep: &HashSet<V>);
/// Returns the best known range constraint for the given expression.
fn range_constraint_for_expression(
&self,
expr: &GroupedExpression<T, V>,
) -> RangeConstraint<T::FieldType> {
expr.range_constraint(self)
}
) -> RangeConstraint<T::FieldType>;
}
/// An error occurred while solving the constraint system.
@@ -96,8 +107,187 @@ pub enum Error {
/// An assignment of a variable.
pub type VariableAssignment<T, V> = (V, GroupedExpression<T, V>);
/// We introduce new variables.
/// This enum avoids clashes with the original variables.
#[derive(Clone, Debug, PartialEq, Eq, Ord, PartialOrd, Hash)]
pub enum Variable<V> {
/// A regular variable that also exists in the original system.
Original(V),
/// A new boolean-constrained variable that was introduced by the solver.
Boolean(usize),
}
impl<V> From<V> for Variable<V> {
/// Converts a regular variable to a `Variable`.
fn from(v: V) -> Self {
Variable::Original(v)
}
}
impl<V: Clone> Variable<V> {
pub fn try_to_original(&self) -> Option<V> {
match self {
Variable::Original(v) => Some(v.clone()),
Variable::Boolean(_) => None,
}
}
}
impl<V: Display> Display for Variable<V> {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
match self {
Variable::Original(v) => write!(f, "{v}"),
Variable::Boolean(i) => write!(f, "boolean_{i}"),
}
}
}
struct BooleanVarDispenser {
next_boolean_id: usize,
}
impl BooleanVarDispenser {
fn new() -> Self {
BooleanVarDispenser { next_boolean_id: 0 }
}
fn next_var<V>(&mut self) -> Variable<V> {
let id = self.next_boolean_id;
self.next_boolean_id += 1;
Variable::Boolean(id)
}
}
/// An implementation of `Solver` that tries to introduce new boolean variables
/// for certain quadratic constraints to make them affine.
struct BooleanExtractedSolver<T, V, S> {
solver: S,
boolean_var_dispenser: BooleanVarDispenser,
_phantom: std::marker::PhantomData<(T, V)>,
}
impl<T, V, S> BooleanExtractedSolver<T, V, S>
where
T: RuntimeConstant + VarTransformable<V, Variable<V>>,
T::Transformed: RuntimeConstant<FieldType = T::FieldType>,
V: Clone + Eq,
S: Solver<T::Transformed, Variable<V>>,
{
fn new(solver: S) -> Self {
Self {
solver,
boolean_var_dispenser: BooleanVarDispenser::new(),
_phantom: std::marker::PhantomData,
}
}
}
impl<T, V, S> RangeConstraintProvider<T::FieldType, V> for BooleanExtractedSolver<T, V, S>
where
T: RuntimeConstant,
S: RangeConstraintProvider<T::FieldType, Variable<V>>,
V: Clone,
{
fn get(&self, var: &V) -> RangeConstraint<T::FieldType> {
self.solver.get(&Variable::from(var.clone()))
}
}
impl<T, V, S: Display> Display for BooleanExtractedSolver<T, V, S> {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
write!(f, "Boolean extracted solver:\n{}", self.solver)
}
}
impl<T, V, S> Solver<T, V> for BooleanExtractedSolver<T, V, S>
where
T: RuntimeConstant + VarTransformable<V, Variable<V>> + Display,
T::Transformed: RuntimeConstant<FieldType = T::FieldType>
+ VarTransformable<Variable<V>, V, Transformed = T>
+ Display,
V: Ord + Clone + Eq + Hash + Display,
S: Solver<T::Transformed, Variable<V>>,
{
/// Solves the system and ignores all assignments that contain a boolean variable
/// (either on the LHS or the RHS).
fn solve(&mut self) -> Result<Vec<VariableAssignment<T, V>>, Error> {
let assignments = self.solver.solve()?;
Ok(assignments
.into_iter()
.filter_map(|(v, expr)| {
let v = v.try_to_original()?;
let expr = expr.try_transform_var_type(&mut |v| v.try_to_original())?;
Some((v, expr))
})
.collect())
}
fn add_algebraic_constraints(
&mut self,
constraints: impl IntoIterator<Item = GroupedExpression<T, V>>,
) {
let mut new_boolean_vars = vec![];
self.solver
.add_algebraic_constraints(constraints.into_iter().flat_map(|constr| {
let constr = constr.transform_var_type(&mut |v| v.clone().into());
let extracted = try_extract_boolean(&constr, &mut || {
let v = self.boolean_var_dispenser.next_var();
new_boolean_vars.push(v.clone());
v
});
std::iter::once(constr).chain(extracted)
}));
// We need to manually add the boolean range constraints for the new variables.
for v in new_boolean_vars {
self.solver
.add_range_constraint(&v, RangeConstraint::from_mask(1));
}
}
fn add_bus_interactions(
&mut self,
bus_interactions: impl IntoIterator<Item = BusInteraction<GroupedExpression<T, V>>>,
) {
self.solver
.add_bus_interactions(bus_interactions.into_iter().map(|bus_interaction| {
bus_interaction
.fields()
.map(|expr| {
// We cannot extract booleans here because that only works
// for "constr = 0".
expr.transform_var_type(&mut |v| v.clone().into())
})
.collect()
}))
}
fn add_range_constraint(&mut self, variable: &V, constraint: RangeConstraint<T::FieldType>) {
self.solver
.add_range_constraint(&variable.clone().into(), constraint);
}
fn retain_variables(&mut self, variables_to_keep: &HashSet<V>) {
// We do not add boolean variables because we want constraints
// to be removed that only reference variables to be removed and
// boolean variables derived from them.
let variables_to_keep = variables_to_keep
.iter()
.map(|v| Variable::from(v.clone()))
.collect::<HashSet<_>>();
self.solver.retain_variables(&variables_to_keep);
}
fn range_constraint_for_expression(
&self,
expr: &GroupedExpression<T, V>,
) -> RangeConstraint<T::FieldType> {
let expr = expr.transform_var_type(&mut |v| v.clone().into());
self.solver.range_constraint_for_expression(&expr)
}
}
/// Given a list of constraints, tries to derive as many variable assignments as possible.
struct SolverImpl<T: RuntimeConstant, V: Clone + Eq, BusInterHandler> {
struct SolverImpl<T: RuntimeConstant, V, BusInterHandler> {
/// The constraint system to solve. During the solving process, any expressions will
/// be simplified as much as possible.
constraint_system: IndexedConstraintSystemWithQueue<T, V>,
@@ -111,33 +301,13 @@ struct SolverImpl<T: RuntimeConstant, V: Clone + Eq, BusInterHandler> {
assignments_to_return: Vec<VariableAssignment<T, V>>,
}
impl<T: RuntimeConstant + ReferencedSymbols<V>, V: Ord + Clone + Hash + Eq + Display>
SolverImpl<T, V, DefaultBusInteractionHandler<T::FieldType>>
{
fn new(constraint_system: ConstraintSystem<T, V>) -> Self {
assert!(
known_variables(constraint_system.expressions()).is_empty(),
"Expected all variables to be unknown."
);
impl<T: RuntimeConstant, V, B: BusInteractionHandler<T::FieldType>> SolverImpl<T, V, B> {
fn new(bus_interaction_handler: B) -> Self {
SolverImpl {
constraint_system: IndexedConstraintSystemWithQueue::from(constraint_system),
constraint_system: Default::default(),
range_constraints: Default::default(),
bus_interaction_handler: Default::default(),
assignments_to_return: Default::default(),
}
}
pub fn with_bus_interaction_handler<B: BusInteractionHandler<T::FieldType>>(
self,
bus_interaction_handler: B,
) -> SolverImpl<T, V, B> {
assert!(self.assignments_to_return.is_empty());
SolverImpl {
bus_interaction_handler,
constraint_system: self.constraint_system,
range_constraints: self.range_constraints,
assignments_to_return: self.assignments_to_return,
}
}
}
@@ -146,13 +316,20 @@ impl<T, V, BusInter> RangeConstraintProvider<T::FieldType, V> for SolverImpl<T,
where
V: Clone + Hash + Eq,
T: RuntimeConstant,
BusInter: BusInteractionHandler<T::FieldType>,
{
fn get(&self, var: &V) -> RangeConstraint<T::FieldType> {
self.range_constraints.get(var)
}
}
impl<T: RuntimeConstant + Display, V: Clone + Ord + Hash + Display, BusInter> Display
for SolverImpl<T, V, BusInter>
{
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
write!(f, "{}", self.constraint_system)
}
}
impl<T, V, BusInter: BusInteractionHandler<T::FieldType>> Solver<T, V>
for SolverImpl<T, V, BusInter>
where
@@ -176,11 +353,20 @@ where
.add_algebraic_constraints(constraints);
}
fn add_bus_interactions(
&mut self,
bus_interactions: impl IntoIterator<Item = BusInteraction<GroupedExpression<T, V>>>,
) {
self.constraint_system
.add_bus_interactions(bus_interactions);
}
fn add_range_constraint(&mut self, variable: &V, constraint: RangeConstraint<T::FieldType>) {
self.apply_range_constraint_update(variable, constraint);
}
fn retain_variables(&mut self, variables_to_keep: &HashSet<V>) {
assert!(self.assignments_to_return.is_empty());
self.range_constraints
.range_constraints
.retain(|v, _| variables_to_keep.contains(v));
self.constraint_system.retain_algebraic_constraints(|c| {
c.referenced_variables()
.any(|v| variables_to_keep.contains(v))
@@ -191,6 +377,21 @@ where
.referenced_variables()
.any(|v| variables_to_keep.contains(v))
});
let remaining_variables = self
.constraint_system
.system()
.variables()
.collect::<HashSet<_>>();
self.range_constraints
.range_constraints
.retain(|v, _| remaining_variables.contains(v));
}
fn range_constraint_for_expression(
&self,
expr: &GroupedExpression<T, V>,
) -> RangeConstraint<T::FieldType> {
expr.range_constraint(self)
}
}

View File

@@ -169,31 +169,31 @@ impl<T: FieldElement, S1: Ord + Clone, S2: Ord + Clone> VarTransformable<S1, S2>
{
type Transformed = SymbolicExpression<T, S2>;
fn transform_var_type(
fn try_transform_var_type(
&self,
var_transform: &mut impl FnMut(&S1) -> S2,
) -> SymbolicExpression<T, S2> {
match self {
var_transform: &mut impl FnMut(&S1) -> Option<S2>,
) -> Option<SymbolicExpression<T, S2>> {
Some(match self {
SymbolicExpression::Concrete(n) => SymbolicExpression::Concrete(*n),
SymbolicExpression::Symbol(v, rc) => {
SymbolicExpression::from_symbol(var_transform(v), rc.clone())
SymbolicExpression::from_symbol(var_transform(v)?, rc.clone())
}
SymbolicExpression::BinaryOperation(lhs, op, rhs, rc) => {
SymbolicExpression::BinaryOperation(
Arc::new(lhs.transform_var_type(var_transform)),
Arc::new(lhs.try_transform_var_type(var_transform)?),
*op,
Arc::new(rhs.transform_var_type(var_transform)),
Arc::new(rhs.try_transform_var_type(var_transform)?),
rc.clone(),
)
}
SymbolicExpression::UnaryOperation(op, inner, rc) => {
SymbolicExpression::UnaryOperation(
*op,
Arc::new(inner.transform_var_type(var_transform)),
Arc::new(inner.try_transform_var_type(var_transform)?),
rc.clone(),
)
}
}
})
}
}