diff --git a/constraint-solver/Cargo.toml b/constraint-solver/Cargo.toml index 9e2fe4f02..12a7dc89b 100644 --- a/constraint-solver/Cargo.toml +++ b/constraint-solver/Cargo.toml @@ -21,6 +21,7 @@ bitvec = "1.0.1" pretty_assertions = "1.4.0" env_logger = "0.10.0" test-log = "0.2.12" +expect-test = "1.5.1" [package.metadata.cargo-udeps.ignore] development = ["env_logger"] diff --git a/constraint-solver/src/grouped_expression.rs b/constraint-solver/src/grouped_expression.rs index 8a028139f..4ecf1c5ba 100644 --- a/constraint-solver/src/grouped_expression.rs +++ b/constraint-solver/src/grouped_expression.rs @@ -2,7 +2,7 @@ use std::{ collections::{BTreeMap, HashSet}, fmt::Display, hash::Hash, - iter::once, + iter::{once, Sum}, ops::{Add, AddAssign, Mul, MulAssign, Neg, Sub}, }; @@ -238,6 +238,10 @@ impl GroupedExpression { (&self.quadratic, self.linear.iter(), &self.constant) } + pub fn into_components(self) -> (Vec<(Self, Self)>, impl Iterator, T) { + (self.quadratic, self.linear.into_iter(), self.constant) + } + /// Computes the degree of a GroupedExpression in the unknown variables. /// Note that it might overestimate the degree if the expression contains /// terms that cancel each other out, e.g. `a * (b + 1) - a * b - a`. @@ -1126,6 +1130,15 @@ impl MulAssign<&T> for GroupedExpressio } } +impl Sum for GroupedExpression { + fn sum>(iter: I) -> Self { + iter.fold(Self::zero(), |mut acc, item| { + acc += item; + acc + }) + } +} + impl Mul for GroupedExpression { type Output = GroupedExpression; diff --git a/constraint-solver/src/solver.rs b/constraint-solver/src/solver.rs index c97f68dc5..ffc687af9 100644 --- a/constraint-solver/src/solver.rs +++ b/constraint-solver/src/solver.rs @@ -8,6 +8,7 @@ use crate::runtime_constant::{ }; use crate::solver::base::BaseSolver; use crate::solver::boolean_extracted::BooleanExtractedSolver; +use crate::solver::linearized::LinearizedSolver; use crate::solver::var_transformation::{VarTransformation, Variable}; use super::grouped_expression::{Error as QseError, RangeConstraintProvider}; @@ -19,6 +20,7 @@ use std::hash::Hash; mod base; mod boolean_extracted; mod exhaustive_search; +mod linearized; mod quadratic_equivalences; mod var_transformation; @@ -28,11 +30,12 @@ pub fn solve_system( bus_interaction_handler: impl BusInteractionHandler, ) -> Result>, Error> where - T: RuntimeConstant + VarTransformable> + Display, + T: RuntimeConstant + VarTransformable> + Hash + Display, T::Transformed: RuntimeConstant + VarTransformable, V, Transformed = T> + ReferencedSymbols> + Display + + Hash + ExpressionConvertible> + Substitutable> + Hash, @@ -47,18 +50,19 @@ pub fn new_solver( bus_interaction_handler: impl BusInteractionHandler, ) -> impl Solver where - T: RuntimeConstant + VarTransformable> + Display, + T: RuntimeConstant + VarTransformable> + Hash + Display, T::Transformed: RuntimeConstant + VarTransformable, V, Transformed = T> + ReferencedSymbols> + Display + + Hash + ExpressionConvertible> + Substitutable> + Hash, V: Ord + Clone + Hash + Eq + Display, { - let mut solver = VarTransformation::new(BooleanExtractedSolver::new(BaseSolver::new( - bus_interaction_handler, + let mut solver = VarTransformation::new(BooleanExtractedSolver::new(LinearizedSolver::new( + BaseSolver::new(bus_interaction_handler), ))); solver.add_algebraic_constraints(constraint_system.algebraic_constraints); solver.add_bus_interactions(constraint_system.bus_interactions); diff --git a/constraint-solver/src/solver/linearized.rs b/constraint-solver/src/solver/linearized.rs new file mode 100644 index 000000000..05f32078c --- /dev/null +++ b/constraint-solver/src/solver/linearized.rs @@ -0,0 +1,459 @@ +use std::collections::HashMap; +use std::hash::Hash; +use std::iter; +use std::{collections::HashSet, fmt::Display}; + +use itertools::Itertools; + +use crate::constraint_system::ConstraintSystem; +use crate::indexed_constraint_system::apply_substitutions; +use crate::runtime_constant::Substitutable; +use crate::solver::var_transformation::Variable; +use crate::solver::{Error, VariableAssignment}; +use crate::{ + constraint_system::BusInteraction, + grouped_expression::{GroupedExpression, RangeConstraintProvider}, + range_constraint::RangeConstraint, + runtime_constant::{RuntimeConstant, VarTransformable}, + solver::Solver, +}; + +/// A Solver that turns algebraic constraints into affine constraints +/// by introducing new variables for the non-affine parts. +/// It also replaces bus interaction fields by new variables if they are +/// not just variables or constants. +/// +/// The original algebraic constraints are kept as well. +pub struct LinearizedSolver { + solver: S, + linearizer: Linearizer, + var_dispenser: LinearizedVarDispenser, +} + +struct LinearizedVarDispenser { + next_var_id: usize, +} + +impl LinearizedVarDispenser { + fn new() -> Self { + LinearizedVarDispenser { next_var_id: 0 } + } + + fn next_var(&mut self) -> Variable { + let id = self.next_var_id; + self.next_var_id += 1; + Variable::Linearized(id) + } + + /// Returns an iterator over all variables dispensed in the past. + fn all_dispensed_vars(&self) -> impl Iterator> { + (0..self.next_var_id).map(Variable::Linearized) + } +} + +impl LinearizedSolver +where + T: RuntimeConstant, + V: Clone + Eq, + S: Solver, +{ + pub fn new(solver: S) -> Self { + Self { + solver, + linearizer: Linearizer::default(), + var_dispenser: LinearizedVarDispenser::new(), + } + } +} + +impl RangeConstraintProvider> + for LinearizedSolver, S> +where + T: RuntimeConstant, + S: RangeConstraintProvider>, + V: Clone, +{ + fn get(&self, var: &Variable) -> RangeConstraint { + self.solver.get(var) + } +} + +impl>, V, S: Display> Display + for LinearizedSolver, S> +{ + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + write!(f, "{}", self.solver) + } +} + +impl Solver> for LinearizedSolver, S> +where + T: RuntimeConstant + Substitutable> + Display + Hash, + V: Ord + Clone + Eq + Hash + Display, + S: Solver>, +{ + fn solve(&mut self) -> Result>>, Error> { + let assignments = self.solver.solve()?; + // Apply the deduced assignments to the substitutions we performed. + // We assume that the user of the solver applies the assignments to + // their expressions and thus "incoming" expressions used in the functions + // `range_constraint_for_expression` and `are_expressions_known_to_be_different` + // will have the assignments applied. + self.linearizer.apply_assignments(&assignments); + Ok(assignments) + } + + fn add_algebraic_constraints( + &mut self, + constraints: impl IntoIterator>>, + ) { + let constraints = constraints + .into_iter() + .flat_map(|constr| { + // We always add the original constraint unmodified. + let mut constrs = vec![constr.clone()]; + if !constr.is_affine() { + let linearized = self.linearizer.linearize( + constr, + &mut || self.var_dispenser.next_var(), + &mut constrs, + ); + constrs.push(linearized); + } + constrs + }) + .collect::>(); + self.solver.add_algebraic_constraints(constraints); + } + + fn add_bus_interactions( + &mut self, + bus_interactions: impl IntoIterator>>>, + ) { + let mut constraints_to_add = vec![]; + let bus_interactions = bus_interactions + .into_iter() + .map(|bus_interaction| { + bus_interaction + .fields() + .map(|expr| { + self.linearizer.substitute_by_var( + expr.clone(), + &mut || self.var_dispenser.next_var(), + &mut constraints_to_add, + ) + }) + .collect::>() + }) + .collect_vec(); + // We only substituted by a variable, but the substitution was not yet linearized. + self.add_algebraic_constraints(constraints_to_add); + self.solver.add_bus_interactions(bus_interactions); + } + + fn add_range_constraint( + &mut self, + variable: &Variable, + constraint: RangeConstraint, + ) { + self.solver.add_range_constraint(variable, constraint); + } + + fn retain_variables(&mut self, variables_to_keep: &HashSet>) { + // There are constraints that only contain `Variable::Linearized` that + // connect quadratic terms with the original constraints. We could try to find + // those, but let's just keep all of them for now. + let mut variables_to_keep = variables_to_keep.clone(); + variables_to_keep.extend(self.var_dispenser.all_dispensed_vars()); + self.solver.retain_variables(&variables_to_keep); + } + + fn range_constraint_for_expression( + &self, + expr: &GroupedExpression>, + ) -> RangeConstraint { + self.internalized_versions_of_expression(expr) + .fold(RangeConstraint::default(), |acc, expr| { + acc.conjunction(&self.solver.range_constraint_for_expression(&expr)) + }) + } + + fn are_expressions_known_to_be_different( + &mut self, + a: &GroupedExpression>, + b: &GroupedExpression>, + ) -> bool { + self.solver.are_expressions_known_to_be_different(a, b) + } +} + +impl LinearizedSolver, S> +where + T: RuntimeConstant + Hash, + V: Ord + Clone + Eq + Hash, +{ + /// Returns an iterator over expressions equivalent to `expr` with the idea that + /// they might allow to answer a query better or worse. + /// It usually returns the original expression, a single variable that it was + /// substituted into during a previous linearization and a previously linearized version. + fn internalized_versions_of_expression( + &self, + expr: &GroupedExpression>, + ) -> impl Iterator>> + Clone { + let direct = expr.clone(); + // See if we have a direct substitution for the expression by a variable. + let simple_substituted = self.linearizer.try_substitute_by_existing_var(expr); + // Try to re-do the linearization + let substituted = self.linearizer.try_linearize_existing(expr.clone()); + iter::once(direct) + .chain(simple_substituted) + .chain(substituted) + } +} + +struct Linearizer { + substitutions: HashMap, V>, +} + +impl Default for Linearizer { + fn default() -> Self { + Linearizer { + substitutions: HashMap::new(), + } + } +} + +impl Linearizer { + /// Linearizes the constraint by introducing new variables for + /// non-affine parts. The new constraints are appended to + /// `constraint_collection` and must be added to the system. + /// The linearized expression is returned. + fn linearize( + &mut self, + expr: GroupedExpression, + var_dispenser: &mut impl FnMut() -> V, + constraint_collection: &mut impl Extend>, + ) -> GroupedExpression { + if expr.is_affine() { + return expr; + } + let (quadratic, linear, constant) = expr.into_components(); + quadratic + .into_iter() + .map(|(l, r)| { + let l = + self.linearize_and_substitute_by_var(l, var_dispenser, constraint_collection); + let r = + self.linearize_and_substitute_by_var(r, var_dispenser, constraint_collection); + self.substitute_by_var(l * r, var_dispenser, constraint_collection) + }) + .chain(linear.map(|(v, c)| GroupedExpression::from_unknown_variable(v) * c)) + .chain(std::iter::once(GroupedExpression::from_runtime_constant( + constant, + ))) + .sum() + } + + /// Tries to linearize the expression according to already existing substitutions. + fn try_linearize_existing( + &self, + expr: GroupedExpression, + ) -> Option> { + if expr.is_affine() { + return Some(expr); + } + let (quadratic, linear, constant) = expr.into_components(); + Some( + quadratic + .into_iter() + .map(|(l, r)| { + let l = + self.try_substitute_by_existing_var(&self.try_linearize_existing(l)?)?; + let r = + self.try_substitute_by_existing_var(&self.try_linearize_existing(r)?)?; + self.try_substitute_by_existing_var(&(l * r)) + }) + .collect::>>()? + .into_iter() + .chain(linear.map(|(v, c)| GroupedExpression::from_unknown_variable(v) * c)) + .chain(std::iter::once(GroupedExpression::from_runtime_constant( + constant, + ))) + .sum(), + ) + } + + /// Linearizes the expression and substitutes the expression by a single variable. + /// The substitution is not performed if the expression is a constant or a single + /// variable (without coefficient). + fn linearize_and_substitute_by_var( + &mut self, + expr: GroupedExpression, + var_dispenser: &mut impl FnMut() -> V, + constraint_collection: &mut impl Extend>, + ) -> GroupedExpression { + let linearized = self.linearize(expr, var_dispenser, constraint_collection); + self.substitute_by_var(linearized, var_dispenser, constraint_collection) + } + + /// Substitutes the given expression by a single variable using the variable dispenser, + /// unless the expression is already just a single variable or constant. Re-uses substitutions + /// that were made in the past. + /// Adds the equality constraint to `constraint_collection` and returns the variable + /// as an expression. + fn substitute_by_var( + &mut self, + expr: GroupedExpression, + var_dispenser: &mut impl FnMut() -> V, + constraint_collection: &mut impl Extend>, + ) -> GroupedExpression { + if let Some(var) = self.try_substitute_by_existing_var(&expr) { + var + } else { + let var = var_dispenser(); + self.substitutions.insert(expr.clone(), var.clone()); + let var = GroupedExpression::from_unknown_variable(var); + constraint_collection.extend([expr - var.clone()]); + var + } + } + + /// Tries to substitute the given expression by an existing variable. + fn try_substitute_by_existing_var( + &self, + expr: &GroupedExpression, + ) -> Option> { + if expr.try_to_known().is_some() || expr.try_to_simple_unknown().is_some() { + Some(expr.clone()) + } else { + self.substitutions + .get(expr) + .map(|var| GroupedExpression::from_unknown_variable(var.clone())) + } + } +} + +impl + Hash, V: Clone + Eq + Ord + Hash> Linearizer { + /// Applies the assignments to the stored substitutions. + fn apply_assignments(&mut self, assignments: &[VariableAssignment]) { + if assignments.is_empty() { + return; + } + let (exprs, vars): (Vec<_>, Vec<_>) = self.substitutions.drain().unzip(); + let exprs = apply_substitutions( + ConstraintSystem { + algebraic_constraints: exprs, + bus_interactions: vec![], + }, + assignments.iter().cloned(), + ) + .algebraic_constraints; + self.substitutions = exprs + .into_iter() + .zip_eq(vars) + .map(|(expr, var)| (expr, var.clone())) + .collect(); + } +} + +#[cfg(test)] +mod tests { + use expect_test::expect; + use powdr_number::GoldilocksField; + + use super::*; + use crate::{constraint_system::DefaultBusInteractionHandler, solver::base::BaseSolver}; + + type Qse = GroupedExpression>; + + fn var(name: &'static str) -> Qse { + GroupedExpression::from_unknown_variable(Variable::from(name)) + } + + fn constant(value: u64) -> Qse { + GroupedExpression::from_number(GoldilocksField::from(value)) + } + + #[test] + fn linearization() { + let mut var_counter = 0usize; + let mut linearizer = Linearizer::default(); + let expr = var("x") + var("y") * (var("z") + constant(1)) * (var("x") - constant(1)); + let mut constraints_to_add = vec![]; + let linearized = linearizer.linearize( + expr, + &mut || { + let var = Variable::Linearized(var_counter); + var_counter += 1; + var + }, + &mut constraints_to_add, + ); + assert_eq!(linearized.to_string(), "x + lin_3"); + assert_eq!( + constraints_to_add.into_iter().format("\n").to_string(), + "z - lin_0 + 1\n(y) * (lin_0) - lin_1\nx - lin_2 - 1\n(lin_1) * (lin_2) - lin_3" + ); + } + + #[test] + fn solver_transforms() { + let mut solver = + LinearizedSolver::new(BaseSolver::new(DefaultBusInteractionHandler::default())); + solver.add_algebraic_constraints(vec![ + (var("x") + var("y")) * (var("z") + constant(1)) * (var("x") - constant(1)), + (var("a") + var("b")) * (var("c") - constant(2)), + ]); + solver.add_bus_interactions(vec![BusInteraction { + bus_id: constant(1), + payload: vec![var("x") + var("y"), -var("a"), var("a")], + multiplicity: var("z") + constant(1), + }]); + // Below, it is important that in the bus interaction, + // `a` is not replaced and that the first payload re-uses the + // already linearized `x + y`. + expect!([r#" + ((x + y) * (z + 1)) * (x - 1) = 0 + x + y - lin_0 = 0 + z - lin_1 + 1 = 0 + (lin_0) * (lin_1) - lin_2 = 0 + x - lin_3 - 1 = 0 + (lin_2) * (lin_3) - lin_4 = 0 + lin_4 = 0 + (a + b) * (c - 2) = 0 + a + b - lin_5 = 0 + c - lin_6 - 2 = 0 + (lin_5) * (lin_6) - lin_7 = 0 + lin_7 = 0 + -(a + lin_8) = 0 + BusInteraction { bus_id: 1, multiplicity: lin_1, payload: lin_0, lin_8, a }"#]) + .assert_eq(&solver.to_string()); + let assignments = solver.solve().unwrap(); + expect!([r#" + lin_4 = 0 + lin_7 = 0"#]) + .assert_eq( + &assignments + .iter() + .map(|(var, value)| format!("{var} = {value}")) + .join("\n"), + ); + + expect!([r#" + ((x + y) * (z + 1)) * (x - 1) = 0 + x + y - lin_0 = 0 + z - lin_1 + 1 = 0 + (lin_0) * (lin_1) - lin_2 = 0 + x - lin_3 - 1 = 0 + (lin_2) * (lin_3) = 0 + 0 = 0 + (a + b) * (c - 2) = 0 + a + b - lin_5 = 0 + c - lin_6 - 2 = 0 + (lin_5) * (lin_6) = 0 + 0 = 0 + -(a + lin_8) = 0 + BusInteraction { bus_id: 1, multiplicity: lin_1, payload: lin_0, lin_8, a }"#]) + .assert_eq(&solver.to_string()); + } +} diff --git a/constraint-solver/src/solver/var_transformation.rs b/constraint-solver/src/solver/var_transformation.rs index 298e720cf..20a1281db 100644 --- a/constraint-solver/src/solver/var_transformation.rs +++ b/constraint-solver/src/solver/var_transformation.rs @@ -17,7 +17,7 @@ pub enum Variable { /// A new boolean-constrained variable that was introduced by the solver. Boolean(usize), /// A new variable introduced by the linearizer. - Linear(usize), + Linearized(usize), } impl From for Variable { @@ -48,7 +48,7 @@ impl Display for Variable { match self { Variable::Original(v) => write!(f, "{v}"), Variable::Boolean(i) => write!(f, "bool_{i}"), - Variable::Linear(i) => write!(f, "lin_{i}"), + Variable::Linearized(i) => write!(f, "lin_{i}"), } } } diff --git a/constraint-solver/tests/solver.rs b/constraint-solver/tests/solver.rs index 231848b2a..c6458581b 100644 --- a/constraint-solver/tests/solver.rs +++ b/constraint-solver/tests/solver.rs @@ -438,3 +438,28 @@ fn ternary_flags() { vec![("is_load", 1.into())], ); } + +#[test] +fn bit_decomposition_bug() { + let algebraic_constraints = vec![ + var("cmp_result_0") * (var("cmp_result_0") - constant(1)), + var("imm_0") - constant(8), + var("cmp_result_0") * var("imm_0") + - constant(4) * var("cmp_result_0") + - var("BusInteractionField(10, 2)") + + constant(4), + (var("BusInteractionField(10, 2)") - constant(4)) + * (var("BusInteractionField(10, 2)") - constant(8)), + ]; + let constraint_system = ConstraintSystem { + algebraic_constraints, + bus_interactions: vec![], + }; + // The solver used to infer more assignments due to a bug + // in the bit decomposition logic. + assert_solve_result( + constraint_system, + DefaultBusInteractionHandler::default(), + vec![("imm_0", 8.into())], + ); +} diff --git a/pipeline/build.rs b/pipeline/build.rs index 1c735a18a..add4d1264 100644 --- a/pipeline/build.rs +++ b/pipeline/build.rs @@ -21,7 +21,11 @@ fn build_reparse_test(kind: &str, dir: &str) { build_tests(kind, dir, "", "reparse") } -const SLOW_LIST: [&str; 1] = ["keccakf16_test"]; +const SLOW_LIST: [&str; 3] = [ + "keccakf16_test", + "keccakf16_memory_test", + "keccakf32_memory_test", +]; #[allow(clippy::print_stdout)] fn build_tests(kind: &str, dir: &str, sub_dir: &str, name: &str) { diff --git a/pipeline/tests/powdr_std.rs b/pipeline/tests/powdr_std.rs index 964b5900b..296d73610 100644 --- a/pipeline/tests/powdr_std.rs +++ b/pipeline/tests/powdr_std.rs @@ -487,12 +487,7 @@ mod reparse { /// but these tests panic if the field is too small. This is *probably* /// fine, because all of these tests have a similar variant that does /// run on Goldilocks. - const BLACKLIST: [&str; 4] = [ - "std/poseidon_bn254_test.asm", - "std/split_bn254_test.asm", - "keccakf16_memory_test", - "keccakf32_memory_test", - ]; + const BLACKLIST: [&str; 2] = ["std/poseidon_bn254_test.asm", "std/split_bn254_test.asm"]; fn run_reparse_test(file: &str) { run_reparse_test_with_blacklist(file, &BLACKLIST);