diff --git a/src/witness_generator/affine_expression.rs b/src/witness_generator/affine_expression.rs index 8bdb21c8b..1da08441c 100644 --- a/src/witness_generator/affine_expression.rs +++ b/src/witness_generator/affine_expression.rs @@ -1,9 +1,9 @@ -use std::{collections::HashSet, ops::Not}; +use std::ops::Not; // TODO this should probably rather be a finite field element. use crate::number::{format_number, is_zero, AbstractNumberType, GOLDILOCKS_MOD}; -use super::bit_constraints::{BitConstraint, BitConstraintSet}; +use super::bit_constraints::BitConstraintSet; use super::eval_error::EvalError::ConflictingBitConstraints; use super::util::WitnessColumnNamer; use super::Constraint; @@ -186,7 +186,7 @@ impl AffineExpression { parts .reduce(|c1, c2| match (c1, c2) { - (Some(c1), Some(c2)) => c1.try_combine(&c2), + (Some(c1), Some(c2)) => c1.try_combine_sum(&c2), _ => None, }) .flatten() @@ -206,6 +206,7 @@ impl AffineExpression { .map(|(i, coeff)| { ( i, + coeff, known_constraints .bit_constraint(i) .unwrap() @@ -213,25 +214,24 @@ impl AffineExpression { ) }) .collect::>(); - if parts.iter().any(|(_i, con)| con.is_none()) { + if parts.iter().any(|(_i, _coeff, con)| con.is_none()) { return Ok(vec![]); } // Check if they are mutually exclusive and compute assignments. - let mut covered_bits = HashSet::::new(); + let mut covered_bits: AbstractNumberType = 0.into(); let mut assignments = vec![]; let mut offset = clamp(-self.offset.clone()); - for (i, con) in parts { - let con = con.clone().unwrap(); - let BitConstraint { min_bit, max_bit } = con; - for bit in min_bit..=max_bit { - if !covered_bits.insert(bit) { - return Ok(vec![]); - } + for (i, coeff, constraint) in parts { + let constraint = constraint.clone().unwrap(); + let mask = constraint.mask(); + if mask.clone() & covered_bits.clone() != 0.into() { + return Ok(vec![]); + } else { + covered_bits |= mask.clone(); } - let mask: AbstractNumberType = con.mask(); assignments.push(( i, - Constraint::Assignment((offset.clone() & mask.clone()) >> min_bit), + Constraint::Assignment((offset.clone() & mask.clone()) / coeff.clone()), )); offset &= mask.not(); } @@ -429,28 +429,40 @@ mod test { - AffineExpression::from_witness_poly_value(3); let known_constraints = TestBitConstraints( vec![ - (2, BitConstraint::from_max(7)), - (3, BitConstraint::from_max(3)), + (2, BitConstraint::from_max_bit(7)), + (3, BitConstraint::from_max_bit(3)), ] .into_iter() .collect(), ); assert_eq!( expr.solve_with_bit_constraints(&known_constraints).unwrap(), - vec![(1, Constraint::BitConstraint(BitConstraint::from_max(11)))] + vec![( + 1, + Constraint::BitConstraint(BitConstraint::from_max_bit(11)) + )] ); assert_eq!( (-expr) .solve_with_bit_constraints(&known_constraints) .unwrap(), - vec![(1, Constraint::BitConstraint(BitConstraint::from_max(11)))] + vec![( + 1, + Constraint::BitConstraint(BitConstraint::from_max_bit(11)) + )] ); // Replace factor 16 by 32. let expr = AffineExpression::from_witness_poly_value(1) - AffineExpression::from_witness_poly_value(2).mul(32.into()) - AffineExpression::from_witness_poly_value(3); - assert!(expr.solve_with_bit_constraints(&known_constraints).is_err()); + assert_eq!( + expr.solve_with_bit_constraints(&known_constraints).unwrap(), + vec![( + 1, + Constraint::BitConstraint(BitConstraint::from_mask(0x1fef.into())) + )] + ); // Replace factor 16 by 8. let expr = AffineExpression::from_witness_poly_value(1) @@ -467,8 +479,8 @@ mod test { - AffineExpression::from_witness_poly_value(3); let known_constraints = TestBitConstraints( vec![ - (2, BitConstraint::from_max(7)), - (3, BitConstraint::from_max(3)), + (2, BitConstraint::from_max_bit(7)), + (3, BitConstraint::from_max_bit(3)), ] .into_iter() .collect(), @@ -491,8 +503,8 @@ mod test { - AffineExpression::from_witness_poly_value(3); let known_constraints = TestBitConstraints( vec![ - (2, BitConstraint::from_max(7)), - (3, BitConstraint::from_max(3)), + (2, BitConstraint::from_max_bit(7)), + (3, BitConstraint::from_max_bit(3)), ] .into_iter() .collect(), diff --git a/src/witness_generator/bit_constraints.rs b/src/witness_generator/bit_constraints.rs index b4ca7df38..c6e65012c 100644 --- a/src/witness_generator/bit_constraints.rs +++ b/src/witness_generator/bit_constraints.rs @@ -13,49 +13,52 @@ use super::{Constraint, FixedData}; /// All bits smaller than min_bit have to be zero /// and all bits larger than max_bit have to be zero. /// The least significant bit is bit zero. -#[derive(PartialEq, Debug, Clone)] +#[derive(PartialEq, Clone)] pub struct BitConstraint { - pub min_bit: u64, - pub max_bit: u64, + mask: AbstractNumberType, } impl BitConstraint { - pub fn from_max(max_bit: u64) -> Self { + pub fn from_max_bit(max_bit: u64) -> Self { + assert!(max_bit < 1024); BitConstraint { - min_bit: 0, - max_bit, + mask: (AbstractNumberType::from(1) << (max_bit + 1)) - AbstractNumberType::from(1), } } + pub fn from_mask(mask: AbstractNumberType) -> Self { + BitConstraint { mask } + } + /// The bit constraint of the sum of two expressions. - pub fn try_combine(&self, other: &BitConstraint) -> Option { - if self.max_bit + 1 == other.min_bit { + pub fn try_combine_sum(&self, other: &BitConstraint) -> Option { + if self.mask.clone() & other.mask.clone() == 0.into() { Some(BitConstraint { - min_bit: self.min_bit, - max_bit: other.max_bit, - }) - } else if other.max_bit + 1 == self.min_bit { - Some(BitConstraint { - min_bit: other.min_bit, - max_bit: self.max_bit, + mask: self.mask.clone() | other.mask.clone(), }) } else { None } } + /// Returns the conjunction of this constraint and the other. + pub fn conjunction(self, other: &BitConstraint) -> BitConstraint { + BitConstraint { + mask: self.mask & other.mask.clone(), + } + } + /// The bit constraint of an integer multiple of an expression. /// TODO this assumes goldilocks pub fn multiple(&self, factor: AbstractNumberType) -> Option { - if factor.clone() * (1u64 << self.max_bit) >= GOLDILOCKS_MOD.into() { + if factor.clone() * self.mask.clone() >= GOLDILOCKS_MOD.into() { None } else { // TODO use binary logarithm (0..64).find_map(|i| { if factor.clone() == (1u64 << i).into() { Some(BitConstraint { - min_bit: self.min_bit + i, - max_bit: self.max_bit + i, + mask: self.mask.clone() << i, }) } else { None @@ -65,7 +68,7 @@ impl BitConstraint { } pub fn mask(&self) -> AbstractNumberType { - ((AbstractNumberType::from(1) << (1 + self.max_bit - self.min_bit)) - 1) << self.min_bit + self.mask.clone() } } @@ -75,6 +78,14 @@ impl Display for BitConstraint { } } +impl core::fmt::Debug for BitConstraint { + fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result { + f.debug_struct("BitConstraint") + .field("mask", &format!("0x{:x}", &self.mask)) + .finish() + } +} + /// Trait that provides a bit constraint on a symbolic variable if given by ID. pub trait BitConstraintSet { fn bit_constraint(&self, id: usize) -> Option; @@ -102,21 +113,24 @@ pub fn determine_global_constraints<'a>( identities: Vec<&'a Identity>, ) -> (BTreeMap<&'a str, BitConstraint>, Vec<&'a Identity>) { let mut known_constraints = BTreeMap::new(); + // For these columns, we know that they are not only constrained to those bits + // but also have one row for each possible value. + // It allows us to completely remove some lookups. + let mut full_span = BTreeSet::new(); for (&name, &values) in &fixed_data.fixed_cols { - if let Some(cons) = process_fixed_column(values) { - known_constraints.insert(name, cons); + println!("constr for {name}"); + if let Some((cons, full)) = process_fixed_column(values) { + assert!(known_constraints.insert(name, cons).is_none()); + if full { + full_span.insert(name); + } } } - // For these columns, we know that they are not only constrained to those bits - // but also have one row for each possible value. - let full_span = known_constraints.keys().copied().collect::>(); - - if fixed_data.verbose { - println!("Determined the following bit constraints on fixed columns:"); - for (name, con) in &known_constraints { - println!(" {name}: {con}"); - } + //if fixed_data.verbose { + println!("Determined the following bit constraints on fixed columns:"); + for (name, con) in &known_constraints { + println!(" {name}: {con}"); } let mut retained_identities = vec![]; @@ -146,21 +160,27 @@ pub fn determine_global_constraints<'a>( /// Analyzes a fixed column and checks if its values correspond exactly /// to a certain bit pattern. /// TODO do this on the symbolic definition instead of the values. -fn process_fixed_column(fixed: &[AbstractNumberType]) -> Option { +fn process_fixed_column(fixed: &[AbstractNumberType]) -> Option<(BitConstraint, bool)> { if let Some(bit) = smallest_period_candidate(fixed) { - let mask: u64 = (1u64 << bit) - 1; - for (i, v) in fixed.iter().enumerate() { - if *v != (i as u64 & mask).into() { - return None; - } + let mask: AbstractNumberType = + (AbstractNumberType::from(1) << bit) - AbstractNumberType::from(1); + if fixed + .iter() + .enumerate() + .all(|(i, v)| *v == AbstractNumberType::from(i) & mask.clone()) + { + return Some((BitConstraint::from_mask(mask), true)); } - Some(BitConstraint { - min_bit: 0, - max_bit: bit - 1, - }) - } else { - None } + let mut mask = 0.into(); + for v in fixed.iter() { + if *v < 0.into() { + return None; + } + mask |= v.clone(); + } + + Some((BitConstraint::from_mask(mask), false)) } /// Deduces new bit constraints on witness columns from constraints on fixed columns @@ -180,7 +200,7 @@ fn propagate_constraints<'a>( is_binary_constraint(fixed_data, identity.left.selector.as_ref().unwrap()) { assert!(known_constraints - .insert(p, BitConstraint::from_max(0)) + .insert(p, BitConstraint::from_max_bit(0)) .is_none()); remove = true; } else if let Some((p, c)) = try_transfer_constraints( @@ -188,7 +208,10 @@ fn propagate_constraints<'a>( identity.left.selector.as_ref().unwrap(), &known_constraints, ) { - assert!(known_constraints.insert(p, c).is_none()); + known_constraints + .entry(p) + .and_modify(|existing| *existing = existing.clone().conjunction(&c)) + .or_insert(c); } } IdentityKind::Plookup | IdentityKind::Permutation | IdentityKind::Connect => { @@ -203,7 +226,12 @@ fn propagate_constraints<'a>( { if let (Some(left), Some(right)) = (is_simple_poly(left), is_simple_poly(right)) { if let Some(constraint) = known_constraints.get(right).cloned() { - assert!(known_constraints.insert(left, constraint).is_none()); + known_constraints + .entry(left) + .and_modify(|existing| { + *existing = existing.clone().conjunction(&constraint) + }) + .or_insert(constraint); } } } @@ -324,7 +352,10 @@ mod test { #[test] fn all_zeros() { let fixed = [0, 0, 0, 0].iter().map(|v| (*v).into()).collect::>(); - assert_eq!(process_fixed_column(&fixed), None); + assert_eq!( + process_fixed_column(&fixed), + Some((BitConstraint::from_mask(0.into()), false)) + ); } #[test] @@ -335,10 +366,7 @@ mod test { .collect::>(); assert_eq!( process_fixed_column(&fixed), - Some(BitConstraint { - min_bit: 0, - max_bit: 0 - }) + Some((BitConstraint::from_mask(1.into()), true)) ); } @@ -350,10 +378,19 @@ mod test { .collect::>(); assert_eq!( process_fixed_column(&fixed), - Some(BitConstraint { - min_bit: 0, - max_bit: 1 - }) + Some((BitConstraint::from_mask(3.into()), true)) + ); + } + + #[test] + fn various_with_bit_mask() { + let fixed = [0, 6, 0x0100, 0x1100, 2] + .iter() + .map(|v| (*v).into()) + .collect::>(); + assert_eq!( + process_fixed_column(&fixed), + Some((BitConstraint::from_mask(0x1106.into()), false)) ); } @@ -363,27 +400,32 @@ mod test { namespace Global(2**20); col fixed BYTE(i) { i & 0xff }; col fixed BYTE2(i) { i & 0xffff }; + col fixed SHIFTED(i) { i & 0xff0 }; col witness A; // A bit more complicated to see that the 'pattern matcher' works properly. (1 - A + 0) * (A + 1 - 1) = 0; col witness B; { B } in { BYTE }; col witness C; - C = A * 2**8 + B; + C = A * 512 + B; + col witness D; + { D } in { BYTE }; + { D } in { SHIFTED }; "; let analyzed = crate::analyzer::analyze_string(pil_source); let (constants, degree) = crate::constant_evaluator::generate(&analyzed); let mut known_constraints = constants .iter() .filter_map(|(name, values)| { - process_fixed_column(values).map(|constraint| (*name, constraint)) + process_fixed_column(values).map(|(constraint, _full)| (*name, constraint)) }) .collect::>(); assert_eq!( known_constraints, vec![ - ("Global.BYTE", BitConstraint::from_max(7)), - ("Global.BYTE2", BitConstraint::from_max(15)), + ("Global.BYTE", BitConstraint::from_max_bit(7)), + ("Global.BYTE2", BitConstraint::from_max_bit(15)), + ("Global.SHIFTED", BitConstraint::from_mask(0xff0.into())), ] .into_iter() .collect() @@ -419,11 +461,13 @@ namespace Global(2**20); assert_eq!( known_constraints, vec![ - ("Global.A", BitConstraint::from_max(0)), - ("Global.B", BitConstraint::from_max(7)), - ("Global.C", BitConstraint::from_max(8)), - ("Global.BYTE", BitConstraint::from_max(7)), - ("Global.BYTE2", BitConstraint::from_max(15)), + ("Global.A", BitConstraint::from_max_bit(0)), + ("Global.B", BitConstraint::from_max_bit(7)), + ("Global.C", BitConstraint::from_mask(0x2ff.into())), + ("Global.D", BitConstraint::from_mask(0xf0.into())), + ("Global.BYTE", BitConstraint::from_max_bit(7)), + ("Global.BYTE2", BitConstraint::from_max_bit(15)), + ("Global.SHIFTED", BitConstraint::from_mask(0xff0.into())), ] .into_iter() .collect() @@ -432,21 +476,24 @@ namespace Global(2**20); #[test] fn combinations() { - let a = BitConstraint::from_max(7); + let a = BitConstraint::from_max_bit(7); + assert_eq!(a, BitConstraint::from_mask(0xff.into())); let b = a.multiple(256.into()).unwrap(); + assert_eq!(b, BitConstraint::from_mask(0xff00.into())); assert_eq!( - b, - BitConstraint { - min_bit: 8, - max_bit: 15 - } + b.try_combine_sum(&a).unwrap(), + BitConstraint::from_mask(0xffff.into()) ); + } + + #[test] + fn weird_combinations() { + let a = BitConstraint::from_mask(0xf00f.into()); + let b = a.multiple(256.into()).unwrap(); + assert_eq!(b, BitConstraint::from_mask(0xf00f00.into())); assert_eq!( - b.try_combine(&a).unwrap(), - BitConstraint { - min_bit: 0, - max_bit: 15 - } + b.try_combine_sum(&a).unwrap(), + BitConstraint::from_mask(0xf0ff0f.into()) ); } }