From 8782d534419cae3c19a79a5798b06ae27bd514ee Mon Sep 17 00:00:00 2001 From: chriseth Date: Tue, 26 Sep 2023 12:24:00 +0200 Subject: [PATCH] Visitors as bound functions. --- analysis/src/macro_expansion.rs | 11 +- asm_to_pil/src/romgen.rs | 43 +- asm_to_pil/src/vm_to_constrained.rs | 5 +- ast/src/analyzed/mod.rs | 7 +- ast/src/analyzed/util.rs | 142 ------ ast/src/analyzed/visitor.rs | 81 ++++ ast/src/asm_analysis/mod.rs | 50 +- ast/src/asm_analysis/utils.rs | 30 -- ast/src/parsed/mod.rs | 18 + ast/src/parsed/utils.rs | 196 +------- ast/src/parsed/visitor.rs | 439 ++++++++++++++++++ .../src/witgen/machines/machine_extractor.rs | 13 +- executor/src/witgen/util.rs | 5 +- pil_analyzer/src/pil_analyzer.rs | 15 +- pilopt/src/lib.rs | 22 +- 15 files changed, 632 insertions(+), 445 deletions(-) delete mode 100644 ast/src/analyzed/util.rs create mode 100644 ast/src/analyzed/visitor.rs create mode 100644 ast/src/parsed/visitor.rs diff --git a/analysis/src/macro_expansion.rs b/analysis/src/macro_expansion.rs index b8c383f15..93ddc64c6 100644 --- a/analysis/src/macro_expansion.rs +++ b/analysis/src/macro_expansion.rs @@ -1,13 +1,12 @@ use std::{ collections::{HashMap, HashSet}, convert::Infallible, - ops::ControlFlow, }; use ast::parsed::{ asm::{ASMProgram, Instruction, InstructionBody, Machine, MachineStatement}, folder::Folder, - utils::{postvisit_expression_in_statement_mut, postvisit_expression_mut}, + visitor::ExpressionVisitable, Expression, FunctionDefinition, PilStatement, }; use number::FieldElement; @@ -90,7 +89,7 @@ where } } - postvisit_expression_in_statement_mut(&mut statement, &mut |e| self.process_expression(e)); + statement.post_visit_expressions_mut(&mut |e| self.process_expression(e)); match &mut statement { PilStatement::FunctionCall(_start, name, arguments) => { @@ -147,7 +146,7 @@ where self.handle_statement(identity) } if let Some(e) = &mut expression { - postvisit_expression_mut(e, &mut |e| self.process_expression(e)); + e.post_visit_expressions_mut(&mut |e| self.process_expression(e)); }; self.arguments = old_arguments; @@ -155,7 +154,7 @@ where expression } - fn process_expression(&mut self, e: &mut Expression) -> ControlFlow<()> { + fn process_expression(&mut self, e: &mut Expression) { if let Expression::Reference(poly) = e { if poly.namespace().is_none() && self.parameter_names.contains_key(poly.name()) { // TODO to make this work inside macros, "next" and "index" need to be @@ -171,7 +170,5 @@ where .expect("Invoked a macro in expression context with empty expression.") } } - - ControlFlow::<()>::Continue(()) } } diff --git a/asm_to_pil/src/romgen.rs b/asm_to_pil/src/romgen.rs index d61c4d9ad..2ad1f2af5 100644 --- a/asm_to_pil/src/romgen.rs +++ b/asm_to_pil/src/romgen.rs @@ -1,17 +1,15 @@ //! Generate one ROM per machine from all declared functions -use std::{collections::HashMap, iter::repeat, ops::ControlFlow}; +use std::{collections::HashMap, iter::repeat}; -use ast::{ - asm_analysis::{ - utils::previsit_expression_in_statement_mut, Batch, CallableSymbol, FunctionStatement, - FunctionSymbol, Incompatible, IncompatibleSet, Machine, OperationSymbol, PilBlock, Rom, - }, - parsed::{ - asm::{OperationId, Param, ParamList, Params}, - utils::previsit_expression_mut, - Expression, - }, +use ast::asm_analysis::{ + Batch, CallableSymbol, FunctionStatement, FunctionSymbol, Incompatible, IncompatibleSet, + Machine, OperationSymbol, PilBlock, Rom, +}; +use ast::parsed::visitor::ExpressionVisitable; +use ast::parsed::{ + asm::{OperationId, Param, ParamList, Params}, + Expression, }; use number::FieldElement; @@ -29,23 +27,16 @@ use crate::{ fn substitute_name_in_statement_expressions( s: &mut FunctionStatement, substitution: &HashMap, -) -> ControlFlow<()> { - fn substitute( - e: &mut Expression, - substitution: &HashMap, - ) -> ControlFlow<()> { - previsit_expression_mut(e, &mut |e| { - if let Expression::Reference(r) = e { - if let Some(v) = substitution.get(r.name()).cloned() { - *r.name_mut() = v; - } - }; - ControlFlow::Continue::<()>(()) - }); - ControlFlow::Continue(()) +) { + fn substitute(e: &mut Expression, substitution: &HashMap) { + if let Expression::Reference(r) = e { + if let Some(v) = substitution.get(r.name()).cloned() { + *r.name_mut() = v; + } + }; } - previsit_expression_in_statement_mut(s, &mut |e| substitute(e, substitution)) + s.pre_visit_expressions_mut(&mut |e| substitute(e, substitution)) } /// Pad the arguments in the `return` statements with zeroes to match the maximum number of outputs diff --git a/asm_to_pil/src/vm_to_constrained.rs b/asm_to_pil/src/vm_to_constrained.rs index 1c5e5b9e3..01f4c21cf 100644 --- a/asm_to_pil/src/vm_to_constrained.rs +++ b/asm_to_pil/src/vm_to_constrained.rs @@ -14,7 +14,7 @@ use ast::{ build_add, build_binary_expr, build_mul, build_number, build_sub, direct_reference, next_reference, }, - utils::postvisit_expression_in_statement_mut, + visitor::ExpressionVisitable, ArrayExpression, BinaryOperator, Expression, FunctionDefinition, MatchArm, MatchPattern, PilStatement, PolynomialName, SelectedExpressions, UnaryOperator, }, @@ -382,13 +382,12 @@ impl ASMPILConverter { }) .collect::>(); body.iter_mut().for_each(|s| { - postvisit_expression_in_statement_mut(s, &mut |e| { + s.post_visit_expressions_mut(&mut |e| { if let Expression::Reference(r) = e { if let Some(sub) = substitutions.get(r.name()) { *r.name_mut() = sub.to_string(); } } - std::ops::ControlFlow::Continue::<()>(()) }); }); diff --git a/ast/src/analyzed/mod.rs b/ast/src/analyzed/mod.rs index d1d979828..4779552ac 100644 --- a/ast/src/analyzed/mod.rs +++ b/ast/src/analyzed/mod.rs @@ -1,6 +1,6 @@ pub mod build; mod display; -pub mod util; +pub mod visitor; use core::hash::Hash; use std::collections::{BTreeMap, BTreeSet, HashMap, HashSet}; @@ -9,9 +9,9 @@ use std::ops::ControlFlow; use number::DegreeType; -use crate::analyzed::util::previsit_expressions_in_pil_file_mut; use crate::parsed; use crate::parsed::utils::expr_any; +use crate::parsed::visitor::ExpressionVisitable; pub use crate::parsed::BinaryOperator; pub use crate::parsed::UnaryOperator; @@ -161,13 +161,12 @@ impl Analyzed { assert!(!to_remove.contains(&poly_id)); poly.id = replacements[&poly_id].id; }); - previsit_expressions_in_pil_file_mut(self, &mut |expr| { + self.pre_visit_expressions_mut(&mut |expr| { if let Expression::Reference(Reference::Poly(poly)) = expr { let poly_id = poly.poly_id.unwrap(); assert!(!to_remove.contains(&poly_id)); poly.poly_id = Some(replacements[&poly_id]); } - ControlFlow::Continue::<()>(()) }); } diff --git a/ast/src/analyzed/util.rs b/ast/src/analyzed/util.rs deleted file mode 100644 index 66213ca8e..000000000 --- a/ast/src/analyzed/util.rs +++ /dev/null @@ -1,142 +0,0 @@ -use std::ops::ControlFlow; - -use crate::parsed::utils::{ - postvisit_expression_mut, previsit_expression, previsit_expression_mut, -}; - -use super::{Analyzed, Expression, FunctionValueDefinition, Identity, SelectedExpressions}; - -/// Calls `f` on each expression in the pil file and then descends into the -/// (potentially modified) expression. -pub fn previsit_expressions_in_pil_file_mut( - pil_file: &mut Analyzed, - f: &mut F, -) -> ControlFlow -where - F: FnMut(&mut Expression) -> ControlFlow, -{ - pil_file - .definitions - .values_mut() - .try_for_each(|(_poly, definition)| match definition { - Some(FunctionValueDefinition::Mapping(e)) | Some(FunctionValueDefinition::Query(e)) => { - previsit_expression_mut(e, f) - } - Some(FunctionValueDefinition::Array(elements)) => elements - .iter_mut() - .flat_map(|e| e.pattern.iter_mut()) - .try_for_each(|e| previsit_expression_mut(e, f)), - Some(FunctionValueDefinition::Expression(e)) => previsit_expression_mut(e, f), - None => ControlFlow::Continue(()), - })?; - - pil_file - .identities - .iter_mut() - .try_for_each(|i| previsit_expressions_in_identity_mut(i, f)) -} - -/// Calls `f` on each expression in the pil file and then descends into the -/// (potentially modified) expression. -pub fn postvisit_expressions_in_pil_file_mut( - pil_file: &mut Analyzed, - f: &mut F, -) -> ControlFlow -where - F: FnMut(&mut Expression) -> ControlFlow, -{ - pil_file - .definitions - .values_mut() - .try_for_each(|(_poly, definition)| match definition { - Some(FunctionValueDefinition::Mapping(e)) | Some(FunctionValueDefinition::Query(e)) => { - postvisit_expression_mut(e, f) - } - Some(FunctionValueDefinition::Array(elements)) => elements - .iter_mut() - .flat_map(|e| e.pattern.iter_mut()) - .try_for_each(|e| postvisit_expression_mut(e, f)), - Some(FunctionValueDefinition::Expression(e)) => postvisit_expression_mut(e, f), - None => ControlFlow::Continue(()), - })?; - - pil_file - .identities - .iter_mut() - .try_for_each(|i| postvisit_expressions_in_identity_mut(i, f)) -} - -pub fn previsit_expressions_in_identity(i: &Identity, f: &mut F) -> ControlFlow -where - F: FnMut(&Expression) -> ControlFlow, -{ - [&i.left, &i.right] - .iter() - .try_for_each(move |item| previsit_expressions_in_selected_expressions(item, f)) -} - -pub fn previsit_expressions_in_identity_mut( - i: &mut Identity, - f: &mut F, -) -> ControlFlow -where - F: FnMut(&mut Expression) -> ControlFlow, -{ - [&mut i.left, &mut i.right] - .iter_mut() - .try_for_each(move |item| previsit_expressions_in_selected_expressions_mut(item, f)) -} - -pub fn postvisit_expressions_in_identity_mut( - i: &mut Identity, - f: &mut F, -) -> ControlFlow -where - F: FnMut(&mut Expression) -> ControlFlow, -{ - [&mut i.left, &mut i.right] - .iter_mut() - .try_for_each(move |item| postvisit_expressions_in_selected_expressions_mut(item, f)) -} - -pub fn previsit_expressions_in_selected_expressions( - s: &SelectedExpressions, - f: &mut F, -) -> ControlFlow -where - F: FnMut(&Expression) -> ControlFlow, -{ - s.selector - .as_ref() - .into_iter() - .chain(s.expressions.iter()) - .try_for_each(move |item| previsit_expression(item, f)) -} - -pub fn previsit_expressions_in_selected_expressions_mut( - s: &mut SelectedExpressions, - f: &mut F, -) -> ControlFlow -where - F: FnMut(&mut Expression) -> ControlFlow, -{ - s.selector - .as_mut() - .into_iter() - .chain(s.expressions.iter_mut()) - .try_for_each(move |item| previsit_expression_mut(item, f)) -} - -pub fn postvisit_expressions_in_selected_expressions_mut( - s: &mut SelectedExpressions, - f: &mut F, -) -> ControlFlow -where - F: FnMut(&mut Expression) -> ControlFlow, -{ - s.selector - .as_mut() - .into_iter() - .chain(s.expressions.iter_mut()) - .try_for_each(move |item| postvisit_expression_mut(item, f)) -} diff --git a/ast/src/analyzed/visitor.rs b/ast/src/analyzed/visitor.rs new file mode 100644 index 000000000..d397b814a --- /dev/null +++ b/ast/src/analyzed/visitor.rs @@ -0,0 +1,81 @@ +use crate::parsed::visitor::VisitOrder; + +use super::*; + +impl ExpressionVisitable for Analyzed { + fn visit_expressions_mut(&mut self, f: &mut F, o: VisitOrder) -> ControlFlow + where + F: FnMut(&mut parsed::Expression) -> ControlFlow, + { + // TODO add constants if we change them to expressions at some point. + self.definitions + .values_mut() + .try_for_each(|(_poly, definition)| match definition { + Some(FunctionValueDefinition::Mapping(e)) + | Some(FunctionValueDefinition::Query(e)) => e.visit_expressions_mut(f, o), + Some(FunctionValueDefinition::Array(elements)) => elements + .iter_mut() + .flat_map(|e| e.pattern.iter_mut()) + .try_for_each(|e| e.visit_expressions_mut(f, o)), + Some(FunctionValueDefinition::Expression(e)) => e.visit_expressions_mut(f, o), + None => ControlFlow::Continue(()), + })?; + + self.identities + .iter_mut() + .try_for_each(|i| i.visit_expressions_mut(f, o)) + } + + fn visit_expressions(&self, f: &mut F, o: VisitOrder) -> ControlFlow + where + F: FnMut(&parsed::Expression) -> ControlFlow, + { + // TODO add constants if we change them to expressions at some point. + self.definitions + .values() + .try_for_each(|(_poly, definition)| match definition { + Some(FunctionValueDefinition::Mapping(e)) + | Some(FunctionValueDefinition::Query(e)) => e.visit_expressions(f, o), + Some(FunctionValueDefinition::Array(elements)) => elements + .iter() + .flat_map(|e| e.pattern.iter()) + .try_for_each(|e| e.visit_expressions(f, o)), + Some(FunctionValueDefinition::Expression(e)) => e.visit_expressions(f, o), + None => ControlFlow::Continue(()), + })?; + + self.identities + .iter() + .try_for_each(|i| i.visit_expressions(f, o)) + } +} + +impl ExpressionVisitable for Identity { + fn visit_expressions_mut(&mut self, f: &mut F, o: VisitOrder) -> ControlFlow + where + F: FnMut(&mut parsed::Expression) -> ControlFlow, + { + self.left + .selector + .as_mut() + .into_iter() + .chain(self.left.expressions.iter_mut()) + .chain(self.right.selector.as_mut()) + .chain(self.right.expressions.iter_mut()) + .try_for_each(move |item| item.visit_expressions_mut(f, o)) + } + + fn visit_expressions(&self, f: &mut F, o: VisitOrder) -> ControlFlow + where + F: FnMut(&parsed::Expression) -> ControlFlow, + { + self.left + .selector + .as_ref() + .into_iter() + .chain(self.left.expressions.iter()) + .chain(self.right.selector.iter()) + .chain(self.right.expressions.iter()) + .try_for_each(move |item| item.visit_expressions(f, o)) + } +} diff --git a/ast/src/asm_analysis/mod.rs b/ast/src/asm_analysis/mod.rs index 0ee69d83c..0a79b2e69 100644 --- a/ast/src/asm_analysis/mod.rs +++ b/ast/src/asm_analysis/mod.rs @@ -7,6 +7,7 @@ use std::{ BTreeMap, BTreeSet, }, iter::{once, repeat}, + ops::ControlFlow, }; use itertools::Either; @@ -17,7 +18,8 @@ use crate::parsed::{ asm::{ AbsoluteSymbolPath, AssignmentRegister, CallableRef, InstructionBody, OperationId, Params, }, - PilStatement, + visitor::{ExpressionVisitable, VisitOrder}, + PilStatement, ShiftedPolynomialReference, }; pub use crate::parsed::Expression; @@ -539,6 +541,52 @@ pub enum FunctionStatement { Return(Return), } +impl ExpressionVisitable> for FunctionStatement { + fn visit_expressions_mut(&mut self, f: &mut F, o: VisitOrder) -> std::ops::ControlFlow + where + F: FnMut(&mut Expression>) -> std::ops::ControlFlow, + { + match self { + FunctionStatement::Assignment(assignment) => { + assignment.rhs.as_mut().visit_expressions_mut(f, o) + } + FunctionStatement::Instruction(instruction) => instruction + .inputs + .iter_mut() + .try_for_each(move |i| i.visit_expressions_mut(f, o)), + FunctionStatement::Label(_) | FunctionStatement::DebugDirective(..) => { + ControlFlow::Continue(()) + } + FunctionStatement::Return(ret) => ret + .values + .iter_mut() + .try_for_each(move |e| e.visit_expressions_mut(f, o)), + } + } + + fn visit_expressions(&self, f: &mut F, o: VisitOrder) -> std::ops::ControlFlow + where + F: FnMut(&Expression>) -> std::ops::ControlFlow, + { + match self { + FunctionStatement::Assignment(assignment) => { + assignment.rhs.as_ref().visit_expressions(f, o) + } + FunctionStatement::Instruction(instruction) => instruction + .inputs + .iter() + .try_for_each(move |i| i.visit_expressions(f, o)), + FunctionStatement::Label(_) | FunctionStatement::DebugDirective(..) => { + ControlFlow::Continue(()) + } + FunctionStatement::Return(ret) => ret + .values + .iter() + .try_for_each(move |e| e.visit_expressions(f, o)), + } + } +} + impl From> for FunctionStatement { fn from(value: AssignmentStatement) -> Self { Self::Assignment(value) diff --git a/ast/src/asm_analysis/utils.rs b/ast/src/asm_analysis/utils.rs index 8cc6ebf21..b28b04f64 100644 --- a/ast/src/asm_analysis/utils.rs +++ b/ast/src/asm_analysis/utils.rs @@ -1,33 +1,3 @@ -use std::ops::ControlFlow; -use crate::parsed::{utils::previsit_expression_mut, Expression}; -use super::FunctionStatement; -/// Traverses the expression tree and calls `f` in pre-order. -pub fn previsit_expression_in_statement_mut( - s: &mut FunctionStatement, - f: &mut F, -) -> ControlFlow -where - F: FnMut(&mut Expression) -> ControlFlow, -{ - match s { - FunctionStatement::Assignment(assignment) => { - previsit_expression_mut(assignment.rhs.as_mut(), f)?; - } - FunctionStatement::Instruction(instruction) => { - for i in &mut instruction.inputs { - previsit_expression_mut(i, f)?; - } - } - FunctionStatement::Label(_) | FunctionStatement::DebugDirective(..) => {} - FunctionStatement::Return(ret) => { - for e in &mut ret.values { - previsit_expression_mut(e, f); - } - } - } - - ControlFlow::Continue(()) -} diff --git a/ast/src/parsed/mod.rs b/ast/src/parsed/mod.rs index 27ea94a25..3098f9995 100644 --- a/ast/src/parsed/mod.rs +++ b/ast/src/parsed/mod.rs @@ -3,6 +3,7 @@ pub mod build; pub mod display; pub mod folder; pub mod utils; +pub mod visitor; use number::{DegreeType, FieldElement}; @@ -78,6 +79,23 @@ pub enum Expression> { MatchExpression(Box>, Vec>), } +impl Expression { + /// Visits this expression and all of its sub-expressions and returns true + /// if `f` returns true on any of them. + pub fn any(&self, mut f: impl FnMut(&Self) -> bool) -> bool { + use std::ops::ControlFlow; + use visitor::ExpressionVisitable; + self.pre_visit_expressions_return(&mut |e| { + if f(e) { + ControlFlow::Break(()) + } else { + ControlFlow::Continue(()) + } + }) + .is_break() + } +} + impl From> for Expression { fn from(value: ShiftedPolynomialReference) -> Self { Self::Reference(value) diff --git a/ast/src/parsed/utils.rs b/ast/src/parsed/utils.rs index 36225ca25..59a0c5cf9 100644 --- a/ast/src/parsed/utils.rs +++ b/ast/src/parsed/utils.rs @@ -1,16 +1,13 @@ -use std::{iter::once, ops::ControlFlow}; +use std::ops::ControlFlow; -use super::{ - ArrayExpression, ArrayLiteral, Expression, FunctionCall, FunctionDefinition, LambdaExpression, - MatchArm, PilStatement, -}; +use super::{visitor::ExpressionVisitable, Expression}; /// Visits `expr` and all of its sub-expressions and returns true if `f` returns true on any of them. pub fn expr_any( expr: &Expression, mut f: impl FnMut(&Expression) -> bool, ) -> bool { - previsit_expression(expr, &mut |e| { + expr.pre_visit_expressions_return(&mut |e| { if f(e) { ControlFlow::Break(()) } else { @@ -19,190 +16,3 @@ pub fn expr_any( }) .is_break() } - -/// Traverses the expression trees of the statement and calls `f` in post-order. -/// Does not enter macro definitions. -pub fn postvisit_expression_in_statement_mut( - statement: &mut PilStatement, - f: &mut F, -) -> ControlFlow -where - F: FnMut(&mut Expression) -> ControlFlow, -{ - match statement { - PilStatement::FunctionCall(_, _, arguments) => arguments - .iter_mut() - .try_for_each(|e| postvisit_expression_mut(e, f)), - PilStatement::PlookupIdentity(_, left, right) - | PilStatement::PermutationIdentity(_, left, right) => left - .selector - .iter_mut() - .chain(left.expressions.iter_mut()) - .chain(right.selector.iter_mut()) - .chain(right.expressions.iter_mut()) - .try_for_each(|e| postvisit_expression_mut(e, f)), - PilStatement::ConnectIdentity(_start, left, right) => left - .iter_mut() - .chain(right.iter_mut()) - .try_for_each(|e| postvisit_expression_mut(e, f)), - - PilStatement::Namespace(_, _, e) - | PilStatement::PolynomialDefinition(_, _, e) - | PilStatement::PolynomialIdentity(_, e) - | PilStatement::PublicDeclaration(_, _, _, e) - | PilStatement::ConstantDefinition(_, _, e) - | PilStatement::LetStatement(_, _, Some(e)) => postvisit_expression_mut(e, f), - - PilStatement::PolynomialConstantDefinition(_, _, fundef) - | PilStatement::PolynomialCommitDeclaration(_, _, Some(fundef)) => match fundef { - FunctionDefinition::Query(_, e) | FunctionDefinition::Mapping(_, e) => { - postvisit_expression_mut(e, f) - } - FunctionDefinition::Array(ae) => postvisit_expression_in_array_expression_mut(ae, f), - FunctionDefinition::Expression(e) => postvisit_expression_mut(e, f), - }, - PilStatement::PolynomialCommitDeclaration(_, _, None) - | PilStatement::Include(_, _) - | PilStatement::PolynomialConstantDeclaration(_, _) - | PilStatement::MacroDefinition(_, _, _, _, _) - | PilStatement::LetStatement(_, _, None) => ControlFlow::Continue(()), - } -} - -fn postvisit_expression_in_array_expression_mut( - ae: &mut ArrayExpression, - f: &mut F, -) -> ControlFlow -where - F: FnMut(&mut Expression) -> ControlFlow, -{ - match ae { - ArrayExpression::Value(expressions) | ArrayExpression::RepeatedValue(expressions) => { - expressions - .iter_mut() - .try_for_each(|e| postvisit_expression_mut(e, f)) - } - ArrayExpression::Concat(a1, a2) => [a1, a2] - .iter_mut() - .try_for_each(|e| postvisit_expression_in_array_expression_mut(e, f)), - } -} - -/// Traverses the expression tree and calls `f` in pre-order. -pub fn previsit_expression<'a, T, Ref, F, B>(e: &'a Expression, f: &mut F) -> ControlFlow -where - F: FnMut(&'a Expression) -> ControlFlow, -{ - f(e)?; - - match e { - Expression::Reference(_) - | Expression::Constant(_) - | Expression::PublicReference(_) - | Expression::Number(_) - | Expression::String(_) => {} - Expression::BinaryOperation(left, _, right) => { - previsit_expression(left, f)?; - previsit_expression(right, f)?; - } - Expression::FreeInput(e) - | Expression::UnaryOperation(_, e) - | Expression::LambdaExpression(LambdaExpression { params: _, body: e }) => { - previsit_expression(e, f)? - } - Expression::Tuple(items) - | Expression::ArrayLiteral(ArrayLiteral { items }) - | Expression::FunctionCall(FunctionCall { - id: _, - arguments: items, - }) => items - .iter() - .try_for_each(|item| previsit_expression(item, f))?, - Expression::MatchExpression(scrutinee, arms) => { - once(scrutinee.as_ref()) - .chain(arms.iter().map(|MatchArm { pattern: _, value }| value)) - .try_for_each(move |item| previsit_expression(item, f))?; - } - }; - ControlFlow::Continue(()) -} - -/// Traverses the expression tree and calls `f` in pre-order. -pub fn previsit_expression_mut( - e: &mut Expression, - f: &mut F, -) -> ControlFlow -where - F: FnMut(&mut Expression) -> ControlFlow, -{ - f(e)?; - - match e { - Expression::Reference(_) - | Expression::Constant(_) - | Expression::PublicReference(_) - | Expression::Number(_) - | Expression::String(_) => {} - Expression::BinaryOperation(left, _, right) => { - previsit_expression_mut(left, f)?; - previsit_expression_mut(right, f)?; - } - Expression::FreeInput(e) - | Expression::UnaryOperation(_, e) - | Expression::LambdaExpression(LambdaExpression { params: _, body: e }) => { - previsit_expression_mut(e.as_mut(), f)? - } - Expression::Tuple(items) - | Expression::ArrayLiteral(ArrayLiteral { items }) - | Expression::FunctionCall(FunctionCall { - arguments: items, .. - }) => items - .iter_mut() - .try_for_each(|item| previsit_expression_mut(item, f))?, - Expression::MatchExpression(scrutinee, arms) => { - once(scrutinee.as_mut()) - .chain(arms.iter_mut().map(|MatchArm { pattern: _, value }| value)) - .try_for_each(move |item| previsit_expression_mut(item, f))?; - } - }; - ControlFlow::Continue(()) -} - -/// Traverses the expression tree and calls `f` in post-order. -pub fn postvisit_expression_mut( - e: &mut Expression, - f: &mut F, -) -> ControlFlow -where - F: FnMut(&mut Expression) -> ControlFlow, -{ - match e { - Expression::Reference(_) - | Expression::Constant(_) - | Expression::PublicReference(_) - | Expression::Number(_) - | Expression::String(_) => {} - Expression::BinaryOperation(left, _, right) => { - postvisit_expression_mut(left, f)?; - postvisit_expression_mut(right, f)?; - } - Expression::FreeInput(e) - | Expression::UnaryOperation(_, e) - | Expression::LambdaExpression(LambdaExpression { params: _, body: e }) => { - postvisit_expression_mut(e.as_mut(), f)? - } - Expression::Tuple(items) - | Expression::ArrayLiteral(ArrayLiteral { items }) - | Expression::FunctionCall(FunctionCall { - arguments: items, .. - }) => items - .iter_mut() - .try_for_each(|item| postvisit_expression_mut(item, f))?, - Expression::MatchExpression(scrutinee, arms) => { - once(scrutinee.as_mut()) - .chain(arms.iter_mut().map(|MatchArm { pattern: _, value }| value)) - .try_for_each(|item| postvisit_expression_mut(item, f))?; - } - }; - f(e) -} diff --git a/ast/src/parsed/visitor.rs b/ast/src/parsed/visitor.rs new file mode 100644 index 000000000..06616f6ef --- /dev/null +++ b/ast/src/parsed/visitor.rs @@ -0,0 +1,439 @@ +use std::ops::ControlFlow; + +use super::{ + ArrayExpression, ArrayLiteral, Expression, FunctionCall, FunctionDefinition, LambdaExpression, + MatchArm, MatchPattern, PilStatement, SelectedExpressions, ShiftedPolynomialReference, +}; + +#[derive(Clone, Copy, Debug, PartialEq, Eq)] +pub enum VisitOrder { + Pre, + Post, +} + +/// A trait to be implemented by an AST node. +/// The idea is that it calls a callback function on each of the sub-nodes +/// that are expressions. +pub trait ExpressionVisitable { + /// Traverses the AST and calls `f` on each Expression in pre-order, + /// potentially break early and return a value. + fn pre_visit_expressions_return_mut(&mut self, f: &mut F) -> ControlFlow + where + F: FnMut(&mut Expression) -> ControlFlow, + { + self.visit_expressions_mut(f, VisitOrder::Pre) + } + + /// Traverses the AST and calls `f` on each Expression in pre-order. + fn pre_visit_expressions_mut(&mut self, f: &mut F) + where + F: FnMut(&mut Expression), + { + self.pre_visit_expressions_return_mut(&mut move |e| { + f(e); + ControlFlow::Continue::<()>(()) + }); + } + + /// Traverses the AST and calls `f` on each Expression in pre-order, + /// potentially break early and return a value. + fn pre_visit_expressions_return(&self, f: &mut F) -> ControlFlow + where + F: FnMut(&Expression) -> ControlFlow, + { + self.visit_expressions(f, VisitOrder::Pre) + } + + /// Traverses the AST and calls `f` on each Expression in pre-order. + fn pre_visit_expressions(&self, f: &mut F) + where + F: FnMut(&Expression), + { + self.pre_visit_expressions_return(&mut move |e| { + f(e); + ControlFlow::Continue::<()>(()) + }); + } + + /// Traverses the AST and calls `f` on each Expression in post-order, + /// potentially break early and return a value. + fn post_visit_expressions_return_mut(&mut self, f: &mut F) -> ControlFlow + where + F: FnMut(&mut Expression) -> ControlFlow, + { + self.visit_expressions_mut(f, VisitOrder::Post) + } + + /// Traverses the AST and calls `f` on each Expression in post-order. + fn post_visit_expressions_mut(&mut self, f: &mut F) + where + F: FnMut(&mut Expression), + { + self.post_visit_expressions_return_mut(&mut move |e| { + f(e); + ControlFlow::Continue::<()>(()) + }); + } + + /// Traverses the AST and calls `f` on each Expression in post-order, + /// potentially break early and return a value. + fn post_visit_expressions_return(&self, f: &mut F) -> ControlFlow + where + F: FnMut(&Expression) -> ControlFlow, + { + self.visit_expressions(f, VisitOrder::Post) + } + + /// Traverses the AST and calls `f` on each Expression in post-order. + fn post_visit_expressions(&self, f: &mut F) + where + F: FnMut(&Expression), + { + self.post_visit_expressions_return(&mut move |e| { + f(e); + ControlFlow::Continue::<()>(()) + }); + } + + fn visit_expressions(&self, f: &mut F, order: VisitOrder) -> ControlFlow + where + F: FnMut(&Expression) -> ControlFlow; + + fn visit_expressions_mut(&mut self, f: &mut F, order: VisitOrder) -> ControlFlow + where + F: FnMut(&mut Expression) -> ControlFlow; +} + +impl ExpressionVisitable for Expression { + fn visit_expressions_mut(&mut self, f: &mut F, o: VisitOrder) -> ControlFlow + where + F: FnMut(&mut Expression) -> ControlFlow, + { + if o == VisitOrder::Pre { + f(self)?; + } + match self { + Expression::Reference(_) + | Expression::Constant(_) + | Expression::PublicReference(_) + | Expression::Number(_) + | Expression::String(_) => {} + Expression::BinaryOperation(left, _, right) => { + left.visit_expressions_mut(f, o)?; + right.visit_expressions_mut(f, o)?; + } + Expression::FreeInput(e) | Expression::UnaryOperation(_, e) => { + e.visit_expressions_mut(f, o)? + } + Expression::LambdaExpression(lambda) => lambda.visit_expressions_mut(f, o)?, + Expression::ArrayLiteral(array_literal) => array_literal.visit_expressions_mut(f, o)?, + Expression::FunctionCall(function) => function.visit_expressions_mut(f, o)?, + Expression::Tuple(items) => items + .iter_mut() + .try_for_each(|item| item.visit_expressions_mut(f, o))?, + Expression::MatchExpression(scrutinee, arms) => { + scrutinee.visit_expressions_mut(f, o)?; + arms.iter_mut() + .try_for_each(|arm| arm.visit_expressions_mut(f, o))?; + } + }; + if o == VisitOrder::Post { + f(self)?; + } + ControlFlow::Continue(()) + } + + fn visit_expressions(&self, f: &mut F, o: VisitOrder) -> ControlFlow + where + F: FnMut(&Expression) -> ControlFlow, + { + if o == VisitOrder::Pre { + f(self)?; + } + match self { + Expression::Reference(_) + | Expression::Constant(_) + | Expression::PublicReference(_) + | Expression::Number(_) + | Expression::String(_) => {} + Expression::BinaryOperation(left, _, right) => { + left.visit_expressions(f, o)?; + right.visit_expressions(f, o)?; + } + Expression::FreeInput(e) | Expression::UnaryOperation(_, e) => { + e.visit_expressions(f, o)? + } + Expression::LambdaExpression(lambda) => lambda.visit_expressions(f, o)?, + Expression::ArrayLiteral(array_literal) => array_literal.visit_expressions(f, o)?, + Expression::FunctionCall(function) => function.visit_expressions(f, o)?, + Expression::Tuple(items) => items + .iter() + .try_for_each(|item| item.visit_expressions(f, o))?, + Expression::MatchExpression(scrutinee, arms) => { + scrutinee.visit_expressions(f, o)?; + arms.iter() + .try_for_each(|arm| arm.visit_expressions(f, o))?; + } + }; + if o == VisitOrder::Post { + f(self)?; + } + ControlFlow::Continue(()) + } +} + +impl ExpressionVisitable> for PilStatement { + fn visit_expressions_mut(&mut self, f: &mut F, o: VisitOrder) -> ControlFlow + where + F: FnMut(&mut Expression>) -> ControlFlow, + { + match self { + PilStatement::FunctionCall(_, _, arguments) => arguments + .iter_mut() + .try_for_each(|e| e.visit_expressions_mut(f, o)), + PilStatement::PlookupIdentity(_, left, right) + | PilStatement::PermutationIdentity(_, left, right) => [left, right] + .into_iter() + .try_for_each(|e| e.visit_expressions_mut(f, o)), + PilStatement::ConnectIdentity(_start, left, right) => left + .iter_mut() + .chain(right.iter_mut()) + .try_for_each(|e| e.visit_expressions_mut(f, o)), + + PilStatement::Namespace(_, _, e) + | PilStatement::PolynomialDefinition(_, _, e) + | PilStatement::PolynomialIdentity(_, e) + | PilStatement::PublicDeclaration(_, _, _, e) + | PilStatement::ConstantDefinition(_, _, e) + | PilStatement::LetStatement(_, _, Some(e)) => e.visit_expressions_mut(f, o), + + PilStatement::PolynomialConstantDefinition(_, _, fundef) + | PilStatement::PolynomialCommitDeclaration(_, _, Some(fundef)) => { + fundef.visit_expressions_mut(f, o) + } + PilStatement::PolynomialCommitDeclaration(_, _, None) + | PilStatement::Include(_, _) + | PilStatement::PolynomialConstantDeclaration(_, _) + | PilStatement::MacroDefinition(_, _, _, _, _) + | PilStatement::LetStatement(_, _, None) => ControlFlow::Continue(()), + } + } + + fn visit_expressions(&self, f: &mut F, o: VisitOrder) -> ControlFlow + where + F: FnMut(&Expression) -> ControlFlow, + { + match self { + PilStatement::FunctionCall(_, _, arguments) => { + arguments.iter().try_for_each(|e| e.visit_expressions(f, o)) + } + PilStatement::PlookupIdentity(_, left, right) + | PilStatement::PermutationIdentity(_, left, right) => [left, right] + .into_iter() + .try_for_each(|e| e.visit_expressions(f, o)), + PilStatement::ConnectIdentity(_start, left, right) => left + .iter() + .chain(right.iter()) + .try_for_each(|e| e.visit_expressions(f, o)), + + PilStatement::Namespace(_, _, e) + | PilStatement::PolynomialDefinition(_, _, e) + | PilStatement::PolynomialIdentity(_, e) + | PilStatement::PublicDeclaration(_, _, _, e) + | PilStatement::ConstantDefinition(_, _, e) + | PilStatement::LetStatement(_, _, Some(e)) => e.visit_expressions(f, o), + + PilStatement::PolynomialConstantDefinition(_, _, fundef) + | PilStatement::PolynomialCommitDeclaration(_, _, Some(fundef)) => { + fundef.visit_expressions(f, o) + } + PilStatement::PolynomialCommitDeclaration(_, _, None) + | PilStatement::Include(_, _) + | PilStatement::PolynomialConstantDeclaration(_, _) + | PilStatement::MacroDefinition(_, _, _, _, _) + | PilStatement::LetStatement(_, _, None) => ControlFlow::Continue(()), + } + } +} + +impl ExpressionVisitable for SelectedExpressions { + fn visit_expressions_mut(&mut self, f: &mut F, o: VisitOrder) -> ControlFlow + where + F: FnMut(&mut Expression) -> ControlFlow, + { + self.selector + .as_mut() + .into_iter() + .chain(self.expressions.iter_mut()) + .try_for_each(move |item| item.visit_expressions_mut(f, o)) + } + + fn visit_expressions(&self, f: &mut F, o: VisitOrder) -> ControlFlow + where + F: FnMut(&Expression) -> ControlFlow, + { + self.selector + .as_ref() + .into_iter() + .chain(self.expressions.iter()) + .try_for_each(move |item| item.visit_expressions(f, o)) + } +} + +impl ExpressionVisitable> for FunctionDefinition { + fn visit_expressions_mut(&mut self, f: &mut F, o: VisitOrder) -> ControlFlow + where + F: FnMut(&mut Expression) -> ControlFlow, + { + match self { + FunctionDefinition::Query(_, e) | FunctionDefinition::Mapping(_, e) => { + e.visit_expressions_mut(f, o) + } + FunctionDefinition::Array(ae) => ae.visit_expressions_mut(f, o), + FunctionDefinition::Expression(e) => e.visit_expressions_mut(f, o), + } + } + + fn visit_expressions(&self, f: &mut F, o: VisitOrder) -> ControlFlow + where + F: FnMut(&Expression) -> ControlFlow, + { + match self { + FunctionDefinition::Query(_, e) | FunctionDefinition::Mapping(_, e) => { + e.visit_expressions(f, o) + } + FunctionDefinition::Array(ae) => ae.visit_expressions(f, o), + FunctionDefinition::Expression(e) => e.visit_expressions(f, o), + } + } +} + +impl ExpressionVisitable> for ArrayExpression { + fn visit_expressions_mut(&mut self, f: &mut F, o: VisitOrder) -> ControlFlow + where + F: FnMut(&mut Expression) -> ControlFlow, + { + match self { + ArrayExpression::Value(expressions) | ArrayExpression::RepeatedValue(expressions) => { + expressions + .iter_mut() + .try_for_each(|e| e.visit_expressions_mut(f, o)) + } + ArrayExpression::Concat(a1, a2) => [a1, a2] + .iter_mut() + .try_for_each(|e| e.visit_expressions_mut(f, o)), + } + } + + fn visit_expressions(&self, f: &mut F, o: VisitOrder) -> ControlFlow + where + F: FnMut(&Expression) -> ControlFlow, + { + match self { + ArrayExpression::Value(expressions) | ArrayExpression::RepeatedValue(expressions) => { + expressions + .iter() + .try_for_each(|e| e.visit_expressions(f, o)) + } + ArrayExpression::Concat(a1, a2) => { + [a1, a2].iter().try_for_each(|e| e.visit_expressions(f, o)) + } + } + } +} + +impl ExpressionVisitable for LambdaExpression { + fn visit_expressions_mut(&mut self, f: &mut F, o: VisitOrder) -> ControlFlow + where + F: FnMut(&mut Expression) -> ControlFlow, + { + self.body.visit_expressions_mut(f, o) + } + + fn visit_expressions(&self, f: &mut F, o: VisitOrder) -> ControlFlow + where + F: FnMut(&Expression) -> ControlFlow, + { + self.body.visit_expressions(f, o) + } +} + +impl ExpressionVisitable for ArrayLiteral { + fn visit_expressions_mut(&mut self, f: &mut F, o: VisitOrder) -> ControlFlow + where + F: FnMut(&mut Expression) -> ControlFlow, + { + self.items + .iter_mut() + .try_for_each(|item| item.visit_expressions_mut(f, o)) + } + + fn visit_expressions(&self, f: &mut F, o: VisitOrder) -> ControlFlow + where + F: FnMut(&Expression) -> ControlFlow, + { + self.items + .iter() + .try_for_each(|item| item.visit_expressions(f, o)) + } +} + +impl ExpressionVisitable for FunctionCall { + fn visit_expressions_mut(&mut self, f: &mut F, o: VisitOrder) -> ControlFlow + where + F: FnMut(&mut Expression) -> ControlFlow, + { + self.arguments + .iter_mut() + .try_for_each(|item| item.visit_expressions_mut(f, o)) + } + + fn visit_expressions(&self, f: &mut F, o: VisitOrder) -> ControlFlow + where + F: FnMut(&Expression) -> ControlFlow, + { + self.arguments + .iter() + .try_for_each(|item| item.visit_expressions(f, o)) + } +} + +impl ExpressionVisitable for MatchArm { + fn visit_expressions_mut(&mut self, f: &mut F, o: VisitOrder) -> ControlFlow + where + F: FnMut(&mut Expression) -> ControlFlow, + { + self.pattern.visit_expressions_mut(f, o)?; + self.value.visit_expressions_mut(f, o) + } + + fn visit_expressions(&self, f: &mut F, o: VisitOrder) -> ControlFlow + where + F: FnMut(&Expression) -> ControlFlow, + { + self.pattern.visit_expressions(f, o)?; + self.value.visit_expressions(f, o) + } +} + +impl ExpressionVisitable for MatchPattern { + fn visit_expressions_mut(&mut self, f: &mut F, o: VisitOrder) -> ControlFlow + where + F: FnMut(&mut Expression) -> ControlFlow, + { + match self { + MatchPattern::CatchAll => ControlFlow::Continue(()), + MatchPattern::Pattern(e) => e.visit_expressions_mut(f, o), + } + } + + fn visit_expressions(&self, f: &mut F, o: VisitOrder) -> ControlFlow + where + F: FnMut(&Expression) -> ControlFlow, + { + match self { + MatchPattern::CatchAll => ControlFlow::Continue(()), + MatchPattern::Pattern(e) => e.visit_expressions(f, o), + } + } +} diff --git a/executor/src/witgen/machines/machine_extractor.rs b/executor/src/witgen/machines/machine_extractor.rs index 9652cddb0..88703a83f 100644 --- a/executor/src/witgen/machines/machine_extractor.rs +++ b/executor/src/witgen/machines/machine_extractor.rs @@ -1,5 +1,4 @@ use std::collections::HashSet; -use std::ops::ControlFlow; use super::block_machine::BlockMachine; use super::double_sorted_witness_machine::DoubleSortedWitnesses; @@ -10,10 +9,8 @@ use super::KnownMachine; use crate::witgen::{ column_map::WitnessColumnMap, generator::Generator, range_constraints::RangeConstraint, }; -use ast::analyzed::{ - util::{previsit_expressions_in_identity, previsit_expressions_in_selected_expressions}, - Expression, Identity, IdentityKind, PolyID, Reference, SelectedExpressions, -}; +use ast::analyzed::{Expression, Identity, IdentityKind, PolyID, Reference, SelectedExpressions}; +use ast::parsed::visitor::ExpressionVisitable; use itertools::Itertools; use number::FieldElement; @@ -178,9 +175,8 @@ fn all_row_connected_witnesses( /// Extracts all references to names from an identity. pub fn refs_in_identity(identity: &Identity) -> HashSet { let mut refs: HashSet = Default::default(); - previsit_expressions_in_identity(identity, &mut |expr| { + identity.pre_visit_expressions(&mut |expr| { ref_of_expression(expr).map(|id| refs.insert(id)); - ControlFlow::Continue::<()>(()) }); refs } @@ -188,9 +184,8 @@ pub fn refs_in_identity(identity: &Identity) -> HashSet { /// Extracts all references to names from selected expressions. pub fn refs_in_selected_expressions(selexpr: &SelectedExpressions) -> HashSet { let mut refs: HashSet = Default::default(); - previsit_expressions_in_selected_expressions(selexpr, &mut |expr| { + selexpr.pre_visit_expressions(&mut |expr| { ref_of_expression(expr).map(|id| refs.insert(id)); - ControlFlow::Continue::<()>(()) }); refs } diff --git a/executor/src/witgen/util.rs b/executor/src/witgen/util.rs index 0557922a9..e8849a759 100644 --- a/executor/src/witgen/util.rs +++ b/executor/src/witgen/util.rs @@ -1,7 +1,7 @@ use std::collections::HashMap; -use ast::analyzed::util::previsit_expressions_in_identity_mut; use ast::analyzed::{Expression, Identity, PolynomialReference, Reference}; +use ast::parsed::visitor::ExpressionVisitable; /// Checks if an expression is /// - a polynomial @@ -56,11 +56,10 @@ pub fn substitute_constants( .iter() .cloned() .map(|mut identity| { - previsit_expressions_in_identity_mut(&mut identity, &mut |e| { + identity.pre_visit_expressions_mut(&mut |e| { if let Expression::Constant(name) = e { *e = Expression::Number(constants[name]) } - std::ops::ControlFlow::Continue::<()>(()) }); identity }) diff --git a/pil_analyzer/src/pil_analyzer.rs b/pil_analyzer/src/pil_analyzer.rs index 042dfc103..08b0daee1 100644 --- a/pil_analyzer/src/pil_analyzer.rs +++ b/pil_analyzer/src/pil_analyzer.rs @@ -3,16 +3,14 @@ use std::fs; use std::path::{Path, PathBuf}; use analysis::MacroExpander; -use ast::parsed::utils::postvisit_expression_mut; + +use ast::parsed::visitor::ExpressionVisitable; use ast::parsed::{ self, ArrayExpression, ArrayLiteral, BinaryOperator, FunctionDefinition, LambdaExpression, MatchArm, MatchPattern, PilStatement, PolynomialName, UnaryOperator, }; use number::{DegreeType, FieldElement}; -use ast::analyzed::util::{ - postvisit_expressions_in_identity_mut, previsit_expressions_in_pil_file_mut, -}; use ast::analyzed::{ Analyzed, Expression, FunctionValueDefinition, Identity, IdentityKind, PolynomialReference, PolynomialType, PublicDeclaration, Reference, RepeatedArray, SelectedExpressions, SourceRef, @@ -83,11 +81,10 @@ impl From> for Analyzed { .unwrap_or_else(|| panic!("Column {} not found.", reference.name)); reference.poly_id = Some(poly.into()); }; - previsit_expressions_in_pil_file_mut(&mut result, &mut |e| { + result.pre_visit_expressions_mut(&mut |e| { if let Expression::Reference(Reference::Poly(reference)) = e { assign_id(reference); } - std::ops::ControlFlow::Continue::<()>(()) }); result .public_declarations @@ -724,7 +721,7 @@ fn substitute_intermediate( identities .into_iter() .scan(HashMap::default(), |cache, mut identity| { - postvisit_expressions_in_identity_mut(&mut identity, &mut |e| { + identity.post_visit_expressions_mut(&mut |e| { if let Expression::Reference(Reference::Poly(r)) = e { let poly_id = r.poly_id.unwrap(); match poly_id.ptype { @@ -740,7 +737,6 @@ fn substitute_intermediate( } } } - std::ops::ControlFlow::Continue::<()>(()) }); Some(identity) }) @@ -755,7 +751,7 @@ fn inlined_expression_from_intermediate_poly_id( cache: &mut HashMap>, ) -> Expression { let mut expr = intermediate_polynomials[&poly_id].clone(); - postvisit_expression_mut(&mut expr, &mut |e| { + expr.post_visit_expressions_mut(&mut |e| { if let Expression::Reference(Reference::Poly(r)) = e { let poly_id = r.poly_id.unwrap(); match poly_id.ptype { @@ -773,7 +769,6 @@ fn inlined_expression_from_intermediate_poly_id( } } } - std::ops::ControlFlow::Continue::<()>(()) }); cache.insert(poly_id, expr.clone()); expr diff --git a/pilopt/src/lib.rs b/pilopt/src/lib.rs index 6a1be626c..27ac040c7 100644 --- a/pilopt/src/lib.rs +++ b/pilopt/src/lib.rs @@ -1,18 +1,14 @@ //! PIL-based optimizer use std::collections::{BTreeMap, HashSet}; -use std::ops::ControlFlow; -use ast::analyzed::util::{ - postvisit_expressions_in_pil_file_mut, previsit_expressions_in_pil_file_mut, -}; use ast::analyzed::Reference; use ast::analyzed::{ build::{build_mul, build_number, build_sub}, Analyzed, BinaryOperator, Expression, FunctionValueDefinition, IdentityKind, PolyID, PolynomialReference, }; -use ast::parsed::utils::postvisit_expression_mut; +use ast::parsed::visitor::ExpressionVisitable; use ast::parsed::UnaryOperator; use number::FieldElement; @@ -54,7 +50,7 @@ fn remove_constant_fixed_columns(pil_file: &mut Analyzed) { }) .collect::>(); - previsit_expressions_in_pil_file_mut(pil_file, &mut |e| { + pil_file.pre_visit_expressions_mut(&mut |e| { if let Expression::Reference(Reference::Poly(PolynomialReference { name: _, index, @@ -67,7 +63,6 @@ fn remove_constant_fixed_columns(pil_file: &mut Analyzed) { *e = Expression::Number(*value); } } - ControlFlow::Continue::<()>(()) }); pil_file.remove_polynomials(&constant_polys.keys().cloned().collect()); @@ -103,17 +98,11 @@ fn constant_value(function: &FunctionValueDefinition) -> Opt /// Simplifies multiplications by zero and one. fn simplify_expressions(pil_file: &mut Analyzed) { - postvisit_expressions_in_pil_file_mut(pil_file, &mut |e| -> ControlFlow<()> { - simplify_expression_single(e); - ControlFlow::Continue(()) - }); + pil_file.post_visit_expressions_mut(&mut simplify_expression_single); } fn simplify_expression(mut e: Expression) -> Expression { - postvisit_expression_mut(&mut e, &mut |e| -> ControlFlow<()> { - simplify_expression_single(e); - ControlFlow::Continue(()) - }); + e.post_visit_expressions_mut(&mut simplify_expression_single); e } @@ -281,7 +270,7 @@ fn remove_constant_witness_columns(pil_file: &mut Analyzed) .filter_map(|expr| constrained_to_constant(expr)) .collect::>(); - previsit_expressions_in_pil_file_mut(pil_file, &mut |e| { + pil_file.pre_visit_expressions_mut(&mut |e| { if let Expression::Reference(Reference::Poly(PolynomialReference { name: _, index, @@ -294,7 +283,6 @@ fn remove_constant_witness_columns(pil_file: &mut Analyzed) *e = Expression::Number(*value); } } - ControlFlow::Continue::<()>(()) }); pil_file.remove_polynomials(&constant_polys.keys().cloned().collect());