Linearizing solver. (#3140)

Adds a solver that introduces new variables for every non-linear
component in an algebraic constraint and also for every bus interaction
field that is not a constant or a variable already.

---------

Co-authored-by: Georg Wiese <georgwiese@gmail.com>
This commit is contained in:
chriseth
2025-08-15 15:21:11 +02:00
committed by GitHub
parent bdf3a5fefa
commit 5d0d97a269
8 changed files with 515 additions and 14 deletions

View File

@@ -21,6 +21,7 @@ bitvec = "1.0.1"
pretty_assertions = "1.4.0"
env_logger = "0.10.0"
test-log = "0.2.12"
expect-test = "1.5.1"
[package.metadata.cargo-udeps.ignore]
development = ["env_logger"]

View File

@@ -2,7 +2,7 @@ use std::{
collections::{BTreeMap, HashSet},
fmt::Display,
hash::Hash,
iter::once,
iter::{once, Sum},
ops::{Add, AddAssign, Mul, MulAssign, Neg, Sub},
};
@@ -238,6 +238,10 @@ impl<T: RuntimeConstant, V: Ord + Clone + Eq> GroupedExpression<T, V> {
(&self.quadratic, self.linear.iter(), &self.constant)
}
pub fn into_components(self) -> (Vec<(Self, Self)>, impl Iterator<Item = (V, T)>, T) {
(self.quadratic, self.linear.into_iter(), self.constant)
}
/// Computes the degree of a GroupedExpression in the unknown variables.
/// Note that it might overestimate the degree if the expression contains
/// terms that cancel each other out, e.g. `a * (b + 1) - a * b - a`.
@@ -1126,6 +1130,15 @@ impl<T: RuntimeConstant, V: Clone + Ord + Eq> MulAssign<&T> for GroupedExpressio
}
}
impl<T: RuntimeConstant, V: Clone + Ord + Eq> Sum for GroupedExpression<T, V> {
fn sum<I: Iterator<Item = Self>>(iter: I) -> Self {
iter.fold(Self::zero(), |mut acc, item| {
acc += item;
acc
})
}
}
impl<T: RuntimeConstant, V: Clone + Ord + Eq> Mul for GroupedExpression<T, V> {
type Output = GroupedExpression<T, V>;

View File

@@ -8,6 +8,7 @@ use crate::runtime_constant::{
};
use crate::solver::base::BaseSolver;
use crate::solver::boolean_extracted::BooleanExtractedSolver;
use crate::solver::linearized::LinearizedSolver;
use crate::solver::var_transformation::{VarTransformation, Variable};
use super::grouped_expression::{Error as QseError, RangeConstraintProvider};
@@ -19,6 +20,7 @@ use std::hash::Hash;
mod base;
mod boolean_extracted;
mod exhaustive_search;
mod linearized;
mod quadratic_equivalences;
mod var_transformation;
@@ -28,11 +30,12 @@ pub fn solve_system<T, V>(
bus_interaction_handler: impl BusInteractionHandler<T::FieldType>,
) -> Result<Vec<VariableAssignment<T, V>>, Error>
where
T: RuntimeConstant + VarTransformable<V, Variable<V>> + Display,
T: RuntimeConstant + VarTransformable<V, Variable<V>> + Hash + Display,
T::Transformed: RuntimeConstant<FieldType = T::FieldType>
+ VarTransformable<Variable<V>, V, Transformed = T>
+ ReferencedSymbols<Variable<V>>
+ Display
+ Hash
+ ExpressionConvertible<T::FieldType, Variable<V>>
+ Substitutable<Variable<V>>
+ Hash,
@@ -47,18 +50,19 @@ pub fn new_solver<T, V>(
bus_interaction_handler: impl BusInteractionHandler<T::FieldType>,
) -> impl Solver<T, V>
where
T: RuntimeConstant + VarTransformable<V, Variable<V>> + Display,
T: RuntimeConstant + VarTransformable<V, Variable<V>> + Hash + Display,
T::Transformed: RuntimeConstant<FieldType = T::FieldType>
+ VarTransformable<Variable<V>, V, Transformed = T>
+ ReferencedSymbols<Variable<V>>
+ Display
+ Hash
+ ExpressionConvertible<T::FieldType, Variable<V>>
+ Substitutable<Variable<V>>
+ Hash,
V: Ord + Clone + Hash + Eq + Display,
{
let mut solver = VarTransformation::new(BooleanExtractedSolver::new(BaseSolver::new(
bus_interaction_handler,
let mut solver = VarTransformation::new(BooleanExtractedSolver::new(LinearizedSolver::new(
BaseSolver::new(bus_interaction_handler),
)));
solver.add_algebraic_constraints(constraint_system.algebraic_constraints);
solver.add_bus_interactions(constraint_system.bus_interactions);

View File

@@ -0,0 +1,459 @@
use std::collections::HashMap;
use std::hash::Hash;
use std::iter;
use std::{collections::HashSet, fmt::Display};
use itertools::Itertools;
use crate::constraint_system::ConstraintSystem;
use crate::indexed_constraint_system::apply_substitutions;
use crate::runtime_constant::Substitutable;
use crate::solver::var_transformation::Variable;
use crate::solver::{Error, VariableAssignment};
use crate::{
constraint_system::BusInteraction,
grouped_expression::{GroupedExpression, RangeConstraintProvider},
range_constraint::RangeConstraint,
runtime_constant::{RuntimeConstant, VarTransformable},
solver::Solver,
};
/// A Solver that turns algebraic constraints into affine constraints
/// by introducing new variables for the non-affine parts.
/// It also replaces bus interaction fields by new variables if they are
/// not just variables or constants.
///
/// The original algebraic constraints are kept as well.
pub struct LinearizedSolver<T, V, S> {
solver: S,
linearizer: Linearizer<T, V>,
var_dispenser: LinearizedVarDispenser,
}
struct LinearizedVarDispenser {
next_var_id: usize,
}
impl LinearizedVarDispenser {
fn new() -> Self {
LinearizedVarDispenser { next_var_id: 0 }
}
fn next_var<V>(&mut self) -> Variable<V> {
let id = self.next_var_id;
self.next_var_id += 1;
Variable::Linearized(id)
}
/// Returns an iterator over all variables dispensed in the past.
fn all_dispensed_vars<V>(&self) -> impl Iterator<Item = Variable<V>> {
(0..self.next_var_id).map(Variable::Linearized)
}
}
impl<T, V, S> LinearizedSolver<T, V, S>
where
T: RuntimeConstant,
V: Clone + Eq,
S: Solver<T, V>,
{
pub fn new(solver: S) -> Self {
Self {
solver,
linearizer: Linearizer::default(),
var_dispenser: LinearizedVarDispenser::new(),
}
}
}
impl<T, V, S> RangeConstraintProvider<T::FieldType, Variable<V>>
for LinearizedSolver<T, Variable<V>, S>
where
T: RuntimeConstant,
S: RangeConstraintProvider<T::FieldType, Variable<V>>,
V: Clone,
{
fn get(&self, var: &Variable<V>) -> RangeConstraint<T::FieldType> {
self.solver.get(var)
}
}
impl<T: VarTransformable<V, Variable<V>>, V, S: Display> Display
for LinearizedSolver<T, Variable<V>, S>
{
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
write!(f, "{}", self.solver)
}
}
impl<T, V, S> Solver<T, Variable<V>> for LinearizedSolver<T, Variable<V>, S>
where
T: RuntimeConstant + Substitutable<Variable<V>> + Display + Hash,
V: Ord + Clone + Eq + Hash + Display,
S: Solver<T, Variable<V>>,
{
fn solve(&mut self) -> Result<Vec<VariableAssignment<T, Variable<V>>>, Error> {
let assignments = self.solver.solve()?;
// Apply the deduced assignments to the substitutions we performed.
// We assume that the user of the solver applies the assignments to
// their expressions and thus "incoming" expressions used in the functions
// `range_constraint_for_expression` and `are_expressions_known_to_be_different`
// will have the assignments applied.
self.linearizer.apply_assignments(&assignments);
Ok(assignments)
}
fn add_algebraic_constraints(
&mut self,
constraints: impl IntoIterator<Item = GroupedExpression<T, Variable<V>>>,
) {
let constraints = constraints
.into_iter()
.flat_map(|constr| {
// We always add the original constraint unmodified.
let mut constrs = vec![constr.clone()];
if !constr.is_affine() {
let linearized = self.linearizer.linearize(
constr,
&mut || self.var_dispenser.next_var(),
&mut constrs,
);
constrs.push(linearized);
}
constrs
})
.collect::<Vec<_>>();
self.solver.add_algebraic_constraints(constraints);
}
fn add_bus_interactions(
&mut self,
bus_interactions: impl IntoIterator<Item = BusInteraction<GroupedExpression<T, Variable<V>>>>,
) {
let mut constraints_to_add = vec![];
let bus_interactions = bus_interactions
.into_iter()
.map(|bus_interaction| {
bus_interaction
.fields()
.map(|expr| {
self.linearizer.substitute_by_var(
expr.clone(),
&mut || self.var_dispenser.next_var(),
&mut constraints_to_add,
)
})
.collect::<BusInteraction<_>>()
})
.collect_vec();
// We only substituted by a variable, but the substitution was not yet linearized.
self.add_algebraic_constraints(constraints_to_add);
self.solver.add_bus_interactions(bus_interactions);
}
fn add_range_constraint(
&mut self,
variable: &Variable<V>,
constraint: RangeConstraint<T::FieldType>,
) {
self.solver.add_range_constraint(variable, constraint);
}
fn retain_variables(&mut self, variables_to_keep: &HashSet<Variable<V>>) {
// There are constraints that only contain `Variable::Linearized` that
// connect quadratic terms with the original constraints. We could try to find
// those, but let's just keep all of them for now.
let mut variables_to_keep = variables_to_keep.clone();
variables_to_keep.extend(self.var_dispenser.all_dispensed_vars());
self.solver.retain_variables(&variables_to_keep);
}
fn range_constraint_for_expression(
&self,
expr: &GroupedExpression<T, Variable<V>>,
) -> RangeConstraint<T::FieldType> {
self.internalized_versions_of_expression(expr)
.fold(RangeConstraint::default(), |acc, expr| {
acc.conjunction(&self.solver.range_constraint_for_expression(&expr))
})
}
fn are_expressions_known_to_be_different(
&mut self,
a: &GroupedExpression<T, Variable<V>>,
b: &GroupedExpression<T, Variable<V>>,
) -> bool {
self.solver.are_expressions_known_to_be_different(a, b)
}
}
impl<T, V, S> LinearizedSolver<T, Variable<V>, S>
where
T: RuntimeConstant + Hash,
V: Ord + Clone + Eq + Hash,
{
/// Returns an iterator over expressions equivalent to `expr` with the idea that
/// they might allow to answer a query better or worse.
/// It usually returns the original expression, a single variable that it was
/// substituted into during a previous linearization and a previously linearized version.
fn internalized_versions_of_expression(
&self,
expr: &GroupedExpression<T, Variable<V>>,
) -> impl Iterator<Item = GroupedExpression<T, Variable<V>>> + Clone {
let direct = expr.clone();
// See if we have a direct substitution for the expression by a variable.
let simple_substituted = self.linearizer.try_substitute_by_existing_var(expr);
// Try to re-do the linearization
let substituted = self.linearizer.try_linearize_existing(expr.clone());
iter::once(direct)
.chain(simple_substituted)
.chain(substituted)
}
}
struct Linearizer<T, V> {
substitutions: HashMap<GroupedExpression<T, V>, V>,
}
impl<T, V> Default for Linearizer<T, V> {
fn default() -> Self {
Linearizer {
substitutions: HashMap::new(),
}
}
}
impl<T: RuntimeConstant + Hash, V: Clone + Eq + Ord + Hash> Linearizer<T, V> {
/// Linearizes the constraint by introducing new variables for
/// non-affine parts. The new constraints are appended to
/// `constraint_collection` and must be added to the system.
/// The linearized expression is returned.
fn linearize(
&mut self,
expr: GroupedExpression<T, V>,
var_dispenser: &mut impl FnMut() -> V,
constraint_collection: &mut impl Extend<GroupedExpression<T, V>>,
) -> GroupedExpression<T, V> {
if expr.is_affine() {
return expr;
}
let (quadratic, linear, constant) = expr.into_components();
quadratic
.into_iter()
.map(|(l, r)| {
let l =
self.linearize_and_substitute_by_var(l, var_dispenser, constraint_collection);
let r =
self.linearize_and_substitute_by_var(r, var_dispenser, constraint_collection);
self.substitute_by_var(l * r, var_dispenser, constraint_collection)
})
.chain(linear.map(|(v, c)| GroupedExpression::from_unknown_variable(v) * c))
.chain(std::iter::once(GroupedExpression::from_runtime_constant(
constant,
)))
.sum()
}
/// Tries to linearize the expression according to already existing substitutions.
fn try_linearize_existing(
&self,
expr: GroupedExpression<T, V>,
) -> Option<GroupedExpression<T, V>> {
if expr.is_affine() {
return Some(expr);
}
let (quadratic, linear, constant) = expr.into_components();
Some(
quadratic
.into_iter()
.map(|(l, r)| {
let l =
self.try_substitute_by_existing_var(&self.try_linearize_existing(l)?)?;
let r =
self.try_substitute_by_existing_var(&self.try_linearize_existing(r)?)?;
self.try_substitute_by_existing_var(&(l * r))
})
.collect::<Option<Vec<_>>>()?
.into_iter()
.chain(linear.map(|(v, c)| GroupedExpression::from_unknown_variable(v) * c))
.chain(std::iter::once(GroupedExpression::from_runtime_constant(
constant,
)))
.sum(),
)
}
/// Linearizes the expression and substitutes the expression by a single variable.
/// The substitution is not performed if the expression is a constant or a single
/// variable (without coefficient).
fn linearize_and_substitute_by_var(
&mut self,
expr: GroupedExpression<T, V>,
var_dispenser: &mut impl FnMut() -> V,
constraint_collection: &mut impl Extend<GroupedExpression<T, V>>,
) -> GroupedExpression<T, V> {
let linearized = self.linearize(expr, var_dispenser, constraint_collection);
self.substitute_by_var(linearized, var_dispenser, constraint_collection)
}
/// Substitutes the given expression by a single variable using the variable dispenser,
/// unless the expression is already just a single variable or constant. Re-uses substitutions
/// that were made in the past.
/// Adds the equality constraint to `constraint_collection` and returns the variable
/// as an expression.
fn substitute_by_var(
&mut self,
expr: GroupedExpression<T, V>,
var_dispenser: &mut impl FnMut() -> V,
constraint_collection: &mut impl Extend<GroupedExpression<T, V>>,
) -> GroupedExpression<T, V> {
if let Some(var) = self.try_substitute_by_existing_var(&expr) {
var
} else {
let var = var_dispenser();
self.substitutions.insert(expr.clone(), var.clone());
let var = GroupedExpression::from_unknown_variable(var);
constraint_collection.extend([expr - var.clone()]);
var
}
}
/// Tries to substitute the given expression by an existing variable.
fn try_substitute_by_existing_var(
&self,
expr: &GroupedExpression<T, V>,
) -> Option<GroupedExpression<T, V>> {
if expr.try_to_known().is_some() || expr.try_to_simple_unknown().is_some() {
Some(expr.clone())
} else {
self.substitutions
.get(expr)
.map(|var| GroupedExpression::from_unknown_variable(var.clone()))
}
}
}
impl<T: RuntimeConstant + Substitutable<V> + Hash, V: Clone + Eq + Ord + Hash> Linearizer<T, V> {
/// Applies the assignments to the stored substitutions.
fn apply_assignments(&mut self, assignments: &[VariableAssignment<T, V>]) {
if assignments.is_empty() {
return;
}
let (exprs, vars): (Vec<_>, Vec<_>) = self.substitutions.drain().unzip();
let exprs = apply_substitutions(
ConstraintSystem {
algebraic_constraints: exprs,
bus_interactions: vec![],
},
assignments.iter().cloned(),
)
.algebraic_constraints;
self.substitutions = exprs
.into_iter()
.zip_eq(vars)
.map(|(expr, var)| (expr, var.clone()))
.collect();
}
}
#[cfg(test)]
mod tests {
use expect_test::expect;
use powdr_number::GoldilocksField;
use super::*;
use crate::{constraint_system::DefaultBusInteractionHandler, solver::base::BaseSolver};
type Qse = GroupedExpression<GoldilocksField, Variable<&'static str>>;
fn var(name: &'static str) -> Qse {
GroupedExpression::from_unknown_variable(Variable::from(name))
}
fn constant(value: u64) -> Qse {
GroupedExpression::from_number(GoldilocksField::from(value))
}
#[test]
fn linearization() {
let mut var_counter = 0usize;
let mut linearizer = Linearizer::default();
let expr = var("x") + var("y") * (var("z") + constant(1)) * (var("x") - constant(1));
let mut constraints_to_add = vec![];
let linearized = linearizer.linearize(
expr,
&mut || {
let var = Variable::Linearized(var_counter);
var_counter += 1;
var
},
&mut constraints_to_add,
);
assert_eq!(linearized.to_string(), "x + lin_3");
assert_eq!(
constraints_to_add.into_iter().format("\n").to_string(),
"z - lin_0 + 1\n(y) * (lin_0) - lin_1\nx - lin_2 - 1\n(lin_1) * (lin_2) - lin_3"
);
}
#[test]
fn solver_transforms() {
let mut solver =
LinearizedSolver::new(BaseSolver::new(DefaultBusInteractionHandler::default()));
solver.add_algebraic_constraints(vec![
(var("x") + var("y")) * (var("z") + constant(1)) * (var("x") - constant(1)),
(var("a") + var("b")) * (var("c") - constant(2)),
]);
solver.add_bus_interactions(vec![BusInteraction {
bus_id: constant(1),
payload: vec![var("x") + var("y"), -var("a"), var("a")],
multiplicity: var("z") + constant(1),
}]);
// Below, it is important that in the bus interaction,
// `a` is not replaced and that the first payload re-uses the
// already linearized `x + y`.
expect!([r#"
((x + y) * (z + 1)) * (x - 1) = 0
x + y - lin_0 = 0
z - lin_1 + 1 = 0
(lin_0) * (lin_1) - lin_2 = 0
x - lin_3 - 1 = 0
(lin_2) * (lin_3) - lin_4 = 0
lin_4 = 0
(a + b) * (c - 2) = 0
a + b - lin_5 = 0
c - lin_6 - 2 = 0
(lin_5) * (lin_6) - lin_7 = 0
lin_7 = 0
-(a + lin_8) = 0
BusInteraction { bus_id: 1, multiplicity: lin_1, payload: lin_0, lin_8, a }"#])
.assert_eq(&solver.to_string());
let assignments = solver.solve().unwrap();
expect!([r#"
lin_4 = 0
lin_7 = 0"#])
.assert_eq(
&assignments
.iter()
.map(|(var, value)| format!("{var} = {value}"))
.join("\n"),
);
expect!([r#"
((x + y) * (z + 1)) * (x - 1) = 0
x + y - lin_0 = 0
z - lin_1 + 1 = 0
(lin_0) * (lin_1) - lin_2 = 0
x - lin_3 - 1 = 0
(lin_2) * (lin_3) = 0
0 = 0
(a + b) * (c - 2) = 0
a + b - lin_5 = 0
c - lin_6 - 2 = 0
(lin_5) * (lin_6) = 0
0 = 0
-(a + lin_8) = 0
BusInteraction { bus_id: 1, multiplicity: lin_1, payload: lin_0, lin_8, a }"#])
.assert_eq(&solver.to_string());
}
}

View File

@@ -17,7 +17,7 @@ pub enum Variable<V> {
/// A new boolean-constrained variable that was introduced by the solver.
Boolean(usize),
/// A new variable introduced by the linearizer.
Linear(usize),
Linearized(usize),
}
impl<V> From<V> for Variable<V> {
@@ -48,7 +48,7 @@ impl<V: Display> Display for Variable<V> {
match self {
Variable::Original(v) => write!(f, "{v}"),
Variable::Boolean(i) => write!(f, "bool_{i}"),
Variable::Linear(i) => write!(f, "lin_{i}"),
Variable::Linearized(i) => write!(f, "lin_{i}"),
}
}
}

View File

@@ -438,3 +438,28 @@ fn ternary_flags() {
vec![("is_load", 1.into())],
);
}
#[test]
fn bit_decomposition_bug() {
let algebraic_constraints = vec![
var("cmp_result_0") * (var("cmp_result_0") - constant(1)),
var("imm_0") - constant(8),
var("cmp_result_0") * var("imm_0")
- constant(4) * var("cmp_result_0")
- var("BusInteractionField(10, 2)")
+ constant(4),
(var("BusInteractionField(10, 2)") - constant(4))
* (var("BusInteractionField(10, 2)") - constant(8)),
];
let constraint_system = ConstraintSystem {
algebraic_constraints,
bus_interactions: vec![],
};
// The solver used to infer more assignments due to a bug
// in the bit decomposition logic.
assert_solve_result(
constraint_system,
DefaultBusInteractionHandler::default(),
vec![("imm_0", 8.into())],
);
}

View File

@@ -21,7 +21,11 @@ fn build_reparse_test(kind: &str, dir: &str) {
build_tests(kind, dir, "", "reparse")
}
const SLOW_LIST: [&str; 1] = ["keccakf16_test"];
const SLOW_LIST: [&str; 3] = [
"keccakf16_test",
"keccakf16_memory_test",
"keccakf32_memory_test",
];
#[allow(clippy::print_stdout)]
fn build_tests(kind: &str, dir: &str, sub_dir: &str, name: &str) {

View File

@@ -487,12 +487,7 @@ mod reparse {
/// but these tests panic if the field is too small. This is *probably*
/// fine, because all of these tests have a similar variant that does
/// run on Goldilocks.
const BLACKLIST: [&str; 4] = [
"std/poseidon_bn254_test.asm",
"std/split_bn254_test.asm",
"keccakf16_memory_test",
"keccakf32_memory_test",
];
const BLACKLIST: [&str; 2] = ["std/poseidon_bn254_test.asm", "std/split_bn254_test.asm"];
fn run_reparse_test(file: &str) {
run_reparse_test_with_blacklist(file, &BLACKLIST);