mirror of
https://github.com/powdr-labs/powdr.git
synced 2026-01-10 07:07:59 -05:00
Generic indexed constraint system (#2981)
This commit is contained in:
@@ -5,33 +5,39 @@ use std::{
|
||||
};
|
||||
|
||||
use itertools::Itertools;
|
||||
use powdr_number::FieldElement;
|
||||
use powdr_number::ExpressionConvertible;
|
||||
|
||||
use crate::{
|
||||
constraint_system::{BusInteraction, BusInteractionHandler, ConstraintRef, ConstraintSystem},
|
||||
effect::Effect,
|
||||
grouped_expression::{QuadraticSymbolicExpression, RangeConstraintProvider},
|
||||
constraint_system::{
|
||||
BusInteraction, BusInteractionHandler, ConstraintRef, ConstraintSystemGeneric,
|
||||
},
|
||||
effect::EffectImpl,
|
||||
grouped_expression::{GroupedExpression, RangeConstraintProvider},
|
||||
runtime_constant::{ReferencedSymbols, RuntimeConstant, Substitutable},
|
||||
symbolic_expression::SymbolicExpression,
|
||||
};
|
||||
|
||||
/// Applies multiple substitutions to a ConstraintSystem in an efficient manner.
|
||||
pub fn apply_substitutions<T: FieldElement, V: Hash + Eq + Clone + Ord>(
|
||||
constraint_system: ConstraintSystem<T, V>,
|
||||
substitutions: impl IntoIterator<Item = (V, QuadraticSymbolicExpression<T, V>)>,
|
||||
) -> ConstraintSystem<T, V> {
|
||||
let mut indexed_constraint_system = IndexedConstraintSystem::from(constraint_system);
|
||||
pub fn apply_substitutions<T: RuntimeConstant + Substitutable<V>, V: Hash + Eq + Clone + Ord>(
|
||||
constraint_system: ConstraintSystemGeneric<T, V>,
|
||||
substitutions: impl IntoIterator<Item = (V, GroupedExpression<T, V>)>,
|
||||
) -> ConstraintSystemGeneric<T, V> {
|
||||
let mut indexed_constraint_system = IndexedConstraintSystemGeneric::from(constraint_system);
|
||||
for (variable, substitution) in substitutions {
|
||||
indexed_constraint_system.substitute_by_unknown(&variable, &substitution);
|
||||
}
|
||||
indexed_constraint_system.into()
|
||||
}
|
||||
|
||||
pub type IndexedConstraintSystem<T, V> =
|
||||
IndexedConstraintSystemGeneric<SymbolicExpression<T, V>, V>;
|
||||
|
||||
/// Structure on top of a [`ConstraintSystem`] that stores indices
|
||||
/// to more efficiently update the constraints.
|
||||
#[derive(Clone, Default)]
|
||||
pub struct IndexedConstraintSystem<T: FieldElement, V: Clone + Eq> {
|
||||
pub struct IndexedConstraintSystemGeneric<T: RuntimeConstant, V> {
|
||||
/// The constraint system.
|
||||
constraint_system: ConstraintSystem<T, V>,
|
||||
constraint_system: ConstraintSystemGeneric<T, V>,
|
||||
/// Stores where each unknown variable appears.
|
||||
variable_occurrences: HashMap<V, Vec<ConstraintSystemItem>>,
|
||||
}
|
||||
@@ -42,49 +48,49 @@ enum ConstraintSystemItem {
|
||||
BusInteraction(usize),
|
||||
}
|
||||
|
||||
impl<T: FieldElement, V: Hash + Eq + Clone + Ord> From<ConstraintSystem<T, V>>
|
||||
for IndexedConstraintSystem<T, V>
|
||||
impl<T: RuntimeConstant, V: Hash + Eq + Clone + Ord> From<ConstraintSystemGeneric<T, V>>
|
||||
for IndexedConstraintSystemGeneric<T, V>
|
||||
{
|
||||
fn from(constraint_system: ConstraintSystem<T, V>) -> Self {
|
||||
fn from(constraint_system: ConstraintSystemGeneric<T, V>) -> Self {
|
||||
let variable_occurrences = variable_occurrences(&constraint_system);
|
||||
IndexedConstraintSystem {
|
||||
IndexedConstraintSystemGeneric {
|
||||
constraint_system,
|
||||
variable_occurrences,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl<T: FieldElement, V: Clone + Eq> From<IndexedConstraintSystem<T, V>>
|
||||
for ConstraintSystem<T, V>
|
||||
impl<T: RuntimeConstant, V: Clone + Eq> From<IndexedConstraintSystemGeneric<T, V>>
|
||||
for ConstraintSystemGeneric<T, V>
|
||||
{
|
||||
fn from(indexed_constraint_system: IndexedConstraintSystem<T, V>) -> Self {
|
||||
fn from(indexed_constraint_system: IndexedConstraintSystemGeneric<T, V>) -> Self {
|
||||
indexed_constraint_system.constraint_system
|
||||
}
|
||||
}
|
||||
|
||||
impl<T: FieldElement, V: Clone + Eq> IndexedConstraintSystem<T, V> {
|
||||
pub fn system(&self) -> &ConstraintSystem<T, V> {
|
||||
impl<T: RuntimeConstant, V: Clone + Eq> IndexedConstraintSystemGeneric<T, V> {
|
||||
pub fn system(&self) -> &ConstraintSystemGeneric<T, V> {
|
||||
&self.constraint_system
|
||||
}
|
||||
|
||||
pub fn algebraic_constraints(&self) -> &[QuadraticSymbolicExpression<T, V>] {
|
||||
pub fn algebraic_constraints(&self) -> &[GroupedExpression<T, V>] {
|
||||
&self.constraint_system.algebraic_constraints
|
||||
}
|
||||
|
||||
pub fn bus_interactions(&self) -> &[BusInteraction<QuadraticSymbolicExpression<T, V>>] {
|
||||
pub fn bus_interactions(&self) -> &[BusInteraction<GroupedExpression<T, V>>] {
|
||||
&self.constraint_system.bus_interactions
|
||||
}
|
||||
|
||||
/// Returns all expressions that appear in the constraint system, i.e. all algebraic
|
||||
/// constraints and all expressions in bus interactions.
|
||||
pub fn expressions(&self) -> impl Iterator<Item = &QuadraticSymbolicExpression<T, V>> {
|
||||
pub fn expressions(&self) -> impl Iterator<Item = &GroupedExpression<T, V>> {
|
||||
self.constraint_system.expressions()
|
||||
}
|
||||
|
||||
/// Removes all constraints that do not fulfill the predicate.
|
||||
pub fn retain_algebraic_constraints(
|
||||
&mut self,
|
||||
mut f: impl FnMut(&QuadraticSymbolicExpression<T, V>) -> bool,
|
||||
mut f: impl FnMut(&GroupedExpression<T, V>) -> bool,
|
||||
) {
|
||||
retain(
|
||||
&mut self.constraint_system.algebraic_constraints,
|
||||
@@ -97,7 +103,7 @@ impl<T: FieldElement, V: Clone + Eq> IndexedConstraintSystem<T, V> {
|
||||
/// Removes all bus interactions that do not fulfill the predicate.
|
||||
pub fn retain_bus_interactions(
|
||||
&mut self,
|
||||
mut f: impl FnMut(&BusInteraction<QuadraticSymbolicExpression<T, V>>) -> bool,
|
||||
mut f: impl FnMut(&BusInteraction<GroupedExpression<T, V>>) -> bool,
|
||||
) {
|
||||
retain(
|
||||
&mut self.constraint_system.bus_interactions,
|
||||
@@ -154,13 +160,13 @@ fn retain<V, Item>(
|
||||
});
|
||||
}
|
||||
|
||||
impl<T: FieldElement, V: Clone + Ord + Hash> IndexedConstraintSystem<T, V> {
|
||||
impl<T: RuntimeConstant, V: Clone + Ord + Hash> IndexedConstraintSystemGeneric<T, V> {
|
||||
/// Adds new algebraic constraints to the system.
|
||||
pub fn add_algebraic_constraints(
|
||||
&mut self,
|
||||
constraints: impl IntoIterator<Item = QuadraticSymbolicExpression<T, V>>,
|
||||
constraints: impl IntoIterator<Item = GroupedExpression<T, V>>,
|
||||
) {
|
||||
self.extend(ConstraintSystem {
|
||||
self.extend(ConstraintSystemGeneric {
|
||||
algebraic_constraints: constraints.into_iter().collect(),
|
||||
bus_interactions: Vec::new(),
|
||||
});
|
||||
@@ -169,16 +175,16 @@ impl<T: FieldElement, V: Clone + Ord + Hash> IndexedConstraintSystem<T, V> {
|
||||
/// Adds new bus interactions to the system.
|
||||
pub fn add_bus_interactions(
|
||||
&mut self,
|
||||
bus_interactions: impl IntoIterator<Item = BusInteraction<QuadraticSymbolicExpression<T, V>>>,
|
||||
bus_interactions: impl IntoIterator<Item = BusInteraction<GroupedExpression<T, V>>>,
|
||||
) {
|
||||
self.extend(ConstraintSystem {
|
||||
self.extend(ConstraintSystemGeneric {
|
||||
algebraic_constraints: Vec::new(),
|
||||
bus_interactions: bus_interactions.into_iter().collect(),
|
||||
});
|
||||
}
|
||||
|
||||
/// Extends the constraint system by the constraints of another system.
|
||||
pub fn extend(&mut self, system: ConstraintSystem<T, V>) {
|
||||
pub fn extend(&mut self, system: ConstraintSystemGeneric<T, V>) {
|
||||
let algebraic_constraint_count = self.constraint_system.algebraic_constraints.len();
|
||||
let bus_interactions_count = self.constraint_system.bus_interactions.len();
|
||||
// Compute the occurrences of the variables in the new constraints,
|
||||
@@ -203,12 +209,12 @@ impl<T: FieldElement, V: Clone + Ord + Hash> IndexedConstraintSystem<T, V> {
|
||||
}
|
||||
}
|
||||
|
||||
impl<T: FieldElement, V: Clone + Hash + Ord + Eq> IndexedConstraintSystem<T, V> {
|
||||
impl<T: RuntimeConstant, V: Clone + Hash + Ord + Eq> IndexedConstraintSystemGeneric<T, V> {
|
||||
/// Returns a list of all constraints that contain at least one of the given variables.
|
||||
pub fn constraints_referencing_variables<'a>(
|
||||
&'a self,
|
||||
variables: impl Iterator<Item = V> + 'a,
|
||||
) -> impl Iterator<Item = ConstraintRef<'a, SymbolicExpression<T, V>, V>> + 'a {
|
||||
) -> impl Iterator<Item = ConstraintRef<'a, T, V>> + 'a {
|
||||
variables
|
||||
.filter_map(|v| self.variable_occurrences.get(&v))
|
||||
.flatten()
|
||||
@@ -222,9 +228,13 @@ impl<T: FieldElement, V: Clone + Hash + Ord + Eq> IndexedConstraintSystem<T, V>
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
impl<T: RuntimeConstant + Substitutable<V>, V: Clone + Hash + Ord + Eq>
|
||||
IndexedConstraintSystemGeneric<T, V>
|
||||
{
|
||||
/// Substitutes a variable with a symbolic expression in the whole system
|
||||
pub fn substitute_by_known(&mut self, variable: &V, substitution: &SymbolicExpression<T, V>) {
|
||||
pub fn substitute_by_known(&mut self, variable: &V, substitution: &T) {
|
||||
// Since we substitute by a known value, we do not need to update variable_occurrences.
|
||||
for item in self
|
||||
.variable_occurrences
|
||||
@@ -239,22 +249,18 @@ impl<T: FieldElement, V: Clone + Hash + Ord + Eq> IndexedConstraintSystem<T, V>
|
||||
&mut self,
|
||||
interaction_index: usize,
|
||||
field_index: usize,
|
||||
value: T,
|
||||
value: T::FieldType,
|
||||
) {
|
||||
let bus_interaction = &mut self.constraint_system.bus_interactions[interaction_index];
|
||||
let field = bus_interaction.fields_mut().nth(field_index).unwrap();
|
||||
*field = QuadraticSymbolicExpression::from_number(value);
|
||||
*field = GroupedExpression::from_number(value);
|
||||
}
|
||||
|
||||
/// Substitute an unknown variable by a QuadraticSymbolicExpression in the whole system.
|
||||
/// Substitute an unknown variable by a GroupedExpression in the whole system.
|
||||
///
|
||||
/// Note this does NOT work properly if the variable is used inside a
|
||||
/// known SymbolicExpression.
|
||||
pub fn substitute_by_unknown(
|
||||
&mut self,
|
||||
variable: &V,
|
||||
substitution: &QuadraticSymbolicExpression<T, V>,
|
||||
) {
|
||||
pub fn substitute_by_unknown(&mut self, variable: &V, substitution: &GroupedExpression<T, V>) {
|
||||
let items = self
|
||||
.variable_occurrences
|
||||
.get(variable)
|
||||
@@ -283,7 +289,15 @@ impl<T: FieldElement, V: Clone + Hash + Ord + Eq> IndexedConstraintSystem<T, V>
|
||||
/// The provided assignments lead to a contradiction in the constraint system.
|
||||
pub struct ContradictingConstraintError;
|
||||
|
||||
impl<T: FieldElement, V: Clone + Hash + Ord + Eq + Display> IndexedConstraintSystem<T, V> {
|
||||
impl<
|
||||
T: RuntimeConstant
|
||||
+ ReferencedSymbols<V>
|
||||
+ Substitutable<V>
|
||||
+ ExpressionConvertible<T::FieldType, V>
|
||||
+ Display,
|
||||
V: Clone + Hash + Ord + Eq + Display,
|
||||
> IndexedConstraintSystemGeneric<T, V>
|
||||
{
|
||||
/// Given a list of assignments, tries to extend it with more assignments, based on the
|
||||
/// constraints in the constraint system.
|
||||
/// Fails if any of the assignments *directly* contradicts any of the constraints.
|
||||
@@ -291,18 +305,17 @@ impl<T: FieldElement, V: Clone + Hash + Ord + Eq + Display> IndexedConstraintSys
|
||||
/// this function only does one step of the derivation.
|
||||
pub fn derive_more_assignments(
|
||||
&self,
|
||||
assignments: BTreeMap<V, T>,
|
||||
range_constraints: &impl RangeConstraintProvider<T, V>,
|
||||
bus_interaction_handler: &impl BusInteractionHandler<T>,
|
||||
) -> Result<BTreeMap<V, T>, ContradictingConstraintError> {
|
||||
assignments: BTreeMap<V, T::FieldType>,
|
||||
range_constraints: &impl RangeConstraintProvider<T::FieldType, V>,
|
||||
bus_interaction_handler: &impl BusInteractionHandler<T::FieldType>,
|
||||
) -> Result<BTreeMap<V, T::FieldType>, ContradictingConstraintError> {
|
||||
let effects = self
|
||||
.constraints_referencing_variables(assignments.keys().cloned())
|
||||
.map(|constraint| match constraint {
|
||||
ConstraintRef::AlgebraicConstraint(identity) => {
|
||||
let mut identity = identity.clone();
|
||||
for (variable, value) in assignments.iter() {
|
||||
identity
|
||||
.substitute_by_known(variable, &SymbolicExpression::Concrete(*value));
|
||||
identity.substitute_by_known(variable, &T::from(*value));
|
||||
}
|
||||
identity
|
||||
.solve(range_constraints)
|
||||
@@ -312,12 +325,9 @@ impl<T: FieldElement, V: Clone + Hash + Ord + Eq + Display> IndexedConstraintSys
|
||||
ConstraintRef::BusInteraction(bus_interaction) => {
|
||||
let mut bus_interaction = bus_interaction.clone();
|
||||
for (variable, value) in assignments.iter() {
|
||||
bus_interaction.fields_mut().for_each(|expr| {
|
||||
expr.substitute_by_known(
|
||||
variable,
|
||||
&SymbolicExpression::Concrete(*value),
|
||||
)
|
||||
})
|
||||
bus_interaction
|
||||
.fields_mut()
|
||||
.for_each(|expr| expr.substitute_by_known(variable, &T::from(*value)))
|
||||
}
|
||||
bus_interaction
|
||||
.solve(bus_interaction_handler, range_constraints)
|
||||
@@ -331,8 +341,8 @@ impl<T: FieldElement, V: Clone + Hash + Ord + Eq + Display> IndexedConstraintSys
|
||||
.into_iter()
|
||||
.flatten()
|
||||
.filter_map(|effect| {
|
||||
if let Effect::Assignment(variable, SymbolicExpression::Concrete(value)) = effect {
|
||||
Some((variable, value))
|
||||
if let EffectImpl::Assignment(variable, value) = effect {
|
||||
Some((variable, value.try_to_number()?))
|
||||
} else {
|
||||
None
|
||||
}
|
||||
@@ -353,8 +363,8 @@ impl<T: FieldElement, V: Clone + Hash + Ord + Eq + Display> IndexedConstraintSys
|
||||
|
||||
/// Returns a hash map mapping all unknown variables in the constraint system
|
||||
/// to the items they occur in.
|
||||
fn variable_occurrences<T: FieldElement, V: Hash + Eq + Clone + Ord>(
|
||||
constraint_system: &ConstraintSystem<T, V>,
|
||||
fn variable_occurrences<T: RuntimeConstant, V: Hash + Eq + Clone>(
|
||||
constraint_system: &ConstraintSystemGeneric<T, V>,
|
||||
) -> HashMap<V, Vec<ConstraintSystemItem>> {
|
||||
let occurrences_in_algebraic_constraints = constraint_system
|
||||
.algebraic_constraints
|
||||
@@ -382,11 +392,11 @@ fn variable_occurrences<T: FieldElement, V: Hash + Eq + Clone + Ord>(
|
||||
.into_group_map()
|
||||
}
|
||||
|
||||
fn substitute_by_known_in_item<T: FieldElement, V: Ord + Clone + Hash + Eq>(
|
||||
constraint_system: &mut ConstraintSystem<T, V>,
|
||||
fn substitute_by_known_in_item<T: RuntimeConstant + Substitutable<V>, V: Ord + Clone + Eq>(
|
||||
constraint_system: &mut ConstraintSystemGeneric<T, V>,
|
||||
item: ConstraintSystemItem,
|
||||
variable: &V,
|
||||
substitution: &SymbolicExpression<T, V>,
|
||||
substitution: &T,
|
||||
) {
|
||||
match item {
|
||||
ConstraintSystemItem::AlgebraicConstraint(i) => {
|
||||
@@ -400,11 +410,11 @@ fn substitute_by_known_in_item<T: FieldElement, V: Ord + Clone + Hash + Eq>(
|
||||
}
|
||||
}
|
||||
|
||||
fn substitute_by_unknown_in_item<T: FieldElement, V: Ord + Clone + Hash + Eq>(
|
||||
constraint_system: &mut ConstraintSystem<T, V>,
|
||||
fn substitute_by_unknown_in_item<T: RuntimeConstant + Substitutable<V>, V: Ord + Clone + Eq>(
|
||||
constraint_system: &mut ConstraintSystemGeneric<T, V>,
|
||||
item: ConstraintSystemItem,
|
||||
variable: &V,
|
||||
substitution: &QuadraticSymbolicExpression<T, V>,
|
||||
substitution: &GroupedExpression<T, V>,
|
||||
) {
|
||||
match item {
|
||||
ConstraintSystemItem::AlgebraicConstraint(i) => {
|
||||
@@ -419,7 +429,9 @@ fn substitute_by_unknown_in_item<T: FieldElement, V: Ord + Clone + Hash + Eq>(
|
||||
}
|
||||
}
|
||||
|
||||
impl<T: FieldElement, V: Clone + Ord + Display + Hash> Display for IndexedConstraintSystem<T, V> {
|
||||
impl<T: RuntimeConstant + Display, V: Clone + Ord + Display + Hash> Display
|
||||
for IndexedConstraintSystemGeneric<T, V>
|
||||
{
|
||||
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
|
||||
write!(f, "{}", self.constraint_system)
|
||||
}
|
||||
@@ -431,7 +443,7 @@ mod tests {
|
||||
|
||||
use super::*;
|
||||
|
||||
fn format_system(s: &IndexedConstraintSystem<GoldilocksField, &'static str>) -> String {
|
||||
fn format_system(s: &IndexedConstraintSystemGeneric<GoldilocksField, &'static str>) -> String {
|
||||
format!(
|
||||
"{} | {}",
|
||||
s.algebraic_constraints().iter().format(" | "),
|
||||
@@ -453,11 +465,11 @@ mod tests {
|
||||
|
||||
#[test]
|
||||
fn substitute_by_unknown() {
|
||||
type Qse = QuadraticSymbolicExpression<GoldilocksField, &'static str>;
|
||||
let x = Qse::from_unknown_variable("x");
|
||||
let y = Qse::from_unknown_variable("y");
|
||||
let z = Qse::from_unknown_variable("z");
|
||||
let mut s: IndexedConstraintSystem<_, _> = ConstraintSystem {
|
||||
type Ge = GroupedExpression<GoldilocksField, &'static str>;
|
||||
let x = Ge::from_unknown_variable("x");
|
||||
let y = Ge::from_unknown_variable("y");
|
||||
let z = Ge::from_unknown_variable("z");
|
||||
let mut s: IndexedConstraintSystemGeneric<_, _> = ConstraintSystemGeneric {
|
||||
algebraic_constraints: vec![
|
||||
x.clone() + y.clone(),
|
||||
x.clone() - z.clone(),
|
||||
@@ -471,13 +483,13 @@ mod tests {
|
||||
}
|
||||
.into();
|
||||
|
||||
s.substitute_by_unknown(&"x", &Qse::from_unknown_variable("z"));
|
||||
s.substitute_by_unknown(&"x", &Ge::from_unknown_variable("z"));
|
||||
|
||||
assert_eq!(format_system(&s), "y + z | 0 | y - z | z: y * [y, z]");
|
||||
|
||||
s.substitute_by_unknown(
|
||||
&"z",
|
||||
&(Qse::from_unknown_variable("x") + Qse::from_number(GoldilocksField::from(7))),
|
||||
&(Ge::from_unknown_variable("x") + Ge::from_number(GoldilocksField::from(7))),
|
||||
);
|
||||
|
||||
assert_eq!(
|
||||
@@ -488,11 +500,11 @@ mod tests {
|
||||
|
||||
#[test]
|
||||
fn retain_update_index() {
|
||||
type Qse = QuadraticSymbolicExpression<GoldilocksField, &'static str>;
|
||||
let x = Qse::from_unknown_variable("x");
|
||||
let y = Qse::from_unknown_variable("y");
|
||||
let z = Qse::from_unknown_variable("z");
|
||||
let mut s: IndexedConstraintSystem<_, _> = ConstraintSystem {
|
||||
type Ge = GroupedExpression<GoldilocksField, &'static str>;
|
||||
let x = Ge::from_unknown_variable("x");
|
||||
let y = Ge::from_unknown_variable("y");
|
||||
let z = Ge::from_unknown_variable("z");
|
||||
let mut s: IndexedConstraintSystemGeneric<_, _> = ConstraintSystemGeneric {
|
||||
algebraic_constraints: vec![
|
||||
x.clone() + y.clone(),
|
||||
x.clone() - z.clone(),
|
||||
|
||||
Reference in New Issue
Block a user