Remove one level of nested calls in the optimizer (#3218)

In order to make the optimizer clearer, this PR tries to put all
optimization passes one after another in `optimize_constraints`.

Also removes `JournalingConstraintSystem`.

TODO:
- [x] Figure out what to do about the `JournalingConstraintSystem`
This commit is contained in:
Thibaut Schaeffer
2025-09-01 15:41:38 +02:00
committed by GitHub
parent f32fd0c81a
commit 396737e749
7 changed files with 125 additions and 119 deletions

View File

@@ -20,9 +20,7 @@ pub fn apply_substitutions<T: RuntimeConstant + Substitutable<V>, V: Hash + Eq +
substitutions: impl IntoIterator<Item = (V, GroupedExpression<T, V>)>,
) -> ConstraintSystem<T, V> {
let mut indexed_constraint_system = IndexedConstraintSystem::from(constraint_system);
for (variable, substitution) in substitutions {
indexed_constraint_system.substitute_by_unknown(&variable, &substitution);
}
indexed_constraint_system.apply_substitutions(substitutions);
indexed_constraint_system.into()
}
@@ -350,6 +348,17 @@ impl<T: RuntimeConstant + Substitutable<V>, V: Clone + Hash + Ord + Eq>
.extend(items.iter().cloned());
}
}
/// Applies multiple substitutions to the constraint system in an efficient manner.
pub fn apply_substitutions(
&mut self,
substitutions: impl IntoIterator<Item = (V, GroupedExpression<T, V>)>,
) {
// We do not track substitutions yet, but we could.
for (variable, substitution) in substitutions {
self.substitute_by_unknown(&variable, &substitution);
}
}
}
/// Returns a hash map mapping all unknown variables in the constraint system

View File

@@ -1,7 +1,6 @@
use crate::constraint_system::{AlgebraicConstraint, ConstraintRef};
use crate::grouped_expression::GroupedExpression;
use crate::indexed_constraint_system::IndexedConstraintSystem;
use crate::journaling_constraint_system::JournalingConstraintSystem;
use crate::runtime_constant::{RuntimeConstant, Substitutable};
use itertools::Itertools;
@@ -23,22 +22,19 @@ pub fn replace_constrained_witness_columns<
T: RuntimeConstant + ExpressionConvertible<T::FieldType, V> + Substitutable<V> + Display,
V: Ord + Clone + Hash + Eq + Display,
>(
mut constraint_system: JournalingConstraintSystem<T, V>,
mut constraint_system: IndexedConstraintSystem<T, V>,
should_inline: impl Fn(&V, &GroupedExpression<T, V>, &IndexedConstraintSystem<T, V>) -> bool,
) -> JournalingConstraintSystem<T, V> {
) -> IndexedConstraintSystem<T, V> {
let mut to_remove_idx = HashSet::new();
let mut inlined_vars = HashSet::new();
let constraint_count = constraint_system
.indexed_system()
.algebraic_constraints()
.len();
let constraint_count = constraint_system.algebraic_constraints().len();
loop {
let inlined_vars_count = inlined_vars.len();
for curr_idx in (0..constraint_count).rev() {
let constraint = &constraint_system.indexed_system().algebraic_constraints()[curr_idx];
let constraint = &constraint_system.algebraic_constraints()[curr_idx];
for (var, expr) in find_inlinable_variables(constraint) {
if should_inline(&var, &expr, constraint_system.indexed_system()) {
if should_inline(&var, &expr, &constraint_system) {
log::trace!("Substituting {var} = {expr}");
log::trace!(" (from identity {constraint})");
@@ -185,7 +181,7 @@ mod test {
let constraint_system =
replace_constrained_witness_columns(constraint_system, bounds(3, 3));
assert_eq!(constraint_system.algebraic_constraints().count(), 2);
assert_eq!(constraint_system.algebraic_constraints().len(), 2);
}
#[test]
@@ -216,10 +212,10 @@ mod test {
// = 0 + 0 + d
// => result = d = -b + 1
// => b = -result + 1
assert_eq!(constraint_system.algebraic_constraints().count(), 0);
assert_eq!(constraint_system.algebraic_constraints().len(), 0);
let bus_interactions = constraint_system.bus_interactions().collect_vec();
let [BusInteraction { payload, .. }] = &bus_interactions[..] else {
let bus_interactions = constraint_system.bus_interactions();
let [BusInteraction { payload, .. }] = bus_interactions else {
panic!();
};
let [result, b] = payload.as_slice() else {
@@ -261,7 +257,7 @@ mod test {
let constraint_system =
replace_constrained_witness_columns(constraint_system, bounds(3, 3));
let constraints = constraint_system.algebraic_constraints().collect_vec();
let constraints = constraint_system.algebraic_constraints();
assert_eq!(constraints.len(), 0);
}
@@ -294,7 +290,7 @@ mod test {
let constraint_system =
replace_constrained_witness_columns(constraint_system, bounds(3, 3));
let constraints = constraint_system.algebraic_constraints().collect_vec();
let constraints = constraint_system.algebraic_constraints();
assert_eq!(constraints.len(), 0);
}
@@ -321,8 +317,8 @@ mod test {
// 1) y = x + 3
// 2) z = y + 2 ⇒ z = (x + 3) + 2 = x + 5
// 3) result = z + 1 ⇒ result = (x + 5) + 1 = x + 6
let bus_interactions = constraint_system.bus_interactions().collect_vec();
let [BusInteraction { payload, .. }] = &bus_interactions[..] else {
let bus_interactions = constraint_system.bus_interactions();
let [BusInteraction { payload, .. }] = bus_interactions else {
panic!();
};
let [result, x] = payload.as_slice() else {
@@ -364,12 +360,12 @@ mod test {
let constraint_system =
replace_constrained_witness_columns(constraint_system, bounds(3, 3));
let constraints = constraint_system.algebraic_constraints().collect_vec();
let [identity] = &constraints[..] else {
let constraints = constraint_system.algebraic_constraints();
let [identity] = constraints else {
panic!();
};
let bus_interactions = constraint_system.bus_interactions().collect_vec();
let [BusInteraction { payload, .. }] = &bus_interactions[..] else {
let bus_interactions = constraint_system.bus_interactions();
let [BusInteraction { payload, .. }] = bus_interactions else {
panic!();
};
let [a, b, c, d, e, f, result] = payload.as_slice() else {
@@ -440,7 +436,7 @@ mod test {
replace_constrained_witness_columns(suboptimal_system, bounds(5, 5));
// Assert the difference in optimization results
assert_eq!(optimal_system.algebraic_constraints().count(), 3);
assert_eq!(suboptimal_system.algebraic_constraints().count(), 4);
assert_eq!(optimal_system.algebraic_constraints().len(), 3);
assert_eq!(suboptimal_system.algebraic_constraints().len(), 4);
}
}

View File

@@ -6,7 +6,6 @@ pub mod effect;
pub mod grouped_expression;
pub mod indexed_constraint_system;
pub mod inliner;
pub mod journaling_constraint_system;
pub mod range_constraint;
pub mod runtime_constant;
pub mod solver;