Use solver directly for memory (#3111)

Co-authored-by: Georg Wiese <georgwiese@gmail.com>
This commit is contained in:
chriseth
2025-08-05 21:04:18 +02:00
committed by GitHub
parent 24ddce250d
commit 356ceba690
6 changed files with 166 additions and 303 deletions

View File

@@ -1,16 +1,11 @@
use std::collections::{BTreeMap, HashMap, HashSet};
use std::fmt::Display;
use std::hash::Hash;
use std::marker::PhantomData;
use itertools::Itertools;
use powdr_constraint_solver::boolean_extractor::{self, RangeConstraintsForBooleans};
use powdr_constraint_solver::constraint_system::{BusInteraction, ConstraintRef, ConstraintSystem};
use powdr_constraint_solver::grouped_expression::{GroupedExpression, RangeConstraintProvider};
use powdr_constraint_solver::indexed_constraint_system::IndexedConstraintSystem;
use powdr_constraint_solver::runtime_constant::{RuntimeConstant, VarTransformable};
use powdr_constraint_solver::constraint_system::{BusInteraction, ConstraintSystem};
use powdr_constraint_solver::grouped_expression::GroupedExpression;
use powdr_constraint_solver::solver::Solver;
use powdr_constraint_solver::utils::possible_concrete_values;
use powdr_number::FieldElement;
/// Optimizes bus sends that correspond to general-purpose memory read and write operations.
@@ -24,11 +19,10 @@ pub fn optimize_memory<
mut system: ConstraintSystem<T, V>,
solver: &mut impl Solver<T, V>,
memory_bus_id: u64,
range_constraints: impl RangeConstraintProvider<T, V> + Clone,
) -> ConstraintSystem<T, V> {
// TODO use the solver here.
let (to_remove, new_constraints) =
redundant_memory_interactions_indices::<T, V, M>(&system, memory_bus_id, range_constraints);
redundant_memory_interactions_indices::<T, V, M>(&system, solver, memory_bus_id);
let to_remove = to_remove.into_iter().collect::<HashSet<_>>();
system.bus_interactions = system
.bus_interactions
@@ -139,10 +133,9 @@ fn redundant_memory_interactions_indices<
M: MemoryBusInteraction<T, V>,
>(
system: &ConstraintSystem<T, V>,
solver: &mut impl Solver<T, V>,
memory_bus_id: u64,
range_constraints: impl RangeConstraintProvider<T, V> + Clone,
) -> (Vec<usize>, Vec<GroupedExpression<T, V>>) {
let address_comparator = MemoryAddressComparator::<T, V, M>::new(system, memory_bus_id);
let mut new_constraints: Vec<GroupedExpression<T, V>> = Vec::new();
// Track memory contents by memory type while we go through bus interactions.
@@ -186,11 +179,11 @@ fn redundant_memory_interactions_indices<
// that this send operation does not interfere with it, i.e.
// if we can prove that the two addresses differ by at least a word size.
memory_contents.retain(|other_addr, _| {
address_comparator.are_addrs_known_to_be_different(
&addr,
other_addr,
range_constraints.clone(),
)
addr.0
.iter()
.zip_eq(other_addr.0.iter())
// Two addresses are different if they differ in at least one component.
.any(|(a, b)| solver.are_expressions_known_to_be_different(a, b))
});
memory_contents.insert(addr.clone(), (index, mem_int.data().to_vec()));
}
@@ -205,186 +198,3 @@ fn redundant_memory_interactions_indices<
(to_remove, new_constraints)
}
type BooleanExtractedExpression<T, V> = GroupedExpression<T, boolean_extractor::Variable<V>>;
/// An address, represented as a list of boolean-extracted expressions.
type BooleanExtractedAddress<T, V> = Vec<BooleanExtractedExpression<T, V>>;
struct MemoryAddressComparator<T, V, M> {
/// For each address `a` contains a list of expressions `v` such that
/// `a = v` is true in the constraint system.
memory_addresses: HashMap<BooleanExtractedAddress<T, V>, Vec<BooleanExtractedAddress<T, V>>>,
_marker: PhantomData<M>,
}
impl<T: FieldElement, V: Ord + Clone + Hash + Display, M: MemoryBusInteraction<T, V>>
MemoryAddressComparator<T, V, M>
{
fn new(system: &ConstraintSystem<T, V>, memory_bus_id: u64) -> Self {
let addresses = system
.bus_interactions
.iter()
.flat_map(|bus| {
M::try_from_bus_interaction(bus, memory_bus_id)
.ok()
.flatten()
})
.map(|bus| bus.addr().into());
let constraints =
boolean_extractor::to_boolean_extracted_system(&system.algebraic_constraints);
let constraint_system: IndexedConstraintSystem<_, _> = ConstraintSystem {
algebraic_constraints: constraints,
bus_interactions: vec![],
}
.into();
let memory_addresses = addresses
.map(|addr| {
// Represent an address as a list of expressions.
let addr = Self::boolean_extracted_address(addr);
let equivalent_expressions = addr
.iter()
// Note that `find_equivalent_expressions` returns the input expression for constants,
// which is a common case.
.map(|addr| find_equivalent_expressions(addr, &constraint_system))
.multi_cartesian_product()
.collect::<Vec<_>>();
(addr, equivalent_expressions)
})
.collect();
Self {
memory_addresses,
_marker: PhantomData,
}
}
fn boolean_extracted_address(address: Address<T, V>) -> BooleanExtractedAddress<T, V> {
address
.0
.into_iter()
.map(|v| v.transform_var_type(&mut |v| v.into()))
.collect()
}
/// Returns true if we can prove that for two addresses `a` and `b`,
/// `a - b` cannot be 0.
pub fn are_addrs_known_to_be_different(
&self,
a: &Address<T, V>,
b: &Address<T, V>,
rc: impl RangeConstraintProvider<T, V> + Clone,
) -> bool {
let a = Self::boolean_extracted_address(a.clone());
let b = Self::boolean_extracted_address(b.clone());
let a_exprs = &self.memory_addresses[&a];
let b_exprs = &self.memory_addresses[&b];
let range_constraints = RangeConstraintsForBooleans::from(rc.clone());
a_exprs
.iter()
.cartesian_product(b_exprs)
.any(|(a_expr, b_expr)| {
// Compare all pairs of address fields. We know the addresses are different
// if at least one pair of fields is known to be different.
a_expr.iter().zip_eq(b_expr).any(|(a_expr, b_expr)| {
is_known_to_be_nonzero(&(a_expr - b_expr), &range_constraints)
})
})
}
}
/// Tries to find equivalent expressions for the given expression
/// according to the given constraint system.
/// Returns at least one equivalent expression (in the worst case, the expression itself).
fn find_equivalent_expressions<T: FieldElement, V: Clone + Ord + Hash + Eq + Display>(
expression: &GroupedExpression<T, V>,
constraints: &IndexedConstraintSystem<T, V>,
) -> Vec<GroupedExpression<T, V>> {
if expression.is_quadratic() {
// This case is too complicated.
return vec![expression.clone()];
}
// Go through the constraints related to this address
// and try to solve for the expression
let mut exprs = constraints
.constraints_referencing_variables(expression.referenced_unknown_variables().cloned())
.filter_map(|constr| match constr {
ConstraintRef::AlgebraicConstraint(constr) => Some(constr),
ConstraintRef::BusInteraction(_) => None,
})
.flat_map(|constr| constr.try_solve_for_expr(expression))
.collect_vec();
if exprs.is_empty() {
// If we cannot solve for the expression, we just take the expression unmodified.
exprs.push(expression.clone());
}
exprs
}
/// Returns true if we can prove that `expr` cannot be 0.
fn is_known_to_be_nonzero<T: FieldElement, V: Clone + Ord + Hash + Eq + Display>(
expr: &GroupedExpression<T, V>,
range_constraints: &impl RangeConstraintProvider<T, V>,
) -> bool {
possible_concrete_values(expr, range_constraints, 20)
.is_some_and(|mut values| values.all(|value| value.is_known_nonzero()))
}
#[cfg(test)]
mod tests {
use super::*;
use powdr_constraint_solver::{
grouped_expression::NoRangeConstraints, range_constraint::RangeConstraint,
};
use powdr_number::GoldilocksField;
type Var = &'static str;
type Qse = GroupedExpression<GoldilocksField, Var>;
fn var(name: Var) -> Qse {
Qse::from_unknown_variable(name)
}
fn constant(value: u64) -> Qse {
Qse::from_number(GoldilocksField::from(value))
}
#[test]
fn is_known_to_by_nonzero() {
assert!(!is_known_to_be_nonzero(&constant(0), &NoRangeConstraints));
assert!(is_known_to_be_nonzero(&constant(1), &NoRangeConstraints));
assert!(is_known_to_be_nonzero(&constant(7), &NoRangeConstraints));
assert!(is_known_to_be_nonzero(&-constant(1), &NoRangeConstraints));
assert!(!is_known_to_be_nonzero(
&(constant(42) - constant(2) * var("a")),
&NoRangeConstraints
));
assert!(!is_known_to_be_nonzero(
&(var("a") - var("b")),
&NoRangeConstraints
));
struct AllVarsThreeOrFour;
impl RangeConstraintProvider<GoldilocksField, &'static str> for AllVarsThreeOrFour {
fn get(&self, _var: &&'static str) -> RangeConstraint<GoldilocksField> {
RangeConstraint::from_range(GoldilocksField::from(3), GoldilocksField::from(4))
}
}
assert!(is_known_to_be_nonzero(&var("a"), &AllVarsThreeOrFour));
assert!(is_known_to_be_nonzero(
// Can't be zero for all assignments of a and b.
&(var("a") - constant(2) * var("b")),
&AllVarsThreeOrFour
));
assert!(!is_known_to_be_nonzero(
// Can be zero for a = 4, b = 3.
&(constant(3) * var("a") - constant(4) * var("b")),
&AllVarsThreeOrFour
));
}
}

View File

@@ -13,7 +13,7 @@ use powdr_constraint_solver::runtime_constant::RuntimeConstant;
use powdr_constraint_solver::solver::{new_solver, Solver};
use powdr_constraint_solver::{
constraint_system::{BusInteraction, ConstraintSystem},
grouped_expression::{GroupedExpression, NoRangeConstraints},
grouped_expression::GroupedExpression,
journaling_constraint_system::JournalingConstraintSystem,
runtime_constant::VarTransformable,
};
@@ -154,12 +154,8 @@ fn optimization_loop_iteration<
)?;
let constraint_system = constraint_system.system().clone();
let constraint_system = if let Some(memory_bus_id) = bus_map.get_bus_id(&BusType::Memory) {
let constraint_system = optimize_memory::<_, _, M>(
constraint_system,
solver,
memory_bus_id,
NoRangeConstraints,
);
let constraint_system =
optimize_memory::<_, _, M>(constraint_system, solver, memory_bus_id);
assert!(check_register_operation_consistency::<_, _, M>(
&constraint_system,
memory_bus_id

View File

@@ -1,12 +1,6 @@
use std::{fmt::Display, hash::Hash};
use std::hash::Hash;
use crate::{
grouped_expression::{GroupedExpression, RangeConstraintProvider},
range_constraint::RangeConstraint,
runtime_constant::{RuntimeConstant, VarTransformable},
};
use itertools::Itertools;
use powdr_number::FieldElement;
use crate::{grouped_expression::GroupedExpression, runtime_constant::RuntimeConstant};
/// Tries to simplify a quadratic constraint by transforming it into an affine
/// constraint that makes use of a new boolean variable.
@@ -75,93 +69,6 @@ pub fn try_extract_boolean<T: RuntimeConstant, V: Ord + Clone + Hash + Eq>(
}
}
/// Tries to simplify a sequence of constraints by transforming them into affine
/// constraints that make use of a new variable that is assumed to be boolean constrained.
/// NOTE: The boolean constraint is not part of the output.
///
/// For example `(a + b) * (a + b + 10) = 0` can be transformed into
/// `a + b + z * 10 = 0`, where `z` is a new boolean variable.
///
/// The constraints in the output use a new variable type that can be converted from
/// the original variable type.
pub fn to_boolean_extracted_system<
'a,
T: RuntimeConstant + VarTransformable<V, Variable<V>> + 'a,
V: Ord + Clone + Hash + Eq + 'a,
>(
constraints: impl IntoIterator<Item = &'a GroupedExpression<T, V>>,
) -> Vec<GroupedExpression<T::Transformed, Variable<V>>>
where
T::Transformed: RuntimeConstant,
{
let mut counter = 0..;
let mut var_dispenser = || Variable::Boolean(counter.next().unwrap());
constraints
.into_iter()
.map(|constr| {
let constr = constr.transform_var_type(&mut |v| v.into());
try_extract_boolean(&constr, &mut var_dispenser).unwrap_or(constr)
})
.collect_vec()
}
/// Range constraint provider that works for `Variable` and delegates range constraint requests
/// for original variables to a provided range constraint provider.
#[derive(Default)]
pub struct RangeConstraintsForBooleans<T: FieldElement, V, R: RangeConstraintProvider<T, V>> {
range_constraints: R,
_phantom: std::marker::PhantomData<(T, V)>,
}
impl<T: FieldElement, V, R: RangeConstraintProvider<T, V>> RangeConstraintProvider<T, Variable<V>>
for RangeConstraintsForBooleans<T, V, R>
{
fn get(&self, variable: &Variable<V>) -> RangeConstraint<T> {
match variable {
Variable::Boolean(_) => RangeConstraint::from_mask(1),
Variable::Original(v) => self.range_constraints.get(v),
}
}
}
impl<T: FieldElement, V, R: RangeConstraintProvider<T, V>> From<R>
for RangeConstraintsForBooleans<T, V, R>
{
fn from(range_constraints: R) -> Self {
RangeConstraintsForBooleans {
range_constraints,
_phantom: std::marker::PhantomData,
}
}
}
/// We introduce new variables (that are always boolean-constrained).
/// This enum avoids clashes with the original variables.
#[derive(Clone, Debug, PartialEq, Eq, Ord, PartialOrd, Hash)]
pub enum Variable<V> {
/// A regular variable that also exists in the original system.
Original(V),
/// A new boolean-constrained variable that was introduced by the solver.
Boolean(usize),
}
impl<V: Clone> From<&V> for Variable<V> {
/// Converts a regular variable to a `Variable`.
fn from(v: &V) -> Self {
Variable::Original(v.clone())
}
}
impl<V: Display> Display for Variable<V> {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
match self {
Variable::Original(v) => write!(f, "{v}"),
Variable::Boolean(i) => write!(f, "boolean_{i}"),
}
}
}
#[cfg(test)]
mod tests {
use crate::test_utils::{constant, var};

View File

@@ -10,6 +10,7 @@ use crate::solver::base::BaseSolver;
use crate::solver::boolean_extracted::{BooleanExtractedSolver, Variable};
use super::grouped_expression::{Error as QseError, RangeConstraintProvider};
use std::collections::HashSet;
use std::fmt::{Debug, Display};
use std::hash::Hash;
@@ -31,7 +32,8 @@ where
+ ReferencedSymbols<Variable<V>>
+ Display
+ ExpressionConvertible<T::FieldType, Variable<V>>
+ Substitutable<Variable<V>>,
+ Substitutable<Variable<V>>
+ Hash,
V: Ord + Clone + Hash + Eq + Display,
{
new_solver(constraint_system, bus_interaction_handler).solve()
@@ -49,7 +51,8 @@ where
+ ReferencedSymbols<Variable<V>>
+ Display
+ ExpressionConvertible<T::FieldType, Variable<V>>
+ Substitutable<Variable<V>>,
+ Substitutable<Variable<V>>
+ Hash,
V: Ord + Clone + Hash + Eq + Display,
{
let solver = BaseSolver::new(bus_interaction_handler);
@@ -88,6 +91,16 @@ pub trait Solver<T: RuntimeConstant, V>: RangeConstraintProvider<T::FieldType, V
&self,
expr: &GroupedExpression<T, V>,
) -> RangeConstraint<T::FieldType>;
/// Returns `true` if `a` and `b` are different for all satisfying assignments.
/// In other words, `a - b` does not allow the value zero.
/// If this function returns `false`, it does not mean that `a` and `b` are equal,
/// i.e. a function always returning `false` here satisfies the trait.
fn are_expressions_known_to_be_different(
&mut self,
a: &GroupedExpression<T, V>,
b: &GroupedExpression<T, V>,
) -> bool;
}
/// An error occurred while solving the constraint system.

View File

@@ -1,3 +1,4 @@
use itertools::Itertools;
use powdr_number::{ExpressionConvertible, FieldElement};
use crate::constraint_system::{BusInteraction, BusInteractionHandler, ConstraintRef};
@@ -7,6 +8,7 @@ use crate::indexed_constraint_system::IndexedConstraintSystemWithQueue;
use crate::range_constraint::RangeConstraint;
use crate::runtime_constant::{ReferencedSymbols, RuntimeConstant, Substitutable};
use crate::solver::{exhaustive_search, quadratic_equivalences, Error, Solver, VariableAssignment};
use crate::utils::possible_concrete_values;
use std::collections::{HashMap, HashSet};
use std::fmt::Display;
@@ -25,6 +27,8 @@ pub struct BaseSolver<T: RuntimeConstant, V, BusInterHandler> {
/// that do not occur in the constraints any more.
/// This is cleared with every call to `solve()`.
assignments_to_return: Vec<VariableAssignment<T, V>>,
/// A cache of expressions that are equivalent to a given expression.
equivalent_expressions_cache: HashMap<GroupedExpression<T, V>, Vec<GroupedExpression<T, V>>>,
}
impl<T: RuntimeConstant, V, B: BusInteractionHandler<T::FieldType>> BaseSolver<T, V, B> {
@@ -33,6 +37,7 @@ impl<T: RuntimeConstant, V, B: BusInteractionHandler<T::FieldType>> BaseSolver<T
constraint_system: Default::default(),
range_constraints: Default::default(),
assignments_to_return: Default::default(),
equivalent_expressions_cache: Default::default(),
bus_interaction_handler,
}
}
@@ -63,10 +68,12 @@ where
T: RuntimeConstant
+ ReferencedSymbols<V>
+ Display
+ Hash
+ ExpressionConvertible<T::FieldType, V>
+ Substitutable<V>,
{
fn solve(&mut self) -> Result<Vec<VariableAssignment<T, V>>, Error> {
self.equivalent_expressions_cache.clear();
self.loop_until_no_progress()?;
Ok(std::mem::take(&mut self.assignments_to_return))
}
@@ -75,6 +82,7 @@ where
&mut self,
constraints: impl IntoIterator<Item = GroupedExpression<T, V>>,
) {
self.equivalent_expressions_cache.clear();
self.constraint_system
.add_algebraic_constraints(constraints);
}
@@ -83,15 +91,18 @@ where
&mut self,
bus_interactions: impl IntoIterator<Item = BusInteraction<GroupedExpression<T, V>>>,
) {
self.equivalent_expressions_cache.clear();
self.constraint_system
.add_bus_interactions(bus_interactions);
}
fn add_range_constraint(&mut self, variable: &V, constraint: RangeConstraint<T::FieldType>) {
self.equivalent_expressions_cache.clear();
self.apply_range_constraint_update(variable, constraint);
}
fn retain_variables(&mut self, variables_to_keep: &HashSet<V>) {
self.equivalent_expressions_cache.clear();
assert!(self.assignments_to_return.is_empty());
self.constraint_system.retain_algebraic_constraints(|c| {
c.referenced_variables()
@@ -119,6 +130,25 @@ where
) -> RangeConstraint<T::FieldType> {
expr.range_constraint(self)
}
fn are_expressions_known_to_be_different(
&mut self,
a: &GroupedExpression<T, V>,
b: &GroupedExpression<T, V>,
) -> bool {
if let (Some(a), Some(b)) = (a.try_to_known(), b.try_to_known()) {
return (a.clone() - b.clone()).is_known_nonzero();
}
let equivalent_to_a = self.equivalent_expressions(a);
let equivalent_to_b = self.equivalent_expressions(b);
equivalent_to_a
.iter()
.cartesian_product(&equivalent_to_b)
.any(|(a_eq, b_eq)| {
possible_concrete_values(&(a_eq - b_eq), self, 20)
.is_some_and(|mut values| values.all(|value| value.is_known_nonzero()))
})
}
}
impl<T, V, BusInter: BusInteractionHandler<T::FieldType>> BaseSolver<T, V, BusInter>
@@ -127,6 +157,7 @@ where
T: RuntimeConstant
+ ReferencedSymbols<V>
+ Display
+ Hash
+ ExpressionConvertible<T::FieldType, V>
+ Substitutable<V>,
{
@@ -203,6 +234,41 @@ where
Ok(progress)
}
/// Returns a vector of expressions that are equivalent to `expression`.
/// The vector is always non-empty, it returns at least `expression` itself.
fn equivalent_expressions(
&mut self,
expression: &GroupedExpression<T, V>,
) -> Vec<GroupedExpression<T, V>> {
if expression.is_quadratic() {
// This case is too complicated.
return vec![expression.clone()];
}
if let Some(equiv) = self.equivalent_expressions_cache.get(expression) {
return equiv.clone();
}
// Go through the constraints related to this expression
// and try to solve for the expression
let mut exprs = self
.constraint_system
.system()
.constraints_referencing_variables(expression.referenced_unknown_variables().cloned())
.filter_map(|constr| match constr {
ConstraintRef::AlgebraicConstraint(constr) => Some(constr),
ConstraintRef::BusInteraction(_) => None,
})
.flat_map(|constr| constr.try_solve_for_expr(expression))
.collect_vec();
if exprs.is_empty() {
// If we cannot solve for the expression, we just take the expression unmodified.
exprs.push(expression.clone());
}
self.equivalent_expressions_cache
.insert(expression.clone(), exprs.clone());
exprs
}
fn apply_effect(&mut self, effect: Effect<T, V>) -> bool {
match effect {
Effect::Assignment(v, expr) => {
@@ -286,3 +352,64 @@ impl<T: FieldElement, V: Clone + Hash + Eq> RangeConstraints<T, V> {
}
}
}
#[cfg(test)]
mod tests {
use crate::constraint_system::DefaultBusInteractionHandler;
use super::*;
use powdr_number::GoldilocksField;
type Var = &'static str;
type Qse = GroupedExpression<GoldilocksField, Var>;
fn var(name: Var) -> Qse {
Qse::from_unknown_variable(name)
}
fn constant(value: u64) -> Qse {
Qse::from_number(GoldilocksField::from(value))
}
#[test]
fn is_known_to_by_nonzero() {
let mut solver =
BaseSolver::<GoldilocksField, Var, _>::new(DefaultBusInteractionHandler::default());
assert!(!solver.are_expressions_known_to_be_different(&constant(0), &constant(0)));
assert!(solver.are_expressions_known_to_be_different(&constant(1), &constant(0)));
assert!(solver.are_expressions_known_to_be_different(&constant(7), &constant(0)));
assert!(solver.are_expressions_known_to_be_different(&-constant(1), &constant(0)));
assert!(
!(solver.are_expressions_known_to_be_different(
&(constant(42) - constant(2) * var("a")),
&constant(0)
))
);
assert!(
!(solver.are_expressions_known_to_be_different(&(var("a") - var("b")), &constant(0)))
);
solver.add_range_constraint(
&"a",
RangeConstraint::from_range(GoldilocksField::from(3), GoldilocksField::from(4)),
);
solver.add_range_constraint(
&"b",
RangeConstraint::from_range(GoldilocksField::from(3), GoldilocksField::from(4)),
);
assert!(solver.are_expressions_known_to_be_different(&(var("a")), &constant(0)));
assert!(solver.are_expressions_known_to_be_different(
// If we try all possible assignments of a and b, this expression
// can never be zero.
&(var("a") - constant(2) * var("b")),
&constant(0)
));
assert!(!solver.are_expressions_known_to_be_different(
// Can be zero for a = 4, b = 3.
&(constant(3) * var("a") - constant(4) * var("b")),
&constant(0)
));
}
}

View File

@@ -186,4 +186,14 @@ where
let expr = expr.transform_var_type(&mut |v| v.clone().into());
self.solver.range_constraint_for_expression(&expr)
}
fn are_expressions_known_to_be_different(
&mut self,
a: &GroupedExpression<T, V>,
b: &GroupedExpression<T, V>,
) -> bool {
let a = a.transform_var_type(&mut |v| v.clone().into());
let b = b.transform_var_type(&mut |v| v.clone().into());
self.solver.are_expressions_known_to_be_different(&a, &b)
}
}