diff --git a/constraint-solver/src/indexed_constraint_system.rs b/constraint-solver/src/indexed_constraint_system.rs index 6103cbea1..866af02cf 100644 --- a/constraint-solver/src/indexed_constraint_system.rs +++ b/constraint-solver/src/indexed_constraint_system.rs @@ -5,33 +5,39 @@ use std::{ }; use itertools::Itertools; -use powdr_number::FieldElement; +use powdr_number::ExpressionConvertible; use crate::{ - constraint_system::{BusInteraction, BusInteractionHandler, ConstraintRef, ConstraintSystem}, - effect::Effect, - grouped_expression::{QuadraticSymbolicExpression, RangeConstraintProvider}, + constraint_system::{ + BusInteraction, BusInteractionHandler, ConstraintRef, ConstraintSystemGeneric, + }, + effect::EffectImpl, + grouped_expression::{GroupedExpression, RangeConstraintProvider}, + runtime_constant::{ReferencedSymbols, RuntimeConstant, Substitutable}, symbolic_expression::SymbolicExpression, }; /// Applies multiple substitutions to a ConstraintSystem in an efficient manner. -pub fn apply_substitutions( - constraint_system: ConstraintSystem, - substitutions: impl IntoIterator)>, -) -> ConstraintSystem { - let mut indexed_constraint_system = IndexedConstraintSystem::from(constraint_system); +pub fn apply_substitutions, V: Hash + Eq + Clone + Ord>( + constraint_system: ConstraintSystemGeneric, + substitutions: impl IntoIterator)>, +) -> ConstraintSystemGeneric { + let mut indexed_constraint_system = IndexedConstraintSystemGeneric::from(constraint_system); for (variable, substitution) in substitutions { indexed_constraint_system.substitute_by_unknown(&variable, &substitution); } indexed_constraint_system.into() } +pub type IndexedConstraintSystem = + IndexedConstraintSystemGeneric, V>; + /// Structure on top of a [`ConstraintSystem`] that stores indices /// to more efficiently update the constraints. #[derive(Clone, Default)] -pub struct IndexedConstraintSystem { +pub struct IndexedConstraintSystemGeneric { /// The constraint system. - constraint_system: ConstraintSystem, + constraint_system: ConstraintSystemGeneric, /// Stores where each unknown variable appears. variable_occurrences: HashMap>, } @@ -42,49 +48,49 @@ enum ConstraintSystemItem { BusInteraction(usize), } -impl From> - for IndexedConstraintSystem +impl From> + for IndexedConstraintSystemGeneric { - fn from(constraint_system: ConstraintSystem) -> Self { + fn from(constraint_system: ConstraintSystemGeneric) -> Self { let variable_occurrences = variable_occurrences(&constraint_system); - IndexedConstraintSystem { + IndexedConstraintSystemGeneric { constraint_system, variable_occurrences, } } } -impl From> - for ConstraintSystem +impl From> + for ConstraintSystemGeneric { - fn from(indexed_constraint_system: IndexedConstraintSystem) -> Self { + fn from(indexed_constraint_system: IndexedConstraintSystemGeneric) -> Self { indexed_constraint_system.constraint_system } } -impl IndexedConstraintSystem { - pub fn system(&self) -> &ConstraintSystem { +impl IndexedConstraintSystemGeneric { + pub fn system(&self) -> &ConstraintSystemGeneric { &self.constraint_system } - pub fn algebraic_constraints(&self) -> &[QuadraticSymbolicExpression] { + pub fn algebraic_constraints(&self) -> &[GroupedExpression] { &self.constraint_system.algebraic_constraints } - pub fn bus_interactions(&self) -> &[BusInteraction>] { + pub fn bus_interactions(&self) -> &[BusInteraction>] { &self.constraint_system.bus_interactions } /// Returns all expressions that appear in the constraint system, i.e. all algebraic /// constraints and all expressions in bus interactions. - pub fn expressions(&self) -> impl Iterator> { + pub fn expressions(&self) -> impl Iterator> { self.constraint_system.expressions() } /// Removes all constraints that do not fulfill the predicate. pub fn retain_algebraic_constraints( &mut self, - mut f: impl FnMut(&QuadraticSymbolicExpression) -> bool, + mut f: impl FnMut(&GroupedExpression) -> bool, ) { retain( &mut self.constraint_system.algebraic_constraints, @@ -97,7 +103,7 @@ impl IndexedConstraintSystem { /// Removes all bus interactions that do not fulfill the predicate. pub fn retain_bus_interactions( &mut self, - mut f: impl FnMut(&BusInteraction>) -> bool, + mut f: impl FnMut(&BusInteraction>) -> bool, ) { retain( &mut self.constraint_system.bus_interactions, @@ -154,13 +160,13 @@ fn retain( }); } -impl IndexedConstraintSystem { +impl IndexedConstraintSystemGeneric { /// Adds new algebraic constraints to the system. pub fn add_algebraic_constraints( &mut self, - constraints: impl IntoIterator>, + constraints: impl IntoIterator>, ) { - self.extend(ConstraintSystem { + self.extend(ConstraintSystemGeneric { algebraic_constraints: constraints.into_iter().collect(), bus_interactions: Vec::new(), }); @@ -169,16 +175,16 @@ impl IndexedConstraintSystem { /// Adds new bus interactions to the system. pub fn add_bus_interactions( &mut self, - bus_interactions: impl IntoIterator>>, + bus_interactions: impl IntoIterator>>, ) { - self.extend(ConstraintSystem { + self.extend(ConstraintSystemGeneric { algebraic_constraints: Vec::new(), bus_interactions: bus_interactions.into_iter().collect(), }); } /// Extends the constraint system by the constraints of another system. - pub fn extend(&mut self, system: ConstraintSystem) { + pub fn extend(&mut self, system: ConstraintSystemGeneric) { let algebraic_constraint_count = self.constraint_system.algebraic_constraints.len(); let bus_interactions_count = self.constraint_system.bus_interactions.len(); // Compute the occurrences of the variables in the new constraints, @@ -203,12 +209,12 @@ impl IndexedConstraintSystem { } } -impl IndexedConstraintSystem { +impl IndexedConstraintSystemGeneric { /// Returns a list of all constraints that contain at least one of the given variables. pub fn constraints_referencing_variables<'a>( &'a self, variables: impl Iterator + 'a, - ) -> impl Iterator, V>> + 'a { + ) -> impl Iterator> + 'a { variables .filter_map(|v| self.variable_occurrences.get(&v)) .flatten() @@ -222,9 +228,13 @@ impl IndexedConstraintSystem } }) } +} +impl, V: Clone + Hash + Ord + Eq> + IndexedConstraintSystemGeneric +{ /// Substitutes a variable with a symbolic expression in the whole system - pub fn substitute_by_known(&mut self, variable: &V, substitution: &SymbolicExpression) { + pub fn substitute_by_known(&mut self, variable: &V, substitution: &T) { // Since we substitute by a known value, we do not need to update variable_occurrences. for item in self .variable_occurrences @@ -239,22 +249,18 @@ impl IndexedConstraintSystem &mut self, interaction_index: usize, field_index: usize, - value: T, + value: T::FieldType, ) { let bus_interaction = &mut self.constraint_system.bus_interactions[interaction_index]; let field = bus_interaction.fields_mut().nth(field_index).unwrap(); - *field = QuadraticSymbolicExpression::from_number(value); + *field = GroupedExpression::from_number(value); } - /// Substitute an unknown variable by a QuadraticSymbolicExpression in the whole system. + /// Substitute an unknown variable by a GroupedExpression in the whole system. /// /// Note this does NOT work properly if the variable is used inside a /// known SymbolicExpression. - pub fn substitute_by_unknown( - &mut self, - variable: &V, - substitution: &QuadraticSymbolicExpression, - ) { + pub fn substitute_by_unknown(&mut self, variable: &V, substitution: &GroupedExpression) { let items = self .variable_occurrences .get(variable) @@ -283,7 +289,15 @@ impl IndexedConstraintSystem /// The provided assignments lead to a contradiction in the constraint system. pub struct ContradictingConstraintError; -impl IndexedConstraintSystem { +impl< + T: RuntimeConstant + + ReferencedSymbols + + Substitutable + + ExpressionConvertible + + Display, + V: Clone + Hash + Ord + Eq + Display, + > IndexedConstraintSystemGeneric +{ /// Given a list of assignments, tries to extend it with more assignments, based on the /// constraints in the constraint system. /// Fails if any of the assignments *directly* contradicts any of the constraints. @@ -291,18 +305,17 @@ impl IndexedConstraintSys /// this function only does one step of the derivation. pub fn derive_more_assignments( &self, - assignments: BTreeMap, - range_constraints: &impl RangeConstraintProvider, - bus_interaction_handler: &impl BusInteractionHandler, - ) -> Result, ContradictingConstraintError> { + assignments: BTreeMap, + range_constraints: &impl RangeConstraintProvider, + bus_interaction_handler: &impl BusInteractionHandler, + ) -> Result, ContradictingConstraintError> { let effects = self .constraints_referencing_variables(assignments.keys().cloned()) .map(|constraint| match constraint { ConstraintRef::AlgebraicConstraint(identity) => { let mut identity = identity.clone(); for (variable, value) in assignments.iter() { - identity - .substitute_by_known(variable, &SymbolicExpression::Concrete(*value)); + identity.substitute_by_known(variable, &T::from(*value)); } identity .solve(range_constraints) @@ -312,12 +325,9 @@ impl IndexedConstraintSys ConstraintRef::BusInteraction(bus_interaction) => { let mut bus_interaction = bus_interaction.clone(); for (variable, value) in assignments.iter() { - bus_interaction.fields_mut().for_each(|expr| { - expr.substitute_by_known( - variable, - &SymbolicExpression::Concrete(*value), - ) - }) + bus_interaction + .fields_mut() + .for_each(|expr| expr.substitute_by_known(variable, &T::from(*value))) } bus_interaction .solve(bus_interaction_handler, range_constraints) @@ -331,8 +341,8 @@ impl IndexedConstraintSys .into_iter() .flatten() .filter_map(|effect| { - if let Effect::Assignment(variable, SymbolicExpression::Concrete(value)) = effect { - Some((variable, value)) + if let EffectImpl::Assignment(variable, value) = effect { + Some((variable, value.try_to_number()?)) } else { None } @@ -353,8 +363,8 @@ impl IndexedConstraintSys /// Returns a hash map mapping all unknown variables in the constraint system /// to the items they occur in. -fn variable_occurrences( - constraint_system: &ConstraintSystem, +fn variable_occurrences( + constraint_system: &ConstraintSystemGeneric, ) -> HashMap> { let occurrences_in_algebraic_constraints = constraint_system .algebraic_constraints @@ -382,11 +392,11 @@ fn variable_occurrences( .into_group_map() } -fn substitute_by_known_in_item( - constraint_system: &mut ConstraintSystem, +fn substitute_by_known_in_item, V: Ord + Clone + Eq>( + constraint_system: &mut ConstraintSystemGeneric, item: ConstraintSystemItem, variable: &V, - substitution: &SymbolicExpression, + substitution: &T, ) { match item { ConstraintSystemItem::AlgebraicConstraint(i) => { @@ -400,11 +410,11 @@ fn substitute_by_known_in_item( } } -fn substitute_by_unknown_in_item( - constraint_system: &mut ConstraintSystem, +fn substitute_by_unknown_in_item, V: Ord + Clone + Eq>( + constraint_system: &mut ConstraintSystemGeneric, item: ConstraintSystemItem, variable: &V, - substitution: &QuadraticSymbolicExpression, + substitution: &GroupedExpression, ) { match item { ConstraintSystemItem::AlgebraicConstraint(i) => { @@ -419,7 +429,9 @@ fn substitute_by_unknown_in_item( } } -impl Display for IndexedConstraintSystem { +impl Display + for IndexedConstraintSystemGeneric +{ fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { write!(f, "{}", self.constraint_system) } @@ -431,7 +443,7 @@ mod tests { use super::*; - fn format_system(s: &IndexedConstraintSystem) -> String { + fn format_system(s: &IndexedConstraintSystemGeneric) -> String { format!( "{} | {}", s.algebraic_constraints().iter().format(" | "), @@ -453,11 +465,11 @@ mod tests { #[test] fn substitute_by_unknown() { - type Qse = QuadraticSymbolicExpression; - let x = Qse::from_unknown_variable("x"); - let y = Qse::from_unknown_variable("y"); - let z = Qse::from_unknown_variable("z"); - let mut s: IndexedConstraintSystem<_, _> = ConstraintSystem { + type Ge = GroupedExpression; + let x = Ge::from_unknown_variable("x"); + let y = Ge::from_unknown_variable("y"); + let z = Ge::from_unknown_variable("z"); + let mut s: IndexedConstraintSystemGeneric<_, _> = ConstraintSystemGeneric { algebraic_constraints: vec![ x.clone() + y.clone(), x.clone() - z.clone(), @@ -471,13 +483,13 @@ mod tests { } .into(); - s.substitute_by_unknown(&"x", &Qse::from_unknown_variable("z")); + s.substitute_by_unknown(&"x", &Ge::from_unknown_variable("z")); assert_eq!(format_system(&s), "y + z | 0 | y - z | z: y * [y, z]"); s.substitute_by_unknown( &"z", - &(Qse::from_unknown_variable("x") + Qse::from_number(GoldilocksField::from(7))), + &(Ge::from_unknown_variable("x") + Ge::from_number(GoldilocksField::from(7))), ); assert_eq!( @@ -488,11 +500,11 @@ mod tests { #[test] fn retain_update_index() { - type Qse = QuadraticSymbolicExpression; - let x = Qse::from_unknown_variable("x"); - let y = Qse::from_unknown_variable("y"); - let z = Qse::from_unknown_variable("z"); - let mut s: IndexedConstraintSystem<_, _> = ConstraintSystem { + type Ge = GroupedExpression; + let x = Ge::from_unknown_variable("x"); + let y = Ge::from_unknown_variable("y"); + let z = Ge::from_unknown_variable("z"); + let mut s: IndexedConstraintSystemGeneric<_, _> = ConstraintSystemGeneric { algebraic_constraints: vec![ x.clone() + y.clone(), x.clone() - z.clone(),