From ea1adcbd02372e37f6afdb32b750fd4ab89ea5db Mon Sep 17 00:00:00 2001 From: chriseth Date: Thu, 23 Mar 2023 11:01:46 +0100 Subject: [PATCH] Transfer global bit constraints. --- src/commit_evaluator/affine_expression.rs | 8 +- src/commit_evaluator/bit_constraints.rs | 86 ++++++++++++------- src/commit_evaluator/evaluator.rs | 4 +- .../sorted_witness_machine.rs | 17 ++-- src/commit_evaluator/symbolic_evaluator.rs | 68 ++++++++++----- src/commit_evaluator/util.rs | 11 ++- src/number.rs | 2 +- 7 files changed, 129 insertions(+), 67 deletions(-) diff --git a/src/commit_evaluator/affine_expression.rs b/src/commit_evaluator/affine_expression.rs index 0b365e86b..5dcef410d 100644 --- a/src/commit_evaluator/affine_expression.rs +++ b/src/commit_evaluator/affine_expression.rs @@ -256,7 +256,7 @@ impl AffineExpression { let name = namer.name(i); if *c == 1.into() { name - } else if *c == (-1).into() { + } else if *c == clamp((-1).into()) { format!("-{name}") } else { format!("{} * {name}", format_number(c)) @@ -327,9 +327,9 @@ impl std::ops::Neg for AffineExpression { type Output = AffineExpression; fn neg(mut self) -> Self::Output { - self.coefficients - .iter_mut() - .for_each(|v| *v = clamp(-v.clone())); + self.coefficients.iter_mut().for_each(|v| { + *v = clamp(-v.clone()); + }); self.offset = clamp(-self.offset); self } diff --git a/src/commit_evaluator/bit_constraints.rs b/src/commit_evaluator/bit_constraints.rs index 24548a8de..87d65e208 100644 --- a/src/commit_evaluator/bit_constraints.rs +++ b/src/commit_evaluator/bit_constraints.rs @@ -2,6 +2,7 @@ use std::collections::{BTreeMap, BTreeSet}; use std::fmt::{Display, Formatter}; use crate::analyzer::{BinaryOperator, Expression, Identity, IdentityKind, PolynomialReference}; +use crate::commit_evaluator::util::{contains_next_ref, WitnessColumnNamer}; use crate::number::{AbstractNumberType, GOLDILOCKS_MOD}; use super::expression_evaluator::ExpressionEvaluator; @@ -79,6 +80,19 @@ pub trait BitConstraintSet { fn bit_constraint(&self, id: usize) -> Option; } +pub struct SimpleBitConstraintSet<'a, Namer: WitnessColumnNamer> { + bit_constraints: &'a BTreeMap<&'a str, BitConstraint>, + names: &'a Namer, +} + +impl<'a, Namer: WitnessColumnNamer> BitConstraintSet for SimpleBitConstraintSet<'a, Namer> { + fn bit_constraint(&self, id: usize) -> Option { + self.bit_constraints + .get(self.names.name(id).as_str()) + .cloned() + } +} + /// Determines global constraints on witness and fixed columns. /// Removes identities that only serve to create bit constraints from /// the identities vector. @@ -108,19 +122,6 @@ pub fn determine_global_constraints<'a>( } } - // TODO this does not yet transfer constraitns of the form - // col fixed BYTE(i) { i & 0xff }; - // col witness A; - // A * (A - 1) = 0; - // col witness B; - // { B } in { BYTE }; - // col witness C; - // C = A * 2**8 + B; - // -> C is constrained. - // - // We can implement that in try_transfer_constraints but for that, we need a symbolic - // evaluator that handles both fixed and witness columns - (known_constraints, reduced_identities) } @@ -165,6 +166,7 @@ fn propagate_constraints<'a>( .is_none()); remove = true; } else if let Some((p, c)) = try_transfer_constraints( + fixed_data, identity.left.selector.as_ref().unwrap(), &known_constraints, ) { @@ -212,27 +214,27 @@ fn is_binary_constraint<'a>(fixed_data: &'a FixedData, expr: &Expression) -> Opt } } } else if let Expression::BinaryOperation(left, BinaryOperator::Mul, right) = expr { - let symbolic_ev = ExpressionEvaluator::new(SymbolicEvaluator::new(fixed_data)); - let left_root = symbolic_ev + let symbolic_ev = SymbolicEvaluator::new(fixed_data); + let left_root = ExpressionEvaluator::new(symbolic_ev.clone()) .evaluate(left) .ok() .and_then(|l| l.solve().ok())?; - let right_root = symbolic_ev + let right_root = ExpressionEvaluator::new(symbolic_ev.clone()) .evaluate(right) .ok() .and_then(|r| r.solve().ok())?; - if let ( - [(var1, Constraint::Assignment(value1))], - [(var2, Constraint::Assignment(value2))], - ) = (&left_root[..], &right_root[..]) + if let ([(id1, Constraint::Assignment(value1))], [(id2, Constraint::Assignment(value2))]) = + (&left_root[..], &right_root[..]) { - if var1 != var2 || *var1 >= fixed_data.witness_cols.len() { + let poly1 = symbolic_ev.poly_from_id(*id1); + let poly2 = symbolic_ev.poly_from_id(*id2); + if poly1 != poly2 || !fixed_data.witness_ids.contains_key(poly1.0) { return None; } if (*value1 == 0.into() && *value2 == 1.into()) || (*value1 == 1.into() && *value2 == 0.into()) { - return Some(fixed_data.witness_cols[*var1].name); + return Some(poly1.0); } } } @@ -241,12 +243,35 @@ fn is_binary_constraint<'a>(fixed_data: &'a FixedData, expr: &Expression) -> Opt /// Tries to transfer constraints in a linear expression. fn try_transfer_constraints<'a>( - _expr: &'a Expression, - _known_constraints: &BTreeMap<&str, BitConstraint>, + fixed_data: &'a FixedData, + expr: &'a Expression, + known_constraints: &BTreeMap<&str, BitConstraint>, ) -> Option<(&'a str, BitConstraint)> { - None - // TODO we do some of this for each row, but we could also do it globally here. - //todo!(); + if contains_next_ref(expr) { + return None; + } + + let symbolic_ev = SymbolicEvaluator::new(fixed_data); + let aff_expr = ExpressionEvaluator::new(symbolic_ev.clone()) + .evaluate(expr) + .ok()?; + + let result = aff_expr + .solve_with_bit_constraints(&SimpleBitConstraintSet { + bit_constraints: known_constraints, + names: &symbolic_ev, + }) + .ok()?; + assert!(result.len() <= 1); + result.get(0).map(|(id, cons)| { + if let Constraint::BitConstraint(cons) = cons { + let (poly, next) = symbolic_ev.poly_from_id(*id); + assert!(!next); + (poly, cons.clone()) + } else { + panic!(); + } + }) } fn is_simple_poly(expr: &Expression) -> Option<&str> { @@ -325,9 +350,8 @@ namespace Global(2**20); (1 - A + 0) * (A + 1 - 1) = 0; col witness B; { B } in { BYTE }; - // TODO we could infer constraints here in the future. - //col witness C; - //C = A * 2**8 + B; + col witness C; + C = A * 2**8 + B; "; let analyzed = crate::analyzer::analyze_string(pil_source); let (constants, degree) = crate::constant_evaluator::generate(&analyzed); @@ -379,7 +403,7 @@ namespace Global(2**20); vec![ ("Global.A", BitConstraint::from_max(0)), ("Global.B", BitConstraint::from_max(7)), - //("Global.C", BitConstraint::from_max(8)), + ("Global.C", BitConstraint::from_max(8)), ("Global.BYTE", BitConstraint::from_max(7)), ("Global.BYTE2", BitConstraint::from_max(15)), ] diff --git a/src/commit_evaluator/evaluator.rs b/src/commit_evaluator/evaluator.rs index 38597ca9d..38c0f8955 100644 --- a/src/commit_evaluator/evaluator.rs +++ b/src/commit_evaluator/evaluator.rs @@ -10,7 +10,7 @@ use super::bit_constraints::{BitConstraint, BitConstraintSet}; use super::eval_error::EvalError; use super::expression_evaluator::{ExpressionEvaluator, SymbolicVariables}; use super::machine::Machine; -use super::util::contains_next_ref; +use super::util::contains_next_witness_ref; use super::{Constraint, EvalResult, FixedData, WitnessColumn}; pub struct Evaluator<'a, QueryCallback> @@ -249,7 +249,7 @@ where fn process_polynomial_identity(&self, identity: &Expression) -> EvalResult { // If there is no "next" reference in the expression, // we just evaluate it directly on the "next" row. - let row = if contains_next_ref(identity, self.fixed_data) { + let row = if contains_next_witness_ref(identity, self.fixed_data) { EvaluationRow::Current } else { EvaluationRow::Next diff --git a/src/commit_evaluator/sorted_witness_machine.rs b/src/commit_evaluator/sorted_witness_machine.rs index 29582db89..9ffe14975 100644 --- a/src/commit_evaluator/sorted_witness_machine.rs +++ b/src/commit_evaluator/sorted_witness_machine.rs @@ -96,8 +96,8 @@ fn check_identity<'a>(fixed_data: &'a FixedData, id: &Identity) -> Option<&'a st /// Checks that the identity has a constraint of the form `a' - a` as the first expression /// on the left hand side and returns the name of the witness column. fn check_constraint<'a>(fixed_data: &'a FixedData, constraint: &Expression) -> Option<&'a str> { - let symbolic_ev = ExpressionEvaluator::new(SymbolicEvaluator::new(fixed_data)); - let sort_constraint = match symbolic_ev.evaluate(constraint) { + let symbolic_ev = SymbolicEvaluator::new(fixed_data); + let sort_constraint = match ExpressionEvaluator::new(symbolic_ev.clone()).evaluate(constraint) { Ok(c) => c, Err(_) => return None, }; @@ -105,18 +105,21 @@ fn check_constraint<'a>(fixed_data: &'a FixedData, constraint: &Expression) -> O [key, _] => *key, _ => return None, }; - let witness_count = fixed_data.witness_cols.len(); - if key_column_id >= witness_count { + let (poly, next) = symbolic_ev.poly_from_id(key_column_id); + if next || fixed_data.witness_ids.get(poly).is_none() { // Either next-witness or fixed column. return None; } - let pattern = AffineExpression::from_witness_poly_value(key_column_id + witness_count) - - AffineExpression::from_witness_poly_value(key_column_id); + let pattern = + AffineExpression::from_witness_poly_value(symbolic_ev.id_for_witness_poly(poly, true)) + - AffineExpression::from_witness_poly_value( + symbolic_ev.id_for_witness_poly(poly, false), + ); if sort_constraint != pattern { return None; } - Some(fixed_data.witness_cols[key_column_id].name) + Some(poly) } impl Machine for SortedWitnesses { diff --git a/src/commit_evaluator/symbolic_evaluator.rs b/src/commit_evaluator/symbolic_evaluator.rs index 7a04a0f45..33de52890 100644 --- a/src/commit_evaluator/symbolic_evaluator.rs +++ b/src/commit_evaluator/symbolic_evaluator.rs @@ -1,4 +1,4 @@ -use std::collections::{BTreeSet, HashMap}; +use std::collections::HashMap; use super::affine_expression::AffineExpression; use super::eval_error::EvalError; @@ -14,27 +14,56 @@ use super::FixedData; /// returned by the EvaluationData struct. /// The only IDs are allocated in the following order: /// witness columns, next witness columns, fixed columns, next fixed columns. +#[derive(Clone)] pub struct SymbolicEvaluator<'a> { fixed_data: &'a FixedData<'a>, - fixed_columns: HashMap<&'a str, usize>, + fixed_ids: HashMap<&'a str, usize>, + fixed_names: Vec<&'a str>, } impl<'a> SymbolicEvaluator<'a> { pub fn new(fixed_data: &'a FixedData<'a>) -> Self { - let fixed_columns = fixed_data - .fixed_cols - .keys() - .cloned() - .collect::>() - .into_iter() + let mut fixed_names = fixed_data.fixed_cols.keys().cloned().collect::>(); + fixed_names.sort(); + let fixed_ids = fixed_names + .iter() .enumerate() - .map(|(i, n)| (n, i)) + .map(|(i, n)| (*n, i)) .collect(); SymbolicEvaluator { fixed_data, - fixed_columns, + fixed_ids, + fixed_names, } } + + pub fn poly_from_id(&self, id: usize) -> (&'a str, bool) { + let witness_count = self.fixed_data.witness_ids.len(); + if id < 2 * witness_count { + ( + self.fixed_data.witness_cols[id % witness_count].name, + id >= witness_count, + ) + } else { + let fixed_count = self.fixed_ids.len(); + let fixed_id = id - 2 * witness_count; + ( + self.fixed_names[fixed_id % fixed_count], + fixed_id >= fixed_count, + ) + } + } + + pub fn id_for_fixed_poly(&self, name: &str, next: bool) -> usize { + let witness_count = self.fixed_data.witness_ids.len(); + let fixed_count = self.fixed_ids.len(); + + 2 * witness_count + self.fixed_ids[name] + if next { fixed_count } else { 0 } + } + pub fn id_for_witness_poly(&self, name: &str, next: bool) -> usize { + let witness_count = self.fixed_data.witness_ids.len(); + self.fixed_data.witness_ids[name] + if next { witness_count } else { 0 } + } } impl<'a> SymbolicVariables for SymbolicEvaluator<'a> { @@ -43,17 +72,14 @@ impl<'a> SymbolicVariables for SymbolicEvaluator<'a> { } fn value(&self, name: &str, next: bool) -> Result { - let witness_count = self.fixed_data.witness_ids.len(); // TODO arrays - if let Some(id) = self.fixed_data.witness_ids.get(name) { + if self.fixed_data.witness_ids.get(name).is_some() { Ok(AffineExpression::from_witness_poly_value( - *id + if next { witness_count } else { 0 }, + self.id_for_witness_poly(name, next), )) } else { - let id = self.fixed_columns[name]; - let fixed_count = self.fixed_data.fixed_cols.len(); Ok(AffineExpression::from_witness_poly_value( - id + witness_count + if next { fixed_count } else { 0 }, + self.id_for_fixed_poly(name, next), )) } } @@ -64,12 +90,12 @@ impl<'a> SymbolicVariables for SymbolicEvaluator<'a> { } impl<'a> WitnessColumnNamer for SymbolicEvaluator<'a> { - fn name(&self, i: usize) -> String { - let witness_count = self.fixed_data.witness_ids.len(); - if i < witness_count { - self.fixed_data.name(i) + fn name(&self, id: usize) -> String { + let (name, next) = self.poly_from_id(id); + if next { + format!("{name}'") } else { - format!("{}'", self.fixed_data.name(i - witness_count)) + name.to_string() } } } diff --git a/src/commit_evaluator/util.rs b/src/commit_evaluator/util.rs index f9f080158..4cb24a3c9 100644 --- a/src/commit_evaluator/util.rs +++ b/src/commit_evaluator/util.rs @@ -6,8 +6,17 @@ pub trait WitnessColumnNamer { fn name(&self, i: usize) -> String; } +/// @returns true if the expression contains a reference to a next value of a +/// (witness or fixed) column +pub fn contains_next_ref(expr: &Expression) -> bool { + expr_any(expr, &mut |e| match e { + Expression::PolynomialReference(poly) => poly.next, + _ => false, + }) +} + /// @returns true if the expression contains a reference to a next value of a witness column. -pub fn contains_next_ref(expr: &Expression, fixed_data: &FixedData) -> bool { +pub fn contains_next_witness_ref(expr: &Expression, fixed_data: &FixedData) -> bool { expr_any(expr, &mut |e| match e { Expression::PolynomialReference(poly) => { poly.next && fixed_data.witness_ids.contains_key(poly.name.as_str()) diff --git a/src/number.rs b/src/number.rs index 055fd9c34..757c497b2 100644 --- a/src/number.rs +++ b/src/number.rs @@ -23,7 +23,7 @@ pub const GOLDILOCKS_MOD: u64 = 0xffffffff00000001u64; pub fn format_number(x: &AbstractNumberType) -> String { if *x > (GOLDILOCKS_MOD / 2).into() { - format!("{}", GOLDILOCKS_MOD - x) + format!("-{}", GOLDILOCKS_MOD - x) } else { format!("{x}") }