Merge pull request #128 from chriseth/improved_bit_constraints

Improved bit constraints.
This commit is contained in:
chriseth
2023-03-29 09:02:57 +02:00
committed by GitHub
2 changed files with 156 additions and 97 deletions

View File

@@ -1,9 +1,9 @@
use std::{collections::HashSet, ops::Not};
use std::ops::Not;
// TODO this should probably rather be a finite field element.
use crate::number::{format_number, is_zero, AbstractNumberType, GOLDILOCKS_MOD};
use super::bit_constraints::{BitConstraint, BitConstraintSet};
use super::bit_constraints::BitConstraintSet;
use super::eval_error::EvalError::ConflictingBitConstraints;
use super::util::WitnessColumnNamer;
use super::Constraint;
@@ -186,7 +186,7 @@ impl AffineExpression {
parts
.reduce(|c1, c2| match (c1, c2) {
(Some(c1), Some(c2)) => c1.try_combine(&c2),
(Some(c1), Some(c2)) => c1.try_combine_sum(&c2),
_ => None,
})
.flatten()
@@ -206,6 +206,7 @@ impl AffineExpression {
.map(|(i, coeff)| {
(
i,
coeff,
known_constraints
.bit_constraint(i)
.unwrap()
@@ -213,25 +214,24 @@ impl AffineExpression {
)
})
.collect::<Vec<_>>();
if parts.iter().any(|(_i, con)| con.is_none()) {
if parts.iter().any(|(_i, _coeff, con)| con.is_none()) {
return Ok(vec![]);
}
// Check if they are mutually exclusive and compute assignments.
let mut covered_bits = HashSet::<u64>::new();
let mut covered_bits: AbstractNumberType = 0.into();
let mut assignments = vec![];
let mut offset = clamp(-self.offset.clone());
for (i, con) in parts {
let con = con.clone().unwrap();
let BitConstraint { min_bit, max_bit } = con;
for bit in min_bit..=max_bit {
if !covered_bits.insert(bit) {
return Ok(vec![]);
}
for (i, coeff, constraint) in parts {
let constraint = constraint.clone().unwrap();
let mask = constraint.mask();
if mask.clone() & covered_bits.clone() != 0.into() {
return Ok(vec![]);
} else {
covered_bits |= mask.clone();
}
let mask: AbstractNumberType = con.mask();
assignments.push((
i,
Constraint::Assignment((offset.clone() & mask.clone()) >> min_bit),
Constraint::Assignment((offset.clone() & mask.clone()) / coeff.clone()),
));
offset &= mask.not();
}
@@ -429,28 +429,40 @@ mod test {
- AffineExpression::from_witness_poly_value(3);
let known_constraints = TestBitConstraints(
vec![
(2, BitConstraint::from_max(7)),
(3, BitConstraint::from_max(3)),
(2, BitConstraint::from_max_bit(7)),
(3, BitConstraint::from_max_bit(3)),
]
.into_iter()
.collect(),
);
assert_eq!(
expr.solve_with_bit_constraints(&known_constraints).unwrap(),
vec![(1, Constraint::BitConstraint(BitConstraint::from_max(11)))]
vec![(
1,
Constraint::BitConstraint(BitConstraint::from_max_bit(11))
)]
);
assert_eq!(
(-expr)
.solve_with_bit_constraints(&known_constraints)
.unwrap(),
vec![(1, Constraint::BitConstraint(BitConstraint::from_max(11)))]
vec![(
1,
Constraint::BitConstraint(BitConstraint::from_max_bit(11))
)]
);
// Replace factor 16 by 32.
let expr = AffineExpression::from_witness_poly_value(1)
- AffineExpression::from_witness_poly_value(2).mul(32.into())
- AffineExpression::from_witness_poly_value(3);
assert!(expr.solve_with_bit_constraints(&known_constraints).is_err());
assert_eq!(
expr.solve_with_bit_constraints(&known_constraints).unwrap(),
vec![(
1,
Constraint::BitConstraint(BitConstraint::from_mask(0x1fef.into()))
)]
);
// Replace factor 16 by 8.
let expr = AffineExpression::from_witness_poly_value(1)
@@ -467,8 +479,8 @@ mod test {
- AffineExpression::from_witness_poly_value(3);
let known_constraints = TestBitConstraints(
vec![
(2, BitConstraint::from_max(7)),
(3, BitConstraint::from_max(3)),
(2, BitConstraint::from_max_bit(7)),
(3, BitConstraint::from_max_bit(3)),
]
.into_iter()
.collect(),
@@ -491,8 +503,8 @@ mod test {
- AffineExpression::from_witness_poly_value(3);
let known_constraints = TestBitConstraints(
vec![
(2, BitConstraint::from_max(7)),
(3, BitConstraint::from_max(3)),
(2, BitConstraint::from_max_bit(7)),
(3, BitConstraint::from_max_bit(3)),
]
.into_iter()
.collect(),

View File

@@ -13,49 +13,52 @@ use super::{Constraint, FixedData};
/// All bits smaller than min_bit have to be zero
/// and all bits larger than max_bit have to be zero.
/// The least significant bit is bit zero.
#[derive(PartialEq, Debug, Clone)]
#[derive(PartialEq, Clone)]
pub struct BitConstraint {
pub min_bit: u64,
pub max_bit: u64,
mask: AbstractNumberType,
}
impl BitConstraint {
pub fn from_max(max_bit: u64) -> Self {
pub fn from_max_bit(max_bit: u64) -> Self {
assert!(max_bit < 1024);
BitConstraint {
min_bit: 0,
max_bit,
mask: (AbstractNumberType::from(1) << (max_bit + 1)) - AbstractNumberType::from(1),
}
}
pub fn from_mask(mask: AbstractNumberType) -> Self {
BitConstraint { mask }
}
/// The bit constraint of the sum of two expressions.
pub fn try_combine(&self, other: &BitConstraint) -> Option<BitConstraint> {
if self.max_bit + 1 == other.min_bit {
pub fn try_combine_sum(&self, other: &BitConstraint) -> Option<BitConstraint> {
if self.mask.clone() & other.mask.clone() == 0.into() {
Some(BitConstraint {
min_bit: self.min_bit,
max_bit: other.max_bit,
})
} else if other.max_bit + 1 == self.min_bit {
Some(BitConstraint {
min_bit: other.min_bit,
max_bit: self.max_bit,
mask: self.mask.clone() | other.mask.clone(),
})
} else {
None
}
}
/// Returns the conjunction of this constraint and the other.
pub fn conjunction(self, other: &BitConstraint) -> BitConstraint {
BitConstraint {
mask: self.mask & other.mask.clone(),
}
}
/// The bit constraint of an integer multiple of an expression.
/// TODO this assumes goldilocks
pub fn multiple(&self, factor: AbstractNumberType) -> Option<BitConstraint> {
if factor.clone() * (1u64 << self.max_bit) >= GOLDILOCKS_MOD.into() {
if factor.clone() * self.mask.clone() >= GOLDILOCKS_MOD.into() {
None
} else {
// TODO use binary logarithm
(0..64).find_map(|i| {
if factor.clone() == (1u64 << i).into() {
Some(BitConstraint {
min_bit: self.min_bit + i,
max_bit: self.max_bit + i,
mask: self.mask.clone() << i,
})
} else {
None
@@ -65,7 +68,7 @@ impl BitConstraint {
}
pub fn mask(&self) -> AbstractNumberType {
((AbstractNumberType::from(1) << (1 + self.max_bit - self.min_bit)) - 1) << self.min_bit
self.mask.clone()
}
}
@@ -75,6 +78,14 @@ impl Display for BitConstraint {
}
}
impl core::fmt::Debug for BitConstraint {
fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result {
f.debug_struct("BitConstraint")
.field("mask", &format!("0x{:x}", &self.mask))
.finish()
}
}
/// Trait that provides a bit constraint on a symbolic variable if given by ID.
pub trait BitConstraintSet {
fn bit_constraint(&self, id: usize) -> Option<BitConstraint>;
@@ -102,21 +113,24 @@ pub fn determine_global_constraints<'a>(
identities: Vec<&'a Identity>,
) -> (BTreeMap<&'a str, BitConstraint>, Vec<&'a Identity>) {
let mut known_constraints = BTreeMap::new();
// For these columns, we know that they are not only constrained to those bits
// but also have one row for each possible value.
// It allows us to completely remove some lookups.
let mut full_span = BTreeSet::new();
for (&name, &values) in &fixed_data.fixed_cols {
if let Some(cons) = process_fixed_column(values) {
known_constraints.insert(name, cons);
println!("constr for {name}");
if let Some((cons, full)) = process_fixed_column(values) {
assert!(known_constraints.insert(name, cons).is_none());
if full {
full_span.insert(name);
}
}
}
// For these columns, we know that they are not only constrained to those bits
// but also have one row for each possible value.
let full_span = known_constraints.keys().copied().collect::<BTreeSet<_>>();
if fixed_data.verbose {
println!("Determined the following bit constraints on fixed columns:");
for (name, con) in &known_constraints {
println!(" {name}: {con}");
}
//if fixed_data.verbose {
println!("Determined the following bit constraints on fixed columns:");
for (name, con) in &known_constraints {
println!(" {name}: {con}");
}
let mut retained_identities = vec![];
@@ -146,21 +160,27 @@ pub fn determine_global_constraints<'a>(
/// Analyzes a fixed column and checks if its values correspond exactly
/// to a certain bit pattern.
/// TODO do this on the symbolic definition instead of the values.
fn process_fixed_column(fixed: &[AbstractNumberType]) -> Option<BitConstraint> {
fn process_fixed_column(fixed: &[AbstractNumberType]) -> Option<(BitConstraint, bool)> {
if let Some(bit) = smallest_period_candidate(fixed) {
let mask: u64 = (1u64 << bit) - 1;
for (i, v) in fixed.iter().enumerate() {
if *v != (i as u64 & mask).into() {
return None;
}
let mask: AbstractNumberType =
(AbstractNumberType::from(1) << bit) - AbstractNumberType::from(1);
if fixed
.iter()
.enumerate()
.all(|(i, v)| *v == AbstractNumberType::from(i) & mask.clone())
{
return Some((BitConstraint::from_mask(mask), true));
}
Some(BitConstraint {
min_bit: 0,
max_bit: bit - 1,
})
} else {
None
}
let mut mask = 0.into();
for v in fixed.iter() {
if *v < 0.into() {
return None;
}
mask |= v.clone();
}
Some((BitConstraint::from_mask(mask), false))
}
/// Deduces new bit constraints on witness columns from constraints on fixed columns
@@ -180,7 +200,7 @@ fn propagate_constraints<'a>(
is_binary_constraint(fixed_data, identity.left.selector.as_ref().unwrap())
{
assert!(known_constraints
.insert(p, BitConstraint::from_max(0))
.insert(p, BitConstraint::from_max_bit(0))
.is_none());
remove = true;
} else if let Some((p, c)) = try_transfer_constraints(
@@ -188,7 +208,10 @@ fn propagate_constraints<'a>(
identity.left.selector.as_ref().unwrap(),
&known_constraints,
) {
assert!(known_constraints.insert(p, c).is_none());
known_constraints
.entry(p)
.and_modify(|existing| *existing = existing.clone().conjunction(&c))
.or_insert(c);
}
}
IdentityKind::Plookup | IdentityKind::Permutation | IdentityKind::Connect => {
@@ -203,7 +226,12 @@ fn propagate_constraints<'a>(
{
if let (Some(left), Some(right)) = (is_simple_poly(left), is_simple_poly(right)) {
if let Some(constraint) = known_constraints.get(right).cloned() {
assert!(known_constraints.insert(left, constraint).is_none());
known_constraints
.entry(left)
.and_modify(|existing| {
*existing = existing.clone().conjunction(&constraint)
})
.or_insert(constraint);
}
}
}
@@ -324,7 +352,10 @@ mod test {
#[test]
fn all_zeros() {
let fixed = [0, 0, 0, 0].iter().map(|v| (*v).into()).collect::<Vec<_>>();
assert_eq!(process_fixed_column(&fixed), None);
assert_eq!(
process_fixed_column(&fixed),
Some((BitConstraint::from_mask(0.into()), false))
);
}
#[test]
@@ -335,10 +366,7 @@ mod test {
.collect::<Vec<_>>();
assert_eq!(
process_fixed_column(&fixed),
Some(BitConstraint {
min_bit: 0,
max_bit: 0
})
Some((BitConstraint::from_mask(1.into()), true))
);
}
@@ -350,10 +378,19 @@ mod test {
.collect::<Vec<_>>();
assert_eq!(
process_fixed_column(&fixed),
Some(BitConstraint {
min_bit: 0,
max_bit: 1
})
Some((BitConstraint::from_mask(3.into()), true))
);
}
#[test]
fn various_with_bit_mask() {
let fixed = [0, 6, 0x0100, 0x1100, 2]
.iter()
.map(|v| (*v).into())
.collect::<Vec<_>>();
assert_eq!(
process_fixed_column(&fixed),
Some((BitConstraint::from_mask(0x1106.into()), false))
);
}
@@ -363,27 +400,32 @@ mod test {
namespace Global(2**20);
col fixed BYTE(i) { i & 0xff };
col fixed BYTE2(i) { i & 0xffff };
col fixed SHIFTED(i) { i & 0xff0 };
col witness A;
// A bit more complicated to see that the 'pattern matcher' works properly.
(1 - A + 0) * (A + 1 - 1) = 0;
col witness B;
{ B } in { BYTE };
col witness C;
C = A * 2**8 + B;
C = A * 512 + B;
col witness D;
{ D } in { BYTE };
{ D } in { SHIFTED };
";
let analyzed = crate::analyzer::analyze_string(pil_source);
let (constants, degree) = crate::constant_evaluator::generate(&analyzed);
let mut known_constraints = constants
.iter()
.filter_map(|(name, values)| {
process_fixed_column(values).map(|constraint| (*name, constraint))
process_fixed_column(values).map(|(constraint, _full)| (*name, constraint))
})
.collect::<BTreeMap<_, _>>();
assert_eq!(
known_constraints,
vec![
("Global.BYTE", BitConstraint::from_max(7)),
("Global.BYTE2", BitConstraint::from_max(15)),
("Global.BYTE", BitConstraint::from_max_bit(7)),
("Global.BYTE2", BitConstraint::from_max_bit(15)),
("Global.SHIFTED", BitConstraint::from_mask(0xff0.into())),
]
.into_iter()
.collect()
@@ -419,11 +461,13 @@ namespace Global(2**20);
assert_eq!(
known_constraints,
vec![
("Global.A", BitConstraint::from_max(0)),
("Global.B", BitConstraint::from_max(7)),
("Global.C", BitConstraint::from_max(8)),
("Global.BYTE", BitConstraint::from_max(7)),
("Global.BYTE2", BitConstraint::from_max(15)),
("Global.A", BitConstraint::from_max_bit(0)),
("Global.B", BitConstraint::from_max_bit(7)),
("Global.C", BitConstraint::from_mask(0x2ff.into())),
("Global.D", BitConstraint::from_mask(0xf0.into())),
("Global.BYTE", BitConstraint::from_max_bit(7)),
("Global.BYTE2", BitConstraint::from_max_bit(15)),
("Global.SHIFTED", BitConstraint::from_mask(0xff0.into())),
]
.into_iter()
.collect()
@@ -432,21 +476,24 @@ namespace Global(2**20);
#[test]
fn combinations() {
let a = BitConstraint::from_max(7);
let a = BitConstraint::from_max_bit(7);
assert_eq!(a, BitConstraint::from_mask(0xff.into()));
let b = a.multiple(256.into()).unwrap();
assert_eq!(b, BitConstraint::from_mask(0xff00.into()));
assert_eq!(
b,
BitConstraint {
min_bit: 8,
max_bit: 15
}
b.try_combine_sum(&a).unwrap(),
BitConstraint::from_mask(0xffff.into())
);
}
#[test]
fn weird_combinations() {
let a = BitConstraint::from_mask(0xf00f.into());
let b = a.multiple(256.into()).unwrap();
assert_eq!(b, BitConstraint::from_mask(0xf00f00.into()));
assert_eq!(
b.try_combine(&a).unwrap(),
BitConstraint {
min_bit: 0,
max_bit: 15
}
b.try_combine_sum(&a).unwrap(),
BitConstraint::from_mask(0xf0ff0f.into())
);
}
}