Merge pull request #138 from chriseth/split_machines

Split out multiple machines.
This commit is contained in:
chriseth
2023-04-06 19:46:31 +02:00
committed by GitHub

View File

@@ -1,12 +1,11 @@
use std::collections::HashSet;
use crate::analyzer::{Expression, Identity, SelectedExpressions};
use super::double_sorted_witness_machine::DoubleSortedWitnesses;
use super::fixed_lookup_machine::FixedLookup;
use super::sorted_witness_machine::SortedWitnesses;
use super::FixedData;
use super::Machine;
use crate::analyzer::{Expression, Identity, SelectedExpressions};
use crate::witness_generator::WitnessColumn;
/// Finds machines in the witness columns and identities
@@ -17,67 +16,99 @@ pub fn split_out_machines<'a>(
identities: &'a [Identity],
witness_cols: &'a [WitnessColumn],
) -> (Vec<Box<dyn Machine>>, Vec<&'a Identity>) {
// TODO we only split out one machine for now.
// We could also split the machine into independent sub-machines.
// The lookup-in-fixed-columns machine, it always exists with an empty set of witnesses.
let mut machines: Vec<Box<dyn Machine>> =
vec![FixedLookup::try_new(fixed, &[], &Default::default()).unwrap()];
let witness_names = witness_cols.iter().map(|c| c.name).collect::<HashSet<_>>();
let all_witnesses = ReferenceExtractor::new(witness_names.clone());
// Extract all witness columns in the RHS of lookups.
let lookup_witnesses = identities
.iter()
.map(|i| all_witnesses.in_selected_expressions(&i.right))
.reduce(|l, r| &l | &r)
.unwrap_or_default();
// Recursively extend the set to all witnesses connected through identities.
let machine_witnesses = all_connected_witnesses(&all_witnesses, lookup_witnesses, identities);
let machine_witness_extractor = ReferenceExtractor::new(machine_witnesses.clone());
// Split identities into those that only concern the machine
// witnesses and those that concern any other witness.
let (machine_identities, base_identities): (Vec<_>, _) = identities.iter().partition(|i| {
// The identity has at least one machine witness, but
// all referenced witnesses are machine witnesses.
let mw = machine_witness_extractor.in_identity(i);
!mw.is_empty() && all_witnesses.in_identity(i).is_subset(&mw)
});
// TODO we probably need to check that machine witnesses do not appear
// in any identity among `identities` except on the RHS.
if let Some(machine) = SortedWitnesses::try_new(fixed, &machine_identities, &machine_witnesses)
{
if fixed.verbose {
println!("Detected machine: sorted witnesses / write-once memory");
let all_witnesses = witness_cols.iter().map(|c| c.name).collect::<HashSet<_>>();
let mut remaining_witnesses = all_witnesses.clone();
let mut base_identities = identities.iter().collect::<Vec<_>>();
for id in identities {
// Extract all witness columns in the RHS of the lookup.
let lookup_witnesses = &refs_in_selected_expressions(&id.right) & (&remaining_witnesses);
if lookup_witnesses.is_empty() {
continue;
}
machines.push(machine);
} else if let Some(machine) =
DoubleSortedWitnesses::try_new(fixed, &machine_identities, &machine_witnesses)
{
// Recursively extend the set to all witnesses connected through identities that preserve
// a fixed row relation.
let machine_witnesses =
all_row_connected_witnesses(lookup_witnesses, &remaining_witnesses, identities);
// Split identities into those that only concern the machine
// witnesses and those that concern any other witness.
let (machine_identities, remaining_identities): (Vec<_>, _) =
base_identities.iter().partition(|i| {
// The identity has at least one machine witness, but
// all referenced witnesses are machine witnesses.
let all_refs = &refs_in_identity(i) & (&all_witnesses);
!all_refs.is_empty() && all_refs.is_subset(&machine_witnesses)
});
base_identities = remaining_identities;
remaining_witnesses = &remaining_witnesses - &machine_witnesses;
if fixed.verbose {
println!("Detected machine: memory");
println!(
"Extracted a machine with the following witnesses and identities:\n{}\n{}",
machine_witnesses
.iter()
.map(|s| s.to_string())
.collect::<Vec<_>>()
.join(", "),
machine_identities
.iter()
.map(|id| id.to_string())
.collect::<Vec<_>>()
.join("\n")
);
}
if let Some(machine) =
SortedWitnesses::try_new(fixed, &machine_identities, &machine_witnesses)
{
if fixed.verbose {
println!("Detected machine: sorted witnesses / write-once memory");
}
machines.push(machine);
} else if let Some(machine) =
DoubleSortedWitnesses::try_new(fixed, &machine_identities, &machine_witnesses)
{
if fixed.verbose {
println!("Detected machine: memory");
}
machines.push(machine);
} else {
println!(
"Could not find a matching machine to handle a query to the following witness set:\n{}",
machine_witnesses
.iter()
.map(|s| s.to_string())
.collect::<Vec<_>>()
.join(", ")
);
remaining_witnesses = &remaining_witnesses | &machine_witnesses;
base_identities.extend(machine_identities);
println!("Will try to continue as is, but this probably requires a specialized machine implementation.");
}
machines.push(machine);
}
(machines, base_identities)
}
fn all_connected_witnesses<'a>(
all_witnesses: &'a ReferenceExtractor,
/// Extends a set of witnesses to the full set of row-connected witnesses.
/// Two witnesses are row-connected if they are part of a polynomial identity
/// or part of the same side of a lookup.
fn all_row_connected_witnesses<'a>(
mut witnesses: HashSet<&'a str>,
all_witnesses: &HashSet<&'a str>,
identities: &'a [Identity],
) -> HashSet<&'a str> {
let mut count = witnesses.len();
loop {
let count = witnesses.len();
for i in identities {
match i.kind {
crate::analyzer::IdentityKind::Polynomial => {
// Any current witness in the identity adds all other witnesses.
let in_identity = all_witnesses.in_identity(i);
let in_identity = &refs_in_identity(i) & all_witnesses;
if in_identity.intersection(&witnesses).next().is_some() {
witnesses.extend(in_identity);
}
@@ -85,12 +116,12 @@ fn all_connected_witnesses<'a>(
crate::analyzer::IdentityKind::Plookup
| crate::analyzer::IdentityKind::Permutation
| crate::analyzer::IdentityKind::Connect => {
// If we already have witnesses on the LHS, include the RHS, but not vice-versa.
let in_lhs = all_witnesses.in_selected_expressions(&i.left);
let in_rhs = all_witnesses.in_selected_expressions(&i.right);
// If we already have witnesses on the LHS, include the LHS,
// and vice-versa, but not across the "sides".
let in_lhs = &refs_in_selected_expressions(&i.left) & all_witnesses;
let in_rhs = &refs_in_selected_expressions(&i.right) & all_witnesses;
if in_lhs.intersection(&witnesses).next().is_some() {
witnesses.extend(in_lhs);
witnesses.extend(in_rhs);
} else if in_rhs.intersection(&witnesses).next().is_some() {
witnesses.extend(in_rhs);
}
@@ -100,66 +131,54 @@ fn all_connected_witnesses<'a>(
if witnesses.len() == count {
return witnesses;
}
count = witnesses.len()
}
}
/// Extracts all references to any of the given names
/// in expressions and identities.
struct ReferenceExtractor<'a> {
names: HashSet<&'a str>,
/// Extracts all references to names from an identity.
pub fn refs_in_identity(identity: &Identity) -> HashSet<&str> {
&refs_in_selected_expressions(&identity.left) | &refs_in_selected_expressions(&identity.right)
}
impl<'a> ReferenceExtractor<'a> {
pub fn new(names: HashSet<&'a str>) -> Self {
ReferenceExtractor { names }
}
pub fn in_identity(&self, identity: &'a Identity) -> HashSet<&'a str> {
&self.in_selected_expressions(&identity.left)
| &self.in_selected_expressions(&identity.right)
}
pub fn in_selected_expressions(&self, selexpr: &'a SelectedExpressions) -> HashSet<&'a str> {
selexpr
.expressions
.iter()
.chain(selexpr.selector.iter())
.map(|e| self.in_expression(e))
.reduce(|l, r| &l | &r)
.unwrap_or_default()
}
pub fn in_expression(&self, expr: &'a Expression) -> HashSet<&'a str> {
match expr {
Expression::Constant(_) => todo!(),
Expression::PolynomialReference(p) => {
if self.names.contains(p.name.as_str()) {
[p.name.as_str()].into()
} else {
HashSet::default()
}
}
Expression::Tuple(items) => self.in_expressions(items),
Expression::BinaryOperation(l, _, r) => &self.in_expression(l) | &self.in_expression(r),
Expression::UnaryOperation(_, e) => self.in_expression(e),
Expression::FunctionCall(_, args) => self.in_expressions(args),
Expression::MatchExpression(scrutinee, arms) => {
&self.in_expression(scrutinee)
| &arms
.iter()
.map(|(_, e)| self.in_expression(e))
.reduce(|a, b| &a | &b)
.unwrap_or_default()
}
Expression::LocalVariableReference(_)
| Expression::PublicReference(_)
| Expression::Number(_)
| Expression::String(_) => HashSet::default(),
/// Extracts all references to names from selected expressions.
pub fn refs_in_selected_expressions(selexpr: &SelectedExpressions) -> HashSet<&str> {
selexpr
.expressions
.iter()
.chain(selexpr.selector.iter())
.map(refs_in_expression)
.reduce(|l, r| &l | &r)
.unwrap_or_default()
}
/// Extracts all references to names from an expression
pub fn refs_in_expression(expr: &Expression) -> HashSet<&str> {
match expr {
Expression::Constant(_) => todo!(),
Expression::PolynomialReference(p) => [p.name.as_str()].into(),
Expression::Tuple(items) => refs_in_expressions(items),
Expression::BinaryOperation(l, _, r) => &refs_in_expression(l) | &refs_in_expression(r),
Expression::UnaryOperation(_, e) => refs_in_expression(e),
Expression::FunctionCall(_, args) => refs_in_expressions(args),
Expression::MatchExpression(scrutinee, arms) => {
&refs_in_expression(scrutinee)
| &arms
.iter()
.map(|(_, e)| refs_in_expression(e))
.reduce(|a, b| &a | &b)
.unwrap_or_default()
}
}
pub fn in_expressions(&self, exprs: &'a [Expression]) -> HashSet<&'a str> {
exprs
.iter()
.map(|e| self.in_expression(e))
.reduce(|l, r| &l | &r)
.unwrap_or_default()
Expression::LocalVariableReference(_)
| Expression::PublicReference(_)
| Expression::Number(_)
| Expression::String(_) => HashSet::default(),
}
}
/// Extracts all references to names from expressions.
pub fn refs_in_expressions(exprs: &[Expression]) -> HashSet<&str> {
exprs
.iter()
.map(refs_in_expression)
.reduce(|l, r| &l | &r)
.unwrap_or_default()
}