Generic indexed constraint system (#2981)

This commit is contained in:
chriseth
2025-07-03 12:02:05 +02:00
committed by GitHub
parent d3b6d40c7c
commit 718ea2b6b3

View File

@@ -5,33 +5,39 @@ use std::{
};
use itertools::Itertools;
use powdr_number::FieldElement;
use powdr_number::ExpressionConvertible;
use crate::{
constraint_system::{BusInteraction, BusInteractionHandler, ConstraintRef, ConstraintSystem},
effect::Effect,
grouped_expression::{QuadraticSymbolicExpression, RangeConstraintProvider},
constraint_system::{
BusInteraction, BusInteractionHandler, ConstraintRef, ConstraintSystemGeneric,
},
effect::EffectImpl,
grouped_expression::{GroupedExpression, RangeConstraintProvider},
runtime_constant::{ReferencedSymbols, RuntimeConstant, Substitutable},
symbolic_expression::SymbolicExpression,
};
/// Applies multiple substitutions to a ConstraintSystem in an efficient manner.
pub fn apply_substitutions<T: FieldElement, V: Hash + Eq + Clone + Ord>(
constraint_system: ConstraintSystem<T, V>,
substitutions: impl IntoIterator<Item = (V, QuadraticSymbolicExpression<T, V>)>,
) -> ConstraintSystem<T, V> {
let mut indexed_constraint_system = IndexedConstraintSystem::from(constraint_system);
pub fn apply_substitutions<T: RuntimeConstant + Substitutable<V>, V: Hash + Eq + Clone + Ord>(
constraint_system: ConstraintSystemGeneric<T, V>,
substitutions: impl IntoIterator<Item = (V, GroupedExpression<T, V>)>,
) -> ConstraintSystemGeneric<T, V> {
let mut indexed_constraint_system = IndexedConstraintSystemGeneric::from(constraint_system);
for (variable, substitution) in substitutions {
indexed_constraint_system.substitute_by_unknown(&variable, &substitution);
}
indexed_constraint_system.into()
}
pub type IndexedConstraintSystem<T, V> =
IndexedConstraintSystemGeneric<SymbolicExpression<T, V>, V>;
/// Structure on top of a [`ConstraintSystem`] that stores indices
/// to more efficiently update the constraints.
#[derive(Clone, Default)]
pub struct IndexedConstraintSystem<T: FieldElement, V: Clone + Eq> {
pub struct IndexedConstraintSystemGeneric<T: RuntimeConstant, V> {
/// The constraint system.
constraint_system: ConstraintSystem<T, V>,
constraint_system: ConstraintSystemGeneric<T, V>,
/// Stores where each unknown variable appears.
variable_occurrences: HashMap<V, Vec<ConstraintSystemItem>>,
}
@@ -42,49 +48,49 @@ enum ConstraintSystemItem {
BusInteraction(usize),
}
impl<T: FieldElement, V: Hash + Eq + Clone + Ord> From<ConstraintSystem<T, V>>
for IndexedConstraintSystem<T, V>
impl<T: RuntimeConstant, V: Hash + Eq + Clone + Ord> From<ConstraintSystemGeneric<T, V>>
for IndexedConstraintSystemGeneric<T, V>
{
fn from(constraint_system: ConstraintSystem<T, V>) -> Self {
fn from(constraint_system: ConstraintSystemGeneric<T, V>) -> Self {
let variable_occurrences = variable_occurrences(&constraint_system);
IndexedConstraintSystem {
IndexedConstraintSystemGeneric {
constraint_system,
variable_occurrences,
}
}
}
impl<T: FieldElement, V: Clone + Eq> From<IndexedConstraintSystem<T, V>>
for ConstraintSystem<T, V>
impl<T: RuntimeConstant, V: Clone + Eq> From<IndexedConstraintSystemGeneric<T, V>>
for ConstraintSystemGeneric<T, V>
{
fn from(indexed_constraint_system: IndexedConstraintSystem<T, V>) -> Self {
fn from(indexed_constraint_system: IndexedConstraintSystemGeneric<T, V>) -> Self {
indexed_constraint_system.constraint_system
}
}
impl<T: FieldElement, V: Clone + Eq> IndexedConstraintSystem<T, V> {
pub fn system(&self) -> &ConstraintSystem<T, V> {
impl<T: RuntimeConstant, V: Clone + Eq> IndexedConstraintSystemGeneric<T, V> {
pub fn system(&self) -> &ConstraintSystemGeneric<T, V> {
&self.constraint_system
}
pub fn algebraic_constraints(&self) -> &[QuadraticSymbolicExpression<T, V>] {
pub fn algebraic_constraints(&self) -> &[GroupedExpression<T, V>] {
&self.constraint_system.algebraic_constraints
}
pub fn bus_interactions(&self) -> &[BusInteraction<QuadraticSymbolicExpression<T, V>>] {
pub fn bus_interactions(&self) -> &[BusInteraction<GroupedExpression<T, V>>] {
&self.constraint_system.bus_interactions
}
/// Returns all expressions that appear in the constraint system, i.e. all algebraic
/// constraints and all expressions in bus interactions.
pub fn expressions(&self) -> impl Iterator<Item = &QuadraticSymbolicExpression<T, V>> {
pub fn expressions(&self) -> impl Iterator<Item = &GroupedExpression<T, V>> {
self.constraint_system.expressions()
}
/// Removes all constraints that do not fulfill the predicate.
pub fn retain_algebraic_constraints(
&mut self,
mut f: impl FnMut(&QuadraticSymbolicExpression<T, V>) -> bool,
mut f: impl FnMut(&GroupedExpression<T, V>) -> bool,
) {
retain(
&mut self.constraint_system.algebraic_constraints,
@@ -97,7 +103,7 @@ impl<T: FieldElement, V: Clone + Eq> IndexedConstraintSystem<T, V> {
/// Removes all bus interactions that do not fulfill the predicate.
pub fn retain_bus_interactions(
&mut self,
mut f: impl FnMut(&BusInteraction<QuadraticSymbolicExpression<T, V>>) -> bool,
mut f: impl FnMut(&BusInteraction<GroupedExpression<T, V>>) -> bool,
) {
retain(
&mut self.constraint_system.bus_interactions,
@@ -154,13 +160,13 @@ fn retain<V, Item>(
});
}
impl<T: FieldElement, V: Clone + Ord + Hash> IndexedConstraintSystem<T, V> {
impl<T: RuntimeConstant, V: Clone + Ord + Hash> IndexedConstraintSystemGeneric<T, V> {
/// Adds new algebraic constraints to the system.
pub fn add_algebraic_constraints(
&mut self,
constraints: impl IntoIterator<Item = QuadraticSymbolicExpression<T, V>>,
constraints: impl IntoIterator<Item = GroupedExpression<T, V>>,
) {
self.extend(ConstraintSystem {
self.extend(ConstraintSystemGeneric {
algebraic_constraints: constraints.into_iter().collect(),
bus_interactions: Vec::new(),
});
@@ -169,16 +175,16 @@ impl<T: FieldElement, V: Clone + Ord + Hash> IndexedConstraintSystem<T, V> {
/// Adds new bus interactions to the system.
pub fn add_bus_interactions(
&mut self,
bus_interactions: impl IntoIterator<Item = BusInteraction<QuadraticSymbolicExpression<T, V>>>,
bus_interactions: impl IntoIterator<Item = BusInteraction<GroupedExpression<T, V>>>,
) {
self.extend(ConstraintSystem {
self.extend(ConstraintSystemGeneric {
algebraic_constraints: Vec::new(),
bus_interactions: bus_interactions.into_iter().collect(),
});
}
/// Extends the constraint system by the constraints of another system.
pub fn extend(&mut self, system: ConstraintSystem<T, V>) {
pub fn extend(&mut self, system: ConstraintSystemGeneric<T, V>) {
let algebraic_constraint_count = self.constraint_system.algebraic_constraints.len();
let bus_interactions_count = self.constraint_system.bus_interactions.len();
// Compute the occurrences of the variables in the new constraints,
@@ -203,12 +209,12 @@ impl<T: FieldElement, V: Clone + Ord + Hash> IndexedConstraintSystem<T, V> {
}
}
impl<T: FieldElement, V: Clone + Hash + Ord + Eq> IndexedConstraintSystem<T, V> {
impl<T: RuntimeConstant, V: Clone + Hash + Ord + Eq> IndexedConstraintSystemGeneric<T, V> {
/// Returns a list of all constraints that contain at least one of the given variables.
pub fn constraints_referencing_variables<'a>(
&'a self,
variables: impl Iterator<Item = V> + 'a,
) -> impl Iterator<Item = ConstraintRef<'a, SymbolicExpression<T, V>, V>> + 'a {
) -> impl Iterator<Item = ConstraintRef<'a, T, V>> + 'a {
variables
.filter_map(|v| self.variable_occurrences.get(&v))
.flatten()
@@ -222,9 +228,13 @@ impl<T: FieldElement, V: Clone + Hash + Ord + Eq> IndexedConstraintSystem<T, V>
}
})
}
}
impl<T: RuntimeConstant + Substitutable<V>, V: Clone + Hash + Ord + Eq>
IndexedConstraintSystemGeneric<T, V>
{
/// Substitutes a variable with a symbolic expression in the whole system
pub fn substitute_by_known(&mut self, variable: &V, substitution: &SymbolicExpression<T, V>) {
pub fn substitute_by_known(&mut self, variable: &V, substitution: &T) {
// Since we substitute by a known value, we do not need to update variable_occurrences.
for item in self
.variable_occurrences
@@ -239,22 +249,18 @@ impl<T: FieldElement, V: Clone + Hash + Ord + Eq> IndexedConstraintSystem<T, V>
&mut self,
interaction_index: usize,
field_index: usize,
value: T,
value: T::FieldType,
) {
let bus_interaction = &mut self.constraint_system.bus_interactions[interaction_index];
let field = bus_interaction.fields_mut().nth(field_index).unwrap();
*field = QuadraticSymbolicExpression::from_number(value);
*field = GroupedExpression::from_number(value);
}
/// Substitute an unknown variable by a QuadraticSymbolicExpression in the whole system.
/// Substitute an unknown variable by a GroupedExpression in the whole system.
///
/// Note this does NOT work properly if the variable is used inside a
/// known SymbolicExpression.
pub fn substitute_by_unknown(
&mut self,
variable: &V,
substitution: &QuadraticSymbolicExpression<T, V>,
) {
pub fn substitute_by_unknown(&mut self, variable: &V, substitution: &GroupedExpression<T, V>) {
let items = self
.variable_occurrences
.get(variable)
@@ -283,7 +289,15 @@ impl<T: FieldElement, V: Clone + Hash + Ord + Eq> IndexedConstraintSystem<T, V>
/// The provided assignments lead to a contradiction in the constraint system.
pub struct ContradictingConstraintError;
impl<T: FieldElement, V: Clone + Hash + Ord + Eq + Display> IndexedConstraintSystem<T, V> {
impl<
T: RuntimeConstant
+ ReferencedSymbols<V>
+ Substitutable<V>
+ ExpressionConvertible<T::FieldType, V>
+ Display,
V: Clone + Hash + Ord + Eq + Display,
> IndexedConstraintSystemGeneric<T, V>
{
/// Given a list of assignments, tries to extend it with more assignments, based on the
/// constraints in the constraint system.
/// Fails if any of the assignments *directly* contradicts any of the constraints.
@@ -291,18 +305,17 @@ impl<T: FieldElement, V: Clone + Hash + Ord + Eq + Display> IndexedConstraintSys
/// this function only does one step of the derivation.
pub fn derive_more_assignments(
&self,
assignments: BTreeMap<V, T>,
range_constraints: &impl RangeConstraintProvider<T, V>,
bus_interaction_handler: &impl BusInteractionHandler<T>,
) -> Result<BTreeMap<V, T>, ContradictingConstraintError> {
assignments: BTreeMap<V, T::FieldType>,
range_constraints: &impl RangeConstraintProvider<T::FieldType, V>,
bus_interaction_handler: &impl BusInteractionHandler<T::FieldType>,
) -> Result<BTreeMap<V, T::FieldType>, ContradictingConstraintError> {
let effects = self
.constraints_referencing_variables(assignments.keys().cloned())
.map(|constraint| match constraint {
ConstraintRef::AlgebraicConstraint(identity) => {
let mut identity = identity.clone();
for (variable, value) in assignments.iter() {
identity
.substitute_by_known(variable, &SymbolicExpression::Concrete(*value));
identity.substitute_by_known(variable, &T::from(*value));
}
identity
.solve(range_constraints)
@@ -312,12 +325,9 @@ impl<T: FieldElement, V: Clone + Hash + Ord + Eq + Display> IndexedConstraintSys
ConstraintRef::BusInteraction(bus_interaction) => {
let mut bus_interaction = bus_interaction.clone();
for (variable, value) in assignments.iter() {
bus_interaction.fields_mut().for_each(|expr| {
expr.substitute_by_known(
variable,
&SymbolicExpression::Concrete(*value),
)
})
bus_interaction
.fields_mut()
.for_each(|expr| expr.substitute_by_known(variable, &T::from(*value)))
}
bus_interaction
.solve(bus_interaction_handler, range_constraints)
@@ -331,8 +341,8 @@ impl<T: FieldElement, V: Clone + Hash + Ord + Eq + Display> IndexedConstraintSys
.into_iter()
.flatten()
.filter_map(|effect| {
if let Effect::Assignment(variable, SymbolicExpression::Concrete(value)) = effect {
Some((variable, value))
if let EffectImpl::Assignment(variable, value) = effect {
Some((variable, value.try_to_number()?))
} else {
None
}
@@ -353,8 +363,8 @@ impl<T: FieldElement, V: Clone + Hash + Ord + Eq + Display> IndexedConstraintSys
/// Returns a hash map mapping all unknown variables in the constraint system
/// to the items they occur in.
fn variable_occurrences<T: FieldElement, V: Hash + Eq + Clone + Ord>(
constraint_system: &ConstraintSystem<T, V>,
fn variable_occurrences<T: RuntimeConstant, V: Hash + Eq + Clone>(
constraint_system: &ConstraintSystemGeneric<T, V>,
) -> HashMap<V, Vec<ConstraintSystemItem>> {
let occurrences_in_algebraic_constraints = constraint_system
.algebraic_constraints
@@ -382,11 +392,11 @@ fn variable_occurrences<T: FieldElement, V: Hash + Eq + Clone + Ord>(
.into_group_map()
}
fn substitute_by_known_in_item<T: FieldElement, V: Ord + Clone + Hash + Eq>(
constraint_system: &mut ConstraintSystem<T, V>,
fn substitute_by_known_in_item<T: RuntimeConstant + Substitutable<V>, V: Ord + Clone + Eq>(
constraint_system: &mut ConstraintSystemGeneric<T, V>,
item: ConstraintSystemItem,
variable: &V,
substitution: &SymbolicExpression<T, V>,
substitution: &T,
) {
match item {
ConstraintSystemItem::AlgebraicConstraint(i) => {
@@ -400,11 +410,11 @@ fn substitute_by_known_in_item<T: FieldElement, V: Ord + Clone + Hash + Eq>(
}
}
fn substitute_by_unknown_in_item<T: FieldElement, V: Ord + Clone + Hash + Eq>(
constraint_system: &mut ConstraintSystem<T, V>,
fn substitute_by_unknown_in_item<T: RuntimeConstant + Substitutable<V>, V: Ord + Clone + Eq>(
constraint_system: &mut ConstraintSystemGeneric<T, V>,
item: ConstraintSystemItem,
variable: &V,
substitution: &QuadraticSymbolicExpression<T, V>,
substitution: &GroupedExpression<T, V>,
) {
match item {
ConstraintSystemItem::AlgebraicConstraint(i) => {
@@ -419,7 +429,9 @@ fn substitute_by_unknown_in_item<T: FieldElement, V: Ord + Clone + Hash + Eq>(
}
}
impl<T: FieldElement, V: Clone + Ord + Display + Hash> Display for IndexedConstraintSystem<T, V> {
impl<T: RuntimeConstant + Display, V: Clone + Ord + Display + Hash> Display
for IndexedConstraintSystemGeneric<T, V>
{
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
write!(f, "{}", self.constraint_system)
}
@@ -431,7 +443,7 @@ mod tests {
use super::*;
fn format_system(s: &IndexedConstraintSystem<GoldilocksField, &'static str>) -> String {
fn format_system(s: &IndexedConstraintSystemGeneric<GoldilocksField, &'static str>) -> String {
format!(
"{} | {}",
s.algebraic_constraints().iter().format(" | "),
@@ -453,11 +465,11 @@ mod tests {
#[test]
fn substitute_by_unknown() {
type Qse = QuadraticSymbolicExpression<GoldilocksField, &'static str>;
let x = Qse::from_unknown_variable("x");
let y = Qse::from_unknown_variable("y");
let z = Qse::from_unknown_variable("z");
let mut s: IndexedConstraintSystem<_, _> = ConstraintSystem {
type Ge = GroupedExpression<GoldilocksField, &'static str>;
let x = Ge::from_unknown_variable("x");
let y = Ge::from_unknown_variable("y");
let z = Ge::from_unknown_variable("z");
let mut s: IndexedConstraintSystemGeneric<_, _> = ConstraintSystemGeneric {
algebraic_constraints: vec![
x.clone() + y.clone(),
x.clone() - z.clone(),
@@ -471,13 +483,13 @@ mod tests {
}
.into();
s.substitute_by_unknown(&"x", &Qse::from_unknown_variable("z"));
s.substitute_by_unknown(&"x", &Ge::from_unknown_variable("z"));
assert_eq!(format_system(&s), "y + z | 0 | y - z | z: y * [y, z]");
s.substitute_by_unknown(
&"z",
&(Qse::from_unknown_variable("x") + Qse::from_number(GoldilocksField::from(7))),
&(Ge::from_unknown_variable("x") + Ge::from_number(GoldilocksField::from(7))),
);
assert_eq!(
@@ -488,11 +500,11 @@ mod tests {
#[test]
fn retain_update_index() {
type Qse = QuadraticSymbolicExpression<GoldilocksField, &'static str>;
let x = Qse::from_unknown_variable("x");
let y = Qse::from_unknown_variable("y");
let z = Qse::from_unknown_variable("z");
let mut s: IndexedConstraintSystem<_, _> = ConstraintSystem {
type Ge = GroupedExpression<GoldilocksField, &'static str>;
let x = Ge::from_unknown_variable("x");
let y = Ge::from_unknown_variable("y");
let z = Ge::from_unknown_variable("z");
let mut s: IndexedConstraintSystemGeneric<_, _> = ConstraintSystemGeneric {
algebraic_constraints: vec![
x.clone() + y.clone(),
x.clone() - z.clone(),