diff --git a/pil_analyzer/src/expression_processor.rs b/pil_analyzer/src/expression_processor.rs new file mode 100644 index 000000000..089add6cd --- /dev/null +++ b/pil_analyzer/src/expression_processor.rs @@ -0,0 +1,181 @@ +use std::collections::HashMap; + +use ast::{ + analyzed::{Expression, PolynomialReference, Reference, RepeatedArray}, + parsed::{ + self, ArrayExpression, ArrayLiteral, LambdaExpression, MatchArm, MatchPattern, + NamespacedPolynomialReference, SelectedExpressions, + }, +}; +use number::DegreeType; + +/// The ExpressionProcessor turns parsed expressions into analyzed expressions. +/// Its main job is to resolve references: +/// It turns simple references into fully namespaced references and resolves local function variables. +/// It also evaluates expressions that are required to be compile-time constant. +pub struct ExpressionProcessor { + resolver: R, + local_variables: HashMap, +} + +pub trait ReferenceResolver { + /// Turns a reference to a name with an optional namespace to an absolute name. + fn resolve(&self, namespace: &Option, name: &str) -> String; +} + +impl ExpressionProcessor { + pub fn new(resolver: R) -> Self { + Self { + resolver, + local_variables: Default::default(), + } + } + + pub fn process_selected_expression( + &mut self, + expr: SelectedExpressions>, + ) -> SelectedExpressions> { + SelectedExpressions { + selector: expr.selector.map(|e| self.process_expression(e)), + expressions: self.process_expressions(expr.expressions), + } + } + + pub fn process_array_expression( + &mut self, + array_expression: ::ast::parsed::ArrayExpression, + size: DegreeType, + ) -> Vec> { + match array_expression { + ArrayExpression::Value(expressions) => { + let values = self.process_expressions(expressions); + let size = values.len() as DegreeType; + vec![RepeatedArray::new(values, size)] + } + ArrayExpression::RepeatedValue(expressions) => { + if size == 0 { + vec![] + } else { + vec![RepeatedArray::new( + self.process_expressions(expressions), + size, + )] + } + } + ArrayExpression::Concat(left, right) => self + .process_array_expression(*left, size) + .into_iter() + .chain(self.process_array_expression(*right, size)) + .collect(), + } + } + + pub fn process_expressions( + &mut self, + exprs: Vec>, + ) -> Vec> { + exprs + .into_iter() + .map(|e| self.process_expression(e)) + .collect() + } + + pub fn process_expression(&mut self, expr: parsed::Expression) -> Expression { + use parsed::Expression as PExpression; + match expr { + PExpression::Reference(poly) => Expression::Reference(self.process_reference(poly)), + PExpression::PublicReference(name) => Expression::PublicReference(name), + PExpression::Number(n) => Expression::Number(n), + PExpression::String(value) => Expression::String(value), + PExpression::Tuple(items) => Expression::Tuple(self.process_expressions(items)), + PExpression::ArrayLiteral(ArrayLiteral { items }) => { + Expression::ArrayLiteral(ArrayLiteral { + items: self.process_expressions(items), + }) + } + PExpression::LambdaExpression(LambdaExpression { params, body }) => { + let body = Box::new(self.process_function(¶ms, *body)); + Expression::LambdaExpression(LambdaExpression { params, body }) + } + PExpression::BinaryOperation(left, op, right) => Expression::BinaryOperation( + Box::new(self.process_expression(*left)), + op, + Box::new(self.process_expression(*right)), + ), + PExpression::UnaryOperation(op, value) => { + Expression::UnaryOperation(op, Box::new(self.process_expression(*value))) + } + PExpression::IndexAccess(index_access) => { + Expression::IndexAccess(parsed::IndexAccess { + array: Box::new(self.process_expression(*index_access.array)), + index: Box::new(self.process_expression(*index_access.index)), + }) + } + PExpression::FunctionCall(c) => Expression::FunctionCall(parsed::FunctionCall { + function: Box::new(self.process_expression(*c.function)), + arguments: self.process_expressions(c.arguments), + }), + PExpression::MatchExpression(scrutinee, arms) => Expression::MatchExpression( + Box::new(self.process_expression(*scrutinee)), + arms.into_iter() + .map(|MatchArm { pattern, value }| MatchArm { + pattern: match pattern { + MatchPattern::CatchAll => MatchPattern::CatchAll, + MatchPattern::Pattern(e) => { + MatchPattern::Pattern(self.process_expression(e)) + } + }, + value: self.process_expression(value), + }) + .collect(), + ), + PExpression::FreeInput(_) => panic!(), + } + } + + fn process_reference(&mut self, reference: NamespacedPolynomialReference) -> Reference { + if reference.namespace.is_none() && self.local_variables.contains_key(&reference.name) { + let id = self.local_variables[&reference.name]; + Reference::LocalVar(id, reference.name.to_string()) + } else { + Reference::Poly(self.process_namespaced_polynomial_reference(reference)) + } + } + + pub fn process_function( + &mut self, + params: &[String], + expression: ::ast::parsed::Expression, + ) -> Expression { + let previous_local_vars = std::mem::take(&mut self.local_variables); + + assert!(self.local_variables.is_empty()); + self.local_variables = params + .iter() + .enumerate() + .map(|(i, p)| (p.clone(), i as u64)) + .collect(); + // Re-add the outer local variables if we do not overwrite them + // and increase their index by the number of parameters. + // TODO re-evaluate if this mechanism makes sense as soon as we properly + // support nested functions and closures. + for (name, index) in &previous_local_vars { + self.local_variables + .entry(name.clone()) + .or_insert(index + params.len() as u64); + } + let processed_value = self.process_expression(expression); + self.local_variables = previous_local_vars; + processed_value + } + + pub fn process_namespaced_polynomial_reference( + &mut self, + poly: ::ast::parsed::NamespacedPolynomialReference, + ) -> PolynomialReference { + PolynomialReference { + name: self.resolver.resolve(&poly.namespace, &poly.name), + poly_id: None, + } + } +} diff --git a/pil_analyzer/src/lib.rs b/pil_analyzer/src/lib.rs index 9ce37c432..1d654f89e 100644 --- a/pil_analyzer/src/lib.rs +++ b/pil_analyzer/src/lib.rs @@ -2,6 +2,7 @@ mod condenser; pub mod evaluator; +pub mod expression_processor; pub mod pil_analyzer; use std::path::Path; diff --git a/pil_analyzer/src/pil_analyzer.rs b/pil_analyzer/src/pil_analyzer.rs index 402e37196..1dd59684a 100644 --- a/pil_analyzer/src/pil_analyzer.rs +++ b/pil_analyzer/src/pil_analyzer.rs @@ -6,19 +6,18 @@ use analysis::MacroExpander; use ast::parsed::visitor::ExpressionVisitable; use ast::parsed::{ - self, ArrayExpression, ArrayLiteral, FunctionDefinition, LambdaExpression, MatchArm, - MatchPattern, NamespacedPolynomialReference, PilStatement, PolynomialName, SelectedExpressions, + self, FunctionDefinition, LambdaExpression, PilStatement, PolynomialName, SelectedExpressions, }; use number::{DegreeType, FieldElement}; use ast::analyzed::{ AlgebraicExpression, Analyzed, Expression, FunctionValueDefinition, Identity, IdentityKind, - PolynomialReference, PolynomialType, PublicDeclaration, Reference, RepeatedArray, SourceRef, - StatementIdentifier, Symbol, SymbolKind, + PolynomialType, PublicDeclaration, SourceRef, StatementIdentifier, Symbol, SymbolKind, }; use crate::evaluator::EvalError; -use crate::{condenser, evaluator}; +use crate::expression_processor::ReferenceResolver; +use crate::{condenser, evaluator, expression_processor::ExpressionProcessor}; pub fn process_pil_file(path: &Path) -> Analyzed { let mut analyzer = PILAnalyzer::new(); @@ -258,25 +257,28 @@ impl PILAnalyzer { PilStatement::PlookupIdentity(start, key, haystack) => ( start, IdentityKind::Plookup, - ExpressionProcessor::new(self).process_selected_expression(key), - ExpressionProcessor::new(self).process_selected_expression(haystack), + self.expression_processor().process_selected_expression(key), + self.expression_processor() + .process_selected_expression(haystack), ), PilStatement::PermutationIdentity(start, left, right) => ( start, IdentityKind::Permutation, - ExpressionProcessor::new(self).process_selected_expression(left), - ExpressionProcessor::new(self).process_selected_expression(right), + self.expression_processor() + .process_selected_expression(left), + self.expression_processor() + .process_selected_expression(right), ), PilStatement::ConnectIdentity(start, left, right) => ( start, IdentityKind::Connect, SelectedExpressions { selector: None, - expressions: ExpressionProcessor::new(self).process_expressions(left), + expressions: self.expression_processor().process_expressions(left), }, SelectedExpressions { selector: None, - expressions: ExpressionProcessor::new(self).process_expressions(right), + expressions: self.expression_processor().process_expressions(right), }, ), // TODO at some point, these should all be caught by the type checker. @@ -376,7 +378,7 @@ impl PILAnalyzer { FunctionDefinition::Query(params, expr) => { assert!(!have_array_size); assert_eq!(symbol_kind, SymbolKind::Poly(PolynomialType::Committed)); - let body = Box::new(ExpressionProcessor::new(self).process_function(¶ms, expr)); + let body = Box::new(self.expression_processor().process_function(¶ms, expr)); FunctionValueDefinition::Query(Expression::LambdaExpression(LambdaExpression { params, body, @@ -384,8 +386,9 @@ impl PILAnalyzer { } FunctionDefinition::Array(value) => { let size = value.solve(self.polynomial_degree.unwrap()); - let expression = - ExpressionProcessor::new(self).process_array_expression(value, size); + let expression = self + .expression_processor() + .process_array_expression(value, size); assert_eq!( expression.iter().map(|e| e.size()).sum::(), self.polynomial_degree.unwrap() @@ -412,8 +415,9 @@ impl PILAnalyzer { index: parsed::Expression, ) { let id = self.public_declarations.len() as u64; - let polynomial = - ExpressionProcessor::new(self).process_namespaced_polynomial_reference(poly); + let polynomial = self + .expression_processor() + .process_namespaced_polynomial_reference(poly); let array_index = array_index.map(|i| { let index = self.evaluate_expression(i).unwrap().to_degree(); assert!(index <= usize::MAX as u64); @@ -445,186 +449,30 @@ impl PILAnalyzer { format!("{}.{name}", self.namespace) } - pub fn namespaced_ref_to_absolute(&self, namespace: &Option, name: &str) -> String { - if name.starts_with('%') || self.definitions.contains_key(&name.to_string()) { - assert!(namespace.is_none()); - // Constants are not namespaced - name.to_string() - } else { - format!("{}.{name}", namespace.as_ref().unwrap_or(&self.namespace)) - } - } - fn evaluate_expression(&self, expr: ::ast::parsed::Expression) -> Result { evaluator::evaluate_expression(&self.process_expression(expr), &self.definitions)? .try_to_number() } + fn expression_processor(&self) -> ExpressionProcessor> { + ExpressionProcessor::new(PILResolver(self)) + } + fn process_expression(&self, expr: ::ast::parsed::Expression) -> Expression { - ExpressionProcessor::new(self).process_expression(expr) + self.expression_processor().process_expression(expr) } } -/// The ExpressionProcessor turns parsed expressions into analyzed expressions. -/// Its main job is to resolve references: -/// It turns simple references into fully namespaced references and resolves local function variables. -/// It also evaluates expressions that are required to be compile-time constant. -struct ExpressionProcessor<'a, T> { - analyzer: &'a PILAnalyzer, - local_variables: HashMap, -} +struct PILResolver<'a, T>(&'a PILAnalyzer); -impl<'a, T: FieldElement> ExpressionProcessor<'a, T> { - fn new(analyzer: &'a PILAnalyzer) -> Self { - Self { - analyzer, - local_variables: Default::default(), - } - } - - pub fn process_selected_expression( - &mut self, - expr: SelectedExpressions>, - ) -> SelectedExpressions> { - SelectedExpressions { - selector: expr.selector.map(|e| self.process_expression(e)), - expressions: self.process_expressions(expr.expressions), - } - } - - pub fn process_array_expression( - &mut self, - array_expression: ::ast::parsed::ArrayExpression, - size: DegreeType, - ) -> Vec> { - match array_expression { - ArrayExpression::Value(expressions) => { - let values = self.process_expressions(expressions); - let size = values.len() as DegreeType; - vec![RepeatedArray::new(values, size)] - } - ArrayExpression::RepeatedValue(expressions) => { - if size == 0 { - vec![] - } else { - vec![RepeatedArray::new( - self.process_expressions(expressions), - size, - )] - } - } - ArrayExpression::Concat(left, right) => self - .process_array_expression(*left, size) - .into_iter() - .chain(self.process_array_expression(*right, size)) - .collect(), - } - } - - pub fn process_expressions(&mut self, exprs: Vec>) -> Vec> { - exprs - .into_iter() - .map(|e| self.process_expression(e)) - .collect() - } - - pub fn process_expression(&mut self, expr: parsed::Expression) -> Expression { - use parsed::Expression as PExpression; - match expr { - PExpression::Reference(poly) => Expression::Reference(self.process_reference(poly)), - PExpression::PublicReference(name) => Expression::PublicReference(name), - PExpression::Number(n) => Expression::Number(n), - PExpression::String(value) => Expression::String(value), - PExpression::Tuple(items) => Expression::Tuple(self.process_expressions(items)), - PExpression::ArrayLiteral(ArrayLiteral { items }) => { - Expression::ArrayLiteral(ArrayLiteral { - items: self.process_expressions(items), - }) - } - PExpression::LambdaExpression(LambdaExpression { params, body }) => { - let body = Box::new(self.process_function(¶ms, *body)); - Expression::LambdaExpression(LambdaExpression { params, body }) - } - PExpression::BinaryOperation(left, op, right) => Expression::BinaryOperation( - Box::new(self.process_expression(*left)), - op, - Box::new(self.process_expression(*right)), - ), - PExpression::UnaryOperation(op, value) => { - Expression::UnaryOperation(op, Box::new(self.process_expression(*value))) - } - PExpression::IndexAccess(index_access) => { - Expression::IndexAccess(parsed::IndexAccess { - array: Box::new(self.process_expression(*index_access.array)), - index: Box::new(self.process_expression(*index_access.index)), - }) - } - PExpression::FunctionCall(c) => Expression::FunctionCall(parsed::FunctionCall { - function: Box::new(self.process_expression(*c.function)), - arguments: self.process_expressions(c.arguments), - }), - PExpression::MatchExpression(scrutinee, arms) => Expression::MatchExpression( - Box::new(self.process_expression(*scrutinee)), - arms.into_iter() - .map(|MatchArm { pattern, value }| MatchArm { - pattern: match pattern { - MatchPattern::CatchAll => MatchPattern::CatchAll, - MatchPattern::Pattern(e) => { - MatchPattern::Pattern(self.process_expression(e)) - } - }, - value: self.process_expression(value), - }) - .collect(), - ), - PExpression::FreeInput(_) => panic!(), - } - } - - fn process_reference(&mut self, reference: NamespacedPolynomialReference) -> Reference { - if reference.namespace.is_none() && self.local_variables.contains_key(&reference.name) { - let id = self.local_variables[&reference.name]; - Reference::LocalVar(id, reference.name.to_string()) +impl<'a, T: FieldElement> ReferenceResolver for PILResolver<'a, T> { + fn resolve(&self, namespace: &Option, name: &str) -> String { + if name.starts_with('%') || self.0.definitions.contains_key(&name.to_string()) { + assert!(namespace.is_none()); + // Constants are not namespaced + name.to_string() } else { - Reference::Poly(self.process_namespaced_polynomial_reference(reference)) - } - } - - fn process_function( - &mut self, - params: &[String], - expression: ::ast::parsed::Expression, - ) -> Expression { - let previous_local_vars = std::mem::take(&mut self.local_variables); - - assert!(self.local_variables.is_empty()); - self.local_variables = params - .iter() - .enumerate() - .map(|(i, p)| (p.clone(), i as u64)) - .collect(); - // Re-add the outer local variables if we do not overwrite them - // and increase their index by the number of parameters. - for (name, index) in &previous_local_vars { - self.local_variables - .entry(name.clone()) - .or_insert(index + params.len() as u64); - } - let processed_value = self.process_expression(expression); - self.local_variables = previous_local_vars; - processed_value - } - - pub fn process_namespaced_polynomial_reference( - &mut self, - poly: ::ast::parsed::NamespacedPolynomialReference, - ) -> PolynomialReference { - let name = self - .analyzer - .namespaced_ref_to_absolute(&poly.namespace, &poly.name); - PolynomialReference { - name, - poly_id: None, + format!("{}.{name}", namespace.as_ref().unwrap_or(&self.0.namespace)) } } }