mirror of
https://github.com/powdr-labs/powdr.git
synced 2026-05-13 03:00:26 -04:00
Merge pull request #128 from chriseth/improved_bit_constraints
Improved bit constraints.
This commit is contained in:
@@ -1,9 +1,9 @@
|
||||
use std::{collections::HashSet, ops::Not};
|
||||
use std::ops::Not;
|
||||
|
||||
// TODO this should probably rather be a finite field element.
|
||||
use crate::number::{format_number, is_zero, AbstractNumberType, GOLDILOCKS_MOD};
|
||||
|
||||
use super::bit_constraints::{BitConstraint, BitConstraintSet};
|
||||
use super::bit_constraints::BitConstraintSet;
|
||||
use super::eval_error::EvalError::ConflictingBitConstraints;
|
||||
use super::util::WitnessColumnNamer;
|
||||
use super::Constraint;
|
||||
@@ -186,7 +186,7 @@ impl AffineExpression {
|
||||
|
||||
parts
|
||||
.reduce(|c1, c2| match (c1, c2) {
|
||||
(Some(c1), Some(c2)) => c1.try_combine(&c2),
|
||||
(Some(c1), Some(c2)) => c1.try_combine_sum(&c2),
|
||||
_ => None,
|
||||
})
|
||||
.flatten()
|
||||
@@ -206,6 +206,7 @@ impl AffineExpression {
|
||||
.map(|(i, coeff)| {
|
||||
(
|
||||
i,
|
||||
coeff,
|
||||
known_constraints
|
||||
.bit_constraint(i)
|
||||
.unwrap()
|
||||
@@ -213,25 +214,24 @@ impl AffineExpression {
|
||||
)
|
||||
})
|
||||
.collect::<Vec<_>>();
|
||||
if parts.iter().any(|(_i, con)| con.is_none()) {
|
||||
if parts.iter().any(|(_i, _coeff, con)| con.is_none()) {
|
||||
return Ok(vec![]);
|
||||
}
|
||||
// Check if they are mutually exclusive and compute assignments.
|
||||
let mut covered_bits = HashSet::<u64>::new();
|
||||
let mut covered_bits: AbstractNumberType = 0.into();
|
||||
let mut assignments = vec![];
|
||||
let mut offset = clamp(-self.offset.clone());
|
||||
for (i, con) in parts {
|
||||
let con = con.clone().unwrap();
|
||||
let BitConstraint { min_bit, max_bit } = con;
|
||||
for bit in min_bit..=max_bit {
|
||||
if !covered_bits.insert(bit) {
|
||||
return Ok(vec![]);
|
||||
}
|
||||
for (i, coeff, constraint) in parts {
|
||||
let constraint = constraint.clone().unwrap();
|
||||
let mask = constraint.mask();
|
||||
if mask.clone() & covered_bits.clone() != 0.into() {
|
||||
return Ok(vec![]);
|
||||
} else {
|
||||
covered_bits |= mask.clone();
|
||||
}
|
||||
let mask: AbstractNumberType = con.mask();
|
||||
assignments.push((
|
||||
i,
|
||||
Constraint::Assignment((offset.clone() & mask.clone()) >> min_bit),
|
||||
Constraint::Assignment((offset.clone() & mask.clone()) / coeff.clone()),
|
||||
));
|
||||
offset &= mask.not();
|
||||
}
|
||||
@@ -429,28 +429,40 @@ mod test {
|
||||
- AffineExpression::from_witness_poly_value(3);
|
||||
let known_constraints = TestBitConstraints(
|
||||
vec![
|
||||
(2, BitConstraint::from_max(7)),
|
||||
(3, BitConstraint::from_max(3)),
|
||||
(2, BitConstraint::from_max_bit(7)),
|
||||
(3, BitConstraint::from_max_bit(3)),
|
||||
]
|
||||
.into_iter()
|
||||
.collect(),
|
||||
);
|
||||
assert_eq!(
|
||||
expr.solve_with_bit_constraints(&known_constraints).unwrap(),
|
||||
vec![(1, Constraint::BitConstraint(BitConstraint::from_max(11)))]
|
||||
vec![(
|
||||
1,
|
||||
Constraint::BitConstraint(BitConstraint::from_max_bit(11))
|
||||
)]
|
||||
);
|
||||
assert_eq!(
|
||||
(-expr)
|
||||
.solve_with_bit_constraints(&known_constraints)
|
||||
.unwrap(),
|
||||
vec![(1, Constraint::BitConstraint(BitConstraint::from_max(11)))]
|
||||
vec![(
|
||||
1,
|
||||
Constraint::BitConstraint(BitConstraint::from_max_bit(11))
|
||||
)]
|
||||
);
|
||||
|
||||
// Replace factor 16 by 32.
|
||||
let expr = AffineExpression::from_witness_poly_value(1)
|
||||
- AffineExpression::from_witness_poly_value(2).mul(32.into())
|
||||
- AffineExpression::from_witness_poly_value(3);
|
||||
assert!(expr.solve_with_bit_constraints(&known_constraints).is_err());
|
||||
assert_eq!(
|
||||
expr.solve_with_bit_constraints(&known_constraints).unwrap(),
|
||||
vec![(
|
||||
1,
|
||||
Constraint::BitConstraint(BitConstraint::from_mask(0x1fef.into()))
|
||||
)]
|
||||
);
|
||||
|
||||
// Replace factor 16 by 8.
|
||||
let expr = AffineExpression::from_witness_poly_value(1)
|
||||
@@ -467,8 +479,8 @@ mod test {
|
||||
- AffineExpression::from_witness_poly_value(3);
|
||||
let known_constraints = TestBitConstraints(
|
||||
vec![
|
||||
(2, BitConstraint::from_max(7)),
|
||||
(3, BitConstraint::from_max(3)),
|
||||
(2, BitConstraint::from_max_bit(7)),
|
||||
(3, BitConstraint::from_max_bit(3)),
|
||||
]
|
||||
.into_iter()
|
||||
.collect(),
|
||||
@@ -491,8 +503,8 @@ mod test {
|
||||
- AffineExpression::from_witness_poly_value(3);
|
||||
let known_constraints = TestBitConstraints(
|
||||
vec![
|
||||
(2, BitConstraint::from_max(7)),
|
||||
(3, BitConstraint::from_max(3)),
|
||||
(2, BitConstraint::from_max_bit(7)),
|
||||
(3, BitConstraint::from_max_bit(3)),
|
||||
]
|
||||
.into_iter()
|
||||
.collect(),
|
||||
|
||||
@@ -13,49 +13,52 @@ use super::{Constraint, FixedData};
|
||||
/// All bits smaller than min_bit have to be zero
|
||||
/// and all bits larger than max_bit have to be zero.
|
||||
/// The least significant bit is bit zero.
|
||||
#[derive(PartialEq, Debug, Clone)]
|
||||
#[derive(PartialEq, Clone)]
|
||||
pub struct BitConstraint {
|
||||
pub min_bit: u64,
|
||||
pub max_bit: u64,
|
||||
mask: AbstractNumberType,
|
||||
}
|
||||
|
||||
impl BitConstraint {
|
||||
pub fn from_max(max_bit: u64) -> Self {
|
||||
pub fn from_max_bit(max_bit: u64) -> Self {
|
||||
assert!(max_bit < 1024);
|
||||
BitConstraint {
|
||||
min_bit: 0,
|
||||
max_bit,
|
||||
mask: (AbstractNumberType::from(1) << (max_bit + 1)) - AbstractNumberType::from(1),
|
||||
}
|
||||
}
|
||||
|
||||
pub fn from_mask(mask: AbstractNumberType) -> Self {
|
||||
BitConstraint { mask }
|
||||
}
|
||||
|
||||
/// The bit constraint of the sum of two expressions.
|
||||
pub fn try_combine(&self, other: &BitConstraint) -> Option<BitConstraint> {
|
||||
if self.max_bit + 1 == other.min_bit {
|
||||
pub fn try_combine_sum(&self, other: &BitConstraint) -> Option<BitConstraint> {
|
||||
if self.mask.clone() & other.mask.clone() == 0.into() {
|
||||
Some(BitConstraint {
|
||||
min_bit: self.min_bit,
|
||||
max_bit: other.max_bit,
|
||||
})
|
||||
} else if other.max_bit + 1 == self.min_bit {
|
||||
Some(BitConstraint {
|
||||
min_bit: other.min_bit,
|
||||
max_bit: self.max_bit,
|
||||
mask: self.mask.clone() | other.mask.clone(),
|
||||
})
|
||||
} else {
|
||||
None
|
||||
}
|
||||
}
|
||||
|
||||
/// Returns the conjunction of this constraint and the other.
|
||||
pub fn conjunction(self, other: &BitConstraint) -> BitConstraint {
|
||||
BitConstraint {
|
||||
mask: self.mask & other.mask.clone(),
|
||||
}
|
||||
}
|
||||
|
||||
/// The bit constraint of an integer multiple of an expression.
|
||||
/// TODO this assumes goldilocks
|
||||
pub fn multiple(&self, factor: AbstractNumberType) -> Option<BitConstraint> {
|
||||
if factor.clone() * (1u64 << self.max_bit) >= GOLDILOCKS_MOD.into() {
|
||||
if factor.clone() * self.mask.clone() >= GOLDILOCKS_MOD.into() {
|
||||
None
|
||||
} else {
|
||||
// TODO use binary logarithm
|
||||
(0..64).find_map(|i| {
|
||||
if factor.clone() == (1u64 << i).into() {
|
||||
Some(BitConstraint {
|
||||
min_bit: self.min_bit + i,
|
||||
max_bit: self.max_bit + i,
|
||||
mask: self.mask.clone() << i,
|
||||
})
|
||||
} else {
|
||||
None
|
||||
@@ -65,7 +68,7 @@ impl BitConstraint {
|
||||
}
|
||||
|
||||
pub fn mask(&self) -> AbstractNumberType {
|
||||
((AbstractNumberType::from(1) << (1 + self.max_bit - self.min_bit)) - 1) << self.min_bit
|
||||
self.mask.clone()
|
||||
}
|
||||
}
|
||||
|
||||
@@ -75,6 +78,14 @@ impl Display for BitConstraint {
|
||||
}
|
||||
}
|
||||
|
||||
impl core::fmt::Debug for BitConstraint {
|
||||
fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result {
|
||||
f.debug_struct("BitConstraint")
|
||||
.field("mask", &format!("0x{:x}", &self.mask))
|
||||
.finish()
|
||||
}
|
||||
}
|
||||
|
||||
/// Trait that provides a bit constraint on a symbolic variable if given by ID.
|
||||
pub trait BitConstraintSet {
|
||||
fn bit_constraint(&self, id: usize) -> Option<BitConstraint>;
|
||||
@@ -102,21 +113,24 @@ pub fn determine_global_constraints<'a>(
|
||||
identities: Vec<&'a Identity>,
|
||||
) -> (BTreeMap<&'a str, BitConstraint>, Vec<&'a Identity>) {
|
||||
let mut known_constraints = BTreeMap::new();
|
||||
// For these columns, we know that they are not only constrained to those bits
|
||||
// but also have one row for each possible value.
|
||||
// It allows us to completely remove some lookups.
|
||||
let mut full_span = BTreeSet::new();
|
||||
for (&name, &values) in &fixed_data.fixed_cols {
|
||||
if let Some(cons) = process_fixed_column(values) {
|
||||
known_constraints.insert(name, cons);
|
||||
println!("constr for {name}");
|
||||
if let Some((cons, full)) = process_fixed_column(values) {
|
||||
assert!(known_constraints.insert(name, cons).is_none());
|
||||
if full {
|
||||
full_span.insert(name);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// For these columns, we know that they are not only constrained to those bits
|
||||
// but also have one row for each possible value.
|
||||
let full_span = known_constraints.keys().copied().collect::<BTreeSet<_>>();
|
||||
|
||||
if fixed_data.verbose {
|
||||
println!("Determined the following bit constraints on fixed columns:");
|
||||
for (name, con) in &known_constraints {
|
||||
println!(" {name}: {con}");
|
||||
}
|
||||
//if fixed_data.verbose {
|
||||
println!("Determined the following bit constraints on fixed columns:");
|
||||
for (name, con) in &known_constraints {
|
||||
println!(" {name}: {con}");
|
||||
}
|
||||
|
||||
let mut retained_identities = vec![];
|
||||
@@ -146,21 +160,27 @@ pub fn determine_global_constraints<'a>(
|
||||
/// Analyzes a fixed column and checks if its values correspond exactly
|
||||
/// to a certain bit pattern.
|
||||
/// TODO do this on the symbolic definition instead of the values.
|
||||
fn process_fixed_column(fixed: &[AbstractNumberType]) -> Option<BitConstraint> {
|
||||
fn process_fixed_column(fixed: &[AbstractNumberType]) -> Option<(BitConstraint, bool)> {
|
||||
if let Some(bit) = smallest_period_candidate(fixed) {
|
||||
let mask: u64 = (1u64 << bit) - 1;
|
||||
for (i, v) in fixed.iter().enumerate() {
|
||||
if *v != (i as u64 & mask).into() {
|
||||
return None;
|
||||
}
|
||||
let mask: AbstractNumberType =
|
||||
(AbstractNumberType::from(1) << bit) - AbstractNumberType::from(1);
|
||||
if fixed
|
||||
.iter()
|
||||
.enumerate()
|
||||
.all(|(i, v)| *v == AbstractNumberType::from(i) & mask.clone())
|
||||
{
|
||||
return Some((BitConstraint::from_mask(mask), true));
|
||||
}
|
||||
Some(BitConstraint {
|
||||
min_bit: 0,
|
||||
max_bit: bit - 1,
|
||||
})
|
||||
} else {
|
||||
None
|
||||
}
|
||||
let mut mask = 0.into();
|
||||
for v in fixed.iter() {
|
||||
if *v < 0.into() {
|
||||
return None;
|
||||
}
|
||||
mask |= v.clone();
|
||||
}
|
||||
|
||||
Some((BitConstraint::from_mask(mask), false))
|
||||
}
|
||||
|
||||
/// Deduces new bit constraints on witness columns from constraints on fixed columns
|
||||
@@ -180,7 +200,7 @@ fn propagate_constraints<'a>(
|
||||
is_binary_constraint(fixed_data, identity.left.selector.as_ref().unwrap())
|
||||
{
|
||||
assert!(known_constraints
|
||||
.insert(p, BitConstraint::from_max(0))
|
||||
.insert(p, BitConstraint::from_max_bit(0))
|
||||
.is_none());
|
||||
remove = true;
|
||||
} else if let Some((p, c)) = try_transfer_constraints(
|
||||
@@ -188,7 +208,10 @@ fn propagate_constraints<'a>(
|
||||
identity.left.selector.as_ref().unwrap(),
|
||||
&known_constraints,
|
||||
) {
|
||||
assert!(known_constraints.insert(p, c).is_none());
|
||||
known_constraints
|
||||
.entry(p)
|
||||
.and_modify(|existing| *existing = existing.clone().conjunction(&c))
|
||||
.or_insert(c);
|
||||
}
|
||||
}
|
||||
IdentityKind::Plookup | IdentityKind::Permutation | IdentityKind::Connect => {
|
||||
@@ -203,7 +226,12 @@ fn propagate_constraints<'a>(
|
||||
{
|
||||
if let (Some(left), Some(right)) = (is_simple_poly(left), is_simple_poly(right)) {
|
||||
if let Some(constraint) = known_constraints.get(right).cloned() {
|
||||
assert!(known_constraints.insert(left, constraint).is_none());
|
||||
known_constraints
|
||||
.entry(left)
|
||||
.and_modify(|existing| {
|
||||
*existing = existing.clone().conjunction(&constraint)
|
||||
})
|
||||
.or_insert(constraint);
|
||||
}
|
||||
}
|
||||
}
|
||||
@@ -324,7 +352,10 @@ mod test {
|
||||
#[test]
|
||||
fn all_zeros() {
|
||||
let fixed = [0, 0, 0, 0].iter().map(|v| (*v).into()).collect::<Vec<_>>();
|
||||
assert_eq!(process_fixed_column(&fixed), None);
|
||||
assert_eq!(
|
||||
process_fixed_column(&fixed),
|
||||
Some((BitConstraint::from_mask(0.into()), false))
|
||||
);
|
||||
}
|
||||
|
||||
#[test]
|
||||
@@ -335,10 +366,7 @@ mod test {
|
||||
.collect::<Vec<_>>();
|
||||
assert_eq!(
|
||||
process_fixed_column(&fixed),
|
||||
Some(BitConstraint {
|
||||
min_bit: 0,
|
||||
max_bit: 0
|
||||
})
|
||||
Some((BitConstraint::from_mask(1.into()), true))
|
||||
);
|
||||
}
|
||||
|
||||
@@ -350,10 +378,19 @@ mod test {
|
||||
.collect::<Vec<_>>();
|
||||
assert_eq!(
|
||||
process_fixed_column(&fixed),
|
||||
Some(BitConstraint {
|
||||
min_bit: 0,
|
||||
max_bit: 1
|
||||
})
|
||||
Some((BitConstraint::from_mask(3.into()), true))
|
||||
);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn various_with_bit_mask() {
|
||||
let fixed = [0, 6, 0x0100, 0x1100, 2]
|
||||
.iter()
|
||||
.map(|v| (*v).into())
|
||||
.collect::<Vec<_>>();
|
||||
assert_eq!(
|
||||
process_fixed_column(&fixed),
|
||||
Some((BitConstraint::from_mask(0x1106.into()), false))
|
||||
);
|
||||
}
|
||||
|
||||
@@ -363,27 +400,32 @@ mod test {
|
||||
namespace Global(2**20);
|
||||
col fixed BYTE(i) { i & 0xff };
|
||||
col fixed BYTE2(i) { i & 0xffff };
|
||||
col fixed SHIFTED(i) { i & 0xff0 };
|
||||
col witness A;
|
||||
// A bit more complicated to see that the 'pattern matcher' works properly.
|
||||
(1 - A + 0) * (A + 1 - 1) = 0;
|
||||
col witness B;
|
||||
{ B } in { BYTE };
|
||||
col witness C;
|
||||
C = A * 2**8 + B;
|
||||
C = A * 512 + B;
|
||||
col witness D;
|
||||
{ D } in { BYTE };
|
||||
{ D } in { SHIFTED };
|
||||
";
|
||||
let analyzed = crate::analyzer::analyze_string(pil_source);
|
||||
let (constants, degree) = crate::constant_evaluator::generate(&analyzed);
|
||||
let mut known_constraints = constants
|
||||
.iter()
|
||||
.filter_map(|(name, values)| {
|
||||
process_fixed_column(values).map(|constraint| (*name, constraint))
|
||||
process_fixed_column(values).map(|(constraint, _full)| (*name, constraint))
|
||||
})
|
||||
.collect::<BTreeMap<_, _>>();
|
||||
assert_eq!(
|
||||
known_constraints,
|
||||
vec![
|
||||
("Global.BYTE", BitConstraint::from_max(7)),
|
||||
("Global.BYTE2", BitConstraint::from_max(15)),
|
||||
("Global.BYTE", BitConstraint::from_max_bit(7)),
|
||||
("Global.BYTE2", BitConstraint::from_max_bit(15)),
|
||||
("Global.SHIFTED", BitConstraint::from_mask(0xff0.into())),
|
||||
]
|
||||
.into_iter()
|
||||
.collect()
|
||||
@@ -419,11 +461,13 @@ namespace Global(2**20);
|
||||
assert_eq!(
|
||||
known_constraints,
|
||||
vec![
|
||||
("Global.A", BitConstraint::from_max(0)),
|
||||
("Global.B", BitConstraint::from_max(7)),
|
||||
("Global.C", BitConstraint::from_max(8)),
|
||||
("Global.BYTE", BitConstraint::from_max(7)),
|
||||
("Global.BYTE2", BitConstraint::from_max(15)),
|
||||
("Global.A", BitConstraint::from_max_bit(0)),
|
||||
("Global.B", BitConstraint::from_max_bit(7)),
|
||||
("Global.C", BitConstraint::from_mask(0x2ff.into())),
|
||||
("Global.D", BitConstraint::from_mask(0xf0.into())),
|
||||
("Global.BYTE", BitConstraint::from_max_bit(7)),
|
||||
("Global.BYTE2", BitConstraint::from_max_bit(15)),
|
||||
("Global.SHIFTED", BitConstraint::from_mask(0xff0.into())),
|
||||
]
|
||||
.into_iter()
|
||||
.collect()
|
||||
@@ -432,21 +476,24 @@ namespace Global(2**20);
|
||||
|
||||
#[test]
|
||||
fn combinations() {
|
||||
let a = BitConstraint::from_max(7);
|
||||
let a = BitConstraint::from_max_bit(7);
|
||||
assert_eq!(a, BitConstraint::from_mask(0xff.into()));
|
||||
let b = a.multiple(256.into()).unwrap();
|
||||
assert_eq!(b, BitConstraint::from_mask(0xff00.into()));
|
||||
assert_eq!(
|
||||
b,
|
||||
BitConstraint {
|
||||
min_bit: 8,
|
||||
max_bit: 15
|
||||
}
|
||||
b.try_combine_sum(&a).unwrap(),
|
||||
BitConstraint::from_mask(0xffff.into())
|
||||
);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn weird_combinations() {
|
||||
let a = BitConstraint::from_mask(0xf00f.into());
|
||||
let b = a.multiple(256.into()).unwrap();
|
||||
assert_eq!(b, BitConstraint::from_mask(0xf00f00.into()));
|
||||
assert_eq!(
|
||||
b.try_combine(&a).unwrap(),
|
||||
BitConstraint {
|
||||
min_bit: 0,
|
||||
max_bit: 15
|
||||
}
|
||||
b.try_combine_sum(&a).unwrap(),
|
||||
BitConstraint::from_mask(0xf0ff0f.into())
|
||||
);
|
||||
}
|
||||
}
|
||||
|
||||
Reference in New Issue
Block a user