Take other bus interactions into account when computing xor range constraints (#3012)

This commit is contained in:
chriseth
2025-07-22 11:32:46 +02:00
committed by GitHub
parent d74aaa1a83
commit 9a97fca7c6
6 changed files with 60 additions and 54 deletions

View File

@@ -1,11 +1,15 @@
use std::collections::HashSet;
use std::collections::HashMap;
use std::hash::Hash;
use std::{fmt::Debug, fmt::Display};
use itertools::Itertools;
use num_traits::{One, Zero};
use powdr_constraint_solver::constraint_system::{BusInteraction, ConstraintSystem};
use powdr_constraint_solver::constraint_system::{
BusInteraction, BusInteractionHandler, ConstraintSystem,
};
use powdr_constraint_solver::grouped_expression::GroupedExpression;
use powdr_constraint_solver::range_constraint::RangeConstraint;
use powdr_constraint_solver::solver::{self, bus_interaction_variable_wrapper, Solver};
use powdr_number::FieldElement;
/// Optimize interactions with the bitwise lookup bus. It mostly optimizes the use of
@@ -13,6 +17,7 @@ use powdr_number::FieldElement;
pub fn optimize_bitwise_lookup<T: FieldElement, V: Hash + Eq + Clone + Ord + Debug + Display>(
mut system: ConstraintSystem<T, V>,
bitwise_lookup_bus_id: u64,
bus_interaction_handler: impl BusInteractionHandler<T>,
) -> ConstraintSystem<T, V> {
// Expressions that we need to byte-constrain at the end.
let mut to_byte_constrain = vec![];
@@ -50,7 +55,7 @@ pub fn optimize_bitwise_lookup<T: FieldElement, V: Hash + Eq + Clone + Ord + Deb
let mut args = vec![x, y, z];
if let Some(zero_pos) = args.iter().position(|e| e.is_zero()) {
args.remove(zero_pos);
// The two remaning expressions in args are equal and bytes.
// The two remaining expressions in args are equal and bytes.
let [a, b] = args.try_into().unwrap();
new_constraints.push(a.clone() - b.clone());
to_byte_constrain.push(a.clone());
@@ -66,9 +71,10 @@ pub fn optimize_bitwise_lookup<T: FieldElement, V: Hash + Eq + Clone + Ord + Deb
// After we have removed the bus interactions, we check which of the
// expressions we still need to byte-constrain. Some are maybe already
// byte-constrained by other bus interactions.
let already_byte_constrained = all_byte_constrained_expressions(&system, bitwise_lookup_bus_id)
.cloned()
.collect::<HashSet<_>>();
let byte_range_constraint = RangeConstraint::from_mask(0xffu64);
let range_constraints =
determine_range_constraints_using_solver(&system, bus_interaction_handler);
let mut to_byte_constrain = to_byte_constrain
.into_iter()
.filter(|expr| {
@@ -76,8 +82,10 @@ pub fn optimize_bitwise_lookup<T: FieldElement, V: Hash + Eq + Clone + Ord + Deb
assert!(n >= T::from(0) && n < T::from(256));
// No need to byte-constrain numbers.
false
} else if let Some(rc) = range_constraints.get(expr) {
*rc != rc.conjunction(&byte_range_constraint)
} else {
!already_byte_constrained.contains(expr)
true
}
})
.unique()
@@ -104,32 +112,31 @@ fn is_simple_multiplicity_bitwise_bus_interaction<T: FieldElement, V: Clone + Ha
&& bus_int.multiplicity.is_one()
}
/// Returns all expressions that are byte-constrained in the machine.
/// The list does not have to be exhaustive.
fn all_byte_constrained_expressions<T: FieldElement, V: Clone + Ord + Hash>(
machine: &ConstraintSystem<T, V>,
bitwise_lookup_bus_id: u64,
) -> impl Iterator<Item = &GroupedExpression<T, V>> {
machine
.bus_interactions
.iter()
.filter(move |bus_int| {
is_simple_multiplicity_bitwise_bus_interaction(bus_int, bitwise_lookup_bus_id)
})
.flat_map(|bus_int| {
let [x, y, z, op] = &bus_int.payload[..] else {
panic!();
};
if let Some(op) = op.try_to_number() {
if op == T::from(0) {
vec![x, y]
} else if op == T::from(1) {
vec![x, y, z]
} else {
vec![]
fn determine_range_constraints_using_solver<
T: FieldElement,
V: Clone + Hash + Eq + Ord + Debug + Display,
>(
system: &ConstraintSystem<T, V>,
bus_interaction_handler: impl BusInteractionHandler<T>,
) -> HashMap<GroupedExpression<T, V>, RangeConstraint<T>> {
let (wrapper, transformed_system) = solver::bus_interaction_variable_wrapper::BusInteractionVariableWrapper::replace_bus_interaction_expressions(system.clone());
Solver::new(transformed_system)
.with_bus_interaction_handler(bus_interaction_handler)
.solve()
.unwrap()
.range_constraints
.range_constraints
.into_iter()
.map(|(var, range_constraint)| {
let expr = match var {
bus_interaction_variable_wrapper::Variable::BusInteractionField(..) => {
wrapper.bus_interaction_vars[&var].clone()
}
} else {
vec![]
}
bus_interaction_variable_wrapper::Variable::Variable(v) => {
GroupedExpression::from_unknown_variable(v)
}
};
(expr, range_constraint)
})
.collect()
}

View File

@@ -92,7 +92,11 @@ fn optimization_loop_iteration<A: Adapter>(
};
let system = if let Some(bitwise_bus_id) = bus_map.get_bus_id(&BusType::BitwiseLookup) {
let system = optimize_bitwise_lookup(constraint_system, bitwise_bus_id);
let system = optimize_bitwise_lookup(
constraint_system,
bitwise_bus_id,
bus_interaction_handler.clone(),
);
stats_logger.log("optimizing bitwise lookup", &system);
system
} else {

View File

@@ -18,7 +18,7 @@ use std::collections::{BTreeMap, HashMap};
use std::fmt::{Debug, Display};
use std::hash::Hash;
mod bus_interaction_variable_wrapper;
pub mod bus_interaction_variable_wrapper;
mod exhaustive_search;
mod quadratic_equivalences;

View File

@@ -1459,10 +1459,10 @@ mod tests {
widths: AirWidths {
preprocessed: 0,
main: 14676,
log_up: 12128,
log_up: 12112,
},
constraints: 4143,
bus_interactions: 11673,
bus_interactions: 11668,
},
powdr_expected_machine_count: 10,
non_powdr_expected_sum: NON_POWDR_EXPECTED_SUM,
@@ -1483,10 +1483,10 @@ mod tests {
widths: AirWidths {
preprocessed: 0,
main: 14656,
log_up: 12108,
log_up: 12092,
},
constraints: 4127,
bus_interactions: 11663,
bus_interactions: 11658,
},
powdr_expected_machine_count: 10,
non_powdr_expected_sum: NON_POWDR_EXPECTED_SUM,
@@ -1508,7 +1508,7 @@ mod tests {
after: AirWidths {
preprocessed: 0,
main: 14656,
log_up: 12108,
log_up: 12092,
},
}),
});
@@ -1617,11 +1617,11 @@ mod tests {
let powdr_metrics_sum = AirMetrics {
widths: AirWidths {
preprocessed: 0,
main: 4831,
log_up: 3968,
main: 4843,
log_up: 3952,
},
constraints: 958,
bus_interactions: 3821,
constraints: 962,
bus_interactions: 3818,
};
let expected_metrics = MachineTestMetrics {
@@ -1640,13 +1640,13 @@ mod tests {
expected_columns_saved: Some(AirWidthsDiff {
before: AirWidths {
preprocessed: 0,
main: 38950,
log_up: 26908,
main: 38986,
log_up: 26936,
},
after: AirWidths {
preprocessed: 0,
main: 4831,
log_up: 3968,
main: 4843,
log_up: 3952,
},
}),
});

View File

@@ -29,8 +29,7 @@ mult=is_valid * 1, args=[b__0_0, 3, b__0_0 + 3 - 2 * b__0_1, 1]
mult=diff_marker__0_1 + diff_marker__1_1 + diff_marker__2_1 + diff_marker__3_1, args=[diff_val_1 - 1, 0, 0, 0]
mult=diff_marker__0_2 + diff_marker__1_2 + diff_marker__2_2 + diff_marker__3_2, args=[diff_val_2 - 1, 0, 0, 0]
mult=is_valid * 1, args=[b__0_3, c__0_3, 2 * a__0_4 - (b__0_3 + c__0_3), 1]
mult=is_valid * 1, args=[b__1_0, b__2_0, 0, 0]
mult=is_valid * 1, args=[b__3_0, b_msb_f_2, 0, 0]
mult=is_valid * 1, args=[b_msb_f_2, 0, 0, 0]
// Algebraic constraints:
b__0_3 * (b__0_3 - 1) = 0

View File

@@ -16,9 +16,5 @@ mult=is_valid * 1, args=[reads_aux__0__base__timestamp_lt_aux__lower_decomp__1_0
mult=is_valid * 1, args=[writes_aux__base__timestamp_lt_aux__lower_decomp__0_0, 17]
mult=is_valid * 1, args=[writes_aux__base__timestamp_lt_aux__lower_decomp__1_0, 12]
// Bus 6 (BITWISE_LOOKUP):
mult=is_valid * 1, args=[b__0_0, b__1_0, 0, 0]
mult=is_valid * 1, args=[b__2_0, 0, 0, 0]
// Algebraic constraints:
is_valid * (is_valid - 1) = 0