From 61e2894206caa2746ad70adcd6d958602be4aa4b Mon Sep 17 00:00:00 2001 From: chriseth Date: Tue, 10 Oct 2023 12:03:50 +0200 Subject: [PATCH] Split expression visitor for Analyzed. --- ast/src/analyzed/mod.rs | 34 ++++++++++++++++++++-- ast/src/analyzed/visitor.rs | 48 -------------------------------- pil_analyzer/src/pil_analyzer.rs | 7 +++-- pilopt/src/lib.rs | 27 ++++++++++++------ 4 files changed, 55 insertions(+), 61 deletions(-) diff --git a/ast/src/analyzed/mod.rs b/ast/src/analyzed/mod.rs index 0b4e37c0e..d79edb914 100644 --- a/ast/src/analyzed/mod.rs +++ b/ast/src/analyzed/mod.rs @@ -163,14 +163,16 @@ impl Analyzed { poly.id = replacements[&poly_id].id; } }); - self.pre_visit_expressions_mut(&mut |expr| { + let visitor = &mut |expr: &mut Expression<_>| { if let Expression::Reference(Reference::Poly(poly)) = expr { poly.poly_id = poly.poly_id.map(|poly_id| { assert!(!to_remove.contains(&poly_id)); replacements[&poly_id] }); } - }); + }; + self.post_visit_expressions_in_definitions_mut(visitor); + self.post_visit_expressions_in_identities_mut(visitor); } /// Adds a polynomial identity and returns the ID. @@ -225,6 +227,34 @@ impl Analyzed { retain }) } + + pub fn post_visit_expressions_in_identities_mut(&mut self, f: &mut F) + where + F: FnMut(&mut Expression), + { + self.identities + .iter_mut() + .for_each(|i| i.post_visit_expressions_mut(f)); + } + + pub fn post_visit_expressions_in_definitions_mut(&mut self, f: &mut F) + where + F: FnMut(&mut Expression), + { + // TODO add public inputs if we change them to expressions at some point. + self.definitions + .values_mut() + .for_each(|(_poly, definition)| match definition { + Some(FunctionValueDefinition::Mapping(e)) + | Some(FunctionValueDefinition::Query(e)) => e.post_visit_expressions_mut(f), + Some(FunctionValueDefinition::Array(elements)) => elements + .iter_mut() + .flat_map(|e| e.pattern.iter_mut()) + .for_each(|e| e.post_visit_expressions_mut(f)), + Some(FunctionValueDefinition::Expression(e)) => e.post_visit_expressions_mut(f), + None => {} + }); + } } #[derive(Debug, Clone)] diff --git a/ast/src/analyzed/visitor.rs b/ast/src/analyzed/visitor.rs index 11ea2d3d3..6ebffe02c 100644 --- a/ast/src/analyzed/visitor.rs +++ b/ast/src/analyzed/visitor.rs @@ -2,54 +2,6 @@ use crate::parsed::visitor::VisitOrder; use super::*; -impl ExpressionVisitable> for Analyzed { - fn visit_expressions_mut(&mut self, f: &mut F, o: VisitOrder) -> ControlFlow - where - F: FnMut(&mut parsed::Expression) -> ControlFlow, - { - // TODO add constants if we change them to expressions at some point. - self.definitions - .values_mut() - .try_for_each(|(_poly, definition)| match definition { - Some(FunctionValueDefinition::Mapping(e)) - | Some(FunctionValueDefinition::Query(e)) => e.visit_expressions_mut(f, o), - Some(FunctionValueDefinition::Array(elements)) => elements - .iter_mut() - .flat_map(|e| e.pattern.iter_mut()) - .try_for_each(|e| e.visit_expressions_mut(f, o)), - Some(FunctionValueDefinition::Expression(e)) => e.visit_expressions_mut(f, o), - None => ControlFlow::Continue(()), - })?; - - self.identities - .iter_mut() - .try_for_each(|i| i.visit_expressions_mut(f, o)) - } - - fn visit_expressions(&self, f: &mut F, o: VisitOrder) -> ControlFlow - where - F: FnMut(&parsed::Expression) -> ControlFlow, - { - // TODO add constants if we change them to expressions at some point. - self.definitions - .values() - .try_for_each(|(_poly, definition)| match definition { - Some(FunctionValueDefinition::Mapping(e)) - | Some(FunctionValueDefinition::Query(e)) => e.visit_expressions(f, o), - Some(FunctionValueDefinition::Array(elements)) => elements - .iter() - .flat_map(|e| e.pattern.iter()) - .try_for_each(|e| e.visit_expressions(f, o)), - Some(FunctionValueDefinition::Expression(e)) => e.visit_expressions(f, o), - None => ControlFlow::Continue(()), - })?; - - self.identities - .iter() - .try_for_each(|i| i.visit_expressions(f, o)) - } -} - impl> ExpressionVisitable for Identity { fn visit_expressions_mut(&mut self, f: &mut F, o: VisitOrder) -> ControlFlow where diff --git a/pil_analyzer/src/pil_analyzer.rs b/pil_analyzer/src/pil_analyzer.rs index 0432f1483..11f84f180 100644 --- a/pil_analyzer/src/pil_analyzer.rs +++ b/pil_analyzer/src/pil_analyzer.rs @@ -77,11 +77,14 @@ impl From> for Analyzed { reference.poly_id = Some(poly.into()); } }; - result.pre_visit_expressions_mut(&mut |e| { + let expr_visitor = &mut |e: &mut Expression<_>| { if let Expression::Reference(Reference::Poly(reference)) = e { assign_id(reference); } - }); + }; + result.post_visit_expressions_in_definitions_mut(expr_visitor); + result.post_visit_expressions_in_identities_mut(expr_visitor); + // TODO at some point, merge public declarations with definitions as well. result .public_declarations .values_mut() diff --git a/pilopt/src/lib.rs b/pilopt/src/lib.rs index b99acb82f..553e141e1 100644 --- a/pilopt/src/lib.rs +++ b/pilopt/src/lib.rs @@ -47,7 +47,7 @@ pub fn optimize_constants(mut pil_file: Analyzed) -> Analyze /// Inlines references to symbols with a single constant value. fn inline_constant_values(pil_file: &mut Analyzed) { let constants = compute_constants(pil_file); - pil_file.post_visit_expressions_mut(&mut |e| { + let visitor = &mut |e: &mut Expression<_>| { if let Expression::Reference(Reference::Poly(poly)) = e { if !poly.next && poly.index.is_none() { if let Some(value) = constants.get(&poly.name) { @@ -55,12 +55,14 @@ fn inline_constant_values(pil_file: &mut Analyzed) { } } } - }); + }; + pil_file.post_visit_expressions_in_definitions_mut(visitor); + pil_file.post_visit_expressions_in_identities_mut(visitor); } /// Substitutes expression that evaluate to a constant value. fn evaluate_constant_subtrees(pil_file: &mut Analyzed) { - pil_file.post_visit_expressions_mut(&mut |e| match e { + let visitor = &mut |e: &mut Expression<_>| match e { Expression::BinaryOperation(left, op, right) => { if let (Expression::Number(l), Expression::Number(r)) = (left.as_ref(), right.as_ref()) { @@ -73,7 +75,9 @@ fn evaluate_constant_subtrees(pil_file: &mut Analyzed) { } } _ => {} - }); + }; + pil_file.post_visit_expressions_in_definitions_mut(visitor); + pil_file.post_visit_expressions_in_identities_mut(visitor); } /// Identifies fixed columns that only have a single value, replaces every @@ -95,7 +99,7 @@ fn remove_constant_fixed_columns(pil_file: &mut Analyzed) { }) .collect::>(); - pil_file.pre_visit_expressions_mut(&mut |e| { + let visitor = &mut |e: &mut Expression<_>| { if let Expression::Reference(Reference::Poly(PolynomialReference { name: _, index, @@ -108,7 +112,9 @@ fn remove_constant_fixed_columns(pil_file: &mut Analyzed) { *e = Expression::Number(*value); } } - }); + }; + pil_file.post_visit_expressions_in_definitions_mut(visitor); + pil_file.post_visit_expressions_in_identities_mut(visitor); pil_file.remove_polynomials(&constant_polys.keys().cloned().collect()); } @@ -143,7 +149,8 @@ fn constant_value(function: &FunctionValueDefinition) -> Opt /// Simplifies multiplications by zero and one. fn simplify_expressions(pil_file: &mut Analyzed) { - pil_file.post_visit_expressions_mut(&mut simplify_expression_single); + pil_file.post_visit_expressions_in_definitions_mut(&mut simplify_expression_single); + pil_file.post_visit_expressions_in_identities_mut(&mut simplify_expression_single); } fn simplify_expression(mut e: Expression) -> Expression { @@ -315,7 +322,7 @@ fn remove_constant_witness_columns(pil_file: &mut Analyzed) .filter_map(|expr| constrained_to_constant(expr)) .collect::>(); - pil_file.pre_visit_expressions_mut(&mut |e| { + let visitor = &mut |e: &mut Expression<_>| { if let Expression::Reference(Reference::Poly(PolynomialReference { name: _, index, @@ -328,7 +335,9 @@ fn remove_constant_witness_columns(pil_file: &mut Analyzed) *e = Expression::Number(*value); } } - }); + }; + pil_file.post_visit_expressions_in_definitions_mut(visitor); + pil_file.post_visit_expressions_in_identities_mut(visitor); pil_file.remove_polynomials(&constant_polys.keys().cloned().collect()); }