From a22065ee2d864849b7b14c9ac88f67f2015369ac Mon Sep 17 00:00:00 2001 From: Georg Wiese Date: Mon, 30 Jun 2025 17:47:12 +0200 Subject: [PATCH] Refactor `QuadraticSymbolicExpression` (#2923) This PR refactors `QuadraticSymbolicExpression`: I removed the reference to [`SymbolicExpression`](https://github.com/powdr-labs/powdr/blob/f60e9d9f699810b565b077911c3836bcb1f9a356/constraint-solver/src/quadratic_symbolic_expression.rs#L70). It can represent values that are known at runtime, represented as an expression. We don't need this in the context of APC: Variables are either unknown or known at compile time, and therefore can be represented as a `FieldElement` instead. The idea is to introduce [this trait](https://github.com/powdr-labs/powdr/blob/4989be08f343ea4b8e8d8a8e30a97b247b9b1a71/constraint-solver/src/runtime_constant.rs#L11), which is implemented by `SymbolicExpression` and any `T: FieldElement`. This way, we can continue to use the solver with `QuadraticSymbolicExpression, V>` (equivalent to the current `QuadraticSymbolicExpression`) in the old JIT pipeline and use the simpler `QuadraticSymbolicExpression` in the APC pipeline. --------- Co-authored-by: chriseth --- .../algebraic_expression_conversion.rs | 31 +- autoprecompiles/Cargo.toml | 1 + .../src/bitwise_lookup_optimizer.rs | 19 +- autoprecompiles/src/constraint_optimizer.rs | 5 +- autoprecompiles/src/expression_conversion.rs | 47 +- autoprecompiles/src/memory_optimizer.rs | 4 +- autoprecompiles/src/optimizer.rs | 3 +- constraint-solver/src/constraint_system.rs | 3 +- constraint-solver/src/effect.rs | 64 +-- .../src/indexed_constraint_system.rs | 7 +- .../src/journaling_constraint_system.rs | 4 +- constraint-solver/src/lib.rs | 1 + .../src/quadratic_symbolic_expression.rs | 412 ++++++++++-------- constraint-solver/src/runtime_constant.rs | 113 +++++ constraint-solver/src/solver.rs | 7 +- .../bus_interaction_variable_wrapper.rs | 6 +- .../src/solver/quadratic_equivalences.rs | 1 + constraint-solver/src/symbolic_expression.rs | 217 ++++++--- .../src/symbolic_to_quadratic.rs | 47 -- constraint-solver/src/test_utils.rs | 2 +- constraint-solver/tests/solver.rs | 4 +- .../src/witgen/jit/block_machine_processor.rs | 7 +- executor/src/witgen/jit/compiler.rs | 6 +- executor/src/witgen/jit/debug_formatter.rs | 4 +- executor/src/witgen/jit/effect.rs | 28 +- executor/src/witgen/jit/identity_queue.rs | 1 + executor/src/witgen/jit/interpreter.rs | 2 +- executor/src/witgen/jit/processor.rs | 5 +- executor/src/witgen/jit/witgen_inference.rs | 8 +- expression/src/conversion.rs | 38 -- expression/src/lib.rs | 37 +- number/src/expression_convertible.rs | 49 +++ number/src/lib.rs | 3 + pilopt/src/qse_opt.rs | 6 +- 34 files changed, 734 insertions(+), 458 deletions(-) create mode 100644 constraint-solver/src/runtime_constant.rs delete mode 100644 expression/src/conversion.rs create mode 100644 number/src/expression_convertible.rs diff --git a/ast/src/analyzed/algebraic_expression_conversion.rs b/ast/src/analyzed/algebraic_expression_conversion.rs index 98d875426..d0a789841 100644 --- a/ast/src/analyzed/algebraic_expression_conversion.rs +++ b/ast/src/analyzed/algebraic_expression_conversion.rs @@ -1,3 +1,4 @@ +use num_traits::One; use powdr_number::{FieldElement, LargeInt}; use super::{ @@ -5,7 +6,8 @@ use super::{ AlgebraicUnaryOperation, AlgebraicUnaryOperator, Challenge, }; -pub trait TerminalConverter { +pub trait TerminalConverter { + fn convert_number(&mut self, number: &T) -> Target; fn convert_reference(&mut self, reference: &AlgebraicReference) -> Target; fn convert_public_reference(&mut self, reference: &str) -> Target; fn convert_challenge(&mut self, challenge: &Challenge) -> Target; @@ -15,11 +17,11 @@ pub trait TerminalConverter { /// The `terminal_converter` is used to convert the terminal nodes of the expression. pub fn convert( expr: &AlgebraicExpression, - terminal_converter: &mut impl TerminalConverter, + terminal_converter: &mut impl TerminalConverter, ) -> Target where - Target: From - + Clone + Target: Clone + + One + std::ops::Add + std::ops::Sub + std::ops::Mul @@ -29,7 +31,7 @@ where AlgebraicExpression::Reference(r) => terminal_converter.convert_reference(r), AlgebraicExpression::PublicReference(r) => terminal_converter.convert_public_reference(r), AlgebraicExpression::Challenge(c) => terminal_converter.convert_challenge(c), - AlgebraicExpression::Number(n) => (*n).into(), + AlgebraicExpression::Number(n) => terminal_converter.convert_number(n), AlgebraicExpression::BinaryOperation(AlgebraicBinaryOperation { left, op, right }) => { if *op == AlgebraicBinaryOperator::Pow { let AlgebraicExpression::Number(exponent) = right.as_ref() else { @@ -62,13 +64,12 @@ where } } -fn apply_pow(v: &Target, exponent: u32) -> Target +fn apply_pow(v: &Target, exponent: u32) -> Target where - T: From, - Target: From + Clone + std::ops::Mul, + Target: Clone + One + std::ops::Mul, { if exponent == 0 { - Target::from(T::from(1)) + Target::one() } else if exponent & 1 == 1 { let r: Target = apply_pow(v, exponent >> 1); (r.clone() * r) * v.clone() @@ -85,11 +86,11 @@ mod test { #[test] fn test_apply_pow() { let v = 9u64; - assert_eq!(apply_pow::(&v, 0), 1); - assert_eq!(apply_pow::(&v, 1), 9); - assert_eq!(apply_pow::(&v, 2), 9 * 9); - assert_eq!(apply_pow::(&v, 3), 9 * 9 * 9); - assert_eq!(apply_pow::(&v, 4), 9 * 9 * 9 * 9); - assert_eq!(apply_pow::(&v, 5), 9 * 9 * 9 * 9 * 9); + assert_eq!(apply_pow::(&v, 0), 1); + assert_eq!(apply_pow::(&v, 1), 9); + assert_eq!(apply_pow::(&v, 2), 9 * 9); + assert_eq!(apply_pow::(&v, 3), 9 * 9 * 9); + assert_eq!(apply_pow::(&v, 4), 9 * 9 * 9 * 9); + assert_eq!(apply_pow::(&v, 5), 9 * 9 * 9 * 9 * 9); } } diff --git a/autoprecompiles/Cargo.toml b/autoprecompiles/Cargo.toml index 4825618b2..d70a5c822 100644 --- a/autoprecompiles/Cargo.toml +++ b/autoprecompiles/Cargo.toml @@ -13,6 +13,7 @@ powdr-constraint-solver.workspace = true itertools = "0.13" log = "0.4.18" +num-traits = "0.2.15" serde = "1.0.218" [package.metadata.cargo-udeps.ignore] diff --git a/autoprecompiles/src/bitwise_lookup_optimizer.rs b/autoprecompiles/src/bitwise_lookup_optimizer.rs index 9c31883a9..e5b906fa4 100644 --- a/autoprecompiles/src/bitwise_lookup_optimizer.rs +++ b/autoprecompiles/src/bitwise_lookup_optimizer.rs @@ -3,6 +3,7 @@ use std::hash::Hash; use std::{fmt::Debug, fmt::Display}; use itertools::Itertools; +use num_traits::{One, Zero}; use powdr_constraint_solver::constraint_system::{BusInteraction, ConstraintSystem}; use powdr_constraint_solver::quadratic_symbolic_expression::QuadraticSymbolicExpression; use powdr_number::FieldElement; @@ -40,14 +41,14 @@ pub fn optimize_bitwise_lookup( +fn is_simple_multiplicity_bitwise_bus_interaction( bus_int: &BusInteraction>, bitwise_lookup_bus_id: u64, ) -> bool { - bus_int.bus_id == T::from(bitwise_lookup_bus_id).into() - && bus_int.multiplicity == T::from(1).into() + bus_int.bus_id == QuadraticSymbolicExpression::from_number(T::from(bitwise_lookup_bus_id)) + && bus_int.multiplicity.is_one() } /// Returns all expressions that are byte-constrained in the machine. diff --git a/autoprecompiles/src/constraint_optimizer.rs b/autoprecompiles/src/constraint_optimizer.rs index 565347327..6c81a66da 100644 --- a/autoprecompiles/src/constraint_optimizer.rs +++ b/autoprecompiles/src/constraint_optimizer.rs @@ -1,6 +1,7 @@ use std::{collections::HashSet, fmt::Display, hash::Hash}; use inliner::DegreeBound; +use num_traits::Zero; use powdr_constraint_solver::{ constraint_system::BusInteractionHandler, inliner, journaling_constraint_system::JournalingConstraintSystem, @@ -135,10 +136,10 @@ fn remove_disconnected_columns constraint_system } -fn remove_trivial_constraints( +fn remove_trivial_constraints( mut constraint_system: JournalingConstraintSystem, ) -> JournalingConstraintSystem { - let zero = QuadraticSymbolicExpression::from(P::zero()); + let zero = QuadraticSymbolicExpression::zero(); constraint_system.retain_algebraic_constraints(|constraint| constraint != &zero); constraint_system .retain_bus_interactions(|bus_interaction| bus_interaction.multiplicity != zero); diff --git a/autoprecompiles/src/expression_conversion.rs b/autoprecompiles/src/expression_conversion.rs index 788ffe2e9..1a1d9405a 100644 --- a/autoprecompiles/src/expression_conversion.rs +++ b/autoprecompiles/src/expression_conversion.rs @@ -1,12 +1,9 @@ use powdr_constraint_solver::{ - quadratic_symbolic_expression::QuadraticSymbolicExpression, - symbolic_expression::{BinaryOperator, SymbolicExpression, UnaryOperator}, + quadratic_symbolic_expression::QuadraticSymbolicExpression, runtime_constant::RuntimeConstant, + symbolic_expression::SymbolicExpression, }; -use powdr_expression::{ - AlgebraicBinaryOperation, AlgebraicBinaryOperator, AlgebraicUnaryOperation, - AlgebraicUnaryOperator, -}; -use powdr_number::FieldElement; +use powdr_expression::{AlgebraicUnaryOperation, AlgebraicUnaryOperator}; +use powdr_number::{ExpressionConvertible, FieldElement}; use crate::expression::{AlgebraicExpression, AlgebraicReference}; @@ -15,9 +12,10 @@ use crate::expression::{AlgebraicExpression, AlgebraicReference}; pub fn algebraic_to_quadratic_symbolic_expression( expr: &AlgebraicExpression, ) -> QuadraticSymbolicExpression { - powdr_expression::conversion::convert(expr, &mut |reference| { - QuadraticSymbolicExpression::from_unknown_variable(reference.clone()) - }) + expr.to_expression( + &|n| QuadraticSymbolicExpression::from_number(*n), + &|reference| QuadraticSymbolicExpression::from_unknown_variable(reference.clone()), + ) } /// Turns a quadratic symbolic expression back into an algebraic expression. @@ -83,37 +81,16 @@ pub fn quadratic_symbolic_expression_to_algebraic( fn symbolic_expression_to_algebraic( e: &SymbolicExpression, ) -> AlgebraicExpression { - match e { - SymbolicExpression::Concrete(v) => { + e.to_expression( + &|v| { if v.is_in_lower_half() { AlgebraicExpression::from(*v) } else { -AlgebraicExpression::from(-*v) } - } - SymbolicExpression::Symbol(r, _) => AlgebraicExpression::Reference(r.clone()), - SymbolicExpression::BinaryOperation(left, op, right, _) => { - let left = Box::new(symbolic_expression_to_algebraic(left)); - let right = Box::new(symbolic_expression_to_algebraic(right)); - let op = symbolic_op_to_algebraic(*op); - AlgebraicExpression::BinaryOperation(AlgebraicBinaryOperation { left, op, right }) - } - SymbolicExpression::UnaryOperation(op, inner, _) => match op { - UnaryOperator::Neg => AlgebraicExpression::UnaryOperation(AlgebraicUnaryOperation { - expr: Box::new(symbolic_expression_to_algebraic(inner)), - op: AlgebraicUnaryOperator::Minus, - }), }, - } -} - -fn symbolic_op_to_algebraic(op: BinaryOperator) -> AlgebraicBinaryOperator { - match op { - BinaryOperator::Add => AlgebraicBinaryOperator::Add, - BinaryOperator::Sub => AlgebraicBinaryOperator::Sub, - BinaryOperator::Mul => AlgebraicBinaryOperator::Mul, - BinaryOperator::Div => unreachable!(), - } + &|r| AlgebraicExpression::Reference(r.clone()), + ) } /// If `e` is negated, returns the expression without negation and `true`, diff --git a/autoprecompiles/src/memory_optimizer.rs b/autoprecompiles/src/memory_optimizer.rs index f18ec2002..959bac210 100644 --- a/autoprecompiles/src/memory_optimizer.rs +++ b/autoprecompiles/src/memory_optimizer.rs @@ -39,7 +39,7 @@ pub fn optimize_memory( // Check that the number of register memory bus interactions for each concrete address in the precompile is even. // Assumption: all register memory bus interactions feature a concrete address. -pub fn check_register_operation_consistency( +pub fn check_register_operation_consistency( system: &ConstraintSystem, memory_bus_id: u64, ) -> bool { @@ -89,7 +89,7 @@ struct MemoryBusInteraction { timestamp: QuadraticSymbolicExpression, } -impl MemoryBusInteraction { +impl MemoryBusInteraction { /// Tries to convert a `BusInteraction` to a `MemoryBusInteraction`. /// /// Returns `Ok(None)` if we know that the bus interaction is not a memory bus interaction. diff --git a/autoprecompiles/src/optimizer.rs b/autoprecompiles/src/optimizer.rs index 90e7bb169..5cb450b76 100644 --- a/autoprecompiles/src/optimizer.rs +++ b/autoprecompiles/src/optimizer.rs @@ -4,7 +4,6 @@ use powdr_constraint_solver::{ constraint_system::{BusInteraction, BusInteractionHandler, ConstraintSystem}, journaling_constraint_system::JournalingConstraintSystem, quadratic_symbolic_expression::{NoRangeConstraints, QuadraticSymbolicExpression}, - symbolic_expression::SymbolicExpression, }; use powdr_number::FieldElement; @@ -249,7 +248,7 @@ fn symbolic_bus_interaction_to_bus_interaction( bus_interaction: &SymbolicBusInteraction

, ) -> BusInteraction> { BusInteraction { - bus_id: SymbolicExpression::Concrete(P::from(bus_interaction.id)).into(), + bus_id: QuadraticSymbolicExpression::from_number(P::from(bus_interaction.id)), payload: bus_interaction .args .iter() diff --git a/constraint-solver/src/constraint_system.rs b/constraint-solver/src/constraint_system.rs index a0ade639f..6690987ae 100644 --- a/constraint-solver/src/constraint_system.rs +++ b/constraint-solver/src/constraint_system.rs @@ -2,6 +2,7 @@ use crate::{ effect::Effect, quadratic_symbolic_expression::{QuadraticSymbolicExpression, RangeConstraintProvider}, range_constraint::RangeConstraint, + runtime_constant::RuntimeConstant, }; use itertools::Itertools; use powdr_number::FieldElement; @@ -17,7 +18,7 @@ pub struct ConstraintSystem { pub bus_interactions: Vec>>, } -impl Display for ConstraintSystem { +impl Display for ConstraintSystem { fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { write!( f, diff --git a/constraint-solver/src/effect.rs b/constraint-solver/src/effect.rs index fa384cbae..741c3ea0d 100644 --- a/constraint-solver/src/effect.rs +++ b/constraint-solver/src/effect.rs @@ -3,28 +3,35 @@ use std::fmt::{self, Display, Formatter}; use itertools::Itertools; use powdr_number::FieldElement; -use crate::{range_constraint::RangeConstraint, symbolic_expression::SymbolicExpression}; +use crate::{ + range_constraint::RangeConstraint, runtime_constant::RuntimeConstant, + symbolic_expression::SymbolicExpression, +}; /// The effect of solving a symbolic equation. #[derive(Clone, PartialEq, Eq)] -pub enum Effect { +pub enum EffectImpl { /// Variable can be assigned a value. - Assignment(V, SymbolicExpression), + Assignment(V, T), /// Perform a bit decomposition of a known value, and assign multiple variables. BitDecomposition(BitDecomposition), /// We learnt a new range constraint on variable. - RangeConstraint(V, RangeConstraint), + RangeConstraint(V, RangeConstraint), /// A run-time assertion. If this fails, we have conflicting constraints. - Assertion(Assertion), + Assertion(Assertion), /// A variable is assigned one of two alternative expressions, depending on a condition. ConditionalAssignment { variable: V, - condition: Condition, - in_range_value: SymbolicExpression, - out_of_range_value: SymbolicExpression, + condition: Condition, + in_range_value: T, + out_of_range_value: T, }, } +// TODO: This type is equivalent to a pre-refactoring version of `EffectImpl`. +// It should be removed in a follow-up PR & we should rename `EffectImpl` to `Effect`. +pub type Effect = EffectImpl, V>; + /// A bit decomposition of a value. /// Executing this effect solves the following equation: /// value = sum_{i=0}^{components.len() - 1} (-1)**components[i].negative * 2**components[i].exponent * components[i].variable @@ -32,14 +39,14 @@ pub enum Effect { /// This effect can only be created if the equation has a unique solution. /// It might be that it leads to a contradiction, which should result in an assertion failure. #[derive(Clone, PartialEq, Eq)] -pub struct BitDecomposition { +pub struct BitDecomposition { /// The value that is decomposed. - pub value: SymbolicExpression, + pub value: T, /// The components of the decomposition. - pub components: Vec>, + pub components: Vec>, } -impl Display for BitDecomposition { +impl Display for BitDecomposition { fn fmt(&self, f: &mut Formatter<'_>) -> fmt::Result { let BitDecomposition { value, components } = self; write!(f, "{} := {value};", components.iter().format(" + ")) @@ -81,33 +88,30 @@ impl Display for BitDecompositionComponent { /// A run-time assertion. If this fails, we have conflicting constraints. #[derive(Clone, PartialEq, Eq)] -pub struct Assertion { - pub lhs: SymbolicExpression, - pub rhs: SymbolicExpression, +pub struct Assertion { + pub lhs: T, + pub rhs: T, /// If this is true, we assert that both sides are equal. /// Otherwise, we assert that they are different. pub expected_equal: bool, } -impl Assertion { - pub fn assert_is_zero(condition: SymbolicExpression) -> Effect { - Self::assert_eq(condition, SymbolicExpression::from(T::from(0))) +impl Assertion { + pub fn assert_is_zero(condition: T) -> EffectImpl { + Self::assert_eq(condition, T::from_u64(0)) } - pub fn assert_is_nonzero(condition: SymbolicExpression) -> Effect { - Self::assert_neq(condition, SymbolicExpression::from(T::from(0))) + pub fn assert_is_nonzero(condition: T) -> EffectImpl { + Self::assert_neq(condition, T::from_u64(0)) } - pub fn assert_eq(lhs: SymbolicExpression, rhs: SymbolicExpression) -> Effect { - Effect::Assertion(Assertion { + pub fn assert_eq(lhs: T, rhs: T) -> EffectImpl { + EffectImpl::Assertion(Assertion { lhs, rhs, expected_equal: true, }) } - pub fn assert_neq( - lhs: SymbolicExpression, - rhs: SymbolicExpression, - ) -> Effect { - Effect::Assertion(Assertion { + pub fn assert_neq(lhs: T, rhs: T) -> EffectImpl { + EffectImpl::Assertion(Assertion { lhs, rhs, expected_equal: false, @@ -116,7 +120,7 @@ impl Assertion { } #[derive(Clone, PartialEq, Eq)] -pub struct Condition { - pub value: SymbolicExpression, - pub condition: RangeConstraint, +pub struct Condition { + pub value: T, + pub condition: RangeConstraint, } diff --git a/constraint-solver/src/indexed_constraint_system.rs b/constraint-solver/src/indexed_constraint_system.rs index d356315c4..1b3802d94 100644 --- a/constraint-solver/src/indexed_constraint_system.rs +++ b/constraint-solver/src/indexed_constraint_system.rs @@ -241,7 +241,7 @@ impl IndexedConstraintSystem ) { let bus_interaction = &mut self.constraint_system.bus_interactions[interaction_index]; let field = bus_interaction.fields_mut().nth(field_index).unwrap(); - *field = value.into(); + *field = QuadraticSymbolicExpression::from_number(value); } /// Substitute an unknown variable by a QuadraticSymbolicExpression in the whole system. @@ -417,7 +417,7 @@ fn substitute_by_unknown_in_item( } } -impl Display for IndexedConstraintSystem { +impl Display for IndexedConstraintSystem { fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { write!(f, "{}", self.constraint_system) } @@ -475,8 +475,7 @@ mod tests { s.substitute_by_unknown( &"z", - &(Qse::from_unknown_variable("x") - + Qse::from(SymbolicExpression::from(GoldilocksField::from(7)))), + &(Qse::from_unknown_variable("x") + Qse::from_number(GoldilocksField::from(7))), ); assert_eq!( diff --git a/constraint-solver/src/journaling_constraint_system.rs b/constraint-solver/src/journaling_constraint_system.rs index b17939a2d..09e7e31ba 100644 --- a/constraint-solver/src/journaling_constraint_system.rs +++ b/constraint-solver/src/journaling_constraint_system.rs @@ -103,7 +103,9 @@ impl JournalingConstraintSystem { } } -impl Display for JournalingConstraintSystem { +impl Display + for JournalingConstraintSystem +{ fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { write!(f, "{}", self.system) } diff --git a/constraint-solver/src/lib.rs b/constraint-solver/src/lib.rs index f5ff42630..bce4a1926 100644 --- a/constraint-solver/src/lib.rs +++ b/constraint-solver/src/lib.rs @@ -8,6 +8,7 @@ pub mod inliner; pub mod journaling_constraint_system; pub mod quadratic_symbolic_expression; pub mod range_constraint; +pub mod runtime_constant; pub mod solver; pub mod symbolic_expression; pub mod symbolic_to_quadratic; diff --git a/constraint-solver/src/quadratic_symbolic_expression.rs b/constraint-solver/src/quadratic_symbolic_expression.rs index 50637831c..809e09484 100644 --- a/constraint-solver/src/quadratic_symbolic_expression.rs +++ b/constraint-solver/src/quadratic_symbolic_expression.rs @@ -5,32 +5,33 @@ use std::{ ops::{Add, AddAssign, Mul, MulAssign, Neg, Sub}, }; -use itertools::Itertools; -use num_traits::Zero; -use powdr_number::{log2_exact, FieldElement, LargeInt}; - use crate::{ - effect::Condition, symbolic_to_quadratic::symbolic_expression_to_quadratic_symbolic_expression, + effect::Condition, + runtime_constant::{ReferencedSymbols, RuntimeConstant, Substitutable}, }; +use itertools::Itertools; +use num_traits::One; +use num_traits::Zero; +use powdr_number::{log2_exact, ExpressionConvertible, FieldElement, LargeInt}; -use super::effect::{Assertion, BitDecomposition, BitDecompositionComponent, Effect}; +use super::effect::{Assertion, BitDecomposition, BitDecompositionComponent, EffectImpl}; use super::range_constraint::RangeConstraint; use super::symbolic_expression::SymbolicExpression; #[derive(Default)] -pub struct ProcessResult { - pub effects: Vec>, +pub struct ProcessResult { + pub effects: Vec>, pub complete: bool, } -impl ProcessResult { +impl ProcessResult { pub fn empty() -> Self { Self { effects: vec![], complete: false, } } - pub fn complete(effects: Vec>) -> Self { + pub fn complete(effects: Vec>) -> Self { Self { effects, complete: true, @@ -58,49 +59,90 @@ pub enum Error { /// an unknown variable gets known and provides functions to solve /// (some kinds of) equations. #[derive(Debug, Clone, PartialEq, Eq, Hash)] -pub struct QuadraticSymbolicExpression { +pub struct QuadraticSymbolicExpressionImpl { /// Quadratic terms of the form `a * X * Y`, where `a` is a (symbolically) /// known value and `X` and `Y` are quadratic symbolic expressions that /// have at least one unknown. quadratic: Vec<(Self, Self)>, /// Linear terms of the form `a * X`, where `a` is a (symbolically) known /// value and `X` is an unknown variable. - linear: BTreeMap>, + linear: BTreeMap, /// Constant term, a (symbolically) known value. - constant: SymbolicExpression, + constant: T, } -impl From> for QuadraticSymbolicExpression { - fn from(k: SymbolicExpression) -> Self { +// TODO: This type is equivalent to a pre-refactoring version of `QuadraticSymbolicExpressionImpl`. +// It should be removed in a follow-up PR & we should rename `QuadraticSymbolicExpressionImpl` to `QuadraticSymbolicExpression`. +pub type QuadraticSymbolicExpression = + QuadraticSymbolicExpressionImpl, V>; + +impl, V> QuadraticSymbolicExpressionImpl { + pub fn from_number(k: F) -> Self { Self { quadratic: Default::default(), linear: Default::default(), - constant: k, + constant: T::from(k), } } } -impl From for QuadraticSymbolicExpression { - fn from(k: T) -> Self { - SymbolicExpression::from(k).into() +impl Zero + for QuadraticSymbolicExpressionImpl +{ + fn zero() -> Self { + Self { + quadratic: Default::default(), + linear: Default::default(), + constant: T::zero(), + } + } + + fn is_zero(&self) -> bool { + self.try_to_known().is_some_and(|k| k.is_known_zero()) } } -impl QuadraticSymbolicExpression { - pub fn from_known_symbol(symbol: V, rc: RangeConstraint) -> Self { - SymbolicExpression::from_symbol(symbol, rc).into() +impl One for QuadraticSymbolicExpressionImpl { + fn one() -> Self { + Self { + quadratic: Default::default(), + linear: Default::default(), + constant: T::one(), + } + } + + fn is_one(&self) -> bool { + self.try_to_known().is_some_and(|k| k.is_known_one()) + } +} + +impl + QuadraticSymbolicExpressionImpl, V> +{ + pub fn from_known_symbol(symbol: V, rc: RangeConstraint) -> Self { + Self::from_runtime_constant(SymbolicExpression::from_symbol(symbol, rc)) + } +} + +impl QuadraticSymbolicExpressionImpl { + pub fn from_runtime_constant(constant: T) -> Self { + Self { + quadratic: Default::default(), + linear: Default::default(), + constant, + } } pub fn from_unknown_variable(var: V) -> Self { Self { quadratic: Default::default(), - linear: [(var.clone(), T::from(1).into())].into_iter().collect(), - constant: T::from(0).into(), + linear: [(var.clone(), T::one())].into_iter().collect(), + constant: T::zero(), } } /// If this expression does not contain unknown variables, returns the symbolic expression. - pub fn try_to_known(&self) -> Option<&SymbolicExpression> { + pub fn try_to_known(&self) -> Option<&T> { if self.quadratic.is_empty() && self.linear.is_empty() { Some(&self.constant) } else { @@ -114,7 +156,7 @@ impl QuadraticSymbolicExpression { } /// If the expression is a known number, returns it. - pub fn try_to_number(&self) -> Option { + pub fn try_to_number(&self) -> Option { self.try_to_known()?.try_to_number() } @@ -150,20 +192,13 @@ impl QuadraticSymbolicExpression { } /// Returns the quadratic, linear and constant components of this expression. - #[allow(clippy::type_complexity)] - pub fn components( - &self, - ) -> ( - &[(Self, Self)], - impl Iterator)>, - &SymbolicExpression, - ) { + pub fn components(&self) -> (&[(Self, Self)], impl Iterator, &T) { (&self.quadratic, self.linear.iter(), &self.constant) } /// Returns the coefficient of the variable `variable` if this is an affine expression. /// Panics if the expression is quadratic. - pub fn coefficient_of_variable(&self, var: &V) -> Option<&SymbolicExpression> { + pub fn coefficient_of_variable(&self, var: &V) -> Option<&T> { assert!(!self.is_quadratic()); self.linear.get(var) } @@ -171,8 +206,8 @@ impl QuadraticSymbolicExpression { /// Returns the range constraint of the full expression. pub fn range_constraint( &self, - range_constraints: &impl RangeConstraintProvider, - ) -> RangeConstraint { + range_constraints: &impl RangeConstraintProvider, + ) -> RangeConstraint { self.quadratic .iter() .map(|(l, r)| { @@ -190,10 +225,12 @@ impl QuadraticSymbolicExpression { } } -impl QuadraticSymbolicExpression { +impl, V: Ord + Clone + Eq + Hash> + QuadraticSymbolicExpressionImpl +{ /// Substitute a variable by a symbolically known expression. The variable can be known or unknown. /// If it was already known, it will be substituted in the known expressions. - pub fn substitute_by_known(&mut self, variable: &V, substitution: &SymbolicExpression) { + pub fn substitute_by_known(&mut self, variable: &V, substitution: &T) { self.constant.substitute(variable, substitution); if self.linear.contains_key(variable) { @@ -202,7 +239,7 @@ impl QuadraticSymbolicExpression QuadraticSymbolicExpression { - to_add += (l * r).into(); + to_add += QuadraticSymbolicExpressionImpl::from_runtime_constant( + l.clone() * r.clone(), + ); false } (Some(l), None) => { @@ -245,13 +284,13 @@ impl QuadraticSymbolicExpression, + substitution: &QuadraticSymbolicExpressionImpl, ) { if !self.referenced_unknown_variables().any(|v| v == variable) { return; } - let mut to_add = QuadraticSymbolicExpression::from(T::zero()); + let mut to_add = QuadraticSymbolicExpressionImpl::zero(); for (var, coeff) in std::mem::take(&mut self.linear) { if var == *variable { to_add += substitution.clone() * coeff; @@ -267,7 +306,7 @@ impl QuadraticSymbolicExpression { - to_add += (lval * rval).into(); + to_add += Self::from_runtime_constant(lval.clone() * rval.clone()); None } (Some(lval), None) => { @@ -285,7 +324,11 @@ impl QuadraticSymbolicExpression, V: Ord + Clone + Eq + Hash> + QuadraticSymbolicExpressionImpl +{ /// Returns the set of referenced variables, both know and unknown. Might contain repetitions. pub fn referenced_variables(&self) -> Box + '_> { let quadr = self @@ -300,7 +343,9 @@ impl QuadraticSymbolicExpression QuadraticSymbolicExpressionImpl { /// Returns the referenced unknown variables. Might contain repetitions. pub fn referenced_unknown_variables(&self) -> Box + '_> { let quadratic = self.quadratic.iter().flat_map(|(a, b)| { @@ -311,12 +356,14 @@ impl QuadraticSymbolicExpression QuadraticSymbolicExpression { +impl + QuadraticSymbolicExpressionImpl, V1> +{ pub fn transform_var_type( &self, var_transform: &mut impl FnMut(&V1) -> V2, - ) -> QuadraticSymbolicExpression { - QuadraticSymbolicExpression { + ) -> QuadraticSymbolicExpressionImpl, V2> { + QuadraticSymbolicExpressionImpl { quadratic: self .quadratic .iter() @@ -352,7 +399,11 @@ impl RangeConstraintProvider for NoRangeConstraints { } } -impl QuadraticSymbolicExpression { +impl< + T: RuntimeConstant + Display + ExpressionConvertible<::FieldType, V>, + V: Ord + Clone + Hash + Eq + Display, + > QuadraticSymbolicExpressionImpl +{ /// Solves the equation `self = 0` and returns how to compute the solution. /// The solution can contain assignments to multiple variables. /// If no way to solve the equation (and no way to derive new range @@ -361,7 +412,7 @@ impl QuadraticSymbolicExp /// If the equation is known to be unsolvable, returns an error. pub fn solve( &self, - range_constraints: &impl RangeConstraintProvider, + range_constraints: &impl RangeConstraintProvider, ) -> Result, Error> { Ok(if self.is_quadratic() { self.solve_quadratic(range_constraints)? @@ -383,7 +434,7 @@ impl QuadraticSymbolicExp /// has a coefficient which is known to be not zero. /// /// Returns the resulting solved quadratic symbolic expression. - pub fn try_solve_for(&self, variable: &V) -> Option> { + pub fn try_solve_for(&self, variable: &V) -> Option> { if self .quadratic .iter() @@ -399,7 +450,7 @@ impl QuadraticSymbolicExp } let mut result = self.clone(); let coefficient = result.linear.remove(variable)?; - Some(result * (SymbolicExpression::from(-T::from(1)).field_div(&coefficient))) + Some(result * (-T::one().field_div(&coefficient))) } /// Algebraically transforms the constraint such that `self = 0` is equivalent @@ -409,8 +460,8 @@ impl QuadraticSymbolicExp /// Panics if `expr` is quadratic. pub fn try_solve_for_expr( &self, - expr: &QuadraticSymbolicExpression, - ) -> Option> { + expr: &QuadraticSymbolicExpressionImpl, + ) -> Option> { assert!( expr.is_affine(), "Tried to solve for quadratic expression {expr}" @@ -431,7 +482,7 @@ impl QuadraticSymbolicExp None } }) - .unwrap_or(T::from(1).into()); + .unwrap_or(T::one()); let result = expr - &(self.clone() * normalization_factor); // Check that the operations removed all variables in `expr` from `self`. @@ -452,7 +503,7 @@ impl QuadraticSymbolicExp fn solve_affine( &self, - range_constraints: &impl RangeConstraintProvider, + range_constraints: &impl RangeConstraintProvider, ) -> Result, Error> { Ok(if self.linear.len() == 1 { let (var, coeff) = self.linear.iter().next().unwrap(); @@ -463,7 +514,7 @@ impl QuadraticSymbolicExp ); if coeff.is_known_nonzero() { // In this case, we can always compute a solution. - let value = self.constant.field_div(&-coeff); + let value = self.constant.field_div(&-coeff.clone()); ProcessResult::complete(vec![assignment_if_satisfies_range_constraints( var.clone(), value, @@ -472,7 +523,7 @@ impl QuadraticSymbolicExp } else if self.constant.is_known_nonzero() { // If the offset is not zero, then the coefficient must be non-zero, // otherwise the constraint is violated. - let value = self.constant.field_div(&-coeff); + let value = self.constant.field_div(&-coeff.clone()); ProcessResult::complete(vec![ Assertion::assert_is_nonzero(coeff.clone()), assignment_if_satisfies_range_constraints( @@ -504,7 +555,7 @@ impl QuadraticSymbolicExp /// Tries to solve a bit-decomposition equation. fn solve_bit_decomposition( &self, - range_constraints: &impl RangeConstraintProvider, + range_constraints: &impl RangeConstraintProvider, ) -> Result, Error> { assert!(!self.is_quadratic()); // All the coefficients need to be known numbers and the @@ -537,8 +588,8 @@ impl QuadraticSymbolicExp let mut concrete_assignments = vec![]; // Check if they are mutually exclusive and compute assignments. - let mut covered_bits: ::Integer = 0.into(); - let mut components = vec![]; + let mut covered_bits: ::Integer = 0.into(); + let mut components: Vec> = vec![]; for (variable, constraint, is_negative, coeff_abs, exponent) in constrained_coefficients .into_iter() .sorted_by_key(|(_, _, _, _, exponent)| *exponent) @@ -555,15 +606,17 @@ impl QuadraticSymbolicExp // if it is not known, we return a BitDecomposition effect. if let Some(offset) = &mut offset { let mut component = if is_negative { -*offset } else { *offset }.to_integer(); - if component > (T::modulus() - 1.into()) >> 1 { + if component > (T::FieldType::modulus() - 1.into()) >> 1 { // Convert a signed finite field element into two's complement. // a regular subtraction would underflow, so we do this. // We add the difference between negative numbers in the field // and negative numbers in two's complement. - component += T::Integer::MAX - T::modulus() + 1.into(); + component += ::Integer::MAX + - T::FieldType::modulus() + + 1.into(); }; component &= bit_mask; - if component >= T::modulus() { + if component >= T::FieldType::modulus() { // If the component does not fit the field, the bit mask is not // tight good enough. return Ok(ProcessResult::empty()); @@ -571,12 +624,15 @@ impl QuadraticSymbolicExp concrete_assignments.push( // We're not using assignment_if_satisfies_range_constraints here, because we // might still exit early. The error case is handled below. - Effect::Assignment(variable.clone(), T::from(component >> exponent).into()), + EffectImpl::Assignment( + variable.clone(), + T::FieldType::from(component >> exponent).into(), + ), ); if is_negative { - *offset += T::from(component); + *offset += T::FieldType::from(component); } else { - *offset -= T::from(component); + *offset -= T::FieldType::from(component); } } else { components.push(BitDecompositionComponent { @@ -588,7 +644,7 @@ impl QuadraticSymbolicExp } } - if covered_bits >= T::modulus() { + if covered_bits >= T::FieldType::modulus() { return Ok(ProcessResult::empty()); } @@ -599,7 +655,7 @@ impl QuadraticSymbolicExp assert_eq!(concrete_assignments.len(), self.linear.len()); Ok(ProcessResult::complete(concrete_assignments)) } else { - Ok(ProcessResult::complete(vec![Effect::BitDecomposition( + Ok(ProcessResult::complete(vec![EffectImpl::BitDecomposition( BitDecomposition { value: self.constant.clone(), components, @@ -610,8 +666,8 @@ impl QuadraticSymbolicExp fn transfer_constraints( &self, - range_constraints: &impl RangeConstraintProvider, - ) -> Vec> { + range_constraints: &impl RangeConstraintProvider, + ) -> Vec> { // Solve for each of the variables in the linear component and // compute the range constraints. assert!(!self.is_quadratic()); @@ -622,13 +678,13 @@ impl QuadraticSymbolicExp Some((var, rc)) }) .filter(|(_, constraint)| !constraint.is_unconstrained()) - .map(|(var, constraint)| Effect::RangeConstraint(var.clone(), constraint)) + .map(|(var, constraint)| EffectImpl::RangeConstraint(var.clone(), constraint)) .collect() } fn solve_quadratic( &self, - range_constraints: &impl RangeConstraintProvider, + range_constraints: &impl RangeConstraintProvider, ) -> Result, Error> { let Some((left, right)) = self.try_as_single_product() else { return Ok(ProcessResult::empty()); @@ -663,15 +719,18 @@ impl QuadraticSymbolicExp /// Tries to combine two process results from alternative branches into a /// conditional assignment. -fn combine_to_conditional_assignment( +fn combine_to_conditional_assignment< + T: RuntimeConstant + ExpressionConvertible<::FieldType, V>, + V: Ord + Clone + Hash + Eq + Display, +>( left: &ProcessResult, right: &ProcessResult, - range_constraints: &impl RangeConstraintProvider, + range_constraints: &impl RangeConstraintProvider, ) -> Option> { - let [Effect::Assignment(first_var, first_assignment)] = left.effects.as_slice() else { + let [EffectImpl::Assignment(first_var, first_assignment)] = left.effects.as_slice() else { return None; }; - let [Effect::Assignment(second_var, second_assignment)] = right.effects.as_slice() else { + let [EffectImpl::Assignment(second_var, second_assignment)] = right.effects.as_slice() else { return None; }; @@ -684,9 +743,13 @@ fn combine_to_conditional_assignment = diff.try_to_expression( + &|n| QuadraticSymbolicExpression::from_number(*n), + &|v| QuadraticSymbolicExpressionImpl::from_unknown_variable(v.clone()), + &|e| e.try_to_number(), )?; + let diff = diff.try_to_known()?.try_to_number()?; // `diff = A - B` is a compile-time known number, i.e. `A = B + diff`. // Now if `rc + diff` is disjoint from `rc`, it means @@ -702,7 +765,7 @@ fn combine_to_conditional_assignment( +fn combine_range_constraints( left: &ProcessResult, right: &ProcessResult, ) -> ProcessResult { @@ -760,37 +823,37 @@ fn combine_range_constraints( +fn assignment_if_satisfies_range_constraints( var: V, - value: SymbolicExpression, - range_constraints: &impl RangeConstraintProvider, -) -> Result, Error> { + value: T, + range_constraints: &impl RangeConstraintProvider, +) -> Result, Error> { let rc = range_constraints.get(&var); if rc.is_disjoint(&value.range_constraint()) { return Err(Error::ConflictingRangeConstraints); } - Ok(Effect::Assignment(var, value)) + Ok(EffectImpl::Assignment(var, value)) } /// Turns an effect into a range constraint on a variable. -fn effect_to_range_constraint( - effect: &Effect, -) -> Option<(V, RangeConstraint)> { +fn effect_to_range_constraint( + effect: &EffectImpl, +) -> Option<(V, RangeConstraint)> { match effect { - Effect::RangeConstraint(var, rc) => Some((var.clone(), rc.clone())), - Effect::Assignment(var, value) => Some((var.clone(), value.range_constraint())), + EffectImpl::RangeConstraint(var, rc) => Some((var.clone(), rc.clone())), + EffectImpl::Assignment(var, value) => Some((var.clone(), value.range_constraint())), _ => None, } } -impl Add for QuadraticSymbolicExpression { - type Output = QuadraticSymbolicExpression; +impl Add for QuadraticSymbolicExpressionImpl { + type Output = QuadraticSymbolicExpressionImpl; fn add(mut self, rhs: Self) -> Self { self += rhs; @@ -798,16 +861,18 @@ impl Add for QuadraticSymbolicExpre } } -impl Add for &QuadraticSymbolicExpression { - type Output = QuadraticSymbolicExpression; +impl Add + for &QuadraticSymbolicExpressionImpl +{ + type Output = QuadraticSymbolicExpressionImpl; fn add(self, rhs: Self) -> Self::Output { self.clone() + rhs.clone() } } -impl AddAssign> - for QuadraticSymbolicExpression +impl + AddAssign> for QuadraticSymbolicExpressionImpl { fn add_assign(&mut self, rhs: Self) { self.quadratic.extend(rhs.quadratic); @@ -822,23 +887,25 @@ impl AddAssign Sub for &QuadraticSymbolicExpression { - type Output = QuadraticSymbolicExpression; +impl Sub + for &QuadraticSymbolicExpressionImpl +{ + type Output = QuadraticSymbolicExpressionImpl; fn sub(self, rhs: Self) -> Self::Output { self + &-rhs } } -impl Sub for QuadraticSymbolicExpression { - type Output = QuadraticSymbolicExpression; +impl Sub for QuadraticSymbolicExpressionImpl { + type Output = QuadraticSymbolicExpressionImpl; fn sub(self, rhs: Self) -> Self::Output { &self - &rhs } } -impl QuadraticSymbolicExpression { +impl QuadraticSymbolicExpressionImpl { fn negate(&mut self) { for (first, _) in &mut self.quadratic { first.negate() @@ -850,8 +917,8 @@ impl QuadraticSymbolicExpression { } } -impl Neg for QuadraticSymbolicExpression { - type Output = QuadraticSymbolicExpression; +impl Neg for QuadraticSymbolicExpressionImpl { + type Output = QuadraticSymbolicExpressionImpl; fn neg(mut self) -> Self { self.negate(); @@ -859,8 +926,8 @@ impl Neg for QuadraticSymbolicExpression } } -impl Neg for &QuadraticSymbolicExpression { - type Output = QuadraticSymbolicExpression; +impl Neg for &QuadraticSymbolicExpressionImpl { + type Output = QuadraticSymbolicExpressionImpl; fn neg(self) -> Self::Output { -((*self).clone()) @@ -868,33 +935,33 @@ impl Neg for &QuadraticSymbolicExpression } /// Multiply by known symbolic expression. -impl Mul<&SymbolicExpression> - for QuadraticSymbolicExpression +impl Mul<&T> + for QuadraticSymbolicExpressionImpl { - type Output = QuadraticSymbolicExpression; + type Output = QuadraticSymbolicExpressionImpl; - fn mul(mut self, rhs: &SymbolicExpression) -> Self { + fn mul(mut self, rhs: &T) -> Self { self *= rhs; self } } -impl Mul> - for QuadraticSymbolicExpression +impl Mul + for QuadraticSymbolicExpressionImpl { - type Output = QuadraticSymbolicExpression; + type Output = QuadraticSymbolicExpressionImpl; - fn mul(self, rhs: SymbolicExpression) -> Self { + fn mul(self, rhs: T) -> Self { self * &rhs } } -impl MulAssign<&SymbolicExpression> - for QuadraticSymbolicExpression +impl MulAssign<&T> + for QuadraticSymbolicExpressionImpl { - fn mul_assign(&mut self, rhs: &SymbolicExpression) { + fn mul_assign(&mut self, rhs: &T) { if rhs.is_known_zero() { - *self = T::zero().into(); + *self = Self::zero(); } else { for (first, _) in &mut self.quadratic { *first *= rhs; @@ -907,10 +974,10 @@ impl MulAssign<&SymbolicExpression< } } -impl Mul for QuadraticSymbolicExpression { - type Output = QuadraticSymbolicExpression; +impl Mul for QuadraticSymbolicExpressionImpl { + type Output = QuadraticSymbolicExpressionImpl; - fn mul(self, rhs: QuadraticSymbolicExpression) -> Self { + fn mul(self, rhs: QuadraticSymbolicExpressionImpl) -> Self { if let Some(k) = rhs.try_to_known() { self * k } else if let Some(k) = self.try_to_known() { @@ -919,13 +986,15 @@ impl Mul for QuadraticSymbolicExpre Self { quadratic: vec![(self, rhs)], linear: Default::default(), - constant: T::from(0).into(), + constant: T::zero(), } } } } -impl Display for QuadraticSymbolicExpression { +impl Display + for QuadraticSymbolicExpressionImpl +{ fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { let (sign, s) = self.to_signed_string(); if sign { @@ -936,7 +1005,7 @@ impl Display for QuadraticSymbolicExp } } -impl QuadraticSymbolicExpression { +impl QuadraticSymbolicExpressionImpl { fn to_signed_string(&self) -> (bool, String) { self.quadratic .iter() @@ -949,8 +1018,8 @@ impl QuadraticSymbolicExpression (false, format!("{var}")), - Some(k) if k == (-1).into() => (true, format!("{var}")), + Some(k) if k == T::FieldType::one() => (false, format!("{var}")), + Some(k) if k == -T::FieldType::one() => (true, format!("{var}")), _ => { let (sign, coeff) = Self::symbolic_expression_to_signed_string(coeff); (sign, format!("{coeff} * {var}")) @@ -958,7 +1027,7 @@ impl QuadraticSymbolicExpression None, + Some(k) if k == T::FieldType::zero() => None, _ => Some(Self::symbolic_expression_to_signed_string(&self.constant)), }) .reduce(|(n1, p1), (n2, p2)| { @@ -974,7 +1043,7 @@ impl QuadraticSymbolicExpression) -> (bool, String) { + fn symbolic_expression_to_signed_string(value: &T) -> (bool, String) { match value.try_to_number() { Some(k) => { if k.is_in_lower_half() { @@ -999,7 +1068,10 @@ mod tests { use pretty_assertions::assert_eq; - type Qse = QuadraticSymbolicExpression; + type Qse = QuadraticSymbolicExpressionImpl< + SymbolicExpression, + &'static str, + >; #[test] fn test_mul() { @@ -1039,7 +1111,7 @@ mod tests { let x = Qse::from_unknown_variable("X"); let y = Qse::from_unknown_variable("Y"); let a = Qse::from_known_symbol("A", RangeConstraint::default()); - let zero = Qse::from(GoldilocksField::from(0)); + let zero = Qse::zero(); let t: Qse = x * y + a; assert_eq!(t.to_string(), "(X) * (Y) + A"); assert_eq!((t.clone() * zero).to_string(), "0"); @@ -1125,7 +1197,7 @@ mod tests { #[test] fn unsolvable() { - let r = Qse::from(GoldilocksField::from(10)).solve(&NoRangeConstraints); + let r = Qse::from_number(GoldilocksField::from(10)).solve(&NoRangeConstraints); assert!(r.is_err()); } @@ -1133,14 +1205,12 @@ mod tests { fn unsolvable_with_vars() { let x = &Qse::from_known_symbol("X", Default::default()); let y = &Qse::from_known_symbol("Y", Default::default()); - let mut constr = x + y - GoldilocksField::from(10).into(); + let mut constr = x + y - constant(10); // We cannot solve it, but we can also not learn anything new from it. let result = constr.solve(&NoRangeConstraints).unwrap(); assert!(result.complete && result.effects.is_empty()); // But if we know the values, we can be sure there is a conflict. - assert!(Qse::from(GoldilocksField::from(10)) - .solve(&NoRangeConstraints) - .is_err()); + assert!(constant(10).solve(&NoRangeConstraints).is_err()); // The same with range constraints that disallow zero. constr.substitute_by_known( @@ -1154,14 +1224,12 @@ mod tests { RangeConstraint::from_range(100.into(), 102.into()), ), ); - assert!(Qse::from(GoldilocksField::from(10)) - .solve(&NoRangeConstraints) - .is_err()); + assert!(constant(10).solve(&NoRangeConstraints).is_err()); } #[test] fn solvable_without_vars() { - let constr = Qse::from(GoldilocksField::from(0)); + let constr = constant(0); let result = constr.solve(&NoRangeConstraints).unwrap(); assert!(result.complete && result.effects.is_empty()); } @@ -1171,14 +1239,14 @@ mod tests { let y = Qse::from_known_symbol("y", Default::default()); let x = Qse::from_unknown_variable("X"); // 2 * X + 7 * y - 10 = 0 - let two = Qse::from(GoldilocksField::from(2)); - let seven = Qse::from(GoldilocksField::from(7)); - let ten = Qse::from(GoldilocksField::from(10)); + let two = constant(2); + let seven = constant(7); + let ten = constant(10); let constr = two * x + seven * y - ten; let result = constr.solve(&NoRangeConstraints).unwrap(); assert!(result.complete); assert_eq!(result.effects.len(), 1); - let Effect::Assignment(var, expr) = &result.effects[0] else { + let EffectImpl::Assignment(var, expr) = &result.effects[0] else { panic!("Expected assignment"); }; assert_eq!(var.to_string(), "X"); @@ -1191,8 +1259,8 @@ mod tests { let z = Qse::from_known_symbol("z", Default::default()); let x = Qse::from_unknown_variable("X"); // z * X + 7 * y - 10 = 0 - let seven = Qse::from(GoldilocksField::from(7)); - let ten = Qse::from(GoldilocksField::from(10)); + let seven = constant(7); + let ten = constant(10); let mut constr = z * x + seven * y - ten.clone(); // If we do not range-constrain z, we cannot solve since we don't know if it might be zero. let result = constr.solve(&NoRangeConstraints).unwrap(); @@ -1211,7 +1279,7 @@ mod tests { let result = constr.solve(&range_constraints).unwrap(); assert!(result.complete); let effects = result.effects; - let Effect::Assignment(var, expr) = &effects[0] else { + let EffectImpl::Assignment(var, expr) = &effects[0] else { panic!("Expected assignment"); }; assert_eq!(var.to_string(), "X"); @@ -1227,10 +1295,9 @@ mod tests { let c = Qse::from_unknown_variable("c"); let z = Qse::from_known_symbol("Z", Default::default()); // a * 0x100 - b * 0x10000 + c * 0x1000000 + 10 + Z = 0 - let ten = Qse::from(GoldilocksField::from(10)); - let constr: Qse = a * Qse::from(GoldilocksField::from(0x100)) - - b * Qse::from(GoldilocksField::from(0x10000)) - + c * Qse::from(GoldilocksField::from(0x1000000)) + let ten = constant(10); + let constr: Qse = a * constant(0x100) - b * constant(0x10000) + + c * constant(0x1000000) + ten.clone() + z.clone(); // Without range constraints on a, this is not solvable. @@ -1245,7 +1312,7 @@ mod tests { let [effect] = &result.effects[..] else { panic!(); }; - let Effect::BitDecomposition(BitDecomposition { value, components }) = effect else { + let EffectImpl::BitDecomposition(BitDecomposition { value, components }) = effect else { panic!(); }; assert_eq!(format!("{value}"), "(10 + Z)"); @@ -1282,19 +1349,17 @@ c = (((10 + Z) & 0xff000000) >> 24) [negative]; let range_constraints = HashMap::from([("a", rc.clone()), ("b", rc.clone()), ("c", rc.clone())]); // a * 0x100 + b * 0x10000 + c * 0x1000000 + 10 - Z = 0 - let ten = Qse::from(GoldilocksField::from(10)); - let constr = a * Qse::from(GoldilocksField::from(0x100)) - + b * Qse::from(GoldilocksField::from(0x10000)) - + c * Qse::from(GoldilocksField::from(0x1000000)) - + ten.clone() - - z.clone(); + let ten = constant(10); + let constr = + a * constant(0x100) + b * constant(0x10000) + c * constant(0x1000000) + ten.clone() + - z.clone(); let result = constr.solve(&range_constraints).unwrap(); assert!(!result.complete); let effects = result .effects .into_iter() .map(|effect| match effect { - Effect::RangeConstraint(v, rc) => format!("{v}: {rc};\n"), + EffectImpl::RangeConstraint(v, rc) => format!("{v}: {rc};\n"), _ => panic!(), }) .format("") @@ -1314,8 +1379,8 @@ c = (((10 + Z) & 0xff000000) >> 24) [negative]; let a = Qse::from_unknown_variable("a"); let b = Qse::from_known_symbol("b", rc.clone()); let range_constraints = HashMap::from([("a", rc.clone()), ("b", rc.clone())]); - let ten = Qse::from(GoldilocksField::from(10)); - let two_pow8 = Qse::from(GoldilocksField::from(0x100)); + let ten = constant(10); + let two_pow8 = constant(0x100); let constr = (a.clone() - b.clone() + two_pow8 - ten.clone()) * (a - b - ten); let result = constr.solve(&range_constraints).unwrap(); assert!(result.complete); @@ -1323,7 +1388,7 @@ c = (((10 + Z) & 0xff000000) >> 24) [negative]; .effects .into_iter() .map(|effect| match effect { - Effect::ConditionalAssignment { + EffectImpl::ConditionalAssignment { variable, condition: Condition { value, condition }, in_range_value, @@ -1347,7 +1412,7 @@ c = (((10 + Z) & 0xff000000) >> 24) [negative]; constr.substitute_by_known(&"b", &GoldilocksField::from(2).into()); let result = constr.solve(&range_constraints).unwrap(); assert!(result.complete); - let [Effect::Assignment(var, expr)] = result.effects.as_slice() else { + let [EffectImpl::Assignment(var, expr)] = result.effects.as_slice() else { panic!("Expected 1 assignment"); }; assert_eq!(var, &"a"); @@ -1355,12 +1420,15 @@ c = (((10 + Z) & 0xff000000) >> 24) [negative]; } fn unpack_range_constraint( - process_result: &ProcessResult, + process_result: &ProcessResult< + SymbolicExpression, + &'static str, + >, ) -> (&'static str, RangeConstraint) { let [effect] = &process_result.effects[..] else { panic!(); }; - let Effect::RangeConstraint(var, rc) = effect else { + let EffectImpl::RangeConstraint(var, rc) = effect else { panic!(); }; (var, rc.clone()) @@ -1369,9 +1437,9 @@ c = (((10 + Z) & 0xff000000) >> 24) [negative]; #[test] fn detect_bit_constraint() { let a = Qse::from_unknown_variable("a"); - let one = Qse::from(GoldilocksField::from(1)); - let three = Qse::from(GoldilocksField::from(3)); - let five = Qse::from(GoldilocksField::from(5)); + let one = constant(1); + let three = constant(3); + let five = constant(5); // All these constraints should be equivalent to a bit constraint. let constraints = [ @@ -1392,8 +1460,8 @@ c = (((10 + Z) & 0xff000000) >> 24) [negative]; #[test] fn detect_complete_range_constraint() { let a = Qse::from_unknown_variable("a"); - let three = Qse::from(GoldilocksField::from(3)); - let four = Qse::from(GoldilocksField::from(4)); + let three = constant(3); + let four = constant(4); // `a` can be 3 or 4, which is can be completely represented by // RangeConstraint::from_range(3, 4), so the identity should be @@ -1413,8 +1481,8 @@ c = (((10 + Z) & 0xff000000) >> 24) [negative]; #[test] fn detect_incomplete_range_constraint() { let a = Qse::from_unknown_variable("a"); - let three = Qse::from(GoldilocksField::from(3)); - let five = Qse::from(GoldilocksField::from(5)); + let three = constant(3); + let five = constant(5); // `a` can be 3 or 5, so there is a range constraint // RangeConstraint::from_range(3, 5) on `a`. diff --git a/constraint-solver/src/runtime_constant.rs b/constraint-solver/src/runtime_constant.rs new file mode 100644 index 000000000..2d215f18d --- /dev/null +++ b/constraint-solver/src/runtime_constant.rs @@ -0,0 +1,113 @@ +use std::ops::{Add, AddAssign, Mul, MulAssign, Neg, Sub}; + +use num_traits::{One, Zero}; +use powdr_number::FieldElement; + +use crate::range_constraint::RangeConstraint; + +/// Represents a run-time constant in the constraint solver, built over +/// a base field type. +/// The base field type itself (i.e. any T: FieldElement) represents a run-time constant +/// (which is also a compile-time constant), but the trait lets us represent run-time +/// constants symbolically as well. +pub trait RuntimeConstant: + Sized + + Neg + + Clone + + From + + Add + + AddAssign + + Sub + + Mul + + MulAssign + + PartialEq + + Eq + + Zero + + One +{ + type FieldType: FieldElement; + + /// Tries to convert the constant to a single number. This always works for compile-time constants. + fn try_to_number(&self) -> Option; + + /// Returns the range constraint for this constant. For compile-time constants, + /// this will be a single value range constraint. + fn range_constraint(&self) -> RangeConstraint; + + /// Divides this constant by another constant, returning a new constant. + fn field_div(&self, other: &Self) -> Self; + + /// Converts a u64 to a run-time constant. + fn from_u64(k: u64) -> Self { + Self::from(Self::FieldType::from(k)) + } + + /// Returns whether this constant is known to be zero at compile time. + fn is_known_zero(&self) -> bool { + self.try_to_number().is_some_and(|n| n.is_zero()) + } + + /// Returns whether this constant is known to be one at compile time. + fn is_known_one(&self) -> bool { + self.try_to_number().is_some_and(|n| n.is_one()) + } + + /// Returns whether this constant is known to be -1 at compile time. + fn is_known_minus_one(&self) -> bool { + self.try_to_number() + .is_some_and(|n| n == -Self::FieldType::from(1)) + } + + /// Returns whether this constant is known to be non-zero at compile time. + /// Note that this could return true even if the constant is not known fully + /// at compile time, but it is guaranteed that the constant is not zero. + fn is_known_nonzero(&self) -> bool { + // Only checking range constraint is enough since if this is a known + // fixed value, we will get a range constraint with just a single value. + !self.range_constraint().allows_value(0.into()) + } +} + +pub trait ReferencedSymbols { + /// Returns an iterator over all referenced symbols in this constant. + fn referenced_symbols<'a>(&'a self) -> impl Iterator + 'a + where + V: 'a; +} + +pub trait Substitutable { + /// Substitutes a variable with another constant. + fn substitute(&mut self, variable: &V, substitution: &Self); +} + +impl RuntimeConstant for T { + type FieldType = T; + + fn try_to_number(&self) -> Option { + Some(*self) + } + + fn range_constraint(&self) -> RangeConstraint { + RangeConstraint::from_value(*self) + } + + fn field_div(&self, other: &Self) -> Self { + *self / *other + } +} + +impl ReferencedSymbols for T { + fn referenced_symbols<'a>(&'a self) -> impl Iterator + 'a + where + V: 'a, + { + // No symbols in numbers. + std::iter::empty() + } +} + +impl Substitutable for T { + fn substitute(&mut self, _variable: &V, _substitution: &Self) { + // No-op for numbers. + } +} diff --git a/constraint-solver/src/solver.rs b/constraint-solver/src/solver.rs index 80466dd64..5d6cb5a7a 100644 --- a/constraint-solver/src/solver.rs +++ b/constraint-solver/src/solver.rs @@ -203,7 +203,10 @@ impl>) -> bool { match effect { - Effect::Assignment(v, expr) => self.apply_assignment(&v, &expr.into()), + Effect::Assignment(v, expr) => self.apply_assignment( + &v, + &QuadraticSymbolicExpression::from_runtime_constant(expr), + ), Effect::RangeConstraint(v, range_constraint) => { self.apply_range_constraint_update(&v, range_constraint) } @@ -223,7 +226,7 @@ impl BusInteractionVariab // Apply any assignments to the bus interaction field definitions. if let Variable::BusInteractionField(..) = variable { if let Some(value) = expr.try_to_number() { - self.bus_interaction_vars - .insert(variable.clone(), value.into()); + self.bus_interaction_vars.insert( + variable.clone(), + QuadraticSymbolicExpression::from_number(value), + ); } else { // Non-concrete assignments are only generated by the `quadratic_equivalences` module, // and bus interaction fields should only appear in bus interactions and constraints diff --git a/constraint-solver/src/solver/quadratic_equivalences.rs b/constraint-solver/src/solver/quadratic_equivalences.rs index 7a8bf9092..871f9b7e9 100644 --- a/constraint-solver/src/solver/quadratic_equivalences.rs +++ b/constraint-solver/src/solver/quadratic_equivalences.rs @@ -8,6 +8,7 @@ use powdr_number::FieldElement; use crate::{ quadratic_symbolic_expression::{QuadraticSymbolicExpression, RangeConstraintProvider}, range_constraint::RangeConstraint, + runtime_constant::RuntimeConstant, symbolic_expression::SymbolicExpression, }; diff --git a/constraint-solver/src/symbolic_expression.rs b/constraint-solver/src/symbolic_expression.rs index 02918d048..a31675bfa 100644 --- a/constraint-solver/src/symbolic_expression.rs +++ b/constraint-solver/src/symbolic_expression.rs @@ -1,5 +1,6 @@ use auto_enums::auto_enum; use itertools::Itertools; +use num_traits::{One, Zero}; use std::hash::Hash; use std::ops::Sub; use std::ops::{AddAssign, MulAssign}; @@ -10,7 +11,9 @@ use std::{ sync::Arc, }; -use powdr_number::FieldElement; +use powdr_number::{ExpressionConvertible, FieldElement}; + +use crate::runtime_constant::{ReferencedSymbols, RuntimeConstant, Substitutable}; use super::range_constraint::RangeConstraint; @@ -72,42 +75,6 @@ impl SymbolicExpression { SymbolicExpression::Symbol(symbol, rc) } } - - pub fn is_known_zero(&self) -> bool { - self.try_to_number().is_some_and(|n| n.is_zero()) - } - - pub fn is_known_one(&self) -> bool { - self.try_to_number().is_some_and(|n| n.is_one()) - } - - pub fn is_known_minus_one(&self) -> bool { - self.try_to_number().is_some_and(|n| n == -T::from(1)) - } - - pub fn is_known_nonzero(&self) -> bool { - // Only checking range constraint is enough since if this is a known - // fixed value, we will get a range constraint with just a single value. - !self.range_constraint().allows_value(0.into()) - } - - pub fn range_constraint(&self) -> RangeConstraint { - match self { - SymbolicExpression::Concrete(v) => RangeConstraint::from_value(*v), - SymbolicExpression::Symbol(.., rc) - | SymbolicExpression::BinaryOperation(.., rc) - | SymbolicExpression::UnaryOperation(.., rc) => rc.clone(), - } - } - - pub fn try_to_number(&self) -> Option { - match self { - SymbolicExpression::Concrete(n) => Some(*n), - SymbolicExpression::Symbol(..) - | SymbolicExpression::BinaryOperation(..) - | SymbolicExpression::UnaryOperation(..) => None, - } - } } impl SymbolicExpression { @@ -150,6 +117,51 @@ impl SymbolicExpression { } } +impl ExpressionConvertible for SymbolicExpression { + /// Turns a SymbolicExpression into an expression over its variables, essentially + /// making all variables unknown variables. + /// + /// Fails in case a division operation is used. + fn try_to_expression< + E: Add + Sub + Mul + Neg, + >( + &self, + number_converter: &impl Fn(&T) -> E, + var_converter: &impl Fn(&V) -> E, + try_to_number: &impl Fn(&E) -> Option, + ) -> Option { + Some(match self { + SymbolicExpression::Concrete(value) => number_converter(value), + SymbolicExpression::Symbol(var, _) => var_converter(var), + SymbolicExpression::BinaryOperation(left, op, right, _) => { + let left = + left.try_to_expression(number_converter, var_converter, try_to_number)?; + let right = + right.try_to_expression(number_converter, var_converter, try_to_number)?; + match op { + BinaryOperator::Add => left + right, + BinaryOperator::Sub => left - right, + BinaryOperator::Mul => left * right, + BinaryOperator::Div => { + if let Some(right) = try_to_number(&right) { + left * number_converter(&(T::from(1) / right)) + } else { + return None; + } + } + } + } + SymbolicExpression::UnaryOperation(op, inner, _) => { + let inner = + inner.try_to_expression(number_converter, var_converter, try_to_number)?; + match op { + UnaryOperator::Neg => -inner, + } + } + }) + } +} + impl SymbolicExpression { pub fn transform_var_type( &self, @@ -235,7 +247,7 @@ impl From for SymbolicExpression { } } -impl Add for &SymbolicExpression { +impl Add for &SymbolicExpression { type Output = SymbolicExpression; fn add(self, rhs: Self) -> Self::Output { @@ -265,20 +277,20 @@ impl Add for &SymbolicExpression { } } -impl Add for SymbolicExpression { +impl Add for SymbolicExpression { type Output = SymbolicExpression; fn add(self, rhs: Self) -> Self::Output { &self + &rhs } } -impl AddAssign for SymbolicExpression { +impl AddAssign for SymbolicExpression { fn add_assign(&mut self, rhs: Self) { *self = self.clone() + rhs; } } -impl Sub for &SymbolicExpression { +impl Sub for &SymbolicExpression { type Output = SymbolicExpression; fn sub(self, rhs: Self) -> Self::Output { @@ -304,14 +316,14 @@ impl Sub for &SymbolicExpression { } } -impl Sub for SymbolicExpression { +impl Sub for SymbolicExpression { type Output = SymbolicExpression; fn sub(self, rhs: Self) -> Self::Output { &self - &rhs } } -impl Neg for &SymbolicExpression { +impl Neg for &SymbolicExpression { type Output = SymbolicExpression; fn neg(self) -> Self::Output { @@ -360,14 +372,14 @@ impl Neg for &SymbolicExpression { } } -impl Neg for SymbolicExpression { +impl Neg for SymbolicExpression { type Output = SymbolicExpression; fn neg(self) -> Self::Output { -&self } } -impl Mul for &SymbolicExpression { +impl Mul for &SymbolicExpression { type Output = SymbolicExpression; fn mul(self, rhs: Self) -> Self::Output { @@ -395,43 +407,20 @@ impl Mul for &SymbolicExpression { } } -impl Mul for SymbolicExpression { +impl Mul for SymbolicExpression { type Output = SymbolicExpression; fn mul(self, rhs: Self) -> Self { &self * &rhs } } -impl MulAssign for SymbolicExpression { +impl MulAssign for SymbolicExpression { fn mul_assign(&mut self, rhs: Self) { *self = self.clone() * rhs; } } -impl SymbolicExpression { - /// Field element division. - /// If you use this, you must ensure that the divisor is not zero. - pub fn field_div(&self, rhs: &Self) -> Self { - if let (SymbolicExpression::Concrete(a), SymbolicExpression::Concrete(b)) = (self, rhs) { - assert!(b != &T::from(0)); - SymbolicExpression::Concrete(*a / *b) - } else if self.is_known_zero() { - SymbolicExpression::Concrete(T::from(0)) - } else if rhs.is_known_one() { - self.clone() - } else if rhs.is_known_minus_one() { - -self - } else { - // TODO other simplifications like `-x / -y => x / y`, `-x / concrete => x / -concrete`, etc. - SymbolicExpression::BinaryOperation( - Arc::new(self.clone()), - BinaryOperator::Div, - Arc::new(rhs.clone()), - Default::default(), - ) - } - } - +impl SymbolicExpression { /// Returns the multiplicative inverse in the field. pub fn field_inverse(&self) -> Self { if let SymbolicExpression::Concrete(x) = self { @@ -454,3 +443,89 @@ impl SymbolicExpression { } } } + +impl Zero for SymbolicExpression { + fn zero() -> Self { + SymbolicExpression::Concrete(T::from(0)) + } + + fn is_zero(&self) -> bool { + self.is_known_zero() + } +} + +impl One for SymbolicExpression { + fn one() -> Self { + SymbolicExpression::Concrete(T::from(1)) + } + + fn is_one(&self) -> bool { + self.is_known_one() + } +} + +impl RuntimeConstant for SymbolicExpression { + type FieldType = T; + + fn try_to_number(&self) -> Option { + match self { + SymbolicExpression::Concrete(n) => Some(*n), + SymbolicExpression::Symbol(..) + | SymbolicExpression::BinaryOperation(..) + | SymbolicExpression::UnaryOperation(..) => None, + } + } + + fn range_constraint(&self) -> RangeConstraint { + match self { + SymbolicExpression::Concrete(v) => RangeConstraint::from_value(*v), + SymbolicExpression::Symbol(.., rc) + | SymbolicExpression::BinaryOperation(.., rc) + | SymbolicExpression::UnaryOperation(.., rc) => rc.clone(), + } + } + + /// Field element division. + /// If you use this, you must ensure that the divisor is not zero. + fn field_div(&self, rhs: &Self) -> Self { + if let (SymbolicExpression::Concrete(a), SymbolicExpression::Concrete(b)) = (self, rhs) { + assert!(b != &T::from(0)); + SymbolicExpression::Concrete(*a / *b) + } else if self.is_known_zero() { + SymbolicExpression::Concrete(T::from(0)) + } else if rhs.is_known_one() { + self.clone() + } else if rhs.is_known_minus_one() { + -self + } else { + // TODO other simplifications like `-x / -y => x / y`, `-x / concrete => x / -concrete`, etc. + SymbolicExpression::BinaryOperation( + Arc::new(self.clone()), + BinaryOperator::Div, + Arc::new(rhs.clone()), + Default::default(), + ) + } + } + + fn from_u64(k: u64) -> Self { + SymbolicExpression::Concrete(T::from(k)) + } +} + +impl ReferencedSymbols + for SymbolicExpression +{ + fn referenced_symbols<'a>(&'a self) -> impl Iterator + 'a + where + V: 'a, + { + SymbolicExpression::referenced_symbols(self) + } +} + +impl Substitutable for SymbolicExpression { + fn substitute(&mut self, variable: &V, substitution: &Self) { + SymbolicExpression::substitute(self, variable, substitution); + } +} diff --git a/constraint-solver/src/symbolic_to_quadratic.rs b/constraint-solver/src/symbolic_to_quadratic.rs index 72b6e762f..8b1378917 100644 --- a/constraint-solver/src/symbolic_to_quadratic.rs +++ b/constraint-solver/src/symbolic_to_quadratic.rs @@ -1,48 +1 @@ -use std::hash::Hash; -use powdr_number::FieldElement; - -use crate::{ - quadratic_symbolic_expression::QuadraticSymbolicExpression, - symbolic_expression::{BinaryOperator, SymbolicExpression, UnaryOperator}, -}; - -/// Turns a SymbolicExpression to a QuadraticSymbolicExpression essentially -/// making all variables unknown variables. -/// -/// Fails in case a division operation is used. -pub fn symbolic_expression_to_quadratic_symbolic_expression< - T: FieldElement, - V: Clone + Ord + Hash, ->( - e: &SymbolicExpression, -) -> Option> { - Some(match e { - SymbolicExpression::Concrete(value) => (*value).into(), - SymbolicExpression::Symbol(var, _) => { - QuadraticSymbolicExpression::from_unknown_variable(var.clone()) - } - SymbolicExpression::BinaryOperation(left, op, right, _) => { - let left = symbolic_expression_to_quadratic_symbolic_expression(left)?; - let right = symbolic_expression_to_quadratic_symbolic_expression(right)?; - match op { - BinaryOperator::Add => left + right, - BinaryOperator::Sub => left - right, - BinaryOperator::Mul => left * right, - BinaryOperator::Div => { - if let Some(right) = right.try_to_known() { - left * right.field_inverse() - } else { - return None; - } - } - } - } - SymbolicExpression::UnaryOperation(op, inner, _) => { - let inner = symbolic_expression_to_quadratic_symbolic_expression(inner)?; - match op { - UnaryOperator::Neg => -inner, - } - } - }) -} diff --git a/constraint-solver/src/test_utils.rs b/constraint-solver/src/test_utils.rs index 804543432..a5fe92477 100644 --- a/constraint-solver/src/test_utils.rs +++ b/constraint-solver/src/test_utils.rs @@ -10,5 +10,5 @@ pub fn var(name: Var) -> Qse { } pub fn constant(value: u64) -> Qse { - GoldilocksField::from(value).into() + Qse::from_number(GoldilocksField::from(value)) } diff --git a/constraint-solver/tests/solver.rs b/constraint-solver/tests/solver.rs index 03af22f4d..db94d7557 100644 --- a/constraint-solver/tests/solver.rs +++ b/constraint-solver/tests/solver.rs @@ -371,8 +371,8 @@ fn binary_flags() { fn ternary_flags() { // Implementing this logic in the OpenVM load/store chip: // https://github.com/openvm-org/openvm/blob/v1.2.0/extensions/rv32im/circuit/src/loadstore/core.rs#L110-L139 - let two_inv = Qse::from(GoldilocksField::one() / GoldilocksField::from(2)); - let neg_one = Qse::from(-GoldilocksField::one()); + let two_inv = Qse::from_number(GoldilocksField::one() / GoldilocksField::from(2)); + let neg_one = Qse::from_number(-GoldilocksField::one()); let sum = var("flag0") + var("flag1") + var("flag2") + var("flag3"); // The flags must be 0, 1, or 2, and their sum must be 1 or 2. // Given these constraints, there are 14 possible assignments. The following diff --git a/executor/src/witgen/jit/block_machine_processor.rs b/executor/src/witgen/jit/block_machine_processor.rs index 6ebc8fa47..b46cf33c1 100644 --- a/executor/src/witgen/jit/block_machine_processor.rs +++ b/executor/src/witgen/jit/block_machine_processor.rs @@ -2,6 +2,7 @@ use std::collections::{BTreeMap, BTreeSet, HashSet}; use bit_vec::BitVec; use itertools::Itertools; +use num_traits::{One, Zero}; use powdr_ast::analyzed::{ContainsNextRef, PolyID, PolynomialType}; use powdr_constraint_solver::quadratic_symbolic_expression::QuadraticSymbolicExpression; use powdr_number::FieldElement; @@ -85,7 +86,7 @@ impl<'a, T: FieldElement> BlockMachineProcessor<'a, T> { let selector = &bus_receive.selected_payload.selector; queue_items.extend(algebraic_expression_to_queue_items( selector, - T::one(), + QuadraticSymbolicExpression::one(), self.latch_row as i32, &witgen, )); @@ -94,7 +95,7 @@ impl<'a, T: FieldElement> BlockMachineProcessor<'a, T> { // Set the known argument to the concrete value. queue_items.extend(algebraic_expression_to_queue_items( &bus_receive.selected_payload.expressions[index], - value, + QuadraticSymbolicExpression::from_number(value), self.latch_row as i32, &witgen, )); @@ -106,7 +107,7 @@ impl<'a, T: FieldElement> BlockMachineProcessor<'a, T> { if other_selector != selector { queue_items.extend(algebraic_expression_to_queue_items( other_selector, - T::zero(), + QuadraticSymbolicExpression::zero(), self.latch_row as i32, &witgen, )); diff --git a/executor/src/witgen/jit/compiler.rs b/executor/src/witgen/jit/compiler.rs index 081cb7118..67725f738 100644 --- a/executor/src/witgen/jit/compiler.rs +++ b/executor/src/witgen/jit/compiler.rs @@ -377,6 +377,7 @@ fn format_effect(effect: &Effect, is_top_level: bo lhs, rhs, expected_equal, + .. }) => format!( "assert!({} {} {});", format_expression(lhs), @@ -534,7 +535,7 @@ fn format_bit_decomposition( } fn format_condition( - Condition { value, condition }: &Condition, + Condition { value, condition }: &Condition>, ) -> String { let value = format!("IntType::from({})", format_expression(value)); let (min, max) = condition.range(); @@ -676,6 +677,7 @@ mod tests { use powdr_ast::analyzed::AlgebraicReference; use powdr_ast::analyzed::FunctionValueDefinition; use powdr_constraint_solver::range_constraint::RangeConstraint; + use powdr_constraint_solver::runtime_constant::RuntimeConstant; use pretty_assertions::assert_eq; use test_log::test; @@ -770,7 +772,7 @@ mod tests { assignment(&x0, number(7) * symbol(&a0)), assignment(&cv1, symbol(&x0)), Effect::MachineCall( - 7.into(), + GoldilocksField::from(7), [false, true].into_iter().collect(), vec![r1.clone(), cv1.clone()], ), diff --git a/executor/src/witgen/jit/debug_formatter.rs b/executor/src/witgen/jit/debug_formatter.rs index eca1c7397..30aa71ab5 100644 --- a/executor/src/witgen/jit/debug_formatter.rs +++ b/executor/src/witgen/jit/debug_formatter.rs @@ -3,7 +3,9 @@ use powdr_ast::analyzed::{ AlgebraicBinaryOperation, AlgebraicBinaryOperator, AlgebraicExpression as Expression, AlgebraicUnaryOperation, PolynomialIdentity, SelectedExpressions, }; -use powdr_constraint_solver::range_constraint::RangeConstraint; +use powdr_constraint_solver::{ + range_constraint::RangeConstraint, runtime_constant::RuntimeConstant, +}; use powdr_number::FieldElement; use crate::witgen::data_structures::identity::{BusSend, Identity}; diff --git a/executor/src/witgen/jit/effect.rs b/executor/src/witgen/jit/effect.rs index 3b27b785b..dfc3257ab 100644 --- a/executor/src/witgen/jit/effect.rs +++ b/executor/src/witgen/jit/effect.rs @@ -8,6 +8,7 @@ use powdr_ast::indent; use powdr_constraint_solver::effect::{ Assertion, BitDecomposition, BitDecompositionComponent, Condition, }; +use powdr_constraint_solver::runtime_constant::RuntimeConstant; use powdr_constraint_solver::symbolic_expression::SymbolicExpression; use powdr_number::FieldElement; @@ -18,24 +19,33 @@ use super::variable::Variable; /// The effect of solving a symbolic equation. #[derive(Clone, PartialEq, Eq)] -pub enum Effect { +pub enum Effect +where + SymbolicExpression: RuntimeConstant, +{ /// Variable can be assigned a value. Assignment(V, SymbolicExpression), /// Perform a bit decomposition of a known value, and assign multiple variables. - BitDecomposition(BitDecomposition), + BitDecomposition(BitDecomposition, V>), /// We learnt a new range constraint on variable. RangeConstraint(V, RangeConstraint), /// A run-time assertion. If this fails, we have conflicting constraints. - Assertion(Assertion), + Assertion(Assertion>), /// A call to a different machine, with bus ID, known inputs and argument variables. MachineCall(T, BitVec, Vec), /// Compute one variable by executing a prover function (given by index) on the value of other variables. ProverFunctionCall(ProverFunctionCall), /// A branch on a variable. - Branch(Condition, Vec>, Vec>), + Branch( + Condition>, + Vec>, + Vec>, + ), } -impl From> for Effect { +impl From> + for Effect +{ fn from(effect: ConstraintSolverEffect) -> Self { match effect { ConstraintSolverEffect::Assignment(v, expr) => Effect::Assignment(v, expr), @@ -87,7 +97,10 @@ impl Effect { } } -impl Effect { +impl Effect +where + SymbolicExpression: RuntimeConstant, +{ /// Returns an iterator over all referenced variables, both read and written. pub fn referenced_variables(&self) -> impl Iterator { let iter: Box> = match self { @@ -143,6 +156,7 @@ pub fn format_code(effects: &[Effect]) -> String { lhs, rhs, expected_equal, + .. }) => { format!( "assert {lhs} {} {rhs};", @@ -206,7 +220,7 @@ pub fn format_code(effects: &[Effect]) -> String { } fn format_condition( - Condition { value, condition }: &Condition, + Condition { value, condition }: &Condition>, ) -> String { let (min, max) = condition.range(); match min.cmp(&max) { diff --git a/executor/src/witgen/jit/identity_queue.rs b/executor/src/witgen/jit/identity_queue.rs index a249a11c3..bd9e8041c 100644 --- a/executor/src/witgen/jit/identity_queue.rs +++ b/executor/src/witgen/jit/identity_queue.rs @@ -12,6 +12,7 @@ use powdr_ast::{ }; use powdr_constraint_solver::{ quadratic_symbolic_expression::QuadraticSymbolicExpression, + runtime_constant::RuntimeConstant, symbolic_expression::SymbolicExpression, variable_update::{UpdateKind, VariableUpdate}, }; diff --git a/executor/src/witgen/jit/interpreter.rs b/executor/src/witgen/jit/interpreter.rs index b874dffe8..cbff3de6a 100644 --- a/executor/src/witgen/jit/interpreter.rs +++ b/executor/src/witgen/jit/interpreter.rs @@ -80,7 +80,7 @@ enum BranchTest { impl BranchTest { fn new( var_mapper: &mut VariableMapper, - Condition { value, condition }: &Condition, + Condition { value, condition }: &Condition>, ) -> Self { let (min, max) = condition.range(); let value = var_mapper.map_expr_to_rpn(value); diff --git a/executor/src/witgen/jit/processor.rs b/executor/src/witgen/jit/processor.rs index bbab56534..1b029363d 100644 --- a/executor/src/witgen/jit/processor.rs +++ b/executor/src/witgen/jit/processor.rs @@ -1,6 +1,7 @@ use std::fmt::{self, Display, Formatter, Write}; use itertools::Itertools; +use num_traits::Zero; use powdr_ast::analyzed::{AlgebraicExpression, PolynomialIdentity}; use powdr_constraint_solver::range_constraint::RangeConstraint; use powdr_constraint_solver::symbolic_expression::SymbolicExpression; @@ -48,7 +49,7 @@ pub struct Processor<'a, T: FieldElement> { pub struct ProcessorResult { /// Generated code. pub code: Vec>, - /// Range constrainst of the variables they were requested on. + /// Range constraints of the variables they were requested on. pub range_constraints: Vec>, } @@ -111,7 +112,7 @@ impl<'a, T: FieldElement> Processor<'a, T> { } Identity::Polynomial(identity) => algebraic_expression_to_queue_items( &identity.expression, - T::zero(), + QuadraticSymbolicExpression::zero(), *row_offset, &witgen, ) diff --git a/executor/src/witgen/jit/witgen_inference.rs b/executor/src/witgen/jit/witgen_inference.rs index e80f83048..72eba2030 100644 --- a/executor/src/witgen/jit/witgen_inference.rs +++ b/executor/src/witgen/jit/witgen_inference.rs @@ -15,6 +15,7 @@ use powdr_constraint_solver::{ Error, ProcessResult, QuadraticSymbolicExpression, RangeConstraintProvider, }, range_constraint::RangeConstraint, + runtime_constant::RuntimeConstant, symbolic_expression::SymbolicExpression, }; use powdr_number::FieldElement; @@ -75,7 +76,7 @@ pub struct BranchResult<'a, T: FieldElement, FixedEval> { /// The code common to both branches. pub common_code: Vec>, /// The condition of the branch. - pub condition: Condition, + pub condition: Condition>, /// The two branches. pub branches: [WitgenInference<'a, T, FixedEval>; 2], } @@ -232,7 +233,8 @@ impl<'a, T: FieldElement, FixedEval: FixedEvaluator> WitgenInference<'a, T, F /// Set a variable to a fixed value. pub fn set_variable(&mut self, variable: Variable, value: T) -> Result, Error> { self.process_equation( - &(QuadraticSymbolicExpression::from_unknown_variable(variable) - value.into()), + &(QuadraticSymbolicExpression::from_unknown_variable(variable) + - QuadraticSymbolicExpression::from_number(value)), ) } @@ -466,7 +468,7 @@ impl<'a, T: FieldElement, FixedEval: FixedEvaluator> WitgenInference<'a, T, F self, ), Expression::PublicReference(_) | Expression::Challenge(_) => todo!(), - Expression::Number(n) => (*n).into(), + Expression::Number(n) => QuadraticSymbolicExpression::from_number(*n), Expression::BinaryOperation(AlgebraicBinaryOperation { left, op, right }) => { let left = self.evaluate(left, row_offset, require_concretely_known); let right = self.evaluate(right, row_offset, require_concretely_known); diff --git a/expression/src/conversion.rs b/expression/src/conversion.rs deleted file mode 100644 index cbb1fe11a..000000000 --- a/expression/src/conversion.rs +++ /dev/null @@ -1,38 +0,0 @@ -use powdr_number::FieldElement; - -use super::{ - AlgebraicBinaryOperation, AlgebraicBinaryOperator, AlgebraicExpression, - AlgebraicUnaryOperation, AlgebraicUnaryOperator, -}; - -/// Converts an AlgebraicExpression into a different structure that supports algebraic operations. -/// The `reference_converter` is used to convert the reference that appear in the expression. -pub fn convert( - expr: &AlgebraicExpression, - reference_converter: &mut impl FnMut(&R) -> Target, -) -> Target -where - Target: From - + Clone - + std::ops::Add - + std::ops::Sub - + std::ops::Mul - + std::ops::Neg, -{ - match expr { - AlgebraicExpression::Reference(r) => reference_converter(r), - AlgebraicExpression::Number(n) => (*n).into(), - AlgebraicExpression::BinaryOperation(AlgebraicBinaryOperation { left, op, right }) => { - let left = convert(left, reference_converter); - let right = convert(right, reference_converter); - match op { - AlgebraicBinaryOperator::Add => left + right, - AlgebraicBinaryOperator::Sub => left - right, - AlgebraicBinaryOperator::Mul => left * right, - } - } - AlgebraicExpression::UnaryOperation(AlgebraicUnaryOperation { op, expr }) => match op { - AlgebraicUnaryOperator::Minus => -convert(expr, reference_converter), - }, - } -} diff --git a/expression/src/lib.rs b/expression/src/lib.rs index 9b094243c..736eb2db3 100644 --- a/expression/src/lib.rs +++ b/expression/src/lib.rs @@ -1,9 +1,12 @@ -use std::{iter, ops}; +use std::{ + iter, + ops::{self, Add, Mul, Neg, Sub}, +}; +use powdr_number::ExpressionConvertible; use schemars::JsonSchema; use serde::{Deserialize, Serialize}; -pub mod conversion; pub mod display; pub mod visitors; @@ -164,3 +167,33 @@ impl From for AlgebraicExpression { AlgebraicExpression::Number(value) } } + +impl ExpressionConvertible for AlgebraicExpression { + fn to_expression< + E: Add + Sub + Mul + Neg, + >( + &self, + number_converter: &impl Fn(&T) -> E, + var_converter: &impl Fn(&R) -> E, + ) -> E { + match self { + AlgebraicExpression::Reference(r) => var_converter(r), + AlgebraicExpression::Number(n) => number_converter(n), + AlgebraicExpression::BinaryOperation(AlgebraicBinaryOperation { left, op, right }) => { + let left = left.to_expression(number_converter, var_converter); + let right = right.to_expression(number_converter, var_converter); + + match op { + AlgebraicBinaryOperator::Add => left + right, + AlgebraicBinaryOperator::Sub => left - right, + AlgebraicBinaryOperator::Mul => left * right, + } + } + AlgebraicExpression::UnaryOperation(AlgebraicUnaryOperation { op, expr }) => match op { + AlgebraicUnaryOperator::Minus => { + -expr.to_expression(number_converter, var_converter) + } + }, + } + } +} diff --git a/number/src/expression_convertible.rs b/number/src/expression_convertible.rs new file mode 100644 index 000000000..ec7610cf2 --- /dev/null +++ b/number/src/expression_convertible.rs @@ -0,0 +1,49 @@ +use std::ops::{Add, Mul, Neg, Sub}; + +use crate::FieldElement; + +pub trait ExpressionConvertible { + /// Converts `self` into a structure that supports algebraic operations. + /// + /// Fails in case a non-algebraic operation is used. + /// + /// The `try_to_number` function is used to check if some conversions can be simplified. + /// + /// This or `to_expression` must be implemented. + fn try_to_expression< + E: Add + Sub + Mul + Neg, + >( + &self, + number_converter: &impl Fn(&T) -> E, + var_converter: &impl Fn(&V) -> E, + _try_to_number: &impl Fn(&E) -> Option, + ) -> Option { + Some(self.to_expression(number_converter, var_converter)) + } + + /// Converts `self` into a structure that supports algebraic operations. + /// + /// This or `try_to_expression` must be implemented. + fn to_expression< + E: Add + Sub + Mul + Neg, + >( + &self, + number_converter: &impl Fn(&T) -> E, + var_converter: &impl Fn(&V) -> E, + ) -> E { + self.try_to_expression(number_converter, var_converter, &|_| unreachable!()) + .unwrap() + } +} + +impl ExpressionConvertible for T { + fn to_expression< + E: Add + Sub + Mul + Neg, + >( + &self, + number_converter: &impl Fn(&T) -> E, + _var_converter: &impl Fn(&V) -> E, + ) -> E { + number_converter(self) + } +} diff --git a/number/src/lib.rs b/number/src/lib.rs index cfa0c2dd2..513397029 100644 --- a/number/src/lib.rs +++ b/number/src/lib.rs @@ -9,14 +9,17 @@ mod koala_bear; mod mersenne31; #[macro_use] mod plonky3_macros; +mod expression_convertible; mod serialize; mod traits; + pub use serialize::{ buffered_write_file, read_polys_csv_file, write_polys_csv_file, CsvRenderMode, ReadWrite, }; pub use baby_bear::BabyBearField; pub use bn254::Bn254Field; +pub use expression_convertible::ExpressionConvertible; pub use goldilocks::GoldilocksField; pub use koala_bear::KoalaBearField; pub use mersenne31::Mersenne31Field; diff --git a/pilopt/src/qse_opt.rs b/pilopt/src/qse_opt.rs index d316dc903..d68e2c2a2 100644 --- a/pilopt/src/qse_opt.rs +++ b/pilopt/src/qse_opt.rs @@ -9,6 +9,7 @@ use powdr_ast::analyzed::{ }; use powdr_constraint_solver::constraint_system::ConstraintSystem; use powdr_constraint_solver::indexed_constraint_system::apply_substitutions; +use powdr_constraint_solver::runtime_constant::RuntimeConstant; use powdr_constraint_solver::{ quadratic_symbolic_expression::QuadraticSymbolicExpression, solver::{self, SolveResult}, @@ -114,7 +115,7 @@ pub fn algebraic_to_quadratic_symbolic_expression( struct TerminalConverter; - impl algebraic_expression_conversion::TerminalConverter> + impl algebraic_expression_conversion::TerminalConverter> for TerminalConverter { fn convert_reference(&mut self, reference: &AlgebraicReference) -> Qse { @@ -126,6 +127,9 @@ pub fn algebraic_to_quadratic_symbolic_expression( fn convert_challenge(&mut self, challenge: &Challenge) -> Qse { Qse::from_unknown_variable(Variable::Challenge(*challenge)) } + fn convert_number(&mut self, number: &T) -> Qse { + Qse::from_number(*number) + } } algebraic_expression_conversion::convert(expr, &mut TerminalConverter)