From 6ca7cffc3879af6242e6afd5b6f5c525958fd0ab Mon Sep 17 00:00:00 2001 From: chriseth Date: Tue, 4 Apr 2023 16:16:09 +0200 Subject: [PATCH] Split out multiple machines. --- .../machines/machine_extractor.rs | 225 ++++++++++-------- 1 file changed, 122 insertions(+), 103 deletions(-) diff --git a/src/witness_generator/machines/machine_extractor.rs b/src/witness_generator/machines/machine_extractor.rs index 2cedf47af..1cbd700e1 100644 --- a/src/witness_generator/machines/machine_extractor.rs +++ b/src/witness_generator/machines/machine_extractor.rs @@ -1,12 +1,11 @@ use std::collections::HashSet; -use crate::analyzer::{Expression, Identity, SelectedExpressions}; - use super::double_sorted_witness_machine::DoubleSortedWitnesses; use super::fixed_lookup_machine::FixedLookup; use super::sorted_witness_machine::SortedWitnesses; use super::FixedData; use super::Machine; +use crate::analyzer::{Expression, Identity, SelectedExpressions}; use crate::witness_generator::WitnessColumn; /// Finds machines in the witness columns and identities @@ -17,67 +16,99 @@ pub fn split_out_machines<'a>( identities: &'a [Identity], witness_cols: &'a [WitnessColumn], ) -> (Vec>, Vec<&'a Identity>) { - // TODO we only split out one machine for now. - // We could also split the machine into independent sub-machines. - // The lookup-in-fixed-columns machine, it always exists with an empty set of witnesses. let mut machines: Vec> = vec![FixedLookup::try_new(fixed, &[], &Default::default()).unwrap()]; - let witness_names = witness_cols.iter().map(|c| c.name).collect::>(); - let all_witnesses = ReferenceExtractor::new(witness_names.clone()); - // Extract all witness columns in the RHS of lookups. - let lookup_witnesses = identities - .iter() - .map(|i| all_witnesses.in_selected_expressions(&i.right)) - .reduce(|l, r| &l | &r) - .unwrap_or_default(); - - // Recursively extend the set to all witnesses connected through identities. - let machine_witnesses = all_connected_witnesses(&all_witnesses, lookup_witnesses, identities); - let machine_witness_extractor = ReferenceExtractor::new(machine_witnesses.clone()); - - // Split identities into those that only concern the machine - // witnesses and those that concern any other witness. - let (machine_identities, base_identities): (Vec<_>, _) = identities.iter().partition(|i| { - // The identity has at least one machine witness, but - // all referenced witnesses are machine witnesses. - let mw = machine_witness_extractor.in_identity(i); - !mw.is_empty() && all_witnesses.in_identity(i).is_subset(&mw) - }); - - // TODO we probably need to check that machine witnesses do not appear - // in any identity among `identities` except on the RHS. - - if let Some(machine) = SortedWitnesses::try_new(fixed, &machine_identities, &machine_witnesses) - { - if fixed.verbose { - println!("Detected machine: sorted witnesses / write-once memory"); + let all_witnesses = witness_cols.iter().map(|c| c.name).collect::>(); + let mut remaining_witnesses = all_witnesses.clone(); + let mut base_identities = identities.iter().collect::>(); + for id in identities { + // Extract all witness columns in the RHS of the lookup. + let lookup_witnesses = &refs_in_selected_expressions(&id.right) & (&remaining_witnesses); + if lookup_witnesses.is_empty() { + continue; } - machines.push(machine); - } else if let Some(machine) = - DoubleSortedWitnesses::try_new(fixed, &machine_identities, &machine_witnesses) - { + + // Recursively extend the set to all witnesses connected through identities that preserve + // a fixed row relation. + let machine_witnesses = + all_row_connected_witnesses(lookup_witnesses, &remaining_witnesses, identities); + + // Split identities into those that only concern the machine + // witnesses and those that concern any other witness. + let (machine_identities, remaining_identities): (Vec<_>, _) = + base_identities.iter().partition(|i| { + // The identity has at least one machine witness, but + // all referenced witnesses are machine witnesses. + let all_refs = &refs_in_identity(i) & (&all_witnesses); + !all_refs.is_empty() && all_refs.is_subset(&machine_witnesses) + }); + base_identities = remaining_identities; + remaining_witnesses = &remaining_witnesses - &machine_witnesses; + if fixed.verbose { - println!("Detected machine: memory"); + println!( + "Extracted a machine with the following witnesses and identities:\n{}\n{}", + machine_witnesses + .iter() + .map(|s| s.to_string()) + .collect::>() + .join(", "), + machine_identities + .iter() + .map(|id| id.to_string()) + .collect::>() + .join("\n") + ); + } + + if let Some(machine) = + SortedWitnesses::try_new(fixed, &machine_identities, &machine_witnesses) + { + if fixed.verbose { + println!("Detected machine: sorted witnesses / write-once memory"); + } + machines.push(machine); + } else if let Some(machine) = + DoubleSortedWitnesses::try_new(fixed, &machine_identities, &machine_witnesses) + { + if fixed.verbose { + println!("Detected machine: memory"); + } + machines.push(machine); + } else { + println!( + "Could not find a matching machine to handle a query to the following witness set:\n{}", + machine_witnesses + .iter() + .map(|s| s.to_string()) + .collect::>() + .join(", ") + ); + remaining_witnesses = &remaining_witnesses | &machine_witnesses; + base_identities.extend(machine_identities); + println!("Will try to continue as is, but this probably requires a specialized machine implementation."); } - machines.push(machine); } (machines, base_identities) } -fn all_connected_witnesses<'a>( - all_witnesses: &'a ReferenceExtractor, +/// Extends a set of witnesses to the full set of row-connected witnesses. +/// Two witnesses are row-connected if they are part of a polynomial identity +/// or part of the same side of a lookup. +fn all_row_connected_witnesses<'a>( mut witnesses: HashSet<&'a str>, + all_witnesses: &HashSet<&'a str>, identities: &'a [Identity], ) -> HashSet<&'a str> { - let mut count = witnesses.len(); loop { + let count = witnesses.len(); for i in identities { match i.kind { crate::analyzer::IdentityKind::Polynomial => { // Any current witness in the identity adds all other witnesses. - let in_identity = all_witnesses.in_identity(i); + let in_identity = &refs_in_identity(i) & all_witnesses; if in_identity.intersection(&witnesses).next().is_some() { witnesses.extend(in_identity); } @@ -85,12 +116,12 @@ fn all_connected_witnesses<'a>( crate::analyzer::IdentityKind::Plookup | crate::analyzer::IdentityKind::Permutation | crate::analyzer::IdentityKind::Connect => { - // If we already have witnesses on the LHS, include the RHS, but not vice-versa. - let in_lhs = all_witnesses.in_selected_expressions(&i.left); - let in_rhs = all_witnesses.in_selected_expressions(&i.right); + // If we already have witnesses on the LHS, include the LHS, + // and vice-versa, but not across the "sides". + let in_lhs = &refs_in_selected_expressions(&i.left) & all_witnesses; + let in_rhs = &refs_in_selected_expressions(&i.right) & all_witnesses; if in_lhs.intersection(&witnesses).next().is_some() { witnesses.extend(in_lhs); - witnesses.extend(in_rhs); } else if in_rhs.intersection(&witnesses).next().is_some() { witnesses.extend(in_rhs); } @@ -100,66 +131,54 @@ fn all_connected_witnesses<'a>( if witnesses.len() == count { return witnesses; } - count = witnesses.len() } } -/// Extracts all references to any of the given names -/// in expressions and identities. -struct ReferenceExtractor<'a> { - names: HashSet<&'a str>, +/// Extracts all references to names from an identity. +pub fn refs_in_identity(identity: &Identity) -> HashSet<&str> { + &refs_in_selected_expressions(&identity.left) | &refs_in_selected_expressions(&identity.right) } -impl<'a> ReferenceExtractor<'a> { - pub fn new(names: HashSet<&'a str>) -> Self { - ReferenceExtractor { names } - } - pub fn in_identity(&self, identity: &'a Identity) -> HashSet<&'a str> { - &self.in_selected_expressions(&identity.left) - | &self.in_selected_expressions(&identity.right) - } - pub fn in_selected_expressions(&self, selexpr: &'a SelectedExpressions) -> HashSet<&'a str> { - selexpr - .expressions - .iter() - .chain(selexpr.selector.iter()) - .map(|e| self.in_expression(e)) - .reduce(|l, r| &l | &r) - .unwrap_or_default() - } - pub fn in_expression(&self, expr: &'a Expression) -> HashSet<&'a str> { - match expr { - Expression::Constant(_) => todo!(), - Expression::PolynomialReference(p) => { - if self.names.contains(p.name.as_str()) { - [p.name.as_str()].into() - } else { - HashSet::default() - } - } - Expression::Tuple(items) => self.in_expressions(items), - Expression::BinaryOperation(l, _, r) => &self.in_expression(l) | &self.in_expression(r), - Expression::UnaryOperation(_, e) => self.in_expression(e), - Expression::FunctionCall(_, args) => self.in_expressions(args), - Expression::MatchExpression(scrutinee, arms) => { - &self.in_expression(scrutinee) - | &arms - .iter() - .map(|(_, e)| self.in_expression(e)) - .reduce(|a, b| &a | &b) - .unwrap_or_default() - } - Expression::LocalVariableReference(_) - | Expression::PublicReference(_) - | Expression::Number(_) - | Expression::String(_) => HashSet::default(), +/// Extracts all references to names from selected expressions. +pub fn refs_in_selected_expressions(selexpr: &SelectedExpressions) -> HashSet<&str> { + selexpr + .expressions + .iter() + .chain(selexpr.selector.iter()) + .map(refs_in_expression) + .reduce(|l, r| &l | &r) + .unwrap_or_default() +} + +/// Extracts all references to names from an expression +pub fn refs_in_expression(expr: &Expression) -> HashSet<&str> { + match expr { + Expression::Constant(_) => todo!(), + Expression::PolynomialReference(p) => [p.name.as_str()].into(), + Expression::Tuple(items) => refs_in_expressions(items), + Expression::BinaryOperation(l, _, r) => &refs_in_expression(l) | &refs_in_expression(r), + Expression::UnaryOperation(_, e) => refs_in_expression(e), + Expression::FunctionCall(_, args) => refs_in_expressions(args), + Expression::MatchExpression(scrutinee, arms) => { + &refs_in_expression(scrutinee) + | &arms + .iter() + .map(|(_, e)| refs_in_expression(e)) + .reduce(|a, b| &a | &b) + .unwrap_or_default() } - } - pub fn in_expressions(&self, exprs: &'a [Expression]) -> HashSet<&'a str> { - exprs - .iter() - .map(|e| self.in_expression(e)) - .reduce(|l, r| &l | &r) - .unwrap_or_default() + Expression::LocalVariableReference(_) + | Expression::PublicReference(_) + | Expression::Number(_) + | Expression::String(_) => HashSet::default(), } } + +/// Extracts all references to names from expressions. +pub fn refs_in_expressions(exprs: &[Expression]) -> HashSet<&str> { + exprs + .iter() + .map(refs_in_expression) + .reduce(|l, r| &l | &r) + .unwrap_or_default() +}