diff --git a/executor/src/witgen/generator.rs b/executor/src/witgen/generator.rs index 0c7650efe..0eefb9560 100644 --- a/executor/src/witgen/generator.rs +++ b/executor/src/witgen/generator.rs @@ -2,7 +2,7 @@ use itertools::Itertools; use number::{DegreeType, FieldElement}; use parser_util::lines::indent; use pil_analyzer::{Expression, Identity, IdentityKind, PolynomialReference}; -use std::collections::{BTreeMap, HashMap}; +use std::collections::{BTreeMap, BTreeSet, HashMap}; use std::time::Instant; use super::affine_expression::{AffineExpression, AffineResult}; @@ -17,6 +17,7 @@ pub struct Generator<'a, T: FieldElement, QueryCallback: Send + Sync> { fixed_data: &'a FixedData<'a, T>, fixed_lookup: &'a mut FixedLookup, identities: &'a [&'a Identity], + witnesses: BTreeSet<&'a PolynomialReference>, machines: Vec>>, query_callback: Option, global_range_constraints: BTreeMap<&'a PolynomialReference, RangeConstraint>, @@ -64,6 +65,7 @@ where fixed_data: &'a FixedData<'a, T>, fixed_lookup: &'a mut FixedLookup, identities: &'a [&'a Identity], + witnesses: BTreeSet<&'a PolynomialReference>, global_range_constraints: BTreeMap<&'a PolynomialReference, RangeConstraint>, machines: Vec>>, query_callback: Option, @@ -74,6 +76,7 @@ where fixed_data, fixed_lookup, identities, + witnesses, machines, query_callback, global_range_constraints, @@ -156,10 +159,12 @@ where self.next .iter() .enumerate() - .filter_map(|(i, v)| if v.is_none() { - Some(self.fixed_data.witness_cols[i].poly.to_string()) - } else { - None + .filter_map(|(i, v)| { + if v.is_none() && self.is_relevant_witness(i) { + Some(self.fixed_data.witness_cols[i].poly.to_string()) + } else { + None + } }) .collect::>() .join(", ") @@ -246,7 +251,12 @@ where } fn format_next_values(&self) -> Vec { - self.format_next_values_iter(self.next.iter().enumerate()) + self.format_next_values_iter( + self.next + .iter() + .enumerate() + .filter(|(i, _)| self.is_relevant_witness(*i)), + ) } fn format_next_known_values(&self) -> Vec { @@ -496,6 +506,12 @@ where self.next[id].is_some() } + /// Returns true if this is a witness column we care about (instead of a sub-machine witness). + fn is_relevant_witness(&self, id: usize) -> bool { + self.witnesses + .contains(&self.fixed_data.witness_cols[id].poly) + } + /// Tries to evaluate the expression to an expression affine in the witness polynomials, /// taking current values of polynomials into account. /// @returns an expression affine in the witness polynomials diff --git a/executor/src/witgen/machines/machine_extractor.rs b/executor/src/witgen/machines/machine_extractor.rs index 40d1629ca..b1563b1b1 100644 --- a/executor/src/witgen/machines/machine_extractor.rs +++ b/executor/src/witgen/machines/machine_extractor.rs @@ -19,6 +19,7 @@ pub struct ExtractionOutput<'a, T> { pub fixed_lookup: FixedLookup, pub machines: Vec>>, pub base_identities: Vec<&'a Identity>, + pub base_witnesses: HashSet<&'a PolynomialReference>, } /// Finds machines in the witness columns and identities @@ -35,7 +36,7 @@ pub fn split_out_machines<'a, T: FieldElement>( let mut machines: Vec>> = vec![]; let all_witnesses = witness_cols.iter().map(|c| &c.poly).collect::>(); - let mut remaining_witnesses = all_witnesses.clone(); + let mut remaining_witnesses: HashSet<&'a PolynomialReference> = all_witnesses.clone(); let mut base_identities = identities.clone(); for id in &identities { // Extract all witness columns in the RHS of the lookup. @@ -124,6 +125,7 @@ pub fn split_out_machines<'a, T: FieldElement>( fixed_lookup: *fixed_lookup, machines, base_identities, + base_witnesses: remaining_witnesses, } } @@ -133,7 +135,7 @@ pub fn split_out_machines<'a, T: FieldElement>( fn all_row_connected_witnesses<'a, T>( mut witnesses: HashSet<&'a PolynomialReference>, all_witnesses: &HashSet<&'a PolynomialReference>, - identities: &'a [&'a Identity], + identities: &[&'a Identity], ) -> HashSet<&'a PolynomialReference> { loop { let count = witnesses.len(); diff --git a/executor/src/witgen/mod.rs b/executor/src/witgen/mod.rs index ac012b997..a8b338390 100644 --- a/executor/src/witgen/mod.rs +++ b/executor/src/witgen/mod.rs @@ -71,6 +71,7 @@ where mut fixed_lookup, machines, base_identities, + base_witnesses, } = machines::machine_extractor::split_out_machines( &fixed, retained_identities, @@ -81,6 +82,7 @@ where &fixed, &mut fixed_lookup, &base_identities, + base_witnesses.into_iter().collect(), known_constraints, machines, query_callback,