diff --git a/executor/src/witgen/machines/double_sorted_witness_machine.rs b/executor/src/witgen/machines/double_sorted_witness_machine.rs index 02ba99215..8d9e8a548 100644 --- a/executor/src/witgen/machines/double_sorted_witness_machine.rs +++ b/executor/src/witgen/machines/double_sorted_witness_machine.rs @@ -5,10 +5,11 @@ use itertools::{Either, Itertools}; use super::{FixedLookup, Machine}; use crate::witgen::affine_expression::AffineResult; +use crate::witgen::util::is_simple_poly_of_name; use crate::witgen::{EvalError, EvalResult, FixedData}; use crate::witgen::{EvalValue, IncompleteCause}; use number::FieldElement; -use pil_analyzer::PolynomialReference; + use pil_analyzer::{Expression, Identity, IdentityKind, SelectedExpressions}; /// TODO make this generic @@ -69,18 +70,8 @@ impl Machine for DoubleSortedWitnesses { right: &SelectedExpressions, ) -> Option { if kind != IdentityKind::Permutation - || (right.selector - != Some(Expression::PolynomialReference(PolynomialReference { - name: "Assembly.m_is_read".to_owned(), - index: None, - next: false, - })) - && right.selector - != Some(Expression::PolynomialReference(PolynomialReference { - name: "Assembly.m_is_write".to_owned(), - index: None, - next: false, - }))) + || !(is_simple_poly_of_name(right.selector.as_ref()?, "Assembly.m_is_read") + || is_simple_poly_of_name(right.selector.as_ref()?, "Assembly.m_is_write")) { return None; } diff --git a/executor/src/witgen/util.rs b/executor/src/witgen/util.rs index 8db5e1f02..6654590c2 100644 --- a/executor/src/witgen/util.rs +++ b/executor/src/witgen/util.rs @@ -41,10 +41,12 @@ pub fn contains_witness_ref(expr: &Expression, fixed_data: &FixedData) -> bool { /// - not shifted with `'` /// and return the polynomial's name if so pub fn is_simple_poly(expr: &Expression) -> Option<&str> { + // TODO return the ID and not the str if let Expression::PolynomialReference(PolynomialReference { name, index: None, next: false, + .. }) = expr { Some(name) @@ -52,3 +54,17 @@ pub fn is_simple_poly(expr: &Expression) -> Option<&str> { None } } + +pub fn is_simple_poly_of_name(expr: &Expression, poly_name: &str) -> bool { + if let Expression::PolynomialReference(PolynomialReference { + name, + index: None, + next: false, + .. + }) = expr + { + name == poly_name + } else { + false + } +} diff --git a/pil_analyzer/src/display.rs b/pil_analyzer/src/display.rs index 5291375bf..32f5e1028 100644 --- a/pil_analyzer/src/display.rs +++ b/pil_analyzer/src/display.rs @@ -211,12 +211,14 @@ namespace T(65536); col witness instr_dec_CNT; col witness instr_assert_zero; (T.instr_assert_zero * (T.XIsZero - 1)) = 0; + col witness X; col witness X_const; col witness X_read_free; col witness A; col witness CNT; col witness read_X_A; col witness read_X_CNT; + col witness reg_write_X_CNT; col witness read_X_pc; col witness reg_write_X_A; T.X = ((((T.read_X_A * T.A) + (T.read_X_CNT * T.CNT)) + T.X_const) + (T.X_read_free * T.X_free_value)); diff --git a/pil_analyzer/src/json_exporter/mod.rs b/pil_analyzer/src/json_exporter/mod.rs index 4cb78f316..7e98c1de1 100644 --- a/pil_analyzer/src/json_exporter/mod.rs +++ b/pil_analyzer/src/json_exporter/mod.rs @@ -314,22 +314,27 @@ impl<'a> Exporter<'a> { fn polynomial_reference_to_json( &self, - PolynomialReference { name, index, next }: &PolynomialReference, + PolynomialReference { + name: _, + index, + poly_id, + next, + }: &PolynomialReference, ) -> (u32, JsonValue, Vec) { - let poly = &self.analyzed.definitions[name].0; - let id = if poly.poly_type == PolynomialType::Intermediate { + let (id, poly_type) = poly_id.unwrap(); + let id = if poly_type == PolynomialType::Intermediate { assert!(index.is_none()); - self.intermediate_poly_expression_ids[&poly.id] + self.intermediate_poly_expression_ids[&id] } else { - poly.id + index.unwrap_or_default() + id + index.unwrap_or_default() }; let poly_json = object! { id: id, - op: polynomial_reference_type_to_json_string(poly.poly_type), + op: polynomial_reference_type_to_json_string(poly_type), deg: 1, next: *next, }; - let dependencies = if poly.poly_type == PolynomialType::Intermediate { + let dependencies = if poly_type == PolynomialType::Intermediate { vec![id] } else { Vec::new() diff --git a/pil_analyzer/src/lib.rs b/pil_analyzer/src/lib.rs index 8415c3764..a75921c23 100644 --- a/pil_analyzer/src/lib.rs +++ b/pil_analyzer/src/lib.rs @@ -92,6 +92,7 @@ impl Analyzed { } } +#[derive(Debug, Clone)] pub struct Polynomial { pub id: u64, pub source: SourceRef, @@ -183,16 +184,40 @@ pub enum Expression { MatchExpression(Box, Vec<(Option, Expression)>), } -#[derive(Debug, PartialEq, Eq, Default, Clone)] +#[derive(Debug, Clone, Eq)] pub struct PolynomialReference { - // TODO would be better to use numeric IDs instead of names, - // but the IDs as they are overlap. Maybe we can change that. + /// Name of the polynomial - just for informational purposes. + /// Comparisons are based on polynomial ID. pub name: String, + /// Identifier for a polynomial reference. + /// Optional because it is filled in in a second stage of analysis. + /// TODO make this non-optional + pub poly_id: Option<(u64, PolynomialType)>, pub index: Option, pub next: bool, } -#[derive(Debug, Copy, Clone, PartialEq, Eq)] +impl PartialOrd for PolynomialReference { + fn partial_cmp(&self, other: &Self) -> Option { + // TODO for efficiency reasons, we should avoid the unwrap check here somehow. + match self.poly_id.unwrap().partial_cmp(&other.poly_id.unwrap()) { + Some(core::cmp::Ordering::Equal) => {} + ord => return ord, + } + assert!(self.index.is_none() && other.index.is_none()); + self.next.partial_cmp(&other.next) + } +} + +impl PartialEq for PolynomialReference { + fn eq(&self, other: &Self) -> bool { + assert!(self.index.is_none() && other.index.is_none()); + // TODO for efficiency reasons, we should avoid the unwrap check here somehow. + self.poly_id.unwrap() == other.poly_id.unwrap() && self.next == other.next + } +} + +#[derive(Debug, Copy, Clone, PartialEq, Eq, PartialOrd, Ord)] pub enum PolynomialType { Committed, Constant, diff --git a/pil_analyzer/src/pil_analyzer.rs b/pil_analyzer/src/pil_analyzer.rs index cc9385c79..9316f0d7b 100644 --- a/pil_analyzer/src/pil_analyzer.rs +++ b/pil_analyzer/src/pil_analyzer.rs @@ -6,6 +6,8 @@ use number::DegreeType; use parser::ast; pub use parser::ast::{BinaryOperator, UnaryOperator}; +use crate::util::previsit_expressions_in_pil_file_mut; + use super::*; pub fn process_pil_file(path: &Path) -> Analyzed { @@ -65,13 +67,34 @@ impl From for Analyzed { .. }: PILContext, ) -> Self { - Self { + let ids = definitions + .iter() + .map(|(name, (poly, _))| (name.clone(), poly.clone())) + .collect::>(); + let mut result = Self { constants, definitions, public_declarations, identities, source_order, - } + }; + let assign_id = |reference: &mut PolynomialReference| { + let poly = ids + .get(&reference.name) + .unwrap_or_else(|| panic!("Column {} not found.", reference.name)); + reference.poly_id = Some((poly.id, poly.poly_type)); + }; + previsit_expressions_in_pil_file_mut(&mut result, &mut |e| { + if let Expression::PolynomialReference(reference) = e { + assign_id(reference); + } + std::ops::ControlFlow::Continue::<()>(()) + }); + result + .public_declarations + .values_mut() + .for_each(|public_decl| assign_id(&mut public_decl.polynomial)); + result } } @@ -578,8 +601,10 @@ impl PILContext { .as_ref() .map(|i| self.evaluate_expression(i).unwrap()) .map(|i| i.to_degree()); + let name = self.namespaced_ref(&poly.namespace, &poly.name); PolynomialReference { - name: self.namespaced_ref(&poly.namespace, &poly.name), + name, + poly_id: None, index, next: poly.next, } diff --git a/pil_analyzer/src/util.rs b/pil_analyzer/src/util.rs index 8b3649af4..9aefe9d55 100644 --- a/pil_analyzer/src/util.rs +++ b/pil_analyzer/src/util.rs @@ -1,6 +1,89 @@ +use crate::{Analyzed, Expression, FunctionValueDefinition, Identity}; use std::{iter::once, ops::ControlFlow}; -use crate::Expression; +/// Calls `f` on each expression in the pil file and then descends into the +/// (potentially modified) expression. +pub fn previsit_expressions_in_pil_file_mut( + pil_file: &mut Analyzed, + f: &mut F, +) -> ControlFlow +where + F: FnMut(&mut Expression) -> ControlFlow, +{ + pil_file + .definitions + .values_mut() + .try_for_each(|(_poly, definition)| match definition { + Some(FunctionValueDefinition::Mapping(e)) | Some(FunctionValueDefinition::Query(e)) => { + previsit_expression_mut(e, f) + } + Some(FunctionValueDefinition::Array(elements)) => elements + .iter_mut() + .flat_map(|e| e.values.iter_mut()) + .try_for_each(|e| previsit_expression_mut(e, f)), + None => ControlFlow::Continue(()), + })?; + + pil_file + .identities + .iter_mut() + .try_for_each(|i| previsit_expressions_in_identity_mut(i, f)) +} + +pub fn postvisit_expressions_in_identity_mut(i: &mut Identity, f: &mut F) -> ControlFlow +where + F: FnMut(&mut Expression) -> ControlFlow, +{ + i.left + .selector + .as_mut() + .into_iter() + .chain(i.right.selector.as_mut()) + .try_for_each(move |item| postvisit_expression_mut(item, f)) +} + +/// Calls `f` on each expression in the pil file and then descends into the +/// (potentially modified) expression. +pub fn postvisit_expressions_in_pil_file_mut( + pil_file: &mut Analyzed, + f: &mut F, +) -> ControlFlow +where + F: FnMut(&mut Expression) -> ControlFlow, +{ + pil_file + .definitions + .values_mut() + .try_for_each(|(_poly, definition)| match definition { + Some(FunctionValueDefinition::Mapping(e)) | Some(FunctionValueDefinition::Query(e)) => { + postvisit_expression_mut(e, f) + } + Some(FunctionValueDefinition::Array(elements)) => elements + .iter_mut() + .flat_map(|e| e.values.iter_mut()) + .try_for_each(|e| postvisit_expression_mut(e, f)), + None => ControlFlow::Continue(()), + })?; + + pil_file + .identities + .iter_mut() + .try_for_each(|i| postvisit_expressions_in_identity_mut(i, f)) +} + +pub fn previsit_expressions_in_identity_mut(i: &mut Identity, f: &mut F) -> ControlFlow +where + F: FnMut(&mut Expression) -> ControlFlow, +{ + i.left + .selector + .as_mut() + .into_iter() + .chain(i.left.expressions.iter_mut()) + .chain(i.right.selector.as_mut()) + .chain(i.right.expressions.iter_mut()) + .try_for_each(move |item| previsit_expression_mut(item, f)) +} /// Visits `expr` and all of its sub-expressions and returns true if `f` returns true on any of them. pub fn expr_any(expr: &Expression, mut f: impl FnMut(&Expression) -> bool) -> bool { @@ -75,3 +158,32 @@ where }; ControlFlow::Continue(()) } + +/// Traverses the expression tree and calls `f` in post-order. +pub fn postvisit_expression_mut(e: &mut Expression, f: &mut F) -> ControlFlow +where + F: FnMut(&mut Expression) -> ControlFlow, +{ + match e { + Expression::PolynomialReference(_) + | Expression::Constant(_) + | Expression::LocalVariableReference(_) + | Expression::PublicReference(_) + | Expression::Number(_) + | Expression::String(_) => {} + Expression::BinaryOperation(left, _, right) => { + postvisit_expression_mut(left, f)?; + postvisit_expression_mut(right, f)?; + } + Expression::UnaryOperation(_, e) => postvisit_expression_mut(e.as_mut(), f)?, + Expression::Tuple(items) | Expression::FunctionCall(_, items) => items + .iter_mut() + .try_for_each(|item| postvisit_expression_mut(item, f))?, + Expression::MatchExpression(scrutinee, arms) => { + once(scrutinee.as_mut()) + .chain(arms.iter_mut().map(|(_n, e)| e)) + .try_for_each(|item| postvisit_expression_mut(item, f))?; + } + }; + f(e) +}