From a9585d2adde2e9c7499ffac6df9677c1f9af7875 Mon Sep 17 00:00:00 2001 From: chriseth Date: Sat, 30 Sep 2023 00:11:53 +0200 Subject: [PATCH] Extract evaluator and processor. --- executor/src/constant_evaluator/mod.rs | 67 ++------ pil_analyzer/src/evaluator.rs | 94 +++++++++++ pil_analyzer/src/lib.rs | 1 + pil_analyzer/src/pil_analyzer.rs | 210 +++++++++++-------------- 4 files changed, 201 insertions(+), 171 deletions(-) create mode 100644 pil_analyzer/src/evaluator.rs diff --git a/executor/src/constant_evaluator/mod.rs b/executor/src/constant_evaluator/mod.rs index ee3700aeb..bfeab5e13 100644 --- a/executor/src/constant_evaluator/mod.rs +++ b/executor/src/constant_evaluator/mod.rs @@ -1,10 +1,9 @@ use std::collections::HashMap; -use ast::analyzed::{Analyzed, Expression, FunctionValueDefinition, Reference}; -use ast::parsed::{FunctionCall, MatchArm, MatchPattern}; -use ast::{evaluate_binary_operation, evaluate_unary_operation}; +use ast::analyzed::{Analyzed, FunctionValueDefinition}; use itertools::Itertools; use number::{DegreeType, FieldElement}; +use pil_analyzer::evaluator::Evaluator; use rayon::prelude::{IntoParallelIterator, ParallelIterator}; /// Generates the constant polynomial values for all constant polynomials @@ -42,18 +41,21 @@ fn generate_values( .into_par_iter() .map(|i| { Evaluator { - analyzed, + constants: &analyzed.constants, + definitions: &analyzed.definitions, variables: &[i.into()], - other_constants, + function_cache: other_constants, } .evaluate(body) + .unwrap() }) .collect(), FunctionValueDefinition::Array(values) => { let evaluator = Evaluator { - analyzed, + constants: &analyzed.constants, + definitions: &analyzed.definitions, variables: &[], - other_constants, + function_cache: other_constants, }; let values: Vec<_> = values .iter() @@ -61,7 +63,7 @@ fn generate_values( let items = elements .pattern() .iter() - .map(|v| evaluator.evaluate(v)) + .map(|v| evaluator.evaluate(v).unwrap()) .collect::>(); items @@ -81,55 +83,6 @@ fn generate_values( } } -struct Evaluator<'a, T> { - analyzed: &'a Analyzed, - other_constants: &'a HashMap<&'a str, Vec>, - variables: &'a [T], -} - -impl<'a, T: FieldElement> Evaluator<'a, T> { - fn evaluate(&self, expr: &Expression) -> T { - match expr { - Expression::Constant(name) => self.analyzed.constants[name], - Expression::Reference(Reference::LocalVar(i, _name)) => self.variables[*i as usize], - Expression::Reference(Reference::Poly(_)) => todo!(), - Expression::PublicReference(_) => todo!(), - Expression::Number(n) => *n, - Expression::String(_) => panic!(), - Expression::Tuple(_) => panic!(), - Expression::ArrayLiteral(_) => panic!(), - Expression::BinaryOperation(left, op, right) => { - evaluate_binary_operation(self.evaluate(left), *op, self.evaluate(right)) - } - Expression::UnaryOperation(op, expr) => { - evaluate_unary_operation(*op, self.evaluate(expr)) - } - Expression::LambdaExpression(_) => panic!(), - Expression::FunctionCall(FunctionCall { id, arguments }) => { - let arg_values = arguments - .iter() - .map(|a| self.evaluate(a)) - .collect::>(); - assert!(arg_values.len() == 1); - let values = &self.other_constants[id.as_str()]; - values[arg_values[0].to_degree() as usize % values.len()] - } - Expression::MatchExpression(scrutinee, arms) => { - let v = self.evaluate(scrutinee); - arms.iter() - .find_map(|MatchArm { pattern, value }| match pattern { - MatchPattern::Pattern(p) => { - (self.evaluate(p) == v).then(|| self.evaluate(value)) - } - MatchPattern::CatchAll => Some(self.evaluate(value)), - }) - .expect("No arm matched the value {v}") - } - Expression::FreeInput(_) => panic!(), - } - } -} - #[cfg(test)] mod test { use number::GoldilocksField; diff --git a/pil_analyzer/src/evaluator.rs b/pil_analyzer/src/evaluator.rs new file mode 100644 index 000000000..357315f1a --- /dev/null +++ b/pil_analyzer/src/evaluator.rs @@ -0,0 +1,94 @@ +use std::collections::HashMap; + +use ast::{ + analyzed::{Analyzed, Expression, FunctionValueDefinition, Reference, Symbol}, + evaluate_binary_operation, evaluate_unary_operation, + parsed::{FunctionCall, MatchArm, MatchPattern}, +}; +use number::FieldElement; + +/// Evaluates an expression to a single value. +pub fn evaluate_expression( + analyzed: &Analyzed, + expression: &Expression, +) -> Result { + Evaluator { + constants: &analyzed.constants, + definitions: &analyzed.definitions, + function_cache: &Default::default(), + variables: &[], + } + .evaluate(expression) +} + +pub struct Evaluator<'a, T> { + pub constants: &'a HashMap, + pub definitions: &'a HashMap>)>, + /// Contains full value tables of functions (columns) we already evaluated. + pub function_cache: &'a HashMap<&'a str, Vec>, + pub variables: &'a [T], +} + +impl<'a, T: FieldElement> Evaluator<'a, T> { + pub fn evaluate(&self, expr: &Expression) -> Result { + match expr { + Expression::Constant(name) => Ok(self.constants[name]), + Expression::Reference(Reference::LocalVar(i, _name)) => Ok(self.variables[*i as usize]), + Expression::Reference(Reference::Poly(poly)) => { + if !poly.next && poly.index.is_none() { + let name = poly.name.to_owned(); + if let Some(value) = self.constants.get(&name) { + Ok(*value) + } else { + let (_, value) = &self.definitions[&name]; + match value { + Some(FunctionValueDefinition::Expression(value)) => { + self.evaluate(value) + } + _ => Err("Cannot evaluate function values".to_string()), + } + } + } else { + Err("Cannot evaluate arrays or next references.".to_string()) + } + } + Expression::PublicReference(r) => Err(format!("Cannot evaluate public reference: {r}")), + Expression::Number(n) => Ok(*n), + Expression::String(_) => Err("Cannot evaluate string literal.".to_string()), + Expression::Tuple(_) => Err("Cannot evaluate tuple.".to_string()), + Expression::ArrayLiteral(_) => Err("Cannot evaluate array literal.".to_string()), + Expression::BinaryOperation(left, op, right) => Ok(evaluate_binary_operation( + self.evaluate(left)?, + *op, + self.evaluate(right)?, + )), + Expression::UnaryOperation(op, expr) => { + Ok(evaluate_unary_operation(*op, self.evaluate(expr)?)) + } + Expression::LambdaExpression(_) => { + Err("Cannot evaluate lambda expression.".to_string()) + } + Expression::FunctionCall(FunctionCall { id, arguments }) => { + let arg_values = arguments + .iter() + .map(|a| self.evaluate(a)) + .collect::, _>>()?; + assert!(arg_values.len() == 1); + let values = &self.function_cache[id.as_str()]; + Ok(values[arg_values[0].to_degree() as usize % values.len()]) + } + Expression::MatchExpression(scrutinee, arms) => { + let v = self.evaluate(scrutinee); + arms.iter() + .find_map(|MatchArm { pattern, value }| match pattern { + MatchPattern::Pattern(p) => { + (self.evaluate(p) == v).then(|| self.evaluate(value)) + } + MatchPattern::CatchAll => Some(self.evaluate(value)), + }) + .expect("No arm matched the value {v}") + } + Expression::FreeInput(_) => Err("Cannot evaluate free input.".to_string()), + } + } +} diff --git a/pil_analyzer/src/lib.rs b/pil_analyzer/src/lib.rs index 92314ef71..5793fb6ad 100644 --- a/pil_analyzer/src/lib.rs +++ b/pil_analyzer/src/lib.rs @@ -1,5 +1,6 @@ #![deny(clippy::print_stdout)] +pub mod evaluator; pub mod pil_analyzer; use std::path::Path; diff --git a/pil_analyzer/src/pil_analyzer.rs b/pil_analyzer/src/pil_analyzer.rs index 4ba23a852..6763ea2ef 100644 --- a/pil_analyzer/src/pil_analyzer.rs +++ b/pil_analyzer/src/pil_analyzer.rs @@ -6,8 +6,8 @@ use analysis::MacroExpander; use ast::parsed::visitor::ExpressionVisitable; use ast::parsed::{ - self, ArrayExpression, ArrayLiteral, BinaryOperator, FunctionDefinition, LambdaExpression, - MatchArm, MatchPattern, PilStatement, PolynomialName, UnaryOperator, + self, ArrayExpression, ArrayLiteral, FunctionDefinition, LambdaExpression, MatchArm, + MatchPattern, PilStatement, PolynomialName, }; use number::{DegreeType, FieldElement}; @@ -17,6 +17,8 @@ use ast::analyzed::{ StatementIdentifier, Symbol, SymbolKind, }; +use crate::evaluator::Evaluator; + pub fn process_pil_file(path: &Path) -> Analyzed { let mut analyzer = PILAnalyzer::new(); analyzer.process_file(path); @@ -46,7 +48,6 @@ struct PILAnalyzer { current_file: PathBuf, symbol_counters: BTreeMap, identity_counter: HashMap, - local_variables: HashMap, macro_expander: MacroExpander, } @@ -196,7 +197,7 @@ impl PILAnalyzer { ); } PilStatement::ConstantDefinition(_start, name, value) => { - self.handle_constant_definition(name, self.evaluate_expression(&value).unwrap()) + self.handle_constant_definition(name, self.evaluate_expression(value).unwrap()) } PilStatement::LetStatement(start, name, value) => { self.handle_generic_definition(start, name, value) @@ -255,7 +256,7 @@ impl PILAnalyzer { ); } _ => { - if let Some(constant) = self.evaluate_expression(&value) { + if let Ok(constant) = self.evaluate_expression(value.clone()) { // Value evaluates to a constant number => treat it as a constant self.handle_constant_definition(name.to_string(), constant); } else { @@ -288,25 +289,25 @@ impl PILAnalyzer { PilStatement::PlookupIdentity(start, key, haystack) => ( start, IdentityKind::Plookup, - self.process_selected_expression(key), - self.process_selected_expression(haystack), + ExpressionProcessor::new(self).process_selected_expression(key), + ExpressionProcessor::new(self).process_selected_expression(haystack), ), PilStatement::PermutationIdentity(start, left, right) => ( start, IdentityKind::Permutation, - self.process_selected_expression(left), - self.process_selected_expression(right), + ExpressionProcessor::new(self).process_selected_expression(left), + ExpressionProcessor::new(self).process_selected_expression(right), ), PilStatement::ConnectIdentity(start, left, right) => ( start, IdentityKind::Connect, SelectedExpressions { selector: None, - expressions: self.process_expressions(left), + expressions: ExpressionProcessor::new(self).process_expressions(left), }, SelectedExpressions { selector: None, - expressions: self.process_expressions(right), + expressions: ExpressionProcessor::new(self).process_expressions(right), }, ), // TODO at some point, these should all be caught by the type checker. @@ -335,7 +336,7 @@ impl PILAnalyzer { fn handle_namespace(&mut self, name: String, degree: ::ast::parsed::Expression) { // TODO: the polynomial degree should be handled without going through a field element. This requires having types in Expression - self.polynomial_degree = self.evaluate_expression(°ree).unwrap().to_degree(); + self.polynomial_degree = self.evaluate_expression(degree).unwrap().to_degree(); self.namespace = name; } @@ -366,7 +367,7 @@ impl PILAnalyzer { ) -> u64 { let have_array_size = array_size.is_some(); let length = array_size - .map(|l| self.evaluate_expression(&l).unwrap()) + .map(|l| self.evaluate_expression(l).unwrap()) .map(|l| l.to_degree()); if length.is_some() { assert!(value.is_none()); @@ -397,16 +398,21 @@ impl PILAnalyzer { FunctionDefinition::Mapping(params, expr) => { assert!(!have_array_size); assert!(symbol_kind == SymbolKind::Poly(PolynomialType::Constant)); - FunctionValueDefinition::Mapping(self.process_function(¶ms, expr)) + FunctionValueDefinition::Mapping( + ExpressionProcessor::new(self).process_function(¶ms, expr), + ) } FunctionDefinition::Query(params, expr) => { assert!(!have_array_size); assert_eq!(symbol_kind, SymbolKind::Poly(PolynomialType::Committed)); - FunctionValueDefinition::Query(self.process_function(¶ms, expr)) + FunctionValueDefinition::Query( + ExpressionProcessor::new(self).process_function(¶ms, expr), + ) } FunctionDefinition::Array(value) => { let size = value.solve(self.polynomial_degree); - let expression = self.process_array_expression(value, size); + let expression = + ExpressionProcessor::new(self).process_array_expression(value, size); assert_eq!( expression.iter().map(|e| e.size()).sum::(), self.polynomial_degree @@ -424,33 +430,6 @@ impl PILAnalyzer { id } - fn process_function( - &mut self, - params: &[String], - expression: ::ast::parsed::Expression, - ) -> Expression { - let previous_local_vars = std::mem::take(&mut self.local_variables); - - assert!(self.local_variables.is_empty()); - self.local_variables = params - .iter() - .enumerate() - .map(|(i, p)| (p.clone(), i as u64)) - .collect(); - // Re-add the outer local variables if we do not overwrite them - // and increase their index by the number of parameters. - // TODO re-evaluate if this mechanism makes sense as soon as we properly - // support nested functions and closures. - for (name, index) in &previous_local_vars { - self.local_variables - .entry(name.clone()) - .or_insert(index + params.len() as u64); - } - let processed_value = self.process_expression(expression); - self.local_variables = previous_local_vars; - processed_value - } - fn handle_public_declaration( &mut self, source: SourceRef, @@ -465,8 +444,9 @@ impl PILAnalyzer { id, source, name: name.to_string(), - polynomial: self.process_namespaced_polynomial_reference(poly), - index: self.evaluate_expression(&index).unwrap().to_degree(), + polynomial: ExpressionProcessor::new(self) + .process_namespaced_polynomial_reference(poly), + index: self.evaluate_expression(index).unwrap().to_degree(), }, ); self.source_order @@ -485,7 +465,7 @@ impl PILAnalyzer { id } - fn namespaced(&self, name: &str) -> String { + pub fn namespaced(&self, name: &str) -> String { self.namespaced_ref(&None, name) } @@ -499,7 +479,39 @@ impl PILAnalyzer { } } - fn process_selected_expression( + fn evaluate_expression(&self, expr: ::ast::parsed::Expression) -> Result { + Evaluator { + constants: &self.constants, + definitions: &self.definitions, + function_cache: &Default::default(), + variables: &[], + } + .evaluate(&self.process_expression(expr)) + } + + fn process_expression(&self, expr: ::ast::parsed::Expression) -> Expression { + ExpressionProcessor::new(self).process_expression(expr) + } +} + +/// The ExpressionProcessor turns parsed expressions into analyzed expressions. +/// Its main job is to resolve references: +/// It turns simple references into fully namespaced references and resolves local function variables. +/// It also evaluates expressions that are required to be compile-time constant. +struct ExpressionProcessor<'a, T> { + analyzer: &'a PILAnalyzer, + local_variables: HashMap, +} + +impl<'a, T: FieldElement> ExpressionProcessor<'a, T> { + fn new(analyzer: &'a PILAnalyzer) -> Self { + Self { + analyzer, + local_variables: Default::default(), + } + } + + pub fn process_selected_expression( &mut self, expr: ::ast::parsed::SelectedExpressions, ) -> SelectedExpressions { @@ -509,7 +521,7 @@ impl PILAnalyzer { } } - fn process_array_expression( + pub fn process_array_expression( &mut self, array_expression: ::ast::parsed::ArrayExpression, size: DegreeType, @@ -538,7 +550,7 @@ impl PILAnalyzer { } } - fn process_expressions( + pub fn process_expressions( &mut self, exprs: Vec<::ast::parsed::Expression>, ) -> Vec> { @@ -548,7 +560,7 @@ impl PILAnalyzer { .collect() } - fn process_expression(&mut self, expr: ::ast::parsed::Expression) -> Expression { + pub fn process_expression(&mut self, expr: ::ast::parsed::Expression) -> Expression { use ast::parsed::Expression as PExpression; match expr { PExpression::Constant(name) => Expression::Constant(name), @@ -586,7 +598,7 @@ impl PILAnalyzer { Expression::UnaryOperation(op, Box::new(self.process_expression(*value))) } PExpression::FunctionCall(c) => Expression::FunctionCall(parsed::FunctionCall { - id: self.namespaced(&c.id), + id: self.analyzer.namespaced(&c.id), arguments: self.process_expressions(c.arguments), }), PExpression::MatchExpression(scrutinee, arms) => Expression::MatchExpression( @@ -607,16 +619,43 @@ impl PILAnalyzer { } } - fn process_namespaced_polynomial_reference( - &self, + fn process_function( + &mut self, + params: &[String], + expression: ::ast::parsed::Expression, + ) -> Expression { + let previous_local_vars = std::mem::take(&mut self.local_variables); + + assert!(self.local_variables.is_empty()); + self.local_variables = params + .iter() + .enumerate() + .map(|(i, p)| (p.clone(), i as u64)) + .collect(); + // Re-add the outer local variables if we do not overwrite them + // and increase their index by the number of parameters. + // TODO re-evaluate if this mechanism makes sense as soon as we properly + // support nested functions and closures. + for (name, index) in &previous_local_vars { + self.local_variables + .entry(name.clone()) + .or_insert(index + params.len() as u64); + } + let processed_value = self.process_expression(expression); + self.local_variables = previous_local_vars; + processed_value + } + + pub fn process_namespaced_polynomial_reference( + &mut self, poly: ::ast::parsed::NamespacedPolynomialReference, ) -> PolynomialReference { let index = poly .index() .as_ref() - .map(|i| self.evaluate_expression(i).unwrap()) + .map(|i| self.analyzer.evaluate_expression(*i.clone()).unwrap()) .map(|i| i.to_degree()); - let name = self.namespaced_ref(poly.namespace(), poly.name()); + let name = self.analyzer.namespaced_ref(poly.namespace(), poly.name()); PolynomialReference { name, poly_id: None, @@ -625,8 +664,8 @@ impl PILAnalyzer { } } - fn process_shifted_polynomial_reference( - &self, + pub fn process_shifted_polynomial_reference( + &mut self, poly: ::ast::parsed::ShiftedPolynomialReference, ) -> PolynomialReference { PolynomialReference { @@ -634,63 +673,6 @@ impl PILAnalyzer { ..self.process_namespaced_polynomial_reference(poly.into_namespaced()) } } - - fn evaluate_expression(&self, expr: &::ast::parsed::Expression) -> Option { - use ast::parsed::Expression::*; - match expr { - Constant(name) => Some( - *self - .constants - .get(name) - .unwrap_or_else(|| panic!("Constant {name} not found.")), - ), - Reference(name) => { - // TODO this whole mechanism should be replaced by a generic "reference" - // type plus operators. - if !name.shift() && name.namespace().is_none() { - // See if it might be a constant - self.constants.get(&name.name().to_owned()).cloned() - } else { - None - } - } - PublicReference(_) => None, - Number(n) => Some(*n), - String(_) => None, - Tuple(_) => None, - ArrayLiteral(_) => None, - LambdaExpression(_) => None, - BinaryOperation(left, op, right) => self.evaluate_binary_operation(left, *op, right), - UnaryOperation(op, value) => self.evaluate_unary_operation(*op, value), - FunctionCall(_) => None, - FreeInput(_) => panic!(), - MatchExpression(_, _) => None, - } - } - - fn evaluate_binary_operation( - &self, - left: &::ast::parsed::Expression, - op: BinaryOperator, - right: &::ast::parsed::Expression, - ) -> Option { - Some(ast::evaluate_binary_operation( - self.evaluate_expression(left)?, - op, - self.evaluate_expression(right)?, - )) - } - - fn evaluate_unary_operation( - &self, - op: UnaryOperator, - value: &::ast::parsed::Expression, - ) -> Option { - Some(ast::evaluate_unary_operation( - op, - self.evaluate_expression(value)?, - )) - } } pub fn inline_intermediate_polynomials(analyzed: &Analyzed) -> Vec> {